tvm.frontend.from_pytorch详细介绍(1)

文章目录

  • 一、pytorch前端整体转化流程(部分)
    • 1.脚本化的pytorch模型
    • 2.内联优化(_run_jit_passes)
      • 2.1、内联优化
      • 2.2 什么是内联函数
    • 3.graph中的所有op(get_all_op_names)
    • 3.1 各个变量的值
      • 1 .graph
      • 2 .nodes
      • 3 .p nodes
      • 4、返回结果
  • 二、from_pytorch完整代码



一、pytorch前端整体转化流程(部分)

1.脚本化的pytorch模型

  脚本化 PyTorch 模型:首先,使用 torch.jit.script 函数对PyTorch 模型进行脚本化。这将把模型转换为 TorchScript 表示,可以独立于 Python 运行时加载和执行。

# 脚本化 PyTorch 模型
scripted_model = torch.jit.script(你的模型)
# 将脚本化的 PyTorch 模型转换为 Relay
relay_model, params = relay.frontend.from_pytorch(scripted_model, input_shapes=[(input_shape,)])

2.内联优化(_run_jit_passes)

在 JIT 过程中,为了处理 prim::CallMethod 的调用,需要执行内联传递操作,将方法调用的函数体内联展开,以便进行进一步的优化或处理。

2.1、内联优化

   这段代码的目的是在PyTorch的JIT编译过程中应用一些特定的转换和优化,以提高生成的图的性能和效率。

def _run_jit_passes(graph, enable_lower_all_tuples=True):"""The inline pass is necessary to unwrap prim::CallMethod"""# pylint: disable=c-extension-no-memberimport torchif is_version_greater_than("1.5.1"):# This is required for torchvision detection models from 1.6 above# It is the same as _jit_pass_inline, except that it has some special# case behaviors for some ops such as aten::__interpolate()torch._C._jit_pass_onnx_function_substitution(graph)else:torch._C._jit_pass_inline(graph)if enable_lower_all_tuples:

   JIT passes用于对PyTorch图进行转换和优化,以便在部署或执行期间提高性能。如果PyTorch的版本大于1.5.1,它会运行一系列的passes,如果PyTorch的版本低于等于1.5.1,它将只运行一个叫做_jit_pass_inline的pass,该pass用于内联函数调用。

2.2 什么是内联函数

  Inplace操作是指在原地(in-place)修改数据或对象,而不创建新的副本。它通常用于优化内存使用和减少计算开销。
  在PyTorch中,有一些操作允许以inplace的方式进行,即直接修改操作作用的张量,而不创建新的张量对象。这些操作通常以_结尾,例如add_、mul_、div_等。通过使用inplace操作,可以减少内存分配和数据拷贝的开销,提高代码的效率。
  需要注意的是,使用inplace操作时需要小心,因为它会直接修改原始数据,可能导致意外的副作用或不可逆的修改。因此,在使用inplace操作时,应确保了解其行为,并在适当的情况下使用它们。

3.graph中的所有op(get_all_op_names)

  这段代码的目的是获取输入图中所有操作符的名称,并以集合的形式返回。它通过遍历图中的节点和子块,收集并去重所有操作符的名称。调试断点的设置可能是为了在执行过程中检查和调试代码。

def get_all_op_names(graph):""" Return all operator names in the input graph """nodes = list(graph.nodes())prim_with_blocks = ["prim::If", "prim::Loop"]for prim in prim_with_blocks:prim_nodes = graph.findAllNodes(prim, recurse=True)for prim_node in prim_nodes:for block in prim_node.blocks():nodes += block.nodes()return set(node.kind() for node in nodes)

函数的主要逻辑如下:

  1. 获取图中的所有节点列表。
  2. 定义了一个名为prim_with_blocks的列表,其中包含了"prim::If"和"prim::Loop"这两个具有子块(blocks)的操作符名称。
  3. 对于prim_with_blocks列表中的每个操作符名称,遍历图中所有的该操作符节点(包括子块中的节点)。
  4. 将每个节点的子块中的节点也加入到nodes列表中。
  5. 最后,返回一个集合(set),其中包含了nodes列表中每个节点的操作符名称。

3.1 各个变量的值

1 .graph

def get_all_op_names(graph):
(Pdb) p graph
graph(%self.1 : __torch__.TempOpModel,%input : Float(1:48, 3:16, 16:1)):%2 : __torch__.torch.nn.modules.conv.ConvTranspose1d = prim::GetAttr[name="convtrans"](%self.1)%4 : Tensor = prim::GetAttr[name="weight"](%2)%5 : None = prim::Constant(), scope: __module.convtrans%6 : int = prim::Constant[value=2](), scope: __module.convtrans %7 : int[] = prim::ListConstruct(%6), scope: __module.convtrans%8 : int = prim::Constant[value=1](), scope: __module.convtrans %9 : int[] = prim::ListConstruct(%8), scope: __module.convtrans%10 : int = prim::Constant[value=1](), scope: __module.convtrans %11 : int[] = prim::ListConstruct(%10), scope: __module.convtrans%12 : bool = prim::Constant[value=1](), scope: __module.convtrans %13 : int = prim::Constant[value=0](), scope: __module.convtrans %14 : int[] = prim::ListConstruct(%13), scope: __module.convtrans%15 : int = prim::Constant[value=1](), scope: __module.convtrans %16 : bool = prim::Constant[value=0](), scope: __module.convtrans%17 : bool = prim::Constant[value=0](), scope: __module.convtrans%18 : bool = prim::Constant[value=1](), scope: __module.convtrans %19 : Float(1:198, 6:33, 33:1) = aten::_convolution(%input, %4, %5, %7, %9, %11, %12, %14, %15, %16, %17, %18), scope: __module.convtrans return (%19)

2 .nodes

nodes = list(graph.nodes())
(Pdb) p graph.nodes
<bound method PyCapsule.nodes of graph(%self.1 : __torch__.TempOpModel,%input : Float(1:48, 3:16, 16:1)):%2 : __torch__.torch.nn.modules.conv.ConvTranspose1d = prim::GetAttr[name="convtrans"](%self.1)%4 : Tensor = prim::GetAttr[name="weight"](%2)%5 : None = prim::Constant(), scope: __module.convtrans%6 : int = prim::Constant[value=2](), scope: __module.convtrans %7 : int[] = prim::ListConstruct(%6), scope: __module.convtrans%8 : int = prim::Constant[value=1](), scope: __module.convtrans %9 : int[] = prim::ListConstruct(%8), scope: __module.convtrans%10 : int = prim::Constant[value=1](), scope: __module.convtrans %11 : int[] = prim::ListConstruct(%10), scope: __module.convtrans%12 : bool = prim::Constant[value=1](), scope: __module.convtrans %13 : int = prim::Constant[value=0](), scope: __module.convtrans%14 : int[] = prim::ListConstruct(%13), scope: __module.convtrans%15 : int = prim::Constant[value=1](), scope: __module.convtrans %16 : bool = prim::Constant[value=0](), scope: __module.convtrans %17 : bool = prim::Constant[value=0](), scope: __module.convtrans%18 : bool = prim::Constant[value=1](), scope: __module.convtrans %19 : Float(1:198, 6:33, 33:1) = aten::_convolution(%input, %4, %5, %7, %9, %11, %12, %14, %15, %16, %17, %18), scope: __module.convtrans return (%19)

3 .p nodes

nodes = list(graph.nodes())
(Pdb) p nodes
[%2 : __torch__.torch.nn.modules.conv.ConvTranspose1d = prim::GetAttr[name="convtrans"](%self.1)
, %4 : Tensor = prim::GetAttr[name="weight"](%2)
, %5 : None = prim::Constant(), scope: __module.convtrans
, %6 : int = prim::Constant[value=2](), scope: __module.convtrans 
, %7 : int[] = prim::ListConstruct(%6), scope: __module.convtrans
, %8 : int = prim::Constant[value=1](), scope: __module.convtrans 
, %9 : int[] = prim::ListConstruct(%8), scope: __module.convtrans
, %10 : int = prim::Constant[value=1](), scope: __module.convtrans 
, %11 : int[] = prim::ListConstruct(%10), scope: __module.convtrans
, %12 : bool = prim::Constant[value=1](), scope: __module.convtrans 
, %13 : int = prim::Constant[value=0](), scope: __module.convtrans
, %14 : int[] = prim::ListConstruct(%13), scope: __module.convtrans
, %15 : int = prim::Constant[value=1](), scope: __module.convtrans 
, %16 : bool = prim::Constant[value=0](), scope: __module.convtrans 
, %17 : bool = prim::Constant[value=0](), scope: __module.convtrans 
, %18 : bool = prim::Constant[value=1](), scope: __module.convtrans 
, %19 : Float(1:198, 6:33, 33:1) = aten::_convolution(%input, %4, %5, %7, %9, %11, %12, %14, %15, %16, %17, %18), scope: __module.convtrans 
]

4、返回结果

  set(node.kind() for node in nodes)是一个表达式,用于创建一个集合(set),其中包含了nodes列表中每个节点的类型(kind)。

  在PyTorch中,node.kind()是用于获取节点类型的方法。每个节点在计算图中都有一个类型,表示该节点所执行的操作或功能。

  node.kind()返回一个表示节点类型的字符串,通常是以prim::开头,后面跟着具体的操作名称或标识符。例如,prim::Add、prim::Mul等。

(Pdb) p op_names
{'aten::_convolution', 'prim::Constant', 'prim::ListConstruct', 'prim::GetAttr'}

二、from_pytorch完整代码


```python
def from_pytorch(script_module,input_infos,custom_convert_map=None,default_dtype="float32",use_parser_friendly_name=False,keep_quantized_weight=False,export_renamed_c_graph_path=None,preserve_pytorch_scopes=False,
):"""Load PyTorch model in the form of a scripted PyTorch model and convert into relay.The companion parameters will be handled automatically.Parameters----------script_module : TopLevelTracedModule objectTorchScripted PyTorch graphNote: We currently only support traces (ie: torch.jit.trace(model, input))input_infos : List of tuplesCan be (input name, input shape) or (input name, (input shape, input types))Graph level input shape and type listThe same input names need to be used for deployment, so choose easy toremember names (such as: input0, input1)e.g.[('input0', (1, 2)), ('input1', (3, 4))]or[('input0', ((1, 2), 'int')), ('input1', ((3, 4), 'float'))]custom_convert_map : Dictionary of str to Relay opA custom op conversion map in the same format as _convert_map abovedefault_type : strThe default dtype to use when type information is not provided by PyTorch.use_parser_friendly_name : boolWhen True, replace '.' with `_' in a original parameter name.The Relay text parser treats a variable name followed by a period as a tuple element access,so a variable name like "dense.weight" cannot be parsed correctly.Use this option when you want to run the AnnotateSpans pass on the imported module.keep_quantized_weight : boolReturn quantized weights and bias, rather than float ones. PyTorch stores quantized weightsin a custom format, so we cannot directly access 8 bit weights as Numpy arrays. We usea PyTorch function to unpack quantized weights into float32 arrays and quantizationparameters. By default, we return float32 weights and rely on the QNN lowering and theRelay constant folding pass to quantize weights at compile time. In BYOC use cases, however,we cannot apply the constant folding pass on a QNN graph. If keep_quantized_weight is True,we quantize weights in the frontend using a function that is equivalent toqnn.op.quantize(...) operating on Numpy arrays.export_renamed_c_graph_path : str, optionalExport the renamed torch._C.Graph to the path.During the conversion, variable names in torch._C.Graph will be assigned based on their optypes. The exported text file can be the reference to spans.preserve_pytorch_scopes : boolWhen naming the nodes in the Relay graph, use the "scope name" from the Pytorch model.If false, a default namer is used that does not preserve the Pytorch scope names.Returns-------mod : tvm.IRModuleThe module that optimizations will be performed on.params : dict of str to tvm.runtime.NDArrayDict of converted parameters stored in tvm.runtime.ndarray format"""import torchmod = tvm.IRModule()prelude = Prelude(mod)enable_lower_all_tuples = Trueconverter = PyTorchOpConverter(prelude, default_dtype, use_parser_friendly_name, preserve_pytorch_scopes)graph = script_module.graph.copy()# Check if lower_all_tuples pass can be enabledgraph_inputs = list(graph.inputs())for inp in graph_inputs:if inp.type().kind() == "TupleType" or inp.type().kind() == "ListType":enable_lower_all_tuples = Falsebreak_run_jit_passes(graph, enable_lower_all_tuples)_redirect_inplace_output(graph)if custom_convert_map:converter.update_convert_map(custom_convert_map)op_names = get_all_op_names(graph)converter.report_missing_conversion(op_names)is_module = isinstance(script_module, torch.jit.ScriptModule)params = script_module.state_dict() if is_module else {}outputs = _get_relay_input_vars(graph, input_infos, prelude, default_dtype=default_dtype, is_module=is_module)if use_parser_friendly_name:new_names = [key.replace(".", "_") for key in params.keys()]params = dict(zip(new_names, params.values()))# rename _C.Graph here for constructing meaningful source name of graph nodes# by doing so, we could Use source_map as the reference to rename model parameterssource_map = _debug_rename(graph, use_parser_friendly_name, preserve_pytorch_scopes)param_vars, tensors, packed_param_map, param_debug_name_map = convert_params(graph, params, source_map, use_parser_friendly_name)tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}outputs.update(param_vars)# For quantized modelsquantized_ops = set(["aten::quantize_per_tensor", "quantized::linear_dynamic"])if len(quantized_ops.intersection(set(op_names))) > 0:weight_quant_params = qnn_torch.get_weight_quant_params(script_module, packed_param_map.values())qnn_torch.inline_input_quant_params_for_fx(graph, tensors, param_debug_name_map)input_scales_for_bias = qnn_torch.add_input_quant_params_to_op_inputs(graph)qnn_torch.add_quant_params_to_outputs(outputs,packed_param_map,weight_quant_params,input_scales_for_bias,keep_quantized_weight,)qnn_torch.add_quant_params(tvm_params, weight_quant_params)converter.update_convert_map(qnn_torch.convert_map)operator_nodes = _get_operator_nodes(graph.nodes(),converter.source_map,converter.op_type_dict,use_parser_friendly_name,preserve_pytorch_scopes,)ret_name = _get_input_names(graph.return_node())outputs = converter.convert_operators(operator_nodes, outputs, ret_name)# ListConstruct kept original python list. Convert to tuple.outputs = [_expr.Tuple(output) if isinstance(output, list) else output for output in outputs]if len(outputs) > 1:ret = _expr.Tuple(outputs)else:ret = outputs[0]# Separate data inputs and parameters to make sure data inputs come first.func_args = []data_inputs = []for arg in _analysis.free_vars(ret):if arg.name_hint not in tvm_params.keys():data_inputs.append(arg)else:func_args.append(arg)# Ensures the order of data_input is the same as the order of inputs specified in input_info.order_input_infos = {input_info[0]: len(input_infos) - idx for idx, input_info in enumerate(input_infos)}data_inputs = sorted(data_inputs,key=lambda data_input: order_input_infos[data_input.name_hint]if data_input.name_hint in order_input_infoselse -1,reverse=True,)func_args = data_inputs + func_argsmod["main"] = tvm.relay.Function(func_args, ret)if export_renamed_c_graph_path:export_c_graph(export_renamed_c_graph_path, graph)return transform.RemoveUnusedFunctions()(mod), tvm_params

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/10900.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

国内智能搜索工具实战教程

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,CCF比赛第二名,科大讯飞比赛第三名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的学生进入到算法…

3、Qt--配置文件的使用

开发平台&#xff1a;Win10 64位 开发环境&#xff1a;Qt Creator 13.0.0 构建环境&#xff1a;Qt 5.15.2 MSVC2019 64位 一、需求及方案 实际开发过程中&#xff0c;我们需要根据本地的配置文件&#xff0c;去配置我们的程序&#xff0c;比如数据库地址、网络地址等信息&…

分享10类正规的网上赚钱平台,让你摆脱单一收入

在这个互联网飞速发展的时代&#xff0c;你是否还在为单一的收入来源而焦虑&#xff1f;别担心&#xff0c;今天带你解锁10种网上赚钱的新姿势&#xff0c;让你的收入不再单一&#xff0c;甚至可能翻倍&#xff01; 1. 文库类&#xff1a;知识的变现 你知道吗&#xff1f;你的…

k8s 数据流向 与 核心概念详细介绍

目录 一 k8s 数据流向 1&#xff0c;超级详细版 2&#xff0c;核心主键及含义 3&#xff0c;K8S 创建Pod 流程 4&#xff0c;用户访问流程 二 Kubernetes 核心概念 1&#xff0c;Pod 1.1 Pod 是什么 1.2 pod 与容器的关系 1.3 pod中容器 的通信 2&#xff0c; …

imx91的uboot编译

一、准备操作 下载半导体厂家的uboot源码 如这里我要下载的是imx91的恩智浦linux芯片bootloader 进入半导体厂家官网 下载源码&#xff0c;略 更新linux源&#xff0c;这里我是替换成清华源 vi /etc/apt/sources.list deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ fo…

【江科大STM32学习笔记】新建工程

1.建立工程文件夹&#xff0c;Keil中新建工程&#xff0c;选择型号 2.工程文件夹里建立Start、Library、User等文件夹&#xff0c;复制固件库里面的文件到工程文件夹 为添加工程文件准备&#xff0c;建文件夹是因为文件比较多需要分类管理&#xff0c;需要用到的文件一定要复…

Web UI自动化测试--PO模式

没有PO实现的测试用例的问题: 重用性低:登录功能重复可维护性差:数据和代码混合可读性差:元素定位方法杂乱(id、xpath、css混杂)可读性差:不易识别操作的含义(特别是css和xpath语法)可维护性差:如果某个元素的属性改了,你要更改多次PO(Page Object Model)页面对象模型…

完全背包问题(c++)

完全背包问题 当前有 N 种物品&#xff0c;第 i 种物品的体积是 ci​&#xff0c;价值是 wi​。 每种物品的数量都是无限的&#xff0c;可以选择任意数量放入背包。 现有容量为 V 的背包&#xff0c;请你放入若干物品&#xff0c;使总体积不超过 V&#xff0c;并且总价值尽可…

el-table被点击行添加背景颜色

在 Element UI 的 el-table 组件中&#xff0c;你可以通过使用行样式 row-style 属性来为被点击的行添加颜色。首先&#xff0c;你需要在数据中定义一个对象来存储被点击行的索引&#xff0c;然后在 row-style 函数中根据这个索引来返回不同的样式。 以下是一个示例&#xff1…

YOLOv8+CLIP实现图文特征匹配

本文通过结合YOLOv8s的高效物体检测能力与CLIP的先进图像-文本匹配技术&#xff0c;展示了深度学习在处理和分析复杂多模态数据中的潜力。这种技术的应用不仅限于学术研究&#xff0c;还能广泛应用于工业、商业和日常技术产品中&#xff0c;以实现更智能的人机交互和信息处理。…

新年首站 | 宝兰德教育行业信创新动力发展研讨会顺利召开

近日&#xff0c;宝兰德携手慧点数码、安超云共同举办了教育行业信创新动力发展研讨会。会议邀请了中国人民公安大学、中国戏曲学院、北京航空航天大学、北京理工大学、华北电力大学、中国矿业大学、北京服装学院、北京城市学院等数十所高校信息中心负责人、专家出席了本次会议…

LeetCode 题目 120:三角形最小路径和

作者介绍&#xff1a;10年大厂数据\经营分析经验&#xff0c;现任字节跳动数据部门负责人。 会一些的技术&#xff1a;数据分析、算法、SQL、大数据相关、python&#xff0c;欢迎探讨交流 欢迎加入社区&#xff1a;码上找工作 作者专栏每日更新&#xff1a; LeetCode解锁1000题…

WEB后端复习——javabean与会话cookie、session

JavaBean 是一种符合特定命名约定的 Java 类&#xff0c;它通常用于封装数据。 JavaBean 的主要特点是&#xff1a; 1. 无参构造器&#xff1a;JavaBean 必须有一个公共的&#xff08;public&#xff09;无参构造方法&#xff0c;以便于反射时能够创建对象实例。 2. 属性&…

Android的视图显示和管理机制:layout view window WindowManager Canvas Surface

在Android系统中&#xff0c;Layout view window WindowManager Canvas Surface SurfaceFlinger这些组件协同工作&#xff0c;以实现图形的绘制和显示。需要搞明白这些组件是什么时候创建的以及他们之间的结构关系。 从上到下的层级关系&#xff1a;用户在View上进行操作&…

嵌入式交叉编译:ffmpeg及相关库

目前只编译了部分。其他库需要时再说。 fdk-aac 嵌入式交叉编译&#xff1a;linux fdk-aac-CSDN博客 libvpx 这个最麻烦&#xff0c;还是编译通过啦。 嵌入式交叉编译&#xff1a;libvpx&#xff08;全网首发&#xff09;-CSDN博客 x265 嵌入式交叉编译&#xff1a;x265…

考研踩坑经验分享

文章目录 写在前面自身情况简介自身学习路线优点坑点 学习路线建议1、2和3月份3、4和5月份6、7和8月份9、10月份11、12月份 一些私货建议结尾 写在前面 考研是一件非常有盼头的事&#xff0c;但绝对不是一件容易的事。 如果你不能做好来年三月份出成绩时&#xff0c;坦然接受…

Ubuntu 下使用 Scons 交叉编译嘉楠堪智 CanMV K230 大小核 Coremark 程序

在 Ubuntu 下使用 SCons 进行交叉编译嘉楠堪智 CanMV K230 大小核&#xff08;不同的玄铁 C908 核心&#xff09;的 C 程序&#xff0c;以 Coremark 程序为例&#xff0c;顺便测试一下大小核和编译器的性能。 2024年3月14日&#xff0c;嘉楠科技宣布推出了全球首款支持 RISC-V…

# 从浅入深 学习 SpringCloud 微服务架构(十七)--Spring Cloud config(1)

从浅入深 学习 SpringCloud 微服务架构&#xff08;十七&#xff09;–Spring Cloud config&#xff08;1&#xff09; 一、配置中心的 概述 1、配置中心概述 对于传统的单体应用而言&#xff0c;常使用配置文件来管理所有配置&#xff0c;比如 SpringBoot 的 application.y…

消费金融平台公司如何做大做强自营产品

本文来自于2019年的某次内部分享沟通会&#xff0c;部分敏感内容已做删减。

油泼辣子在食品类别可以申请成商标不!

前阵韩国人在美国申请“chili crunch”油泼辣子作为商标&#xff0c;还准备禁止华人餐馆使用投诉侵权并索赔&#xff0c;普推知产老杨在USPTO上面检索发现&#xff0c;这个人申请的主要是30类方便食品的调味品&#xff0c;商标分类是全球通用的。 商标名称不能申请本类所属的通…