PyTorch 的 torch.nn 模块学习

torch.nn 是 PyTorch 中专门用于构建和训练神经网络的模块。它的整体架构分为几个主要部分,每部分的原理、要点和使用场景如下:

1. nn.Module

  • 原理和要点nn.Module 是所有神经网络组件的基类。任何神经网络模型都应该继承 nn.Module,并实现其 forward 方法。
  • 使用场景:用于定义和管理神经网络模型,包括层、损失函数和自定义的前向传播逻辑。
import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.linear = nn.Linear(10, 1)def forward(self, x):return self.linear(x)model = MyModel()
print(model)

2. Layers(层)

  • 原理和要点:层是神经网络的基本构建块,包括全连接层、卷积层、池化层等。每种层执行特定类型的操作,并包含可学习的参数。
  • 使用场景:用于构建神经网络的各个组成部分,如特征提取、降维等。
2.1 nn.Linear(全连接层)
linear = nn.Linear(10, 5)
input = torch.randn(1, 10)
output = linear(input)
print(output)
2.2 nn.Conv2d(二维卷积层)
conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
input = torch.randn(1, 1, 5, 5)
output = conv(input)
print(output)
2.3 nn.MaxPool2d(二维最大池化层)
maxpool = nn.MaxPool2d(kernel_size=2)
input = torch.randn(1, 1, 4, 4)
output = maxpool(input)
print(output)

3. Loss Functions(损失函数)

  • 原理和要点:损失函数用于衡量模型预测与真实值之间的差异,指导模型优化过程。
  • 使用场景:用于计算训练过程中需要最小化的误差。
3.1 nn.MSELoss(均方误差损失)
mse_loss = nn.MSELoss()
input = torch.randn(3, 5)
target = torch.randn(3, 5)
loss = mse_loss(input, target)
print(loss)
3.2 nn.CrossEntropyLoss(交叉熵损失)
cross_entropy_loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5)
target = torch.tensor([1, 0, 4])
loss = cross_entropy_loss(input, target)
print(loss)

4. Optimizers(优化器)

  • 原理和要点:优化器用于调整模型参数,以最小化损失函数。
  • 使用场景:用于训练模型,通过反向传播更新参数。
4.1 torch.optim.SGD(随机梯度下降)
import torch.optim as optimmodel = MyModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()# Training loop
for epoch in range(100):optimizer.zero_grad()output = model(torch.randn(1, 10))loss = criterion(output, torch.randn(1, 1))loss.backward()optimizer.step()
4.2 torch.optim.Adam(自适应矩估计)
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(100):optimizer.zero_grad()output = model(torch.randn(1, 10))loss = criterion(output, torch.randn(1, 1))loss.backward()optimizer.step()

5. Activation Functions(激活函数)

  • 原理和要点:激活函数引入非线性,使模型能够拟合复杂的函数。
  • 使用场景:用于激活输入,增加模型表达能力。
5.1 nn.ReLU(修正线性单元)
relu = nn.ReLU()
input = torch.randn(2)
output = relu(input)
print(output)

6. Normalization Layers(归一化层)

  • 原理和要点:归一化层用于标准化输入,改善训练的稳定性和速度。
  • 使用场景:用于标准化激活值,防止梯度爆炸或消失。
6.1 nn.BatchNorm2d(二维批量归一化)
batch_norm = nn.BatchNorm2d(3)
input = torch.randn(1, 3, 5, 5)
output = batch_norm(input)
print(output)

7. Dropout Layers(丢弃层)

  • 原理和要点:Dropout 层通过在训练过程中随机丢弃一部分神经元来防止过拟合。
  • 使用场景:用于防止模型过拟合,增加模型的泛化能力。
7.1 nn.Dropout
dropout = nn.Dropout(p=0.5)
input = torch.randn(2, 3)
output = dropout(input)
print(output)

8. Container Modules(容器模块)

  • 原理和要点:容器模块用于组合多个层,构建复杂的神经网络结构。
  • 使用场景:用于组合多个层,形成更复杂的网络结构。
8.1 nn.Sequential(顺序容器)
model = nn.Sequential(nn.Linear(10, 20),nn.ReLU(),nn.Linear(20, 5)
)
input = torch.randn(1, 10)
output = model(input)
print(output)
8.2 nn.ModuleList(模块列表)
layers = nn.ModuleList([nn.Linear(10, 20),nn.ReLU(),nn.Linear(20, 5)
])input = torch.randn(1, 10)
for layer in layers:input = layer(input)
print(input)

9. Functional API (torch.nn.functional)

  • 原理和要点:包含大量用于深度学习的无状态函数,这些函数通常是操作层的底层实现。
  • 使用场景:用于在前向传播中灵活调用函数。
9.1 F.relu(ReLU 激活函数)
import torch.nn.functional as Finput = torch.randn(2)
output = F.relu(input)
print(output)
9.2 F.cross_entropy(交叉熵损失函数)
input = torch.randn(3, 5)
target = torch.tensor([1, 0, 4])
loss = F.cross_entropy(input, target)
print(loss)
9.3 F.conv2d(二维卷积)
input = torch.randn(1, 1, 5, 5)
weight = torch.randn(3, 1, 3, 3)  # Manually defined weights
output = F.conv2d(input, weight)
print(output)

10. Parameter (torch.nn.Parameter)

  • 原理和要点torch.nn.Parametertorch.Tensor 的一种特殊子类,用于表示模型的可学习参数。它们在 nn.Module 中会自动注册为参数。
  • 使用场景:用于定义模型中的可学习参数。
示例代码:
class MyModelWithParam(nn.Module):def __init__(self):super(MyModelWithParam, self).__init__()self.my_param = nn.Parameter(torch.randn(10, 10))def forward(self, x):return x @ self.my_parammodel = MyModelWithParam()
input = torch.randn(1, 10)
output = model(input)
print(output)# 查看模型参数
for name, param in model.named_parameters():print(name, param.size())

综合示例

下面是一个结合上述各个部分的综合示例:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimclass MyComplexModel(nn.Module):def __init__(self):super(MyComplexModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3)self.bn1 = nn.BatchNorm2d(32)self.conv2 = nn.Conv2d(32, 64, kernel_size=3)self.bn2 = nn.BatchNorm2d(64)self.dropout = nn.Dropout(0.25)self.fc1 = nn.Linear(64*12*12, 128)self.fc2 = nn.Linear(128, 10)self.custom_param = nn.Parameter(torch.randn(128, 128))def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = F.max_pool2d(x, 2)x = F.relu(self.bn2(self.conv2(x)))x = F.max_pool2d(x, 2)x = self.dropout(x)x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = x @ self.custom_paramx = self.fc2(x)return F.log_softmax(x, dim=1)model = MyComplexModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)for epoch in range(10):optimizer.zero_grad()input = torch.randn(64, 1, 28, 28)target = torch.randint(0, 10, (64,))output = model(input)loss = criterion(output, target)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')

通过以上示例,可以更清晰地理解 torch.nn 模块的整体架构、原理、要点及其具体使用场景。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/20213.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

compare_exchange 基本使用

参考博客:C内存模型 compare_exchange_weak基本使用 bool compare_exchange_weak( T& expected, T desired,std::memory_order success,std::memory_order failure );expected:期望的值desired 想要写入的值 如果obj和期望的值相同,则写入desired并…

如何培养技术人员的管理能力?

随着企业发展的需求不断增长,对于一专多能的复合型人才的需求也日益增加。这种人才既拥有技术实力,又具备出色的管理能力。尤其对于高新技术企业而言,技术骨干往往更有机会成为管理人员。一方面是因为技术骨干在自己的岗位上展现出了核心技术…

「C系列」C 简介

文章目录 一、C 简介1. C语言的主要特点:2. C语言的应用领域:3. 学习C语言的建议: 二、C 环境设置、编辑器1. C环境设置2. 编辑器选择3. 总结 三、C第一个案例四、相关链接 一、C 简介 C语言是一种通用的、过程式的计算机编程语言&#xff0…

通用代码生成器应用场景四,跨编程语言翻译

通用代码生成器应用场景四,跨编程语言翻译 如果您有一个Java工程,想把它移植到Rust或Golang语言中去,希望尽可能加快研发速度。 如果您的系统是通用代码生成器开发的,保留了系统的SGS源文件或者SGS2的Excel模板,您可…

探索未来电商视觉革命:Doly,AI驱动的3D产品宣传短片一键生成器

在数字化营销日新月异的今天,产品展示的视觉冲击力已成为电商平台吸引消费者的关键。Doly,由法国创新先驱AniML匠心打造,正引领一场AI与3D技术融合的电商内容创新风暴,让每一位电商卖家都能轻松拥有好莱坞级别的产品宣传短片,只需简单几步,即可在激烈的市场竞争中脱颖而出…

结构体 基础知识

本笔记为观看64 结构体-结构体定义和使用_哔哩哔哩_bilibili 的学习笔记 1.结构体概念 结构体属于用户自定义的数据类型,允许用户存储不同的数据类型。 2.结构体定义和使用 ​ 结构体定义 ​ 通过结构体创建变量的方式 2.1 Struct 结构体名 变量名 ​ 2…

Springboot 开发 -- 跨域问题技术详解

一、跨域的概念 跨域访问问题指的是在客户端浏览器中,由于安全策略的限制,不允许从一个源(域名、协议、端口)直接访问另一个源的资源。当浏览器发起一个跨域请求时,会被浏览器拦截,并阻止数据的传输。 这…

【算法】MT2 棋子翻转

✨题目链接: MT2 棋子翻转 ✨题目描述 在 4x4 的棋盘上摆满了黑白棋子,黑白两色棋子的位置和数目随机,其中0代表白色,1代表黑色;左上角坐标为 (1,1) ,右下角坐标为 (4,4) 。 现在依次有一些翻转操作&#…

“迎七一、学党史、祭英烈”活动在孙善师孙善帅烈士故居启动

临沂信息联播讯(张春兄、冯爱云) 5月30日,山东省著名烈士孙善师孙善帅故居迎来了山东全味时间企业管理咨询服务有限公司、志林丽虹沂蒙文化传播(临沂)有限公司、山东志林搏击健身有限公司的参观团队,标志着…

【WEEK14】 【DAY4】Swagger第二部分【中文版】

2024.5.30 Thursday 接上文【WEEK14】 【DAY3】Swagger第一部分【中文版】 目录 16.4.配置扫描接口16.4.1.修改SwaggerConfig.java16.4.1.1.使用.basePackage()方法指定扫描的包路径16.4.1.2.其他扫描方式均可在RequestHandlerSelectors.class中查看源码 16.4.2.仍然是修改Swag…

这个夏天,凶险如昨?

回望2023年三季度的“美债风暴”,当时美债收益率狂飙突破5%,阴霾笼罩下全球风险资产一片惨淡,这一场景会在今夏再度上演吗? 本周美债遭遇抛售,10年期收益率上破4.6%,2年期收益率逼近5%关口,收益…

mongodb 增删改查

使用MongoTemplate的updateFirst()或updateMulti()方法 MongoTemplate提供了更底层的访问MongoDB的API,允许你执行更复杂的更新操作。updateFirst()方法会更新找到的第一个匹配的文档,而updateMulti()会更新所有匹配的文档。 javaimport org.springfram…

【Android】点击图片获取点击位置在图片中的位置

需求 在一个页面中,有一张图片展示,这个页面是一个可滑动页面,但是当点击到这个图片里面的位置的时候,我们需要获取到这个点击位置在图片的哪个位置,即获取到点击点与图片当前的相对位置。 分析 我们在屏幕上可以通…

linux磁盘满了,如何查找大文件清除?

将整个Linux中文件按照文件大小排序,从大到小排序 只显示前100条数据 命令: find / -type f -exec du -h {} | sort -rh | head -n 100结果:

全栈工程师需要具备哪些技能?

概论: 全栈工程师是一位能够从头到尾构建 Web 应用程序的工程师,能独立完成产品。技术包括前端部分、后端部分和应用程序所在的基础架构。他们在整个技术栈中工作,并了解其中的每个部分。从需求分析开始,到概要设计,详…

HarmonyOS鸿蒙学习笔记(25)相对布局 RelativeContainer详细说明

RelativeContainer 简介 前言核心概念官方实例官方实例改造蓝色方块改造center 属性说明参考资料 前言 RelativeContainer是鸿蒙的相对布局组件,它的布局很灵活,可以很方便的控制各个子UI 组件的相对位置,其布局理念有点类似于android的约束…

270 基于matlab的模糊自适应PID控制

基于matlab的模糊自适应PID控制,具有10页报告。传统PID在对象变化时,控制器的参数难以自动调整。将模糊控制与PID控制结合,利用模糊推理方法实现对PID参数的在线自整定。使控制器具有较好的自适应性。使用MATLAB对系统进行仿真,结…

如何配置云WAF以实现更有效的流量分发

云WAF流量分发功能介绍 云WAF(Web Application Firewall)是一种基于云计算环境的Web应用安全防护服务,主要用于保护Web应用程序免受各种网络攻击,如SQL注入、跨站脚本(XSS)、分布式拒绝服务(DD…

前后端交互:axios 和 json;springboot 和 vue

vue 准备的 <template><div><button click"sendData">发送数据</button><button click"getData">接收</button><button click"refresh">刷新</button><br><ul v-if"questions&…

win10系统下WPS工具显示灰色全部用不了,提示登录

如果你在写文档或使用excel时发现导航栏的工具全部使用不了&#xff0c;弹出是需要您登录&#xff0c;可以通过以下操作不用登录。 按照 1&#xff08;搜索框&#xff09;—> 2&#xff08;应用&#xff09;—> 3&#xff08;WPS Office&#xff09;点鼠标左键—> 4&a…