pytorch中nn.Sequential详解

1 nn.Sequential概述

1.1 nn.Sequential介绍

nn.Sequential是一个序列容器,用于搭建神经网络的模块被按照被传入构造器的顺序添加到容器中。除此之外,一个包含神经网络模块的OrderedDict也可以被传入nn.Sequential()容器中。利用nn.Sequential()搭建好模型架构,模型前向传播时调用forward()方法,模型接收的输入首先被传入nn.Sequential()包含的第一个网络模块中。然后,第一个网络模块的输出传入第二个网络模块作为输入,按照顺序依次计算并传播,直到nn.Sequential()里的最后一个模块输出结果。

因此,Sequential可以看成是有多个函数运算对象,串联成的神经网络,其返回的是Module类型的神经网络对象。

1.2 nn.Sequential的本质作用

与一层一层的单独调用模块组成序列相比,nn.Sequential() 可以允许将整个容器视为单个模块(即相当于把多个模块封装成一个模块),forward()方法接收输入之后,nn.Sequential()按照内部模块的顺序自动依次计算并输出结果。这就意味着我们可以利用nn.Sequential() 自定义自己的网络层。

示例代码:

from torch import nnclass net(nn.Module):def __init__(self, in_channel, out_channel):super(net, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(in_channel, in_channel / 4, kernel_size=1),nn.BatchNorm2d(in_channel / 4),nn.ReLU())self.layer2 = nn.Sequential(nn.Conv2d(in_channel / 4, in_channel / 4),nn.BatchNorm2d(in_channel / 4),nn.ReLU())self.layer3 = nn.Sequential(nn.Conv2d(in_channel / 4, out_channel, kernel_size=1),nn.BatchNorm2d(out_channel),nn.ReLU())def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return x

上边的代码,我们通过nn.Sequential()将卷积层,BN层和激活函数层封装在一个层中,输入x经过卷积、BN和ReLU后直接输出激活函数作用之后的结果。

1.3 nn.Sequential源码

def __init__(self, *args):super(Sequential, self).__init__()if len(args) == 1 and isinstance(args[0], OrderedDict):for key, module in args[0].items():self.add_module(key, module)else:for idx, module in enumerate(args):self.add_module(str(idx), module)

nn.Sequential()首先判断接收的参数是否为OrderedDict类型,如果是的话,分别取出OrderedDict内每个元素的key(自定义的网络模块名)和value(网络模块),然后将其通过add_module方法添加到nn.Sequrntial()中。

    # NB: We can't really type check this function as the type of input# may change dynamically (as is tested in# TestScript.test_sequential_intermediary_types).  Cannot annotate# with Any as TorchScript expects a more precise typedef forward(self, input):for module in self:input = module(input)return input

 调用forward()方法进行前向传播时,for循环按照顺序遍历nn.Sequential()中存储的网络模块,并以此计算输出结果,并返回最终的计算结果。

1.3 nn.Sequential与其它容器的区别

2 使用nn.Sequential定义网络

2.1 顺序添加网络模块到容器中

import torch
import torch.nn as nnmodel = nn.Sequential(nn.Linear(28 * 28, 32),nn.ReLU(),nn.Linear(32, 10),nn.Softmax(dim=1)
)
print("model:", model)
print("model.parameters:", model.parameters)x_input = torch.randn(2, 28, 28, 1)
print("x_input:", x_input)
print("x_input.shape:", x_input.shape)y_pred = model.forward(x_input.view(x_input.size()[0], -1))
print("y_pred:", y_pred)

运行代码显示:

model: Sequential((0): Linear(in_features=784, out_features=32, bias=True)(1): ReLU()(2): Linear(in_features=32, out_features=10, bias=True)(3): Softmax(dim=1)
)
model.parameters: <bound method Module.parameters of Sequential((0): Linear(in_features=784, out_features=32, bias=True)(1): ReLU()(2): Linear(in_features=32, out_features=10, bias=True)(3): Softmax(dim=1)
)>
x_input.shape: torch.Size([2, 28, 28, 1])
y_pred: tensor([[0.1127, 0.0652, 0.1399, 0.0973, 0.1085, 0.0859, 0.1193, 0.1048, 0.0865,0.0800],[0.0986, 0.0955, 0.0927, 0.0765, 0.0782, 0.1004, 0.1171, 0.1605, 0.0883,0.0922]], grad_fn=<SoftmaxBackward0>)

2.2 包含神经网络模块的OrderedDict传入容器中

import torch
import torch.nn as nn
from collections import OrderedDictmodel = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),('relu1', nn.ReLU()),('out', nn.Linear(32, 10)),('softmax', nn.Softmax(dim=1))]))
print("model:", model)
print("model.parameters:", model.parameters)x_input = torch.randn(2, 28, 28, 1)
print("x_input.shape:", x_input.shape)y_pred = model.forward(x_input.view(x_input.size()[0], -1))
print("y_pred:", y_pred)

运行代码显示:

model: Sequential((h1): Linear(in_features=784, out_features=32, bias=True)(relu1): ReLU()(out): Linear(in_features=32, out_features=10, bias=True)(softmax): Softmax(dim=1)
)
model.parameters: <bound method Module.parameters of Sequential((h1): Linear(in_features=784, out_features=32, bias=True)(relu1): ReLU()(out): Linear(in_features=32, out_features=10, bias=True)(softmax): Softmax(dim=1)
)>
x_input.shape: torch.Size([2, 28, 28, 1])
y_pred: tensor([[0.0836, 0.1185, 0.1422, 0.0801, 0.0817, 0.0870, 0.0948, 0.1099, 0.1131,0.0892],[0.0772, 0.0933, 0.1312, 0.1135, 0.1214, 0.0736, 0.1461, 0.0711, 0.0908,0.0818]], grad_fn=<SoftmaxBackward0>)

3 nn.Sequential网络操作

3.1 索引查看子模块

import torch.nn as nn
from collections import OrderedDictmodel = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),('relu1', nn.ReLU()),('out', nn.Linear(32, 10)),('softmax', nn.Softmax(dim=1))]))
print("index0:", model[0])
print("index1:", model[1])
print("index2:", model[2])

运行代码显示:

index0: Linear(in_features=784, out_features=32, bias=True)
index1: ReLU()
index2: Linear(in_features=32, out_features=10, bias=True)

3.2 修改子模块

import torch.nn as nn
from collections import OrderedDictmodel = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),('relu1', nn.ReLU()),('out', nn.Linear(32, 10)),('softmax', nn.Softmax(dim=1))]))
model[1] = nn.Sigmoid()
print(model)

运行代码显示:

Sequential((h1): Linear(in_features=784, out_features=32, bias=True)(relu1): Sigmoid()(out): Linear(in_features=32, out_features=10, bias=True)(softmax): Softmax(dim=1)
)

3.3 添加子模块

import torch.nn as nn
from collections import OrderedDictmodel = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),('relu1', nn.ReLU()),('out', nn.Linear(32, 10)),('softmax', nn.Softmax(dim=1))]))
model.append(nn.Linear(10, 2))
print(model)

运行代码显示:

Sequential((h1): Linear(in_features=784, out_features=32, bias=True)(relu1): ReLU()(out): Linear(in_features=32, out_features=10, bias=True)(softmax): Softmax(dim=1)(4): Linear(in_features=10, out_features=2, bias=True)
)

3.4 删除子模块

import torch.nn as nn
from collections import OrderedDictmodel = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),('relu1', nn.ReLU()),('out', nn.Linear(32, 10)),('softmax', nn.Softmax(dim=1))]))
del model[2]
print(model)

运行代码显示:

Sequential((h1): Linear(in_features=784, out_features=32, bias=True)(relu1): ReLU()(softmax): Softmax(dim=1)
)

3.5 嵌套子模块

import torch.nn as nnseq_1 = nn.Sequential(nn.Linear(15, 10), nn.ReLU(), nn.Linear(10, 5))
seq_2 = nn.Sequential(nn.Linear(25, 15), nn.Sigmoid(), nn.Linear(15, 10))
seq_3 = nn.Sequential(seq_1, seq_2)
print(seq_3)

运行代码显示:

Sequential((0): Sequential((0): Linear(in_features=15, out_features=10, bias=True)(1): ReLU()(2): Linear(in_features=10, out_features=5, bias=True))(1): Sequential((0): Linear(in_features=25, out_features=15, bias=True)(1): Sigmoid()(2): Linear(in_features=15, out_features=10, bias=True))
)

4 Pytorch框架介绍

4.1 什么是Pytorch

PyTorch是一个开源的机器学习库,用于各种计算密集型任务,从基本的线性代数和优化问题到复杂的机器学习(深度学习)应用。它最初是由Facebook的AI研究实验室(FAIR)开发的,现在已经成为一个广泛使用的库,拥有庞大的社群和生态系统。

4.2 Pytorch的主要特点

  • 张量计算能力 :PyTorch提供了一个多维数组(也称为张量)的数据结构,该数据结构可用于执行各种数学运算。它也提供了用于张量计算的丰富库。

  • 自动微分:PyTorch通过其Autograd模块提供自动微分功能,这对于梯度下降和优化非常有用。

  • 动态计算图:与其他深度学习框架(如TensorFlow的早期版本)使用静态计算图不同,PyTorch使用动态计算图。这意味着图在运行时构建,这使得更灵活的模型构建成为可能。

  • 简洁的API:PyTorch的API设计得直观和易于使用,这使得开发和调试模型变得更加简单。

  • Python集成:由于PyTorch紧密集成了Python,因此它可以轻松地与Python生态系统(包括NumPy、SciPy和Matplotlib)协同工作。

  • 社群和生态系统:由于其灵活性和易用性,PyTorch赢得了大量开发者和研究人员的喜爱。这导致了一个活跃的社群以及大量的第三方库和工具。

  • 多平台和多后端支持:PyTorch不仅支持CPU,还支持NVIDIA和AMD的GPU。它也有一个生产就绪的部署解决方案——TorchServe。

  • 丰富的预训练模型和工具箱:通过torchvision、torchaudio和torchtext等库,PyTorch提供了丰富的预训练模型和数据加载工具。

4.3 PyTorch常用的工具包

● torch:类似于Numpy的通用数组库,可以在将张量类型转换为(torch.cuda.TensorFloat)并支持在GPU上进行计算。

● torch.autograd:主要用于构建计算图形并自动获取渐变的包

● torch.nn:具有共同层和成本函数的神经网络库

● torch.optim:具有通用优化算法(如SGD,Adam等)的优化包

● torch.utils:数据载入器。具有训练器和其他便利功能

● torch.legacy(.nn/.optim) :处于向后兼容性考虑,从 Torch 移植来的 legacy 代码

● torch.multiprocessing:python 多进程并发,实现进程之间 torch Tensors 的内存共享

4.4 pytroch深度学习流程

名称内容
1. 准备数据数据几乎可以是任何东西,在本文中,我们将创建一条简单的直线。
2. 建立模型创建一个模型来学习数据中的模式,将选择 损失函数、 优化器 来构建 训练过程。
3. 将模型拟合到数据(训练)已经有了数据和模型,现在让模型尝试在(训练)数据中找到模式。
4. 做出预测和评估模型(推理)模型在数据中找到了模式,将其预测与实际(测试)数据进行比较。
5. 保存和加载模型当想在其他地方使用模型,或者稍后再回来使用它时需要保存和加载模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/238921.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

蓝牙耳机编码方式

蓝牙耳机编码方式 蓝牙耳机的编码方式指的是蓝牙耳机如何处理和传输音频数据。主要的蓝牙编码方式包括&#xff1a; SBC (Subband Coding)&#xff1a;这是蓝牙音频的标准编码方式&#xff0c;所有蓝牙音频设备都支持。虽然它的音质不是最佳&#xff0c;但兼容性很好。 AAC (A…

【重点】【DP】5.最长回文子串|516.最长回文子序列

两个求解目标类似的题目&#xff0c;对比记忆&#xff01; 5.最长回文子串 题目 法1&#xff1a;二维DP 最基础方法&#xff01;必须掌握&#xff01; O(N^2) O(N^2) class Solution {public String longestPalindrome(String s) {int n s.length();if (n 1) {return s…

webpack之介绍

学习webpack之前&#xff0c;请先让我们大家了解一下什么是webpack&#xff1f;为什么要用webpack&#xff1f; Webpack是一个现代化的JavaScript应用程序的静态模块打包工具。它可以将多个模块打包成一个或多个静态资源文件&#xff0c;以便在浏览器中使用。 Webpack的主要功…

连几句恶语都容它不下,那是鸡肠鼠肚,有大度才能成大器。

连几句恶语都容它不下&#xff0c;那是鸡肠鼠肚&#xff0c;有大度才能成大器。

Spring Boot测试 - JUnit整合及模拟Mvc

概述 在现代软件开发中&#xff0c;测试是确保应用程序质量和稳定性的关键步骤。Spring Boot框架为开发人员提供了丰富的测试工具和集成&#xff0c;其中JUnit是最常用的测试框架之一。本文将介绍如何在Spring Boot项目中集成JUnit测试&#xff0c;以及如何使用模拟Mvc来进行W…

csrf自动化检测调研

https://github.com/pillarjs/understanding-csrf/blob/master/README_zh.md CSRF 攻击者在钓鱼站点&#xff0c;可以通过创建一个AJAX按钮或者表单来针对你的网站创建一个请求&#xff1a; <form action"https://my.site.com/me/something-destructive" metho…

一些问题/技巧的集合(仅个人使用)

目录 第一章、1.1&#xff09;前端找不到图片1.2&#xff09;1.3&#xff09;1.4&#xff09; 第二章、2.1&#xff09;2.2&#xff09;2.3&#xff09; 第三章、3.1&#xff09;3.2&#xff09;3.3&#xff09; 第四章、4.1&#xff09;4.2&#xff09;4.3&#xff09; 友情提…

系列一、GitHub搜索技巧

一、GitHub搜索技巧 1.1、概述 作为程序员&#xff0c;GitHub大家应该都再熟悉不过了&#xff0c;很多时候当我们需要使用某一项技能而又无从下手时&#xff0c;通常会在百度&#xff08;面向百度编程&#xff09;或者在GitHub上通过关键字寻找相关案例&#xff0c;比如我想学…

IU5070E线性单节锂电池充电管理IC

IU5070E是一款具有太阳能板最大功率点跟踪MPPT功能&#xff0c;单节锂离子电池线性充电器&#xff0c;最高支持1.5A的充电电流&#xff0c;支持非稳压适配器。同时输入电流限制精度和启动序列使得这款芯片能够符合USB-IF涌入电流规范。 IU5070E具有动态电源路径管理(DPPM)功能&…

如果你带着热爱专注地做些事,很多有趣的事就会随之而来。

如果你带着热爱专注地做些事&#xff0c;很多有趣的事就会随之而来。

第11章 GUI Page403~405 步骤三 设置滚动范围

运行效果&#xff1a; 源代码&#xff1a; /**************************************************************** Name: wxMyPainterApp.h* Purpose: Defines Application Class* Author: yanzhenxi (3065598272qq.com)* Created: 2023-12-21* Copyright: yanzhen…

一款外置MOS开关降压型 LED 恒流控制器应用方案

一、基本概述 TX6121 是一款高效率、高精度的降压型大功率 LED 恒流驱动控制器芯片。芯片采用固定关断时间的峰值电流控制方式&#xff0c;关断时间可通过外部电容进行调节&#xff0c;工作频率可根据用户要求而改变。 通过调节外置的电流采样电阻&#xff0c;能控制高亮度 LE…

火力发电厂电气一次部分初步设计(论文+图纸)

1 原始资料 设计电厂为中型是凝汽式发电厂&#xff0c;共4台发电机组&#xff0c;2台75MW机组&#xff0c;2台50MW机组&#xff0c;总的装机容量为250MW&#xff0c;占系统容量的比例为&#xff1a; 250/(3500250)100%6.7%<15%&#xff0c;未超过电力系统的检修备用容量和…

Jwt 如何在 springboot 项目中进行接口访问鉴权

文章目录 1 springboot 框架负责接口的拦截和放行1.1 原理1.2 思路1.3 坑: Springboot 访问了错误处理路径 /error 2 jwt token 负责携带数据和签名的生成及校验2.1 初始化2.2 设置 Header2.3 携带数据 payload2.4 签名 sign 后, 生成 token2.5 校验2.6 获取信息2.7 字段说明 3…

WebGL在教育和培训的应用

WebGL在教育和培训领域具有广泛的应用&#xff0c;其强大的图形渲染能力和跨平台性使得它成为创建交互式、视觉化的数字内容的理想选择。以下是一些WebGL在教育和培训上的应用示例&#xff0c;希望对大家有所帮助。北京木奇移动技术有限公司&#xff0c;专业的软件外包开发公司…

PlatEMO UI 界面

&#x1f389; 博主相信&#xff1a; 有足够的积累&#xff0c;并且一直在路上&#xff0c;就有无限的可能&#xff01;&#xff01;&#xff01; &#x1f468;‍&#x1f393;个人主页&#xff1a; 青年有志的博客 &#x1f4af; Github 源码下载&#xff1a;https://github.…

xpath 解析(基础)

解析xml 首先要下载包&#xff1a;pip install lxml 基本使用如下代码所示&#xff1a; # xpath 解析&#xff1a;先安装lxml:pip install lxml from lxml import etreexml """ <book><id>1</id><name>山花遍地开</name><pr…

nmap端口扫描工具安装和使用方法

nmap&#xff08;Network Mapper&#xff09;是一款开源免费的针对大型网络的端口扫描工具&#xff0c;nmap可以检测目标主机是否在线、主机端口开放情况、检测主机运行的服务类型及版本信息、检测操作系统与设备类型等信息。本文主要介绍nmap工具安装和基本使用方法。 nmap主…

【Java】编写一个简单的Servlet程序

Java Servlet 是运行在 Web 服务器或应用服务器上的程序&#xff0c;它是作为来自 Web 浏览器或其他 HTTP 客户端的请求和 HTTP 服务器上的数据库或应用程序之间的中间层。 使用 Servlet&#xff0c;可以收集来自网页表单的用户输入&#xff0c;呈现来自数据库或者其他源的记录…