pytorch中nn.Sequential详解

1 nn.Sequential概述

1.1 nn.Sequential介绍

nn.Sequential是一个序列容器,用于搭建神经网络的模块被按照被传入构造器的顺序添加到容器中。除此之外,一个包含神经网络模块的OrderedDict也可以被传入nn.Sequential()容器中。利用nn.Sequential()搭建好模型架构,模型前向传播时调用forward()方法,模型接收的输入首先被传入nn.Sequential()包含的第一个网络模块中。然后,第一个网络模块的输出传入第二个网络模块作为输入,按照顺序依次计算并传播,直到nn.Sequential()里的最后一个模块输出结果。

因此,Sequential可以看成是有多个函数运算对象,串联成的神经网络,其返回的是Module类型的神经网络对象。

1.2 nn.Sequential的本质作用

与一层一层的单独调用模块组成序列相比,nn.Sequential() 可以允许将整个容器视为单个模块(即相当于把多个模块封装成一个模块),forward()方法接收输入之后,nn.Sequential()按照内部模块的顺序自动依次计算并输出结果。这就意味着我们可以利用nn.Sequential() 自定义自己的网络层。

示例代码:

from torch import nnclass net(nn.Module):def __init__(self, in_channel, out_channel):super(net, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(in_channel, in_channel / 4, kernel_size=1),nn.BatchNorm2d(in_channel / 4),nn.ReLU())self.layer2 = nn.Sequential(nn.Conv2d(in_channel / 4, in_channel / 4),nn.BatchNorm2d(in_channel / 4),nn.ReLU())self.layer3 = nn.Sequential(nn.Conv2d(in_channel / 4, out_channel, kernel_size=1),nn.BatchNorm2d(out_channel),nn.ReLU())def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return x

上边的代码,我们通过nn.Sequential()将卷积层,BN层和激活函数层封装在一个层中,输入x经过卷积、BN和ReLU后直接输出激活函数作用之后的结果。

1.3 nn.Sequential源码

def __init__(self, *args):super(Sequential, self).__init__()if len(args) == 1 and isinstance(args[0], OrderedDict):for key, module in args[0].items():self.add_module(key, module)else:for idx, module in enumerate(args):self.add_module(str(idx), module)

nn.Sequential()首先判断接收的参数是否为OrderedDict类型,如果是的话,分别取出OrderedDict内每个元素的key(自定义的网络模块名)和value(网络模块),然后将其通过add_module方法添加到nn.Sequrntial()中。

    # NB: We can't really type check this function as the type of input# may change dynamically (as is tested in# TestScript.test_sequential_intermediary_types).  Cannot annotate# with Any as TorchScript expects a more precise typedef forward(self, input):for module in self:input = module(input)return input

 调用forward()方法进行前向传播时,for循环按照顺序遍历nn.Sequential()中存储的网络模块,并以此计算输出结果,并返回最终的计算结果。

1.3 nn.Sequential与其它容器的区别

2 使用nn.Sequential定义网络

2.1 顺序添加网络模块到容器中

import torch
import torch.nn as nnmodel = nn.Sequential(nn.Linear(28 * 28, 32),nn.ReLU(),nn.Linear(32, 10),nn.Softmax(dim=1)
)
print("model:", model)
print("model.parameters:", model.parameters)x_input = torch.randn(2, 28, 28, 1)
print("x_input:", x_input)
print("x_input.shape:", x_input.shape)y_pred = model.forward(x_input.view(x_input.size()[0], -1))
print("y_pred:", y_pred)

运行代码显示:

model: Sequential((0): Linear(in_features=784, out_features=32, bias=True)(1): ReLU()(2): Linear(in_features=32, out_features=10, bias=True)(3): Softmax(dim=1)
)
model.parameters: <bound method Module.parameters of Sequential((0): Linear(in_features=784, out_features=32, bias=True)(1): ReLU()(2): Linear(in_features=32, out_features=10, bias=True)(3): Softmax(dim=1)
)>
x_input.shape: torch.Size([2, 28, 28, 1])
y_pred: tensor([[0.1127, 0.0652, 0.1399, 0.0973, 0.1085, 0.0859, 0.1193, 0.1048, 0.0865,0.0800],[0.0986, 0.0955, 0.0927, 0.0765, 0.0782, 0.1004, 0.1171, 0.1605, 0.0883,0.0922]], grad_fn=<SoftmaxBackward0>)

2.2 包含神经网络模块的OrderedDict传入容器中

import torch
import torch.nn as nn
from collections import OrderedDictmodel = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),('relu1', nn.ReLU()),('out', nn.Linear(32, 10)),('softmax', nn.Softmax(dim=1))]))
print("model:", model)
print("model.parameters:", model.parameters)x_input = torch.randn(2, 28, 28, 1)
print("x_input.shape:", x_input.shape)y_pred = model.forward(x_input.view(x_input.size()[0], -1))
print("y_pred:", y_pred)

运行代码显示:

model: Sequential((h1): Linear(in_features=784, out_features=32, bias=True)(relu1): ReLU()(out): Linear(in_features=32, out_features=10, bias=True)(softmax): Softmax(dim=1)
)
model.parameters: <bound method Module.parameters of Sequential((h1): Linear(in_features=784, out_features=32, bias=True)(relu1): ReLU()(out): Linear(in_features=32, out_features=10, bias=True)(softmax): Softmax(dim=1)
)>
x_input.shape: torch.Size([2, 28, 28, 1])
y_pred: tensor([[0.0836, 0.1185, 0.1422, 0.0801, 0.0817, 0.0870, 0.0948, 0.1099, 0.1131,0.0892],[0.0772, 0.0933, 0.1312, 0.1135, 0.1214, 0.0736, 0.1461, 0.0711, 0.0908,0.0818]], grad_fn=<SoftmaxBackward0>)

3 nn.Sequential网络操作

3.1 索引查看子模块

import torch.nn as nn
from collections import OrderedDictmodel = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),('relu1', nn.ReLU()),('out', nn.Linear(32, 10)),('softmax', nn.Softmax(dim=1))]))
print("index0:", model[0])
print("index1:", model[1])
print("index2:", model[2])

运行代码显示:

index0: Linear(in_features=784, out_features=32, bias=True)
index1: ReLU()
index2: Linear(in_features=32, out_features=10, bias=True)

3.2 修改子模块

import torch.nn as nn
from collections import OrderedDictmodel = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),('relu1', nn.ReLU()),('out', nn.Linear(32, 10)),('softmax', nn.Softmax(dim=1))]))
model[1] = nn.Sigmoid()
print(model)

运行代码显示:

Sequential((h1): Linear(in_features=784, out_features=32, bias=True)(relu1): Sigmoid()(out): Linear(in_features=32, out_features=10, bias=True)(softmax): Softmax(dim=1)
)

3.3 添加子模块

import torch.nn as nn
from collections import OrderedDictmodel = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),('relu1', nn.ReLU()),('out', nn.Linear(32, 10)),('softmax', nn.Softmax(dim=1))]))
model.append(nn.Linear(10, 2))
print(model)

运行代码显示:

Sequential((h1): Linear(in_features=784, out_features=32, bias=True)(relu1): ReLU()(out): Linear(in_features=32, out_features=10, bias=True)(softmax): Softmax(dim=1)(4): Linear(in_features=10, out_features=2, bias=True)
)

3.4 删除子模块

import torch.nn as nn
from collections import OrderedDictmodel = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),('relu1', nn.ReLU()),('out', nn.Linear(32, 10)),('softmax', nn.Softmax(dim=1))]))
del model[2]
print(model)

运行代码显示:

Sequential((h1): Linear(in_features=784, out_features=32, bias=True)(relu1): ReLU()(softmax): Softmax(dim=1)
)

3.5 嵌套子模块

import torch.nn as nnseq_1 = nn.Sequential(nn.Linear(15, 10), nn.ReLU(), nn.Linear(10, 5))
seq_2 = nn.Sequential(nn.Linear(25, 15), nn.Sigmoid(), nn.Linear(15, 10))
seq_3 = nn.Sequential(seq_1, seq_2)
print(seq_3)

运行代码显示:

Sequential((0): Sequential((0): Linear(in_features=15, out_features=10, bias=True)(1): ReLU()(2): Linear(in_features=10, out_features=5, bias=True))(1): Sequential((0): Linear(in_features=25, out_features=15, bias=True)(1): Sigmoid()(2): Linear(in_features=15, out_features=10, bias=True))
)

4 Pytorch框架介绍

4.1 什么是Pytorch

PyTorch是一个开源的机器学习库,用于各种计算密集型任务,从基本的线性代数和优化问题到复杂的机器学习(深度学习)应用。它最初是由Facebook的AI研究实验室(FAIR)开发的,现在已经成为一个广泛使用的库,拥有庞大的社群和生态系统。

4.2 Pytorch的主要特点

  • 张量计算能力 :PyTorch提供了一个多维数组(也称为张量)的数据结构,该数据结构可用于执行各种数学运算。它也提供了用于张量计算的丰富库。

  • 自动微分:PyTorch通过其Autograd模块提供自动微分功能,这对于梯度下降和优化非常有用。

  • 动态计算图:与其他深度学习框架(如TensorFlow的早期版本)使用静态计算图不同,PyTorch使用动态计算图。这意味着图在运行时构建,这使得更灵活的模型构建成为可能。

  • 简洁的API:PyTorch的API设计得直观和易于使用,这使得开发和调试模型变得更加简单。

  • Python集成:由于PyTorch紧密集成了Python,因此它可以轻松地与Python生态系统(包括NumPy、SciPy和Matplotlib)协同工作。

  • 社群和生态系统:由于其灵活性和易用性,PyTorch赢得了大量开发者和研究人员的喜爱。这导致了一个活跃的社群以及大量的第三方库和工具。

  • 多平台和多后端支持:PyTorch不仅支持CPU,还支持NVIDIA和AMD的GPU。它也有一个生产就绪的部署解决方案——TorchServe。

  • 丰富的预训练模型和工具箱:通过torchvision、torchaudio和torchtext等库,PyTorch提供了丰富的预训练模型和数据加载工具。

4.3 PyTorch常用的工具包

● torch:类似于Numpy的通用数组库,可以在将张量类型转换为(torch.cuda.TensorFloat)并支持在GPU上进行计算。

● torch.autograd:主要用于构建计算图形并自动获取渐变的包

● torch.nn:具有共同层和成本函数的神经网络库

● torch.optim:具有通用优化算法(如SGD,Adam等)的优化包

● torch.utils:数据载入器。具有训练器和其他便利功能

● torch.legacy(.nn/.optim) :处于向后兼容性考虑,从 Torch 移植来的 legacy 代码

● torch.multiprocessing:python 多进程并发,实现进程之间 torch Tensors 的内存共享

4.4 pytroch深度学习流程

名称内容
1. 准备数据数据几乎可以是任何东西,在本文中,我们将创建一条简单的直线。
2. 建立模型创建一个模型来学习数据中的模式,将选择 损失函数、 优化器 来构建 训练过程。
3. 将模型拟合到数据(训练)已经有了数据和模型,现在让模型尝试在(训练)数据中找到模式。
4. 做出预测和评估模型(推理)模型在数据中找到了模式,将其预测与实际(测试)数据进行比较。
5. 保存和加载模型当想在其他地方使用模型,或者稍后再回来使用它时需要保存和加载模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/238921.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

csrf自动化检测调研

https://github.com/pillarjs/understanding-csrf/blob/master/README_zh.md CSRF 攻击者在钓鱼站点&#xff0c;可以通过创建一个AJAX按钮或者表单来针对你的网站创建一个请求&#xff1a; <form action"https://my.site.com/me/something-destructive" metho…

一些问题/技巧的集合(仅个人使用)

目录 第一章、1.1&#xff09;前端找不到图片1.2&#xff09;1.3&#xff09;1.4&#xff09; 第二章、2.1&#xff09;2.2&#xff09;2.3&#xff09; 第三章、3.1&#xff09;3.2&#xff09;3.3&#xff09; 第四章、4.1&#xff09;4.2&#xff09;4.3&#xff09; 友情提…

系列一、GitHub搜索技巧

一、GitHub搜索技巧 1.1、概述 作为程序员&#xff0c;GitHub大家应该都再熟悉不过了&#xff0c;很多时候当我们需要使用某一项技能而又无从下手时&#xff0c;通常会在百度&#xff08;面向百度编程&#xff09;或者在GitHub上通过关键字寻找相关案例&#xff0c;比如我想学…

IU5070E线性单节锂电池充电管理IC

IU5070E是一款具有太阳能板最大功率点跟踪MPPT功能&#xff0c;单节锂离子电池线性充电器&#xff0c;最高支持1.5A的充电电流&#xff0c;支持非稳压适配器。同时输入电流限制精度和启动序列使得这款芯片能够符合USB-IF涌入电流规范。 IU5070E具有动态电源路径管理(DPPM)功能&…

第11章 GUI Page403~405 步骤三 设置滚动范围

运行效果&#xff1a; 源代码&#xff1a; /**************************************************************** Name: wxMyPainterApp.h* Purpose: Defines Application Class* Author: yanzhenxi (3065598272qq.com)* Created: 2023-12-21* Copyright: yanzhen…

一款外置MOS开关降压型 LED 恒流控制器应用方案

一、基本概述 TX6121 是一款高效率、高精度的降压型大功率 LED 恒流驱动控制器芯片。芯片采用固定关断时间的峰值电流控制方式&#xff0c;关断时间可通过外部电容进行调节&#xff0c;工作频率可根据用户要求而改变。 通过调节外置的电流采样电阻&#xff0c;能控制高亮度 LE…

火力发电厂电气一次部分初步设计(论文+图纸)

1 原始资料 设计电厂为中型是凝汽式发电厂&#xff0c;共4台发电机组&#xff0c;2台75MW机组&#xff0c;2台50MW机组&#xff0c;总的装机容量为250MW&#xff0c;占系统容量的比例为&#xff1a; 250/(3500250)100%6.7%<15%&#xff0c;未超过电力系统的检修备用容量和…

WebGL在教育和培训的应用

WebGL在教育和培训领域具有广泛的应用&#xff0c;其强大的图形渲染能力和跨平台性使得它成为创建交互式、视觉化的数字内容的理想选择。以下是一些WebGL在教育和培训上的应用示例&#xff0c;希望对大家有所帮助。北京木奇移动技术有限公司&#xff0c;专业的软件外包开发公司…

PlatEMO UI 界面

&#x1f389; 博主相信&#xff1a; 有足够的积累&#xff0c;并且一直在路上&#xff0c;就有无限的可能&#xff01;&#xff01;&#xff01; &#x1f468;‍&#x1f393;个人主页&#xff1a; 青年有志的博客 &#x1f4af; Github 源码下载&#xff1a;https://github.…

nmap端口扫描工具安装和使用方法

nmap&#xff08;Network Mapper&#xff09;是一款开源免费的针对大型网络的端口扫描工具&#xff0c;nmap可以检测目标主机是否在线、主机端口开放情况、检测主机运行的服务类型及版本信息、检测操作系统与设备类型等信息。本文主要介绍nmap工具安装和基本使用方法。 nmap主…

【Java】编写一个简单的Servlet程序

Java Servlet 是运行在 Web 服务器或应用服务器上的程序&#xff0c;它是作为来自 Web 浏览器或其他 HTTP 客户端的请求和 HTTP 服务器上的数据库或应用程序之间的中间层。 使用 Servlet&#xff0c;可以收集来自网页表单的用户输入&#xff0c;呈现来自数据库或者其他源的记录…

在MongoDB中使用数组字段和子文档字段进行索引

本文主要介绍在MongoDB使用数组字段和子文档字段进行索引。 目录 MongoDB的高级索引一、索引数组字段二、索引子文档字段 MongoDB的高级索引 MongoDB是一个面向文档的NoSQL数据库&#xff0c;它提供了丰富的索引功能来加快查询性能。除了常规的单字段索引之外&#xff0c;Mong…

Ubuntu 常用命令之 shutdown 命令用法介绍

&#x1f4d1;Linux/Ubuntu 常用命令归类整理 shutdown 是 Ubuntu 系统下的一个命令&#xff0c;用于关闭或重启系统。这个命令可以让系统在一个特定的时间点进行关机或者重启&#xff0c;也可以立即执行。 shutdown 命令的基本格式如下 shutdown [选项] 时间 [警告消息]选项…

react当中生命周期(旧生命周期详解)

新生命周期https://blog.csdn.net/kkkys_kkk/article/details/135156102?spm1001.2014.3001.5501 目录 什么是生命周期 react中的生命周期 旧生命周期 生命周期图示 常用的生命周期钩子函数 初始化阶段 挂载阶段 在严格模式下挂载阶段的生命周期函数会执行两次原因 更…

软件渗透测试有哪些测试流程?权威安全测试报告的重要性

软件渗透测试也是安全测试的一种&#xff0c;是通过模拟恶意黑客的攻击方法&#xff0c;来评估计算机网络系统安全的一种评估方法。作为网络安全防范的一种新技术&#xff0c;对于网络安全组织具有实际应用价值。 一、软件渗透测试的过程   软件渗透测试的过程通常包括四个主…

前端学习——vuex的入门

学习一门技术最快捷的方式就是先了解其概念和使用场景&#xff0c;毕竟任何技术的出现都是为了解决某一个场景下的通用解决方案&#xff0c;并且使用最合理的方式去解决问题。 那么什么是vuex&#xff1f; Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式 库。它采用集中…

基于ssm+jsp学生综合测评管理系统源码和论文

网络的广泛应用给生活带来了十分的便利。所以把学生综合测评管理与现在网络相结合&#xff0c;利用java技术建设学生综合测评管理系统&#xff0c;实现学生综合测评的信息化。则对于进一步提高学生综合测评管理发展&#xff0c;丰富学生综合测评管理经验能起到不少的促进作用。…

OPC UA 与PROFINET比较

ROFINET和OPC UA是两种常见的协议&#xff0c;过去这两个协议有两个不同的角色。PROFINET通常用于现场设备和本地控制器之间的实时数据通信。而OPC UA通常用于在本地控制器和更高级别的MES和SCADA系统之间进行通信。 OPC UA 网络架构 PROFINET网络由IO控制器和IO设备组成&…

【数据结构】什么是树?

&#x1f984;个人主页:修修修也 &#x1f38f;所属专栏:数据结构 ⚙️操作环境:Visual Studio 2022 &#x1f4cc;树的定义 树(Tree)是n(n≥0)个结点的有限集.n0时称为空树. 在任意一颗非空树中: 有且仅有一个特定的称为根(Root)的结点;当n>1时,其余结点可分为m(m>0)个互…