pytorch-nn.Module

目录

  • 1. nn.Module
  • 2. nn.Sequential容器
  • 3. 网络参数parameters
  • 4. Modules内部管理
  • 5. checkpoint
  • 6. train/test状态切换
  • 6. 实现自己的网络层
    • 6.1 实现打平操作
    • 6.2 实现自己的线性层
  • 7. 代码

1. nn.Module

是所有nn.类的父类,其中包括nn.Linear nn.BatchNorm2d nn.Conv2d nn.ReLU nn.Sigmoid等等

2. nn.Sequential容器

如下图,定义一个net网络,将所有继承自nn.Module的子类定义的网络层加入到了nn.Sequential容器中,与一层一层的单独调用模块组成序列相比,nn.Sequential() 可以允许将整个容器视为单个模块(即相当于把多个模块封装成一个模块),forward()方法接收输入之后,nn.Sequential()按照内部模块的顺序自动依次计算并输出结果。因此可以利用nn.Sequential()搭建模型架构

在这里插入图片描述

3. 网络参数parameters

如下图,通过net.parameters()可以获取到net的参数,转换成list后,通过index访问第几个参数,比如:图中的list(net.named_parameters())[0]就可以获取到网络的第一个参数,也就是网络第一层的w参数。
通过list(net.named_parameters()).items()获取到所有网络层,从获取结果可以看到,每一层都被pytorch命名了,比如:‘0.weight’,‘0.bias’,即第一层网络的weight和bias.
在这里插入图片描述

4. Modules内部管理

与根节点相连的直系亲属叫children,其他再与children连接的节点都叫modules
如下图,nn.Sequential是Net的children,其他的是modules,包括nn.ReLU、nn.Linear、BasicNet
在这里插入图片描述
从下面这张截图可以看出,Net本身和Children也都是modules
在这里插入图片描述

5. checkpoint

为了防止train过程意外停止,需从头train的问题,train过程需要定期保持checkpoint,而一旦出现train意外停止,就可以从最后一次checkpoint接着训练。
torch.save保存checkpoint
torch.load_state_dict(torch.load(‘chpt.md’))用于load checkpoint
在这里插入图片描述

6. train/test状态切换

所有nn.类都继承自nn.Module,因此在切换train和test状态时,只需要调用一次net.train()或net.eval即可,而不需要那些train和test(dropout)行为不一致的类每个单独去切换.
在这里插入图片描述

6. 实现自己的网络层

6.1 实现打平操作

全连接层层需要打平输入,打平操作通过.view方法实现,由于Flatten继承自nn.Module,因此可以直接放到nn.Sequential中。
在这里插入图片描述

6.2 实现自己的线性层

通过net.parameters()可以将网络参数加到优化器中。
在这里插入图片描述
troch.tensor是不会自动加到nn.parameters中,因此需要使用nn.Parameter将tensor加到nn.parameters,从而才能加到SGD等优化器中。

在这里插入图片描述

7. 代码

import  torch
from    torch import nn
from    torch import optimclass MyLinear(nn.Module):def __init__(self, inp, outp):super(MyLinear, self).__init__()# requires_grad = Trueself.w = nn.Parameter(torch.randn(outp, inp))self.b = nn.Parameter(torch.randn(outp))def forward(self, x):x = x @ self.w.t() + self.breturn xclass Flatten(nn.Module):def __init__(self):super(Flatten, self).__init__()def forward(self, input):return input.view(input.size(0), -1)class TestNet(nn.Module):def __init__(self):super(TestNet, self).__init__()self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),nn.MaxPool2d(2, 2),Flatten(),nn.Linear(1*14*14, 10))def forward(self, x):return self.net(x)class BasicNet(nn.Module):def __init__(self):super(BasicNet, self).__init__()self.net = nn.Linear(4, 3)def forward(self, x):return self.net(x)class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.net = nn.Sequential(BasicNet(),nn.ReLU(),nn.Linear(3, 2))def forward(self, x):return self.net(x)def main():device = torch.device('cuda')net = Net()net.to(device)net.train()net.eval()# net.load_state_dict(torch.load('ckpt.mdl'))### torch.save(net.state_dict(), 'ckpt.mdl')for name, t in net.named_parameters():print('parameters:', name, t.shape)for name, m in net.named_children():print('children:', name, m)for name, m in net.named_modules():print('modules:', name, m)if __name__ == '__main__':main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/23835.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

每日一练 - OSPF协议验证机制

01 真题题目 OSPF 只有在 Hello 报文中有验证信息,OSPF 支持 MD5 密文验证. A.正确 B.错误 02 真题答案 B 03 答案解析 这个陈述是不完全正确的。首先,OSPF确实使用Hello报文来携带认证信息,但这不意味着只有Hello报文包含验证信息。 OSPF的认证机制可…

政府绩效考核第三方评估的含义

政府绩效考核第三方评估是指由独立于政府的外部机构(如专业评估公司、研究机构或非政府组织)对政府部门或其下属单位的绩效进行客观、公正、系统的评估。其主要目的是通过引入独立的第三方评估机构,对政府绩效进行科学、全面的考核&#xff0…

【AIGC调研系列】Qwen2与llama3对比的优势

Qwen2与Llama3的对比中,Qwen2展现出了多方面的优势。首先,从性能角度来看,Qwen2在多个基准测试中表现出色,尤其是在代码和数学能力上有显著提升[1][9]。此外,Qwen2还在自然语言理解、知识、多语言等多项能力上均显著超…

肺结节14问,查出肺结节怎么办?哪些能用中医调治消散?快来了解一下吧

近些年,随着大众防癌意识的加强,和胸部低剂量CT的普及,肺结节的检出率也逐年升高,不少患者CT报告上,写着“肺小结”“肺部磨玻璃结节”的字样,当你看到这几个字时,会不会瞬间紧张起来&#xff1…

编程规范-代码检测-格式化-规范化提交

适用于vue项目的编程规范 – 在多人开发时统一编程规范至关重要 1、代码检测 --Eslint Eslint:一个插件化的 javascript 代码检测工具 在 .eslintrc.js 文件中进行配置 // ESLint 配置文件遵循 commonJS 的导出规则,所导出的对象就是 ESLint 的配置对…

简化电动汽车充电器和光伏逆变器的高压电流检测

在任何电气系统中,电流都是一个至关重要的参数。电动汽车 (EV) 充电系统和太阳能系统都需要检测电流的大小,以便控制和监测功率转换、充电和放电。电流传感器通过监测分流电阻器上的压降或导体中电流产生的磁场来测量电流。 金属氧化物半导体场效应晶体…

DBeaver连接MySQL提示“Public Key Retrieval is not allowed“问题的解决方式

问题描述 客户端root用户连接数据库出现出现Public Key Retrieval is not allowed 原因分析: 加上allowPublicKeyRetrievalfalse: 解决方案: allowPublicKeyRetrievaltrue:

Java Web学习笔记14——BOM对象

BOM: 概念:浏览器对象模型(Browser Object Model),允许JavaScript与浏览器对话,JavaScript将浏览器的各个组成部分封装为对象。 组成: Window:浏览器窗口对象 介绍:浏览…

opencv锐化卷积核的定义和应用(图像锐化)。

定义锐化卷积核 卷积核(Kernel)是一个小矩阵,它用于在图像处理操作中,比如模糊、锐化、边缘检测等。卷积核通过卷积操作应用于图像像素,产生新的图像。 在锐化操作中,我们通常使用一个 3x3 的卷积核。以下…

注解 - @RestController

注解简介 在今天的每日一注解中,我们将探讨RestController注解。RestController是Spring框架中的一个组合注解,方便创建RESTful Web服务。 注解定义 RestController注解是Controller和ResponseBody注解的组合,用于定义RESTful控制器。以下是…

物联网(IoT)及物联网网络协议面试题及参考答案(2万字长文)

什么是物联网(IoT)? 物联网(Internet of Things,简称IoT)是一个由互联网、传统电信网、传感器网络等多种网络组成的网络概念。它允许物体与物体、物体与人、人与人之间通过智能传感器、软件和网络进行信息交换和通信,实现智能化识别、定位、跟踪、监控和管理。物联网的…

光伏电站鸟害解决方案,列式冲击波声压光伏驱鸟器

光伏电站的运营过程中,最怕遇上鸟粪污染。鸟粪不仅难以清洗,还可能导致光伏组件损坏、降低发电效率。因此,制定并实施有效的驱鸟策略对于光伏电站的稳定运营至关重要。 针对光伏电站的鸟害问题,我们可以从以下几个方面来解决&…

知名优秀定制线缆生产源头工厂推荐-精工电联:全程跟踪监制,打造水下机器人线缆定制新标杆

在科技飞速发展的今天,精工电联作为高科技智能化产品及自动化设备专用连接线束和连接器配套服务商,始终站在行业前沿。我们专注于为高科技行业提供高品质、优匹配的集成线缆和连接器定制服务,特别是在水下机器人线缆定制领域,通过…

CAN的TP模式和COM模式的区别

CAN的TP(传输协议)模式和COM(通信)模式主要涉及汽车网络中的数据传输机制,两者在功能、寻址方式和帧类型等方面有所不同。具体分析如下: 功能 TP模式:TP模式,即传输协议模式&#…

sql死锁分析

一、重要参数 获取事务信息:SELECT * FROM information_schema.INNODB_TRX; 获取锁等待:SELECT * FROM information_schema.INNODB_LOCK_WAITS; 查看锁信息:SELECT * FROM information_schema.INNODB_LOCKS WHERE lock_trx_id IN () 二、case1:间隙锁和x锁互斥导致死锁 1、背景…

安全高效海外仓系统:中小海外仓标准化管理的第一步

在当今全球化的商业背景中,可以说海外仓已经成为跨境电商供应链中不可或缺的一环。 尤其是对于那些处于成长阶段的中小型海外仓来说,选择一款安全高效并且符合其海外仓规模特点的wms管理系统尤其重要。 今天我们就来系统的了解一下,安全高效…

大厂AI团战高考作文,华师一附中特级教师这样打分

在人工智能的浪潮中, 人们不禁疑问: AI真的能超越人类吗? 这究竟是现实还是幻想? 我们将目睹一场前所未有的较量: 百度文心一言、阿里通义千问、 腾讯混元、字节豆包 四家国内顶尖互联网企业 精心打造的AI大模…

HBM简介

1、什么是HBM HBMHigh Bandwidth Memory 是一种用于某些 GPU的 3D 堆叠 DRAM存储器 (动态随机存取存储器)以及服务器、高性能计算 (HPC) 、网络连接的内存接口。其实就是将很多个DDR芯片堆叠在一起后和GPU封装在一起,实…

ROS socketcan_bridge使用说明

ROS socketcan_bridge使用说明(以ubuntu20.04为例) socketcan_bridge是什么 ROS针对socketcan提供了三个层次的驱动库,分别是ros_canopen,socketcan_bridge和socketcan_interface。 socketcan_interface: 功能&#x…

k-means聚类模型的原理和应用

k-means聚类算法是一种迭代求解的聚类分析算法,其步骤是,预将数据分为K组,然后随机选取K个对象作为初始的聚类中心;计算每个对象与各个种子聚类中心之间的距离,把每个对象分配给距离它最近的聚类中心;聚类中…