AI模型推理(5)——实战篇(持续更新)

前言

本文主要通过实战的方式,记录各种模型推理的方法

模型训练

首先我们先使用Pytorch训练一个最简单的十分类神经网络,如下:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# 加载训练数据
training_data = datasets.FashionMNIST(root=r"./Datasets/",train=True,download=True,transform=ToTensor(),
)# 加载验证数据
test_data = datasets.FashionMNIST(root=r"./Datasets/",train=False,download=True,transform=ToTensor(),
)# Create data loaders.
batch_size = 16
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")# 定义神经网络模型
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork().to(device)
# print(model)# 定义损失函数,优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)# 定义训练过程
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")# 定义验证方法(在验证数据集中进行验证)
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")epochs = 100
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")

模型推理

Pytorch模型

Pytorch官方入门文档所给出的模型持久化及加载方法,使用torch.save()方法对模型进行持久化,所保存的模型为动态图模型。如下:

# (需承接上面的训练代码,才可正常运行)
# 保存模型
model_path = "./model"
if not os.path.isdir(model_path):os.makedirs(model_path)torch.save(model.state_dict(), os.path.join(model_path, 'model.pth'))
print("Saved PyTorch Model State to model.pth")# 加载模型进行推理
model = NeuralNetwork()
model.load_state_dict(torch.load("./model/model.pth"))classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():pred = model(x)predicted, actual = classes[pred[0].argmax(0)], classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')

TorchScript

TorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法,是一种静态图模型。TorchScript模型可以从Python进程中保存,并加载到没有Python依赖的进程中。使用方法如下:

/* 保存模型 */
# 通过trace的方法生成IR需要一个输入样例 
dummy_input = torch.rand(1, 1, 28, 28) # IR生成 
with torch.no_grad(): jit_model = torch.jit.trace(model, dummy_input) # 将模型序列化 
jit_model.save('./model/jit_model.pt') /* 加载、推理模型 */
# 加载序列化后的模型 
jit_model = torch.jit.load('./model/jit_model.pt') x, y = test_data[0][0], test_data[0][1]
start_time = time.time()
pred = jit_model.forward(x)
print(f"spend time: {time.time()-start_time}")
print(pred[0].argmax(0))

参考文档

Save and Load the Model — PyTorch Tutorials 2.1.1+cu121 documentation

TorchScript — PyTorch master documentation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/178984.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何运用AppLink平台中的数据连接器组件

AppLink平台组件组成 AppLink平台组件分成三个板块触发事件组件、基础组件和数据连接器 数据连接器组件里面有10个组件,目前也在不断新增更多的数据连接器,那他们在AppLink平台里的原理、触发动作以及怎么使用呢?接下来用MySQL和TimescaleD…

在线陪诊系统: 医疗科技的崭新前沿

在医学科技的快速发展中,在线陪诊系统正成为医疗服务领域的创新力量。通过结合互联网和先进的远程技术,这一系统为患者和医生提供了更为便捷、高效的医疗体验。本文将深入探讨在线陪诊系统的技术背后的核心代码和实现原理。 技术背后的关键代码 在线陪…

用于图像分类任务的经典神经网络综述

🎀个人主页: https://zhangxiaoshu.blog.csdn.net 📢欢迎大家:关注🔍点赞👍评论📝收藏⭐️,如有错误敬请指正! 💕未来很长,值得我们全力奔赴更美好的生活&…

autojs-图片篇(一)

注释很详细,直接上代码 自动点击授予权限的操作 //安卓版本高于Android 9 if(device.sdkInt>28){//等待截屏权限申请并同意threads.start(function () {packageName(com.android.systemui).text(立即开始).waitFor();text(立即开始).click();}); }截图查相应位置…

开关电源做“做安规”请记住这 4 点!

1.定义 为了保证人身安全,财产,环境等不受伤害和损失,所做出的规定。 2.安规所涉及的要求 a.电击 b.火灾 c.电磁辐射 d.环境污染 e.化学辐射 f.能量冲击 g.化学腐蚀 h.机械伤害和热伤害 3.世界主要安规体系 a.IEC体系----以欧盟为代表 b.UL体系----以美国为代表…

55.跳跃游戏

原题链接&#xff1a;55.跳跃游戏 思路&#xff1a; 看代码注释 全代码&#xff1a; class Solution { public:bool canJump(vector<int>& nums) {int cover 0;if (nums.size() 1) return true; // 只有一个元素&#xff0c;就是能达到for (int i 0; i < co…

探索性因子分析流程

探索性因子分析的步骤&#xff1a; 接下来&#xff0c;通过一个案例演示因子分析&#xff08;探索性因子分析&#xff09;的各个步骤应该如何进行。 案例&#xff1a;欲探究我国不同省份铁路运输能力情况&#xff0c;收集到部分相关数据如下&#xff1a; 上传数据至SPSSAU系统…

echarts 水波图

echarts 水波图 安装 npm install echarts --save npm install echarts-liquidfill --save引入 import * as echarts from echarts; import echarts-liquidfill;html <div id"chart1" ref"chart1" class"chart1"></div>css .cha…

leetcode做题笔记1670. 设计前中后队列

请你设计一个队列&#xff0c;支持在前&#xff0c;中&#xff0c;后三个位置的 push 和 pop 操作。 请你完成 FrontMiddleBack 类&#xff1a; FrontMiddleBack() 初始化队列。void pushFront(int val) 将 val 添加到队列的 最前面 。void pushMiddle(int val) 将 val 添加到…

RequestContextHolder 类简介

RequestContextHolder 类简介 RequestContextHolder是Spring Framework中的一个类&#xff0c;用于在多线程环境中存储和访问HTTP请求的上下文信息。它允许在Spring应用程序中从任何位置访问当前请求的相关信息&#xff0c;如HTTP头部、会话数据等&#xff0c;而无需将请求对象…

C语言实现串的部分算法

一、简介 串&#xff08;string&#xff09;&#xff08;或字符串&#xff09;是由零个或多个字符组成的有序序列&#xff0c;一般记为 sa1a2....an s为串的名&#xff0c;用单引号括起来的时字符序列串的值&#xff0c;串中字符的数目n称为串的长度。 零个字符的串称为空串…

C语言--每日选择题--Day28

第一题 1. 设a和b均为double型变量&#xff0c;且a5.5、b2.5&#xff0c;则表达式(int)ab/b的值是&#xff08; &#xff09; A&#xff1a;6.500000 B&#xff1a;6 C&#xff1a;5.500000 D&#xff1a;6.000000 答案及解析 D 本题考查的是不同数据类型之间的变量进行运算时…

常见面试题-Redis 切片集群以及主节点选举机制

Redis 切片集群了解吗&#xff1f; 答&#xff1a; Redis 切片集群是目前使用比较多的方案&#xff0c;Redis 切面集群支持多个主从集群进行横向扩容&#xff0c;架构如下&#xff1a; 使用切片集群有什么好处&#xff1f; 提升 Redis 读写性能&#xff0c;之前的主从模式中&…

使用C语言库函数qsort排序注意点

目录 题目背景错误C语言代码&#xff1a;正确C语言代码&#xff1a;注意点 题目背景 高校团委组织校园歌手比赛&#xff0c;进入决赛的校园歌手有10位,歌手编号从1到10进行编号。组委会随机抽取方式产生了决赛次序为&#xff1a;3,1,9,10,2,7,5,8,4,6。比赛现场有5个评委为参赛…

【Docker项目实战】使用Docker部署Plik临时文件上传系统

【Docker实战项目】使用Docker部署Plik 临时文件上传系统 一、Plik介绍1.1 Plik简介1.2 Plik特点 二、本地环境介绍2.1 本地环境规划2.2 本次实践介绍 三、本地环境检查3.1 检查Docker服务状态3.2 检查Docker版本3.3 检查docker compose 版本 四、下载Plik镜像五、部署Plik临时…

PLC:200smart

PLC&#xff1a;200smart 第十章、数据类型、数据存储1、数据类型1.1、有符号数1.2、有符号数 2、传送指令 第十一章、比较指令、整数、浮点数的运算1、比较指令1、运算指令1.1、浮点数运算1.2、整数运算 第十章、数据类型、数据存储 1、数据类型 数据类型分为两大类 无符号数…

Java中的mysql——面试题+答案——第24期

当涉及MySQL时&#xff0c;面试题可以涵盖更多高级主题、安全性和实践经验。 MySQL中的存储引擎InnoDB和MyISAM的区别是什么&#xff1f; 答案&#xff1a; InnoDB支持事务&#xff0c;而MyISAM不支持。InnoDB使用行级锁&#xff0c;而MyISAM使用表级锁。InnoDB支持外键&#x…

【小布_ORACLE】Part11-1--RMAN Backups笔记

Oracle的数据备份于恢复RMAN Backups 学习第11章需要掌握&#xff1a; 一.RMAN的备份类型 二.使用backup命令创建备份集 三.创建备份文件 四.备份归档日志文件 五.使用RMAN的copy命令创建镜像拷贝 文章目录 Oracle的数据备份于恢复RMAN Backups1.RMAN Backup Concepts&#x…

用了这7款html网页制作软件,你会爱上编程!

制作网页是一个复杂的过程&#xff0c;需要注意到各种细节&#xff0c;只有依靠出色的技术能力和强大的工具&#xff0c;我们才能真正达到我们的目标。幸运的是&#xff0c;有很多优秀的HTML网页设计软件可以让整个流程变得更加轻松和高效。以下就是我们经过深思熟虑和严格筛选…

Redis 的过期策略都有哪些?

思考:假如redis的key过期之后&#xff0c;会立即删除吗&#xff1f; Redis对数据设置数据的有效时间&#xff0c;数据过期以后&#xff0c;就需要将数据从内存中删除掉。可以按照不同的规则进行删除&#xff0c;这种删除规则就被称之为数据的删除策略&#xff08;数据过期策略…