PyTorch 学习笔记

环境:python3.8 + PyTorch2.4.1+cpu + PyCharm

参考链接:

快速入门 — PyTorch 教程 2.6.0+cu124 文档

PyTorch 文档 — PyTorch 2.4 文档

快速入门

导入库

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

加载数据集

使用 FashionMNIST 数据集。每个 TorchVision 都包含两个参数: 分别是 修改样本 和 标签。

# Download training data from open datasets.
training_data = datasets.FashionMNIST(root="data",     # 数据集存储的位置train=True,      # 加载训练集(True则加载训练集)download=True,   # 如果数据集在指定目录中不存在,则下载(True才会下载)transform=ToTensor(), # 应用于图像的转换列表,例如转换为张量和归一化
)# Download test data from open datasets.
test_data = datasets.FashionMNIST(root="data",train=False,     # 加载测试集(False则加载测试集)download=True,transform=ToTensor(),
)

创建数据加载器

batch_size = 64# Create data loaders.
# DataLoader():batch_size每个批次的大小,shuffle=True则打乱数据
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)for X, y in test_dataloader: # 遍历训练数据加载器,x相当于图片,y相当于标签print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break

 

创建模型

为了在 PyTorch 中定义神经网络,我们创建一个继承 来自 nn.模块。我们定义网络的各层 ,并在函数中指定数据如何通过网络。要加速 作,我们将其移动到 CUDA、MPS、MTIA 或 XPU 等加速器。如果当前加速器可用,我们将使用它。否则,我们使用 CPU。__init__forward

#使用加速器,并打印当前使用的加速器(当前加速器可用则使用当前的,否则使用cpu)
# device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu" # torch2.4.2并没有accelerator这个属性,2.6的才有,所以注释掉
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")# 检查 CUDA 是否可用
print("CUDA available:", torch.cuda.is_available())# Define model
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork().to(device) #torch2.4.2并没有accelerator这个属性,2.6的才有,所以注释掉不用
# model = NeuralNetwork()
print(model)

优化模型参数

要训练模型,我们需要一个损失函数和一个优化器:

loss_fn = nn.CrossEntropyLoss() # 损失函数,nn.CrossEntropyLoss()用于多分类
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) # 优化器,用于更新模型的参数,以最小化损失函数
'''
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
优化器用PyTorch 提供的随机梯度下降(Stochastic Gradient Descent, SGD)优化器
model.parameters():将模型的参数传递给优化器,优化器会根据这些参数计算梯度并更新它们
lr=1e-3:学习率(learning rate),控制每次参数更新的步长
(较大的学习率可能导致训练不稳定,较小的学习率可能导致训练速度变慢)
'''

在单个训练循环中,模型对训练集进行预测(分批提供给它),并且 反向传播预测误差以调整模型的参数:

'''
训练模型(单个epoch)
dataloader:数据加载器,用于按批次加载训练数据
model     :神经网络模型
loss_fn   :损失函数,用于计算预测值与真实值之间的误差
optimizer :优化器,用于更新模型参数
'''
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train() # 将模型设置为训练模式(启用 dropout 和 batch normalization 的训练行为)for batch, (X, y) in enumerate(dataloader): # 遍历 dataloader 中的每个批次,获取输入 X 和标签 yX, y = X.to(device), y.to(device) # 将数据移动到指定设备(如 GPU 或 CPU)# Compute prediction error# 计算预测损失,同时也是前向传播pred = model(X)         # 模型的预测值,即模型的输出loss = loss_fn(pred, y) # 计算损失:y为实际的类别标签# Backpropagation 反向传播和优化# 梯度清零应在每次反向传播之前执行,以避免梯度累积(先用optimizer.zero_grad())loss.backward()       # 计算梯度optimizer.step()      # 使用优化器更新模型参数optimizer.zero_grad() # 清除之前的梯度(清零梯度,为下一轮计算做准备)# 梯度清零应在每次反向传播之前执行,以避免梯度累积(在计算模型预测值前先用optimizer.zero_grad())if batch % 100 == 0: # 每 100 个批次打印一次损失值和当前处理的样本数量loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

 进度条显示

  • 如果数据集较大,训练过程可能较慢。可以使用 tqdm 库添加进度条,提升用户体验。例如:
from tqdm import tqdm
for batch, (X, y) in enumerate(tqdm(dataloader, desc="Training")):...

我们还根据测试集检查模型的性能,以确保它正在学习:

# 测试模型
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 测试集的总样本数num_batches = len(dataloader)   # 测试数据加载器(dataloader)的总批次数model.eval()                    # 设置为评估模式,这会关闭 dropout 和 batch normalization 的训练行为test_loss, correct = 0, 0       # 累积测试损失和正确预测的样本数with torch.no_grad(): # 禁用梯度计算,使用 torch.no_grad() 上下文管理器,避免计算梯度,从而节省内存并加速计算for X, y in dataloader:X, y = X.to(device), y.to(device) # 将数据加载到指定设备pred = model(X) # 模型预测test_loss += loss_fn(pred, y).item() # 累积损失correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 累积正确预测数# correct += (pred.argmax(1) == y).float().sum().item()  # 可以直接使用 .float(),更简洁test_loss /= num_batches    # 平均损失correct /= size             # 准确率print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 累积正确预测数
# correct += (pred.argmax(1) == y).float().sum().item()  # 可以直接使用 .float(),更简洁
'''pred.argmax(1):
pred 是模型的输出(通常是未经过 softmax 的 logits,形状为 [batch_size, num_classes])。
argmax(1) 表示在第二个维度(即类别维度)上找到最大值的索引,返回一个形状为 [batch_size] 的张量,表示每个样本的预测类别。pred.argmax(1) == y:
y 是真实标签(形状为 [batch_size]),表示每个样本的真实类别。
这一步会比较预测的类别和真实类别,返回一个布尔张量,形状为 [batch_size],其中每个元素表示对应样本的预测是否正确。.type(torch.float):
将布尔张量转换为浮点数张量(True 转为 1.0,False 转为 0.0).sum():
对浮点数张量求和,得到预测正确的样本总数.item():
将结果从张量转换为 Python 的标量(整数)
'''
  • 举例一:

pred 是模型的输出:torch.tensor([[2.5, 0.3, 0.2], [0.1, 3.2, 0.7]])

y 是真实标签:torch.tensor([0, 1])

import torchpred = torch.tensor([[2.5, 0.3, 0.2], [0.1, 3.2, 0.7]])
y = torch.tensor([0, 1])correct = (pred.argmax(1) == y).type(torch.float).sum().item()
print(correct)  # 输出: 2.0 -> 转换为整数后为 2
  • 举例二:
import torch# 模型输出(未经过 softmax 的 logits)
pred = torch.tensor([[2.0, 1.0, 0.1],  # 第一个样本的预测分数[0.5, 3.0, 0.2],  # 第二个样本的预测分数[1.2, 0.3, 2.5]]) # 第三个样本的预测分数# 真实标签
y = torch.tensor([0, 1, 2])  # 第一个样本的真实类别是 0,第二个是 1,第三个是 2# 计算预测正确的样本数
correct = (pred.argmax(1) == y).type(torch.float).sum().item()
print(f"预测正确的样本数: {correct}") # 预测正确的样本数: 3'''逐步分析
对每个样本的预测分数取最大值的索引,得到预测类别:
pred.argmax(1)  # 输出: tensor([0, 1, 2])比较预测类别和真实标签,得到布尔张量:
pred.argmax(1) == y  # 输出: tensor([True, True, True]).type(torch.float): 将布尔张量转换为浮点数张量:
(pred.argmax(1) == y).type(torch.float)  # 输出: tensor([1.0, 1.0, 1.0]).sum(): 对浮点数张量求和,得到预测正确的样本总数:
(pred.argmax(1) == y).type(torch.float).sum()  # 输出: tensor(3.0).item():将结果从张量转换为 Python 标量:
(pred.argmax(1) == y).type(torch.float).sum().item()  # 输出: 3在这个例子中,模型对所有 3 个样本的预测都正确,因此预测正确的样本数为 3。
'''
# 如果知道总样本数,可以进一步计算准确率:# 总样本数
total = len(y)# 准确率
accuracy = correct / total
print(f"准确率: {accuracy * 100:.2f}%")

训练过程分多次迭代 (epoch) 进行。在每个 epoch 中,模型会学习 参数进行更好的预测。

然后打印模型在每个 epoch 的准确率和损失,

期望看到 准确率Accuracy增加,损失Avg loss随着每个 epoch 的减少而减少:

# 跑5轮,每轮皆是先训练,然后测试
epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")

输出:

Epoch 1
-------------------------------
loss: 2.308106  [   64/60000]
loss: 2.292096  [ 6464/60000]
loss: 2.280747  [12864/60000]
loss: 2.273108  [19264/60000]
loss: 2.256617  [25664/60000]
loss: 2.240094  [32064/60000]
loss: 2.229981  [38464/60000]
loss: 2.204926  [44864/60000]
loss: 2.201917  [51264/60000]
loss: 2.178733  [57664/60000]
Test Error: Accuracy: 46.1%, Avg loss: 2.164820 Epoch 2
-------------------------------
loss: 2.178193  [   64/60000]
loss: 2.160645  [ 6464/60000]
loss: 2.110801  [12864/60000]
loss: 2.129119  [19264/60000]
loss: 2.078400  [25664/60000]
loss: 2.029629  [32064/60000]
loss: 2.044328  [38464/60000]
loss: 1.972220  [44864/60000]
loss: 1.980023  [51264/60000]
loss: 1.920835  [57664/60000]
Test Error: Accuracy: 56.2%, Avg loss: 1.906657 Epoch 3
-------------------------------
loss: 1.938616  [   64/60000]
loss: 1.902610  [ 6464/60000]
loss: 1.797264  [12864/60000]
loss: 1.844325  [19264/60000]
loss: 1.726765  [25664/60000]
loss: 1.688332  [32064/60000]
loss: 1.695883  [38464/60000]
loss: 1.605903  [44864/60000]
loss: 1.628846  [51264/60000]
loss: 1.532240  [57664/60000]
Test Error: Accuracy: 59.8%, Avg loss: 1.541237 Epoch 4
-------------------------------
loss: 1.604458  [   64/60000]
loss: 1.563167  [ 6464/60000]
loss: 1.426733  [12864/60000]
loss: 1.503305  [19264/60000]
loss: 1.376496  [25664/60000]
loss: 1.381424  [32064/60000]
loss: 1.371971  [38464/60000]
loss: 1.312882  [44864/60000]
loss: 1.342990  [51264/60000]
loss: 1.244696  [57664/60000]
Test Error: Accuracy: 62.7%, Avg loss: 1.268371 Epoch 5
-------------------------------
loss: 1.344515  [   64/60000]
loss: 1.318664  [ 6464/60000]
loss: 1.166471  [12864/60000]
loss: 1.275481  [19264/60000]
loss: 1.146058  [25664/60000]
loss: 1.179018  [32064/60000]
loss: 1.171105  [38464/60000]
loss: 1.129168  [44864/60000]
loss: 1.163182  [51264/60000]
loss: 1.077062  [57664/60000]
Test Error: Accuracy: 64.7%, Avg loss: 1.097442 Done!

保存模型

保存模型的常用方法是序列化内部状态字典(包含模型参数):

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

加载模型

加载模型的过程包括重新创建模型结构和加载 state 字典放入其中。

model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model.pth", weights_only=True))

查看安装的PyTorch版本

方法一:cmd终端查看

终端中输入:

>>>python
>>>import torch
>>>torch.__version__  //注意version前后是两个下划线

方法二:PyCharm查看

打开Pycharm,在Python控制台中输入:

或者在Pycharm的“Python软件包”中查看:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/900604.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

windows开启wsl与轻量级虚拟机管理

基于win 10 打造K8S应用开发环境(wsl & kind) 一、wsl子系统安装 1.1 确认windows系统版本 cmd/powershell 或者win r 运行winver 操作系统要> 19044 1.2 开启wsl功能 控制面板 -> 程序 -> 启用或关闭Windows功能 开启适用于Linux的…

C++ -异常之除以 0 问题(整数除以 0 编译时检测、整数除以 0 运行时检测、浮点数除以 0 编译时检测、浮点数除以 0 运行时检测)

一、整数除以 0&#xff08;编译时检测&#xff09; 1、演示 #include <iostream>using namespace std;int main() {int result 10 / 0;cout << result << endl;return 0; }程序无法运行&#xff0c;输出结果 error C2124: 被零除或对零求模2、演示解读 …

【蓝桥杯】搜索算法:剪枝技巧+记忆化搜索

1. 可行性剪枝应用 1.1. 题目 题目描述: 给定一个正整数n和一个正整数目标值target,以及一个由不同正整数组成的数组nums。要求从nums中选出若干个数,每个数可以被选多次,使得这些数的和恰好等于target。问有多少种不同的组合方式? 输入: 第一行:n和target,表示数组…

Uniapp 集成极光推送(JPush)完整指南

文章目录 前言一、准备工作1. 注册极光开发者账号2. 创建应用3. Uniapp项目准备 二、集成极光推送插件方法一&#xff1a;使用UniPush&#xff08;推荐&#xff09;方法二&#xff1a;手动集成极光推送SDK 三、配置原生平台参数四、核心功能实现1. 获取RegistrationID2. 设置别…

Linux中进程

一、认识进程 进程(PCB)内核数据结构(task_struct)程序的代码和数据 每一个进程都有其独立的task_struct,OS对众多的task_struct进行管理&#xff0c;如何管理&#xff1f;先描述再组织&#xff0c;所有运⾏在系统⾥的进程都以task_struct链表的形式存在内核⾥&#xff0c;而…

国外的AI工具

一 OpenAI &#xff1a; &#x1f4a1; 总览&#xff1a; 名称全称/代号简介GPT-4o“o” omniOpenAI 最新的旗舰多模态模型&#xff08;文字、图像、音频三模态&#xff09;&#xff0c;比 GPT-4 更强、更快、更便宜。GPT-4o-mini精简版 GPT-4o轻量级版本&#xff0c;推测为性…

企业级Java开发工具MyEclipse v2025.1——支持AI编码辅助

MyEclipse一次性提供了巨量的Eclipse插件库&#xff0c;无需学习任何新的开发语言和工具&#xff0c;便可在一体化的IDE下进行Java EE、Web和PhoneGap移动应用的开发&#xff1b;强大的智能代码补齐功能&#xff0c;让企业开发化繁为简。 立即获取MyEclipse v2025.1正式版 具…

按键长按代码

这些代码都存放在定时器中断中。中断为100ms中断一次。 数据判断&#xff0c;看的懂就看吧

在 macOS 上连接 PostgreSQL 数据库(pgAdmin、DBeaver)

在 macOS 上连接 PostgreSQL 数据库 pgAdmin 官方提供的图形化管理工具&#xff0c;支持 macOS。 下载地址&#xff1a;https://www.pgadmin.org/ pgAdmin 4 是对 pgAdmin 的完全重写&#xff0c;使用 Python、ReactJs 和 Javascript 构建。一个用 Electron 编写的桌面运行时…

FTP协议和win server2022安装ftp

FTP协议简介 FTP&#xff08;File Transfer Protocol&#xff0c;文件传输协议&#xff09;是一种用于在网络上的计算机之间传输文件的标准网络协议。它被广泛应用于服务器与客户端之间的文件上传、下载以及管理操作。FTP支持多种文件类型和结构&#xff0c;并提供了相对简单的…

人工智能——AdaBoost算法

目录 摘要 13 AdaBoost算法 13.1 本章工作任务 13.2 本章技能目标 13.3 本章简介 13.4 编程实战 13.5 本章总结 13.6 本章作业 本章已完结! 摘要 本章实现的工作是:首先采用Python语言读取数据并构造训练集和测试集。然后建立AdaBoost模型,利用训练集训练该模型,…

DFS 蓝桥杯

最大数字 问题描述 给定一个正整数 NN 。你可以对 NN 的任意一位数字执行任意次以下 2 种操 作&#xff1a; 将该位数字加 1 。如果该位数字已经是 9 , 加 1 之后变成 0 。 将该位数字减 1 。如果该位数字已经是 0 , 减 1 之后变成 9 。 你现在总共可以执行 1 号操作不超过 A…

【开发经验】调试OpenBMC Redfish EventService功能

EventService功能是Redfish规范中定义的一种事件日志的发送方式。用户可以设置订阅者信息(通常是一个web服务器)&#xff0c;当产生事件日志时&#xff0c;OpenBMC可以根据用户设置的订阅者信息与对日志的筛选设置&#xff0c;将事件日志发送到订阅者。 相比于传统的SNMPTrap日…

中断嵌套、中断咬尾、中断晚到

中断咬尾&#xff08;Tail-Chaining&#xff09;是一种通过减少上下文切换开销来实现中断连续响应的高效机制&#xff0c;其核心在于避免重复的出栈和入栈操作&#xff0c;从而显著降低中断延迟。以下是具体原理及实现方式&#xff1a; 中断咬尾的运作机制 当多个中断请求连续…

Vue2下载二进制文件

后端&#xff1a; controller: GetMapping(value "/get-import-template")public void problemTemplate(HttpServletRequest request, HttpServletResponse response) throws Exception {iUserService.problemTemplate(request, response);} service: void probl…

Ubuntu小练习

文章目录 一、远程连接1、通过putty连接2、查看putty运行状态3、通过Puuty远程登录Ubuntu4、添加新用户查看是否添加成功 5、用新用户登录远程Ubuntu6、使用VNC远程登录树莓派 二、虚拟机上talk聊天三、Opencv1、简单安装版&#xff08;适合新手安装&#xff09;2、打开VScode特…

996引擎-疑难杂症:Ctrl + F9 编辑好的UI进入游戏查看却是歪的

Ctrl F9 编辑好UI后&#xff0c;进入游戏查看却是歪的。 检查Ctrl F10 是否有做过编辑。可以找到对应界面执行【清空】

WinForm真入门(5)——控件的基类Control

控件的基类–Control 用于 Windows 窗体应用程序的控件都派生自 Control类并继承了许多通用成员,这些成员都是平时使用控件的过程最常用到的。无论要学习哪个控件的使用&#xff0c;都离不开这些基本成员&#xff0c;尤其是一些公共属性。由于 Conlrol 类规范了控件的基本特征…

RAG(检索增强生成)系统,提示词(Prompt)表现测试(数据说话)

在RAG(检索增强生成)系统中,评价提示词(Prompt)设计是否优秀,必须通过量化测试数据来验证,而非主观判断。以下是系统化的评估方法、测试指标和具体实现方案: 一、提示词优秀的核心标准 优秀的提示词应显著提升以下指标: 维度量化指标测试方法事实一致性Faithfulness …

Appium的学习总结-Inspector参数设置和界面使用(5)

环境搭建好后&#xff0c;怎么使用呢&#xff1f; 环境这里使用的是&#xff1a; Appium的Server端GUI 22版本 Inspector需要单独下载安装&#xff0c;GUI里并没有集成。 &#xff08;使用Appium v1.22.0,查看元素信息需要另外安装下载Appium Inspector&#xff09; 操作&…