Transformer详解:从放弃到入门(二)

多头注意力

  上篇文章中我们了解了词编码和位置编码,接下来我们介绍Transformer中的核心模块——多头注意力。

自注意力

  首先回顾下注意力机制,注意力机制允许模型为序列中不同的元素分配不同的权重。而自注意力中的"自"表示输入序列中的输入相互之间的注意力,即通过某种方式计算输入序列每个位置相互之间的相关性。具体的推导可以看这篇文章。
  对于Transformer编码器来说,给定一个输入序列 ( x 1 , . . . , x n ) (x_1,...,x_n) (x1,...,xn),这里假设是输入序列中第i个位置所对应的词嵌入。自注意力产生了一个新的相同长度的嵌入 ( y 1 , . . . , y 2 ) (y_1,...,y_2) (y1,...,y2),其中每个 y i y_i yi是所有 x j x_j xj的加权和(包括本身): y i = ∑ j α i j x j y_i=\sum_j\alpha_{ij}x_j yi=jαijxj  其中, α i j \alpha_{ij} αij是注意力权重,有 ∑ j α i j = 1 \sum_j\alpha_{ij}=1 jαij=1
  计算注意力的方式有很多种,最高效的是点积注意力,即两个输入之间做点积。点积的结果是一个实数范围内的标量,结果越大代表两个向量越相似。这是计算两个输入之间的注意力分数,将某个token与所有的输入进行计算,就可以得到n个注意力分数,经过Softmax归一化就可以得到权重向量 α \alpha α,其中 α i j \alpha_{ij} αij表示两个输入i和j之间的相关度(权重系数): α i j = s o f t m a x ( s c o r e ( x i , x j ) ) \alpha_{ij}=softmax(score(x_i,x_j)) αij=softmax(score(xi,xj))  得到了这些权重之后,就可以按照上面的公式对所有输入加权得到 y i y_i yi
  Transformer中的注意力会更加复杂一点,主要体现在两点:Q,K,V和缩放点积机制和多头注意力。

缩放点积注意力

  Transformer的多头注意力模块有三个输入:

  • Query: 与所有的输入进行比较,为当前关注的点。
  • Key:作为与Query进行比较的角色,用于计算和Query之间的相关性。
  • Value:用于计算当前注意力关注点的输出,根据注意力权重对不同的Value进行加权和。
      这三个输入都是由原始输入映射而来的,为了生成这三种不同的角色,Transformer分别引入了三个权重矩阵 W Q , W K , W V W^Q,W^K,W^V WQ,WK,WV,分别将每个输入投影到不同角色query,key和value表示: q = x W Q ; k = x W K ; v = x W V q=xW^Q; k=xW^K; v=xW^V q=xWQ;k=xWK;v=xWV  Query和Key是用于比较的,Value是用于提取特征的。通过将输入映射到不同的角色,使模型具有更强的学习能力。看到一个比较直观的解释:如果把注意力过程类比成搜索的话,那么假设在百度中输入"自然语言处理是什么",那么Query就是这个搜索的语句;Key相当于检索到的网页的标题;Value就是网页的内容。在这里插入图片描述  现在我们用Q和K矩阵计算亲和度矩阵sim matrix,矩阵中第i,j个位置,即指原序列第i个token和第j个token之间的亲和度(点积结果): s c o r e ( q i , k j ) = q i ⋅ k j score(q_i,k_j)=q_i\cdot k_j score(qi,kj)=qikj  点积的结果是一个标量,但这个结果可能非常大(不管是正的还是负的),这会使得softmax函数值进入一个导数非常小的区域。需要对这个注意力得分进行缩放,缩放使得分布更加平滑。一种缩放的方法是把点积结果除以一个和嵌入大小相关的因子(factor)。注意这是在传递给softmax之前进行的。Transformer的做法是除以query和key向量维度的平方根: s c o r e ( q i , k j ) = q i ⋅ k j d score(q_i,k_j)=\frac{q_i\cdot k_j}{\sqrt d} score(qi,kj)=d qikj  这是计算具体两个token向量的亲和度,我们可以用矩阵相乘实现批量操作,考虑一下维度的问题:包含batch大小和序列长度,输入的完整维度是(batch_size, seq_len, embed_dim)。我们得到query,key和value也是相同维度的,只是经过了不同的线性变换。那么亲和度矩阵的计算方法为: S e l f A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V SelfAttention(Q,K,V)=softmax(\frac{QK^T}{\sqrt d_k})V SelfAttention(Q,K,V)=softmax(d kQKT)V

多头注意力

  上面介绍的缩放点积注意力把原始的x映射到不同的空间后,去做注意力。每次映射相当于是在特定空间中去建模特定的语义交互关系,类似卷积中的多通道可以得到多个特征图,那么多个注意力可以得到多个不同方面的语义交互关系。可以让模型更好地关注到不同位置的信息,捕捉到输入序列中不同依赖关系和语义信息。有助于处理长序列、解决语义消歧、句子表示等任务,提高模型的建模能力。Transformer中使用多头注意力实现这一点。
  对于多头注意力中的每个头i,都有自己不同的query,key和value矩阵,假设Q,K矩阵的维度是 d k d_k dk,V矩阵维度是 d v d_v dv,那么每个头的权重矩阵是 W i Q ∈ R d × d k , W i K ∈ R d × d k , , W i K ∈ R d × d v W_i^Q\in R^{d×d_k},W_i^K\in R^{d×d_k},,W_i^K\in R^{d×d_v} WiQRd×dk,WiKRd×dk,,WiKRd×dv,将他们与reshape后的输入相乘,得到的每个头的Q,K,V矩阵维度应该是 Q ∈ R s q l e n ∗ d k , K ∈ R s q l e n ∗ d k , V ∈ R s q l e n ∗ d v Q\in R^{sqlen*d_k},K\in R^{sqlen*d_k},V\in R^{sqlen*d_v} QRsqlendk,KRsqlendk,VRsqlendv其中,sqlen是序列长度, d k d_k dk是每个头的维度。
  得到这些多头注意力的组合以后,再把它们拼接起来,然后通过一个线性变化映射回原来的维度,保证输入和输出的维度一致: M u l t i H e a d A t t e n t i o n ( X ) = c o n c a t ( h e a d 1 , . . . , h e a d h ) W o MultiHeadAttention(X)=concat(head_1,...,head_h)W^o MultiHeadAttention(X)=concat(head1,...,headh)Wo h e a d i = S e l f A t t e n t i o n ( Q , K , V ) head_i=SelfAttention(Q,K,V) headi=SelfAttention(Q,K,V) Q = X W i Q ; K = X W I K ; V = X W I V Q=XW_i^Q;K=XW_I^K;V=XW_I^V Q=XWiQ;K=XWIK;V=XWIV
  下面是一个三个头的注意力示意图,在原论文中,d = 512,有h = 8个注意力头。由于每个头维度的减少,总的计算量和正常维度的单头注意力一样(8 × 64 = 512)。在这里插入图片描述  给出一个pytorch实现的例子,在forward方法中,首先利用三个线性变换分别计算query,key,value矩阵。接着拆分成多个头,传给attention方法计算多头注意力,然后合并多头注意力的结果。最后经过一个用作拼接的线性层。
  在下一篇文章中,我会继续讲解Transformer中剩下的其他组件。

class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        d_model: int = 512,
        n_heads: int = 8,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_key = d_model // n_heads  # dimension of every head

        self.q = nn.Linear(d_model, d_model)  # query matrix
        self.k = nn.Linear(d_model, d_model)  # key matrix
        self.v = nn.Linear(d_model, d_model)  # value matrix
        self.concat = nn.Linear(d_model, d_model)  # output

        self.dropout = nn.Dropout(dropout)

    def split_heads(self, x: Tensor, is_key: bool = False) -> Tensor:
        batch_size = x.size(0)
        # x (batch_size, seq_len, n_heads, d_key)
        x = x.view(batch_size, -1, self.n_heads, self.d_key)
        if is_key:
            # (batch_size, n_heads, d_key, seq_len)
            return x.permute(0, 2, 3, 1)
        # (batch_size, n_heads, seq_len, d_key)
        return x.transpose(1, 2)

    def merge_heads(self, x: Tensor) -> Tensor:
        x = x.transpose(1, 2).contiguous().view(x.size(0), -1, self.d_model)
        return x

    def attenion(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        mask: Tensor = None,
        keep_attentions: bool = False,
    ):
        scores = torch.matmul(query, key) / math.sqrt(self.d_key)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # weights (batch_size, n_heads, q_length, k_length)
        weights = self.dropout(torch.softmax(scores, dim=-1))
        # (batch_size, n_heads, q_length, k_length) x (batch_size, n_heads, v_length, d_key) -> (batch_size, n_heads, q_length, d_key)
        # assert k_length == v_length
        # attn_output (batch_size, n_heads, q_length, d_key)
        attn_output = torch.matmul(weights, value)

        if keep_attentions:
            self.weights = weights
        else:
            del weights

        return attn_output

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        mask: Tensor = None,
        keep_attentions: bool = False,
    ) -> Tuple[Tensor, Tensor]:
        """

        Args:
            query (Tensor): (batch_size, q_length, d_model)
            key (Tensor): (batch_size, k_length, d_model)
            value (Tensor): (batch_size, v_length, d_model)
            mask (Tensor, optional): mask for padding or decoder. Defaults to None.
            keep_attentions (bool): whether keep attention weigths or not. Defaults to False.

        Returns:
            output (Tensor): (batch_size, q_length, d_model) attention output
        """
        query, key, value = self.q(query), self.k(key), self.v(value)

        query, key, value = (
            self.split_heads(query),
            self.split_heads(key, is_key=True),
            self.split_heads(value),
        )

        attn_output = self.attenion(query, key, value, mask, keep_attentions)

        del query
        del key
        del value

        # Concat
        concat_output = self.merge_heads(attn_output)
        # the final liear
        # output (batch_size, q_length, d_model)
        output = self.concat(concat_output)

        return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/599102.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

相机内存卡格式化怎么恢复?恢复数据的3个方法

相机内存卡格式化后,许多用户都曾面临过照片丢失的困境。这些照片可能具有极高的纪念价值,也可能包含着重要的信息。因此如何有效地恢复这些照片变得至关重要。本文将详细介绍三种实用的恢复方法,帮助您找回那些珍贵的影像。 下面分享几个实…

RS2103XH 功能和参数介绍及规格书

RS2103XH 是一款单刀双掷(SPDT)模拟开关芯片,主要用于各种模拟信号的切换和控制。下面是一些其主要的功能和参数介绍: 主要功能特点: 模拟信号切换:能够连接和断开模拟信号路径,提供灵活的信号路…

经典面试题之滑动窗口专题

class Solution { public:int minSubArrayLen(int target, vector<int>& nums) {// 长度最小的子数组 // 大于等于 targetint min_len INT32_MAX;// 总和int sum 0;int start 0; // 起点for(int i 0; i< nums.size(); i) {sum nums[i];while(sum > targe…

京东淘宝1688商品采集商品数据抓取API

item_get-获得淘宝商品详情 item_search 关键字搜索商品 公共参数 请求地址: taobao/item_search 名称类型必须描述keyString是调用key&#xff08;必须以GET方式拼接在URL中&#xff09;secretString是调用密钥api_nameString是API接口名称&#xff08;包括在请求地址中&a…

安卓自动化脚本制作流程详解!

在移动应用日益普及的今天&#xff0c;安卓自动化脚本制作成为了开发者提高工作效率、减少重复劳动的重要手段&#xff0c;本文将详细介绍安卓自动化脚本的制作流程&#xff0c;并通过五段源代码的实例&#xff0c;帮助读者更好地理解和掌握这一过程。 一、安卓自动化脚本制作…

Vector Laboratories|用于生物偶联疗法BioDesign™ dPEG® Linker连接平台

术语dPEG代表“离散PEG&#xff08;discrete PEG&#xff09;”&#xff0c;这是一种均一的、单分子量&#xff08;MW&#xff09;、高纯度的新一代聚乙二醇聚合物。Vector Laboratorie采用其受专利保护的专有生产工艺&#xff0c;可生产提供适合于各种应用场景&#xff0c;具有…

卸载系统自带APP

Firefly RK3588 android 12自动多个系统软件&#xff0c;无法从UI界面进行手动删除。因此&#xff0c;考虑使用shell指令进行处理。 系统自动APP大多都安装在system/app目录下&#xff0c;且该目录多为只读。因此采用如下步骤&#xff0c; //Shell su adb shell su //重新挂载…

食品饮料-冲饮市场线上发展现状:香飘飘品牌监控数据分析

近期&#xff0c;老国货品牌香飘飘在国内备受关注&#xff0c;起因是某网友在日本华人超市内看到香飘飘Meco果汁茶产品包装统一增加了几组“海洋不是日本的下水道”、“请日本政客豪饮核污水”、“地球可以没有日本但不能没有海洋”等中日双语标语&#xff0c;正大光明讽刺日本…

Xinstall:专业的APP全渠道统计服务商,助力广告数据分析

在移动互联网时代&#xff0c;APP已成为企业营销的重要阵地。然而&#xff0c;随着竞争加剧&#xff0c;广告主们面临着如何精准衡量广告投放效果、优化投放策略等挑战。这时&#xff0c;专业的APP全渠道统计服务商——Xinstall便成为了广告主们的得力助手。 Xinstall作为国内…

2024最新行业领域名词解释大全

2024最新行业领域名词解释大全 &#x1f680; 大家好&#xff01;我是你们的老朋友猫头虎&#x1f42f;。今天要为大家带来2024年最新的行业领域名词解释大全&#xff01;在这个信息爆炸的时代&#xff0c;准确了解不同领域的行业动态、工作机会和职业前景至关重要。下面我会分…

创建操作手册知识库的终极指南

在繁忙的工作中&#xff0c;有一个方便好用的操作手册知识库能帮我们节省大量时间&#xff0c;避免走弯路。那么&#xff0c;如何创建这样一个知识库呢&#xff1f;下面就给大家讲解一下简单易学的创建步骤。 一、明确目标与需求 在创建操作手册知识库之前&#xff0c;首先要明…

Vue + Element-plus 快速入门

1. 构建项目 npm init vuelatest # 可选项一路回车&#xff0c;使用默认NO,按提示执行3条命令 cd 项目名 npm install npm run dev 2. 下载element-plus npm install element-plus --save 3.替换main.js import { createApp } from vue import ElementPlus from element-plu…

解决问题:Docker证书到期(Error grabbing logs: rpc error: code = Unknown)导致无法查看日志

问题描述 Docker查看日志时portainer报错信息如下&#xff1a; Error grabbing logs: rpc error: code Unknown desc warning: incomplete log stream. some logs could not be retrieved for the following reasons: node klf9fdsjjt5tb0w4hxgr4s231 is not available报错…

零基础代码随想录【Day27】|| 39. 组合总和,40.组合总和II, 131.分割回文串

目录 DAY27 39. 组合总和 解题思路&代码 40.组合总和II 解题思路&代码 131.分割回文串 解题思路&代码 DAY27 39. 组合总和 力扣题目链接(opens new window) 给定一个无重复元素的数组 candidates 和一个目标数 target &#xff0c;找出 candidates 中所有…

国家电网某地电力公司网络硬件综合监控运维项目

国家电网某地电力公司是国家电网有限公司的子公司&#xff0c;负责当地电网规划、建设、运营和供电服务&#xff0c;下属多家地市供电企业和检修公司、信息通信公司等业务支撑实施机构。 项目现状 随着公司信息化建设加速&#xff0c;其信息内网中存在大量物理服务器、存储设备…

牛客网刷题 | BC78 KiKi说祝福语

目前主要分为三个专栏&#xff0c;后续还会添加&#xff1a; 专栏如下&#xff1a; C语言刷题解析 C语言系列文章 我的成长经历 感谢阅读&#xff01; 初来乍到&#xff0c;如有错误请指出&#xff0c;感谢&#xff01; 描述 2020年来到了&#…

Optional学习记录

Optional出现的意义 在Java中&#xff0c;我们经常遇到的一种异常情况&#xff1a;空指针异常&#xff0c;在原本的编程中&#xff0c;为了避免这种异常&#xff0c;我们通常会向对象进行判断&#xff0c;然而&#xff0c;过多的判断语句会让我们的代码显得臃肿不堪。 所以在J…

通过mask得到bbox(numpy实现)

在SAM的加持下&#xff0c;我们很容易得到物体的mask&#xff0c;但是物体的bbox信息通常也很有用。那么&#xff0c;我们可以写一个函数&#xff0c;立马可以通过mask得到bbox。 代码如下&#xff1a; import numpy as npdef mask2bbox(mask):nonzero_indices np.nonzero(m…

淤地坝安全监测预警系统解决方案

一、方案背景 淤地坝是黄土高原地区人民群众长期同水土流失斗争实践中创造的一种行之有效的水土保持工程措施&#xff0c;在拦泥保土、减少入黄泥沙、防洪减灾、淤地造田、巩固退耕还林&#xff08;草&#xff09;、保障生态安全、促进粮食生产和水资源合理利用及经济社会稳定发…

做抖音小店需要注意什么?这几点很多人不知道,看完防踩坑

大家好&#xff0c;我是电商笨笨熊 抖音小店虽然推出了一段时间&#xff0c;但是依旧有新手玩家陆陆续续加入其中&#xff1b; 对于很多新手来说&#xff0c;只看到了其中红利&#xff0c;但却没有看到其中包含的一些运营小细节&#xff0c;且这些细节决定你店铺未来发展&…
最新文章