一文通透各种注意力:从多头注意力MHA到分组查询注意力GQA、多查询注意力MQA

前言

通过本博客内之前的文章可知,自回归解码的标准做法是缓存序列中先前标记的键(K)和值(V) 对,从而加快注意力计算速度。然而,随着上下文窗口或批量大小的增加,多头注意力 (MHA)模型中与 KV 缓存大小相关的内存成本显着增长

对于较大的模型,KV 缓存大小成为瓶颈,键和值投影可以在多个头之间共享,而不会大幅降低性能,可以使用

  • 具有单个 KV 投影的原始多查询格式(MQA),ChatGLM2-6B即用的这个
    不过,多查询注意(Multi-query attention,简称MQA)只使用一个键值头,虽大大加快了解码器推断的速度,但MQA可能导致质量下降,而且仅仅为了更快的推理而训练一个单独的模型可能是不可取的
  • 或具有多个 KV 投影的分组查询注意力(grouped-query attention,简称GQA),LLaMA2和Mistral均用的这个
    这是一种多查询注意的泛化,它通过折中(多于一个且少于查询头的数量,比如4个)键值头的数量,使得经过强化训练的GQA以与MQA相当的速度达到接近多头注意力的质量,即速度快 质量高

经实验论证,GQA 变体在大多数评估任务上的表现与 MHA 基线相当,并且平均优于 MQA 变体

多头注意力MHA分组查询注意力GQA多查询注意力MQA
LLaMA2
Mistral
ChatGLM2

以下是这三种注意力机制在结构上的对比

第一部分 多头注意力

// 待更

第二部分 LLaMA2之分组查询注意力——Grouped-Query Attention

23年,Google的研究者们提出了一种新的方法,即分组查询注意(GQA,论文地址为:GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints)

// 待更

第三部分 ChatGLM2之多查询注意力(Muti Query Attention)

3.1 MQA的核心特征:各自Query矩阵,但共享Key 和 Value 矩阵

多查询注意力(Muti Query Attention)是 19 年Google一研究者提出的一种新的 Attention 机制(对应论文为:Fast Transformer Decoding: One Write-Head is All You Need、这是其解读之一),其能够在保证模型效果的同时加快 decoder 生成 token 的速度

那其与17年 Google提出的transformer中多头注意力机制(简称MHA)有啥本质区别呢?有意思的是,区别在于:

  • 我们知道MHA的每个头都各自有一份不同的Key、Query、Value矩阵
  • 而MQA 让所有的头之间 共享 同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量
    总之,MQA 实际上是将 head 中的 key 和 value 矩阵抽出来单独存为一份共享参数,而 query 则是依旧保留在原来的 head 中,每个 head 有一份自己独有的 query 参数

下图对比了多头注意力(Multi-Head Attention)、LLaMA2中分组查询注意力(Grouped-Query Attention)、多查询注意力(Muti Query Attention)的差别

总之,MHA 和 MQA 之间的区别只在于建立 Wqkv Layer 上

# Multi Head Attention
self.Wqkv = nn.Linear(                        # 【关键】Multi-Head Attention 的创建方法
    self.d_model, 
    3 * self.d_model,                         # 有 query, key, value 3 个矩阵, 所以是 3 * d_model
    device=device
)

query, key, value = qkv.chunk(                # 【关键】每个 tensor 都是 (1, 512, 768)
    3, 
    dim=2
)


# Multi Query Attention
self.Wqkv = nn.Linear(                                # 【关键】Multi-Query Attention 的创建方法
    d_model,
    d_model + 2 * self.head_dim,                      # 只创建 query 的 head 向量,所以只有 1 个 d_model
    device=device,                                    # 而 key 和 value 不再具备单独的头向量
)

query, key, value = qkv.split(                        # query -> (1, 512, 768)
    [self.d_model, self.head_dim, self.head_dim],     # key   -> (1, 512, 96)
    dim=2                                             # value -> (1, 512, 96)
)

对比上面的代码,你可以发现

  • 在 MHA 中,query, key, value 每个向量均有 768 维度
  • 而在 MQA 中,只有 query 是 768 维,而 key 和 value 均只剩下 96 维了,恰好是 1 个 head_dim 的维度

因此,可以确认:在 MQA 中,除了 query 向量还保存着 8 个头,key 和 value 向量都只剩 1 个「公共头」了,这也正好印证了论文中所说的「所有 head 之间共享一份 key 和 value 的参数」

剩下的问题就是如何将这 1 份参数同时让 8 个头都使用,代码里使用矩阵乘法 matmul 来广播,使得每个头都乘以这同一个 tensor,以此来实现参数共享:

def scaled_multihead_dot_product_attention(
        query,
        key,
        value,
        n_heads,
        multiquery=False,
    ):
    q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)         # (1, 512, 768) -> (1, 8, 512, 96)
    kv_n_heads = 1 if multiquery else n_heads
    k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)        # (1, 512, 768) -> (1, 8, 96, 512) if not multiquery 
                                                                    # (1, 512, 96) -> (1, 1, 96, 512)  if multiquery
    v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)      # (1, 512, 768) -> (1, 8, 512, 96) if not multiquery 
                                                                    # (1, 512, 96) -> (1, 1, 512, 96)  if multiquery
    
    attn_weight = q.matmul(k) * softmax_scale                       # (1, 8, 512, 512)
    attn_weight = torch.softmax(attn_weight, dim=-1)                # (1, 8, 512, 512)

    out = attn_weight.matmul(v)                                     # (1, 8, 512, 512) * (1, 1, 512, 96) = (1, 8, 512, 96)
    out = rearrange(out, 'b h s d -> b s (h d)')                    # (1, 512, 768)

    return out, attn_weight, past_key_value

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/116757.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【多线程】Lambda表达式

package org.example;public class TestLambda {public static void main(String[] args) {Like likenew Like();like.lambda();}}//定义一个函数式接口 interface ILike{void lambda(); }//实现类 class Like implements ILike{Overridepublic void lambda() {System.out.prin…

Excel自学三部曲_Part3:Excel工作场景实战(四)

文章目录 四、高级函数与数据连接1. 多窗口操作2. VLOOKUP函数3. XLOOKUP函数4. CSV数据格式 四、高级函数与数据连接 1. 多窗口操作 如何将两张子表数据(战区信息、城市信息)连接到主表数据(成交数据),增加主要数据的…

“一键批量拆分HTML文本,高效整理文件,提升工作效率“

您是否曾经被大量的HTML文本文件困扰,难以找到所需的特定信息?现在,我们向您推荐一款强大的工具,它能够一键拆分HTML文本,让您轻松实现文件整理,提高工作效率! 首先,在首助编辑高手…

人工智能基础_机器学习014_BGD批量梯度下降公式更新_进一步推导_SGD随机梯度下降和MBGD小批量梯度下降公式进一步推导---人工智能工作笔记0054

然后我们先来看BGD批量梯度下降,可以看到这里,其实这个公式来源于 梯度下降的公式对吧,其实就是对原始梯度下降公式求偏导以后的梯度下降公式,然后 使用所有样本进行梯度下降得来的,可以看到* 1/n 其实就是求了一个平均数对吧.所有样本的平均数. 然后我们看,我们这里* 1/n那么…

API接口安全设计

简介 HTTP接口是互联网各系统之间对接的重要方式之一,使用HTTP接口开发和调用都很方便,也是被大量采用的方式,它可以让不同系统之间实现数据的交换和共享。 由于HTTP接口开放在互联网上,所以我们就需要有一定的安全措施来保证接口…

C++11 initializer_list 轻量级初始化列表的使用场景(让自定义类可以用初始化列表的形式来实例化对象)

initializer_list 是 C11 中的一个特性&#xff0c;它允许你使用花括号 {} 中的值列表来初始化容器或数组。通常用于初始化标准库容器&#xff0c;比如 std::vector、std::set、std::map 以及数组。 场景一&#xff1a;用初始化列表初始化容器 std::vector<int> arr {…

【深度学习】pytorch——Autograd

笔记为自我总结整理的学习笔记&#xff0c;若有错误欢迎指出哟~ 深度学习专栏链接&#xff1a; http://t.csdnimg.cn/dscW7 pytorch——Autograd Autograd简介requires_grad计算图没有梯度追踪的张量ensor.data 、tensor.detach()非叶子节点的梯度计算图特点总结 利用Autograd实…

scrapy+selenium框架模拟登录

目录 一、cookie和session实现登录原理 二、模拟登录方法-Requests模块Cookie实现登录 三、cookiesession实现登录并获取数据 四、selenium使用基本代码 五、scrapyselenium实现登录 一、cookie和session实现登录原理 cookie:1.网站持久保存在浏览器中的数据2.可以是长期…

Day20力扣打卡

打卡记录 数组中两个数的最大异或值&#xff08;位运算&#xff09; 链接 二进制位上从高位向低位进行模拟&#xff0c;看数组中是否有满足此情况的数字。具体题解 class Solution { public:int findMaximumXOR(vector<int>& nums) {int mx *max_element(nums.be…

【存档】vscode配置latex环境

原来在另一台电脑上找了个教程配了一遍&#xff0c;这次重新配的时候&#xff0c;那个教程作者更新过后&#xff0c;改成只有linux的脚本了&#xff0c;所以存档一下。真香警告, 2023年初的vscodelatex写作 - 知乎 (zhihu.com) 环境&#xff1a; win10/win11vscodelatex work…

【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)

目录 0. 前言 1. Cifar10数据集 2. AlexNet网络模型 2.1 AlexNet的网络结构 2.2 激活函数ReLu 2.3 Dropout方法 2.4 数据增强 3. 使用GPU加速进行批量训练 4. 网络模型构建 5. 训练过程 6. 完整代码 0. 前言 按照国际惯例&#xff0c;首先声明&#xff1a;本文只是我…

分享81个工作总结PPT,总有一款适合您

分享81个工作总结PPT&#xff0c;总有一款适合您 PPT下载链接&#xff1a;https://pan.baidu.com/s/13hyrlZo2GhRoQjI-6z31-w?pwd8888 提取码&#xff1a;8888 Python采集代码下载链接&#xff1a;采集代码.zip - 蓝奏云 学习知识费力气&#xff0c;收集整理更不易。知识付…

IDEA创建Springboot多模块项目

一、创建父模块 File --> New --> Project &#xff0c;选择 “ Spring Initalizr ” &#xff0c;点击 Next Next Next --> Finish 二、创建子模块 右键根目录&#xff0c;New --> Module 选择 “ Spring Initializr ”&#xff0c;点击Next 此处注意T…

设置IDEA快捷生成方法头,类头注释

1.File->settings->editor->live templates进入Live Template界面进行设置&#xff1a; 下一步&#xff1a; 下一步&#xff1a; /*** Title: $title$* author: sunyanzeng* date: $datatime$*/在需要添加文件头的地方打出“aa”&#xff0c;回车&#xff0c;会自…

go语言 | grpc原理介绍(三)

了解 gRPC 通信模式中的消息流 gRPC 支持四种通信模式&#xff0c;分别是简单 RPC、服务端流式 RPC、客户端流式 RPC 和双向流式 RPC。 简单 RPC 在gRPC中&#xff0c;一个简单的RPC调用遵循请求-响应模型&#xff0c;通常涉及以下几个关键步骤和组件&#xff1a; 请求头&a…

Java自学第2课:Java语言基础知识要点

1 Java主类结构 任务&#xff1a;创建新项目名为item&#xff0c;包名为number&#xff0c;类名为first。 1.1 包声明 不指定包时&#xff0c;默认就是工程名&#xff0c;指定后&#xff0c;类文件可以分类了&#xff0c;是这意思吧。包就大概等于一个文件夹。而且在类文件中…

Educational Codeforces Round 2 D 计算几何

题目链接&#xff1a;Educational Codeforces Round 2 D 题目 给你两个圆。求它们相交处的面积。 输入 第一行包含三个整数 x1, y1, r1 (  - 109 ≤ x1, y1 ≤ 109, 1 ≤ r1 ≤ 109 ) - 第一个圆的圆心位置和半径。 第二行包含三个整数 x2, y2, r2 (  …

【IO多路转接】select编程模型

文章目录 1 :peach:五种IO模型:peach:1.1 :apple:阻塞IO:apple:1.2 :apple:非阻塞IO:apple:1.3 :apple:信号驱动IO:apple:1.4 :apple:IO多路转接:apple:1.5 :apple:异步IO:apple:1.6 :apple:同步通信&异步通信:apple:1.7 :apple:阻塞&非阻塞:apple:1.8 :apple:总结:app…

IDEA快捷键总结+常识积累

&#xff08;一&#xff09;常用快捷键总结 以下快捷键输入完成后按Tab键即可。 1、输入main public static void main(String[] args) {}2、输入sout System.out.println();3、输入fori for (int i 0; i < ; i) {}4、输入foreach&#xff08;增强for循环快捷键&#x…

蓝桥杯(C++ 扫雷)

题目&#xff1a; 思想&#xff1a; 1、遍历每个点是否有地雷&#xff0c;有地雷则直接返回为9&#xff0c;无地雷则遍历该点的周围八个点&#xff0c;计数一共有多少个地雷&#xff0c;则返回该数。 代码&#xff1a; #include<iostream> using namespace std; int g[…
最新文章