Pytorch:模块(Module类)

文章目录

  • 一、Module类介绍
    • 1、主要功能
    • 2、神经网络模型使用理解
      • a.前向传播示例代码
      • b.关键点


在 PyTorch 中,Module 是一个非常核心的概念,它是所有神经网络层和模型的基础类。torch.nn.Module 是构建所有神经网络的基类,在 PyTorch 中非常重要,因为它提供了网络的组织架构,并封装了权重、梯度的管理、模型参数的更新等功能。

PyTorch 中的 Linear 层ReLU 激活函数以及大多数其他神经网络层和函数都返回 torch.Tensor 类型的对象。这些返回的张量包含了经过相应层或函数处理后的数据。在神经网络中,数据通常以张量的形式在各个层之间流动。

一、Module类介绍

所有神经网络层和模型的基础类,自定义神经网络时对其继承。

1、主要功能

  1. 封装参数

    • Module 类在内部自动管理 层的参数。每当你在 Module 中定义一个层对象,如 self.conv1 = nn.Conv2d(...), PyTorch 自动将这些层的参数加入到模型的参数列表中。这些参数通过 module.parameters() 方法访问。模型参数(定义在模型内部的层的权重和偏置)默认 requires_grad=True
  2. 自动梯度计算

    • 每个 Module 可以使用 PyTorch 的自动微分(autograd)系统来自动计算和存储梯度。在 forward 方法执行运算时,PyTorch 会跟踪这些运算产生的所有张量,对应的梯度在调用张量的 backward()方法后自动计算。由于模型参数默认requires_grad=True,因此对这些参数的所有操作都将被进行自动梯度计算。
  3. 前向传播定义

    • 在定义自己的网络时,需要覆盖 Moduleforward() 方法。这是模型接收输入数据并返回输出的地方forward() 方法定义了模型的前向传播路径
  4. 模型保存和加载
    模型的保存和加载是在 PyTorch 中进行模型持久化和迁移学习的常用操作。模型可以保存为 .pt.pth 文件,包括其参数、优化器状态和其他任何相关的信息。

  • 保存模型:

    • 最简单的保存方法是使用 torch.save 来保存模型的 state_dict,这是一个包含模型参数的字典。
    torch.save(model.state_dict(), 'model_path.pth')
    
    • 值得注意的是, 这条命令只是用来保存模型参数的,因此在加载参数时,需要使用同样的模型使用load_state_dict()才可。
  • 加载模型:

    • 加载模型时,首先需要实例化模型对象,然后使用 load_state_dict() 方法加载参数。
    model = MyModel()
    model.load_state_dict(torch.load('model_path.pth'))
    
  1. 将模型移动到指定的设备:
    在 PyTorch 中,可以将模型和数据移动到不同的设备上(如 CPU 或 GPU),以支持不同的计算需求。
  • 使用 .to() 方法可以将模型移动到指定的设备:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
  1. 切换模型的训练和评估方式
    torch.nn.Module 提供了 .train().eval() 方法,用于切换模型的训练和评估模式。
  • 训练模式 (train):

    • 在训练模式下,所有的层都被通知模型正在训练,这对于某些特定层(如 DropoutBatchNorm)非常重要,因为它们在训练和评估时的行为不同。
    model.train()
    
  • 评估模式 (eval):

    • 评估模式用于模型测试或验证阶段,确保所有层都处于评估状态。
    model.eval()
    
  1. .parameters()方法
  • parameters() 方法返回一个迭代器,包含模型中所有的参数(通常用于传递给优化器)。
    for param in model.parameters():print(param.size())
    
  1. .modules()方法
  • modules() 方法返回一个迭代器,遍历模型中的所有模块(层)。这在分析模型结构或应用特定操作到每一层时非常有用。
    for module in model.modules():print(module)
    

2、神经网络模型使用理解

白话:
  损失函数和优化器都不是module类中的方法,而是外部的方法,但是他们都能够作用于模型的权重:由于自动微分,损失函数接收的是结果张量,因此损失函数带来的梯度会被更新给权重的梯度。而优化器接受的是module对象参数的迭代器,它能根据参数的梯度对参数进行更新。

  自定义神经网络,实际上就是定义一个,该类继承自torch.nn.Module。在对这个类进行实例化时,是使用__init__默认构造函数实例化的。实例化后得到一个神经网络对象,对该对象输入数据会被重载为输入forward函数,而forward函数就是对输入数据进行一层一层的网络层结构处理。forward函数的输出一般是对输入进行了前向传播后的结果,为了对模型参数进行训练更新,我们一般还需要定义一个损失函数;这个损失函数是torch.tensor类型的,可以调用其backward函数,进行反向传播梯度;最后定义一个优化器,进行参数权重更新(实际上这里反向传播梯度 就 相当于损失函数对权重进行求导了,改变权重的方向就是让损失更小的方向。)

反向传播并不更新参数,优化器才是用来更新参数的,反向传播只是更新梯度。 这也是为什么优化器有一个学习率。教程:张量的梯度计算


非白话 自定义神经网络的流程:

  1. 定义一个类,继承自 torch.nn.Module

    • 这个类是您自定义神经网络的基础。通过继承 torch.nn.Module,您的网络能够利用 PyTorch 提供的模块化、参数管理、梯度计算等强大功能。
  2. __init__ 方法中初始化网络层

    • 这是定义神经网络结构的地方。您可以添加诸如全连接层 (nn.Linear), 卷积层 (nn.Conv2d), 激活函数 (nn.ReLU) 等。这些层将被自动注册为模块的子项,使其参数也自动成为模型的一部分。
  3. 定义 forward 方法

    • forward 方法描述了输入数据如何通过定义的层传播。这个方法是在模型训练和评估时自动被调用的,用于前向传播计算输出。
  4. 损失函数和反向传播

    • 在训练阶段,网络输出通过一个损失函数 (loss function) 评估其与真实标签的差异。常用的损失函数有 nn.CrossEntropyLoss(用于分类任务)和 nn.MSELoss(用于回归任务)。
    • 调用损失张量的 .backward() 方法启动自动梯度计算,即反向传播。在这一过程中,PyTorch 根据损失函数自动计算每个参数的梯度,并存储在参数的 .grad 属性中。
    • 在每次迭代后,需要手动清空梯度,以便下一次迭代。如果不清空梯度,梯度会累积,导致不正确的参数更新。清空梯度:optimizer.zero_grad()
  5. 参数更新

    • 使用一个优化器(如 torch.optim.SGDtorch.optim.Adam)来调整网络参数,基于计算的梯度进行更新,以减少损失函数的值。这通常在调用 .backward() 后进行。

a.前向传播示例代码

损失函数和优化器的例子请看:神经网络训练过程代码详解

下面是一个简单的自定义 Module 的例子,定义了一个包含两个全连接层的简单神经网络。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()# 定义第一个全连接层self.fc1 = nn.Linear(16, 12)# 定义第二个全连接层self.fc2 = nn.Linear(12, 10)def forward(self, x):# 第一个全连接层的激活函数使用ReLUx = F.relu(self.fc1(x))# 第二个全连接层的输出x = self.fc2(x)return x# 实例化网络
net = SimpleNet()#__init__()里并不需要参数。默认构造函数不需要参数,net就是一个实例化对象。
# 创建一些随机输入数据
input = torch.randn(1, 16)# 通过网络进行前向传播
output = net(input)#实际上直接使用对象名(),重载为:调用forward函数。
#input先经过一个nn.Linear(16,12),然后进行一次relu(),然后经过一个nn.Linear(12,10)

b.关键点

  • 继承:自定义的模型需要继承自 nn.Module
  • 超类初始化:使用 super() 初始化基类,这是在 Python 类中常见的做法,确保正确初始化父类部分。
  • 定义层:在构造函数中定义网络所需的各种层。
  • 前向传播:在 forward 方法中定义数据如何通过网络。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/3585.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

抖音视频笔记

文章目录 手机录屏如何录入麦克风声音变声 一直不太用抖音等交圈软件。 但是有时想记录下生活中的点滴,比较简单的方式实际就是app,那么了解下吧。 制作完毕后可以保存为草稿,不一定发布的。 手机录屏如何录入麦克风声音 毫无疑问&#xff…

图神经网络 | Pytorch图神经网络ST-GNN

时空图神经网络(Spatio-temporal Graph Neural Network)是一种用于处理时空数据的神经网络模型。它结合了图神经网络(Graph Neural Network)和时空数据的特性,能够对时空关系进行建模和预测。 在时空图神经网络中,数据被组织成图的形式,其中节点表示特定的时空位置,边…

Java NIO概念

Java NIO是什么? Java NIO,全称为Java Non-blocking Input/Output或New IO,是Java平台从JDK 1.4版本开始引入的一套新的输入/输出API。它旨在提供一种更高效、可扩展性更强的IO操作方式,特别适合构建高性能的网络应用和进行大容量…

决策树分析及其在项目管理中的应用

决策树分析是一种分类学习方法,其主要用于解决分类和回归问题。在决策树中,每个内部节点表示一个属性上的测试,每个分支代表一个属性输出,而每个叶节点则代表类或类分布。通过从根节点到内部节点的路径,可以构建一系列…

uniapp制作安卓原生插件踩坑

1.uniapp和Android工程互相引用讲解 uniapp原生Android插件开发入门教程 (最新版)_uniapp android 插件开发-CSDN博客 2.uniapp引用原生aar目录结构 详细尝试步骤1完成后生成的aar使用,需要新建nativeplugins然后丢进去 3.package.json示例…

深度学习--RNN循环神经网络和LSTM

深度学习中的循环神经网络(RNN)以及其中的一个变种长短期记忆网络(LSTM)是在序列数据处理方面非常重要的模型。下面我将详细介绍这两种网络的原理和应用。 循环神经网络(RNN) 循环神经网络是一类专门用于…

pytest数据驱动DDT(数据库/execl/yaml)

常见的DDT技术 数据结构: 列表、字典、json串 文件: txt、csv、excel 数据库: 数据库链接 数据库提取 参数化: pytest.mark.parametrize() pytest.fixture() …

Java集合框架-Collection-List-vector(遗留类)

目录 一、vector层次结构图二、概述三、底层数据结构四、常用方法五、和ArrayList的对比 一、vector层次结构图 二、概述 Vector类是单列集合List接口的一个实现类。与ArrayList类似,Vector也实现了一个可以动态修改的数组,两者最本质的区别在于——Vec…

有哪些人工智能/数据分析领域可以考取的证书?

一、TensorFlow谷歌开发者认证 TensorFlow面向学生、开发者、数据科学家等人群,帮助他们展示自己在用 TensorFlow 构建、训练模型的过程中所学到的实用机器学习技能。 添加图片注释,不超过 140 字(可选) TensorFlow 的产品总监 …

SQL中的锁

一、概述 介绍 锁是计算机协调多个进程或线程并发访问某一资源的机制。在数据库中,除传统的计算资(CPU、RAM、I/0)的争用以外,数据也是一种供许多用户共享的资源。如何保证数据并发访问的一致性、有效性是所有数据库必须解决的一个问题,锁冲…

SpringBoot Redis使用篇

引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId…

keep-alive的理解和使用方法(使用时的生命周期)

文章目录 一、Keep-alive 是什么二、使用场景三、原理分析四、思考题&#xff1a;缓存后如何获取数据beforeRouteEnteractived 参考文献 一、Keep-alive 是什么 keep-alive是vue中的内置组件&#xff0c;能在组件切换过程中将状态保留在内存中&#xff0c;防止重复渲染DOM ke…

PostgreSQL的扩展(extensions)-常用的扩展之pg_stat_statements

PostgreSQL的扩展&#xff08;extensions&#xff09;-常用的扩展之pg_stat_statements 基础信息 OS版本&#xff1a;Red Hat Enterprise Linux Server release 7.9 (Maipo) DB版本&#xff1a;16.2 pg软件目录&#xff1a;/home/pg16/soft pg数据目录&#xff1a;/home/pg16/…

java解析PDF、WORD获取中的表格以及文本内容

近期因工作需要需要解析PDF&#xff0c;需要把PDF中的文本和表格分离&#xff0c;最终要实现的目标是PDF中的文本内容放一块&#xff0c;表格内容放一块&#xff0c;以list的形式存储。解析PDF的技术有很多&#xff0c;经过多次尝试发现使用AdobeAcrobat可以实现表格和文本分离…

TensorFlow轻松入门(一)(更新中)

常见模块 tf. &#xff1a;包含了张量定义&#xff0c;变换等常用函数和类&#xff1b;tf.data&#xff1a;输入数据处理模块&#xff0c;提供了像tf.data.Dataset等类用于封装输入数据&#xff0c;指定批量大小等&#xff1b;tf.image&#xff1a;图像处理模块&#xff0c;提…

el-form 表单设置某个参数非必填验证

html <el-form ref"form" :rules"rules"><el-form-item prop"tiktokEmail" label"邮箱" ><el-input v-model"form.tiktokEmail" placeholder"邮箱" ></el-input></el-form-item&…

mybatis中foreach使用

一、foreach 属性使用 <foreach collection"list" index"index" item"mchntCd" open"(" close")" separator",">#{mchntCd} </foreach>item&#xff1a; 集合中元素迭代时的别名&#xff0c;该参数为…

项目实战:Qt获取CTP量化交易接口测试数据工具 v1.0.0(获取深度行情数据、订阅取消订阅)

若该文为原创文章&#xff0c;转载请注明出处 本文章博客地址&#xff1a;https://hpzwl.blog.csdn.net/article/details/137937666 红胖子(红模仿)的博文大全&#xff1a;开发技术集合&#xff08;包含Qt实用技术、树莓派、三维、OpenCV、OpenGL、ffmpeg、OSG、单片机、软硬结…

VSCODE自定义代码片段简述与基础使用

目录 一、 简述二 、 基础使用说明2.1 新建一个代码块工作区间2.2 语法 三、 示例四、 参考链接 一、 简述 VSCode的自定义代码片段功能允许开发者根据自己的需求定义和使用自己的代码片段&#xff0c;从而提高编码效率。 优点: 提高效率&#xff1a; 自定义代码片段能够减少…

springboot+Vue实现分页

文章目录 一、后端二、前端 今天开发的有一个场景就是需要从远程ssh服务器上加载一个文件展示到前端&#xff0c;但是一次性拉过来有几万条数据&#xff0c;一下载加载整个文件会导致前端非常非常的卡&#xff0c;于是要使用分页解决&#xff0c;我之前看过的有mybatis的分页查…