使用Pytorch导出自定义ONNX算子

在实际部署模型时有时可能会遇到想用的算子无法导出onnx,但实际部署的框架是支持该算子的。此时可以通过自定义onnx算子的方式导出onnx模型(注:自定义onnx算子导出onnx模型后是无法使用onnxruntime推理的)。下面给出个具体应用中的示例:需要导出pytorch的affine_grid算子,但在pytorch的2.0.1版本中又无法正常导出该算子,故可通过如下自定义算子代码导出。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypes


class CustomAffineGrid(Function):
    @staticmethod
    def forward(ctx, theta: torch.Tensor, size: torch.Tensor):
        grid = F.affine_grid(theta=theta, size=size.cpu().tolist())
        return grid

    @staticmethod
    def symbolic(g: torch.Graph, theta: torch.Tensor, size: torch.Tensor):
        return g.op("AffineGrid", theta, size)


class MyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor, theta: torch.Tensor, size: torch.Tensor):
        grid = CustomAffineGrid.apply(theta, size)
        x = F.grid_sample(x, grid=grid, mode="bilinear", padding_mode="zeros")
        return x


def main():
    with torch.inference_mode():
        custum_model = MyModel()
        x = torch.randn(1, 3, 224, 224)
        theta = torch.randn(1, 2, 3)
        size = torch.as_tensor([1, 3, 512, 512])

        torch.onnx.export(model=custum_model,
                          args=(x, theta, size),
                          f="custom.onnx",
                          input_names=["input0_x", "input1_theta", "input2_size"],
                          output_names=["output"],
                          dynamic_axes={"input0_x": {2: "h0", 3: "w0"},
                                        "output": {2: "h1", 3: "w1"}},
                          opset_version=16,
                          operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)


if __name__ == '__main__':
    main()

在上面代码中,通过继承torch.autograd.Function父类的方式实现导出自定义算子,继承该父类后需要用户自己实现forward以及symbolic两个静态方法,其中forward方法是在pytorch正常推理时调用的函数,而symbolic方法是在导出onnx时调用的函数。对于forward方法需要按照正常的pytorch语法来实现,其中第一个参数必须是ctx但对于当前导出onnx场景可以不用管它,后面的参数是实际自己传入的参数。对于symbolic方法的第一个必须是g,后面的参数任为实际自己传入的参数,然后通过g.op方法指定具体导出自定义算子的名称,以及输入的参数(注:上面示例中传入的都是Tensor所以可以直接传入,对与非Tensor的参数可见下面一个示例)。最后在使用时直接调用自己实现类的apply方法即可。使用netron打开自己导出的onnx文件,可以看到如下所示网络结构。
在这里插入图片描述

有时按照使用的推理框架导出自定义算子时还需要设置一些参数(非Tensor)那么可以参考如下示例,例如要导出int型的参数k那么可以通过传入k_i来指定,要导出float型的参数scale那么可以通过传入scale_f来指定,要导出string型的参数clockwise那么可以通过传入clockwise_s来指定:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypes


class CustomRot90AndScale(Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        x = torch.rot90(x, k=1, dims=(3, 2))  # clockwise 90
        x *= 1.2
        return x

    @staticmethod
    def symbolic(g: torch.Graph, x: torch.Tensor):
        return g.op("Rot90AndScale", x, k_i=1, scale_f=1.2, clockwise_s="yes")


class MyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor):
        return CustomRot90AndScale.apply(x)


def main():
    with torch.inference_mode():
        custum_model = MyModel()
        x = torch.randn(1, 3, 224, 224)

        torch.onnx.export(model=custum_model,
                          args=(x,),
                          f="custom_rot90.onnx",
                          input_names=["input"],
                          output_names=["output"],
                          dynamic_axes={"input": {2: "h0", 3: "w0"},
                                        "output": {2: "w0", 3: "h0"}},
                          opset_version=16,
                          operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)


if __name__ == '__main__':
    main()

使用netron打开自己导出的onnx文件,可以看到如下所示信息。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/439189.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

米酒生产加工污水处理需要哪些工艺设备

米酒生产加工过程中产生的污水是一项重要的环境问题,需要采用适当的工艺设备进行处理。下面将介绍一些常用的污水处理工艺设备。 首先,生产过程中的污水需要进行初级处理,常见的设备包括格栅和砂池。格栅用于去除污水中的大颗粒杂质&#xff…

python导出数据到sqlite中

import sqlite3# 数据 data [{username: 张三, age: 33, score: 13},{username: 李四, age: 44, score: 14},{username: 王五, age: 55, score: 15}, ]# 连接SQLite数据库(如果不存在则创建) conn sqlite3.connect(test.db)# 创建游标对象 cursor con…

神经网络8-注意力机制

注意力机制(Attention Mechanism)源于对人类视觉的研究。在认知科学中,由于信息处理的瓶颈,人类会选择性地关注所有信息的一部分,同时忽略其他可见的信息。这种机制被称为注意力机制。举个例子来说,当我们观…

【排序算法】深入理解插入排序算法:从原理到实现

1. 引言 排序算法是计算机科学中的基本问题之一,它的目标是将一组元素按照某种规则进行排列。插入排序是其中一种简单但有效的排序算法,通过逐步构建有序序列来实现排序。本文将从原理、时间复杂度、应用场景、优缺点等方面深入探讨插入排序算法&#x…

keepalived原理以及lvs、nginx跟keeplived的运用

keepalived基础 keepalived的原理是根据vrrp协议(主备模式)去设定的 vrrp技术相关原理 状态机; 优先级0~255 心跳线1秒 vrrp工作模式 双主双备模式 VRRP负载分担过程 vrrp安全认证:使用共享密匙 keepalived工具介绍 keepal…

如何压缩图片大小到100kb以下?

如何压缩图片大小到100kb以下?不知道上班族小伙伴有没有发现,当我们工作中使用图片的时候经常遇到遇到一个尴尬的情况,例如我们需要网某个网站上传一张图片的时候,会被限制要求图片大小不能超过100kb,如果超过就无法进…

基于uniapp cli项目开发的老项目,运行报错path.replace is not a function

项目:基于uniapp cli的微信小程序老项目 问题:git拉取代码,npm安装包时就报错; cnpm能安装成功包,运行报错 三种方法尝试解决: 更改代码,typeof pathstring的话,才走path.replace…

JVM 面试题

1、什么情况下会发生栈内存溢出。 栈内存溢出通常发生在以下几种情况中: 函数递归调用过深: 当函数递归调用自身且没有合适的退出条件时,每次递归调用都会在栈上分配一个新的栈帧来存储局部变量、返回地址等信息。如果递归层次过多&#xff…

计讯物联山体滑坡地质灾害监测方案为灾区保驾护航

针对我国某些地区频繁爆发山体滑坡的情况,计讯物联深入调研滑坡体自动监测、无线通讯、险情预报等方面,自主研发反应快速高效、可广泛应用的山体滑坡地质灾害监测方案,全面掌握山体滑坡信息,为当地居民留有余裕的逃生时间。 计讯物…

Docker 配置阿里云镜像加速器

一、首先需要创建一个阿里云账号 二、登录阿里云账号 三、进入控制台 四、搜索容器镜像服务,并选择 五、选择镜像工具中的镜像加速 六 、配置镜像源 注意:有/etc/docker文件夹的直接从第二个命令开始

JavaSE-10(JDK8 新特性-万字总结)

新特性概述 Lambda 表达式函数式接口引用Stream API接口中的默认方法 / 静态方法新时间日期 API其他新特性 Lambda表达式/匿名函数 在 Java 中,匿名函数通常是指 Lambda 表达式。 Lambda 表达式允许你以一种简洁、紧凑的方式编写匿名函数,而不必创建…

Windows系统安装MongoDB并结合内网穿透实现公网访问本地数据库

文章目录 前言1. 安装数据库2. 内网穿透2.1 安装cpolar内网穿透2.2 创建隧道映射2.3 测试随机公网地址远程连接 3. 配置固定TCP端口地址3.1 保留一个固定的公网TCP端口地址3.2 配置固定公网TCP端口地址3.3 测试固定地址公网远程访问 前言 MongoDB是一个基于分布式文件存储的数…

Zookeeper基础知识:成功分布式系统的关键

文章目录 一、引言二、Zookeeper介绍三、Zookeeper安装四、Zookeeper架构【重点】4.1 Zookeeper树形结构4.2 znode类型4.3 Zookeeper的监听通知机制 五、Zookeeper常用操作5.1 zk常用命令5.2 Java连接Zookeeper5.3 Java操作Znode节点5.4 监听通知机制 六、Zookeeper集群【重点】…

基于AI软件平台 HEGERLS智能托盘四向车机器人物流仓储解决方案持续升级

随着各大中小型企业对仓储需求的日趋复杂,柔性、离散的物流子系统也不断涌现,各种多类型的智能移动机器人、自动化仓储装备大量陆续的应用于物流行业中,但仅仅依靠传统的物流技术和单点的智能化设备,已经无法更有效的应对这些挑战…

【学习心得】websocket协议简介并与http协议对比

一、轮询和长轮询 在websocket协议出现之前,要想实现服务器和客户端的双向持久通信采取的是Ajax轮询。它的原理是每隔一段时间客户端就给服务器发送请求找服务器要数据。 让我们通过一个生活化的比喻来解释轮询和长轮询假设你正在与一位不怎么主动说话的老大爷&…

学习人工智能:吴恩达《AI for everyone》2019 第3周:实现智能音箱和自动驾驶的几个步骤;无监督学习;增强学习

吴恩达 Andrew Ng, 斯坦福大学前教授,Google Brain项目发起人、领导者。 Coursera 的联合创始人和联合主席,在 Coursera 上有十万用户的《机器学习》课程;斯坦福大学计算机科学前教授。百度前副总裁、前首席科学家;谷…

开发知识点-Apache Struts2框架

Apache Struts2 介绍S2-001S2CVE-2023-22530介绍 Apache Struts2是一个基于MVC(模型-视图-控制器)设计模式的Web应用程序框架,它是Apache旗下的一个开源项目,并且是Struts1的下一代产品。Struts2是在Struts1和WebWork的技术基础上合并出来的全新web框架,其核心是WebWork。…

网络入侵检测系统之Suricata(十)--ICMP实现详解

ICMP协议 Common header 0 1 2 40 1 2 3 4 5 6 7 8 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 0 1 2 3 4--------------------------------| Type | Code | Checksum |-----…

VUE_自适应布局-postcss-pxtorem,nuxt页面自适配

postcss-pxtorem是一个PostCSS插件,用于将CSS中的像素值转换为rem单位,以实现响应式布局和适配不同屏幕尺寸的需求。 它的适配原理是将CSS中的像素值除以一个基准值,通常是设计稿的宽度,然后将结果转换为rem单位。这样&#xff0…

根据用户名称实现单点登录

一、参数格式 二、后端实现 Controller层 public class IAccessTokenLoginController extends BaseController {Autowiredprivate ISysUserService sysUserService;Autowiredprivate ISingleTokenServiceImpl tokenService;/*** 登录方法** return 结果*/PostMapping("/l…
最新文章