pytorch学习day3

一、模型创建(Module)

网络创建流程

上面的图表展示了使用PyTorch创建神经网络模型的主要步骤。每个步骤按顺序连接,展示了从导入必要的库到最终测试模型的整个流程:

  1. 导入必要的库:首先导入PyTorch及其相关模块。
  2. 定义网络结构:通过继承 nn.Module 类定义神经网络的层和前向传播过程。
  3. 实例化模型:使用定义的结构实例化模型对象。
  4. 定义损失函数和优化器:选择并定义损失函数和优化器。
  5. 准备数据:加载并预处理数据,创建数据加载器。
  6. 训练模型:通过训练循环进行前向传播、计算误差和反向传播更新权重。
  7. 测试模型:在测试数据上评估模型的性能。

模型构建的两个要素

在PyTorch中,构建神经网络模型的关键在于两个要素:构建子模块拼接子模块。这两个要素分别在模型类的 __init__() 方法和 forward() 方法中实现。

1. 构建子模块

在自定义模型中,通过继承 nn.Module 类,并在 __init__() 方法中定义子模块。这些子模块通常是神经网络的各层,例如卷积层、全连接层、激活函数等。

示例:

import torch.nn as nnclass CustomModel(nn.Module):def __init__(self):super(CustomModel, self).__init__()# 定义子模块self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(in_features=16*16*16, out_features=120)self.fc2 = nn.Linear(in_features=120, out_features=84)self.fc3 = nn.Linear(in_features=84, out_features=10)

在上面的代码中,我们定义了一个卷积层 conv1,一个池化层 pool,以及三个全连接层 fc1fc2fc3。这些子模块是模型的基本组成部分。

2. 拼接子模块

forward() 方法中定义子模块的拼接方式。forward() 方法描述了输入数据如何经过这些子模块的传递过程,最终输出结果。

示例:

class CustomModel(nn.Module):def __init__(self):super(CustomModel, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(in_features=16*16*16, out_features=120)self.fc2 = nn.Linear(in_features=120, out_features=84)self.fc3 = nn.Linear(in_features=84, out_features=10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))  # 拼接卷积层和池化层x = x.view(-1, 16*16*16)              # 展平张量x = F.relu(self.fc1(x))               # 拼接第一个全连接层x = F.relu(self.fc2(x))               # 拼接第二个全连接层x = self.fc3(x)                       # 拼接第三个全连接层return x

forward() 方法中,我们定义了输入数据的传递路径。数据首先通过卷积层 conv1 和池化层 pool,然后展平为一维张量,依次通过三个全连接层 fc1fc2fc3,最后输出结果。

通过这两个步骤,我们可以构建出一个功能齐全的神经网络模型。以下是流程图,帮助理解这两个要素在模型构建中的位置和作用。

当然,以下是关于模型构建的两个要素的表格,可以直接复制使用:

| 模型构建的两个要素 |                  描述                  |
|--------------------|----------------------------------------|
| 构建子模块         | 在自定义模型(继承 nn.Module)的 `__init__()` 方法中定义各个层(卷积层、池化层、全连接层等)|
| 拼接子模块         | 在自定义模型的 `forward()` 方法中定义层的连接方式,描述前向传播过程              |

这个表格简明地展示了模型构建的两个关键步骤和它们分别在哪个方法中实现。希望这能帮助你更好地理解和使用PyTorch进行模型构建。

通过以上两个步骤,我们可以灵活地定义各种复杂的神经网络模型,并通过 forward() 方法灵活地组合这些子模块,实现数据的前向传播过程。

二、nn.Mudule的属性

nn.Module 是 PyTorch 中所有神经网络模块的基类。它提供了一些关键属性和方法,用于构建和管理神经网络模型。以下是 nn.Module 的一些重要属性和方法:

1. parameters()

  • 描述:返回模型所有参数的迭代器。
  • 用途:通常用于优化器来获取模型参数进行训练。
for param in model.parameters():print(param.size())

2. named_parameters()

  • 描述:返回一个包含模型参数名字和参数本身的迭代器。
  • 用途:当你需要获取特定层的参数时特别有用。
for param in model.parameters():print(param.size())

3. children()

  • 描述:返回模型所有子模块的迭代器。
  • 用途:用于递归遍历模型的各个子模块。
for child in model.children():print(child)

4. named_children()

  • 描述:返回一个包含模型子模块名字和子模块本身的迭代器。
  • 用途:用于详细查看每个子模块。
for name, child in model.named_children():print(name, child)

5. modules()

  • 描述:返回模型所有模块(包括模型本身和其子模块)的迭代器。
  • 用途:用于遍历所有模块。
for module in model.modules():print(module)

6. named_modules()

  • 描述:返回一个包含模型模块名字和模块本身的迭代器。
  • 用途:当你需要以层级结构查看所有模块时使用。
  • for name, module in model.named_modules():print(name, module)

7. add_module(name, module)

  • 描述:将一个子模块添加到当前模块。
  • 用途:动态地添加子模块。
model.add_module('extra_layer', nn.Linear(10, 10))

8. forward()

  • 描述:定义前向传播逻辑。用户需要在自己的子类中重载这个方法。
  • 用途:定义输入数据如何通过网络层进行传递。
def forward(self, x):x = self.layer1(x)x = self.layer2(x)return x

9. train(mode=True)

  • 描述:将模块设置为训练模式。
  • 用途:启用或禁用 Dropout 和 BatchNorm。
model.train()  # 设置为训练模式
model.eval()   # 设置为评估模式

10. zero_grad()

  • 描述:将所有模型参数的梯度清零。
  • 用途:在每次反向传播前清除旧的梯度。
model.zero_grad()

这些属性和方法提供了强大的功能,使得 nn.Module 能够灵活且高效地管理神经网络模型。通过这些接口,你可以构建、管理和训练复杂的神经网络。

三、模型容器Containers

模型容器(Containers)

在 PyTorch 中,模型容器(Containers)是用于组织和管理神经网络层的一种方式。通过使用模型容器,可以更方便地构建和管理复杂的神经网络结构。以下是 PyTorch 中常用的几种模型容器:

1. nn.Sequential

描述:

nn.Sequential 是一个按顺序执行子模块的容器。它将子模块按定义顺序串联起来,适合用于简单的前向传播模型。

用途:

用于快速构建按顺序堆叠的网络结构,例如多层感知机(MLP)和简单的卷积神经网络(CNN)。

示例:

import torch.nn as nnmodel = nn.Sequential(nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU()
)

在这个例子中,输入数据依次通过两个卷积层和两个 ReLU 激活函数。

2. nn.ModuleList

描述:

nn.ModuleList 是一个存储子模块的有序列表,但并没有定义前向传播的具体顺序。它主要用于需要灵活前向传播定义的模型。

用途:

适用于需要在前向传播过程中动态选择层或者有条件执行层的情况。

示例:

import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.layers = nn.ModuleList([nn.Conv2d(1, 20, 5), nn.Conv2d(20, 64, 5)])def forward(self, x):for layer in self.layers:x = layer(x)return x

在这个例子中,layers 存储了两个卷积层,并在 forward 方法中以循环的方式应用它们。

3. nn.ModuleDict

描述:

nn.ModuleDict 是一个存储子模块的字典,可以使用键来访问子模块。它提供了灵活的模块管理方式,可以通过键值对的方式存取模块。

用途:

适用于需要命名访问子模块,且不需要严格的前向传播顺序的情况,例如多分支的模型结构。

示例:

import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.layers = nn.ModuleDict({'conv1': nn.Conv2d(1, 20, 5),'conv2': nn.Conv2d(20, 64, 5)})def forward(self, x):x = self.layers['conv1'](x)x = self.layers['conv2'](x)return x

在这个例子中,layers 存储了两个卷积层,可以通过键名 'conv1''conv2' 进行访问。

4. nn.ParameterListnn.ParameterDict

描述:

这两个容器分别用于存储参数列表和参数字典,与 ModuleListModuleDict 类似,但它们存储的是参数而不是模块。

用途:

适用于需要灵活管理模型参数的情况。

示例:

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(3)])self.param_dict = nn.ParameterDict({'param1': nn.Parameter(torch.randn(10, 10)),'param2': nn.Parameter(torch.randn(10, 10))})def forward(self, x):# 使用 self.params 和 self.param_dict 进行前向传播pass

在这个例子中,params 存储了三个参数,而 param_dict 则存储了两个命名参数。

4 总结

通过使用这些模型容器,PyTorch 提供了灵活且高效的方式来组织和管理神经网络模型的层和参数。nn.Sequential 适用于简单的顺序结构,nn.ModuleListnn.ModuleDict 提供了更多的灵活性,适用于更复杂的网络结构。nn.ParameterListnn.ParameterDict 则用于更灵活的参数管理。利用这些容器,可以更方便地构建和管理复杂的神经网络模型。

5 实现一个简单VGG网络

创建一个简单的VGG网络

VGG网络是一种深度卷积神经网络,因其简单且具有良好的性能而广泛应用。下面我们利用PyTorch提供的模型容器,构建一个简化版的VGG网络。我们将主要使用nn.Sequential来按顺序堆叠卷积层和全连接层。

相关论文地址:https://arxiv.org/abs/1409.1556

1. 导入必要的库

import torch
import torch.nn as nn
import torch.nn.functional as F

2. 定义VGG块

VGG块由多个卷积层和一个池化层组成。我们定义一个函数来创建这些块。

def vgg_block(num_convs, in_channels, out_channels):layers = []for _ in range(num_convs):layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))layers.append(nn.ReLU(inplace=True))in_channels = out_channelslayers.append(nn.MaxPool2d(kernel_size=2, stride=2))return nn.Sequential(*layers)

3. 定义VGG网络

我们利用nn.Sequential来堆叠多个VGG块,最后添加全连接层。

class SimpleVGG(nn.Module):def __init__(self):super(SimpleVGG, self).__init__()self.features = nn.Sequential(vgg_block(2, 3, 64),vgg_block(2, 64, 128),vgg_block(3, 128, 256),vgg_block(3, 256, 512),vgg_block(3, 512, 512))self.classifier = nn.Sequential(nn.Linear(512*7*7, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, 10))def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)return x

4. 实例化和测试模型

我们创建模型实例并打印其结构,确保其正确性。

model = SimpleVGG()
print(model)

5. 测试模型结构

为了确保模型构建正确,我们可以打印模型结构或者传递一个随机张量进行测试。

if __name__ == "__main__":model = SimpleVGG()print(model)# 测试输入数据input_tensor = torch.randn(1, 3, 224, 224)output = model(input_tensor)print(output.shape)  # 应输出 torch.Size([1, 10])

通过这些步骤,我们利用PyTorch提供的模型容器创建了一个简化版的VGG网络。这个网络由五个VGG块和三个全连接层组成,适用于图像分类任务。根据需求可以进一步调整网络结构和参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/845251.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Tlias智能学习辅助系统】03 部门管理 前后端联调

Tlias智能学习辅助系统 03 部门管理 前后端联调 前端环境 前端环境 链接:https://pan.quark.cn/s/8720156ed6bf 提取码:aGeR 解压后放在一个不包含中文的文件夹下,双击 nginx.exe 启动服务 跨域的问题已经被nginx代理转发了,所以…

vs code 中使用SSH 连接远程的Ubuntu系统

如下图,找到对应的位置 在电脑上找到以下位置 打开配置如下,记住,那个root为你的用户名,这个用户名,具体根据你的用户名来设置,对应的密码就是你登录Ubuntu时的密码 Host root192.168.0.64User rootHostNa…

第N3周:Pytorch文本分类入门

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制🚀 文章来源:K同学的学习圈子 这里借用K同学的一张图片大致说明本次任务流程。 1.本次所用AG News数据集介绍 AG…

汽车IVI中控开发入门及进阶(二十三):i.MX8

前言: IVI市场的复杂性急剧增加,而TimeToMarket在几代产品中从5年减少到2-3年。Tier1正在接近开放系统的模型(用户可以安装应用程序),从专有/关闭源代码到标准接口/开放源代码,从软件堆栈对系统体系结构/应用层/系统验证和鉴定的完全所有权,越来越依赖第三方中间件和平…

优思学院|作为质量工程师,需要考哪些证书?别浪费你的气力,一张就够!

质量工程师做什么呢?他们的主要任务就是确保产品和服务的质量,以满足客户需求并超越竞争对手。尽管市场上有各种各样的质量管理认证,但优思学院认为,专注于六西格玛的学习和认证就足够了。 为什么选择六西格玛? 第一…

C++青少年简明教程:While和Do-while循环语句

C青少年简明教程:While和Do-while循环语句 C的while和do-while语句都是循环控制语句,用于重复执行一段代码。while语句在循环开始前检查循环条件,而do-while语句在循环结束后检查循环条件。 使用while循环时,如果需要在每次迭代前…

12V升20V3.5A升压恒压WT3207

12V升20V3.5A升压恒压WT3207 WT3207是一款高效PWM升压控制器,采用SO-8封装设计,专为低输入电压应用优化。该控制器支持5V至36V的宽输入电压范围,使其能够有效提升12V、15V和19V系统的电压水平,特别适合于两节或三节锂离子电池供电…

LabVIEW超声波局部放电检测系统开发

LabVIEW超声波局部放电检测系统开发 在高压电力系统中,局部放电(PD)是导致绝缘失效的主要原因之一。局部放电的检测对于确保电力系统的可靠运行至关重要。开发了一种基于LabVIEW软件的超声波局部放电检测系统的设计与实现。该系统利用数字信号处理技术,…

Python | Leetcode Python题解之第119题杨辉三角II

题目&#xff1a; 题解&#xff1a; class Solution:def getRow(self, rowIndex: int) -> List[int]:row [1, 1]if rowIndex < 1:return row[:rowIndex 1]elif rowIndex > 2:for i in range(rowIndex - 1):row [row[j] row[j 1] for j in range(i 1)]row.inser…

【Python入门学习笔记】Python3超详细的入门学习笔记,非常详细(适合小白入门学习)

Python3基础 想要获取pdf或markdown格式的笔记文件点击以下链接获取 Python入门学习笔记点击我获取 1&#xff0c;Python3 基础语法 1-1 编码 默认情况下&#xff0c;Python 3 源码文件以 UTF-8 编码&#xff0c;所有字符串都是 unicode 字符串。 当然你也可以为源码文件指…

【数据结构】二叉树运用及相关例题

文章目录 前言查第K层的节点个数判断该二叉树是否为完全二叉树例题一 - Leetcode - 226反转二叉树例题一 - Leetcode - 110平衡二叉树 前言 在笔者的前几篇篇博客中介绍了二叉树的基本概念及基本实现方法&#xff0c;有兴趣的朋友自己移步看看。 这篇文章主要介绍一下二叉树的…

iframe内嵌网页自适应缩放 以展示源网页的比例尺寸

需求:这是我最近开发的低代码平台遇到的需求 ,要求将配置好的应用在弹框中预览(将预览网页内嵌入弹框中) 但是内嵌进入后 他会截取一部分(我源网站网页尺寸 是1980x1080 或者 3060X2160等等) 但是我这个dialog弹框只有我自定义的1000多px的宽高 他只会展示我iframe网页的一部分…

Linux - 磁盘管理1

1.磁盘的分区 1.1 磁盘的类型&#xff08;标签&#xff09; MBR&#xff1a; ① 最大支持2T以内的硬盘 ② 有主分区p 拓展分区e 逻辑分区l之分 > 主分区编号1-4&#xff0c;主分区可以格式化使用 拓展分区编号1-4&#xff0c;拓展分区不能格式化 拓展分区最多能有1个&…

C++11中的新特性(2)

C11 1 可变参数模板2 emplace_back函数3 lambda表达式3.1 捕捉列表的作用3.2 lambda表达式底层原理 4 包装器5 bind函数的使用 1 可变参数模板 在C11之前&#xff0c;模板利用class关键字定义了几个参数&#xff0c;那么我们在编译推演中&#xff0c;我们就必须传入对应的参数…

Leetcode:Z 字形变换

题目链接&#xff1a;6. Z 字形变换 - 力扣&#xff08;LeetCode&#xff09; 普通版本&#xff08;二维矩阵的直接读写&#xff09; 解决办法&#xff1a;直接依据题目要求新建并填写一个二维数组&#xff0c;最后再将该二维数组中的有效字符按从左到右、从上到下的顺序读取并…

umijs+react+ts项目代码一片红处处报错解决

报错问题现象 1、在没有 "node" 模块解析策略的情况下&#xff0c;无法指定选项 "-resolveJsonModule"。 2、类型“JSX.IntrinsicElements”上不存在属性“div”。 解决办法 试了很多都没用&#xff0c;最后是参考这位朋友的解决了 vitevue3搭建工程标…

一个HL7的模拟工具

这个模拟器是为了过&#xff08; NIST美国国家标准与技术研究院&#xff08;National Institute of Standards and Technology&#xff0c;NIST&#xff09;的电子病历住院部分的认证而写的。 用途说明 inpatient中的lab order信息通过该工具向实验室转发该信息。并将实验室…

Window系统安装Docker

因为docker只适合在liunx系统上运行&#xff0c;如果在window上安装的话&#xff0c;就需要开启window的虚拟化&#xff0c;打开控制面板&#xff0c;点击程序&#xff0c;在程序和功能中可以看到启动和关闭window功能&#xff0c;点开后&#xff0c;找到Hyper-V&#xff0c;Wi…

Conditional DETR解读---带anchor的DETR

DETR存在的问题 1.收敛速度慢 2.对小目标物体检测效果不好&#xff0c;因为transformer计算量大&#xff0c;受限于计算规模&#xff0c;CNN提取特征时只采取了最后一层特征&#xff0c;没有用FPN等结构。所以对于小目标检测效果不好。 论文主要观点 通过对DETRdecoder中的a…