6-2 pytorch中训练模型的3种方法

Pytorch通常需要用户编写自定义训练循环,训练循环的代码风格因人而异。(养成自己的习惯)
有3类典型的训练循环代码风格:脚本形式训练循环,函数形式训练循环,类形式训练循环。
下面以minist数据集的多分类模型的训练为例,演示这3种训练模型的风格。
其中类形式训练循环我们同时演示torchkeras.KerasModel和torchkeras.LightModel两种示范。

准备数据

transform = transforms.Compose([transforms.ToTensor()])ds_train = torchvision.datasets.MNIST(root="./data/mnist/",train=True,download=True,transform=transform)
ds_val = torchvision.datasets.MNIST(root="./data/mnist/",train=False,download=True,transform=transform)dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=4)
dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=128, shuffle=False, num_workers=4)print(len(ds_train))
print(len(ds_val))

image.png

%matplotlib inline
%config InlineBackend.figure_format = 'svg'#查看部分样本
from matplotlib import pyplot as plt plt.figure(figsize=(8,8)) 
for i in range(9):img,label = ds_train[i] img = torch.squeeze(img) # 删除为1的维度ax=plt.subplot(3,3,i+1)ax.imshow(img.numpy())ax.set_title("label = %d"%label)ax.set_xticks([])ax.set_yticks([]) 
plt.show()

image.png

一、脚本风格

脚本风格的训练循环非常常见。

net = nn.Sequential()
net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3))
net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5))
net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("dropout",nn.Dropout2d(p = 0.1))
net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
net.add_module("flatten",nn.Flatten())
net.add_module("linear1",nn.Linear(64,32))
net.add_module("relu",nn.ReLU())
net.add_module("linear2",nn.Linear(32,10))print(net)

image.png
代码量较多,可以查看最下方链接对应的notebook。

二、函数风格

该风格在脚本形式上做了进一步的函数封装。

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.layers = nn.ModuleList([nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Dropout2d(p = 0.1),nn.AdaptiveMaxPool2d((1,1)),nn.Flatten(),nn.Linear(64,32),nn.ReLU(),nn.Linear(32,10)])def forward(self,x):for layer in self.layers:x = layer(x)return x
net = Net()
print(net)

image.png
代码量较多,可以查看最下方链接对应的notebook。

三、类风格

此处使用**torchkeras.KerasModel(其源码其实就是脚本风格中的代码)**高层次API接口中的fit方法训练模型。
使用该形式训练模型非常简洁明了。
先构建模型,同一二。

from torchkeras import KerasModel class Net(nn.Module):def __init__(self):super().__init__()self.layers = nn.ModuleList([nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Dropout2d(p = 0.1),nn.AdaptiveMaxPool2d((1,1)),nn.Flatten(),nn.Linear(64,32),nn.ReLU(),nn.Linear(32,10)])def forward(self,x):for layer in self.layers:x = layer(x)return xnet = Net() print(net)

使用kerasModel:

from torchmetrics import Accuracymodel = KerasModel(net,loss_fn=nn.CrossEntropyLoss(),metrics_dict = {"acc":Accuracy(task='multiclass',num_classes=10)},optimizer = torch.optim.Adam(net.parameters(),lr = 0.01)  )model.fit(train_data = dl_train,val_data= dl_val,epochs=10,patience=3,monitor="val_acc", mode="max",plot=True,cpu=True
)

训练过程:
image.png
其实编码训练代码按照自己的习惯即可,不必要按照以上三种方式。

参考:https://github.com/lyhue1991/eat_pytorch_in_20_days

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/81918.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Boot集成Redis实现数据缓存

🌿欢迎来到衍生星球的CSDN博文🌿 🍁本文主要学习Spring Boot集成Redis实现数据缓存 🍁 🌱我是衍生星球,一个从事集成开发的打工人🌱 ⭐️喜欢的朋友可以关注一下🫰🫰&…

【最新面试问题记录持续更新,java,kotlin,android,flutter】

最近找工作,复习了下java相关的知识。发现已经对很多概念模糊了。记录一下。部分是往年面试题重新整理,部分是自己面试遇到的问题。持续更新中~ 目录 java相关1. 面向对象设计原则2. 面向对象的特征是什么3. 重载和重写4. 基本数据类型5. 装箱和拆箱6. …

华为aarch64架构的泰山服务器EulerOS 2.0 (SP8)系统离线安装saltstack3003.1实践

华为泰山服务器的CPU芯片架构为aarch64,所装系统为EulerOS 2.0 (SP8)aarch64系统,安装saltstack比较困难。本文讲解通过pip安装方式离线安装saltstack3003.1以进行集中化管理和维护。 一、系统环境 1、操作系统版本 [rootlocalhost ~]# cat /etc/os-r…

uniapp 可输入可选择的........框

安装 uniapp: uni-combox地址 vue页面 <uni-combox :border"false" input"selectname" focus"handleFocus" blur"handleBlur" :candidates"candidates" placeholder"请选择姓名" v-model"name"&g…

Flask实现Web服务调用Python程序

Flask实现Web服务调用Python程序_flask调用python程序_小白白程序员的博客-CSDN博客 【小沐学Python】Python实现Web服务器&#xff08;Flask入门&#xff09;_python flask web开发_爱看书的小沐的博客-CSDN博客

在微信公众号怎么实现投票活动

微信公众号实现投票活动的方法和步骤 一、投票活动的优势 通过投票活动&#xff0c;微信公众号可以实现用户参与、增加互动、了解用户需求等功能&#xff0c;同时也可以提升品牌知名度和用户粘性。以下是一些投票活动的优势&#xff1a; 增加用户参与度&#xff1a;通过投票活…

SpringCloudGateway网关实战(三)

SpringCloudGateway网关实战&#xff08;三&#xff09; 上一章节我们讲了gateway的内置过滤器Filter&#xff0c;本章节我们来讲讲全局过滤器。 自带全局过滤器 在实现自定义全局过滤器前&#xff0c; spring-cloud-starter-gateway依赖本身就自带一些全局过滤器&#xff0…

嵌入式-C语言中的常量

目录 一.简介 二.常量类型 一.简介 在C语言中&#xff0c;常量是指不可改变的值&#xff0c;即固定不变的数据。常量可以是数值、字符或字符串等类型的值。常量可以直接在代码中使用&#xff0c;而无需在程序运行时进行修改。 二.常量类型 C语言中的常量有以下几种类型&…

Unity中Shader的模板测试

文章目录 前言什么是模板测试1、模板缓冲区2、模板缓冲区中存储的值3、模板测试是什么&#xff08;看完以下流程就能知道模板测试是什么&#xff09;模板测试就是在渲染&#xff0c;后渲染的物体前&#xff0c;与渲染前的模板缓冲区的值进行比较&#xff0c;选出符合条件的部分…

sprinboot 引入 Elasticsearch 依赖包

1.springboot与es的版本有比较强的绑定关系&#xff0c;如果springboot工程引入es的依赖后报一些依赖的错误&#xff0c;那么就看表格中的对应关系&#xff0c;将sprinboot或者es的版本做对应的调整 2.本人是从springboot1.x升级到springboot2.x&#xff0c;做了排包工作 3.升级…

Linux学习记录——이십팔 网络基础(1)

文章目录 1、了解2、网络协议栈3、TCP/IP模型4、网络传输1、同一局域网&#xff08;子网&#xff09;2、局域网通信原理3、跨一个路由器的两个子网4、其它 详细的网络发展历史就不写了 1、了解 为什么会出现网络&#xff1f;一开始多个计算机之间想要共享文件&#xff0c;就得…

L1-006 连续因子

一个正整数 N 的因子中可能存在若干连续的数字。例如 630 可以分解为 3567&#xff0c;其中 5、6、7 就是 3 个连续的数字。给定任一正整数 N&#xff0c;要求编写程序求出最长连续因子的个数&#xff0c;并输出最小的连续因子序列。 输入格式&#xff1a; 输入在一行中给出一…

DirectX12(d3d12)初始化

一、前置要求 Windows 10及以上(安装有DirectX12)VisualStudio 2022 二、DirectX12入门 1.引用头文件 #include<Windows.h> #include<d3d12.h> #include<dxgi1_4.h>2.注册窗口类并初始化窗口 这里我们调用Windows API 通过应用程序的句柄来注册一个唯一…

算法通关村-----回溯模板如何解决排列组合问题

组合总和 问题描述 给你一个 无重复元素 的整数数组 candidates 和一个目标整数 target &#xff0c;找出 candidates 中可以使数字和为目标数 target 的 所有 不同组合 &#xff0c;并以列表形式返回。你可以按 任意顺序 返回这些组合。candidates 中的 同一个 数字可以 无限…

uniapp视频播放功能

UniApp提供了多种视频播放组件&#xff0c;包括视频播放器&#xff08;video&#xff09;、多媒体组件&#xff08;media&#xff09;、WebView&#xff08;内置Video标签&#xff09;等。其中&#xff0c;video和media组件是最常用的。 video组件 video组件是基于HTML5 vide…

default 和 delete 与默认构造函数 的使用

前言 使用default和delete关键字来干预编译器自动生成的函数。让我详细解释一下这些知识点&#xff1a; 正文 编译器生成的默认构造函数&#xff1a; 如果类A没有定义任何构造函数&#xff0c;那么编译器会自动生成一个无参的默认构造函数 A()。这个默认构造函数实际上是一个…

【AI视野·今日CV 计算机视觉论文速览 第248期】Mon, 18 Sep 2023

AI视野今日CS.CV 计算机视觉论文速览 Mon, 18 Sep 2023 Totally 83 papers &#x1f449;上期速览✈更多精彩请移步主页 Interesting: &#x1f4da;Robust e-NeRF,处理高速且大噪声事件相机流的NERF模型。(from NUS新加坡国立) 稀疏噪声事件与稠密事件数据的区别&#xff1a;…

阿里云无影云电脑和传统PC有什么区别?

阿里云无影云电脑和传统电脑PC有什么区别&#xff1f;区别大了&#xff0c;无影云电脑是云端的桌面服务&#xff0c;传统PC是本地的硬件计算机&#xff0c;无影云电脑的数据是保存在云端&#xff0c;本地传统PC的数据是保存在本地客户端&#xff0c;阿里云百科分享阿里云无影云…

解决Vue项目中的“Cannot find module ‘vue-template-compiler‘”错误

1. 问题描述 在Vue项目中&#xff0c;当我们使用Vue的单文件组件&#xff08;.vue文件&#xff09;时&#xff0c;有时会遇到以下错误信息&#xff1a; ERROR: Cannot find module vue-template-compiler这个错误通常发生在我们使用Vue的版本不匹配或者缺少必要的依赖模块时。…

什么是回归测试?

什么是回归测试&#xff1f; 回归测试被定义为一种软件测试类型&#xff0c;以确认最近的程序或代码更改未对现有功能产生不利影响。 回归测试只不过是全部或部分选择已执行的测试用例&#xff0c;然后重新执行以确保现有功能正常运行。 进行此测试是为了确保新代码更改不会…