CANNOpsTransformer融合因果一维卷积

📅 2026/7/4 21:24:06 👁️ 阅读次数 📝 编程学习
CANNOpsTransformer融合因果一维卷积

FusedCausalConv1d

【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库,实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer

产品支持情况

产品是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品×
Atlas A2 训练系列产品/Atlas A2 推理系列产品×
Atlas 200I/500 A2 推理产品×
Atlas 推理系列产品×
Atlas 训练系列产品×

功能说明

  • 算子功能:对序列执行因果一维卷积,沿序列维度使用缓存数据(长度为卷积核宽减1)对各序列头部进行padding,确保输出依赖当前及历史输入;卷积完成后,将当前序列尾部的数据(长度为卷积核宽减1)更新到缓存;在因果一维卷积输出的基础上,将原始输入加到输出上以实现残差连接。

  • 本算子支持以下场景:

    • 场景一(prefill场景):

      x: [cu_seq_len, dim] weight: [K, dim],其中K=3 conv_states: [-1, K-1, dim] query_start_loc: [batch+1] cache_indices: [batch] initial_state_mode: [batch] bias: [dim](无作用) num_accepted_tokens: [batch](无作用) y: [cu_seq_len, dim] run_mode: 0

      其中cu_seq_len为batch内所有变长序列拼接后的总长度。

    • 场景二(decode场景 - 变长序列):

      x: [cu_seq_len, dim] weight: [K, dim],其中K=3 conv_states: [-1, state_len, dim] query_start_loc: [batch+1] cache_indices: [batch] initial_state_mode: [batch] bias: [dim](无作用) num_accepted_tokens: [batch](用于投机解码) y: [cu_seq_len, dim] run_mode: 1

      其中state_len必须大于所有batch中最大的token个数加1。

    • 场景三(decode场景 - 固定batch):

      x: [batch, m+1, dim] weight: [K, dim],其中K=3 conv_states: [-1, K-1+m, dim] query_start_loc: [batch+1](无作用) cache_indices: [batch] initial_state_mode: [batch] bias: [dim](无作用) num_accepted_tokens: [batch](用于投机解码,m为投机token个数) y: [batch, m+1, dim] run_mode: 1
  • 计算公式:

    K是卷积核宽度(固定为3),L是原始序列长度,dim是特征维度。

    1. 缓存拼接:

    $$ x'[i, dim] = \begin{cases} cacheState[i, dim], & 0 \leq i < K-1 \ x[i - (K-1), dim], & K-1 \leq i < L + K - 1 \end{cases} $$

    1. 因果1维卷积:

    $$ y[i, dim] = \sum_{k=0}^{K-1} w[k, dim] \cdot x'[i + k, dim] $$

    1. 缓存更新:

    $$ cacheState[i, dim] = x'[L + i, dim], \quad i = 0, 1, \dots, K-2 $$

    1. 残差连接(可选):

    $$ y[i, dim] += x[i, dim] $$

参数说明

参数名输入/输出/属性描述数据类型数据格式
x输入输入序列,对应公式中x。FLOAT16、BFLOAT16ND
weight输入因果1维卷积核,K固定为3,对应公式中w。数据类型与x一致ND
conv_states输入/输出缓存状态张量,存储各序列的历史token数据,各序列计算完成后原地更新,对应公式中cacheState。数据类型与x一致ND
query_start_loc可选输入序列起始位置索引,记录各序列在拼接张量x中的起始位置。query_start_loc[i]表示第i个序列的起始偏移。INT32ND
cache_indices可选输入缓存索引,指定每个序列对应的缓存状态在conv_states中的索引。INT32ND
initial_state_mode可选输入初始状态标志,表示各序列是否使用缓存数据:0=零填充,1=使用缓存,2=使用缓存但前K-1个输出置0。INT32ND
bias可选输入卷积的偏置。数据类型与x一致ND
num_accepted_tokens可选输入decode场景下的投机token个数。INT32ND
activation_mode属性激活函数类型,取值为0、1、2。
0:None;
1:silu;
2:swish。
INT-
pad_slot_id属性用于跳过不需要参与计算的batch,-1表示不跳过。当cache_indices[i]==pad_slot_id时跳过该batch。INT-
run_mode属性用于判断是prefill场景或decode场景,取值为0、1。
0:prefill场景;
1:decode场景。
INT-
residual_connection属性是否做残差连接,取值为0、1。
0:不做残差连接;
1:输出y和输入x相加后输出。
INT-
y输出输出序列,shape与x一致,对应公式中y。数据类型与x一致ND

约束说明

  • 输入shape限制:

    • prefill场景:
      • x支持2维[cu_seq_len, dim]。
      • weight必须是2维[K, dim],其中K固定为3。
      • conv_states必须是3维[..., K-1, dim],第0维大小不固定且大于等于batch。
      • cu_seq_len范围[batch, 65536],dim范围[128, 16384]且是128的倍数,batch范围[1, 256]。
    • decode场景(固定batch):
      • x支持3维[batch, m+1, dim]。
      • weight必须是2维[K, dim],其中K固定为3。
      • conv_states必须是3维[..., K-1+m, dim],第0维大小不固定且大于等于batch。
      • m范围[0, 5],dim范围[128, 16384]且是128的倍数,batch范围[1, 256]。
    • decode场景(变长序列):
      • x支持2维[cu_seq_len, dim]。
      • weight必须是2维[K, dim],其中K固定为3。
      • conv_states必须是3维[..., state_len, dim],第0维大小不固定且大于等于batch,state_len必须大于所有batch中最大的token个数加K-1。
      • cu_seq_len范围[batch, batch*6],每个batch的token个数范围为[1, 6]。dim范围[128, 16384]且是128的倍数,batch范围[1, 256]。
  • 输入值域限制:

    • query_start_loc是累计偏移量,取值范围[0, cu_seq_len],长度为batch+1,query_start_loc[i]表示第i个序列的起始偏移,query_start_loc[batch+1]表示最后一个序列的结束位置。
    • cache_indices长度为batch,指定每个序列对应的缓存槽索引。
    • num_accepted_tokens分为None和非None,非None情况下长度为batch,每个元素取值不超过当前batch的token个数且大于0。

    调用说明

    调用方式样例代码说明
    aclnn接口test_aclnn_fused_causal_conv1d通过aclnnFusedCausalConv1d调用FusedCausalConv1d算子

【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库,实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考