nn.embedding函数详解(pytorch)

提示:文章附有源码!!!

文章目录

  • 前言
  • 一、nn.embedding函数解释
  • 二、nn.embedding函数使用方法
  • 四、模型训练与预测的权重变化探讨


前言

最近发现prompt工程(如sam模型),也有transform的detr模型等都使用了nn.Embedding函数,对points、boxes或learn query进行编码或解码。因此,我想写一篇文章作为记录,本想简单对其 介绍,但写着写着就想把所有与它相关东西作为记录。本文章探讨了nn.Embedding参数、使用方法、模型训练与预测的变化,并附有列子源码作为支撑 ,呈现一个较为完善的理解内容。

一、nn.embedding函数解释

Embedding实际是一个索引表或查找表,它是符合随机初始化生成的正太分布的表,将输入向量化,其结构如下:

nn.Embedding(num_embeddings, embedding_dim)

第1个参数 num_embeddings 就是生成num_embeddings个嵌入向量。
第2个参数 embedding_dim 就是嵌入向量的维度,即用embedding_dim值的维数来表示一个基本单位。

当然,该函数还有很多其它参数,解释如下:

参数源码注释如下:

num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
                             therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
                             i.e. it remains as a fixed "pad". For a newly constructed Embedding,
                             the embedding vector at :attr:`padding_idx` will default to all zeros,
                             but can be updated to another value to be used as the padding vector.
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
                            is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
                                        the words in the mini-batch. Default ``False``.
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
                         See Notes for more details regarding sparse gradients.

参数中文解释:

num_embeddings (python:int) – 词典的大小尺寸,比如总共出现5000个词,那就输入5000。此时index为(0-4999embedding_dim (python:int) – 嵌入向量的维度,即用多少维来表示一个符号。
padding_idx (python:int, optional) – 填充id,比如,输入长度为100,但是每次的句子长度并不一样,后面就需要用统一的数字填充,而这里就是指定这个数字,这样,网络在遇到填充id时,就不会计算其与其它符号的相关性。(初始化为0max_norm (python:float, optional) – 最大范数,如果嵌入向量的范数超过了这个界限,就要进行再归一化。
norm_type (python:float, optional) – 指定利用什么范数计算,并用于对比max_norm,默认为2范数。
scale_grad_by_freq (boolean, optional) – 根据单词在mini-batch中出现的频率,对梯度进行放缩。默认为False.
sparse (bool, optional) – 若为True,则与权重矩阵相关的梯度转变为稀疏张量

注:该函数服从正太分布,该函数可参与训练,我将在后面做解释。

二、nn.embedding函数使用方法

该函数实际是对词的编码,假如你有2句话,每句话有四个词,那么你想对每个词使用6个维度表达,其代码如下:

import torch.nn as nn
import torch
if __name__ == '__main__':
   embedding = nn.Embedding(100, 6)  # 我设置100个索引,每个使用6个维度表达。
   input = torch.LongTensor([[1, 2, 4, 5],
   [4, 3, 2, 3]])  # a batch of 2 samples of 4 indices each
   e = embedding(input)
   print('输出尺寸', e.shape)
   print('输出值:\n',e)
   weights=embedding.weight
   print('embed权重输出值:\n', weights[:6])
     

输出结果:
在这里插入图片描述

从图上可看出,输入编码是通过索引查找已编号embedding的权重,并将其赋值替换表达。换句话说,nn.Embedding(100, 6)生成正太分布100行6列数据,行必须超过输入句子词语长度,而句子每个词使用整数编码成索引,该索引对应之前embedding行寻找,得到对应行
维度,即可转为表达该词的特征向量。

四、模型训练与预测的权重变化探讨

之前已说过nn.Embedding()在训练过程中会发生变化,但在预测中将不在变化,应该是被训练成最佳词的向量维度表达,也就是说每个词唯一对应索引,被Embedding特征表达训练成最佳特征表达,也可说训练词索引特征表达固定。为探讨此过程,我写了对应示列,如下:

import torch
from torch.nn import Embedding

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.emb = Embedding(5, 3)
    def forward(self,vec):
        input = torch.tensor([0, 1, 2, 3, 4])
        emb_vec1 = self.emb(input)
        # print(emb_vec1)  ### 输出对同一组词汇的编码
        output = torch.einsum('ik, kj -> ij', emb_vec1, vec)
        return output
def simple_train():
    model = Model()
    vec = torch.randn((3, 1))
    label = torch.Tensor(5, 1).fill_(3)
    loss_fun = torch.nn.MSELoss()
    opt = torch.optim.SGD(model.parameters(), lr=0.015)
    print('初始化emebding参数权重:\n',model.emb.weight)
    for iter_num in range(100):
        output = model(vec)
        loss = loss_fun(output, label)
        opt.zero_grad()
        loss.backward(retain_graph=True)
        opt.step()
        # print('第{}次迭代emebding参数权重{}:\n'.format(iter_num, model.emb.weight))
    print('训练后emebding参数权重:\n',model.emb.weight)
    torch.save(model.state_dict(),'./embeding.pth')
    return model

def simple_test():
   model = Model()
   ckpt = torch.load('./embeding.pth')
   model.load_state_dict(ckpt)

   model=model.eval()
   vec = torch.randn((3, 1))

   print('加载emebding参数权重:\n', model.emb.weight)
   for iter_num in range(100):
      output = model(vec)
   print('n次预测后emebding参数权重:\n', model.emb.weight)

if __name__ == '__main__':
   simple_train()  # 训练与保存权重
   simple_test()

结果如下:

在这里插入图片描述
训练代码参考博客:点击这里

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/118520.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

GB28181学习(十五)——流传输方式

前言 基于GB/T28181-2022版本,实时流的传输方式包括3种: UDPTCP被动TCP主动 UDP 流程 注意: m字段指定传输方式为RTP/AVP; 抓包 SIP服务器发送INVITE请求; INVITE sip:xxx192.168.0.111:5060 SIP/2.0 Via: SIP…

【C++】智能指针【内存泄漏|智能指针原理及使用|RAII】

目录 1、了解内存泄露 1.1 内存泄漏的定义及危害 1.2 内存泄漏分类(了解) 1.3 如何检测内存泄漏(了解) 1.4如何避免内存泄漏 2、智能指针的引出 3、智能指针的使用及原理 3.1 RAII 3.2 智能指针的原理 3.3 std::auto_pt…

skynet学习笔记01— skynet开发环境搭建(超详细)与第一个skynet程序

00、参考资料 https://blog.csdn.net/qq769651718/category_7480207.html 01、前置准备 开发所在目录 mhzzjmhzzj-virtual-machine:~/work/skynetStudy$ pwd /home/mhzzj/work/skynetStudy前置准备 mhzzjmhzzj-virtual-machine:~/work/skynetStudy$ sudo apt install lua5…

手动关闭PS中的TopazStudio2的登录窗口

2021 adobe photoshop Topaz Studio 2 不是使用防火墙出站规则,是手动关闭的解决方案 点击社区-切换用户,登录窗口会出现X,可以手动关闭

【flutter no devices】

1.在环境变量增加 ANDROID_HOME 值为:C:\Users\Administrator\AppData\Local\Android\Sdk (Android sdk 位置) 2 环境变量的path里面增加2个值: %ANDROID_HOME%\platform-tools %ANDROID_HOME%\tools 3 打开cmd,或者在Android st…

uniapp使用抖音微信自定义组件

tt.vue中使用video-player组件 用到的目录如下: pages.json {"path": "pages/Tabbar/tt/tt","style": {"navigationBarTitleText": "","enablePullDownRefresh": false,// 使用自定义组件"using…

使用R语言构建HTTP爬虫:IP管理与策略

目录 摘要 一、HTTP爬虫与IP管理概述 二、使用R语言进行IP管理 三、爬虫的伦理与合规性 四、注意事项 结论 摘要 本文深入探讨了使用R语言构建HTTP爬虫时如何有效管理IP地址。由于网络爬虫高频、大量的请求可能导致IP被封禁,因此合理的IP管理策略显得尤为重要…

【JAVA学习笔记】63 -坦克大战1.3-敌方发射子弹,击中坦克消失并爆炸,敌人坦克随机移动,规定范围限制移动

项目代码 https://github.com/yinhai1114/Java_Learning_Code/tree/main/IDEA_Chapter18/src/com/yinhai/tankgame1_3 〇、要求 增加功能 1.让敌人的坦克也能够发射子弹(可以有多颗子弹) 2.当我方坦克击中敌人坦克时,敌人的坦克就消失,如果能做出爆炸效果更好. …

人工智能-卷积神经网络

从全连接层到卷积 我们之前讨论的多层感知机十分适合处理表格数据,其中行对应样本,列对应特征。 对于表格数据,我们寻找的模式可能涉及特征之间的交互,但是我们不能预先假设任何与特征交互相关的先验结构。 此时,多层感…

Java锁常见面试题

图片引用自:不可不说的Java“锁”事 - 美团技术团队 1 java内存模型 java内存模型(JMM)是线程间通信的控制机制。JMM定义了主内存和线程之间抽象关系。线程之间的共享变量存储在主内存中,每个线程都有一个私有的本地内存,本地内存中存储了该…

计算机毕业设计java+springboot+vue的旅游攻略平台

项目介绍 本系统结合计算机系统的结构、概念、模型、原理、方法,在计算机各种优势的情况下,采用JAVA语言,结合SpringBoot框架与Vue框架以及MYSQL数据库设计并实现的。员工管理系统主要包括个人中心、用户管理、攻略管理、审核信息管理、积分…

Crypto(9)[MRCTF2020]keyboard

下载题目,看看里面是什么 这是什么鬼,由题目可以获得线索,keyboard,不是键盘吗,然后看了看别人写的wp,发现是九键,有几个数字对应的密文就是第几个字母 比如第一个6,对应的字母是mno&#xff0c…

80个10倍提升Excel技能的ChatGPT提示

你是否厌倦了在使用Excel时感觉像个新手?你是否想将你的技能提升到更高的水平,成为真正的Excel大师?嗯,如果你正在使用ChatGPT,那么成为Excel专家简直易如反掌。 你只需要了解一些最有用的Excel提示,就能在…

龙迅LT8619C HDMI转LVDS/RGB/BT656/BT1120/BT601

LT8619C 描述: Lontium的LT8619C是一款高性能的HDMI/双模式DP接收器芯片,符合HDMI 1.4规范。TTL输出可支持RGB、BT656、BT1120,输出分辨率最多可支持4Kx2K30Hz。为了方便地实现一个多媒体系统,LT8619C支持8通道高质量的I2S音频或…

CSDN每日一题学习训练——Java版(两数相加、二叉树的锯齿形层序遍历、解数独)

版本说明 当前版本号[20231106]。 版本修改说明20231106初版 目录 文章目录 版本说明目录两数相加题目解题思路代码思路补充说明参考代码 二叉树的锯齿形层序遍历题目解题思路代码思路参考代码 解数独题目解题思路代码思路补充说明参考代码 两数相加 题目 给你两个 非空 的…

ARMday02(汇编语法、汇编指令)

汇编语法 汇编文件中的内容 1.伪操作:在汇编程序中不占用存储空间,但是可以在程序编译时起到引导和标识作用 .text .global .glbal .if .else .endif .data .word.... 2.汇编指令:每一条汇编指令都用来标识一个机器码,让计算机做…

【亚马逊云科技产品测评】活动征文|亚马逊云科技AWS之EC2详细测评

引言 (授权声明:本篇文章授权活动官方亚马逊云科技文章转发、改写权,包括不限于在 Developer Centre, 知乎,自媒体平台,第三方开发者媒体等亚马逊云科技官方渠道) 在当前的数字化时代,云服务已…

RK3568平台开发系列讲解(音视频篇)RTMP 推流

🚀返回专栏总目录 文章目录 一、RTMP 的工作原理二、RTMP 流媒体服务框架2.1、Nginx 流媒体服务器2.2、FFmpeg 推流沉淀、分享、成长,让自己和他人都能有所收获!😄 📢目前常见的视频监控和视频直播都是使用了 RTMP、RTSP、HLS、MPEG-DASH、 WebRTC流媒体传输协议等。 R…

Apache Doris (五十二): Doris Join类型 - Broadcast Join

🏡 个人主页:IT贫道_大数据OLAP体系技术栈,Apache Doris,Clickhouse 技术-CSDN博客 🚩 私聊博主:加入大数据技术讨论群聊,获取更多大数据资料。 🔔 博主个人B栈地址:豹哥教你大数据的个人空间-豹哥教你大数据个人主页-哔哩哔哩视频 目录 1. Broadcast Join原理

虹科示波器 | 汽车免拆检测 | 2017款长安福特翼虎车发动机故障灯异常点亮

一、故障现象 一辆2017款长安福特翼虎车,搭载CAF488WQ9发动机,累计行驶里程约为8.9万km。该车因发动机故障灯异常点亮在其他维修厂检修,维修人员用故障检测仪检测,提示气缸3失火,调换火花塞、点火线圈及喷油器&#xf…