模型端侧适配技能之ONNX 模型拆分

📅 2026/7/3 18:01:55 👁️ 阅读次数 📝 编程学习
模型端侧适配技能之ONNX 模型拆分

ONNX 模型拆分完整技巧(工具、场景、坑点、示例代码)

拆分 ONNX ≠ 单纯提速,拆分是手段,提速只是其中一种场景,很多时候拆分反而会变慢。

一、需要拆分 ONNX 模型的核心场景

拆分模型不是单纯为了提速,是解决各类工程痛点的手段,分为性能优化类工程兼容类业务调试类三大场景:

(一)拆分后可提升推理速度的场景

  1. 异构硬件流水线并行设备同时具备 GPU/CPU/NPU/DLA/DSP 多种算力单元,将预处理、主干网络、NMS 后处理分配给不同硬件异步并行执行,硬件资源不闲置,整体推理耗时下降。
  2. 规避算子不兼容,减少 CPU 降级推理TensorRT、RKNN、Tengine 等端侧加速引擎对 Loop、If、复杂索引、动态 NMS 等算子支持差;拆分后把不兼容算子单独切出交由 CPU 运行,主干网络完整走硬件加速,避免整张模型全部软推理。
  3. 多任务多头按需推理,减少无效计算一体化多任务模型(检测 + 分割 + 分类)每次推理会执行全部分支;拆分后可按需只运行需要的子模型,跳过无关分支,降低计算量。
  4. 超大模型分片,解决显存 / 内存溢出BEV、大分割模型一次性加载权重会出现 OOM;拆分后分段加载、推理后释放权重,保证模型正常运行,间接避免卡顿、崩溃带来的效率损失。

(二)拆分不提速,仅解决工程 / 调试问题的场景

  1. 精度问题定位调试单独提取中间特征子图,对比框架导出与 ONNX 中间输出,快速定位误差层。
  2. 业务模块化复用主干网络通用,检测头、分割头可单独替换、迭代,无需重新导出完整模型。
  3. 分层交付与权限隔离模型主干、后处理分开交付,区分不同模块使用权限。
  4. 自定义算子改造单独切出目标算子所在子图,单独替换为 CUDA / 硬件自定义算子。
  5. 多工具链适配主干用 TensorRT 编译加速,后处理用 ONNXRuntime 运行,两类工具无法同时处理一张完整模型。

(三)不建议拆分的场景(拆分反而减速)

仅单 GPU 硬件、拆分后串行运行多个子模型,会产生多重损耗:多次创建推理会话、GPU/CPU 数据来回拷贝、算子融合优化碎片化、权重重复加载,推理耗时上升 20%~50%。

二、4 种主流拆分工具对比

工具优势适用场景局限
onnx.utils.extract_model(onnx 原生)零额外依赖、极简 API简单直线图、无分支 / 残差残差跳跃连接、循环子图易报错
onnx-graphsurgeon(TensorRT 配套)图编辑能力最强,可改节点、重连线、处理残差复杂网络、残差、多分支、跨硬件拆分需额外安装,仅 Python
onnxslim轻量化、支持 CLI 命令行、一键提取子图快速裁剪、删除冗余输出复杂图修改能力弱于 graphsurgeon
框架侧预拆分(PyTorch/Paddle 导出时分段)无 ONNX 图修复、张量名天然对齐模型还未导出 onnx 阶段拆分已有的 onnx 文件无法使用

三、基础拆分方法 1:原生 onnx extract_model(最简单)

原理

指定子图输入张量名子图输出张量名,自动提取中间子图,自动携带所需权重 initializerONNX。

示例代码

import onnx # 原始模型、输出子模型 src_onnx = "full_model.onnx" sub_onnx = "backbone_sub.onnx" # 1. 用Netron打开onnx,找到分割边界张量名 # 例如:原图输入是["images"], 分割点张量为"feat_out" sub_inputs = ["images"] sub_outputs = ["feat_out"] # 提取子模型 onnx.utils.extract_model( src_onnx, sub_onnx, input_names=sub_inputs, output_names=sub_outputs ) # 校验子图合法性 model = onnx.load(sub_onnx) onnx.checker.check_model(model) print("子模型拆分完成")

适用限制

  • 不能切割If/Loop等带子图的控制流算子;
  • 残差跳跃连接跨分割边界会丢失张量,直接报错;
  • 仅适合单向无分支的简单网络(分类主干)。

四、基础拆分方法 2:OnnxSlim(命令行 + Python,快速裁剪)

CLI 一行拆分(推荐快速调试)

# 只保留从images输入到feat_out输出的子图 onnxslim full.onnx sub.onnx --inputs images --outputs feat_out # 只删除多余输出,保留原图输入 onnxslim full.onnx head.onnx --outputs det_out0,det_out1

Python API

import onnxslim onnxslim.slim("full.onnx", "sub.onnx", inputs=["images"], outputs=["feat_out"])

五、高级拆分方法 3:onnx-graphsurgeon(复杂网络首选,处理残差 / 多分支)

核心优势

手动遍历张量、重定向输入输出、清理孤立节点、完美兼容残差、跳跃连接、多分支网络,解决 extract_model 残差报错问题。

实战代码:切分 Backbone 与检测头

import onnx import onnx_graphsurgeon as gs # 1. 加载图 model = onnx.load("yolov8_full.onnx") graph = gs.import_onnx(model) tensors = graph.tensors() # 2. 定义分割边界张量(Netron查看) split_tensor = tensors["backbone_feat"] # 主干输出,头分支输入 origin_input = tensors["images"] # ========== 拆分1:主干子图 ========== # 新图输入:原图输入;新图输出:分割张量 graph_backbone = graph.copy() graph_backbone.inputs = [origin_input] graph_backbone.outputs = [split_tensor] # 清理无用节点、权重 graph_backbone.cleanup().toposort() onnx.save(gs.export_onnx(graph_backbone), "backbone.onnx") # ========== 拆分2:检测头子图 ========== graph_head = graph.copy() # 头模型输入改为分割张量,输出保留原图所有输出 graph_head.inputs = [split_tensor] graph_head.cleanup().toposort() onnx.save(gs.export_onnx(graph_head), "det_head.onnx")

关键技巧:处理跨分支残差

  1. 拆分前先graph.cleanup()移除冗余 Identity;
  2. 所有跳跃连接张量不能跨分割边界,分割线必须放在残差 Add 之后;
  3. 多输出场景:将所有分支末端张量统一设为子图 output。

六、框架侧预拆分(导出前拆分,最优方案)

在 PyTorch 导出 ONNX 前直接拆分子网络,张量名天然对齐,无图修复问题,推荐新项目使用。

import torch from torchvision import models model = models.resnet50(pretrained=True).eval() dummy = torch.randn(1,3,640,640) # 拆分1:主干 backbone = torch.nn.Sequential(*list(model.children())[:-2]) torch.onnx.export( backbone, dummy, "backbone.onnx", input_names=["img"], output_names=["feat"], opset_version=17 ) # 拆分2:分类头 feat_dummy = torch.randn(1,2048,20,20) head = torch.nn.Sequential(model.avgpool, model.fc) torch.onnx.export( head, feat_dummy, "cls_head.onnx", input_names=["feat"], output_names=["pred"], opset_version=17 )

七、工程通用拆分技巧

1. 分割线选择黄金规则

  1. 禁止切残差 / 跳跃连接中间:分割点放在 Add/Concat 之后,不要切 Shortcut 支路;
  2. 避开控制流算子内部:If、Loop、Scan 的子图不能被分割线截断;
  3. 边界张量尽量选 Identity 输出:Netron 中插入 Identity 节点作为分割标记,方便定位;
  4. 多头模型统一在分支起点分割:多检测头、多分割头从主干输出处一刀切。

2. 拆分前预处理(大幅降低报错)

  1. 简化模型:onnxsim full.onnx sim.onnx,消除冗余 Reshape、Identity;
  2. 推理 shape 推导:onnx.shape_inference.infer_shapes(model),子图 shape 校验;
  3. 外部权重分离:超大模型拆分前导出外部数据,避免 onnx 文件过大:
    from onnx.external_data_helper import convert_model_to_external_data model = onnx.load("big.onnx") convert_model_to_external_data(model, location="weights.bin") onnx.save(model, "big_split.onnx")

3. 拆分后校验三板斧

# 1. 格式合法性校验 onnx.checker.check_model(sub_model) # 2. 推理输出对齐(原始图vs子图拼接输出) import onnxruntime as ort # 原图推理 ori_sess = ort.InferenceSession("full.onnx") ori_out = ori_sess.run(None, {"images": rand_input}) # 子图串联推理 b_sess = ort.InferenceSession("backbone.onnx") feat = b_sess.run(None, {"images": rand_input})[0] h_sess = ort.InferenceSession("head.onnx") sub_out = h_sess.run(None, {"backbone_feat": feat}) # 输出误差校验 import numpy as np print(np.allclose(ori_out[0], sub_out[0], atol=1e-5))

4. 特殊场景拆分方案

场景 A:前后处理拆分(CPU 预处理,GPU 推理)
  • 预处理(Resize/Normalize/Transpose)拆为独立子模型,CPU 执行;
  • 推理主干 GPU 执行;
  • 后处理 NMS/ArgMax 单独拆分,DLA/CPU 执行;
场景 B:多输出多头拆分
# extract_model同时提取多个输出,拆出单头 onnx.utils.extract_model( "full.onnx", "seg_head.onnx", input_names=["feat"], output_names=["seg_out"] # 只保留分割输出,丢弃检测输出 )
场景 C:算子不兼容拆分(如 TensorRT 不支持循环)
  1. Netron 定位不支持算子的输入输出张量;
  2. 分割线放在该算子前后,拆出不兼容子图;
  3. 子图用 ONNXRuntime/CPU 自定义算子推理,其余用 TensorRT。

八、高频报错与解决方案

  1. extract_model 报错:张量不存在原因:分割线跨残差支路,shortcut 张量未包含在子图; 解决:改用 onnx-graphsurgeon 复制完整图后裁剪,或移动分割点到 Add 后。
  2. 子图推理 shape 不匹配原因:未做 shape 推理、动态维度冲突; 解决:拆分前执行infer_shapes,固定输入 shape 或统一 dynamic 维度。
  3. 拆分后权重丢失原因:initializer 仅被分支使用,裁剪时被清理; 解决:使用graph.copy()完整复制原图再裁剪,不要直接修改原图。
  4. Loop/If 子图拆分失败官方限制:分割线不能切割控制流子图,需将整个 Loop 作为一个完整子图拆分。

九、工具选型速记

  1. 简单直线网络、快速调试 →onnx.utils.extract_model/onnxslim
  2. ResNet、YOLO、U-Net 带残差 / 多分支 →onnx-graphsurgeon
  3. 模型还未导出、原生 PyTorch/Paddle → 框架内分段导出(最优)
  4. 命令行批量拆分脚本 → onnxslim CLI

以上内容,要是对您有用,请给个赞和关注,感谢您的支持。