【Pytorch】(十三)模型部署: TorchScript

文章目录

  • (十三)模型部署: TorchScript
    • Pytorch动态图的优缺点
    • TorchScript
    • Pytorch模型转换为TorchScript
      • torch.jit.trace
      • torch.jit.script
      • trace和script的区别总结
      • trace 和script 混合使用
      • 保存和加载模型

(十三)模型部署: TorchScript

Pytorch动态图的优缺点

与Tensorflow使用静态计算图不同,PyTorch 使用的是动态计算图:

动态图允许在运行时渐进地构建计算图,使得模型设计更加灵活。开发者可以使用 Python 的控制流结构(如循环、条件语句等)来动态地定义模型的结构,从而更容易实现复杂的模型逻辑。

这种计算方式更直观,更pythonic。开发者可以更容易地理解和调试模型各个模块,快速地修改、迭代模型。

然而,与静态图相比,动态图的执行效率可能会较低。因为动态图难以进行一些计算图的优化,如运算符融合、图优化等。而且,动态图依赖于Python 环境。这些因素使得动态图不适合在低延迟要求较高的生产环境下部署。

因此,在部署Pytorch训练后的模型时,需要将动态图转换为静态图,这就要用到TorchScript。

TorchScript

TorchScript是PyTorch模型的一种静态图表示形式,支持模型的部署优化、跨平台部署以及与其他深度学习框架的集成:

  • 模型的部署优化:TorchScript 可以帮助优化 PyTorch 模型以提高性能和效率。通过将模型转换为静态图形式,TorchScript 可以应用各种优化技术,如运算符融合、图优化等,从而加速模型执行并降低内存消耗。
  • 跨平台部署:将模型转换为 TorchScript 格式可以实现跨平台部署,模型可以在没有 Python 环境的情况下运行。这对于在生产环境中部署模型到服务器、移动设备或边缘设备上非常有用。
  • 与其他框架集成:通过将 PyTorch 模型转换为 TorchScript 格式,可以更方便地与其他深度学习框架进行交互。例如,可以将TorchScript 进一步转换为 ONNX 格式,从而与 TensorFlow 等其他框架进行集成和交互操作。

Pytorch模型转换为TorchScript

torch.jit.tracetorch.jit.script 是 PyTorch 中用于模型转换为 TorchScript 格式的工具,但它们有不同的作用和使用场景。

torch.jit.trace

通过torch.jit.trace 将 没有控制流的MyCell 模块转化为TorchScript:


import torch  # This is all you need to use both PyTorch and TorchScript!torch.manual_seed(191009)  # set the seed for reproducibilityclass MyCell(torch.nn.Module):def __init__(self):super(MyCell, self).__init__()self.linear = torch.nn.Linear(4, 4)def forward(self, x, h):new_h = torch.tanh(self.linear(x) + h)return new_h, new_hmy_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
MyCell(original_name=MyCell(linear): Linear(original_name=Linear)
)

torch.jit.trace调用了my_cell,记录了模块计算时发生的操作,并创建了一个torch.jit.ScriptModule的实例(TracedModule是其实例)traced_celltraced_cell 记录了my_cell的计算图。我们可以使用.graph属性来查看:

print(traced_cell.graph)
graph(%self.1 : __torch__.MyCell,%x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),%h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):%linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)%20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)%11 : int = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0%12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0%13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0%14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)return (%14)

然而,图中包含的大多数信息对我们没有用处。我们可以使用.code属性对其进行Python语法解释:

print(traced_cell.code)
def forward(self,x: Tensor,h: Tensor) -> Tuple[Tensor, Tensor]:linear = self.linear_0 = torch.tanh(torch.add((linear).forward(x, ), h))return (_0, _0)

调用traced_cell会产生与Python模块实例my_cell() 相同的结果:

print(my_cell(x, h))
print(traced_cell(x, h))
(tensor([[-0.2541,  0.2460,  0.2297,  0.1014],[-0.2329, -0.2911,  0.5641,  0.5015],[ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541,  0.2460,  0.2297,  0.1014],[-0.2329, -0.2911,  0.5641,  0.5015],[ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>))
(tensor([[-0.2541,  0.2460,  0.2297,  0.1014],[-0.2329, -0.2911,  0.5641,  0.5015],[ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541,  0.2460,  0.2297,  0.1014],[-0.2329, -0.2911,  0.5641,  0.5015],[ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>))

torch.jit.script

我们先尝试通过torch.jit.trace 将 带有控制流的MyCell 模块转化为TorchScript:

class MyDecisionGate(torch.nn.Module):def forward(self, x):if x.sum() > 0:return xelse:return -xclass MyCell(torch.nn.Module):def __init__(self, dg):super(MyCell, self).__init__()self.dg = dgself.linear = torch.nn.Linear(4, 4)def forward(self, x, h):new_h = torch.tanh(self.dg(self.linear(x)) + h)return new_h, new_hmy_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))print(traced_cell.dg.code)
print(traced_cell.code)
/var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:261: TracerWarning:Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!def forward(self,argument_1: Tensor) -> NoneType:return Nonedef forward(self,x: Tensor,h: Tensor) -> Tuple[Tensor, Tensor]:dg = self.dglinear = self.linear_0 = (linear).forward(x, )_1 = (dg).forward(_0, )_2 = torch.tanh(torch.add(_0, h))return (_2, _2)

可以看到,if-else分支并没有被表示出来。为什么?
trace记录代码运行发生的操作,并构造一个ScriptModule。控制流中只有一种情况被记录了下来,其他情况都被忽略了。

这就需要用到torch.jit.script了:

scripted_gate = torch.jit.script(MyDecisionGate())my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell)print(scripted_gate.code)
print(scripted_cell.code)
def forward(self,x: Tensor) -> Tensor:if bool(torch.gt(torch.sum(x), 0)):_0 = xelse:_0 = torch.neg(x)return _0def forward(self,x: Tensor,h: Tensor) -> Tuple[Tensor, Tensor]:dg = self.dglinear = self.linear_0 = torch.add((dg).forward((linear).forward(x, ), ), h)new_h = torch.tanh(_0)return (new_h, new_h)

可以考到,控制流也被记录了下来。
现在让我们尝试运行该程序:

# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
print(scripted_cell(x, h))
(tensor([[ 0.5679,  0.5762,  0.2506, -0.0734],[ 0.5228,  0.7122,  0.6985, -0.0656],[ 0.6187,  0.4487,  0.7456, -0.0238]], grad_fn=<TanhBackward0>), tensor([[ 0.5679,  0.5762,  0.2506, -0.0734],[ 0.5228,  0.7122,  0.6985, -0.0656],[ 0.6187,  0.4487,  0.7456, -0.0238]], grad_fn=<TanhBackward0>))

trace和script的区别总结

  • torch.jit.tracetorch.jit.trace 用于将一个具体的输入示例追踪(trace)模型的一次计算过程,从而生成一个 TorchScript 模型。对于动态控制流(如条件语句),它只会记录每个分支中的一种情况。因此,它不适用于无固定形状输入、具有动态控制流的模型。

  • torch.jit.scripttorch.jit.script 用于将整个 PyTorch 模型转换为 TorchScript 模型,包括模型的所有逻辑和控制流。script适用于无固定形状输入、具有动态控制流的模型 。但是,它可能会把保存一些多余的代码, 产生额外的性能开销。

因此,可以将两者混合使用,扬长避短。

trace 和script 混合使用

torch.jit.tracetorch.jit.script 可以混合使用: 复杂模型中静态部分用torch.jit.trace进行转换, 动态部分用torch.jit.script 进行转换,以发挥各自的优势。以下是两个可能的情况:

  • torch.jit.script内联traced模块的代码,
class MyRNNLoop(torch.nn.Module):def __init__(self):super(MyRNNLoop, self).__init__()self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))def forward(self, xs):h, y = torch.zeros(3, 4), torch.zeros(3, 4)for i in range(xs.size(0)):y, h = self.cell(xs[i], h)return y, hrnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
def forward(self,xs: Tensor) -> Tuple[Tensor, Tensor]:h = torch.zeros([3, 4])y = torch.zeros([3, 4])y0 = yh0 = hfor i in range(torch.size(xs, 0)):cell = self.cell_0 = (cell).forward(torch.select(xs, 0, i), h0, )y1, h1, = _0y0, h0 = y1, h1return (y0, h0)
  • torch.jit.trace内联scripted模块的代码,
class WrapRNN(torch.nn.Module):def __init__(self):super(WrapRNN, self).__init__()self.loop = torch.jit.script(MyRNNLoop())def forward(self, xs):y, h = self.loop(xs)return torch.relu(y)traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)
def forward(self,xs: Tensor) -> Tensor:loop = self.loop_0, y, = (loop).forward(xs, )return torch.relu(y)

保存和加载模型

  • traced.save : 保存TorchScript

  • torch.jit.load : 加载TorchScript

traced.save('wrapped_rnn.pt')loaded = torch.jit.load('wrapped_rnn.pt')print(loaded)
print(loaded.code)
RecursiveScriptModule(original_name=WrapRNN(loop): RecursiveScriptModule(original_name=MyRNNLoop(cell): RecursiveScriptModule(original_name=MyCell(dg): RecursiveScriptModule(original_name=MyDecisionGate)(linear): RecursiveScriptModule(original_name=Linear)))
)
def forward(self,xs: Tensor) -> Tensor:loop = self.loop_0, y, = (loop).forward(xs, )return torch.relu(y)

参考:
https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/4120.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

掌静脉识别关键技术研究综述

掌静脉识别作为一种新兴的红外生物识别技术&#xff0c;因其高安全性、活体检测性等优势已成为当前生物特征识别领域中的研究热点之一。近年来&#xff0c;该领域的大量研究通过引入深度学习方法推动了掌静脉识别技术的发展。为了掌握掌静脉识别领域最新研究现状及发展方向&…

ddos云服务器有哪些防御方法和优势

本文将介绍云服务器遇到DDoS攻击的应对方法&#xff0c;包括流量清洗、负载均衡、防火墙设置和CDN加速等。同时&#xff0c;文章还介绍了ddos云服务器的防御优势&#xff0c;包括高防护能力、自动化防御、实时监控和报警以及弹性扩展等。通过这些防御方法和ddos云服务器的应用&…

React复习笔记

基础语法 创建项目 借助脚手架&#xff0c;新建一个React项目(可以使用vite或者cra&#xff0c;这里使用cra) npx create-react-app 项目名 create-react-app是React脚手架的名称 启动项目 npm start 或者 yarn start src是源文件index.js相当于Vue的main.js文件。整个…

vue3 ——笔记 (条件渲染,列表渲染,事件处理)

条件渲染 v-if v-if 指令用于条件性地渲染一块内容&#xff0c;只有v-if的表达式返回值为真才会渲染 v-else v-else 为 v-if 添加一个 else 区块 v-else 必须在v-if或v-else-if后 v-else-if v-else-if 是v-if 的区块 可以连续多次重复使用 v-show 按条件显示元素 v-sh…

【Linux系统化学习】生产者消费者模型(阻塞队列和环形队列)

目录 生产者消费者模型 什么是生产者消费者模型 为什么要使用生产者消费者模型 生产者消费者模型的优点 为什么生产者和生产者要互斥&#xff1f; 为什么消费者和消费者要互斥&#xff1f; 为什么生产者和消费者既是互斥又是同步&#xff1f; 基于BlockingQueue的生产者…

26版SPSS操作教程(高级教程第十六章)

目录 前言 粉丝及官方意见说明 第十六章一些学习笔记 第十六章一些操作方法 多维尺度分析 不考虑个体差异的多维尺度分析模型 假设数据 具体操作 结果解释 选择不同距离的排列方式 考虑个体差异的多维尺度分析模型&#xff08;INDSCAL&#xff0c;individual differ…

[C++ QT项目实战]----系统实现双击表格某一行,表格数据不再更新,可以查看该行所有信息,选中表更新之后,数据可以继续更新

前言 在需要庞大的数据量的系统中&#xff0c;基于合适的功能对数据进行观察和使用至关重要&#xff0c;本篇在自己项目实战的基础上&#xff0c;基于C QT编程语言&#xff0c;对其中一个数据功能进行分析和代码实现&#xff0c;希望可以有所帮助。一些特殊原因&#xff0c;图片…

车道分割YOLOV8-SEG

车道分割YOLOV8-SEG&#xff0c;训练得到PT模型&#xff0c;然后转换成ONNX&#xff0c;OPENCV的DNN调用&#xff0c;支持C,PYTHON,ANDROID开发 车道分割YOLOV8-SEG

数据污染对大型语言模型的潜在影响

大型语言模型&#xff08;LLMs&#xff09;中存在的数据污染是一个重要问题&#xff0c;可能会影响它们在各种任务中的表现。这指的是LLMs的训练数据中包含了来自下游任务的测试数据。解决数据污染问题至关重要&#xff0c;因为它可能导致结果偏倚&#xff0c;并影响LLMs在其他…

python三维交互可视化工具plotly使用

三维数据可视化工具使用 import plotly.graph_objects as go import numpy as np# 生成随机点 data np.random.uniform(-3,3,(100000, 2)) Z np.exp(-((data[:, 0] - 0)**2 / (2*1**2) (data[:, 1] - 0)**2 / (2*1**2)))scatter1 go.Scatter3d(xdata[:, 0], ydata[:, 1], …

【项目】仿muduo库One Thread One Loop式主从Reactor模型实现高并发服务器(Http板块)

【项目】仿muduo库One Thread One Loop式主从Reactor模型实现高并发服务器&#xff08;Http板块&#xff09; 一、思路图二、Util板块1、Splite板块&#xff08;分词&#xff09;&#xff08;1&#xff09;代码&#xff08;2&#xff09;测试及测试结果i、第一种测试ii、第二种…

关于discuz论坛网址优化的一些记录(伪静态)

最近网站刚上线&#xff0c;针对SEO做了些操作&#xff0c;为了方便网站网页被收录&#xff0c;特此记录下 1.开启伪静态 按照操作勾选所有项&#xff0c;然后点击查看伪静态规则 2.打开宝塔&#xff0c;找到左侧列表的网站&#xff0c;然后找到相应站点的设置。把discuz自动…

STM32的端口引脚的复用功能及重映射功能解析

目录 STM32的端口引脚的复用功能及重映射功能解析 复用功能 复用功能的初始化 重映射功能 重映射功能的初始化 复用功能和重映射的区别 部分重映射与完全重映射 补充 STM32的端口引脚的复用功能及重映射功能解析 复用功能 首先、我们可以这样去理解stm32引脚的复用功能…

SD-WAN怎样助力企业网络升级

随着企业规模的持续扩张&#xff0c;其网络建设的重要性日益凸显&#xff0c;成为业务成功的基石。尤其对于中小企业而言&#xff0c;信息化和电脑化已成为推动生产力和竞争力提升的关键所在。办公室自动化、数据库、ERP、CRM、物流供应链等关键业务应用的不断增加&#xff0c;…

css 文字左右抖动效果

<template><div class"box"><div class"shake shape">抖动特效交字11</div></div> </template><script setup></script><style scope> .shape {margin: 50px;width: 200px;height: 50px;line-heigh…

计算机存储原理.2

1.主存储器与CPU之间的连接 2.存储器芯片的输入输出信号 3.增加主存的存储字长 3.1位扩展 数据总线的利用成分是不充分的(单块只能读写一位)&#xff0c;为了解决这个问题所以引出了位扩展。 使用多块存储芯片解决这个问题。 3.2字扩展 因为存储器买的是8k*8位的&am…

Linear Secret-Sharing Scheme(LSSS) Monotone Span Program(MSP)

参考文献&#xff1a; [KW93] Karchmer M, Wigderson A. On span programs[C]//[1993] Proceedings of the Eigth Annual Structure in Complexity Theory Conference. IEEE, 1993: 102-111.[CDM00] Cramer R, Damgrd I, Maurer U. General secure multi-party computation fr…

【探索Java编程:从入门到入狱】Day2

&#x1f36c; 博主介绍&#x1f468;‍&#x1f393; 博主介绍&#xff1a;大家好&#xff0c;我是 hacker-routing &#xff0c;很高兴认识大家~ ✨主攻领域&#xff1a;【渗透领域】【应急响应】 【Java、PHP】 【VulnHub靶场复现】【面试分析】 &#x1f389;点赞➕评论➕收…

js[黑马笔记]

js基础 基础语法 输入输出 变量 数组 常量 数据类型 类型转换 运算符 语句 数组 函数 调用方式 函数名() 匿名函数 使用: 1.函数表达式 2.立即执行函数 对象 内置对象 web API DOM document object Model元素操作 获取元素 设置元素 定时器 DOM事件基础 事件监听 事件类…

流量网关与服务网关的区别:(面试题,掌握)

流量网关&#xff1a;&#xff08;如Nignx&#xff0c;OpenResty&#xff0c;Kong&#xff09;是指提供全局性的、与后端业务应用无关的策略&#xff0c;例如 HTTPS证书认证、Web防火墙、全局流量监控&#xff0c;黑白名单等。 服务网关&#xff1a;&#xff08;如Spring Clou…