CANNBot: RoPE预计算参考

📅 2026/7/4 9:24:20 👁️ 阅读次数 📝 编程学习
CANNBot: RoPE预计算参考

RotaryEmbedding 预计算与调用模式参考

【免费下载链接】cannbot-skillsCANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体,本仓库为其提供可复用的 Skills 模块。项目地址: https://gitcode.com/cann/cannbot-skills

目的

本文档用于指导 RoPE 前置改造,收敛出更成熟的RotaryEmbedding模块化写法和调用方式。

适用场景

当目标算子或目标链路需要 RoPE 相关前置改造时,可优先参考本文档。例如:

  • 需要将分散的freqs_cis/cos/sin计算收敛成统一模块
  • 需要为npu_mla_prolog_v3npu_kv_rmsnorm_rope_cache_v2npu_fused_infer_attention_score_v2等路径准备稳定的cos/sinquery_rope / key_rope
  • 需要把Prefill / Decode的 RoPE 取值逻辑从各层局部切片改成模型级统一准备
  • 需要为PA / 非PA路径整理更清晰的 RoPE metadata 组织方式

推荐模式

优先参考deepseek-v3.2-expdeepseek-r1等成熟实现,并采用下面的共同思路:

  1. 将 RoPE 的预计算、缓存和取值逻辑收敛到统一位置,不再分散到各层或多个辅助函数中;实现位置可以是独立类、公共模块或模型内部统一入口
  2. 在模型初始化阶段完成第一次cos/sincache 预计算,并用register_buffer保存
  3. 在模型级或统一入口按当前阶段和位置语义取出实际需要的cos/sin
  4. 下游 Attention / MLA 层只消费已经准备好的cos/sincos_sin,避免每层各自从整表切片
  5. 若存在Prefill / DecodePA / 非PABNSD / TND等差异,应在统一入口收敛这些分支,而不是散落到每一层
  6. 若后续链路需要query_rope / key_rope或直接传rope_sin / rope_cos给融合算子,应在统一入口先整理好 RoPE 输入,再传入下游模块

仓库成熟参考

  • cann-recipes-infer/models/deepseek-v3.2-exp/models/modules.pycann-recipes-infer/models/deepseek-v3.2-exp/models/modeling_deepseek.py
    • 参考RotaryEmbedding类、_init_rope()、模型级统一取值
  • cann-recipes-infer/models/deepseek_r1/models/modules.pycann-recipes-infer/models/deepseek_r1/models/modeling_deepseek.py
    • 参考Prefill / DecodePA / 非PABNSD / TND等分支处理方式

最小样例

下面的样例保留成熟模式的核心结构,适合作为前置改造模板:

class RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings, base=10000.0): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._set_cos_sin_cache(max_position_embeddings, dtype=torch.get_default_dtype()) def _set_cos_sin_cache(self, seq_len, dtype): t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, hidden_states, position_ids): cos = torch.index_select(self.cos_cached, 0, position_ids.view(-1)).unsqueeze(1).unsqueeze(1) sin = torch.index_select(self.sin_cached, 0, position_ids.view(-1)).unsqueeze(1).unsqueeze(1) return cos.to(hidden_states.dtype), sin.to(hidden_states.dtype)
def _init_rope(self): self.rotary_emb = RotaryEmbedding( dim=self.config.qk_rope_head_dim, max_position_embeddings=self.config.max_position_embeddings, base=self.config.rope_theta, )
def forward(self, hidden_states, position_ids, ...): cos_sin = self.rotary_emb(hidden_states, position_ids) for layer in self.layers: hidden_states = layer(hidden_states, cos_sin=cos_sin, ...) return hidden_states

使用边界

本文档只用于指导 RoPE 预计算与调用路径改造,不负责:

  • 选择具体融合算子
  • 展开完整 Attention 替换流程
  • 定义query_rope / key_rope的最终算子接口细节

这些内容仍以主 skill 和对应算子文档为准。

【免费下载链接】cannbot-skillsCANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体,本仓库为其提供可复用的 Skills 模块。项目地址: https://gitcode.com/cann/cannbot-skills

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考