Flash Attention(1):背景介绍,与传统Attention对比,前向反向算法解析

0 英文缩写

  • FA: Flash Attention
  • HBM:High Bandwidth Memory,高带宽显存

0 论文

[2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

中文:FlashAttention:一种具有 IO 感知,且兼具快速、内存高效的新型注意力算法

科研团队:斯坦福大学计算机系+纽约州立大学布法罗分校

发表时间:20220527

1 背景:

  • 背景1:应用广泛:Transformer 模型在图像分类、自然语言处理等分支领域中逐渐成为最为常见的架构
  • 背景2:模型扩展:随着技术不断进步,Transformer 模型在尺寸和深度等方面都进一步拓展
  • 背景3:算法复杂度特征:核心模块自注意力机制(self attention)的时间复杂度存储复杂度,均与输入长度(一般即为处理的序列长度)的平方成正比

结合背景123,可以发现更大的模型在更长的上下文背景上还存在着一定的挑战。

  • 背景4:计算读写开销:论文GPU内不同存储系统的速度举例如下:

    • GPU SRAM 读写(I/O)速度19 TB/s
    • GPU HBM 读写(I/O)速度 1.5 TB/s

    image-20231217134356291

2 相关方案

在此背景之下,有人提出一些近似自注意力的方法,旨在减少注意力计算和内存需求。

  • 稀疏近似
  • 低秩近似
  • 它们的组合

缺点:尽管这些方法可以将计算降低到线性或接近线性,但它们过于关注降低每秒所执行的浮点运算次数(FLops),换句话说更倾向于单纯降低计算复杂度。忽略来自内存访问(IO)的开销。不能实现更高且更有实用价值的计算加速范式。

3 传统Attention

(更详细的推导过程和描述可以参考前文)

Attention机制其核心为计算输入向量的相关程度,例如在翻译过程中,不同的英文对中文的依赖程度不同,Attention机制通常可以进行如下描述

3.1 输入输出定义

  • 输入1: Q Q Q 序列(query),其中 { Q = ( q 1 q 2 q 3 ⋮ q m ) ⏟ d k } m ∈ R m × d k , q i ∈ R 1 × d k ∣ i ∈ 1 , 2 , … , m } \left\{Q=\underbrace{\left(\begin{array}{c}q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_m \end{array}\right)}_{d_{k}}\} m \in\mathbb{R}^{m\times d_k}, q_{i}\in\mathbb{R}^{1\times d_k} \mid i\in 1,2, \ldots, m\right\} Q=dk q1q2q3qm }mRm×dk,qiR1×dki1,2,,m
  • 输入2: K K K 序列 (key),其中 { K = ( k 1 k 2 k 3 ⋮ k m ) ⏟ d k } m ∈ R m × d k , k i ∈ R 1 × d k ∣ i = 1 , 2 , … , m } \left\{K=\underbrace{\left(\begin{array}{c}k_1 \\ k_2 \\ k_3 \\ \vdots \\ k_m\end{array}\right)}_{d_{k}}\} m\in\mathbb{R}^{m\times d_k}, k_{i}\in \mathbb{R}^{1\times d_k} \mid i=1,2, \ldots, m\right\} K=dk k1k2k3km }mRm×dk,kiR1×dki=1,2,,m
  • 输入3: V V V 序列 (value) ,其中 { V = ( v 1 v 2 v 3 ⋮ v m ) ⏟ d v } m ∈ R m × d v , v i ∈ R 1 × d v ∣ i = 1 , 2 , … , m } \left\{V=\underbrace{\left(\begin{array}{c}v_1 \\ v_2 \\ v_3 \\ \vdots \\ v_m\end{array}\right)}_{d_{v}}\} m\in\mathbb{R}^{m\times d_v}, v_{i}\in \mathbb{R}^{1\times d_v} \mid i=1,2, \ldots, m\right\} V=dv v1v2v3vm }mRm×dv,viR1×dvi=1,2,,m
  • 输出为$\text { Attention }(Q, K, V) $ 向量,计算公式:

 Attention  ( Q , K , V ) ∈ R m × d v = softmax ⁡ ( Q K T d k ) V \text { Attention }(Q, K, V) \in\mathbb R^{m \times d_{v}}=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V  Attention (Q,K,V)Rm×dv=softmax(dk QKT)V

3.2 算法解析

第一步:矩阵乘法

为什么可以计算得到不同输入向量之间的得分

矩阵乘法

image-20210412163054048

假设共有十个输入向量,每个向量的长度为512,也即为 m = 10 m=10 m=10 d k = 512 d_k=512 dk=512

Q = ( q 1 [ 0 ] ⋯ q 1 [ d k ] ⋮ ⋯ ⋮ q 10 [ 0 ] ⋯ q 10 [ 511 ] ) = ( q 1 ⃗ ⋮ q 10 ⃗ ) Q=\left(\begin{array}{ccc} q_{1}[0] & \cdots & q_{1}[d_k] \\ \vdots & \cdots & \vdots \\ q_{10}[0] & \cdots & q_{10}[511] \end{array}\right) = \left(\begin{array}{c}\vec{q_{1}}\\\vdots\\ \vec{q_{10}} \end{array}\right) Q= q1[0]q10[0]q1[dk]q10[511] = q1 q10

K = ( k 1 [ 0 ] ⋯ k 1 [ 511 ] ⋮ ⋯ ⋮ k 10 [ 0 ] ⋯ k 10 [ 511 ] ) = ( k 1 ⃗ ⋮ k 10 ⃗ ) K=\left(\begin{array}{ccc}k_{1}[0] & \cdots & k_{1}[511] \\\vdots & \cdots & \vdots \\k_{10}[0] & \cdots & k_{10}[511]\end{array}\right) = \left(\begin{array}{c}\vec{k_{1}}\\\vdots\\ \vec{k_{10}} \end{array}\right) K= k1[0]k10[0]k1[511]k10[511] = k1 k10

相乘结果如下
Q ⋅ K T ∈ R m × m = ( q 1 ⃗ ⋮ q 10 ⃗ ) ⋅ ( k 1 ⃗ T ⋯ k 10 ⃗ T ) ( q 1 ⃗ ⋅ k 1 ⃗ T ⋯ q 1 ⃗ ⋅ k 10 ⃗ T ⋮ ⋯ ⋮ q 10 ⃗ ⋅ k 1 ⃗ T ⋯ q 10 ⃗ ⋅ k 10 ⃗ T ) = ( s 1 − 1 ⋯ s 1 − 10 ⋮ ⋯ ⋮ s 10 − 1 ⋯ s 10 − 10 ) Q \cdot K^T \in \mathbf{R}^{m\times m}= \left(\begin{array}{c}\vec{q_{1}}\\\vdots\\ \vec{q_{10}} \end{array}\right) \cdot \left(\vec{k_{1}}^T\cdots \vec{k_{10}}^T\right) \left(\begin{array}{ccc} \vec{q_{1}}\cdot\vec{k_{1}}^T & \cdots & \vec{q_{1}}\cdot\vec{k_{10}}^T \\\vdots & \cdots & \vdots \\\vec{q_{10}}\cdot\vec{k_{1}}^T& \cdots & \vec{q_{10}}\cdot\vec{k_{10}}^T\end{array}\right) =\left(\begin{array}{ccc}s_{1-1} & \cdots & s_{1-10} \\\vdots & \cdots & \vdots \\s_{10-1} & \cdots & s_{10-10}\end{array}\right) QKTRm×m= q1 q10 (k1 Tk10 T) q1 k1 Tq10 k1 Tq1 k10 Tq10 k10 T = s11s101s110s1010

矩阵 S S S中的每一个元素通过分别来自于 Q \mathbf{Q} Q K \mathbf{K} K的两个向量的点乘得到的,通过最原始的矩阵定义,可以得知两个向量的点乘意味着一个向量在另一个向量的投影,也可以李继伟表示向量 q i ⃗ \vec{q_{i}} qi k j ⃗ \vec{k_j} kj 的相似程度

第二步:scaling与归一化

除以一个数字 d k \sqrt{d_{k}} dk 的意义是:

  • 因为如果 d k d_k dk太大,点乘的值太大,如果不做scaling,结果就没有加法注意力好。
  • 为了不让输入太大,导致softmax函数被推动到非常平缓的区域。

将得到scaling后的相似度进行Softmax操作,假定Scaling之后相似度矩阵为
( s 1 − 1 ′ ⋯ s 1 − m ′ ⋮ ⋯ ⋮ s m − 1 ′ ⋯ s m − m ′ ) = ( s 1 − 1 / d k ⋯ s 1 − m / d k ⋮ ⋯ ⋮ s m − 1 / d k ⋯ s m − m / d k ) \left(\begin{array}{ccc}s'_{1-1} & \cdots & s'_{1-m} \\\vdots & \cdots & \vdots \\ s'_{m-1} & \cdots & s'_{m-m}\end{array}\right) = \left(\begin{array}{ccc}s_{1-1}/\sqrt{d_{k}} & \cdots & s_{1-m}/\sqrt{d_{k}} \\\vdots & \cdots & \vdots \\s_{m-1}/\sqrt{d_{k}} & \cdots & s_{m-m}/\sqrt{d_{k}}\end{array}\right) s11sm1s1msmm = s11/dk sm1/dk s1m/dk smm/dk
进行归一化
( s 1 − 1 ′ ′ ⋯ s 1 − m ′ ′ ⋮ ⋯ ⋮ s m − 1 ′ ′ ⋯ s m − m ′ ) = ( e s 1 − 1 ′ ∑ i = 1 m e s 1 − i ′ ⋯ e s 1 − m ′ ∑ i = 1 m e s 1 − i ′ ⋮ ⋯ ⋮ e s m − 1 ′ ∑ i = 1 m e s m − i ′ ⋯ e s m − m ′ ∑ i = 1 m e s m − i ′ ) \left(\begin{array}{ccc}s''_{1-1} & \cdots & s''_{1-m} \\\vdots & \cdots & \vdots \\ s''_{m-1} & \cdots & s'_{m-m}\end{array}\right) = \left(\begin{array}{ccc}\frac{e^{s'_{1-1}}} {\sum_{i=1}^{m} e^{s'_{1-i}} } & \cdots & \frac{e^{s'_{1-m}}} {\sum_{i=1}^{m} e^{s'_{1-i}} } \\\vdots & \cdots & \vdots \\ \frac{e^{s'_{m-1}}} {\sum_{i=1}^{m} e^{s'_{m-i}} } & \cdots & \frac{e^{s'_{m-m}}} {\sum_{i=1}^{m} e^{s'_{m-i}} } \end{array}\right) s11′′sm1′′s1m′′smm = i=1mes1ies11i=1mesmiesm1i=1mes1ies1mi=1mesmiesmm

如此实现一横行的加权和为1,不同的 v i v_i vi 向量获得的加权综合为1

第三步:加权输出

针对计算出来的权重 α i \alpha_{i} αi,通过权重对 V V V中所有的values进行加权求和计算,得到Attention向量
( s 1 − 1 ′ ⋯ s 1 − m ′ ⋮ ⋯ ⋮ s m − 1 ′ ⋯ s m − m ′ ) ( v 1 ⃗ ⋮ v m ⃗ ) \left(\begin{array}{ccc}s'_{1-1} & \cdots & s'_{1-m} \\\vdots & \cdots & \vdots \\ s'_{m-1} & \cdots & s'_{m-m}\end{array}\right)\left(\begin{array}{c}\vec{v_{1}}\\\vdots\\ \vec{v_{m}} \end{array}\right) s11sm1s1msmm v1 vm

3.3 读写IO伪代码

#########Standard Attention Implementation
Require: Matrices Q, K, V ∈ R^{N×d} in HBM.
1: Load Q, K by blocks from HBM, compute S = QK^{T}, write S to HBM.
2: Read S from HBM, compute P = softmax(S), write P to HBM.
3: Load P and V by blocks from HBM, compute O = PV, write O to HBM.
4: Return O.

3.3 关于Attention的总结

  • 采用点乘注意力,这种注意力机制对于加法注意力而言,更快,同时更节省空间。
  • attention抽象为对value的每个表示(token)进行加权,而加权的weight就是 attention weight,而 attention weight 就是根据 querykey 计算得到,其意义为:为了用 value 求出 query 的结果, 根据 querykey 来决定注意力应该放在value的哪部分。

image-20201223152516251

4 Flash Attention

4.1 背景分析

在标准注意力实现中,注意力的性能主要受限于内存带宽,是内存受限的。频繁地从HBM中读写 R N × N \mathbb{R}^{N \times N} RN×N的矩阵是影响性能的主要瓶颈。稀疏近似和低秩近似等近似注意力方法虽然减少了计算量FLOPs,但对于内存受限的操作,运行时间的瓶颈是从HBM中读写数据的耗时,减少计算量并不能有效地减少运行时间(wall-clock time)。针对内存受限的标准注意力,Flash Attention是IO感知的,目标是避免频繁地从HBM中读写数据

4.2 解决方案

从GPU显存分级来看,SRAM的读写速度比HBM高一个数量级,但内存大小要小很多。通过kernel融合的方式,将多个操作融合为一个操作,利用高速的SRAM进行计算,可以减少读写HBM的次数,从而有效减少内存受限操作的运行时间。但SRAM的内存大小有限,不可能一次性计算完整的注意力,因此必须进行分块计算,使得分块计算需要的内存不超过SRAM的大小。

问题一:为什么要进行分块计算呢?

内存受限 --> 减少HBM读写次数 --> kernel融合 --> 满足SRAM的内存大小 --> 分块计算

因此分块大小block_size不能太大,否则会导致存储内容踢出。

问题二:分块计算的难点是什么呢?

注意力机制的计算过程是“矩阵乘法 --> scale --> mask --> softmax --> dropout --> 矩阵乘法”,矩阵乘法和逐点操作(scale,mask,dropout)的分块计算是容易实现的,难点在于softmax的分块计算。由于计算softmax的归一化因子(分母)时,需要获取到完整的输入数据,进行分块计算的难度比较大。论文中也是重点对softmax的分块计算进行了阐述。

tiling的主要思想是分块计算注意力。分块计算的难点在于softmax的分块计算,softmax与矩阵 K K K 的列是耦合的,通过引入了两个额外的统计量 m ( x ) m(x) m(x) l ( x ) l(x) l(x)来进行解耦,实现了分块计算。需要注意的是,可以利用GPU多线程同时并行计算多个block的softmax。为了充分利用硬件性能,多个block的计算不是串行(sequential)的, 而是并行的

4.3 前向算法伪代码:Softmax的IO缩减

一个简单的例子实现分块计算Softmax

对向量 A = [ 1 , 2 , 3 , 4 ] A = [1,2,3,4] A=[1,2,3,4] 计算Softmax,分成两块 A 1 = [ 1 , 2 ] A_1 = [1,2] A1=[1,2] A 2 = [ 3 , 4 ] A_2 = [3,4] A2=[3,4] 进行计算。 计算block1和block2:

block1
m 1 = m a x ( [ 1 , 2 ] ) = 2 f 1 = [ e 1 − m 1 , e 2 − m 1 ] = [ e − 1 , e 0 ] l 1 = ∑ f 1 = e − 1 + e 0 o 1 = f 1 l 1 = [ e − 1 , e 0 ] e − 1 + e 0 = [ e − 1 e − 1 + e 0 , e 0 e − 1 + e 0 ] m_1 = max([1,2]) = 2\\ f_1 = [e^{1-m_1},e^{2-m_1}] = [e^{-1},e^0]\\ l_1 = \sum f_1 = e^{-1} + e^0\\ o_1 = \frac{f_1}{l_1} = \frac{[e^{-1},e^0]}{e^{-1} + e^0} = \left[ \frac{e^{-1}}{e^{-1} + e^0}, \frac{e^0}{e^{-1} + e^0}\right] m1=max([1,2])=2f1=[e1m1,e2m1]=[e1,e0]l1=f1=e1+e0o1=l1f1=e1+e0[e1,e0]=[e1+e0e1,e1+e0e0]
block2
m 2 = m a x ( [ 3 , 4 ] ) = 4 f 2 = [ e 3 − m 2 , e 4 − m 2 ] = [ e − 1 , e 0 ] l 2 = ∑ f 2 = e − 1 + e 0 o 2 = f 2 l 2 = [ e − 1 , e 0 ] e − 1 + e 0 = [ e − 1 e − 1 + e 0 , e 0 e − 1 + e 0 ] m_2 = max([3,4]) = 4\\ f_2 = [e^{3-m_2},e^{4-m_2}] = [e^{-1},e^0]\\ l_2 = \sum f_2 = e^{-1} + e^0\\ o_2 = \frac{f_2}{l_2} = \frac{[e^{-1},e^0]}{e^{-1} + e^0} = \left[ \frac{e^{-1}}{e^{-1} + e^0}, \frac{e^0}{e^{-1} + e^0}\right] m2=max([3,4])=4f2=[e3m2,e4m2]=[e1,e0]l2=f2=e1+e0o2=l2f2=e1+e0[e1,e0]=[e1+e0e1,e1+e0e0]
合并得到完整的softmax结果:
m = m a x ( m a x 1 , m a x 2 ) = 4 f = [ e m 1 − m f 1 , e m 2 − m ∗ f 2 ] = [ e − 3 , e − 2 , e − 1 , e 0 ] l = e m 1 − m l 1 , e m 2 − m ∗ l 2 = e − 3 + e − 2 + e − 1 + e 0 o = f l = [ e − 1 , e 0 ] e − 1 + e 0 = [ e − 1 e − 1 + e 0 , e 0 e − 1 + e 0 ] m = max(max_1,max_2) = 4\\ f = \left[e^{m_1-m}f_1,e^{m_2-m}*f_2\right] = \left[e^{-3},e^{-2},e^{-1},e^0\right]\\ l = e^{m_1-m}l_1,e^{m_2-m}*l_2 = e^{-3}+e^{-2}+e^{-1}+e^0\\ o = \frac{f}{l} = \frac{[e^{-1},e^0]}{e^{-1} + e^0} = \left[ \frac{e^{-1}}{e^{-1} + e^0}, \frac{e^0}{e^{-1} + e^0}\right] m=max(max1,max2)=4f=[em1mf1,em2mf2]=[e3,e2,e1,e0]l=em1ml1,em2ml2=e3+e2+e1+e0o=lf=e1+e0[e1,e0]=[e1+e0e1,e1+e0e0]

算法伪代码

在这里插入图片描述

备注:这是在在忽略mask和dropout的情况下,简化分析Flash Attention算法的前向计算过程

作用分析:

在Flash Attention的前向计算算法中可以看出,FlashAttention实现在不访问整个输入的情况下计算softmax,实现IO的较大缩减,标准Attention算法由于要计算softmax,而softmax都是按行来计算的,即在和 V \mathbf{V} V做矩阵乘之前,需要让 Q \mathbf{Q} Q K \mathbf{K} K 的各个分块完成整一行分块的计算得到Softmax的结果后,再和矩阵 V \mathbf{V} V分块做矩阵乘。而在Flash Attention中,将输入分割成块,并在输入块上进行多次传递,从而以增量方式执行softmax缩减

4.4 后向回传伪代码

将前文的前向计算抽象成如下模型,便于后文的引用
S = τ Q K ⊤ ∈ R N × N S masked  = M A S K ( S ) ∈ R N × N P = softmax ⁡ ( S masked  ) ∈ R N × N P dropped  = dropout ⁡ ( P , p drop  ) ∈ R N × N O = P dropped  V ∈ R N × d \begin{gathered} S=\tau Q K^{\top} \in \mathbb{R}^{N \times N} \\ S^{\text {masked }}=M A S K(S) \in \mathbb{R}^{N \times N} \\ P=\operatorname{softmax}\left(S^{\text {masked }}\right) \in \mathbb{R}^{N \times N} \\ P^{\text {dropped }}=\operatorname{dropout}\left(P, p_{\text {drop }}\right) \in \mathbb{R}^{N \times N} \\ O=P^{\text {dropped }} V \in \mathbb{R}^{N \times d} \end{gathered} S=τQKRN×NSmasked =MASK(S)RN×NP=softmax(Smasked )RN×NPdropped =dropout(P,pdrop )RN×NO=Pdropped VRN×d
在标准注意力实现中,后向传递计算 Q \mathbf{Q} Q K \mathbf{K} K V \mathbf{V} V的梯度时,需要用到中间矩阵 S ∈ R N × N \mathbf{S}\in\mathbb{R}^{N\times N} SRN×N P ∈ R N × N \mathbf{P}\in\mathbb{R}^{N\times N} PRN×N。Flash Attention没有保存这两个矩阵,而是保存了两个统计量 m ( x ) m(x) m(x) l ( x ) l(x) l(x),在后向传递时进行重计算。

在反向传递过程中, 需要计算损失函数 ϕ \phi ϕ O \mathbf{O} O Q \mathbf{Q} Q K \mathbf{K} K V \mathbf{V} V 的梯度。在给定 d O ∈ R N × d d \mathbf{O} \in \mathbb{R}^{N \times d} dORN×d 的情况下, 计算梯度 d Q ∈ R N × d d\mathbf{Q}\in \mathbb{R}^{N \times d} dQRN×d d K ∈ R N × d d\mathbf{K}\in \mathbb{R}^{N \times d} dKRN×d d V ∈ R N × d d\mathbf{V} \in \mathbb{R}^{N \times d} dVRN×d 。其中, d O d\mathbf{O} dO d Q d\mathbf{Q} dQ d K d\mathbf{K} dK d V d\mathbf{V} dV 分别表示为 ∂ ϕ ∂ O \frac{\partial \phi}{\partial \mathbf{O}} Oϕ ∂ ϕ ∂ Q \frac{\partial \phi}{\partial \mathbf{Q}} Qϕ ∂ ϕ ∂ K \frac{\partial \phi}{\partial \mathbf{K}} Kϕ ∂ ϕ ∂ V \frac{\partial \phi}{\partial \mathbf{V}} Vϕ

计算 d V d\mathbf{V} dV

梯度 d V d\mathbf{V} dV 是容易计算的。由 O = P V \mathbf{O}=\mathbf{P} \mathbf{V} O=PV,基于矩阵求导算法和链式法则, 得到矩阵形式的梯度 d V = P ⊤ d O d\mathbf{V}=\mathbf{P}^{\top} d \mathbf{O} dV=PdO 。在元素形式上,有:
d v j = ∑ i P i j d o i = ∑ i e ( q i ⊤ k j ) L i d o i d \mathbf{v}_j=\sum_i \mathbf{P}_{i j} d \mathbf{o}_i=\sum_i \frac{e^{(\mathbf{q}_i^{\top} k_j)}}{L_i} d \mathbf{o}_i dvj=iPijdoi=iLie(qikj)doi
之前已经计算好 L i L_i Li,就可以通过反复累加的方式计算得到 d v j d \mathbf{v}_j dvj

计算 d Q d\mathbf{Q} dQ d K d\mathbf{K} dK

梯度 d Q d\mathbf{Q} dQ K \mathbf{K} K 的计算是略微复杂的。首先要计算 d P d\mathbf{P} dP d S d\mathbf{S} dS 。由 O = P V \mathbf{O}=\mathbf{P} \mathbf{V} O=PV,得到矩阵形式的梯度 d P = d O V ⊤ d\mathbf{P}=d\mathbf{O} \mathbf{V}^{\top} dP=dOV 。在元素形式上,有:
d P i j = d o i ⊤ v j d \mathbf{P}_{i j}=d \mathbf{o}_i^{\top} \mathbf{v}_j dPij=doivj

P i : = softmax ⁡ ( S i : ) \mathbf{P}_{i:}=\operatorname{softmax}\left(\mathbf{S}_{i:}\right) Pi:=softmax(Si:) (表示 i i i的一整行)。基于 y = softmax ⁡ ( x ) y=\operatorname{softmax}(x) y=softmax(x) 的雅各比矩阵为 diag ⁡ ( y ) − y y ⊤ \operatorname{diag}(y)-y y^{\top} diag(y)yy 。可以得到:
d S i : = ( diag ⁡ ( P i : ) − P i : P i : ⊤ ) d P i : = P i : ∘ d P i : − ( P i : ⊤ d P i : ) P i : d \mathbf{S}_{i:}=\left(\operatorname{diag}\left(\mathbf{P}_{i:}\right)-\mathbf{P}_{i:} P_{i:}^{\top}\right) d \mathbf{P}_{i:}=\mathbf{P}_{i:} \circ d \mathbf{P}_{i:}-\left(P_{i:}^{\top} d \mathbf{P}_{i:}\right) \mathbf{P}_{i:} dSi:=(diag(Pi:)Pi:Pi:)dPi:=Pi:dPi:(Pi:dPi:)Pi:

其中 ∘ \circ 表示逐点相乘。

可以定义:
D i = P i : ⊤ d P i : = ∑ j e q i ⊤ k j L i d o i ⊤ v j = d o i ⊤ ∑ j e q i ⊤ k j L i v j = d o i ⊤ o i D_i=P_{i:}^{\top} d P_{i:}=\sum_j \frac{e^{q_i^{\top} k_j}}{L_i} d o_i^{\top} v_j=d o_i^{\top} \sum_j \frac{e^{q_i^{\top} k_j}}{L_i} v_j=d o_i^{\top} o_i Di=Pi:dPi:=jLieqikjdoivj=doijLieqikjvj=doioi

将该定义代回到上式中, 可以得到:
d S i : = P i : ∘ d P i : − D i P i : d S_{i:}=P_{i:} \circ d P_{i:}-D_i P_{i:} dSi:=Pi:dPi:DiPi:
因此,梯度 d S d\mathbf{S} dS 可以表示为以下形式:
d S i j = P i j d P i j − D i P i j = P i j ( d P i j − D i ) d \mathbf{S}_{i j}=\mathbf{P}_{i j} d \mathbf{P}_{i j}-\mathbf{D}_i \mathbf{P}_{i j}=\mathbf{P}_{i j}\left(d \mathbf{P}_{i j}-\mathbf{D}_i\right) dSij=PijdPijDiPij=Pij(dPijDi)

在计算得到 d P i j d \mathbf{P}_{i j} dPij d S i j d \mathbf{S}_{i j} dSij 后, 可以计算 d Q d\mathbf{Q} dQ d K d\mathbf{K} dK 。有前向计算公式 S i j = q i ⊤ k j \mathbf{S}_{i j}=\mathbf{q}_i^{\top} \mathbf{k}_j Sij=qikj, 可以得到:
d q i = ∑ j d S i j k j = ∑ j P i j ( d P i j − D i ) k j = ∑ j e ( q i ⊤ k j ) L i ( d o i ⊤ v j − D i ) k j d k j = ∑ i d S i j q i = ∑ i P i j ( d P i j − D i ) q i = ∑ i e ( q i ⊤ k j ) L i ( d o i ⊤ v j − D i ) q i \begin{gathered} d \mathbf{q}_i=\sum_j d \mathbf{S}_{i j} \mathbf{k}_j=\sum_j \mathbf{P}_{i j}\left(d \mathbf{P}_{i j}-\mathbf{D}_i\right) \mathbf{k}_j=\sum_j \frac{e^{(\mathbf{q}_i^{\top} \mathbf{k}_j)}}{\mathbf{L}_i}\left(d \mathbf{o}_i^{\top} \mathbf{v}_j-\mathbf{D}_i\right) \mathbf{k}_j \\ d \mathbf{k}_j=\sum_i d \mathbf{S}_{i j} \mathbf{q}_i=\sum_i \mathbf{P}_{i j}\left(d \mathbf{P}_{i j}-\mathbf{D}_i\right) \mathbf{q}_i=\sum_i \frac{e^{(\mathbf{q}_i^{\top} \mathbf{k}_j)}}{\mathbf{L}_i}\left(d \mathbf{o}_i^{\top} \mathbf{v}_j-\mathbf{D}_i\right) \mathbf{q}_i \end{gathered} dqi=jdSijkj=jPij(dPijDi)kj=jLie(qikj)(doivjDi)kjdkj=idSijqi=iPij(dPijDi)qi=iLie(qikj)(doivjDi)qi

与前向计算类似,在计算得到 L i \mathbf{L}_i Li 后, 就可以通过反复累加的方式计算得到 d q i d \mathbf{q}_i dqi d k j d \mathbf{k}_j dkj d v j d \mathbf{v}_j dvj 。避免了实例化矩阵 P \mathbf{P} P S \mathbf{S} S,节省了显存,后向传递的显存复杂度为 O ( N ) O(N) O(N)

作用分析

对比标准Attention算法的实现过程中,其需要将计算中的 S \mathbf{S} S P \mathbf{P} P写入到HBM中,而这些中间矩阵的大小与输入的序列长度有关且为二次型;

Flash Attention算法中,其并没有将 S \mathbf{S} S P \mathbf{P} P写入HBM中去,而是通过分块写入到HBM中去,存储前向传递的 softmax 归一化因子,在后向传播中快速重新计算片上注意力,这比从HBM中读取中间注意力矩阵的标准方法更快。即使由于重新计算导致 FLOPS 增加,但其运行速度更快并且使用更少的内存(序列长度线性),主要是因为大大减少了 HBM 访问量。

Flash Attention实现了不使用中间注意力矩阵,通过存储归一化因子来减少HBM内存的消耗。

5 总结

  • FA尽可能避免从HBM中读取和写入注意力矩阵,做到了:
  1. 在不访问整个输入的情况下计算softmax函数的IO缩减;
  2. 在后向传播中不存储中间注意力矩阵
  • 通过减少GPU内存读取/写入,FlashAttention的运行速度比PyTorch标准注意力快 2-4 倍,所需内存减少5-20倍。

6 参考文献

[2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

论文分享:新型注意力算法FlashAttention - 知乎

FlashAttention:加速计算,节省显存, IO感知的精确注意力 - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/257988.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【jvm从入门到实战】(九) 垃圾回收(2)-垃圾回收器

垃圾回收器是垃圾回收算法的具体实现。 由于垃圾回收器分为年轻代和老年代,除了G1之外其他垃圾回收器必须成对组合进行使用 垃圾回收器的组合使用关系图如下。 常用的组合如下: Serial(新生代) Serial Old(老年代) Pa…

【每日一题】【12.19】1901.寻找峰值Ⅱ

🔥博客主页: A_SHOWY🎥系列专栏:力扣刷题总结录 数据结构 云计算 数字图像处理 力扣每日一题_ 1.题目链接 1901. 寻找峰值 IIhttps://leetcode.cn/problems/find-a-peak-element-ii/ 2.题目描述 看到这个时间复杂度就知道和昨…

关于折线回归

一、说明 今天的帖子主要是关于使用折线回归找到最佳值。即将某条曲线分解成包络线段,然后用分段回归方式优化。但它也涉及使用 SAS 和 R 的剂量反应研究和样条曲线。这不是第一篇关于这些主题的文章,但我确实想在其中添加折线。只是因为它还在使用。 二…

借助dayjs,把各种类型的日期转换成“YYYY-MM-DD“格式

记得先 npm install datajs <template><div class"home"></div> </template> <script lang"ts" setup> import { reactive, ref } from "vue";import dayjs from "dayjs"; import customParseFormat f…

助力智能人群检测计数,基于YOLOv7开发构建通用场景下人群检测计数识别系统

在一些人流量比较大的场合&#xff0c;或者是一些特殊时刻、时段、节假日等特殊时期下&#xff0c;密切关注当前系统所承载的人流量是十分必要的&#xff0c;对于超出系统负荷容量的情况做到及时预警对于管理团队来说是保障人员安全的重要手段&#xff0c;本文的主要目的是想要…

LED恒流调节器FP7126:引领LED照明和调光的新时代(调光电源、汽车大灯)

目录 一、FP7126概述 二、FP7126功能 三、应用领域 随着科技的进步&#xff0c;LED照明成为了当代照明产业的主力军。而在LED照明的核心技术中&#xff0c;恒流调节器是不可或缺的组成部分。今天&#xff0c;我将为大家介绍一款重要的恒流调节器FP7126&#xff0c;适用于LED…

Axure的案例演示

增删改查&#xff1a; 在中继器里面展示照片

App(Android)ICP备案号查询——————高仿微信

&#x1f604; 个人主页&#xff1a;✨拉莫帅-CSDN博客✨&#x1f914; 博文&#xff1a;132篇&#x1f525; 原创&#xff1a;130篇&#xff0c;转载&#xff1a;2篇&#x1f525; 总阅读量&#xff1a;388923❤️ 粉丝量&#xff1a;112&#x1f341; 感谢点赞和关注 &#x…

什么是集成测试?它和系统测试的区别是什么? 操作方法来了

01 什么是集成测试&#xff1f; 集成测试是软件测试的一种方法&#xff0c;用于测试不同的软件模块之间的交互和协作是否正常。集成测试的主要目的是确保不同的软件模块能够无缝协作&#xff0c;形成一个完整的软件系统&#xff0c;并且能够满足系统的需求和规格。 在集成测试…

Qt Q_DECL_OVERRIDE

Q_DECL_OVERRIDE也就是C的override&#xff08;重写函数&#xff09;&#xff0c;其目的就是为了防止写错虚函数,在重写虚函数时需要用到。 /* 鼠标按下事件 */ void mousePressEvent(QMouseEvent *event) Q_DECL_OVERRIDE; 参考: Qt Q_DECL_OVERRIDE - 一杯清酒邀明月 - 博客…

Android Studio问题解决:Gradle Download 下载超时 Connect reset

文章目录 一、遇到问题二、解决办法 一、遇到问题 Gradle Download下载超时Sync了很多次&#xff0c;一直失败 二、解决办法 手动通过gradle网站下载 https://gradle.org/releases/可能也会出现超时&#xff0c;最好开个VPN软件会比较快。 下载好的软件&#xff0c;放到本机的…

管理类联考——数学——真题篇——按题型分类——充分性判断题——蒙猜D

先看目录&#xff0c;除了2018年比较怪&#xff0c;其他最多2个D&#xff08;数学只有两个弟弟&#xff0c;一个大弟&#xff0c;一个小弟&#xff09; 文章目录 2023真题&#xff08;2023-16&#xff09;-D 2022真题&#xff08;2022-21&#xff09;-D-分析选项⇒是否等价⇒是…

使用极狐gitlab初始化导入本地项目

本地有项目的情况需要同步到极狐gitlab上 第一步&#xff1a; 在gitlab上新创建一个空项目 ⚠️⚠️⚠️这里需要注意红色圈住的地方一定不要选择&#xff0c;因为选择了这个后续会有不必要的麻烦 第二步 在本地项目中删除原来的.git文件(这一步如果是新项目可以忽略&#…

扑克牌炸金花

1.创建类 使用权限修饰符定义所需要参数&#xff0c;使用this关键字生成方法 public class gamejinhua { private String suit;//花色 private int rank;//数字 public gamejinhua(String suit, int rank) { this.suit suit; this.rank rank; } 2.使用快捷键生成get和…

静态库和动态库

静态库 编译&#xff08;链接&#xff09;时把静态库中相关代码复制到可执行文件中&#xff0c;程序中已包含代码&#xff0c;运行时不再需要静态库 占用更多磁盘和内存空间&#xff0c;但程序运行时无需加载库&#xff0c;运行速度快 升级时&#xff0c;程序需要重新编译链…

WPF仿网易云搭建笔记(7):HandyControl重构

文章目录 专栏和Gitee仓库前言相关文章 新建项目项目环境项目结构 代码结果结尾 专栏和Gitee仓库 WPF仿网易云 Gitee仓库 WPF仿网易云 CSDN博客专栏 前言 最近我发现Material Design UI的功能比较简单&#xff0c;想实现一些比较简单的功能&#xff0c;比如消息提示&#xff0…

电脑风扇控制软件Macs Fan Control mac支持多个型号

Macs Fan Control mac是一款专门为 Mac 用户设计的软件&#xff0c;它可以帮助用户控制和监控 Mac 设备的风扇速度和温度。这款软件允许用户手动调整风扇速度&#xff0c;以提高设备的散热效果&#xff0c;减少过热造成的风险。 Macs Fan Control 可以在菜单栏上显示当前系统温…

HTS318 红外热释传感器处理芯片 PIR控制芯片 用于红外感应灯、走廊灯等

HTS318是一颗高度集成的用于热释电红外传感器 (PIR) 的控制芯片。HTS318单片集成了热释电被动红外移动探测的所有必需组件模拟前端可以直接与模拟型PIR探测器使用电容连接&#xff0c;内置3V LDO&#xff0c;给PIR探测器供电。内置高精度模数转换器&#xff0c;可将探测器信号转…

漏刻有时数据可视化Echarts组件开发(45)机场流程导航线和指示点的开发记录

路径线 ECharts中的路径线是指用于连接起点和终点的线。在ECharts中&#xff0c;路径图主要用于带有起点和终点信息的线数据的绘制&#xff0c;如地图上的航班、路线等。路径线可以用于展示数据点之间的连接关系&#xff0c;以及数据点之间的相对位置。 {//路径图name: 路线图…

宣布推出 ML.NET 3.0

作者&#xff1a;Jeff Handley 排版&#xff1a;Alan Wang ML.NET 是面向 .NET 开发人员的开源、跨平台的机器学习框架&#xff0c;可将自定义机器学习模型集成到 .NET 应用程序中。ML.NET 3.0 版本现已发布&#xff0c;其中包含大量新功能和增强功能&#xff01; 此版本中的深…
最新文章