Pytorch:神经网络训练过程代码详解

文章目录

  • 一、基本概念
    • 1、epoch
    • 2、遍历DataLoader
  • 二、神经网络训练过程代码详解
      • 步骤一:选择并初始化优化器
      • 步骤二:计算损失
      • 步骤三:反向传播
      • 步骤四:更新模型参数
      • 步骤五:清空梯度
      • 组合到训练循环中


一、基本概念

for epoch in range(total_epoch):for label_x,label_y in dataloader:pass

1、epoch

  epoch 指的是整个数据集在训练过程中被完整地遍历一次。如果数据集被分成多个批次输入模型,则一个 epoch 完成后意味着所有的批次已被模型处理一次。epoch 的数目通常根据训练数据的大小、模型复杂度和任务需求来决定。每个 epoch 结束后,模型学到的知识会更加深入,但也存在过度学习(过拟合)的风险,特别是当 epoch 数目过多时。
  即每一个epoch会处理所有的batchepoch也被称为训练周期

2、遍历DataLoader

  遍历DataLoader,实际上就是每次取出一个batch的数据。

二、神经网络训练过程代码详解

建议先理解:Module模块

步骤一:选择并初始化优化器

首先,根据模型的需求选择一个合适的优化器。不同的优化器可能适合不同类型的数据和网络架构。一旦选择了优化器,需要将模型的参数传递给它,并设置一些特定的参数,如学习率、权重衰减等。

import torch.optim as optim# 假设 model 是你的网络模型
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

在这个例子中,选择了随机梯度下降(SGD)作为优化器,并设置了学习率和动量。

步骤二:计算损失

在训练循环中,每次迭代都会处理一批数据,模型会根据这些数据进行预测,并计算损失。

criterion = torch.nn.CrossEntropyLoss()  # 选择合适的损失函数
outputs = model(inputs)                 # 前向传播
loss = criterion(outputs, labels)       # 计算损失

步骤三:反向传播

一旦有了损失,就可以使用 .backward() 方法来自动计算模型中所有可训练参数的梯度。

loss.backward()

这一步将计算损失函数相对于每个参数的梯度,并将它们存储在各个参数的 .grad 属性中。

步骤四:更新模型参数

使用优化器的 .step() 方法来根据计算得到的梯度更新参数。

optimizer.step()

这个调用会更新模型的参数,具体的更新方式取决于你选择的优化算法。

步骤五:清空梯度

在每次迭代后,需要手动清空梯度,以便下一次迭代。如果不清空梯度,梯度会累积,导致不正确的参数更新。

optimizer.zero_grad()

组合到训练循环中

将上述步骤组合到一个训练循环中,我们得到了完整的训练过程:

model = MyModel() #实例化神经网络层,调用继承自Module类的MyModel类的构造函数
criterion = torch.nn.CrossEntropyLoss()  # 选择合适的损失函数,这里是交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)# 定义优化器,传入模型参数
model.train()#切换至训练模式
for epoch in range(total_epochs):for inputs, labels in dataloader:  # 从数据加载器获取数据inputs, labels = inputs.to(device), labels.to(device)# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()  # 清空之前的梯度loss.backward()        # 反向传播# 优化参数optimizer.step()       # 更新参数print(f'Epoch [{epoch+1}/{total_epochs}], Loss: {loss.item()}')
  • loss.item()

    • 在 PyTorch 中,loss 是一个 torch.Tensor 对象。当计算模型的损失时,这个对象通常只包含一个元素(一个标量值),它代表了当前批次数据的损失值。loss.item() 方法是从包含单个元素的张量中提取出那个标量值作为 Python 数值。这是很有用的,因为它允许你将损失值脱离张量的形式进行进一步的处理或输出,比如打印、记录或做条件判断。
  • print(f'Epoch [{epoch+1}/{total_epochs}], Loss: {loss.item()}')

    • 这行代码是用来在训练过程中输出当前 epoch 的编号和该 epoch 的损失值。这对于监控训练进程和调试模型非常有帮助。具体来说:
      • epoch+1:由于计数通常从 0 开始,所以 +1 是为了更自然地显示(从 1 开始而不是从 0 开始)。
      • {total_epochs}:这是训练过程中总的 epoch 数。
      • {loss.item()}:如前所述,这表示当前批次的损失值,作为一个标量数值输出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/3049.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

jQuery 动画小练习

以下是一个使用 jQuery 实现动画效果的简单示例。这个示例会让一个元素在页面加载时向右移动&#xff0c;并在点击时回到原始位置&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"…

冯唐成事心法笔记 —— 知智慧

系列文章目录 冯唐成事心法笔记 —— 知己 冯唐成事心法笔记 —— 知人 冯唐成事心法笔记 —— 知世 冯唐成事心法笔记 —— 知智慧 文章目录 系列文章目录PART 4 知智慧 知可为&#xff0c;知不可为大势不可为怎么办为什么人是第一位的多谈问题&#xff0c;少谈道理用金字塔…

服用5年份筑基丹 - Vue篇

前言 修仙之道&#xff0c;千回百转&#xff0c;每一步都充满了玄妙与机遇。在这条充满奇幻的修仙之路上&#xff0c;有一物至关重要&#xff0c;那便是筑基丹。此丹&#xff0c;凝聚了修仙者多年的心血与智慧&#xff0c;是修炼道路上的重要助力。 今日&#xff0c;我有幸得…

面试经典150题——路径总和

​ 1. 题目描述 2. 题目分析与解析 2.1 思路一 注意题目的关键点&#xff1a;判断该树中是否存在 根节点到叶子节点 的路径&#xff0c;起点是root&#xff0c;终点是叶子节点。 那么我们就可以从根节点按照层序遍历的方式&#xff0c;从根节点从根到 叶子不断对路径进行加…

前端H5动态背景登录页面(下)

最近正好有点儿时间&#xff0c;把之前没整理完的前端动态背景登录页面给整理一下&#xff01;这是之前的连接前端H5动态背景登录页面&#xff08;上&#xff09;&#xff0c;这主要是两个登陆页面&#xff0c;一个彩色气泡&#xff0c;一个动态云朵&#xff0c;感兴趣的可以点…

Python程序设计教案

文章目录&#xff1a; 一&#xff1a;软件环境安装 第一个软件&#xff1a;pycharm 第二个软件&#xff1a;thonny 第三个软件&#xff1a;IDIE&#xff08;自带的集成开发环境&#xff09; 二&#xff1a;相关 1.规范 2.关键字 3.Ascll码表 三&#xff1a;语法基础…

linux nginx开机自启

安装位置/usr/local/nginx监听端口80配置文件地址/usr/local/nginx/conf/ 注册服务 cd /usr/lib/systemd/system/vim nginx.service nginx.service 内容 [Unit] DescriptionThe NGINX HTTP and reverse proxy server Aftersyslog.target network.target[Service] Typeforki…

离开A页面时,取消A页面的axios接口数据请求

需求&#xff1a;从A页面跳转至B页面时&#xff0c;要取消A页面的axios请求&#xff1b;有时候&#xff0c;我们可能需要在发送请求后取消它&#xff0c;比如用户在请求还未完成时离开了当前页面或者执行了其他操作&#xff0c;本文将介绍如何在使用 Axios 发送请求时取消这些请…

Apache反向代理的功能和設置

Apache反向代理是Apache HTTP伺服器的一種功能&#xff0c;可以讓伺服器接收客戶端的請求並將其轉發到其他伺服器&#xff0c;然後將這些伺服器的回應返回給客戶端。這樣&#xff0c;客戶端就像直接訪問Apache伺服器一樣&#xff0c;而實際上是在訪問其他的伺服器。 Apache反向…

【Altium Designer 22原理图,PCB】

Altium Designer 22-原理图&#xff0c;PCB ■ AD22■ 工程■ 工程之外的文件 ■ AD22-画原理图■ 原理图库的设计■ 操作心得■ 元件库来源■ 检查原理图库的正确性并生成报告 ■ 原理图的设计■ 原理图页的大小设置■ 设置栅格100mil■ 放置元器件■ 元件的复制&#xff0c;剪…

从 MySQL 到 ClickHouse 实时数据同步 —— Debezium + Kafka 表引擎

目录 一、总体架构 二、安装配置 MySQL 主从复制 三、安装配置 ClickHouse 集群 四、安装 JDK 五、安装配置 Zookeeper 集群 六、安装配置 Kafaka 集群 七、安装配置 Debezium-Connector-MySQL 插件 1. 创建插件目录 2. 解压文件到插件目录 3. 配置 Kafka Connector …

常见UI设计模式有哪些?从小白到资深必学

通过了解如何以及何时使用&#xff0c;每种 UI 设计模式都有其特定的目的&#xff0c;可以创建一个一致高效的界面。UI 设计模式为用户界面设计者提供了一种通用语言&#xff0c;并为网站和应用程序的用户提供了一致性。本指南&#xff0c;即时设计总结了 UI 设计模式和 UI 设计…

执法记录仪如何防抖

影像记录发展至今&#xff0c;防抖已是必备要素&#xff0c;实际拍摄过程中&#xff0c;或通过硬件的运动补偿&#xff0c;或通过软件的加工处理&#xff0c;来抵消抖动对拍摄的影响。 到现在为止&#xff0c;已经有哪些防抖技术&#xff0c;它们各有什么优劣呢&#xff1f; …

HTTP协议的总结

参考 https://www.runoob.com/http/http-tutorial.html 1.简介 HTTP&#xff08;超文本传输协议&#xff0c;Hypertext Transfer Protocol&#xff09;是一种用于从网络传输超文本到本地浏览器的传输协议。它定义了客户端与服务器之间请求和响应的格式。HTTP 工作在 TCP/IP 模…

美客多、Lazada商家必须知道的养号技巧,助力打造爆款!

在Lazada平台开店&#xff0c;每个商家都渴望打造出自己的爆款产品。爆款不仅能为店铺带来大量流量&#xff0c;还能显著提升店铺和其他产品的转化率。然而&#xff0c;要想成功打造爆款&#xff0c;并非易事&#xff0c;需要掌握一些关键的小技能。 在Lazada平台&#xff0c;商…

每日OJ题_BFS解决拓扑排序③_力扣LCR 114. 火星词典

目录 力扣LCR 114. 火星词典 解析代码 力扣LCR 114. 火星词典 LCR 114. 火星词典 难度 困难 现有一种使用英语字母的外星文语言&#xff0c;这门语言的字母顺序与英语顺序不同。 给定一个字符串列表 words &#xff0c;作为这门语言的词典&#xff0c;words 中的字符串已…

十五、Java中I/O流

1、流的基本概念 1)流的概念 流:在Java中所有的数据都是使用流读写的。流是一组有顺序的,有起点和终点的字节集合,是对数据传输的总称或抽象。即数据在两设备间的传输称为流,流的本质就是数据传输,根据数据传输特性将流抽象为各种类。 (1)按照流向分:输入流、输出流。…

网络靶场实战-物联网安全qiling框架初探

背景 Qiling Framework是一个基于Python的二进制分析、模拟和虚拟化框架。它可以用于动态分析和仿真运行不同操作系统、处理器和体系结构下的二进制文件。除此之外&#xff0c;Qiling框架还提供了易于使用的API和插件系统&#xff0c;方便使用者进行二进制分析和漏洞挖掘等工作…

【求助】西门子S7-200PLC定时中断+数据归档的使用

前言 已经经历了种种磨难来记录我的数据&#xff08;使用过填表程序、触摸屏的历史记录和数据归档&#xff09;之后&#xff0c;具体可以看看这篇文章&#xff1a;&#x1f6aa;西门子S7-200PLC的数据归档怎么用&#xff1f;&#xff0c;出现了新的问题。 问题的提出 最新的…

网工交换基础——生成树协议(01)

一、生成树的技术概述 1、技术背景 二层交换机网络的冗余性导致出现二层环路&#xff1a; 人为因素导致的二层环路问题&#xff1a; 二层环路带来的网络问题&#xff1a; 生成树协议的概念&#xff1a; STP(Spanning Tree Protocol)是生成树协议的英文缩写。该协议可应用于在网…