深度学习(4):torch.nn.Module

文章目录

  • 一、是什么
  • 二、`nn.Module` 的核心功能
  • 三、`nn.Module` 的基本用法
    • 1. 定义自定义模型
    • 2. 初始化模型
      • 3. 模型的使用
  • 四、`nn.Module` 的关键特性
    • 1. 自动注册子模块和参数
    • 2. `forward` 方法
    • 3. 不需要定义反向传播
  • 五、常用的内置模块
  • 六、示例:创建一个简单的神经网络
    • 1. 问题描述
    • 2. 模型定义
    • 3. 训练过程
  • 七、深入理解 `nn.Module` 的一些重要概念
    • 1. 参数访问
    • 2. 模块访问
    • 3. 保存和加载模型
    • 4. 自定义层和模块
  • 八、`nn.Module` 的实践技巧
    • 1. 使用 `Sequential` 快速构建模型
    • 2. 模型的嵌套
  • 九、总结
    • 十、参考示例:完整的训练脚本

一、是什么

torch.nn.Module 是 PyTorch 中所有神经网络模块的基类,是构建神经网络模型的核心组件。

二、nn.Module 的核心功能

  1. 参数管理:自动管理模型的可训练参数(parameters),方便参数的访问和更新。

  2. 子模块管理:支持将模型分解为多个子模块,便于组织复杂的网络结构。

  3. 前向计算(forward):定义模型的前向传播逻辑。


三、nn.Module 的基本用法

1. 定义自定义模型

要创建自定义的神经网络模型,需要继承 nn.Module,并实现以下内容:

  • 构造函数 __init__:在这里定义网络的层和子模块。
  • 前向方法 forward:定义数据如何经过网络进行前向传播。
import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()# 定义网络层self.layer1 = nn.Linear(10, 20)self.relu = nn.ReLU()self.layer2 = nn.Linear(20, 1)def forward(self, x):# 定义前向传播过程out = self.layer1(x)out = self.relu(out)out = self.layer2(out)return out

2. 初始化模型

model = MyModel()

3. 模型的使用

  • 前向传播

    output = model(input_data)
    
  • 获取模型参数

    for name, param in model.named_parameters():print(name, param.size())
    

四、nn.Module 的关键特性

1. 自动注册子模块和参数

__init__ 方法中,当你将 nn.Module 的实例(如 nn.Linearnn.Conv2d 等)赋值给模型的属性时,nn.Module 会自动将这些子模块注册到模型中。这意味着:

  • 参数管理:模型的所有参数都会被自动收集,存储在 model.parameters() 中。
  • 子模块管理:可以通过 model.children()model.modules() 访问子模块。
class MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.fc = nn.Linear(10, 5)self.conv = nn.Conv2d(3, 16, kernel_size=3)model = MyModule()
print(list(model.parameters()))  # 自动包含了 fc 和 conv 的参数

2. forward 方法

forward 方法定义了模型的前向传播逻辑。在调用模型实例时,会自动调用 forward 方法。

output = model(input_data)  # 等价于 output = model.forward(input_data)

3. 不需要定义反向传播

在大多数情况下,不需要手动实现反向传播函数。PyTorch 的自动求导机制(autograd)会根据前向传播中的操作,自动计算梯度。

五、常用的内置模块

PyTorch 提供了大量的内置模块,继承自 nn.Module,可以直接使用:

  • 线性层nn.Linear
  • 卷积层nn.Conv1dnn.Conv2dnn.Conv3d
  • 循环神经网络nn.RNNnn.LSTMnn.GRU
  • 归一化层nn.BatchNorm1dnn.BatchNorm2d
  • 激活函数nn.ReLUnn.Sigmoidnn.Softmax
  • 损失函数nn.MSELossnn.CrossEntropyLoss

六、示例:创建一个简单的神经网络

1. 问题描述

创建一个多层感知机(MLP),用于对 MNIST 手写数字进行分类。

2. 模型定义

class MNISTClassifier(nn.Module):def __init__(self):super(MNISTClassifier, self).__init__()self.flatten = nn.Flatten()  # 将输入展开为一维self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 64)self.relu2 = nn.ReLU()self.fc3 = nn.Linear(64, 10)  # 输出10个类别的分数def forward(self, x):x = self.flatten(x)x = self.relu(self.fc1(x))x = self.relu2(self.fc2(x))x = self.fc3(x)return x

3. 训练过程

import torch.optim as optim# 初始化模型、损失函数和优化器
model = MNISTClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 假设有数据加载器 data_loader
for epoch in range(num_epochs):for images, labels in data_loader:# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

七、深入理解 nn.Module 的一些重要概念

1. 参数访问

  • parameters():返回一个生成器,包含模型所有可训练的参数。
  • named_parameters():返回一个生成器,生成 (name, parameter) 对,方便查看参数名称和形状。
for name, param in model.named_parameters():print(f'Parameter {name}: shape {param.shape}')

2. 模块访问

  • children():返回直接子模块的迭代器。
  • modules():返回自身及所有子模块的迭代器。
for child in model.children():print(child)for module in model.modules():print(module)

3. 保存和加载模型

  • 保存模型状态

    torch.save(model.state_dict(), 'model.pth')
    
  • 加载模型状态

    model = MNISTClassifier()
    model.load_state_dict(torch.load('model.pth'))
    

4. 自定义层和模块

通过继承 nn.Module,可以创建自定义的层或模块。

class CustomLayer(nn.Module):def __init__(self, in_features, out_features):super(CustomLayer, self).__init__()self.weight = nn.Parameter(torch.randn(in_features, out_features))self.bias = nn.Parameter(torch.zeros(out_features))def forward(self, x):return torch.matmul(x, self.weight) + self.bias

八、nn.Module 的实践技巧

1. 使用 Sequential 快速构建模型

对于简单的模型,可以使用 nn.Sequential 将多个层按顺序组合。

model = nn.Sequential(nn.Flatten(),nn.Linear(28 * 28, 128),nn.ReLU(),nn.Linear(128, 64),nn.ReLU(),nn.Linear(64, 10)
)

2. 模型的嵌套

可以将模块嵌套使用,构建复杂的网络结构。

class ComplexModel(nn.Module):def __init__(self):super(ComplexModel, self).__init__()self.block1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3),nn.ReLU())self.block2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3),nn.ReLU())self.fc = nn.Linear(64 * 24 * 24, 10)def forward(self, x):x = self.block1(x)x = self.block2(x)x = x.view(x.size(0), -1)  # 展平x = self.fc(x)return x

九、总结

  • nn.Module 是 PyTorch 构建神经网络的基础,提供了参数管理、子模块管理和前向传播等功能。
  • 通过继承 nn.Module,可以方便地创建自定义模型或层,满足各种复杂的需求。
  • 在使用 nn.Module 时,注意正确地定义 __init__forward 方法,并确保在 forward 方法中定义前向计算逻辑。
  • PyTorch 提供了大量的内置模块,可以直接使用或作为自定义模块的基石。
  • 善于利用 nn.Module 的特性和工具,可以大大提高模型开发的效率和代码的可读性。

十、参考示例:完整的训练脚本

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 定义超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 5# 数据集和数据加载器
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)# 定义模型
class MNISTClassifier(nn.Module):def __init__(self):super(MNISTClassifier, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 64)self.relu2 = nn.ReLU()self.fc3 = nn.Linear(64, 10)def forward(self, x):x = self.flatten(x)x = self.relu(self.fc1(x))x = self.relu2(self.fc2(x))x = self.fc3(x)return x# 初始化模型、损失函数和优化器
model = MNISTClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 训练模型
for epoch in range(num_epochs):for images, labels in train_loader:# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 保存模型
torch.save(model.state_dict(), 'mnist_classifier.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/55003.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

全网最全软件测试面试题(含答案解析+文档)

一、软件测试基础面试题 1、阐述软件生命周期都有哪些阶段? 常见的软件生命周期模型有哪些? 软件生命周期是指一个计算机软件从功能确定设计,到开发成功投入使用,并在使用中不断地修改、增补和完善,直到停止该软件的使用的全过程(从酝酿到…

YOLO V8半自动标注工具设计

前提: 对于某些边界不明确的小目标,要是目标由比较多的话,标注起来就会非常麻烦。 如何利用已有训练模型,生成框,进行预标注。再通过调节预标注框的方式,提高标注的效率。 1 通过预先训练的模型生成yolo 格…

一文上手SpringSecurity【七】

之前我们在测试的时候,都是使用的字符串充当用户名称和密码,本篇将其换成MySQL数据库. 一、替换为真实的MySQL 1.1 引入依赖 <dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</artifactId><version>8.0.33</v…

Jenkins Pipeline 中通过勾选参数来控制是否构建 Docker 镜像

1.定义参数&#xff1a; 使用 booleanParam 定义一个布尔参数&#xff0c;示例如下 booleanParam(name: BUILD_DOCKER, description: 是否构建Docker镜像, defaultValue: false)2.使用参数&#xff1a; 在 stage 中&#xff0c;根据参数的值决定构建方式&#xff1a; stage(编…

python基础库

文章目录 1.研究目的2.platform库介绍3.代码4.结果展示 1.研究目的 最近项目中需要利用python获取计算机硬件的一些基本信息,查阅资料,.于是写下这篇简短的博客,有问题烦请提出,谢谢-_- 2.platform库介绍 platform 库是 Python 的一个内置库&#xff0c;可以让我们轻松地获取…

spring boot 项目中redis的使用,key=value值 如何用命令行来查询并设置值。

1、有一个老项目&#xff0c;用到了网易云信&#xff0c;然后这里面有一个AppKey&#xff0c;然后调用的时候要在header中加入这些标识&#xff0c;进行与服务器进行交互。 2、开发将其存在了redis中&#xff0c;一开始的时候&#xff0c;我们测试用的老的key&#xff0c;然后提…

ValueError: Out of range float values are not JSON compliant

可能原因一 可能原因二 数据里面有NaN

算法: 滑动窗口题目练习

文章目录 滑动窗口长度最小的子数组无重复字符的最长子串最大连续1个个数 III将x减到0的最小操作数水果成篮找到字符串中所有字母异位词串联所有单词的子串最小覆盖子串 总结 滑动窗口 长度最小的子数组 做这道题时,脑子里大概有个印象,知道要用滑动窗口,但是对于滑动窗口为什…

2016年国赛高教杯数学建模D题风电场运行状况分析及优化解题全过程文档及程序

2016年国赛高教杯数学建模 D题风电场运行状况分析及优化 风能是一种最具活力的可再生能源&#xff0c;风力发电是风能最主要的应用形式。我国某风电场已先后进行了一、二期建设&#xff0c;现有风机124台&#xff0c;总装机容量约20万千瓦。请建立数学模型&#xff0c;解决以下…

探索私有化聊天软件:即时通讯与音视频技术的结合

在数字化转型的浪潮中&#xff0c;企业对于高效、安全、定制化的通讯解决方案的需求日益迫切。鲸信&#xff0c;作为音视频通信技术的佼佼者&#xff0c;凭借其强大的即时通讯与音视频SDK&#xff08;软件开发工具包&#xff09;结合能力&#xff0c;为企业量身打造了私有化聊天…

MySQL Mail服务器集成:如何配置发送邮件?

MySQL Mail插件使用指南&#xff1f;怎么优化 MySQL发邮件性能&#xff1f; MySQL Mail服务器的集成&#xff0c;使得数据库可以直接触发邮件发送&#xff0c;极大地简化了应用架构。AokSend将详细介绍如何配置MySQL Mail服务器&#xff0c;以实现邮件发送功能。 MySQL Mail&…

【YashanDB知识库】如何配置jdbc驱动使getDatabaseProductName()返回Oracle

本文转自YashanDB官网&#xff0c;具体内容请见https://www.yashandb.com/newsinfo/7352676.html?templateId1718516 问题现象 某些三方件&#xff0c;例如 工作流引擎activiti&#xff0c;暂未适配yashandb&#xff0c;使用中会出现如下异常&#xff1a; 问题的风险及影响 …

【STM32】江科大STM32笔记汇总(已完结)

STM32江科大笔记汇总 STM32学习笔记课程简介(01)STM32简介(02)软件安装(03)新建工程(04)GPIO输出(05)LED闪烁& LED流水灯& 蜂鸣器(06)GPIO输入(07)按键控制LED 光敏传感器控制蜂鸣器(08)OLED调试工具(09)OLED显示屏(10)EXTI外部中断(11)对射式红外传感器计次 旋转编码器…

K8S服务发布

一 、服务发布方式对比 二者主要区别在于&#xff1a; 1、部署复杂性&#xff1a;传统的服务发布方式通常涉及手动配置 和管理服务器、网络设置、负载均衡等&#xff0c;过程相对复 杂且容易出错。相比之下&#xff0c;Kubernetes服务发布方式 通过使用容器编排和自动化部署工…

QT----Creater14.0,qt5.15无法启动调试,Launching GDB Debugger报红

问题描述 使用QT Creater 14.0 和qt5.15,无法启动调试也没有报错,加载debugger报红 相关文件都有 解决方案 尝试重装QT,更换版本5.15.2,下载到文件夹,shift鼠标右键打开powershell输入 .\qt-online-installer-windows-x64-4.8.0.exe --mirror http://mirrors.ustc.edu.cn…

解决fatal: unable to access ‘https://........git/‘: Recv failure: Operation time

目录 前言 解决方法一 解决方法二 解决方法三 解决方法四 总结 前言 在使用 Git 进行代码拉取时&#xff0c;可能会遇到连接超时的问题&#xff0c;特别是在某些网络环境下&#xff0c;例如公司网络或防火墙严格的环境中。这种情况下&#xff0c;Git 无法访问远程仓…

OpenHarmony(鸿蒙南向)——平台驱动指南【DAC】

往期知识点记录&#xff1a; 鸿蒙&#xff08;HarmonyOS&#xff09;应用层开发&#xff08;北向&#xff09;知识点汇总 鸿蒙&#xff08;OpenHarmony&#xff09;南向开发保姆级知识点汇总~ 持续更新中…… 概述 功能简介 DAC&#xff08;Digital to Analog Converter&…

LLM - 使用 RAG (检索增强生成) 多路召回 实现 精准知识问答 教程

欢迎关注我的CSDN&#xff1a;https://spike.blog.csdn.net/ 本文地址&#xff1a;https://spike.blog.csdn.net/article/details/142629289 免责声明&#xff1a;本文来源于个人知识与公开资料&#xff0c;仅用于学术交流&#xff0c;欢迎讨论&#xff0c;不支持转载。 RAG (R…

windows下 Winobj.exe工具使用说明c++

1、winobj.exe工具下载地址 WinObj - Sysinternals | Microsoft Learn 2、接下来用winobj.exe查看全局互斥&#xff0c;先写一个小例子 #include <iostream> #include <stdlib.h> #include <tchar.h> #include <string> #include <windows.h>…

VS2017安装Installer Projects制作Setup包

下载安装扩展包 VS2017默认未安装Installer Projects Package&#xff0c;需要联机下载&#xff1a; 也可网页上下载离线InstallerProjects.vsix文件&#xff1a; https://visualstudioclient.gallerycdn.vsassets.io/extensions/visualstudioclient/microsoftvisualstudio20…