【Pytorch】(十三)模型部署: TorchScript

文章目录

  • (十三)模型部署: TorchScript
    • Pytorch动态图的优缺点
    • TorchScript
    • Pytorch模型转换为TorchScript
      • torch.jit.trace
      • torch.jit.script
      • trace和script的区别总结
      • trace 和script 混合使用
      • 保存和加载模型

(十三)模型部署: TorchScript

Pytorch动态图的优缺点

与Tensorflow使用静态计算图不同,PyTorch 使用的是动态计算图:

动态图允许在运行时渐进地构建计算图,使得模型设计更加灵活。开发者可以使用 Python 的控制流结构(如循环、条件语句等)来动态地定义模型的结构,从而更容易实现复杂的模型逻辑。

这种计算方式更直观,更pythonic。开发者可以更容易地理解和调试模型各个模块,快速地修改、迭代模型。

然而,与静态图相比,动态图的执行效率可能会较低。因为动态图难以进行一些计算图的优化,如运算符融合、图优化等。而且,动态图依赖于Python 环境。这些因素使得动态图不适合在低延迟要求较高的生产环境下部署。

因此,在部署Pytorch训练后的模型时,需要将动态图转换为静态图,这就要用到TorchScript。

TorchScript

TorchScript是PyTorch模型的一种静态图表示形式,支持模型的部署优化、跨平台部署以及与其他深度学习框架的集成:

  • 模型的部署优化:TorchScript 可以帮助优化 PyTorch 模型以提高性能和效率。通过将模型转换为静态图形式,TorchScript 可以应用各种优化技术,如运算符融合、图优化等,从而加速模型执行并降低内存消耗。
  • 跨平台部署:将模型转换为 TorchScript 格式可以实现跨平台部署,模型可以在没有 Python 环境的情况下运行。这对于在生产环境中部署模型到服务器、移动设备或边缘设备上非常有用。
  • 与其他框架集成:通过将 PyTorch 模型转换为 TorchScript 格式,可以更方便地与其他深度学习框架进行交互。例如,可以将TorchScript 进一步转换为 ONNX 格式,从而与 TensorFlow 等其他框架进行集成和交互操作。

Pytorch模型转换为TorchScript

torch.jit.tracetorch.jit.script 是 PyTorch 中用于模型转换为 TorchScript 格式的工具,但它们有不同的作用和使用场景。

torch.jit.trace

通过torch.jit.trace 将 没有控制流的MyCell 模块转化为TorchScript:


import torch  # This is all you need to use both PyTorch and TorchScript!

torch.manual_seed(191009)  # set the seed for reproducibility


class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)

torch.jit.trace调用了my_cell,记录了模块计算时发生的操作,并创建了一个torch.jit.ScriptModule的实例(TracedModule是其实例)traced_celltraced_cell 记录了my_cell的计算图。我们可以使用.graph属性来查看:

print(traced_cell.graph)
graph(%self.1 : __torch__.MyCell,
      %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
  %11 : int = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0
  %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0
  %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0
  %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
  return (%14)

然而,图中包含的大多数信息对我们没有用处。我们可以使用.code属性对其进行Python语法解释:

print(traced_cell.code)
def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  _0 = torch.tanh(torch.add((linear).forward(x, ), h))
  return (_0, _0)

调用traced_cell会产生与Python模块实例my_cell() 相同的结果:

print(my_cell(x, h))
print(traced_cell(x, h))
(tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>))
(tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>))

torch.jit.script

我们先尝试通过torch.jit.trace 将 带有控制流的MyCell 模块转化为TorchScript:

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))

print(traced_cell.dg.code)
print(traced_cell.code)
/var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:261: TracerWarning:

Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

def forward(self,
    argument_1: Tensor) -> NoneType:
  return None

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = (linear).forward(x, )
  _1 = (dg).forward(_0, )
  _2 = torch.tanh(torch.add(_0, h))
  return (_2, _2)

可以看到,if-else分支并没有被表示出来。为什么?
trace记录代码运行发生的操作,并构造一个ScriptModule。控制流中只有一种情况被记录了下来,其他情况都被忽略了。

这就需要用到torch.jit.script了:

scripted_gate = torch.jit.script(MyDecisionGate())

my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell)

print(scripted_gate.code)
print(scripted_cell.code)
def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)

可以考到,控制流也被记录了下来。
现在让我们尝试运行该程序:

# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
print(scripted_cell(x, h))
(tensor([[ 0.5679,  0.5762,  0.2506, -0.0734],
        [ 0.5228,  0.7122,  0.6985, -0.0656],
        [ 0.6187,  0.4487,  0.7456, -0.0238]], grad_fn=<TanhBackward0>), tensor([[ 0.5679,  0.5762,  0.2506, -0.0734],
        [ 0.5228,  0.7122,  0.6985, -0.0656],
        [ 0.6187,  0.4487,  0.7456, -0.0238]], grad_fn=<TanhBackward0>))

trace和script的区别总结

  • torch.jit.tracetorch.jit.trace 用于将一个具体的输入示例追踪(trace)模型的一次计算过程,从而生成一个 TorchScript 模型。对于动态控制流(如条件语句),它只会记录每个分支中的一种情况。因此,它不适用于无固定形状输入、具有动态控制流的模型。

  • torch.jit.scripttorch.jit.script 用于将整个 PyTorch 模型转换为 TorchScript 模型,包括模型的所有逻辑和控制流。script适用于无固定形状输入、具有动态控制流的模型 。但是,它可能会把保存一些多余的代码, 产生额外的性能开销。

因此,可以将两者混合使用,扬长避短。

trace 和script 混合使用

torch.jit.tracetorch.jit.script 可以混合使用: 复杂模型中静态部分用torch.jit.trace进行转换, 动态部分用torch.jit.script 进行转换,以发挥各自的优势。以下是两个可能的情况:

  • torch.jit.script内联traced模块的代码,
class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4])
  y = torch.zeros([3, 4])
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    cell = self.cell
    _0 = (cell).forward(torch.select(xs, 0, i), h0, )
    y1, h1, = _0
    y0, h0 = y1, h1
  return (y0, h0)
  • torch.jit.trace内联scripted模块的代码,
class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)

traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)
def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)

保存和加载模型

  • traced.save : 保存TorchScript

  • torch.jit.load : 加载TorchScript

traced.save('wrapped_rnn.pt')

loaded = torch.jit.load('wrapped_rnn.pt')

print(loaded)
print(loaded.code)
RecursiveScriptModule(
  original_name=WrapRNN
  (loop): RecursiveScriptModule(
    original_name=MyRNNLoop
    (cell): RecursiveScriptModule(
      original_name=MyCell
      (dg): RecursiveScriptModule(original_name=MyDecisionGate)
      (linear): RecursiveScriptModule(original_name=Linear)
    )
  )
)
def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)

参考:
https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/578714.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于java+springboot+vue实现的医疗挂号管理系统(文末源码+Lw)203

摘 要 在如今社会上&#xff0c;关于信息上面的处理&#xff0c;没有任何一个企业或者个人会忽视&#xff0c;如何让信息急速传递&#xff0c;并且归档储存查询&#xff0c;采用之前的纸张记录模式已经不符合当前使用要求了。所以&#xff0c;对医疗挂号信息管理的提升&#x…

Pytorch 之torch.nn初探 卷积--Convolution Layers

任务描述 本关任务&#xff1a; 本关提供了一个Variable 类型的变量input&#xff0c;按照要求创建一 Conv1d变量conv&#xff0c;对input应用卷积操作并赋值给变量 output&#xff0c;并输出output 的大小。 相关知识 卷积的本质就是用卷积核的参数来提取原始数据的特征&a…

OpenHarmony语言基础类库【@ohos.util.Stack (线性容器Stack)】

ohos.util.Stack (线性容器Stack) Stack基于数组的数据结构实现&#xff0c;特点是先进后出&#xff0c;只能在一端进行数据的插入和删除。 Stack和[Queue]相比&#xff0c;Queue基于循环队列实现&#xff0c;只能在一端删除&#xff0c;另一端插入&#xff0c;而Stack都在一…

[Qt的学习日常]--信号和槽

前言 作者&#xff1a;小蜗牛向前冲 名言&#xff1a;我可以接受失败&#xff0c;但我不能接受放弃 如果觉的博主的文章还不错的话&#xff0c;还请点赞&#xff0c;收藏&#xff0c;关注&#x1f440;支持博主。如果发现有问题的地方欢迎❀大家在评论区指正 本期学习&#xff…

PyQt6 优化操作:建立侧边栏,要求可拖拽改变宽度,可用按钮控制侧边栏的展开和收起

1. 官方文档 QSplitter — PyQt Documentation v6.6.0 2. 效果展示 可拖拽改变宽度比例 点击按钮快速收起或展开侧边栏 点击按钮&#xff0c;侧边栏收起&#xff0c;同时按钮图标变为向左箭头 (对应展开功能)&#xff0c;再次点击按钮&#xff0c;侧边栏展开&#xff0c;同…

Pycharm新建工程时使用Python自带解释器的方法

Pycharm新建工程时使用Python自带解释器的方法 新建Project时最好不要新建Python解释器&#xff0c;实践证明&#xff0c;自己新建的Python解释器容易出现各种意想不到的问题。 那么怎样使用Python安装时自带的解释器呢&#xff1f; 看下面的三张截图大家就清楚了。 我的Pyth…

英智数字孪生机器人解决方案,赋能仓库物流模式全面升级

工业机械臂、仓储机器人、物流机器人等模式的机器人系统在现代产业中扮演着愈发重要的角色&#xff0c;他们的发展推动了自动化和智能化水平的提高&#xff0c;有助于为制造业、物流业、医疗保健业和服务业等行业创造新效率并提升人们的生活质量。 行业面临的挑战 机器人开发、…

Windows安装Elasticsearch 7.9.2

1 下载 https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.9.2-windows-x86_64.zip 2 配置 进入config目录&#xff0c;打开elasticsearch.yml文件&#xff0c;给集群和节点配置名称。 cluster.name: my-es node.name: node-1 3 启动 打开bin目录&am…

Docker之常见FAQ记录清单

一、前言 本文记录Docker使用过程中遇见的问题&#xff0c;供后续回顾参考。 关联资源&#xff1a;网络Docker博客、官方FAQ、文档、Docker 从入门到实践、中文社区、riptutorial 二、问题及处理记录 2.1、docker容器内没有vi,nano等编辑器 1&#xff09;如果宿主机本地有&a…

vs2019 - warning LNK4098 : 默认库“msvcrt.lib”与其他库的使用冲突

文章目录 vs2019 - warning LNK4098 : 默认库“msvcrt.lib”与其他库的使用冲突概述笔记实验 - 编译静态库实验 - 编译主工程&#xff0c;包含静态库实验主工程和静态库编译设置不同时的编译报错和警告备注备注 - 判断/Mdd, /MdEND vs2019 - warning LNK4098 : 默认库“msvcrt.…

[SWPUCTF-2022-新生赛]ez_sql

title:[SWPUCTF 2022 新生赛]ez_sql 审题 根据提示&#xff0c;POST传参 得到假的flag 判断类型 字符型注入 判断列数 发现空格和’or’被过滤 重新构造 nss-1/**/oorrder/**/by/**/4#发现为3个字段 采用联合注入union 爆库 发现union被过滤&#xff0c;双写union绕过 发…

以生命健康为中心的物联网旅居养老运营平台

随着科技的飞速发展和人口老龄化的日益加剧&#xff0c;养老问题逐渐成为社会关注的焦点。传统的养老模式已经难以满足现代老年人的多元化需求&#xff0c;因此&#xff0c;构建一个以生命健康为中心的物联网旅居养老运营平台显得尤为重要。 以生命健康为中心的物联网旅居养老运…

两大成果发布!“大规模量子云算力集群”和高性能芯片展示中国科技潜力

在当前的科技领域&#xff0c;量子计算的进步正日益引起全球的关注。中国在这一领域的进展尤为显著&#xff0c;今天&#xff0c;北京量子信息科学研究院&#xff08;以下简称北京量子院&#xff09;和中国科学院量子信息与量子科技创新研究院&#xff08;以下简称量子创新院&a…

【c++】深入剖析与动手实践:C++中Stack与Queue的艺术

&#x1f525;个人主页&#xff1a;Quitecoder &#x1f525;专栏&#xff1a;c笔记仓 朋友们大家好&#xff0c;本篇文章我们来到STL新的内容&#xff0c;stack和queue 目录 1. stack的介绍与使用函数介绍例题一&#xff1a;最小栈例题二&#xff1a;栈的压入、弹出队列栈的模…

Docker 的数据管理 与 Docker 镜像的创建

目录 一、Docker 的数据管理 1.1.数据卷 1.2.数据卷容器 1.3.容器互联&#xff08;使用centos镜像&#xff09; 二、Docker 镜像的创建 2.1.基于现有镜像创建 2.2.基于本地模板创建 2.3.基于Dockerfile创建 2.3.1联合文件系统&#xff08;UnionFs&#xff09; 2.3.2…

GDPU 竞赛技能实践 天码行空9

1. 埃式筛法 求区间[2, n]内所有的素数对 &#x1f496; Main.java import java.util.Scanner;public class Main {static int N (int) 1e8, cnt 0;static int[] p new int[N];static boolean[] st new boolean[N];public static void main(String[] args){Scanner sc …

使用grasshopper修改梁的起始点方向

一般北方向朝上的情况&#xff0c;梁的方向从南向北&#xff0c;从西向东。 现在使用grasshopper来判断起始点坐标&#xff0c;分辨是否错误。 交换起始点这个&#xff0c;我实在不会用电池操作&#xff0c;只好敲python代码实现了。代码如下&#xff1a; 如果会敲代码的同学…

66.网络游戏逆向分析与漏洞攻防-利用数据包构建角色信息-重新规划游戏分析信息的输出

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 如果看不懂、不知道现在做的什么&#xff0c;那就跟着做完看效果&#xff0c;代码看不懂是正常的&#xff0c;只要会抄就行&#xff0c;抄着抄着就能懂了 内容…

Apache Seata的可观测实践

title: Seata的可观测实践 keywords: [Seata、分布式事务、数据一致性、微服务、可观测] description: 本文介绍Seata在可观测领域的探索和实践 author: 刘戎-Seata 本文来自 Apache Seata官方文档&#xff0c;欢迎访问官网&#xff0c;查看更多深度文章。 Seata简介 Seata的…

STM32单片机通过ST-Link 烧录和调试

系列文章目录 STM32单片机系列专栏 C语言术语和结构总结专栏 1. ST-LINK V2 ST LINK v2下载器用于STM32单片机&#xff0c;可以下载程序、调试程序、读取芯片数据&#xff0c;解除芯片读写保护等等&#xff0c;辅助软件用的是STM32 ST-LINK Utility。 STM32 ST-LINK Utility…
最新文章