PyTorch之Torch Script的简单使用

一、参考资料

TorchScript 简介
Torch Script
Loading a TorchScript Model in C++
TorchScript 解读(一):初识 TorchScript
libtorch教程(一)开发环境搭建:VS+libtorch和Qt+libtorch

二、Torch Script模型格式

1. Torch Script简介

Torch Script 是一种序列化和优化 PyTorch 模型的格式,在优化过程中,一个 torch.nn.Module 模型会被转换成 Torch Script 的 torch.jit.ScriptModule 模型。通常,TorchScript 被当成一种中间表示使用。

Torch Script 的主要用途是进行模型部署,需要记录生成一个便于推理优化的 IR,对计算图的编辑通常都是面向性能提升等等,不会给模型本身添加新的功能。

模型格式支持语言适用场景
PyTorch modelPython模型训练
Torch ScriptC++模型推理,模型部署

2. 生成Torch Script模型

如何将PyTorch model格式转换为Torch Script,有两种方式: torch.jit.tracetorch.jit.script

As its name suggests, the primary interface to PyTorch is the Python programming language. While Python is a suitable and preferred language for many scenarios requiring dynamism and ease of iteration, there are equally many situations where precisely these properties of Python are unfavorable. One environment in which the latter often applies is production – the land of low latencies and strict deployment requirements. For production scenarios, C++ is very often the language of choice, even if only to bind it into another language like Java, Rust or Go. The following paragraphs will outline the path PyTorch provides to go from an existing Python model to a serialized representation that can be loaded and executed purely from C++, with no dependency on Python.

A PyTorch model’s journey from Python to C++ is enabled by Torch Script, a representation of a PyTorch model that can be understood, compiled and serialized by the Torch Script compiler. If you are starting out from an existing PyTorch model written in the vanilla “eager” API, you must first convert your model to Torch Script.

There exist two ways of converting a PyTorch model to Torch Script. The first is known as tracing, a mechanism in which the structure of the model is captured by evaluating it once using example inputs, and recording the flow of those inputs through the model. This is suitable for models that make limited use of control flow. The second approach is to add explicit annotations to your model that inform the Torch Script compiler that it may directly parse and compile your model code, subject to the constraints imposed by the Torch Script language.

2.1 trace跟踪模式

Converting to Torch Script via Tracing
How to convert your PyTorch model to TorchScript

功能:将不带控制流的模型转换为 Torch Script,并生成一个 ScriptModule 对象。

函数原型:torch.jit.trace

所谓 trace 指的是进行一次模型推理,在推理的过程中记录所有经过的计算,将这些记录整合成计算图,即模型的静态图。

trace跟踪模式的缺点是:无法识别出模型中的控制流(如循环)

To convert a PyTorch model to Torch Script via tracing, you must pass an instance of your model along with an example input to the torch.jit.trace function. This will produce a torch.jit.ScriptModule object with the trace of your model evaluation embedded in the module’s forward method.

import torch
import torchvision# An instance of your model.
model = torchvision.models.resnet18(pretrained=True)# Switch the model to eval model
model.eval()# An example input you would normally provide to your model's forward() method.
dummy_input = torch.rand(1, 3, 224, 224)# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, dummy_input)# Save the TorchScript model
traced_script_module.save("traced_resnet18_model.pt")output = traced_script_module(torch.ones(1, 3, 224, 224))# IR中间表示
print(traced_script_module.graph)
print(traced_script_module.code)# 调用traced_cell会产生与 Python 模块相同的结果
print(model(x, h))
print(traced_script_module(x, h))

2.2 script记录模式(带控制流)

功能:将带控制流的模型转换为 Torch Script,并生成一个 ScriptModule 对象。

函数原型:torch.jit.script

script记录模式,通过解析模型来正确记录所有的控制流。script记录模式直接解析网络定义的 python 代码,生成抽象语法树 AST。

Because the forward method of this module uses control flow that is dependent on the input, it is not suitable for tracing. Instead, we can convert it to a ScriptModule. In order to convert the module to the ScriptModule, one needs to compile the module with torch.jit.script as follows.

class MyModule(torch.nn.Module):def __init__(self, N, M):super(MyModule, self).__init__()self.weight = torch.nn.Parameter(torch.rand(N, M))def forward(self, input):if input.sum() > 0:output = self.weight.mv(input)else:output = self.weight + inputreturn outputmy_module = MyModule(10,20)
sm = torch.jit.script(my_module)# Save the ScriptModule
sm.save("my_module_model.pt")

2.3 trace格式转换

import torch
import torchvision
from unet import UNetmodel = UNet(3, 2) 
model.load_state_dict(torch.load("best_weights.pth"))
model.eval()example = torch.rand(1, 3, 320, 480) 
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")

三、Loading a TorchScript Model in C++

1. C++加载 ScriptModule

创建一个简单的程序,目录结构如下所示:

example-app/CMakeLists.txtexample-app.cpp

1.1 example-app.cpp

#include <torch/script.h> // One-stop header.#include <iostream>
#include <memory>int main(int argc, const char* argv[]) {if (argc != 2) {std::cerr << "usage: example-app <path-to-exported-script-module>\n";return -1;}torch::jit::script::Module module;try {// Deserialize the ScriptModule from a file using torch::jit::load().module = torch::jit::load(argv[1]);}catch (const c10::Error& e) {std::cerr << "error loading the model\n";return -1;}std::cout << "ok\n";
}

The <torch/script.h> header encompasses all relevant includes from the LibTorch library necessary to run the example. Our application accepts the file path to a serialized PyTorch ScriptModule as its only command line argument and then proceeds to deserialize the module using the torch::jit::load() function, which takes this file path as input. In return we receive a torch::jit::script::Module object.

1.2 CMakeLists.txt

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)find_package(Torch REQUIRED)add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)

1.3 编译执行

mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
cmake --build . --config Release
make

输出结果

root@4b5a67132e81:/example-app# mkdir build
root@4b5a67132e81:/example-app# cd build
root@4b5a67132e81:/example-app/build# cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
-- The C compiler identification is GNU 5.4.0
-- The CXX compiler identification is GNU 5.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE
-- Configuring done
-- Generating done
-- Build files have been written to: /example-app/build
root@4b5a67132e81:/example-app/build# make
Scanning dependencies of target example-app
[ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
[100%] Linking CXX executable example-app
[100%] Built target example-app

1.4 执行结果

If we supply the path to the traced ResNet18 model traced_resnet_model.pt we created earlier to the resulting example-app binary, we should be rewarded with a friendly “ok”. Please note, if try to run this example with my_module_model.pt you will get an error saying that your input is of an incompatible shape. my_module_model.pt expects 1D instead of 4D.

root@4b5a67132e81:/example-app/build# ./example-app <path_to_model>/traced_resnet_model.pt
ok

2. C++推理ScriptModule

main()函数中添加模型推理的代码:

// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

编译执行

root@4b5a67132e81:/example-app/build# make
Scanning dependencies of target example-app
[ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
[100%] Linking CXX executable example-app
[100%] Built target example-app
root@4b5a67132e81:/example-app/build# ./example-app traced_resnet_model.pt
-0.2698 -0.0381  0.4023 -0.3010 -0.0448
[ Variable[CPUFloatType]{1,5} ]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/795524.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

关于 elf loader 的编写

可以使用如下命令观看 elf 文件的信息 readelf -a build/ramdisk.img | vim -在编写 elf loader 的时候&#xff0c;实际上只有下图这一部分 “Program Headers” 是有用的 凡是类型为 “LOAD” 的就是需要加载进内存的部分 所以&#xff0c;只要把这些部分加载进内存里&…

数据库不用mmap

你确定你想用 MMAP 实现数据库么&#xff1f;_哔哩哔哩_bilibili MMAP 的随机读与顺序读的性能表现不好&#xff0c;以及对于写主要是不可控的刷入时机以及代码冗余&#xff0c;所以 MMAP 不适合在数据库中使用。 mmap是posix系统调用&#xff0c;它提供由操作系统管理内存映…

acwing算法提高之图论--最小生成树的典型应用

目录 1 介绍2 训练 1 介绍 本专题用来记录使用prim算法或kruskal算法求解的题目。 2 训练 题目1&#xff1a;1140最短网络 C代码如下&#xff0c; #include <iostream> #include <cstring>using namespace std;const int N 110, INF 0x3f3f3f3f; int g[N][N…

(C)1008 数组元素循环右移问题

1008 数组元素循环右移问题&#xff1a; 问题描述 输入样例&#xff1a; 6 2 1 2 3 4 5 6 输出样例&#xff1a; 5 6 1 2 3 4 解决方案&#xff1a; #include<stdio.h> #include<string.h> #include<math.h> int main(){int n,k,flag,y,x,final;int a[10000…

Flutter Boost 3

社区的 issue 没有收敛的趋势。 设计过于复杂&#xff0c;概念太多。这让一个新手看 FlutterBoost 的代码很吃力。 这些问题促使我们重新梳理设计&#xff0c;为了彻底解决这些顽固的问题&#xff0c;我们做一次大升级&#xff0c;我们把这次升级命名为 FlutterBoost 3.0&am…

合理早餐选择,稳定糖尿病血糖。

对于糖尿病患者来说&#xff0c;饮食管理是治疗的重要一环。不合理的早餐选择会导致血糖的波动。很多糖尿病朋友按时吃药&#xff0c;但是血糖就是稳定不住&#xff0c;之前看过一个例子&#xff0c;北京崇文门医院朱学敏主任接诊过一个患者&#xff0c;那个患者按时吃药&#…

LaTeX 空格与换行

任意多个空格与一个空格的功能相同。只有字符后面的空格是有效的&#xff0c;每行最前面的空格被忽略。单个换行被视作一个空格&#xff0c;连续两个换行表示分段。~被称作一种不可打断的空格&#xff0c;排版系统不会在这种空格之间换行。西文的逗号、句号和分号等标点后面应该…

Java | Leetcode Java题解之第8题字符串转换整数atoi

题目&#xff1a; 题解&#xff1a; class Solution {public int myAtoi(String str) {Automaton automaton new Automaton();int length str.length();for (int i 0; i < length; i) {automaton.get(str.charAt(i));}return (int) (automaton.sign * automaton.ans);} …

Android Studio学习10——资源res的使用

一、String,StringArray的使用 一次修改&#xff0c;多出生效 String StringArray 二、color的使用 颜色代码对应表 和上面的相似用法 三、Dimen(尺寸)的使用 用的少&#xff0c;一般直接写尺寸 四、如何写一个drawable作为背景 五、如何写一个可以改变的drawable(按钮按下…

IP地址:是给主机配置的,还是给网卡配置的?

在探索网络的奥秘时&#xff0c;我们经常会遇到一个看似简单但又复杂的问题&#xff1a;IP地址到底是配置在主机上&#xff0c;还是配置在网卡上&#xff1f;为什么我们通常说的是“主机IP地址”呢&#xff1f;让我们一起深入探讨。 1. 网卡与IP地址 &#x1f5a5;️&#x1f…

利用OllyDbg对程序内容进行修改实验

1.双击运行exe文件&#xff0c;出现如下弹窗 2.用ollydbg工具打开该执行文件&#xff0c;页面显示如下 3.在注释窗口执行以下操作 4.双击运行exe文件时&#xff0c;显示”Copied!”所以接下来在注释里找到这个字样&#xff0c;如下&#xff0c;我们需要把对话框中的内容修改为“…

SQL语句学习+牛客基础39SQL

什么是SQL&#xff1f; SQL (Structured Query Language:结构化查询语言) 是用于管理关系数据库管理系统&#xff08;RDBMS&#xff09;。 SQL 的范围包括数据插入、查询、更新和删除&#xff0c;数据库模式创建和修改&#xff0c;以及数据访问控制。 SQL语法 数据库表 一个…

ChatGPT(3.5版本)开放无需注册:算力背后的数据之战悄然打响

✨✨ 欢迎大家来访Srlua的博文&#xff08;づ&#xffe3;3&#xffe3;&#xff09;づ╭❤&#xff5e;✨✨ &#x1f31f;&#x1f31f; 欢迎各位亲爱的读者&#xff0c;感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢&#xff0c;在这里我会分享我的知识和经验。&am…

90天玩转Python-02-基础知识篇:初识Python与PyCharm

90天玩转Python-02-基础知识篇:初识Python与PyCharm 一、引言 二、Python语言探秘 三、如何学习Python 四、PyCharm IDE的安装与使用 安装步骤: 五、Python编程实例 实例1:列表操作

如何在 Mac 上恢复已删除的数据

如果您丢失了 Mac 上的数据&#xff0c;请不要绝望。恢复数据比您想象的要容易&#xff0c;并且有很多方法可以尝试。 在 Mac 上遭受数据丢失是每个人都认为永远不会发生在他们身上的事情之一......直到它发生。不过&#xff0c;请不要担心&#xff0c;因为您可以通过多种方法…

Java数据结构队列

队列(Queue) 概念 队列的使用 注意&#xff1a;Queue是个接口&#xff0c;在实例化时必须实例化LinkedList的对象&#xff0c;因为LinkedList实现了Queue接口。 import java.util.LinkedList; import java.util.Queue;public class Test {public static void main(String[]…

【事务注解✈️✈️】@Transactional注解在不同参数配置下的功能实现

目录 前言 使用场景 1.单个方法层面 2.类级别使用 3.指定异常回滚 4.跨方法调用事务管理 5.只读事务 ​ 6.设置超时时间&#xff0c;超时则自动回滚 7.隔离级别设置 章末 前言 小伙伴们大家好&#xff0c;ACID&#xff08;原子性&#xff0c;一致性&#xff0c;隔离…

用递归方法求n!(n的值最好小于13,int型数据分配4个字节,能表示的最大数为2147483647)

#include <stdio.h> // 主函数&#xff1a;计算并输出给定整数的阶乘 int main(){ // 声明阶乘函数 int fac(int n); int n, y; // 输入一个整型数字 printf("请输入一个整型数字:"); scanf("%d", &n); …

网络安全 | 什么是单点登录SSO?

关注WX&#xff1a;CodingTechWork SSO-概念 单点登录 (SSO) 是一种身份认证方法&#xff0c;用户一次可通过一组登录凭证登入会话&#xff0c;在该次会话期间无需再次登录&#xff0c;即可安全访问多个相关的应用和服务。SSO 通常用于管理一些环境中的身份验证&#xff0c;包…

Arcade 作用力

这个程序展示了简单的变换反馈的应用。变换反馈类似于渲染&#xff0c;但是输出的是一个缓冲区而不是帧缓冲区/屏幕。 这个例子展示了一个常见的ping-pong技术&#xff0c;即在两个缓冲区之间对具有位置和速度的点进行变换&#xff0c;这样我们就始终在前一状态上工作。 初始…