【深度学习笔记】4_4 自定义层

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

4.4 自定义层

深度学习的一个魅力在于神经网络中各式各样的层,例如全连接层和后面章节中将要介绍的卷积层、池化层与循环层。虽然PyTorch提供了大量常用的层,但有时候我们依然希望自定义层。本节将介绍如何使用Module来自定义层,从而可以被重复调用。

4.4.1 不含模型参数的自定义层

我们先介绍如何定义一个不含模型参数的自定义层。事实上,这和4.1节(模型构造)中介绍的使用Module类构造模型类似。下面的CenteredLayer类通过继承Module类自定义了一个将输入减掉均值后输出的层,并将层的计算定义在了forward函数里。这个层里不含模型参数。

import torch
from torch import nnclass CenteredLayer(nn.Module):def __init__(self, **kwargs):super(CenteredLayer, self).__init__(**kwargs)def forward(self, x):return x - x.mean()

我们可以实例化这个层,然后做前向计算。

layer = CenteredLayer()
layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))

输出:

tensor([-2., -1.,  0.,  1.,  2.])

我们也可以用它来构造更复杂的模型。

net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())

下面打印自定义层各个输出的均值。因为均值是浮点数,所以它的值是一个很接近0的数。

y = net(torch.rand(4, 8))
y.mean().item()

输出:

0.0

4.4.2 含模型参数的自定义层

我们还可以自定义含模型参数的自定义层。其中的模型参数可以通过训练学出。

在4.2节(模型参数的访问、初始化和共享)中介绍了Parameter类其实是Tensor的子类,如果一个TensorParameter,那么它会自动被添加到模型的参数列表里。所以在自定义含模型参数的层时,我们应该将参数定义成Parameter,除了像4.2.1节那样直接定义成Parameter类外,还可以使用ParameterListParameterDict分别定义参数的列表和字典。

ParameterList接收一个Parameter实例的列表作为输入然后得到一个参数列表,使用的时候可以用索引来访问某个参数,另外也可以使用appendextend在列表后面新增参数。

class MyDense(nn.Module):def __init__(self):super(MyDense, self).__init__()self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])self.params.append(nn.Parameter(torch.randn(4, 1)))def forward(self, x):for i in range(len(self.params)):x = torch.mm(x, self.params[i])return x
net = MyDense()
print(net)

输出:

MyDense((params): ParameterList((0): Parameter containing: [torch.FloatTensor of size 4x4](1): Parameter containing: [torch.FloatTensor of size 4x4](2): Parameter containing: [torch.FloatTensor of size 4x4](3): Parameter containing: [torch.FloatTensor of size 4x1])
)

ParameterDict接收一个Parameter实例的字典作为输入然后得到一个参数字典,然后可以按照字典的规则使用了。例如使用update()新增参数,使用keys()返回所有键值,使用items()返回所有键值对等等,可参考官方文档。

class MyDictDense(nn.Module):def __init__(self):super(MyDictDense, self).__init__()self.params = nn.ParameterDict({'linear1': nn.Parameter(torch.randn(4, 4)),'linear2': nn.Parameter(torch.randn(4, 1))})self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增def forward(self, x, choice='linear1'):return torch.mm(x, self.params[choice])net = MyDictDense()
print(net)

输出:

MyDictDense((params): ParameterDict((linear1): Parameter containing: [torch.FloatTensor of size 4x4](linear2): Parameter containing: [torch.FloatTensor of size 4x1](linear3): Parameter containing: [torch.FloatTensor of size 4x2])
)

这样就可以根据传入的键值来进行不同的前向传播:

x = torch.ones(1, 4)
print(net(x, 'linear1'))
print(net(x, 'linear2'))
print(net(x, 'linear3'))

输出:

tensor([[1.5082, 1.5574, 2.1651, 1.2409]], grad_fn=<MmBackward>)
tensor([[-0.8783]], grad_fn=<MmBackward>)
tensor([[ 2.2193, -1.6539]], grad_fn=<MmBackward>)

我们也可以使用自定义层构造模型。它和PyTorch的其他层在使用上很类似。

net = nn.Sequential(MyDictDense(),MyListDense(),
)
print(net)
print(net(x))

输出:

Sequential((0): MyDictDense((params): ParameterDict((linear1): Parameter containing: [torch.FloatTensor of size 4x4](linear2): Parameter containing: [torch.FloatTensor of size 4x1](linear3): Parameter containing: [torch.FloatTensor of size 4x2]))(1): MyListDense((params): ParameterList((0): Parameter containing: [torch.FloatTensor of size 4x4](1): Parameter containing: [torch.FloatTensor of size 4x4](2): Parameter containing: [torch.FloatTensor of size 4x4](3): Parameter containing: [torch.FloatTensor of size 4x1]))
)
tensor([[-101.2394]], grad_fn=<MmBackward>)

小结

  • 可以通过Module类自定义神经网络中的层,从而可以被重复调用。

注:本节与原书此节有一些不同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/705033.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

240Hz高刷电竞显示器 - HKC VG253KM

&#x1f389;&#x1f389;&#x1f389; 各位电竞爱好者们&#xff0c;今天给大家带来一款神秘武器&#xff0c;一款能够让你在游戏中大展拳脚的高刷电竞显示器 - HKC VG253KM&#xff01;&#x1f525;&#x1f525;&#x1f525; 这款显示器&#xff0c;哎呀&#xff0c;真…

10分钟快速开始SkyWalking结合Springboot项目

10分钟快速开始SkyWalking结合Springboot项目 实习期间&#xff0c;公司让我去学习一下链路追踪如何集成到Springboot项目中。 为此有两个方案&#xff1a; 1.opentelementryjaegerprometheus opentelementry 收集器收集线上的metrics和traces&#xff0c;然后发送给jaeger和p…

IP对讲终端SV-6002(防水)

SV-6002&#xff08;防水&#xff09;是一款IP对讲终端&#xff0c;具有10/100M以太网接口&#xff0c;其接收网络的音频数据&#xff0c;解码后播放&#xff0c;外部DC12~24V电源供电端子&#xff0c;提供单路2W的音频输出。基于TCP/IP网络通信协议和数字音频技术&#xff0c;…

低代码开发如何助力数字化企业管理系统平台构建

随着数字化时代的到来&#xff0c;企业对于管理系统的需求日益增长。高效的管理系统可以提高企业的运作效率&#xff0c;降低成本&#xff0c;提升竞争力。然而&#xff0c;传统的开发方式在应对日益复杂的管理系统需求时&#xff0c;显得力不从心。低代码开发作为一种新兴的开…

Vue笔记(一)

常用指令 1.v-show与v-if底层原理的区别 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>创建一个V…

OpenGL ES 3.0 从入门到精通系统性学习教程

为什么要写这个教程 因为在工作中频繁使用 OpenGL ES 做一些特效、滤镜之类的效果&#xff0c;加上平时学到的的知识点也比较细碎&#xff0c;就想着去系统地学习下 OpenGL ES 相关开发知识&#xff0c;并将学习过程记录下来。 准备知识 一些同学反映&#xff0c;学习这个教…

STM32存储左右互搏 QSPI总线FATS文件读写FLASH W25QXX

STM32存储左右互搏 QSPI总线FATS文件读写FLASH W25QXX FLASH是常用的一种非易失存储单元&#xff0c;W25QXX系列Flash有不同容量的型号&#xff0c;如W25Q64的容量为64Mbit&#xff0c;也就是8MByte。这里介绍STM32CUBEIDE开发平台HAL库Quad SPI总线实现FATS文件操作W25Q各型号…

redis持久化失败问题(MISCONF Redis is configured to save RDB snapshots, but ......)问题解决

今天同事反应测试环境业务一直报错&#xff0c;好像是redis持久化出现了问题&#xff0c;并给出了错误信息&#xff0c;让我帮忙看一下&#xff0c;说明明还有2G内存为何还会报错 MISCONF Redis is configured to save RDB snapshots, but it is currently not able to persis…

mysql 安装 与 使用

1.安装地址&#xff08;社区免费版本&#xff09; https://dev.mysql.com/downloads/mysql/ 2.查看端口 ****是否被占用&#xff08;例子 3306端口&#xff09; netstat -an | find "3306" 3.配置环境 系统变量名 变量名&#xff1a;MYSQL_HOME 变量值&#…

第十三天-mysql交互

目录 1.安装MySQL connector 方式1&#xff1a;直接安装 方式2&#xff1a;下载 2.创建链接 3.游标Cursor 4.事务控制 5. 数据库连接池 1. 使用 6.循环执行SQL语句 不了解mysql的可以先了解mysql基础 1.安装MySQL connector 1. MySQL connector 是MySQL官方驱动模块…

jmeter 按线程数阶梯式压测数据库

当前版本&#xff1a; jmeter 5.6.3mysql 5.7.39 简介 JMeter 通过 bzm - Concurrency Thread Group 来实现阶梯式压测&#xff0c;它并不是JMeter的官方插件&#xff0c;而是一种由Blazemeter提供的高级线程组插件。可以在不同的时间内并发执行不同数量的线程&#xff0c;模拟…

音频常用测试参数

一、总谐波失真&#xff08;THDN&#xff09; 总谐波失真指音频信号源通过功率放大器时&#xff0c;由于非线性元件所引起的输出信号比输入信号多出的额外谐波成份。谐波失真是由于系统不是完全线性造成的&#xff0c;我们用新增加总谐波成份的均方根与原来信号有效值的百分比来…

MySQL之Pt-kill工具

工具下载 [rootlocalhost1 bin]# wget percona.com/get/percona-toolkit.tar.gz [rootlocalhost1 bin]# yum install perl-DBI [rootlocalhost1 bin]# yum install perl-DBD-MySQL [rootlocalhost1 bin]# ./pt-kill --help1、每10秒检查一次&#xff0c;发现有 Query 的进程就…

3D生成式AI模型与工具

当谈到技术炒作时&#xff0c;人工智能正在超越虚拟世界&#xff0c;吸引世界各地企业和消费者的注意力。 但人工智能可以进一步增强虚拟世界&#xff0c;至少在某种意义上&#xff1a;资产创造。 AI 有潜力扩大用于虚拟环境的 3D 资产的创建。 AI 3D生成使用人工智能生成3D模…

开发知识点-.netC#图形用户界面开发之WPF

C#图形用户界面开发 NuGet框架简介WinForms(Windows Forms):WPF(Windows Presentation Foundation):UWP(Universal Windows Platform):MAUI(Multi-platform App UI):选择控件参考文章随笔分类 - WPF入门基础教程系列

什么时候要用到Reflect API?

参考文档 https://www.zhihu.com/question/460133198 https://cn.vuejs.org/guide/extras/reactivity-in-depth.html https://juejin.cn/post/7103764386220769311 Reflect API 一般搭配 Proxy API 一起使用。什么是 Proxy API 呢&#xff1f; 先回顾下 vue 的数据响应性是如何…

投票项目_注册功能版本迭代

V0版本: 简单的注册,前端先进行初步核对,两次输入的密码是否相等?用户账号和密码是否符合要求?核对成功后前端传来账号密码,拿到用户名去数据库核对,如果没有找到相同的用户名就插入到数据库,找到相同的用户名就返回”用户已存在” V1版本: 加入了uuid 1. 导入依赖 <!-- …

《银幕上的编码传奇:计算机科学与科技精神的光影盛宴》

目录 1.在电影的世界里&#xff0c;计算机科学不仅是一门严谨的学科&#xff0c;更是一种富有戏剧张力和人文思考的艺术载体。 2.电影作为现代文化的重要载体&#xff0c;常常以其丰富的想象力和视觉表现力来探讨计算机科学和技术的各种前沿主题。 3.电影中的程序员角色往往…

GDB之(3)加载指定动态库文件

GDB之(3)加载指定动态库文件 Author&#xff1a;Once Day Date&#xff1a;2024年2月26日 漫漫长路&#xff0c;才刚刚开始… 全系列文章请查看专栏: Linux实践记录_Once-Day的博客-CSDN博客 推荐参考文档&#xff1a; gdb 查找动态库方法_info sharedlibrary-CSDN博客GDB…

swift -- 系统语音识别(转文字)

文章目录 一、系统类1. 导入系统库2. SFSpeechRecognizer声音处理器3. SFSpeechAudioBufferRecognitionRequest 语音识别器4. AVAudioEngine 处理声音的数据5. SFSpeechRecognitionTask 语言识别任务管理器 二、代码整理1. 初始化属性2. 判断权限3. 开始语音识别4. 停止语音识别…