Pytorch:神经网络过程代码详解

文章目录

  • 一、基本概念
    • 1、epoch
    • 2、遍历DataLoader
  • 二、神经网络训练过程代码详解
      • 步骤一:选择并初始化优化器
      • 步骤二:计算损失
      • 步骤三:反向传播
      • 步骤四:更新模型参数
      • 步骤五:清空梯度
      • 组合到训练循环中
      • 步骤六:保存模型
  • 三、神经网络评估过程代码详解
      • 步骤一:加载模型
      • 步骤二:切换至评估模式
      • 步骤三:进行评估(这里计算分类问题的准确率)
  • 四、经典数据集——鸢尾花数据集
  • 五、常用神经网络层计算原理
    • 1、nn.Linear(input_features,output_features)
      • a.基本概念和公式
      • b.详细解释
      • c.代码示例
      • d.原理解释


一、基本概念

for epoch in range(total_epoch):for label_x,label_y in dataloader:pass

1、epoch

  epoch 指的是整个数据集在训练过程中被完整地遍历一次。如果数据集被分成多个批次输入模型,则一个 epoch 完成后意味着所有的批次已被模型处理一次。epoch 的数目通常根据训练数据的大小、模型复杂度和任务需求来决定。每个 epoch 结束后,模型学到的知识会更加深入,但也存在过度学习(过拟合)的风险,特别是当 epoch 数目过多时。
  即每一个epoch会处理所有的batchepoch也被称为训练周期

2、遍历DataLoader

  遍历DataLoader,实际上就是每次取出一个batch的数据。

二、神经网络训练过程代码详解

建议先理解:Module模块

步骤一:选择并初始化优化器

首先,根据模型的需求选择一个合适的优化器。不同的优化器可能适合不同类型的数据和网络架构。一旦选择了优化器,需要将模型的参数传递给它,并设置一些特定的参数,如学习率、权重衰减等。

import torch.optim as optim# 假设 model 是你的网络模型
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

在这个例子中,选择了随机梯度下降(SGD)作为优化器,并设置了学习率和动量。

步骤二:计算损失

在训练循环中,每次迭代都会处理一批数据,模型会根据这些数据进行预测,并计算损失。

criterion = torch.nn.CrossEntropyLoss()  # 选择合适的损失函数
outputs = model(inputs)                 # 前向传播
loss = criterion(outputs, labels)       # 计算损失

步骤三:反向传播

一旦有了损失,就可以使用 .backward() 方法来自动计算模型中所有可训练参数的梯度。

loss.backward()

这一步将计算损失函数相对于每个参数的梯度,并将它们存储在各个参数的 .grad 属性中。

步骤四:更新模型参数

使用优化器的 .step() 方法来根据计算得到的梯度更新参数。

optimizer.step()

这个调用会更新模型的参数,具体的更新方式取决于你选择的优化算法。

步骤五:清空梯度

在每次迭代后,需要手动清空梯度,以便下一次迭代。如果不清空梯度,梯度会累积,导致不正确的参数更新。

optimizer.zero_grad()

组合到训练循环中

将上述步骤组合到一个训练循环中,我们得到了完整的训练过程:

model = MyModel() #实例化神经网络层,调用继承自Module类的MyModel类的构造函数
criterion = torch.nn.CrossEntropyLoss()  # 选择合适的损失函数,这里是交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)# 定义优化器,传入模型参数
model.train()#切换至训练模式
for epoch in range(total_epochs):for inputs, labels in dataloader:  # 从数据加载器获取数据inputs, labels = inputs.to(device), labels.to(device)# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()  # 清空之前的梯度loss.backward()        # 反向传播# 优化参数optimizer.step()       # 更新参数print(f'Epoch [{epoch+1}/{total_epochs}], Loss: {loss.item()}')
  • loss.item()

    • 在 PyTorch 中,loss 是一个 torch.Tensor 对象。当计算模型的损失时,这个对象通常只包含一个元素(一个标量值),它代表了当前批次数据的损失值。loss.item() 方法是从包含单个元素的张量中提取出那个标量值作为 Python 数值。这是很有用的,因为它允许你将损失值脱离张量的形式进行进一步的处理或输出,比如打印、记录或做条件判断。
  • print(f'Epoch [{epoch+1}/{total_epochs}], Loss: {loss.item()}')

    • 这行代码是用来在训练过程中输出当前 epoch 的编号和该 epoch 的损失值。这对于监控训练进程和调试模型非常有帮助。具体来说:
      • epoch+1:由于计数通常从 0 开始,所以 +1 是为了更自然地显示(从 1 开始而不是从 0 开始)。
      • {total_epochs}:这是训练过程中总的 epoch 数。
      • {loss.item()}:如前所述,这表示当前批次的损失值,作为一个标量数值输出。

步骤六:保存模型

torch.save(model.state_dict(),"model.pth")

三、神经网络评估过程代码详解

步骤一:加载模型

实例化对应模型,使用该模型对象的.load_state_dict()方法导入之前存储的模型参数。

model=MyModel()
model.to(device)
model.load_state_dict(torch.load("model.pth"))

步骤二:切换至评估模式

model.eval()

步骤三:进行评估(这里计算分类问题的准确率)

  • 定义评估时需要的变量
  • 使用torch.no_grad()指定上下文Pytorch不追踪梯度信息
total_correct=0 #正确的样本数
total_samples=0 #样本数
with torch.no_grad():# 该局部化区域内的张量不再计算梯度for batch_x,batch_y in dataloader:batch_x.to(device)batch_y.to(device)batch_x = batch_x.to(torch.float) #转换成浮点output = model(batch_x) # 前向传播,得到分类结果 形状为[batch_size,num_classes]_,predicted = torch.max(output,dim=1)  # 不考虑dim=0的batch_size,从第一维开始考虑,沿着dim=1的方向寻找最大值,实际上就是找分类得分最高的分类,predicted接收的是max的索引,因此predicted的形状是[batch_size,1] 就是每一个样本的分类total_correct += (predicted == batch_y).sum().item()total_samples += predicted.size(dim=0)
accuracy=total_correct / total_samples
print(f"accuracy:{accuracy}")

_,predicted=tensor.max(output,dim=1)

  • output形状是[batch_size,classes_num],沿着列求最大值, 得到列中的最大值索引,相当于得到的predicted的形状是[batch_size,1],这里的1 的数值 是一行中 最大值的索引 也就是预测的类别。然后batch_y 的形状是[batch_size,1]predicted进行对比,就是对比类别是否相同。所以,我们在考虑问题的时候,由于batch_size的存在,第0维忽略掉考虑也行,然后就好理解了

四、经典数据集——鸢尾花数据集

代码来源

import torch
import torch.nn as nn
from sklearn.datasets import load_iris
from torch.utils.data import Dataset, DataLoader# 此函数用于加载鸢尾花数据集
def load_data(shuffle=True):x = torch.tensor(load_iris().data)y = torch.tensor(load_iris().target)# 数据归一化x_min = torch.min(x, dim=0).valuesx_max = torch.max(x, dim=0).valuesx = (x - x_min) / (x_max - x_min)if shuffle:idx = torch.randperm(x.shape[0])x = x[idx]y = y[idx]return x, y# 自定义鸢尾花数据类
class IrisDataset(Dataset):def __init__(self, mode='train', num_train=120, num_dev=15):super(IrisDataset, self).__init__()x, y = load_data(shuffle=True)  # 将x转换为浮点型数据y = y.long()  # 将y转换为长整型数据# x, y = load_data(shuffle=True)if mode == 'train':self.x, self.y = x[:num_train], y[:num_train]elif mode == 'dev':self.x, self.y = x[num_train:num_train + num_dev], y[num_train:num_train + num_dev]else:self.x, self.y = x[num_train + num_dev:], y[num_train + num_dev:]def __getitem__(self, idx):return self.x[idx], self.y[idx]def __len__(self):return len(self.x)# 创建一个模型类来定义神经网络模型
class IrisModel(nn.Module):def __init__(self):super(IrisModel, self).__init__()self.fc = nn.Linear(4, 3)def forward(self, x):return self.fc(x)# 加载数据
batch_size = 16train_dataset = IrisDataset(mode='train')
dev_dataset = IrisDataset(mode='dev')
test_dataset = IrisDataset(mode='test')train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)# 实例化神经网络模型
model = IrisModel()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)# 训练模型
num_epochs = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)for epoch in range(num_epochs):model.train()total_loss = 0for batch_x, batch_y in train_loader:batch_x, batch_y = batch_x.to(device), batch_y.to(device)batch_x = batch_x.to(torch.float)  # 使用float32数据类型# RuntimeError: mat1 and mat2 must have the same dtype, but got Double and Floatoptimizer.zero_grad()output = model(batch_x)loss = criterion(output, batch_y)loss.backward()optimizer.step()total_loss += loss.item()avg_loss = total_loss / len(train_loader)print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {avg_loss}")# 保存模型
torch.save(model.state_dict(), 'iris_model.pth')#%%
# 加载模型
model = IrisModel()  # 先实例化一个模型
model.to(device)
model.load_state_dict(torch.load('iris_model.pth'))# 评估模型
model.eval()
total_correct = 0
total_samples = 0with torch.no_grad():for batch_x, batch_y in test_loader:batch_x, batch_y = batch_x.to(device), batch_y.to(device)batch_x = batch_x.to(torch.float)  # 使用float32数据类型output = model(batch_x)_, predicted = torch.max(output, dim=1)total_correct += (predicted == batch_y).sum().item()total_samples += batch_y.size(0)accuracy = total_correct / total_samples
print(f"Test Accuracy: {accuracy}")

五、常用神经网络层计算原理

1、nn.Linear(input_features,output_features)

在 PyTorch 中,torch.nn.Linear 表示一个全连接层(也称作密集层或线性层)。这个层执行的基本运算是线性变换,将输入数据通过矩阵乘法和偏置加法转换到一个新的空间。 这是许多神经网络架构中最基本的组成部分之一。

a.基本概念和公式

全连接层的基本数学公式可以表示为:
output = input ⋅ W T + b \text{output} = \text{input} \cdot W^T + b output=inputWT+b
其中:

  • input 是输入数据。
  • W 是权重矩阵。使用.weight查看
  • b 是偏置向量。使用.bias查看
  • output 是层的输出。

b.详细解释

  1. 权重矩阵(W):在全连接层中,每个输入特征都通过权重与每个输出节点相连。权重矩阵的形状通常是 [out_features, in_features],其中 in_features 是输入层的特征数,out_features 是输出层的特征数。这种结构确保了每个输入特征都通过权重影响每个输出特征。

  2. 偏置(b):偏置是一个向量,其长度等于输出特征的数量(out_features)。偏置向量在矩阵乘法后添加到每一个输出向量上,可以提供额外的灵活性,使模型能够更好地拟合数据。

  3. 矩阵乘法:输入向量(或批量输入矩阵)与权重矩阵的转置进行矩阵乘法。这一步实现了从输入空间到输出空间的线性变换。

  4. 加偏置:将偏置向量加到矩阵乘法的结果上,以完成从输入到输出的映射。

c.代码示例

假设我们有一个输入特征维度为 in_features=5,输出特征维度为 out_features=4 的全连接层,下面是如何在 PyTorch 中定义和使用这个层的示例:

import torch
import torch.nn as nnx = torch.randn(5,4,3,5)
y = torch.randn(5,4,5,3)
f = nn.Linear(5,4)
output_tensor = f(x)
print(output_tensor.shape)

输出:

torch.Size([5, 4, 3, 4])

在这个例子中,x 经过 f 后,得到了 output_tensor。全连接层的参数(权重和偏置)是在定义层时自动初始化的,并且在训练过程中通过反向传播算法进行优化。

d.原理解释

Linear线性层,实际上是对向量进行线性变换。对于一个张量形状为[a,b,c,d,e,f,g],它的输入特征被认为是g,高维并不变化,通过输入特征长度为g的线性层之后。
比如

形状[3,5]
x1: 1 2 3 4 5
x2: 3 4 5 6 7
x3: 5 6 7 4 6

经过线性层:

nn.Linear(5,4) # 权重矩阵形状为[4,5]

改变形状为(我们这里不关注数值):实际上就是矩阵[3,5]×[5,4]得到[3,4]

形状[3,4]
x1: 1 3 4 5
x2: 4 5 6 6
x3: 6 7 3 1

因此可以理解为,经过线性层实际上就是将向量改变其维度(进行线性变换),比如注意力机制使用的权重矩阵就是将输入 X X X进行线性变换,使用一个Linear层就能做到。如果输入的维度更高,只不过是将向量分成了不同批次罢了,我们关注最后一维即可。 这和矩阵乘法是一样的,形状上只需要关注最后两维,更高维度只不过是把矩阵封装成不同批次罢了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/5515.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

windows和mac 电脑 部署Ollama

官网地址:https://ollama.com/ github地址:https://github.com/ollama/ollama 一、windows下 https://github.com/ollama/ollama 安装大模型 ollama run llama3 下载的大模型地址: C:\Users\dengg\.ollama 4.34G

二维数组-----刷题2

题目不是傻子题目&#xff0c;但很简单&#xff01;定义一个变量k&#xff0c;在嵌套中不断累加输出即可。 #include<cstdio> int k,n; int main(){scanf("%d",&n);for(int i1;i<n;i){for(int j1;j<n;j){k;printf("%d ",k);}printf("…

Python基础学习之记录中间文件

倘若想记录代码运行过程中的结果文件&#xff0c;那么以下函数仅供参考 代码示例&#xff1a; import os import datetime import sys import pandas as pd# 定义总的文件夹路径 base_folder E:\\D\\log\\product_data_compare_log# 定义一个函数来创建带时间戳的文件夹 def…

LoRa模块在智能灌溉系统中的应用特点介绍

LoRa模块在智能灌溉系统中的应用特点主要体现在以下几个方面&#xff1a; 低功耗与长寿命&#xff1a; LoRa模块具有极低的功耗&#xff0c;使其在待机状态下耗电量极低&#xff0c;能够支持长时间连续运行&#xff0c;减少了频繁更换电池或充电的需求&#xff0c;确保了智能灌…

【Godot4.2】有序和无序列表函数库 - myList

概述 在打印输出或其他地方可能需要构建有序或无序列表。本质就是构造和维护一个纯文本数组。并用格式化文本形式&#xff0c;输出带序号或前缀字符的多行文本。 为此我专门设计了一个类myList&#xff0c;来完成这项任务。 代码 以下是myList类的完整代码&#xff1a; # …

SQL Sever无法连接服务器

SQL Sever无法连接服务器&#xff0c;报错证书链是由不受信任的颁发机构颁发的 解决方法&#xff1a;不用ssl方式连接 1、点击弹框中按钮“选项” 2、连接安全加密选择可选 3、不勾选“信任服务器证书” 4、点击“连接”&#xff0c;可连接成功

python安卓自动化pyaibote实践------学习通自动刷课

前言 欢迎来到我的博客 个人主页:北岭敲键盘的荒漠猫-CSDN博客 本文是一个完成一个自动播放课程&#xff0c;避免人为频繁点击脚本的构思与源码。 加油&#xff01;为实现全部电脑自动化办公而奋斗&#xff01; 为实现摆烂躺平的人生而奋斗&#xff01;&#xff01;&#xff…

视觉语言模型详解

视觉语言模型可以同时从图像和文本中学习&#xff0c;因此可用于视觉问答、图像描述等多种任务。本文&#xff0c;我们将带大家一览视觉语言模型领域: 作个概述、了解其工作原理、搞清楚如何找到真命天“模”、如何对其进行推理以及如何使用最新版的 trl 轻松对其进行微调。 什…

【C语言】指针篇-精通库中的快速排序算法:巧妙掌握技巧(4/5)

&#x1f308;个人主页&#xff1a;是店小二呀 &#x1f308;C语言笔记专栏&#xff1a;C语言笔记 &#x1f308;C笔记专栏&#xff1a; C笔记 &#x1f308;喜欢的诗句:无人扶我青云志 我自踏雪至山巅 文章目录 一、回调函数二、快速排序(Qsort)2.1 Qsort参数部分介绍2.2 不…

报错“Install Js dependencies failed”【鸿蒙开发Bug已解决】

文章目录 项目场景:问题描述原因分析:解决方案:此Bug解决方案总结Bug解决方案寄语项目场景: 最近也是遇到了这个问题,看到网上也有人在询问这个问题,本文总结了自己和其他人的解决经验,解决了【报错“Install Js dependencies failed”】的问题。 报错如下 问题描述 …

leetcode 92. 反转链表 II

class Solution(object):def reverseBetween(self, head, left, right):""":type head: ListNode:type left: int:type right: int:rtype: ListNode""" right right -1left left -1while( right-left>0 ):print(right-left)# 左侧节点l …

零基础玩转Linux+Ubuntu实战视频课程

零基础玩转LinuxUbuntu实战视频课程 Linux发行版之间的关系jpg 1-1课程简介及Linux学习路线介绍.mp4 1-10什么是环境变量.mp4 1-11文件系统管理.mp4 1-12用户账户管理.mp4 1-13文件的访问权限.mp4 1-14进程管理.mp4 1-15软件源码包的编译、安装与卸裁.mp4 1-16制作自己的deb软…

WRF原理与基本操作

WRF介绍 WPS是三个&#xff0c;它们协同工作&#xff0c;为真实数据模拟的输入准备输出资料,为真实数据模拟做预处理。 geogrid定义模式范围&#xff0c;将静态地形资料插值到格点 ; ungrib将气象数据从GRIB格式解码 提取气象场; metgrid将ungrib解码的气象场水平地插值到geog…

【C++语法练习】计算梯形的面积

题目链接&#xff1a;https://www.starrycoding.com/problem/158 题目描述 已知一个梯形的上底 a a a&#xff0c;下底 b b b和高 h h h&#xff0c;请求出它的面积&#xff08;结果保留两位小数&#xff09;。 输入格式 第一行一个整数 T T T表示测试用例个数。 ( 1 ≤ T …

Linux 的静态库和动态库

本文目录 一、静态库1. 创建静态库2. 静态库的使用 二、动态库1. 为什么要引入动态库呢&#xff1f;2. 创建动态库3. 动态库的使用4. 查看可执行文件依赖的动态库 一、静态库 在编译程序的链接阶段&#xff0c;会将源码汇编生成的目标文件.o与引用到的库&#xff08;包括静态库…

Open CASCADE学习|GeomFill_CurveAndTrihedron

GeomFill_CurveAndTrihedron类是GeomFill_LocationLaw的子类&#xff0c;用于定义一个位置法则&#xff08;Location Law&#xff09;&#xff0c;该法则结合了一个曲线&#xff08;curve&#xff09;和一个三面体法则&#xff08;TrihedronLaw&#xff09;。 类功能&#xff…

关于用户体验和设计思维

介绍 要开发有效的原型并为用户提供出色的体验&#xff0c;了解用户体验 (UX) 和设计思维的原则至关重要。 用户体验是用户与产品、服务或系统交互并获得相应体验的过程。 设计思维是一种解决问题的方法&#xff0c;侧重于创新和创造。 在启动期实现用户体验和设计思维时&#…

JQuery从入门到精通

目录-JQuery 1.概述............................................................. 2 2.简介............................................................. 3 3.安装............................................................. 4 4.语法............................…

大数据分析与内存计算学习笔记

一、Scala编程初级实践 1.计算级数&#xff1a; 请用脚本的方式编程计算并输出下列级数的前n项之和Sn&#xff0c;直到Sn刚好大于或等于q为止&#xff0c;其中q为大于0的整数&#xff0c;其值通过键盘输入。&#xff08;不使用脚本执行方式可写Java代码转换成Scala代码执行&a…

监视器和显示器的区别,普通硬盘和监控硬盘的区别

监视器与显示器的区别&#xff0c;你真的知道吗&#xff1f; 中小型视频监控系统中&#xff0c;显示系统是最能展现效果的一个重要环节&#xff0c;显示系统的优劣将直接影响视频监控系统的用户体验满意度。 中小型视频监控系统中&#xff0c;显示系统是最能展现效果的一个重要…