Pytorch手撸Attention

Pytorch手撸Attention

注释写的很详细了,对照着公式比较下更好理解,可以参考一下知乎的文章

注意力机制

在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F


class SelfAttention(nn.Module):
    def __init__(self, embed_size):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size

        # 定义三个全连接层,用于生成查询(Q)、键(K)和值(V)
        # 用Linear线性层让q、k、y能更好的拟合实际需求
        self.value = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.query = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        # x 的形状应为 (batch_size批次数量, seq_len序列长度, embed_size嵌入维度)
        batch_size, seq_len, embed_size = x.shape

        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        
        # 计算注意力分数矩阵
        # 使用 Q 矩阵乘以 K 矩阵的转置来得到原始注意力分数
        # 注意力分数的形状为 [batch_size, seq_len, seq_len]
        # K.transpose(1,2)转置后[batch_size, embed_size, seq_len]
        # 为什么不直接使用 .T 直接转置?直接转置就成了[embed_size, seq_len,batch_size],不方便后续进行矩阵乘法
        attention_scores = torch.matmul(Q, K.transpose(1, 2)) / torch.sqrt(
            torch.tensor(self.embed_size, dtype=torch.float32))

        # 应用 softmax 获取归一化的注意力权重,dim=-1表示基于最后一个维度做softmax
        attention_weight = F.softmax(attention_scores, dim=-1)

        # 应用注意力权重到 V 矩阵,得到加权和
        # 输出的形状为 [batch_size, seq_len, embed_size]
        output = torch.matmul(attention_weight, V)

        return output

多头注意力机制

在这里插入图片描述

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super().__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        # 整除来确定每个头的维度
        self.head_dim = embed_size // num_heads
		
        # 加入断言,防止head_dim是小数,必须保证可以整除
        assert self.head_dim * num_heads == embed_size

        self.q = nn.Linear(embed_size, embed_size)
        self.k = nn.Linear(embed_size, embed_size)
        self.v = nn.Linear(embed_size, embed_size)
        self.out = nn.Linear(embed_size, embed_size)

    def forward(self, query, key, value):
        # N就是batch_size的数量
        N = query.shape[0]
        
        # *_len是序列长度
        q_len = query.shape[1]
        k_len = key.shape[1]
        v_len = value.shape[1]
		
        # 通过线性变换让矩阵更好的拟合
        queries = self.q(query)
        keys = self.k(key)
        values = self.v(value)
		
        # 重新构建多头的queries,permute调整tensor的维度顺序
        # 结合下文demo进行理解
        queries = queries.reshape(N, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        keys = keys.reshape(N, k_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        values = values.reshape(N, v_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
		
        # 计算多头注意力分数
        attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt(
            torch.tensor(self.head_dim, dtype=torch.float32))
        attention = F.softmax(attention_scores, dim=-1)
		
        # 整合多头注意力机制的计算结果
        out = torch.matmul(attention, values).permute(0, 2, 1, 3).reshape(N, q_len, self.embed_size)
        # 过一遍线性函数
        out = self.out(out)

        return out

demo测试

self-attention测试
# 测试自注意力机制
batch_size = 2
seq_len = 3
embed_size = 4

# 生成一个随机数据 tensor
input_tensor = torch.rand(batch_size, seq_len, embed_size)

# 创建自注意力模型实例
model = SelfAttention(embed_size)

# print输入数据
print("输入数据 [batch_size, seq_len, embed_size]:")
print(input_tensor)

# 运行自注意力模型
output_tensor = model(input_tensor)

# print输出数据
print("输出数据 [batch_size, seq_len, embed_size]:")
print(output_tensor)

=======print=========

输入数据 [batch_size, seq_len, embed_size]:
tensor([[[0.7579, 0.7342, 0.1031, 0.8610],
         [0.8250, 0.0362, 0.8953, 0.1687],
         [0.8254, 0.8506, 0.9826, 0.0440]],

        [[0.0700, 0.4503, 0.1597, 0.6681],
         [0.8587, 0.4884, 0.4604, 0.2724],
         [0.5490, 0.7795, 0.7391, 0.9113]]])

输出数据 [batch_size, seq_len, embed_size]:
tensor([[[-0.3714,  0.6405, -0.0865, -0.0659],
         [-0.3748,  0.6389, -0.0861, -0.0706],
         [-0.3694,  0.6388, -0.0855, -0.0660]],

        [[-0.2365,  0.4541, -0.1811, -0.0354],
         [-0.2338,  0.4455, -0.1871, -0.0370],
         [-0.2332,  0.4458, -0.1867, -0.0363]]], grad_fn=<UnsafeViewBackward0>)
MultiHeadAttention

多头注意力机制务必自己debug一下,主要聚焦在理解如何拆分成多头的,不结合代码你很难理解多头的操作过程

1、queries.reshape(N, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 处理之后的 size = torch.Size([64, 8, 10, 16])

  • 通过上述操作,queries 张量的最终形状变为 [N, self.num_heads, q_len, self.head_dim]。这样的排列方式使得每个注意力头可以单独处理对应的序列部分,而每个头的处理仅关注其分配到的特定维度 self.head_dim
  • 这个形状是为了后续的矩阵乘法操作准备的,其中每个头的查询将与对应的键进行点乘,以计算注意力分数

2、attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt( torch.tensor(self.head_dim, dtype=torch.float32)) 将reshape后的quries的后两个维度进行转置后点乘,对应了 Q ⋅ K T Q \cdot K^T QKT ;根据demo这里的头数为8,所以公式中对应的下标 i i i 为8

3、在进行完多头注意力机制的计算后通过 torch.matmul(attention, values).permute(0, 2, 1, 3).reshape(N, q_len, self.embed_size) 整合,变回原来的 [batch_size,seq_length,embed_size]形状

# 测试多头注意力
embed_size = 128  # 嵌入维度
num_heads = 8    # 头数
attention = MultiHeadAttention(embed_size, num_heads)

# 创建随机数据模拟 [batch_size, seq_length, embedding_dim]
batch_size = 64
seq_length = 10
dummy_values = torch.rand(batch_size, seq_length, embed_size)
dummy_keys = torch.rand(batch_size, seq_length, embed_size)
dummy_queries = torch.rand(batch_size, seq_length, embed_size)

# 计算多头注意力输出
output = attention(dummy_values, dummy_keys, dummy_queries)
print(output.shape)  # [batch_size, seq_length, embed_size]

=======print=========

torch.Size([64, 10, 128])

如果你难以理解权重矩阵的拼接和拆分,推荐李宏毅的attention课程(YouTobe)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/550412.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Sy-linux下常用的网络命令linux network commands

linux下的网络命令非常强大&#xff0c;这里根据教材需要&#xff0c;列出来常用的网络命令和场景实例&#xff0c;供参考。 一、命令列表&#xff1a; Command Description ip Manipulating routing to assigning and configuring network parameters traceroute Identi…

【Java】通过poi给word首页添加水印图片

背景&#xff1a; poi并没有提供直接插入水印图片的方法&#xff0c;目前需要再word的首页插入一张水印图片&#xff0c;于是就需要通过另一种方式&#xff0c;插入透明图片&#xff08;png格式&#xff09;并将图片设置为“浮于文字上方”的方式实现该需求。 所需jar&#xf…

Linux解压4GB以上zip文件

Linux使用unzip解压大于4GB文件&#xff0c;会出现以下错误&#xff1a; 解决方法 安装p7zip yum -y install p7zip执行命令&#xff1a; 7za x MSRVTT.zip

Spark-机器学习(2)特征工程之特征提取

在之前的文章中&#xff0c;我们了解我们的机器学习&#xff0c;了解我们spark机器学习中的MLIib算法库&#xff0c;知道它大概的模型&#xff0c;熟悉并认识它。想了解的朋友可以查看这篇文章。同时&#xff0c;希望我的文章能帮助到你&#xff0c;如果觉得我的文章写的不错&a…

HackMyVM-Connection

目录 信息收集 arp nmap WEB web信息收集 dirsearch smbclient put shell 提权 系统信息收集 suid gdb提权 信息收集 arp ┌─[rootparrot]─[~/HackMyVM] └──╼ #arp-scan -l Interface: enp0s3, type: EN10MB, MAC: 08:00:27:16:3d:f8, IPv4: 192.168.9.115 S…

js打印页面源码 ,打印选取的容器里的内容,打印指定内容

js打印页面源码 &#xff0c;打印选取的容器里的内容&#xff0c;打印指定内容 效果 代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge&…

FreeRTOS时间管理

FreeRTOS时间管理 主要要了解延时函数&#xff1a; 相对延时&#xff1a;指每次延时都是从执行函数vTaskDelay()开始&#xff0c;直到延时指定的时间结束。 绝对延时&#xff1a;指将整个任务的运行周期看成一个整体&#xff0c;适用于需要按照一定频率运行的任务。 函数 vTa…

PTA图论的搜索题

目录 7-1 列出连通集 题目 输入格式: 输出格式: 输入样例: 输出样例: AC代码 7-2 六度空间 题目 输入格式: 输出格式: 输入样例: 输出样例: 思路 AC代码 7-3 地下迷宫探索 题目 输入格式: 输出格式: 输入样例1: 输出样例1: 输入样例2: 输出样例2: 思路 …

MySQL 试图

视图功能在 5.0 以后的版本启用 视图是一张虚表。数据表确实包含了具体数据并且保存到硬盘中的实表。视图使用数据检索语句动态生 成的一张虚表。每一次数据服务重启或者系统重启之后&#xff0c;在数据库服务启动期间&#xff0c;会使用创建视图的语 句重新生成视图中的数据&…

这家物流装备公司突破天际:销售额飙升至10亿美元,引领仓储机器人革命!...

导语 大家好&#xff0c;我是智能仓储物流技术研习社的社长&#xff0c;老K。专注分享智能仓储物流技术、智能制造等内容。 新书《智能物流系统构成与技术实践》 法国的Exotec公司在仓储自动化领域取得了显著成就&#xff0c;其销售额已超过10亿美元&#xff0c;成为全球物料搬…

考研数学|《1800》《1000》《660》《880》如何搭配❓

这几本书都是不同阶段对应的习题册 我觉得最舒服的使用就是方式就是基础阶段用《1800题基础部分》然后强化阶段主要刷《880题》并且强化阶段带着刷《660题》 上面是我的使用方式。之所以没有刷《1000题》是因为这本习题册的难度对我来说还是太大了&#xff0c;并且计算量很大…

上海计算机学会 2023年10月月赛 乙组T3 树的连通子图(树、树形dp)

第三题&#xff1a;T3树的连通子图 标签&#xff1a;树、树形 d p dp dp题意&#xff1a;给定一棵 n n n个结点的树&#xff0c; 1 1 1号点为这棵树的根。计算这棵树连通子图的个数&#xff0c;答案对 1 , 000 , 000 , 007 1,000,000,007 1,000,000,007取余数。题解&#xff1…

HTML内联框架

前言&#xff1a; 我们有时候打开网页时会有广告窗的出现&#xff0c;而这些窗口并不是来自于本站的&#xff0c;而是来自于外部网页&#xff0c;只是被引用到了自己网页中而已。这一种技术可以通过内联来实现。 标签介绍&#xff1a; HTML 内联框架元素 (<iframe>) 表示…

基于Springboot的影城管理系统

基于SpringbootVue的影城管理系统的设计与实现 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringbootMybatis工具&#xff1a;IDEA、Maven、Navicat 系统展示 用户登录 首页展示 电影信息 电影资讯 后台登录页 后台首页 用户管理 电影类型管理 放映…

RAG (Retrieval Augmented Generation) 结合 LlamaIndex、Elasticsearch 和 Mistral

作者&#xff1a;Srikanth Manvi 在这篇文章中&#xff0c;我们将讨论如何使用 RAG 技术&#xff08;检索增强生成&#xff09;和 Elasticsearch 作为向量数据库来实现问答体验。我们将使用 LlamaIndex 和本地运行的 Mistral LLM。 在开始之前&#xff0c;我们将先了解一些术…

vue3 生命周期(生命周期钩子 vs 生命周期选项 vs 缓存实例的生命周期)

vue3 支持两种风格书写&#xff1a;选项式 API 和组合式 API 若采用组合式 API &#xff0c;则使用生命周期钩子若采用选项式 API &#xff0c;则使用生命周期选项两者选用一种即可&#xff0c;不建议同时使用&#xff0c;避免逻辑紊乱。 生命周期钩子 在 setup 中使用 onBefo…

Vue 阶段练习:记事本

将 Vue快速入门 和 Vue 指令的学习成果应用到实际场景中&#xff08;如该练习 记事本&#xff09;&#xff0c;我们能够解决实际问题并提升对 Vue 的技能掌握。 目录 功能展示 需求分析 我的代码 案例代码 知识点总结 功能展示 需求分析 列表渲染删除功能添加功能底部统计…

3D目标检测实用技巧(二)- 实现点云(or 体素)向图像平面的投影并可视化

一、引言 受Focals Conv的启发&#xff0c;该论文中通过将点云投影到图片中清晰展现出点云学习后的情况&#xff1a; 本次实现的是体素向图像投影并显示&#xff0c;实现出来的效果如下&#xff1a; 二、 实现细节 1、体素投影到图像坐标系 这里我们参考的是VirConv的投影函…

通过matlab分别对比PSO,反向学习PSO,多策略改进反向学习PSO三种优化算法

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 4.1 粒子群优化算法 (PSO) 4.2 反向学习粒子群优化算法 (OPSO) 4.3 多策略改进反向学习粒子群优化算法 (MSO-PSO) 5.完整程序 1.程序功能描述 分别对比PSO,反向学习PSO,多策略改进反向学…

百货商场用户画像描绘与价值分析

目录 内容概述数据说明实现目标技术点主要内容导入模块1.项目背景1.1 项目背景与挖掘目标 2.数据探索与预处理2.1 结合业务对数据进行探索并进行预处理2.2 将会员信息表和销售流水表关联与合并 3 统计分析3.1 分析会员的年龄构成、男女比例等基本信息3.2 分析会员的总订单占比&…
最新文章