【Pytorch】(十三)PyTorch模型部署: TorchScript

文章目录

  • (十三)PyTorch模型部署
    • Pytorch动态图的优缺点
    • TorchScript
    • Pytorch模型转换为TorchScript
      • torch.jit.trace
      • torch.jit.script
      • trace和script的区别总结
      • script 和 trace 混合使用
      • 保存和加载模型

(十三)PyTorch模型部署

Pytorch动态图的优缺点

与Tensorflow使用静态计算图不同,PyTorch 使用的是动态计算图:

动态图允许在运行时渐进地构建计算图,使得模型设计更加灵活。开发者可以使用 Python 的控制流结构(如循环、条件语句等)来动态地定义模型的结构,从而更容易实现复杂的模型逻辑。

这种计算方式更直观,更pythonic。开发者可以更容易地理解和调试模型各个模块,快速地修改、迭代模型。

然而,与静态图相比,动态图的执行效率可能会较低。因为动态图难以进行一些计算图的优化,如运算符融合、图优化等。而且,动态图依赖于Python 环境。这些因素使得动态图不适合在低延迟要求较高的生产环境下部署。

因此,在部署Pytorch训练后的模型时,需要将动态图转换为静态图,这就要用到TorchScript。

TorchScript

TorchScript是PyTorch模型的一种静态图表示形式,支持模型的部署优化、跨平台部署以及与其他深度学习框架的集成:

  • 模型的部署优化:TorchScript 可以帮助优化 PyTorch 模型以提高性能和效率。通过将模型转换为静态图形式,TorchScript 可以应用各种优化技术,如运算符融合、图优化等,从而加速模型执行并降低内存消耗。
  • 跨平台部署:将模型转换为 TorchScript 格式可以实现跨平台部署,模型可以在没有 Python 环境的情况下运行。这对于在生产环境中部署模型到服务器、移动设备或边缘设备上非常有用。
  • 与其他框架集成:通过将 PyTorch 模型转换为 TorchScript 格式,可以更方便地与其他深度学习框架进行交互。例如,可以将TorchScript 进一步转换为 ONNX 格式,从而与 TensorFlow 等其他框架进行集成和交互操作。

Pytorch模型转换为TorchScript

torch.jit.tracetorch.jit.script 是 PyTorch 中用于模型转换为 TorchScript 格式的工具,但它们有不同的作用和使用场景。

torch.jit.trace

通过torch.jit.trace 将 没有控制流的MyCell 模块转化为TorchScript:


import torch  # This is all you need to use both PyTorch and TorchScript!torch.manual_seed(191009)  # set the seed for reproducibilityclass MyCell(torch.nn.Module):def __init__(self):super(MyCell, self).__init__()self.linear = torch.nn.Linear(4, 4)def forward(self, x, h):new_h = torch.tanh(self.linear(x) + h)return new_h, new_hmy_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
MyCell(original_name=MyCell(linear): Linear(original_name=Linear)
)

torch.jit.trace调用了my_cell,记录了模块计算时发生的操作,并创建了一个torch.jit.ScriptModule的实例(TracedModule是其实例)traced_celltraced_cell 记录了my_cell的计算图。我们可以使用.graph属性来查看:

print(traced_cell.graph)
graph(%self.1 : __torch__.MyCell,%x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),%h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):%linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)%20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)%11 : int = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0%12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0%13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0%14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)return (%14)

然而,图中包含的大多数信息对我们没有用处。我们可以使用.code属性对其进行Python语法解释:

print(traced_cell.code)
def forward(self,x: Tensor,h: Tensor) -> Tuple[Tensor, Tensor]:linear = self.linear_0 = torch.tanh(torch.add((linear).forward(x, ), h))return (_0, _0)

调用traced_cell会产生与Python模块实例my_cell() 相同的结果:

print(my_cell(x, h))
print(traced_cell(x, h))
(tensor([[-0.2541,  0.2460,  0.2297,  0.1014],[-0.2329, -0.2911,  0.5641,  0.5015],[ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541,  0.2460,  0.2297,  0.1014],[-0.2329, -0.2911,  0.5641,  0.5015],[ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>))
(tensor([[-0.2541,  0.2460,  0.2297,  0.1014],[-0.2329, -0.2911,  0.5641,  0.5015],[ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541,  0.2460,  0.2297,  0.1014],[-0.2329, -0.2911,  0.5641,  0.5015],[ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>))

torch.jit.script

我们先尝试通过torch.jit.trace 将 带有控制流的MyCell 模块转化为TorchScript:

class MyDecisionGate(torch.nn.Module):def forward(self, x):if x.sum() > 0:return xelse:return -xclass MyCell(torch.nn.Module):def __init__(self, dg):super(MyCell, self).__init__()self.dg = dgself.linear = torch.nn.Linear(4, 4)def forward(self, x, h):new_h = torch.tanh(self.dg(self.linear(x)) + h)return new_h, new_hmy_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))print(traced_cell.dg.code)
print(traced_cell.code)
/var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:261: TracerWarning:Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!def forward(self,argument_1: Tensor) -> NoneType:return Nonedef forward(self,x: Tensor,h: Tensor) -> Tuple[Tensor, Tensor]:dg = self.dglinear = self.linear_0 = (linear).forward(x, )_1 = (dg).forward(_0, )_2 = torch.tanh(torch.add(_0, h))return (_2, _2)

可以看到,if-else分支并没有被表示出来。为什么?
trace记录代码运行发生的操作,并构造一个ScriptModule。控制流中只有一种情况被记录了下来,其他情况都被忽略了。

这就需要用到torch.jit.script了:

scripted_gate = torch.jit.script(MyDecisionGate())my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell)print(scripted_gate.code)
print(scripted_cell.code)
def forward(self,x: Tensor) -> Tensor:if bool(torch.gt(torch.sum(x), 0)):_0 = xelse:_0 = torch.neg(x)return _0def forward(self,x: Tensor,h: Tensor) -> Tuple[Tensor, Tensor]:dg = self.dglinear = self.linear_0 = torch.add((dg).forward((linear).forward(x, ), ), h)new_h = torch.tanh(_0)return (new_h, new_h)

可以考到,控制流也被记录了下来。
现在让我们尝试运行该程序:

# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
print(scripted_cell(x, h))
(tensor([[ 0.5679,  0.5762,  0.2506, -0.0734],[ 0.5228,  0.7122,  0.6985, -0.0656],[ 0.6187,  0.4487,  0.7456, -0.0238]], grad_fn=<TanhBackward0>), tensor([[ 0.5679,  0.5762,  0.2506, -0.0734],[ 0.5228,  0.7122,  0.6985, -0.0656],[ 0.6187,  0.4487,  0.7456, -0.0238]], grad_fn=<TanhBackward0>))

trace和script的区别总结

  • torch.jit.tracetorch.jit.trace 用于将一个具体的输入示例追踪(trace)模型的一次计算过程,从而生成一个 TorchScript 模型。对于动态控制流(如条件语句),它只会记录每个分支中的一种情况。因此,它不适用于无固定形状输入、具有动态控制流的模型。

  • torch.jit.scripttorch.jit.script 用于将整个 PyTorch 模型转换为 TorchScript 模型,包括模型的所有逻辑和控制流。script适用于无固定形状输入、具有动态控制流的模型 。但是,它可能会把保存一些多余的代码, 产生额外的性能开销。

因此,可以将两者混合使用,扬长避短。

script 和 trace 混合使用

torch.jit.tracetorch.jit.script 可以混合使用: 复杂模型中静态部分用torch.jit.trace进行转换, 动态部分用torch.jit.script 进行转换,以发挥各自的优势。以下是两个可能的情况:

  • torch.jit.script内联traced模块的代码,
class MyRNNLoop(torch.nn.Module):def __init__(self):super(MyRNNLoop, self).__init__()self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))def forward(self, xs):h, y = torch.zeros(3, 4), torch.zeros(3, 4)for i in range(xs.size(0)):y, h = self.cell(xs[i], h)return y, hrnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
def forward(self,xs: Tensor) -> Tuple[Tensor, Tensor]:h = torch.zeros([3, 4])y = torch.zeros([3, 4])y0 = yh0 = hfor i in range(torch.size(xs, 0)):cell = self.cell_0 = (cell).forward(torch.select(xs, 0, i), h0, )y1, h1, = _0y0, h0 = y1, h1return (y0, h0)
  • torch.jit.trace内联scripted模块的代码,
class WrapRNN(torch.nn.Module):def __init__(self):super(WrapRNN, self).__init__()self.loop = torch.jit.script(MyRNNLoop())def forward(self, xs):y, h = self.loop(xs)return torch.relu(y)traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)
def forward(self,xs: Tensor) -> Tensor:loop = self.loop_0, y, = (loop).forward(xs, )return torch.relu(y)

保存和加载模型

  • traced.save : 保存TorchScript

  • torch.jit.load : 加载TorchScript

traced.save('wrapped_rnn.pt')loaded = torch.jit.load('wrapped_rnn.pt')print(loaded)
print(loaded.code)
RecursiveScriptModule(original_name=WrapRNN(loop): RecursiveScriptModule(original_name=MyRNNLoop(cell): RecursiveScriptModule(original_name=MyCell(dg): RecursiveScriptModule(original_name=MyDecisionGate)(linear): RecursiveScriptModule(original_name=Linear)))
)
def forward(self,xs: Tensor) -> Tensor:loop = self.loop_0, y, = (loop).forward(xs, )return torch.relu(y)

参考:
https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/2979.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

科学高效备考AMC8和AMC10竞赛,吃透2000-2024年1850道真题和解析

如何科学、有效地备考AMC8、AMC10美国数学竞赛&#xff1f;多做真题&#xff0c;吃透真题是科学有效的方法之一&#xff0c;通过做真题&#xff0c;可以帮助孩子找到真实竞赛的感觉&#xff0c;而且更加贴近比赛的内容&#xff0c;可以通过真题查漏补缺&#xff0c;更有针对性的…

成功解决ImportError: cannot import name ‘builder‘ from ‘google.protobuf.internal

成功解决ImportError: cannot import name builder from google.protobuf.internal 目录 解决问题 解决思路 解决方法 解决问题 ImportError: cannot import name builder from google.protobuf.internal 解决思路 导入错误:无法从“google.protobuf.internal”导入名称“…

在React函数组件中使用错误边界和errorElement进行错误处理

在React 18中,函数组件可以使用两种方式来处理错误: 使用 ErrorBoundary ErrorBoundary 是一种基于类的组件,可以捕获其子组件树中的任何 JavaScript 错误,并记录这些错误、渲染备用 UI 而不是冻结的组件树。 在函数组件中使用 ErrorBoundary,需要先创建一个基于类的 ErrorB…

网络通信安全

一、网络通信安全基础 TCP/IP协议简介 TCP/IP体系结构、以太网、Internet地址、端口 TCP/IP协议简介如下&#xff1a;&#xff08;from文心一言&#xff09; TCP/IP&#xff08;Transmission Control Protocol/Internet Protocol&#xff0c;传输控制协议/网际协议&#xff0…

用友NC Cloud importhttpscer接口任意文件上传漏洞

声明 本文仅用于技术交流&#xff0c;请勿用于非法用途 由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失&#xff0c;均由使用者本人负责&#xff0c;文章作者不为此承担任何责任。 一、漏洞描述 用友NC Cloud的importhttpscer接口如果存在任意文件上传…

开源文本嵌入模型M3E

进入正文前&#xff0c;先扯点题外话 这两天遇到一个棘手的问题&#xff0c;在用 docker pull 拉取镜像时&#xff0c;会报错&#xff1a; x509: certificate has expired or is not yet valid 具体是下面&#x1f447;这样的 rootDS918:/volume2/docker/xiaoya# docker pul…

恒峰智慧科技—森林守护者:森林消防泵如何助力灭火?

在茂密的森林中&#xff0c;一场突如其来的火灾可能带来无法估量的破坏。幸运的是&#xff0c;森林消防泵的出现&#xff0c;帮助我们对抗这些威胁。本文将深入探讨森林消防泵如何在灭火工作中发挥重要作用。 一、森林消防泵的功能和重要性&#xff1a; 首先&#xff0c;我们需…

探索人工智能的边界:GPT 4.0与文心一言 4.0免费使用体验全揭秘!

探索人工智能的边界&#xff1a;GPT与文心一言免费试用体验全揭秘&#xff01; 前言免费使用文心一言4.0的方法官方入口进入存在的问题免费使用文心一言4.0的方法 免费使用GPT4.0的方法官方入口进入存在的问题免费使用GPT4.0的方法 前言 未来已来&#xff0c;人工智能已经可以…

Matlab|基于元模型优化算法的主从博弈多虚拟电厂动态定价和能量管理

1 主要内容 该程序复现《基于元模型优化算法的主从博弈多虚拟电厂动态定价和能量管理》模型&#xff0c;建立运营商和多虚拟电厂的一主多从博弈模型&#xff0c;研究运营商动态定价行为和虚拟电厂能量管理模型&#xff0c;模型为双层&#xff0c;首先下层模型中&#xff0c;构建…

【Android】android 10 jar_sdk_library添加

前言 当前项目遇到客户&#xff0c;Android 10 平台&#xff0c;需要封装jar_sdk_library给第三方应用使用。其中jar_sdk_library中存在aidl文件。遇到无法编译通过问题。 解决 system/tools/aidl修改 Android.bp修改

frp改造Windows笔记本实现家庭版免费内网穿透

文章目录 前言frp原理Windows服务端IP检验IP固定软件下载端口放行端口映射开机启动 NAS客户端端口查询软件下载端口检验穿透测试自启设置 Ubuntu客户端软件下载后台启动 后记 前言 之前一直用花生壳远程控制一个服务器&#xff0c;但最近内网的网络策略似乎发生了变化&#xf…

信息系统项目管理师0068:数据标准化(5信息系统工程—5.2数据工程—5.2.2数据标准化)

点击查看专栏目录 文章目录 5.2.2数据标准化1.元数据标准化2.数据元标准化3.数据模式标准化4.数据分类与编码标准化5.数据标准化管理记忆要点总结5.2.2数据标准化 数据标准化是实现数据共享的基础。数据标准化主要为复杂的信息表达、分类和定位建立相应的原则和规范,使其简单化…

谷歌发布基于声学建模的无限虚拟房间增强现实鲁棒语音识别技术

声学室模拟允许在AR眼镜上以最少的真实数据进行训练&#xff0c;用于开发鲁棒的语音识别声音分离模型。 随着增强现实&#xff08;AR&#xff09;技术的强大和广泛应用&#xff0c;它能应用到各种日常情境中。我们对AR技术的潜能感到兴奋&#xff0c;并持续不断地开发和测试新…

Adobe Illustrator 2024 v28.4.1 (macOS, Windows) - 矢量绘图

Adobe Illustrator 2024 v28.4.1 (macOS, Windows) - 矢量绘图 Acrobat、After Effects、Animate、Audition、Bridge、Character Animator、Dimension、Dreamweaver、Illustrator、InCopy、InDesign、Lightroom Classic、Media Encoder、Photoshop、Premiere Pro、Adobe XD 请…

ChatGPT实战100例 - (18) 用事件风暴玩转DDD

文章目录 ChatGPT实战100例 - (18) 用事件风暴玩转DDD一、标准流程二、定义目标和范围三、准备工具和环境四、列举业务事件五、 组织和排序事件六、确定聚合并引入命令七、明确界限上下文八、识别领域事件和领域服务九、验证和修正模型十、生成并验证软件设计十一、总结 ChatGP…

解线性方程组——(Gauss-Seidel)高斯-赛德尔迭代法 | 北太天元

一、Gauss-Seidel迭代法 n 3 n3 n3时 A ( a 11 a 12 a 13 a 21 a 22 a 23 a 31 a 32 a 33 ) , b ( b 1 b 2 b 3 ) , A\begin{pmatrix} a_{11} & a_{12} &a_{13}\\ a_{21} & a_{22} &a_{23}\\ a_{31} & a_{32} &a_{33}\\ \end{pmatrix} ,\quad b\be…

缓存神器-JetCache

序言 今天和大家聊聊阿里的一款缓存神器 JetCache。 一、缓存在开发实践中的问题 1.1 缓存方案的可扩展性问题 谈及缓存&#xff0c;其实有许多方案可供选择。例如&#xff1a;Guava Cache、Caffine、Encache、Redis 等。 这些缓存技术都能满足我们的需求&#xff0c;但现…

《从零开始的Java世界》10File类与IO流

《从零开始的Java世界》系列主要讲解Javase部分&#xff0c;从最简单的程序设计到面向对象编程&#xff0c;再到异常处理、常用API的使用&#xff0c;最后到注解、反射&#xff0c;涵盖Java基础所需的所有知识点。学习者应该从学会如何使用&#xff0c;到知道其实现原理全方位式…

LAMP(Linux+Apache+MySQL+PHP)环境介绍、配置、搭建

LAMP(LinuxApacheMySQLPHP)环境介绍、配置、搭建 LAMP介绍 LAMP是由Linux&#xff0c; Apache&#xff0c; MySQL&#xff0c; PHP组成的&#xff0c;即把Apache、MySQL以及PHP安装在Linux系统上&#xff0c;组成一个环境来运行PHP的脚本语言。Apache是最常用的Web服务软件&a…

纸箱码垛机:从传统到智能,科技如何助力产业升级

随着科技的飞速发展&#xff0c;传统工业领域正经历着一场重要的变革。作为物流行业重要一环的纸箱码垛机&#xff0c;其从传统到智能的转型升级&#xff0c;不仅提高了生产效率&#xff0c;还大幅降低了人工成本&#xff0c;为产业升级提供了强大助力。星派将探讨纸箱码垛机的…