6-2 pytorch中训练模型的3种方法

Pytorch通常需要用户编写自定义训练循环,训练循环的代码风格因人而异。(养成自己的习惯)
有3类典型的训练循环代码风格:脚本形式训练循环,函数形式训练循环,类形式训练循环。
下面以minist数据集的多分类模型的训练为例,演示这3种训练模型的风格。
其中类形式训练循环我们同时演示torchkeras.KerasModel和torchkeras.LightModel两种示范。

准备数据

transform = transforms.Compose([transforms.ToTensor()])ds_train = torchvision.datasets.MNIST(root="./data/mnist/",train=True,download=True,transform=transform)
ds_val = torchvision.datasets.MNIST(root="./data/mnist/",train=False,download=True,transform=transform)dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=4)
dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=128, shuffle=False, num_workers=4)print(len(ds_train))
print(len(ds_val))

image.png

%matplotlib inline
%config InlineBackend.figure_format = 'svg'#查看部分样本
from matplotlib import pyplot as plt plt.figure(figsize=(8,8)) 
for i in range(9):img,label = ds_train[i] img = torch.squeeze(img) # 删除为1的维度ax=plt.subplot(3,3,i+1)ax.imshow(img.numpy())ax.set_title("label = %d"%label)ax.set_xticks([])ax.set_yticks([]) 
plt.show()

image.png

一、脚本风格

脚本风格的训练循环非常常见。

net = nn.Sequential()
net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3))
net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5))
net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("dropout",nn.Dropout2d(p = 0.1))
net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
net.add_module("flatten",nn.Flatten())
net.add_module("linear1",nn.Linear(64,32))
net.add_module("relu",nn.ReLU())
net.add_module("linear2",nn.Linear(32,10))print(net)

image.png
代码量较多,可以查看最下方链接对应的notebook。

二、函数风格

该风格在脚本形式上做了进一步的函数封装。

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.layers = nn.ModuleList([nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Dropout2d(p = 0.1),nn.AdaptiveMaxPool2d((1,1)),nn.Flatten(),nn.Linear(64,32),nn.ReLU(),nn.Linear(32,10)])def forward(self,x):for layer in self.layers:x = layer(x)return x
net = Net()
print(net)

image.png
代码量较多,可以查看最下方链接对应的notebook。

三、类风格

此处使用**torchkeras.KerasModel(其源码其实就是脚本风格中的代码)**高层次API接口中的fit方法训练模型。
使用该形式训练模型非常简洁明了。
先构建模型,同一二。

from torchkeras import KerasModel class Net(nn.Module):def __init__(self):super().__init__()self.layers = nn.ModuleList([nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Dropout2d(p = 0.1),nn.AdaptiveMaxPool2d((1,1)),nn.Flatten(),nn.Linear(64,32),nn.ReLU(),nn.Linear(32,10)])def forward(self,x):for layer in self.layers:x = layer(x)return xnet = Net() print(net)

使用kerasModel:

from torchmetrics import Accuracymodel = KerasModel(net,loss_fn=nn.CrossEntropyLoss(),metrics_dict = {"acc":Accuracy(task='multiclass',num_classes=10)},optimizer = torch.optim.Adam(net.parameters(),lr = 0.01)  )model.fit(train_data = dl_train,val_data= dl_val,epochs=10,patience=3,monitor="val_acc", mode="max",plot=True,cpu=True
)

训练过程:
image.png
其实编码训练代码按照自己的习惯即可,不必要按照以上三种方式。

参考:https://github.com/lyhue1991/eat_pytorch_in_20_days

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/81918.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Boot集成Redis实现数据缓存

🌿欢迎来到衍生星球的CSDN博文🌿 🍁本文主要学习Spring Boot集成Redis实现数据缓存 🍁 🌱我是衍生星球,一个从事集成开发的打工人🌱 ⭐️喜欢的朋友可以关注一下🫰🫰&…

【最新面试问题记录持续更新,java,kotlin,android,flutter】

最近找工作,复习了下java相关的知识。发现已经对很多概念模糊了。记录一下。部分是往年面试题重新整理,部分是自己面试遇到的问题。持续更新中~ 目录 java相关1. 面向对象设计原则2. 面向对象的特征是什么3. 重载和重写4. 基本数据类型5. 装箱和拆箱6. …

华为aarch64架构的泰山服务器EulerOS 2.0 (SP8)系统离线安装saltstack3003.1实践

华为泰山服务器的CPU芯片架构为aarch64,所装系统为EulerOS 2.0 (SP8)aarch64系统,安装saltstack比较困难。本文讲解通过pip安装方式离线安装saltstack3003.1以进行集中化管理和维护。 一、系统环境 1、操作系统版本 [rootlocalhost ~]# cat /etc/os-r…

uniapp 可输入可选择的........框

安装 uniapp: uni-combox地址 vue页面 <uni-combox :border"false" input"selectname" focus"handleFocus" blur"handleBlur" :candidates"candidates" placeholder"请选择姓名" v-model"name"&g…

在微信公众号怎么实现投票活动

微信公众号实现投票活动的方法和步骤 一、投票活动的优势 通过投票活动&#xff0c;微信公众号可以实现用户参与、增加互动、了解用户需求等功能&#xff0c;同时也可以提升品牌知名度和用户粘性。以下是一些投票活动的优势&#xff1a; 增加用户参与度&#xff1a;通过投票活…

Unity中Shader的模板测试

文章目录 前言什么是模板测试1、模板缓冲区2、模板缓冲区中存储的值3、模板测试是什么&#xff08;看完以下流程就能知道模板测试是什么&#xff09;模板测试就是在渲染&#xff0c;后渲染的物体前&#xff0c;与渲染前的模板缓冲区的值进行比较&#xff0c;选出符合条件的部分…

sprinboot 引入 Elasticsearch 依赖包

1.springboot与es的版本有比较强的绑定关系&#xff0c;如果springboot工程引入es的依赖后报一些依赖的错误&#xff0c;那么就看表格中的对应关系&#xff0c;将sprinboot或者es的版本做对应的调整 2.本人是从springboot1.x升级到springboot2.x&#xff0c;做了排包工作 3.升级…

Linux学习记录——이십팔 网络基础(1)

文章目录 1、了解2、网络协议栈3、TCP/IP模型4、网络传输1、同一局域网&#xff08;子网&#xff09;2、局域网通信原理3、跨一个路由器的两个子网4、其它 详细的网络发展历史就不写了 1、了解 为什么会出现网络&#xff1f;一开始多个计算机之间想要共享文件&#xff0c;就得…

DirectX12(d3d12)初始化

一、前置要求 Windows 10及以上(安装有DirectX12)VisualStudio 2022 二、DirectX12入门 1.引用头文件 #include<Windows.h> #include<d3d12.h> #include<dxgi1_4.h>2.注册窗口类并初始化窗口 这里我们调用Windows API 通过应用程序的句柄来注册一个唯一…

uniapp视频播放功能

UniApp提供了多种视频播放组件&#xff0c;包括视频播放器&#xff08;video&#xff09;、多媒体组件&#xff08;media&#xff09;、WebView&#xff08;内置Video标签&#xff09;等。其中&#xff0c;video和media组件是最常用的。 video组件 video组件是基于HTML5 vide…

default 和 delete 与默认构造函数 的使用

前言 使用default和delete关键字来干预编译器自动生成的函数。让我详细解释一下这些知识点&#xff1a; 正文 编译器生成的默认构造函数&#xff1a; 如果类A没有定义任何构造函数&#xff0c;那么编译器会自动生成一个无参的默认构造函数 A()。这个默认构造函数实际上是一个…

【AI视野·今日CV 计算机视觉论文速览 第248期】Mon, 18 Sep 2023

AI视野今日CS.CV 计算机视觉论文速览 Mon, 18 Sep 2023 Totally 83 papers &#x1f449;上期速览✈更多精彩请移步主页 Interesting: &#x1f4da;Robust e-NeRF,处理高速且大噪声事件相机流的NERF模型。(from NUS新加坡国立) 稀疏噪声事件与稠密事件数据的区别&#xff1a;…

阿里云无影云电脑和传统PC有什么区别?

阿里云无影云电脑和传统电脑PC有什么区别&#xff1f;区别大了&#xff0c;无影云电脑是云端的桌面服务&#xff0c;传统PC是本地的硬件计算机&#xff0c;无影云电脑的数据是保存在云端&#xff0c;本地传统PC的数据是保存在本地客户端&#xff0c;阿里云百科分享阿里云无影云…

计数排序与基数排序

计数排序与基数排序 计数排序 计数排序&#xff1a;使用一个数组记录序列中每一个数字出现的次数&#xff0c;将该数组的下标作为实际数据&#xff0c;元素的值作为数据出现的次数。例如对于序列[3,0,1,1,3,3,0,2]&#xff0c;统计的结果为&#xff1a; 0出现的次数&#xf…

Redis模块二:缓存分类 + Redis模块三:常见缓存(应用)

缓存大致可以分为两大类&#xff1a;1&#xff09;本地缓存 2&#xff09;分布式缓存 目录 本地缓存 分布式缓存 常见缓存的使用 本地缓存&#xff1a;Spring Cache 分布式缓存&#xff1a;Redis 本地缓存 本地缓存也叫单机缓存&#xff0c;也就是说可以应⽤在单机环…

ffmpeg编译 Error: operand type mismatch for `shr‘

错误如下&#xff1a; D:\msys2\tmp\ccUxvBjQ.s: Assembler messages: D:\msys2\tmp\ccUxvBjQ.s:345: Error: operand type mismatch for shr D:\msys2\tmp\ccUxvBjQ.s:410: Error: operand type mismatch for shr D:\msys2\tmp\ccUxvBjQ.s:470: Error: operand type mismatch…

[JAVAee]Spring项目的创建与基本使用

目录 Spring项目的创建 Spring中Bean对象的存储与获取 存储Bean对象 获取并使用Bean对象 getBean方法的重载 本文章介绍了Spring项目创建与使用的过程与一定的注意事项. Spring项目的创建 首先在IDEA中,新建一个Maven 第二步,在pom.xml中写入spring的依赖. pom.xml是mav…

情侣飞行棋 情侣小游戏 2023 抖音

飞行棋网站地址:https://effect.guoyaxue.top/fxq/index.html#/ 以及各种新版来袭&#xff1a; 以及各种情侣小游戏合集 https://fxnew.guoyaxue.top/#/

【微信小程序】swiper的使用

1.swiper的基本使用 <jxz-header></jxz-header> <view class"banner"><swiperprevious-margin"30rpx"autoplayinterval"2000"indicator-dotsindicator-color"rgba(0,0,0,0.3)"indicator-active-color"#bda…

一百七十九、Linux——Linux报错No package epel-release available

一、目的 在Linux中配置Xmanager服务时&#xff0c;执行脚本时Linux报错No package epel-release available 二、解决措施 &#xff08;一&#xff09;第一步&#xff0c;# wget http://dl.fedoraproject.org/pub/epel/epel-release-latest-7.noarch.rpm &#xff08;二&…