LLM 加速技巧:Muti Query Attention

MQA 是 19 年提出的一种新的 Attention 机制,其能够在保证模型效果的同时加快 decoder 生成 token 的速度。在大语言模型时代被广泛使用,很多LLM都采用了MQA,如Falcon、PaLM、StarCoder等。

在介绍MQA 之前,我们先回顾一下传统的多头注意力

Multi-Head Attention(MHA)

多头注意力是transformer 模型的默认注意力机制,如下图所示:

在文本生成方面,基于transformer 的自回归语言模型存在一个问题。在训练过程中可以获得真实的目标序列,并且可以有效地实现并行化。

但是在推理过程中,每个位置的查询都要处理在该位置或之前生成的所有键值对。也就是说自注意力层在特定位置的输出影响下一个令牌的生成,所以无法并行化,这使得推理变得非常的慢。

下图是基于transformer 解码器的自回归语言模型中自注意层的解码过程:

 defMHAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
     q=tf.einsum("bd, hdk−>bhk", x, P_q)
     new_K=tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, hdk−>bhk", x, P_k), axis=2)], axis=2)
     new_V=tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, hdv−>bhv", x, P_v), axis=2)], axis=2)
     logits=tf.einsum("bhk, bhmk−>bhm", q, new_K)
     weights=tf.softmax(logits)
     O=tf.einsum("bhm, bhmv−>bhv", weights, new_V)
     Y=tf.einsum("bhv, hdv−>bd", O, P_o)
     returnY, new_K, new_V

其中:

X:当前的输入张量,m为当前步,m+1为阶跃,形状为[b, d]

P_q, P_k:查询和键投影张量,形状为[h, d, k]

P_v:值投影张量,形状为[h, d, v]

P_o:学习到的线性投影,形状为[h, d, v]

Prev_K:上一步的关键张量,形状为[b, h, m, k]

Prev_V:前一步的Value张量,形状为[b, h, m, v]

new_K:加上当前步的键张量,形状为[b, h, m+1, k]

new_V:加了当前步长的Value张量,形状为[b, h, m+1, v]

维度表示如下:

M:先前执行的步骤数

B:批量大小

D:输入和输出的尺寸

H:注意力头数

k:Q,K张量的另一个维度

v: v张量的另一个维度

Multi-Query Attention(MQA)

MQA是多头注意的一种变体。

MQA的方法是保持Q的初始头数,但K和V只有一个头,这意味着所有Q个头共享相同的K和V,因此称为Multi-Query,如下图所示:

从论文的解释中可以看到,MQA 让所有的头之间 共享 同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。

MQA解码过程的代码本质上与MHA的代码相同,只是从中删除了表示头部尺寸的字母“h”。K, V, P_k, P_v的和方程:

 defMQAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
     q=tf.einsum("bd, hdk−>bhk", x, P_q)
     new_K=tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, dk−>bk", x, P_k), axis=2)], axis=2)
     new_V=tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, dv−>bv", x, P_v), axis=2)], axis=2)
     logits=tf.einsum("bhk, bmk−>bhm", q, new_K)
     weights=tf.softmax(logits)
     O=tf.einsum("bhm, bmv−>bhv", weights, new_V)
     Y=tf.einsum("bhv, hdv−>bd", O, P_o)
     returnY, new_K, new_V

上面都是tf的代码,如果阅读有问题,我从 llm-foundry项目中找到了pytorch的代码实现,这里只做个摘抄,有兴趣的请看原项目

 classMultiheadAttention(nn.Module):
 
     def__init__(
             self,
             d_model: int,
             n_heads: int,
             device: str
         ):
         """
         Multi Head init func.
 
         Args:
             d_model (int): hidden state size, e.g. 768
             n_heads (int): 设定的注意力头数, e.g. 8
             device (str): _description_
         """
         super().__init__()
 
         self.d_model=d_model
         self.n_heads=n_heads
     
         self.Wqkv=nn.Linear(                       # Multi-Head Attention 的创建方法
             self.d_model, 
             3*self.d_model,                        # 有 query, key, value 3 个矩阵, 所以是 3 * d_model
             device=device
         )                                            # (d_model, 3 * d_model)
         self.attn_fn=scaled_multihead_dot_product_attention
         self.out_proj=nn.Linear(
             self.d_model, 
             self.d_model, 
             device=device
         )
 
     defforward(
         self,
         x
     ):
         """
         forward func.
 
         Args:
             x (tensor): (batch, hidden_state, d_model) e.g. -> (1, 768, 512)
 
         Returns:
             _type_: _description_
         """
         qkv=self.Wqkv(x)                            # (1, 768, 3 * 768)
 
         query, key, value=qkv.chunk(                # 每个 tensor 都是 (1, 512, 768)
             3, 
             dim=2
         )     
 
         context, attn_weights, past_key_value=self.attn_fn(
             query,
             key,
             value,
             self.n_heads
         )                                             # (1, 512, 768)
 
         returnself.out_proj(context), attn_weights, past_key_value
 
 
 classMultiQueryAttention(nn.Module):
     """Multi-Query self attention.
 
     Using torch or triton attention implemetation enables user to also use
     additive bias.
     """
 
     def__init__(
         self,
         d_model: int,
         n_heads: int,
         device: Optional[str] =None,
     ):
         super().__init__()
 
         self.d_model=d_model
         self.n_heads=n_heads
         self.head_dim=d_model//n_heads
 
         self.Wqkv=nn.Linear(                           # Multi-Query Attention 的创建方法
             d_model,
             d_model+2*self.head_dim,                 # 只创建 query 的 head 向量,所以只有 1 个 d_model
             device=device,                               # 而 key 和 value 则只共享各自的一个 head_dim 的向量
         )
 
         self.attn_fn=scaled_multihead_dot_product_attention
         self.out_proj=nn.Linear(
             self.d_model, 
             self.d_model, 
             device=device
         )
         self.out_proj._is_residual=True  # type: ignore
 
     defforward(
         self,
         x,
     ):
         qkv=self.Wqkv(x)                                           # (1, 512, 960)
 
         query, key, value=qkv.split(                               # query -> (1, 512, 768)
             [self.d_model, self.head_dim, self.head_dim],            # key   -> (1, 512, 96)
             dim=2                                                    # value -> (1, 512, 96)
         )
 
         context, attn_weights, past_key_value=self.attn_fn(
             query,
             key,
             value,
             self.n_heads,
             multiquery=True,
         )
 
         returnself.out_proj(context), attn_weights, past_key_value

从代码中可以看到所有 头之间共享一份 key 和 value 的参数,但是如何将这 1 份参数同时让 8 个头都使用呢?

代码里使用矩阵乘法 matmul 来广播,使得每个头都乘以这同一个 tensor,以此来实现参数共享,主要是这个函数:scaled_multihead_dot_product_attention

 defscaled_multihead_dot_product_attention(
         query,
         key,
         value,
         n_heads,
         past_key_value=None,
         softmax_scale=None,
         attn_bias=None,
         key_padding_mask=None,
         is_causal=False,
         dropout_p=0.0,
         training=False,
         needs_weights=False,
         multiquery=False,
     ):
     q=rearrange(query, 'b s (h d) -> b h s d', h=n_heads)         # (1, 512, 768) -> (1, 8, 512, 96)
     kv_n_heads=1ifmultiqueryelsen_heads
     k=rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)        # (1, 512, 768) -> (1, 8, 96, 512) if not multiquery 
                                                                     # (1, 512, 96) -> (1, 1, 96, 512)  if multiquery
     v=rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)      # (1, 512, 768) -> (1, 8, 512, 96) if not multiquery 
                                                                     # (1, 512, 96) -> (1, 1, 512, 96)  if multiquery
     
     attn_weight=q.matmul(k) *softmax_scale                       # (1, 8, 512, 512)
     attn_weight=torch.softmax(attn_weight, dim=-1)                # (1, 8, 512, 512)
 
     out=attn_weight.matmul(v)                                     # (1, 8, 512, 512) * (1, 1, 512, 96) = (1, 8, 512, 96)
     out=rearrange(out, 'b h s d -> b s (h d)')                    # (1, 512, 768)
 
     returnout, attn_weight, past_key_value

MQA指标测试

MQA能在多大程度上提高速度?让我们看看原文中提供的结果图表:

从上表可以看出,MQA在编码器上的速度提升不是很显著,但在解码器上的速度提升是相当显著的。

论文中也有关于质量的实验,结果表明MQA的性能与基线相比只是稍微低一些。降低应该是肯定的因为毕竟共享了参数,但是只要再可接受范围内并且能够大量提升速度这个降低就是可以接受的,对吧。

为什么MQA可以实现推理加速?

在MQA中,键张量和值张量的大小分别为b * k和b * v,而在MHA中,键张量和值张量的大小分别为b * h * k和b * h * v,其中h表示头的个数。

MQA通过以下方法实现推理加速:

1、KV缓存大小减少了h(头数量),这意味着需要存储在GPU内存中的张量也减少了。节省的空间可以用来增加批大小,从而提高效率。

2、减少了从内存中读取的数据量,从而减少了计算单元的等待时间,提高了计算利用率。

3、MQA有一个相对较小的KV数量,可以放入缓存(SRAM)中。MHA则需要较大的KV数量,不能完全存储在缓存中,需要从GPU内存(DRAM)读取,这很耗时。

总结

MQA是在2019年提出的,当时的应用还没有那么广泛。这是因为以前的模型不需要关心这些方面,例如,LSTM只需要维护一个状态,而不需要保留任何缓存。

当transformer最初被提出时,它主要用于Seq2Seq任务,特别是在Encoder-Decoder模型中。由于模型的规模不是很大,也并且没有太多的实际需求,所以MQA并没有引起太多的关注。

直到近年来(尤其是2023年开始)基于transformer的大型语言模型(如GPT)得到广泛应用后,推理的瓶颈才被人们重视。所以MQA才被发现非常有用,这主要是由于对大规模gpt式生成模型的实际需求。

最后我们再回顾以下这个论文:

https://avoid.overfit.cn/post/877de0f5a56d478d8133d75a05064e7e

作者:Florian June

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/436184.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

利用GPT开发应用001:GPT基础知识及LLM发展

文章目录 一、惊艳的GPT二、大语言模型LLMs三、自然语言处理NLP四、大语言模型LLM发展 一、惊艳的GPT 想象一下,您可以与计算机的交流速度与与朋友交流一样快。那会是什么样子?您可以创建哪些应用程序?这正是OpenAI正在助力构建的世界&#x…

Ethersacn的交易数据是什么样的(2)

分析 Raw Transanction RLP(Recursive Length Prefix)是一种以太坊中用于序列化数据的编码方式。它被用于将各种数据结构转换为二进制格式,以便在以太坊中传输和存储。RLP 是一种递归的编码方式,允许对复杂的数据结构进行编码。所…

word如何实现不同章节显示不同页眉

一、问题描述 写论文时遇到如下情形,第二章页眉跟第一章一样,如下图 二、解决方法 在第二章前一页空白处,选择依次布局→分隔符→下一页,如下图 双击第二章页眉,进入页眉编辑状态,点击链接到前一节按钮&a…

SOC设计:关于时钟门控的细节

有如下几个信号 输入信号 1、同步后的rstnsync_clk 2、时钟:clk 3、test_mode 4、软件控制信号:clk_sub_en 输出信号 1、clk_sub 功能:软件配置的使能信号clk_sub_en经过时钟clk 2拍同步处理后产生clk 域下的enable信号,然…

2024年腾讯云服务器99元一年,最新价格整理

腾讯云服务器99元一年是真的吗?真的,只是又降价了,现在只要61元一年,配置为2核2G3M轻量应用服务器,40GB SSD盘,腾讯云百科txybk.com分享腾讯云官方活动购买链接 https://curl.qcloud.com/oRMoSucP 活动打开…

Python编程实验六:面向对象应用

目录 一、实验目的与要求 二、实验内容 三、主要程序清单和程序运行结果 第1题 第2题 四、实验结果分析与体会 一、实验目的与要求 (1)通过本次实验,学生应掌握类的定义与对象的创建、类的继承与方法的覆盖; (2…

鸿道Intewell-Win_V2.1.3_kyland软件版本发布说明

一、软件发布版本信息 版本号:V2.1.3_kyland 版本发布类型:trail试用版本 二、版本特点 适配 E211-1370(J6412,8GB,256GB SSD)设备 三、运行环境推荐 Intewell developer可以运行在windows7及windows10 64位 四、支…

程序员书单推荐:从入门到精通的必读之作

在程序员的职业生涯中,阅读技术书籍是不断学习和提升自我的重要途径。本文将为你推荐一系列从入门到精通的程序员书单,帮助你系统地掌握编程知识、提高技能水平,并在职业生涯中取得更大的进步。 一、入门篇 《Head First C语言》&#xff1…

掌握流量主变现秘诀!视频号”今日话题”赛道,详解保姆式教学一体化实操玩法,助你轻松驾驭!

其实,这个领域的制作相当简单。 只需按照下面我提供的教程操作,基本上十分钟内就能完成一个视频。 掌握流量主变现秘诀!视频号”今日话题”赛道,详解保姆式教学一体化实操玩法,助你轻松驾驭! 就收益而言,…

何为时间复杂度和空间复杂度

时间复杂度和空间复杂度是用来评估算法性能的两个重要指标。 1. **时间复杂度**: - 时间复杂度描述了算法执行所需的时间量随输入数据规模的增加而增加的趋势。通常用大O符号(O)表示,表示算法的渐近上界。例如,O(n…

STM32(8)NVIC编程

中断源由部分片上外设产生 在misc.h中找,杂项 配置NVIC GPIO和AFIO不能产生中断源,但能通过EXTI,由EXTI产生中断源 NVIC不需要开启时钟,因为NVIC模块位于内核内部,芯片一上电就能工作。 中断响应函数 中断向量表在启…

rtthread stm32h743的使用(七)dac设备使用

我们要在rtthread studio 开发环境中建立stm32h743xih6芯片的工程。我们使用一块stm32h743及fpga的核心板完成相关实验,核心板如图: 1.我们还是先建立工程 2.生成工程后打开mx进行配置,时钟配置如前所讲,不在赘述 3.更改mx文件…

观其大略之HybridCLR学习笔记

问题背景 1 现有热更方案的开发效率、性能没有到达极限,还有提升的空间 2 ios多平台政策导致热更新受限问题,ios禁止jit。根据我查找的资料,ios的代码段启动的时候就确定了,不能增加新的代码段。IOS封了内存(或者堆&…

你不得不知道的Python AI库

Python是人工智能(AI)和机器学习(ML)领域中使用最广泛的编程语言之一,拥有丰富的库支持各种AI和ML任务。本文介绍一些经典的Python AI库。 1. NumPy 简介:NumPy(Numerical Python)…

开源工业软件:SCADA系统开源

PyScada是一个开源的scada系统 源代码地址 http://www.gitpp.com/huangtomy/pyscada-cn SCADA系统是Supervisory Control And Data Acquisition的缩写,即数据采集与监视控制系统。它是以计算机为基础的DCS与电力自动化监控系统,应用领域非常广&#x…

LeetCode.2917. 找出数组中的 K-or 值

题目 2917. 找出数组中的 K-or 值 分析 这道题其实是要我们求第i位二进制为1的元素个数至少为k,把符合条件的2^i全部加到一起。 因此,我们的思路就是枚举数组的每一位,并且进行以下两个步骤: 统计所有元素第i位1的个数cnt。…

哪个职业是科学育婴的好帮手?3月7日蚂蚁新村今日答案:育婴师

蚂蚁新村是一个虚拟社区。在这个虚拟社区中,用户可以参与各种活动,比如生产能量豆、做慈善捐赠等。同时,蚂蚁新村也提供了一些知识问答环节,用户在参与的过程中可以增进知识。这些问答内容往往涉及广泛的主题,如文化、…

Docker本地部署Redis容器结合内网穿透实现无公网ip远程连接

文章目录 前言1. 安装Docker步骤2. 使用docker拉取redis镜像3. 启动redis容器4. 本地连接测试4.1 安装redis图形化界面工具4.2 使用RDM连接测试 5. 公网远程访问本地redis5.1 内网穿透工具安装5.2 创建远程连接公网地址5.3 使用固定TCP地址远程访问 前言 本文主要介绍如何在Ub…

Vue项目实战--空间论坛(1)

环境准备 安装好node.js,Vue后 添加插件 router---路由,多页面的应用 vuex---在多个组件之间维护同一个数据 添加依赖 bootstrap---美工 popperjs/core vue项目介绍 views-----对应vue文件,页面 router-----路由,页面,c…

【深度学习笔记】优化算法——随机梯度下降

随机梯度下降 在前面的章节中,我们一直在训练过程中使用随机梯度下降,但没有解释它为什么起作用。为了澄清这一点,我们刚在 :numref:sec_gd中描述了梯度下降的基本原则。本节继续更详细地说明随机梯度下降(stochastic gradient d…
最新文章