PyTorch(一)模型训练过程

PyTorch(一)模型训练过程

#c 总结 实践总结

该实践从「数据处理」开始到最后利用训练好的「模型」预测,感受到了整个模型的训练过程。其中也有部分知识点,例如定义神经网络,只是初步的模仿,有一个比较浅的认识,还需要继续学习原理。

整个流程:
「准备数据」,「创建数据加载器」,「选择训练设备」,「定义神经网络」,「定义损失函数和优化器」,「定义训练和测试函数」,「迭代训练」,「保存模型」,「加载模型」,「模型预测」。

相关知识点:
1.Dataset与DataLoader
2.迭代器
3.模型定义
4.损失函数与优化器
5.模型训练与测试
6.模型保存,加载与预测

1 数据处理

#d Dataset与DataLoader

在处理数据时,PyTorch有两个基本的原语来与数据交互:torch.utils.data.DataLoadertorch.utils.data.DatasetDataset 用于存储样本以及它们相应的标签,而 DataLoader 围绕 Dataset 封装了一个「迭代器」。
Dataset 类通常用来定义数据集,它包含了数据和标签。而 DataLoader 类则是用来批量加载数据集,支持自动加载、打乱数据、多线程加载等功能,使得数据的加载更加高效和灵活。

#e 导入库 Dataset与DataLoader

import torch
from torch import nn# 神经网络模块
from torch.utils.data import DataLoader# 数据加载器
from torchvision import datasets# 数据集
from torchvision.transforms import ToTensor# 图像转换为张量

#c 补充 特定领域库 Dataset与DataLoader

PyTorch 提供了特定领域的库,比如 TorchTextTorchVisionTorchAudio,它们都包含了Datasettorchvision.datasets 模块包含了许多现实世界视觉数据集的 Dataset 对象,例如 CIFAR、COCO。每个 TorchVisionDataset都包括两个参数:transformtarget_transform,它们分别用来修改样本和标签。

#d 迭代器(Iterable)

迭代器(Iterable)是一种允许程序员遍历一个容器(特别是列表等序列类型)的对象。在Python中,迭代器遵循迭代协议,即它们实现了__iter__()方法,该方法返回一个迭代器对象本身,这个对象还需要实现__next__()方法,该方法在每次迭代时返回容器中的下一个项目。通过提供一种统一、高效、按需处理数据的方式,极大地简化了数据遍历和处理的复杂性。

「迭代器」解决的问题:

  1. 统一的遍历接口:迭代器提供了一种统一的方法来遍历各种类型的数据容器(如列表、元组、字典等),而不需要知道容器的内部结构。
  2. 内存效率:迭代器允许按需遍历元素,而不是一次性将所有元素加载到内存中。这对于遍历大数据集特别有用,因为它可以显著减少程序的内存使用。
  3. 惰性计算:迭代器支持惰性计算,这意味着数据元素是在需要时才被计算和返回,而不是在迭代器创建时。这可以提高计算效率,特别是在处理复杂或无限的数据序列时。

没有「迭代器」的影响:

  1. 遍历复杂性增加:没有迭代器,程序员需要为不同类型的数据结构编写不同的遍历代码,这不仅增加了开发的复杂性,也降低了代码的可重用性。
  2. 内存效率降低:在处理大型数据集时,可能需要一次性将所有数据加载到内存中,这会导致显著的内存消耗,甚至可能导致内存不足的错误。
  3. 减少惰性计算的机会:没有迭代器机制,很难实现按需计算数据元素的逻辑,这可能导致不必要的计算开销,特别是在只需要数据集一小部分或者在数据集很大时。

#e 吃自助餐 迭代器

想象一下你在一家餐厅吃自助餐。自助餐提供了一个装满不同菜肴的长桌子,你拿着一个盘子,从一端开始,挨个检查每种菜肴,决定是否将其加入你的盘子。在这个过程中,你(顾客)就像一个迭代器,而长桌子上的菜肴就像是一个可迭代的容器。你一次检查一个菜肴,直到遍历完所有的菜肴,或者你的盘子满了为止。

#e 迭代访问列表 迭代器

假设我们有一个列表(List)numbers = [1, 2, 3, 4, 5],使用iter(numbers)创建了一个迭代器,它能够遍历列表numbers中的每个元素。使用next(iterator)可以获取容器中的下一个元素。当所有元素都被遍历完毕时,next()会抛出一个StopIteration异常,表示没有更多元素可以访问,这时我们结束循环。

numbers = [1, 2, 3, 4, 5]  # 可迭代的容器
iterator = iter(numbers)  # 创建迭代器while True:try:# 使用next()获取下一个元素number = next(iterator)print(number)except StopIteration:# 如果所有元素都遍历完毕,则结束循环break

#c 关联 相关概念

「迭代器」影响的「概念」:

  1. 可迭代对象(Iterable):任何实现了__iter__()方法的对象都是可迭代的,该方法需要返回一个迭代器对象。迭代器本身也是可迭代的,因为它实现了__iter__()方法,并返回自身。

  2. 生成器(Generator):生成器是一种特殊类型的迭代器,它使用函数加上yield语句来实现,无需手动实现__iter__()__next__()方法。生成器简化了迭代器的创建过程,直接受到了迭代器概念的启发。

  3. 循环(Loops):例如for循环和while循环,在Python中,for循环内部实际上使用迭代器来遍历可迭代对象。

  4. 函数式编程工具:如map()filter()reduce()等函数,它们接受一个函数和一个可迭代对象作为输入,内部通过迭代器遍历可迭代对象。

影响「迭代器」的概念:

  1. 面向对象编程(OOP):迭代器模式是面向对象设计模式的一部分,要求对象实现特定的接口(如Python中的__iter__()__next__()方法)。面向对象的概念提供了迭代器实现的框架。

  2. 惰性计算(Lazy Evaluation):惰性计算是指仅在真正需要计算结果时才进行计算。迭代器天然支持惰性计算,因为它们一次只处理集合中的一个元素。

  3. 函数式编程(Functional Programming):函数式编程强调使用函数来处理数据。迭代器与函数式编程紧密相关,因为迭代器提供了一种遍历和处理数据集合的方法,而不改变数据本身,这与函数式编程的不可变性原则相吻合。

#c 说明 数据集的选择

本次实践使用 FashionMNIST 数据集。该数据集是一个用于衣物识别的数据集,由Zalando(一家欧洲的在线时尚零售商)提供。它被设计为原始MNIST数据集的直接替代品,用于在机器学习和计算机视觉领域的基准测试中。FashionMNIST包含了10个类别的衣物图片,每个类别有7000张图片,整个数据集分为60000张训练图片和10000张测试图片。每张图片都是28x28像素的灰度图。这些类别包括:

  1. T-shirt/top(T恤/上衣)
  2. Trouser(裤子)
  3. Pullover(套衫)
  4. Dress(连衣裙)
  5. Coat(外套)
  6. Sandal(凉鞋)
  7. Shirt(衬衫)
  8. Sneaker(运动鞋)
  9. Bag(包)
  10. Ankle boot(短靴)

#e 下载数据集 数据集的选择

如果是自行搜集数据,比如利用爬虫获取自己想要的数据,获取的数据需要进行「数据处理」,例如「删除不符合数据」,「统一数据格式」,「去重」等方式。这里下载的数据已经是符合训练的数据格式,所以不需要进行对应的数据处理的环节。

# 下载训练数据集
train_data = datasets.FashionMNIST(root="data",  # 数据存储的路径train=True,   # 指定下载的是训练数据集download=True,  # 如果数据不存在,则通过网络下载transform=ToTensor()  # 将图片转换为Tensor
)# 下载测试数据集
test_data = datasets.FashionMNIST(root="data",  # 数据存储的路径train=False,  # 指定下载的是测试数据集download=True,  # 如果数据不存在,则通过网络下载transform=ToTensor()  # 将图片转换为Tensor
)

#d 数据加载

Dataset 作为参数传递给 DataLoaderDataLoaderdataset封装一个可迭代对象,并且支持自动批处理、采样、多进程数据加载等。

#e 加载代码 数据加载

在这里,定义了一个批量大小为64,即 dataloader 可迭代对象中的每个元素将返回一个包含64个特征和标签的批次。

batch_size = 64# 批大小# 创建数据加载器
train_dataloader = DataLoader(train_data, batch_size=batch_size)
#将dataset作为参数传入DataLoader,DataLoader会自动将数据分批,打乱数据,将数据加载到内存中
test_dataloader = DataLoader(test_data, batch_size=batch_size)for x,y in test_dataloader:print(f"Shape of x [N, C, H, W]: {x.shape}")#x.shape是一个4维张量,第一个维度是批大小,第二个维度是通道数,第三和第四维度是图像的高度和宽度print(f"Shape of y: {y.shape}, {y.dtype}")'''Shape of x [N, C, H, W]: torch.Size([64, 1, 28, 28])Shape of y: torch.Size([64]), torch.int64'''break

2 创建模型

#d 定义模型

在PyTorch中定义神经网络,需创建一个继承自nn.Module的类,并在__init__函数中定义神经网络的层,在forward函数中定义数据在神经网络中的传播路径。为了加速神经网络的训练,可以使用GPU或者MPS来训练模型。

#e 定义代码 定义模型

#使用cpu,gpu,mps的设备来训练模型
device =("cuda"if torch.cuda.is_available()else "mps"if torch.backends.mps.is_available()else "cpu"
)
print(f"Using {device} device")
#Using cuda deviceclass NeuralNetwork(nn.Module):def __init__(self):#定义神经网络的层super().__init__()#调用父类的构造函数self.flatten = nn.Flatten()#将28*28的图像展平为784的向量self.linear_relu_stack = nn.Sequential(#定义一个包含三个线性层的神经网络nn.Linear(28*28,512),#输入层nn.ReLU(),#激活函数nn.Linear(512,512),#隐藏层nn.ReLU(),#激活函数nn.Linear(512,10),#输出层)def forward(self,x):#定义数据在神经网络中的传播路径x = self.flatten(x)#将图像展平logits = self.linear_relu_stack(x)#将展平后的图像传入神经网络return logits#返回输出
model = NeuralNetwork().to(device)#将模型加载到设备上
print(model)
'''
NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(linear_relu_stack): Sequential((0): Linear(in_features=784, out_features=512, bias=True)(1): ReLU()(2): Linear(in_features=512, out_features=512, bias=True)(3): ReLU()(4): Linear(in_features=512, out_features=10, bias=True))
)
'''

3 优化模型参数

#d 定义训练参数

  1. 在训练模型之前,需要定义「损失函数(loss function)」[ 和「优化器(optimizer)」。概念解释(5 相关概念)

  2. 在单个训练循环中,模型会对分批提供它的「训练数据集」进行「预测」并通过「反向传播算法」预测误差以调整模型的参数。

  3. 检查模型在测试数据集上的性能,以确保它在学习.

  4. 训练过程在多个迭代(周期)中进行。在每个周期中,模型学习参数以做出更好的预测。在每个周期打印模型的准确率和损失,希望看到准确率随着每个周期的增加而提高,损失随着每个周期的减少。

#e 损失函数与优化器

loss_fn = nn.CrossEntropyLoss()#使用交叉熵损失函数
#使用随机梯度下降优化器,model.parameters()返回模型的参数,lr=1e-3是学习率
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

#e 训练函数

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)#数据集的大小model.train()#将模型设置为训练模式for batch, (X, y) in enumerate(dataloader):#遍历数据集X, y = X.to(device), y.to(device)#将数据加载到设备上# 计算预测误差pred = model(X)#对输入的数据进行预测loss = loss_fn(pred, y)#计算损失,差异越小,模型预测的越准确# 反向传播loss.backward()#反向传播算法optimizer.step()#优化器更新模型参数optimizer.zero_grad()#梯度清零if batch % 100 == 0:#每100个批次打印一次loss, current = loss.item(), (batch+1) * len(X)#打印损失和当前的批次的数据量print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")'''{loss:>7f}:表示损失值以浮点数形式打印,总宽度为7位,右对齐。{current:>5d}:表示当前处理的总数据量以整数形式打印,总宽度为5位,右对齐。{size:>5d}:表示整个数据集的大小以整数形式打印,总宽度为5位,右对齐'''

#e 测试函数

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)#数据集的大小num_batches = len(dataloader)#批次的数量model.eval()#将模型设置为评估模式test_loss, correct = 0, 0#初始化损失和正确的数量with torch.no_grad():#关闭梯度计算for X, y in dataloader:X, y = X.to(device), y.to(device)#将数据加载到设备上pred = model(X)#对输入的数据进行预测test_loss += loss_fn(pred, y).item()#计算损失correct += (pred.argmax(1) == y).type(torch.float).sum().item()#计算正确的数量'''pred.argmax(1)找出每个预测中概率最高的类别的索引,== y判断这些索引是否与真实标签相等。结果是一个布尔Tensor,通过.type(torch.float)转换为浮点数Tensor,然后使用.sum().item()计算并累加正确预测的总数。'''test_loss /= num_batches#计算平均损失correct /= size#计算正确率print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

#e 迭代训练

epochs = 5#迭代次数
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)#训练模型test(test_dataloader, model, loss_fn)#测试模型
print("Done!")#训练完成
'''
运行结果
Epoch 1
-------------------------------
loss: 2.304268  [   64/60000]
loss: 2.284021  [ 6464/60000]
loss: 2.263621  [12864/60000]
loss: 2.259448  [19264/60000]
loss: 2.231920  [25664/60000]
loss: 2.221592  [32064/60000]
loss: 2.215944  [38464/60000]
loss: 2.191191  [44864/60000]
loss: 2.177027  [51264/60000]
loss: 2.141848  [57664/60000]
Test Error: Accuracy: 58.7%, Avg loss: 2.137664Epoch 2
-------------------------------
loss: 2.147467  [   64/60000]
loss: 2.139907  [ 6464/60000]
loss: 2.077062  [12864/60000]
loss: 2.094236  [19264/60000]
loss: 2.030329  [25664/60000]
loss: 1.982215  [32064/60000]
loss: 1.997371  [38464/60000]
loss: 1.923110  [44864/60000]
loss: 1.913458  [51264/60000]
loss: 1.835431  [57664/60000]
Test Error: Accuracy: 61.3%, Avg loss: 1.839774
'''

4 模型的保存

#d 保存方式

保存模型的一种常见方法是序列化内部状态字典(包含模型参数)。

#e 实现代码 保存方式

torch.save(model.state_dict(), "./model.pth")
print("Saved PyTorch Model State to ./model.pth")
'''
Saved PyTorch Model State to ./model.pth
'''

5 模型加载与预测

#d 加载流程

加载模型的过程包括重新创建模型结构,并将状态字典加载到其中。

#e 加载代码 加载流程

model = NeuralNetwork().to(device)#创建模型,to(device)将模型加载到设备上
model.load_state_dict(torch.load("./model.pth"))#加载模型

#e 预测代码

#利用模型进行预测
classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]
model.eval()#将模型设置为评估模式
x, y = test_data[0][0], test_data[0][1]#获取测试数据
with torch.no_grad():#关闭梯度计算pred = model(x.to(device))#对输入的数据进行预测predicted, actual = classes[pred[0].argmax(0)], classes[y]#获取预测的类别和真实的类别print(f'Predicted: "{predicted}", Actual: "{actual}"')'''Predicted: "Ankle boot", Actual: "Ankle boot"'''

#c 备注 完整python文件

AI_series_learn/PyTorch/1.快速开始/basic.py at main · togetherhkl/AI_series_learn (github.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/31685.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

力扣456.132模式

力扣456.132模式 单调栈 维护单调递减的栈用k维护‘2’&#xff0c;每次出栈就更新**&#xff08;保证栈中元素始终大于k&#xff09;**当出现nums[i] < k时 说明存在‘1’又因为栈中存在‘3’因此就存在132模式序列 class Solution {public:bool find132pattern(vector&…

C语言C99标准、C11标准新增加的特性

C语言标准 C语言从其诞生至今&#xff0c;经历了多个标准的更新&#xff0c;主要标准包括&#xff1a; C89/C90 (ANSI C / ISO/IEC 9899:1990)&#xff1a;这是C语言的第一个官方标准&#xff0c;由ANSI于1989年发布&#xff0c;后被ISO采纳为国际标准&#xff0c;发布于1990年…

echarts+vue2实战(一)

目录 一、项目准备 二、(横向分页)柱状图 2.1、动态刷新 2.2、UI调整 2.3、分辨率适配 三、(竖向平移)柱状图 3.1、平移动画 3.2、不同数值显示不同颜色 四、(下拉切换)折线图 4.1、切换图表和分辨率适配 4.2、UI调整 五、(三级分类)饼图 5.1、数据切换 六、圆环…

使用Tkinter创建带查找功能的文本编辑器

使用Tkinter创建带查找功能的文本编辑器 介绍效果代码解析创建主窗口添加菜单栏实现文件操作实现查找 完整代码 介绍 在这篇博客中&#xff0c;我将分享如何使用Python的Tkinter库创建一个带有查找功能的简单文本编辑器。 效果 代码解析 创建主窗口 import tkinter as tkcl…

Offset Explorer 连接SASL PLAIN鉴权的Kafka

1、填写Kafka信息 2、配置鉴权信息 Security 选择 SASL PLAINTEXT JAAS Config 配置账号密码 org.apache.kafka.common.security.plain.PlainLoginModule required username"账号"password"密码";

[Vulnhub] Troll FTP匿名登录+定时任务权限提升

信息收集 IP AddressPorts Opening192.168.8.104TCP:21,22,80 $ nmap -sC -sV 192.168.8.104 -p- --min-rate 1000 Nmap scan report for 192.168.8.104 (192.168.8.104) Host is up (0.0042s latency). Not shown: 65532 closed tcp ports (conn-refused) PORT STATE SER…

openh264 宏块级码率控制源码分析

openh264 宏块级码率控制函数关系 宏块级核心函数分析 WelsRcMbInitGom函数 功能&#xff1a;openh264 码率控制框架中宏块级码率控制函数&#xff0c;根据是否启用GOM QP来决定如何设置宏块的QP值&#xff0c;以控制编码的质量和比特率。原理过程&#xff1a; 函数参数&…

“打造智能售货机系统,基于ruoyi微服务版本开源项目“

目录 # 开篇 售货机术语 1. 表设计说明 2. 页面展示 2.1 区域管理页面 2.2 合作商管理页面 2.3 点位管理页面 3. 建表资源 3.1 创建表的 SQL 语句&#xff08;包含字段备注&#xff09; 1. Region 表 2. Node 表 3. Partner 表 4. 创建 tb_vending_machine 表的 S…

学习java第一百零六天

Spring的后置处理器 BeanPostProcessor&#xff1a;Bean的后置处理器&#xff0c;主要在bean初始化前后工作。 InstantiationAwareBeanPostProcessor&#xff1a;继承于BeanPostProcessor&#xff0c;主要在实例化bean前后工作&#xff1b; AOP创建代理对象就是通过该接口实现…

用LangChain调用Ollama的时候一个小问题

说来让人无语&#xff0c;简单记录一下。安装好Ollama后&#xff0c;我们通常通过访问http://127.0.0.1:11434来测试其是否正常&#xff0c;通常会出来“Ollama is running”&#xff0c;然后我习惯性地从Chrome把地址拷贝到VS Code&#xff0c; oembed OllamaEmbeddings(bas…

【启明智显产品介绍】Model3C工业级HMI芯片详解专题(一)芯片性能

【启明智显产品介绍】工业级HMI芯片Model3C详解&#xff08;一&#xff09;芯片性能 Model3C 是一款基于 RISC-V 的高性能、国产自主、工业级高清显示与智能控制 MCU&#xff0c;配置平头哥E907&#xff0c;主频400MHz&#xff0c;强大的 2D 图形加速处理器、PNG/JPEG 解码引擎…

【Conda】修改 Conda 默认的虚拟环境位置

文章目录 问题描述分析与解决查看默认安装位置修改 .condarc 文件修改权限 参考资料 问题描述 Conda 的虚拟环境默认安装在 C 盘。时间久了&#xff0c;C 盘上的内存会被大量占用&#xff0c;影响电脑性能。于是想到修改虚拟环境的默认存放位置&#xff0c;改到自定义的位置。…

找不到d3dx9_43.dll无法继续执行代码的几种解决方法

在工作或生活使用电脑都会遇到丢失dll文件应用无法启动的情况&#xff0c;比如你安装完一款你最喜欢的游戏在启动的时候提示系统缺少d3dx9_39.dll、d3dx9_40.dll、d3dx9_41.dll、d3dx9_42.dll、d3dx9_43.dll、xinput1_3.dll 文件而无法正常游戏&#xff0c;或你在工作的时候安装…

分享HTML显示2D/3D时间

效果截图 实现代码 <!DOCTYPE html> <head> <title>three.jscannon.js Web 3D</title><meta charset"utf-8"><meta name"viewport" content"widthdevice-width,initial-scale1,maximum-scale1"><meta n…

图神经网络学习笔记

文章目录 一、图神经网络应用领域分析二、图基本模块定义三、邻接矩阵的定义四、GNN中常见任务五、消息传递计算方法六、多层GCN的作用七、GCN基本模型概述八、图卷积的基本计算方法九、邻接的矩阵的变换十、GCN变换原理解读 本笔记参考自b站up主小巴只爱学习的图神经网络教程 …

【Android面试八股文】你能说一说View绘制流程与自定义View注意点吗?

文章目录 一、自定义View的构造函数以及各参数的用法二、自定义View的几种方式三、自定义View的绘制流程四、自定义View需要注意的一些点五、举个例子一、自定义View的构造函数以及各参数的用法 在Android中,自定义View通常需要提供多个构造函数,以适应不同的使用场景。主要…

创建OpenWRT虚拟机

环境&#xff1a;Ubuntu 2204&#xff0c;VM VirtualBox 7.0.18 安装必备软件包&#xff1a; sudo apt update sudo apt install subversion automake make cmake uuid-dev gcc vim build-essential clang flex bison g gawk gcc-multilib g-multilib gettext git libncurses…

C语言中操作符详解(一)

众所周知&#xff0c;在我们的C语言中有着各式各样的操作符&#xff0c;并且在此之前呢&#xff0c;我们已经认识并运用了许许多多的操作符&#xff0c;都是诸君的老朋友了昂 操作符作为我们使用C语言的一个非常非常非常重要的工具&#xff0c;诸君一定要加以重视&#xff0c;…

大模型如何改变世界?李彦宏:未来至少一半人要学会“提问题“

2023年爆火的大模型&#xff0c;对我们来说意味着什么&#xff1f; 百度创始人、董事长兼CEO李彦宏认为&#xff0c;“大模型即将改变世界。” 5月26日&#xff0c;李彦宏参加了在北京举办的2023中关村论坛&#xff0c;发表了题为《大模型改变世界》的演讲。李彦宏认为&#…

在centos服务器上部署nginx容器

1.下载nginx镜像 2.导入镜像 docker load -i nginx.tar 3. 查看导入的镜像 docker images 4. 运行镜像 docker run -d -p 80:80 --name my-nginx nginx 5. 访问Nginx 其他 1.查看所有正在运行的Docker容器:docker ps 2.查看所有镜像:docker images 3.使用Docker命令…