【LLM加速】注意力优化(基于位置/内容的稀疏注意力 | flashattention)

note

(1)近似注意力:

  • Routing Transformer采用K-means 聚类方法,针对Query和Key进行聚类,类中心向量集合为 { μ i } i = 1 k \left\{\boldsymbol{\mu}_i\right\}_{i=1}^k {μi}i=1k ,其中k 是类中心的个数。每个Query 只与其处在相同簇 (Cluster) 下的Key 进行交互。
  • Reformer 则采用局部敏感哈希 (Local-Sensitive Hashing,LSH) 的方法为每个Query 选择Key-Value 对。其主要思想是使用LSH 函数对Query 和Key 进行哈希计算,将它们划分到多个桶内,以提升在同一个桶内的Query 和Key 参与交互的概率。

(2)在Transformer 结构中,自注意力机制的时间和存储复杂度与序列的长度呈平方的关系,因此占用了大量的计算设备内存并消耗了大量的计算资源。如何优化自注意力机制的时空复杂度、增强计算效率是大语言模型面临的重要问题。

  • 方法一:从近似注意力出发,旨在减少注意力计算和内存需求,提出了稀疏近似、低秩近似等方法。
  • 方法二:从计算加速设备本身的特性出发,研究如何更好地利用硬件特性对Transformer 中的注意力层进行高效计算。

(3)FlashAttention目标是尽可能高效地使用SRAM来加快计算速度,避免从全局内存中读取和写入注意力矩阵。达成该目标需要能做到在不访问整个输入的情况下计算Softmax函数,并且后向传播中不能存储中间注意力矩阵。

文章目录

  • note
  • 一、近似注意力
    • 1. 基于位置的稀疏注意力机制
    • 2. 基于内容的稀疏注意力机制
      • (1)Routing Transformer:使用聚类
      • (2)Reformer:使用LSH哈希
  • 二、计算加速
    • 1. GPU硬件基础知识
    • 2. flashattention
    • 3. 多查询注意力MQA
      • (1)MHA和MQA的区别
      • (2)MHA和MQA的具体代码
      • (3)使用矩阵乘法matmul广播实现参数共享
      • (4)tgi框架中的MQA
  • Reference

一、近似注意力

对一些训练好的Transformer 结构中的注意力矩阵进行分析时发现,其中很多是稀疏的,因此可以通过限制Query-Key 对的数量来降低计算复杂度。这类方法称为稀疏注意力(SparseAttention)机制。可以将稀疏化方法进一步分成基于位置的和基于内容信息的两类。

1. 基于位置的稀疏注意力机制

基于位置的稀疏注意力机制的基本类型如下图,主要包含如下五种类型:全局注意力(Global Attention)、带状注意力(Band Attention)、膨胀注意力(Dilated Attention)、随机注意力(Random Attention)、局部块注意力(Block Local Attention)。

这些注意力机制的区别主要在于它们如何选择序列中的元素来计算注意力权重,这直接影响计算复杂度、处理长距离依赖的能力以及对不同类型任务的适用性。每种注意力机制的关键区别和特点:

  1. 全局注意力(Global Attention):

    • 关键特点:在计算每个位置的注意力时,考虑序列中的所有其他位置。
    • 优点:能够捕获全局依赖性,理论上可以处理任意距离的关系。
    • 缺点:计算复杂度高,随序列长度的平方增长,不适合处理长序列。
  2. 带状注意力(Band Attention):

    • 关键特点:仅在每个位置的一个固定宽度的带内计算注意力权重,通常集中在序列的对角线附近。
    • 优点:减少了计算量,适合捕获局部依赖性。
    • 缺点:可能忽略重要的长距离依赖。
  3. 膨胀注意力(Dilated Attention):

    • 关键特点:通过引入膨胀因子来间隔地选择序列中的元素进行注意力计算,从而覆盖更广的范围。和CNN中的Dilated Conv类似,通过增加空隙以获取更大的感受野
    • 优点:在降低计算复杂度的同时,能够捕获更远的依赖性。
    • 缺点:可能不如全局注意力在捕捉所有长距离依赖上有效。
  4. 随机注意力(Random Attention):

    • 关键特点:随机选择序列中的位置来计算注意力权重。即通过随机采样,提升非局部的交互。
    • 优点:显著降低计算需求,引入随机性可能帮助模型探索更多的依赖关系。
    • 缺点:随机性可能导致忽略一些关键的依赖关系。
  5. 局部块注意力(Block Local Attention):

    • 关键特点:将序列分割成多个块,在这些局部块内计算注意力权重。使用多个不重叠的块Block来限制信息交互。
    • 优点:大幅降低计算复杂度,适合处理长序列。
    • 缺点:如果不允许跨块计算,则可能忽略块间的依赖关系。

总结来说,这些注意力机制通过不同的策略平衡计算复杂度和模型的捕获依赖能力。选择哪种注意力机制取决于特定任务的需求,例如处理长序列数据时可能更倾向于使用带状、膨胀、随机或局部块注意力机制,而在不那么受限于计算资源的情况下,全局注意力可能是最好的选择,因为它能够捕获全局依赖性。

在这里插入图片描述
下面给出带状注意力的栗子:

# query-shape: [bs, seq_len, emb_dim]
def band_attention(query, key, value, band_width):
    """
    Args:
        query, key, value: standard attention inputs
        band_width: The width of the band around the diagonal to compute attention.
    Returns:
        Tensor: The output of the attention mechanism.
    """
    batch_size, seq_len, d_k = query.size()
    scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)

    # Create a mask to zero out attention scores outside the band
    idxs = torch.arange(seq_len).unsqueeze(0).to(query.device)
    mask = (idxs - idxs.transpose(0, 1)).abs().ge(band_width).to(scores.dtype)
    scores.masked_fill_(mask, float('-inf'))

    attention = F.softmax(scores, dim=-1)
    output = torch.matmul(attention, value)
    return output

# 测试的case
def band_attention_test():
    import torch
    # 假设输入数据的维度
    batch_size = 2
    seq_length = 10
    embed_size = 128
    heads = 8
    # 生成随机数据作为输入
    values = torch.rand(batch_size, seq_length, embed_size)
    keys = torch.rand(batch_size, seq_length, embed_size)
    queries = torch.rand(batch_size, seq_length, embed_size)
    # 定义带宽
    band_width = 3
    # 使用相同的随机数据输入
    band_attention_output = band_attention(queries, keys, values, band_width)
    # Band Attention Output Shape: torch.Size([2, 10, 128])
    print("Band Attention Output Shape:", band_attention_output.shape)

可以看到上面的mask矩阵确实是带状的:
在这里插入图片描述

现有的稀疏注意力机制,通常是基于上述五种基于位置的稀疏注意力机制的复合模式,下图给出了一些典型的稀疏注意力模型:

  • star-transformer:使用带状注意力和全局注意力的组合,只包括一个全局注意力节点和宽度为3的带状注意力,其中任意两个非相邻节点通过一个共享的全局注意力连接,而相邻节点则直接相连。
  • longformer:将上层中的一些带状注意力头部替换为具有扩张窗口的注意力,在增加感受野同时不增加计算量
  • ETC(Extended Transformer Construction):利用带状注意力和外部全局节点注意力(External Global-node Attention)的组合。ETC 稀疏注意力还包括一种掩码机制来处理结构化输入,并采用对比预测编码(Contrastive Predictive Coding,CPC)进行预训练。
  • BigBird:使用带状和全局注意力,还使用额外的随机注意力来近似全连接注意力,此外还揭示了稀疏编码器和稀疏解码器的使用可以模拟任何图灵机

在这里插入图片描述

2. 基于内容的稀疏注意力机制

基于内容的稀疏注意力机制根据输入数据创建稀疏注意力,其中一种很简单的方法是选择和给定查询 (Query) 有很高相似度的键 (Key)。

(1)Routing Transformer:使用聚类

(1)Routing Transformer采用K-means 聚类方法,针对Query和Key进行聚类,类中心向量集合为 { μ i } i = 1 k \left\{\boldsymbol{\mu}_i\right\}_{i=1}^k {μi}i=1k ,其中k 是类中心的个数。每个Query 只与其处在相同簇 (Cluster) 下的Key 进行交互。中心向量采用滑动平均的方法进行更新:
μ ~ ← μ ~ + ( 1 − λ ) ( ∑ i : μ ( q i ) = μ q i + ∑ j : μ ( k j ) = μ k j ) c μ ← λ c μ + ( 1 − λ ) ∣ μ ∣ μ ← μ ~ c μ \begin{gathered} \widetilde{\boldsymbol{\mu}} \leftarrow \tilde{\boldsymbol{\mu}}+(1-\lambda)\left(\sum_{i: \mu\left(\boldsymbol{q}_i\right)=\mu} \boldsymbol{q}_i+\sum_{j: \mu\left(\boldsymbol{k}_j\right)=\mu} \boldsymbol{k}_j\right) \\ c_\mu \leftarrow \lambda c_\mu+(1-\lambda)|\mu| \\ \mu \leftarrow \frac{\widetilde{\boldsymbol{\mu}}}{c_\mu} \end{gathered} μ μ~+(1λ) i:μ(qi)=μqi+j:μ(kj)=μkj cμλcμ+(1λ)μμcμμ

(2)Reformer:使用LSH哈希

(2)Reformer 则采用局部敏感哈希 (Local-Sensitive Hashing,LSH) 的方法为每个Query 选择Key-Value 对。其主要思想是使用LSH 函数对Query 和Key 进行哈希计算,将它们划分到多个桶内,以提升在同一个桶内的Query 和Key 参与交互的概率。假设 b b b 是桶的个数,给定一个大小为 [ D k , b / 2 ] [D k , b / 2] [Dkb/2] 的随机矩阵 R R R , LSH 函数的定义为:
h ( x ) = arg ⁡ max ⁡ ( [ x R ; − x R ] ) h(\boldsymbol{x})=\arg \max ([\boldsymbol{x} R ;-\boldsymbol{x} R]) h(x)=argmax([xR;xR])

如果 h q i = h k j h \boldsymbol{q}_i=h \boldsymbol{k}_j \quad hqi=hkj 时, q i \boldsymbol{q}_i qi 才可以与相应的Key-Value对进行交互。

二、计算加速

1. GPU硬件基础知识

NVIDIA GPU中的内存(显存)按照它们物理上是在GPU芯片内部还是板卡RAM存储芯片上,决定了它们的速度、大小以及访问限制。GPU显存分为:

  • 全局内存(Global memory)
  • 本地内存(Local memory)
  • 共享内存(Shared memory,SRAM)
  • 寄存器内存(Register memory)
  • 常量内存(Constant memory)
  • 纹理内存(Texture memory)

在这里插入图片描述

全局内存和本地内存使用的高带宽显存(High Bandwidth Memory,HBM)位于板卡RAM存储芯片上,该部分内存容量很大。全局内存是所有线程都可以访问,而本地内存则只能当前线程访问。NVIDIA H100中全局内存有80GB空间,其访问速度虽然可以达到3.35TB/s,但是如果全部线程同时访问全局内存时,其平均带宽仍然很低

共享内存和寄存器位于GPU芯片上,因此容量很小,并且共享内存只有在同一个GPU线程块(Thread Block)内的线程才可以共享访问,而寄存器仅限于同一个线程内部才能访问。NVIDIA H100中每个GPU线程块在流式多处理器(Stream Multi-processor,SM)可以使用的共享存储容量仅有228KB,但是其速度非常快,远高于全局内存的访问速度。

根据自注意力机制的原理,在GPU中进行计算时,传统的方法还需要引入两个中间矩阵 S 和 P 并存储到全局内存中。具体计算过程如下:
S = Q × K , P = Softmax ⁡ ( S ) , O = P × V \boldsymbol{S}=\boldsymbol{Q} \times \boldsymbol{K}, \boldsymbol{P}=\operatorname{Softmax}(\boldsymbol{S}), \boldsymbol{O}=\boldsymbol{P} \times \boldsymbol{V} S=Q×K,P=Softmax(S),O=P×V

按照上述计算过程,需要:

  • 首先从全局内存中读取矩阵 Q Q Q K K K ,并将计算好的矩阵 S S S再写入全局内存
  • 之后再从全局内存中获取矩阵 S S S ,计算Softmax得到矩阵 P P P 再写入全局内存
  • 之后读取矩阵 P P P 和矩阵 V V V ,计算得到矩阵 O O O

这样的过程会极大占用显存的带宽。在自注意力机制中,计算速度比内存速度快得多 ,因此计算效率越来越多地受到全局内存访问的瓶颈。

2. flashattention

FlashAttention就是通过利用GPU硬件中的特殊设计,针对全局内存和共享存储的I/O速度的不同,尽可能地避免HBM中读取或写入注意力矩阵。

FlashAttention目标是尽可能高效地使用SRAM来加快计算速度,避免从全局内存中读取和写入注意力矩阵。达成该目标需要能做到在不访问整个输入的情况下计算Softmax函数,并且后向传播中不能存储中间注意力矩阵

在这里插入图片描述

FlashAttention 就提出了不使用中间注意力矩阵,通过存储归一化因子来减少全局内存消耗的方法。

FlashAttention 算法并没有将S、P 整体写入全局内存,而是通过分块写入,存储前向传递的Softmax 归一化因子,在后向传播中快速重新计算片上注意力,这比从全局内存中读取中间注意力矩阵的标准方法更快。

虽然大幅减少了全局内存的访问量,重新计算也导致FLOP(FLOPS指标,Floating Point Operations per Second 指每秒浮点运算次数) 增加,但其运行的速度更快且使用的内存更少。

在这里插入图片描述

3. 多查询注意力MQA

多查询注意力(Multi Query Attention)是多头注意力的一种变体。其特点是,在多查询注意力中不同的注意力头共享一个键和值的集合,每个头只单独保留了一份查询参数,因此键和值的矩阵仅有一份,这大幅减少了显存占用,使其更高效。
在这里插入图片描述

由于多查询注意力改变了注意力机制的结构,因此模型通常需要从训练开始就支持多查询注意力。文献研究结果表明,可以通过对已经训练好的模型进行微调来添加多查询注意力支持,仅需要约5% 的原始训练数据量就可以达到不错的效果。

包括Falcon[64]、SantaCoder[65]、StarCoder[66] 在内的很多模型都采用了多查询注意力机制。

(1)MHA和MQA的区别

MHA 和 MQA 之间的区别主要在于建立 Wqkv Layer 上(如下代码)。在MQA中,除了query向量还保存8个头,key和value向量都只剩下1个【公共头】,即前面说的所有head之间共享一份key和value参数。

# Multi Head Attention
self.Wqkv = nn.Linear(                        # 【关键】Multi-Head Attention 的创建方法
    self.d_model, 
    3 * self.d_model,                         # 有 query, key, value 3 个矩阵, 所以是 3 * d_model
    device=device
)
query, key, value = qkv.chunk(                # 【关键】每个 tensor 都是 (1, 512, 768)
    3, 
    dim=2
)
  
# Multi Query Attention
self.Wqkv = nn.Linear(                                # 【关键】Multi-Query Attention 的创建方法
    d_model,
    d_model + 2 * self.head_dim,                      # 只创建 query 的 head 向量,所以只有 1 个 d_model
    device=device,                                    # 而 key 和 value 不再具备单独的头向量
)
query, key, value = qkv.split(                        # query -> (1, 512, 768)
    [self.d_model, self.head_dim, self.head_dim],     # key   -> (1, 512, 96)
    dim=2                                             # value -> (1, 512, 96)
)

(2)MHA和MQA的具体代码

其中MultiheadAttentionMultiQueryAttention类完整的代码如下。

class MultiheadAttention(nn.Module):
 
     def __init__(
             self,
             d_model: int,
             n_heads: int,
             device: str
        ):
         """
        Multi Head init func.
 
        Args:
            d_model (int): hidden state size, e.g. 768
            n_heads (int): 设定的注意力头数, e.g. 8
            device (str): _description_
        """
         super().__init__()
 
         self.d_model = d_model
         self.n_heads = n_heads
     
         self.Wqkv = nn.Linear(                       # Multi-Head Attention 的创建方法
             self.d_model,
             3 * self.d_model,                        # 有 query, key, value 3 个矩阵, 所以是 3 * d_model
             device=device
        )                                            # (d_model, 3 * d_model)
         self.attn_fn = scaled_multihead_dot_product_attention
         self.out_proj = nn.Linear(
             self.d_model,
             self.d_model,
             device=device
        )
 
     def forward(
         self,
         x
    ):
         """
        forward func.
 
        Args:
            x (tensor): (batch, hidden_state, d_model) e.g. -> (1, 768, 512)
 
        Returns:
            _type_: _description_
        """
         qkv = self.Wqkv(x)                            # (1, 768, 3 * 768)
 
         query, key, value = qkv.chunk(                # 每个 tensor 都是 (1, 512, 768)
             3,
             dim=2
        )    
 
         context, attn_weights, past_key_value = self.attn_fn(
             query,
             key,
             value,
             self.n_heads
        )                                             # (1, 512, 768)
 
         return self.out_proj(context), attn_weights, past_key_value
 
 
 class MultiQueryAttention(nn.Module):
     """Multi-Query self attention.
 
    Using torch or triton attention implemetation enables user to also use
    additive bias.
    """
 
     def __init__(
         self,
         d_model: int,
         n_heads: int,
         device: Optional[str] = None,
    ):
         super().__init__()
 
         self.d_model = d_model
         self.n_heads = n_heads
         self.head_dim = d_model // n_heads
 
         self.Wqkv = nn.Linear(                           # Multi-Query Attention 的创建方法
             d_model,
             d_model + 2 * self.head_dim,                 # 只创建 query 的 head 向量,所以只有 1 个 d_model
             device=device,                               # 而 key 和 value 则只共享各自的一个 head_dim 的向量
        )
 
         self.attn_fn = scaled_multihead_dot_product_attention
         self.out_proj = nn.Linear(
             self.d_model,
             self.d_model,
             device=device
        )
         self.out_proj._is_residual = True  # type: ignore
 
     def forward(
         self,
         x,
    ):
         qkv = self.Wqkv(x)                                           # (1, 512, 960)
 
         query, key, value = qkv.split(                               # query -> (1, 512, 768)
            [self.d_model, self.head_dim, self.head_dim],             # key   -> (1, 512, 96)
             dim=2                                                    # value -> (1, 512, 96)
        )
 
         context, attn_weights, past_key_value = self.attn_fn(
             query,
             key,
             value,
             self.n_heads,
             multiquery=True,
        )
 
         return self.out_proj(context), attn_weights, past_key_value

(1)初始化函数 __init__

  • __init__(self, d_model: int, n_heads: int, device: Optional[str] = None): 这是类的初始化函数,用于创建类的实例时初始化其属性。它接受三个参数:模型的维度 d_model、注意力头的数量 n_heads,以及设备 device(可选),用于指定模块运行的硬件(CPU或GPU)。
  • self.d_model = d_modelself.n_heads = n_heads: 这两行代码将传入的模型维度和头的数量保存为类的属性。
  • self.head_dim = d_model // n_heads: 计算每个头的维度,即将模型维度均分到每个头上。
  • self.Wgkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device): 创建一个线性层 Wgkv,用于生成查询(Q)、键(K)和值(V)。这个线性层的输出维度是 d_model + 2 * self.head_dim,意味着查询的维度保持为 d_model,而键和值的维度为 self.head_dim。这种设计减少了模型参数,因为它没有为键和值分别创建额外的线性变换。
  • self.attn_fn = scaled_multihead_dot_product_attentionself.out_proj = nn.Linear(self.d_model, self.d_model, device=device): 定义了一个注意力函数 attn_fn 和一个输出投影层 out_projattn_fn 负责计算多头点积注意力,而 out_proj 用于将注意力机制的输出转换回原始输入的维度。

(2)前向传播函数 forward

  • def forward(self, X): 定义了前向传播函数,它接收一个输入张量 X
  • gkv = self.Wgkv(X): 首先,输入通过 Wgkv 线性层,产生了合并的查询、键、值矩阵。
  • query, key, value = gkv.split([self.d_model, self.head_dim, self.head_dim], dim=2): 然后,将 gkv 拆分为查询、键和值三部分。注意拆分的维度与 Wgkv 层的输出设计相匹配。
  • context, attn_weights, past_key_value = self.attn_fn(query, key, value, self.n_heads, multiquery=True): 使用定义的注意力函数计算注意力,multiquery=True 参数指示使用多查询注意力机制。
  • return self.out_proj(context), attn_weights, past_key_value: 最后,将注意力的输出通过 out_proj 投影层,然后将结果、注意力权重和过去的键值对返回。

(3)使用矩阵乘法matmul广播实现参数共享

其中注意上面的scaled_multihead_dot_product_attention函数就是实现刚才说的一份key和value参数让多个头使用,使用矩阵乘法matmul进行广播,实现参数共享。

def scaled_multihead_dot_product_attention(
        query,
        key,
        value,
        n_heads,
        multiquery=False,
    ):
    q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)         # (1, 512, 768) -> (1, 8, 512, 96)
    kv_n_heads = 1 if multiquery else n_heads
    k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)        # (1, 512, 768) -> (1, 8, 96, 512) if not multiquery 
                                                                    # (1, 512, 96) -> (1, 1, 96, 512)  if multiquery
    v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)      # (1, 512, 768) -> (1, 8, 512, 96) if not multiquery 
                                                                    # (1, 512, 96) -> (1, 1, 512, 96)  if multiquery
    
    attn_weight = q.matmul(k) * softmax_scale                       # (1, 8, 512, 512)
    attn_weight = torch.softmax(attn_weight, dim=-1)                # (1, 8, 512, 512)
 
    out = attn_weight.matmul(v)                                     # (1, 8, 512, 512) * (1, 1, 512, 96) = (1, 8, 512, 96)
    out = rearrange(out, 'b h s d -> b s (h d)')                    # (1, 512, 768)
 
    return out, attn_weight, past_key_value

(4)tgi框架中的MQA

具体还可以参考tgi框架中的MQA代码:

class MultiQueryAttention(nn.Module):
    """Multi-Query self attention.

    Using torch or triton attention implementation enables user to also use
    additive bias.
    """

    def __init__(self, config, prefix, weights):
        super().__init__()
        attn_impl = config.attn_config["attn_impl"]
        self.attn_impl = config.attn_config["attn_impl"]
        self.clip_qkv = config.attn_config["clip_qkv"]
        self.qk_ln = config.attn_config["qk_ln"]
        self.d_model = config.d_model
        d_model = config.d_model
        self.n_heads = config.n_heads
        self.softmax_scale = config.attn_config["softmax_scale"]
        if self.softmax_scale is None:
            self.softmax_scale = 1 / math.sqrt(self.head_dim)
        self.attn_dropout_p = config.attn_config["attn_pdrop"]
        # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
        self.Wqkv = TensorParallelColumnLinear.load(
            config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
        )
        fuse_splits = (d_model, d_model + self.head_dim)
        if self.qk_ln:
            raise NotImplementedError("qk_ln not supported")
        if self.attn_impl == "flash":
            self.attn_fn = flash_attn_fn
        elif self.attn_impl == "triton":
            self.attn_fn = triton_flash_attn_fn
            if verbose:
                warnings.warn(
                    "While `attn_impl: triton` can be faster than `attn_impl: flash` "
                    + "it uses more memory. When training larger models this can trigger "
                    + "alloc retries which hurts performance. If encountered, we recommend "
                    + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
                )
        elif self.attn_impl == "torch":
            self.attn_fn = scaled_multihead_dot_product_attention
            if torch.cuda.is_available() and verbose:
                warnings.warn(
                    "Using `attn_impl: torch`. If your model does not use `alibi` or "
                    + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
                    + "we recommend using `attn_impl: triton`."
                )
        else:
            raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
        self.out_proj = TensorParallelRowLinear.load(
            config,
            prefix=f"{prefix}.out_proj",
            weights=weights,
            bias=not config.no_bias,
        )
        # self.out_proj._is_residual = True

    def forward(
        self,
        x,
        past_key_value=None,
        attn_bias=None,
        attention_mask=None,
        is_causal=True,
        needs_weights=False,
    ):
        qkv = self.Wqkv(x)
        if self.clip_qkv:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
        (query, key, value) = qkv.split(
            [self.d_model, self.head_dim, self.head_dim], dim=2
        )
        key_padding_mask = attention_mask
        if self.qk_ln:
            dtype = query.dtype
            query = self.q_ln(query).to(dtype)
            key = self.k_ln(key).to(dtype)
        (context, attn_weights, past_key_value) = self.attn_fn(
            query,
            key,
            value,
            self.n_heads,
            past_key_value=past_key_value,
            softmax_scale=self.softmax_scale,
            attn_bias=attn_bias,
            key_padding_mask=key_padding_mask,
            is_causal=is_causal,
            dropout_p=self.attn_dropout_p,
            training=self.training,
            needs_weights=needs_weights,
            multiquery=True,
        )
        return (self.out_proj(context), attn_weights, past_key_value)

Reference

[1] https://github.com/huggingface/text-generation-inference
[2] LLM 加速技巧:Muti Query Attention
[3] 训练模型算力的单位:FLOPs、FLOPS、Macs 与 估算模型(FC, CNN, LSTM, Transformers&&LLM)的FLOPs
[4] FlashAttention 的速度优化原理是怎样的?
[5] FlashAttention图解(如何加速Attention)
[6] flashattention论文:https://arxiv.org/pdf/2205.14135.pdf
[7] 全局注意力机制(global attention)详解与代码实现
[8] 每天学习一点点—大模型知识学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/464016.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

@RequestParam、@PathVariable、@RequestBody

1、中文翻译 RequestParam-请求参数、PathVariable-路径变量、RequestBody请求体 2、作用: Controller中获取前端传递的参数 3、从注解本身角度分析 3.1、PathVariable:路径变量 通过 PathVariable 可以将URL中占位符参数{xxx}绑定到处理器类的方法形…

【组合回溯】Leetcode 131. 分割回文串

【组合回溯】Leetcode 131. 分割回文串 解法 切割组合回溯 ---------------🎈🎈131. 分割回文串 题目链接🎈🎈------------------- 解法 切割组合回溯 全局变量:result存储所有path的集合,path用来记录切…

文件系统 与 软硬链接

目录 一、文件系统 认识磁盘 磁盘存储的逻辑抽象结构 块组的内容 inode Table Data blocks inode Bitmap Block Bitmap Group Descriptor Table Super Block 理解目录 二、软硬链接 软链接​ 硬链接 硬链接数 一、文件系统 之前的博客主题叫做"进程打开文…

Redisinsight默认端口改成5540了!网上的8001都是错误的

Redisinsight 打开白屏解决方法 最近发现一个很讨厌的bug,就是redisinsight运行之后,不行了,在网上找到的所有资料里面,redis insight都是运行在8001端口,但是我现在发现,变成了5540 所以对应的docker-com…

Node.js与webpack(三)

上一节:Node.js与Webpack笔记(二)-CSDN博客 从0来一遍(webpack项目) 将之前的webpack 的纯开发配置,重新创建空白项目,重新做一遍,捋一遍思路防止加入生产模式时候弄混 1.创建文件夹…

SVM-支持向量机实验分析(软硬间隔,线性核,高斯核)

目录 一、前言 二、实验 0. 导入包 1. 支持向量机带来的效果 2. 软硬间隔 3. 非线性支持向量机 4. 核函数变换 线性核 高斯核 对比不同的gamma值对结果的影响 一、前言 学习本文之前要具有SVM支持向量机的理论知识,可以参考支持向量机(Support Vector …

epoll怎么就高效了?

目录 摘要 1 举个栗子 2 从 epoll_create 开始 3 epoll_ctl,插入待监听的描述符 3.1 故事围绕 ep_item 展开 3.2 在 socket 等待队列上设置 epoll 回调 3.3 关系变得复杂 4 epoll_wait 等你 4.1 等待就绪事件 4.2 共享内存? 5 来了来了&#xf…

第 126 场 LeetCode 双周赛题解

A 求出加密整数的和 模拟 class Solution { public:int sumOfEncryptedInt(vector<int> &nums) {int res 0;for (auto x: nums) {string s to_string(x);char ch *max_element(s.begin(), s.end());for (auto &c: s)c ch;res stoi(s);}return res;} };B 执行…

Java学习笔记(15)

JDK7前时间相关类 Date时间类 Simpledateformat Format 格式化 Parse 解析 默认格式 指定格式 EE&#xff1a;表示周几 Parse&#xff1a;把字符串时间转成date对象 注意&#xff1a;创建对象的格式要和字符串的格式一样 Calendar日历类 不能创建对象 Getinstance 获取当…

8款手机宝藏APP,每款都非常强大实用!

1. 综合AI工具箱——HuluAI 综合AI工具https://h5.cxyhub.com/?invitationhmeEo7 HuluAI是一款聚合式全能AI工具&#xff0c;完美接入官方正版GPT4.0和Midjourney绘画&#xff01;。除此之外&#xff0c;还拥有文心一言语言大模型和DallE3绘图功能。经过长时间的稳定运行&am…

【数据结构】深入理解AVL树:实现和应用

AVL树是一种自平衡的二叉搜索树&#xff0c;它能够保持良好的平衡性质&#xff0c;使得在最坏情况下的时间复杂度也能保持在对数级别。本文将深入介绍AVL树的原理、实现和应用&#xff0c;并通过示例代码演示其基本操作。 文章目录 什么是AVL树&#xff1f;AVL树的实现在AVL树…

Linux - 安装 Jenkins(详细教程)

目录 前言一、简介二、安装前准备三、下载与安装四、配置镜像地址五、启动与关闭六、常用插件的安装 前言 虽然说网上有很多关于 Jenkins 安装的教程&#xff0c;但是大部分都不够详细&#xff0c;或者是需要搭配 docker 或者 k8s 等进行安装&#xff0c;对于新手小白而已&…

2024人工智能四大趋势→

2023年&#xff0c;世人见证了ChatGPT在全球范围的大火。以生成式人工智能为代表的新一代人工智能问世&#xff0c;改变了人工智能&#xff08;AI&#xff09;技术与应用的发展轨迹&#xff0c;加速了人与AI的互动进程&#xff0c;是人工智能发展史上的新里程碑。2024年&#x…

职场中的“跨界思维”:如何拓宽你的职业发展道路?

在当今职场&#xff0c;单一技能的竞争已经越来越激烈&#xff0c;具备跨界思维的人才越来越受到企业的青睐。本文将探讨职场中的“跨界思维”&#xff0c;帮助您拓宽职业发展道路&#xff0c;提升自身竞争力。 一、什么是跨界思维&#xff1f; 跨界思维&#xff0c;顾名思义&a…

【重新定义matlab强大系列十八】Matlab深度学习长短期记忆 (LSTM) 网络生成文本

&#x1f517; 运行环境&#xff1a;Matlab &#x1f6a9; 撰写作者&#xff1a;左手の明天 &#x1f947; 精选专栏&#xff1a;《python》 &#x1f525; 推荐专栏&#xff1a;《算法研究》 #### 防伪水印——左手の明天 #### &#x1f497; 大家好&#x1f917;&#x1f91…

Etcd 介绍与使用(入门篇)

etcd 介绍 etcd 简介 etc &#xff08;基于 Go 语言实现&#xff0c;&#xff09;在 Linux 系统中是配置文件目录名&#xff1b;etcd 就是配置服务&#xff1b; etcd 诞生于 CoreOS 公司&#xff0c;最初用于解决集群管理系统中 os 升级时的分布式并发控制、配置文件的存储与分…

Bean的作用域、Bean的自动装配、注解自动装配 (Spring学习笔记五)

1、Bean 的作用域 官网上显示有六种 1、Bean的作用域默认的是singleton&#xff08;单例模式的实现&#xff09; 也可以显示的设置&#xff08;单例模式的实现&#xff09; <!--用scope可以设置Bean的作用域--><bean id"user2" class"com.li.pojo.Us…

Elasticsearch从入门到精通-05ES匹配查询

Elasticsearch从入门到精通-05ES匹配查询 &#x1f44f;作者简介&#xff1a;大家好&#xff0c;我是程序员行走的鱼 &#x1f4d6; 本篇主要介绍和大家一块学习一下ES各种场景下的匹配查询,有助于我们在项目中进行综合使用 前提 创建索引并指定ik分词器: PUT /es_db {"…

[ Linux ] vim的使用(附:命令模式的常见命令列表)

1.下载安装 这里是在通过yum进行下载安装 yum install -y vim 2.了解 vim是一款编辑器&#xff0c;它具有多模式的特点 主要有&#xff1a;插入模式&#xff0c;命令模式&#xff0c;底行模式 3.使用 打开 vim 文件名 命令模式的常见命令列表 插入模式 按「 i 」切换…

建设IAM/IDM统一身份管理,实现系统之间的单点登录(SSO)

企业实施身份管理的现状&#xff1a; 1.身份存储分散&#xff0c;不能统一供应诸多应用系统&#xff0c;企业用户信息常常存在于多个系统&#xff0c;如HR系统有一套用户信息&#xff0c;OA系统也有一套用户信息&#xff0c;身份存储不集中&#xff0c;不能统一地为诸多应用系…