torch.compile模型编译加速

一、定义

  1. 定义
  2. 接口介绍
  3. 案例

二、实现

  1. 定义

    1. torch.compile 是加速 PyTorch 代码的最新方法! torch.compile 通过 JIT 将 PyTorch 代码编译成优化的内核,使 PyTorch 代码运行得更快,大部分过程仅需修改一行代码。
    2. torch.compile 的一个重要组件就是 TorchDynamo。TorchDynamo 负责将任意 Python 代码即时编译成 FX Graph(计算图),然后可以进一步优化。TorchDynamo 通过在运行时分析 Python 字节码并检测对 PyTorch 操作的调用来提取 FX Graph。
    3. torch.compile 的另一个重要组件 TorchInductor 会将 FX Graph 进一步编译成优化的内核。TorchDynamo 允许使用不同的后端,所以为了检查 TorchDynamo 输出的 FX Graph,可以创建一个自定义后端来输出 FX Graph 并简单地返回 Graph 未优化的前向内容。
    4. 允许自定义函数
      开始编译的时候需要耗费大量的时间,即第一次请求,时间较长。
      5. 详情见: https://pytorch.org/docs/stable/torch.compiler.html
      https://pytorch.org/get-started/pytorch-2.0/
  2. 接口介绍

modoel_compile = torch.compile(model, mode="reduce-overhead")
(默认)default: 适合加速大模型,编译速度快且无需额外存储空间
reduce-overhead:适合加速小模型,需要额外存储空间
max-autotune:编译速度非常耗时,但提供最快的加速
  1. 案例
import torch
def foo(x, y):a = torch.sin(x)b = torch.cos(x)return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))
#方式二
@torch.compile
def opt_foo2(x, y):a = torch.sin(x)b = torch.cos(x)return a + b
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10)))
方式三
class MyModule(torch.nn.Module):def __init__(self):super().__init__()self.lin = torch.nn.Linear(100, 10)def forward(self, x):return torch.nn.functional.relu(self.lin(x))
mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(torch.randn(10, 100)))

训练

import torch
import torchvision.models as modelsmodel = models.resnet18().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
compiled_model = torch.compile(model)x = torch.randn(16, 3, 224, 224).cuda()
optimizer.zero_grad()
out = compiled_model(x)
out.sum().backward()
optimizer.step()

保存:

torch.save(optimized_model.state_dict(), "foo.pt")
# both these lines of code do the same thing
torch.save(model.state_dict(), "foo.pt")

推理:

# API Not Final
exported_model = torch._dynamo.export(model, input)
torch.save(exported_model, "foo.pt")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/47192.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

利用 VAE、GAN 和 Transformer 释放生成式 AI

利用 VAE、GAN 和 Transformer 释放生成式 AI 文章目录 一、介绍1.1 学习目标1.2 定义生成式 AI 二、生成式 AI 的力量三、变分自动编码器 (VAE)3.1 定义编码器和解码器模型3.2 定义采样函数3.3 定义损失函数3.4 编译和训练模型 四、生成对抗网络 &#…

Spring Framework各种jar包官网下载2024年最新下载官方渠道。

Spring其实就是一个大家族,它包含了Spring Framework,Spring Boot等一系列技术,它其实就是由许许多多的jar包构成,我们要使用Spring的框架,就要去下载支持这个框架的jar包即可。 1.官网下载Spring Framework的jar包 官…

R语言学习笔记10-向量-矩阵-数组-数据框-列表对比

R语言学习笔记10-向量-矩阵-数组-数据框-列表对比 向量(Vector)矩阵(Matrix)数组(Array)数据框(Data Frame)列表(List)综合分析和对比 在R语言中,…

算法训练营第38天|1049. 最后一块石头的重量 II|494. 目标和|474.一和零

1049. 最后一块石头的重量 II 思路:本题思路为尽可能的将石头分成两堆。可以看成有一个容量为总和一半的背包,尽可能装满这个背包。 494. 目标和 思路:首先要把这道题转化为背包问题,这道题本质上是要将数组分成两个子集。其中一…

vue3 + antd + typeScript 封装一个高仿的ProTable(2)

前言 因为我想要一个类似ProTable高级组件的表单,但是查询之后发现没有,所以就自己写一个,这个版本会更加完善。 功能 1.封装表格request请求集中(分页、筛选、过滤),让功能使用起来更加简单 const sourceRequest = async (params:any, pagination:any, filters:any, …

java通过jwt生成Token

定义 JWT(JSON Web Token)简而言之,JWT是一个加密的字符串,JWT传输的信息经过了数字签名,因此传输的信息可以被验证和信任。一般被用来在身份提供者和服务提供者间传递被认证用户的身份信息,以便于从资源服…

React@16.x(60)Redux@4.x(9)- 实现 applyMiddleware

目录 1,applyMiddleware 原理2,实现2.1,applyMiddleware2.1.1,compose 方法2.1.2,applyMiddleware 2.2,修改 createStore 接上篇文章:Redux中间件介绍。 1,applyMiddleware 原理 R…

iOS——MRC与ARC以及自动释放池深入底层学习

MRC与ARC再回顾 在前面,我们简单学了MRC与ARC。MRC指手动内存管理,需要开发者使用retain、release等手动管理对象的引用计数,确保对象在必要时被释放。ARC指自动内存管理,由编译器自动管理对象的引用计数,开发者不需要…

基于深度学习的机器人控制

基于深度学习的机器人控制技术结合了深度学习模型和机器人操作,旨在提升机器人在复杂环境中的自适应能力和智能行为。这项技术在自动驾驶、工业自动化、医疗辅助等领域有着广泛的应用。以下是对这一领域的系统介绍: 1. 任务和目标 机器人控制的主要任务…

基于springboot和mybatis的RealWorld后端项目实战一之hello-springboot

新建Maven项目 注意archetype选择quickstart pom.xml 修改App.java App.java同级目录新增controller包 HelloController.java package org.example.controller;import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotatio…

浅析stm32启动文件

浅析stm32启动文件 文章目录 浅析stm32启动文件1.什么是启动文件?2.启动文件的命名规则3.stm32芯片的命名规则 1.什么是启动文件? 我们来看gpt给出的答案: STM32的启动文件是一个关键的汇编语言源文件,它负责在微控制器上电或复位…

【简历】惠州某二本学院:前端简历指导,秋招面试通过率为0

注:为保证用户信息安全,姓名和学校等信息已经进行同层次变更,内容部分细节也进行了部分隐藏 简历说明 这是一份25届二本同学,投递前端职位的简历,那么在校招环节二本同学主要针对的还是小公司,这个学校因为…

LVS+Nginx高可用集群---搭建高可用集群负载均衡

1.LVS简介 Lvs(Linux Virtual Server):使用集群,对于整个用户来说是透明,用户访问的时候是单个高性能的整体。道理与nginx类似 LVS网络拓扑图:是基于四层。 用户通过浏览器发送请求,然后到达LVS.Lvs根据相应算法将…

AI PC创造新商机,ISP与HPD集成单芯片方案受欢迎

今年以来,AI PC逐渐成为市场的焦点,因为AI PC给多年一成不变的PC市场带来了新的看点,也给了消费者升级的理由。今年是AI PC的元年,上半年不论是芯片厂商,还是PC厂商都在AI PC市场快速布局。AI PC相关的大模型、生态&am…

ollama + fastgpt 搭建免费本地知识库

目录 1、ollama ollama的一些操作命令: 使用的方式: 2、fastgpt 快速部署: 修改配置: config.json: docker-compose.yml: 运行fastgpt: 访问OneApi: 添加令牌和渠道: 登陆fastgpt,创建知识库和应用 3、总结: 附录: 1. 11434是ollama的端口: 2. m3e 测…

处理多维特征的输入(Multiple Dimension Input)

输入x有多个特征features,最终得到输出y的类别。 在上一节提到,左边是我们最开始了解的线性回归,右边是我们的logistics回归(返回值为一个离散的集合)。对于本节,就是在logistics回归输入x的基础上让其多一…

中伟视界:矿山智能化——AI引领创新,行车不行人检测算法实现实时预警,防范行车不行人事故发生

行车不行人检测AI分析算法通过利用人工智能和深度学习技术,对井下行人和车辆的行驶情况进行实时检测和识别。该算法在提升矿山安全管理、减少事故发生方面具有重要作用。本文将详细介绍该AI算法的识别过程、应用场景及其技术特点。 一、识别过程 行车不行人检测AI分…

LM算法与TRF算法(含有在ICP配准情境下的两种算法对应代码)

在 ICP 配准中,使用LM算法通常会遇到找到的对应点对数量不足的问题 因为使用 Levenberg-Marquardt (LM) 算法进行最小二乘优化时,残差的数量小于变量的数量。 实际应用: ICP配准过程:针对两个三维点云数据,两个点云上均有相互对应的3D关键点。我需要在每个点云上的每个关…

3 万字 25 道 Nginx经典面试题总结

🍅 作者简介:哪吒,CSDN2021博客之星亚军🏆、新星计划导师✌、博客专家💪 🍅 哪吒多年工作总结:Java学习路线总结,搬砖工逆袭Java架构师 🍅 技术交流:定期更新…

Hadoop安装报错

报错:ERROR 2023-03-09 21:33:00,178 NetUtil.py:97 - SSLError: Failed to connect. Please check openssl library versions. 解决方案: 在安装失败得客户端执行 编辑 /etc/python/cert-verification.cfg 配置文件,将 [https] 节的 verify 项 设为禁用…