动手学深度学习——层和块

1. 层

层是一个将输入数据转换为输出数据的神经网络组件。每个层都会对输入数据进行一定的操作,例如线性变换、非线性激活函数等,以产生输出数据。

torch.nn模块提供了各种预定义的层,如线性层、卷积层、池化层等,

  • nn.Linear:线性层
  • nn.MaxPool2d:二维池化层
  • nn.Conv2d:二维卷积层
  • nn.ReLu:激活函数层

也支持基于nn.Module自定义层。

1.1 自定义简单层

import torch
import torch.nn.functional as F
from torch import nnclass CenteredLayer(nn.Module):def __init__(self):super().__init__()def forward(self, X):return X - X.mean()

这个层的功能是对每个输入减去均值,运行示例:

layer = CenteredLayer()
layer(torch.FloatTensor([1, 2, 3, 4, 5]))> tensor([-2., -1.,  0.,  1.,  2.])

这个层没有定义需要训练的参数,这一类的层往往用于特定的功能转换,例如数据重排、裁剪、归一化等。

1.2 定义带参数的层

使用nn.Parameter来创建需要训练的参数,以线性全连接层为例

  • 需要两个参数:权重和偏置
  • 需要两个参数in_units和units来指明输入维度和输出维度
class MyLinear(nn.Module):def __init__(self, in_units, units):super().__init__()self.weight = nn.Parameter(torch.randn(in_units, units))self.bias = nn.Parameter(torch.randn(units,))def forward(self, X):return torch.matmul(X, self.weight.data) + self.bias.data

实例化MyLinear类实例并用其进行前向传播计算:

linear = MyLinear(5, 3)
linear(torch.rand(2, 5))> tensor([[ 1.9813, -0.1214,  0.1627],[ 2.6518, -0.8198,  0.6513]])

2. 块

在神经网络中,块可以表示为多个层组成的组件,将多个块组合能构成复杂的网络模型。

在这里插入图片描述
从编程的角度(以pytorch为例),块也是由继承nn.Module的类来表示,它必须具有的组成部分:

  1. 组成块的层
  2. 前向传播函数forward,用于将输入转换为输出;
  3. 反向传播函数backward,用于计算梯度;
  4. 待训练的参数;

由于pytorch支持反向传播自动求导,已经由pytorch内部封装了反向传播函数的实现,另外pytorch会自动根据层的大小来初始化模型参数w和b,所以我们在定义块时只需要考虑前向传播函数和组成块的层。

2.1 自定义块

以前一篇文章中的多层感知机为例,可以封装为一个块:

class MLP(nn.Module):# 用模型参数声明层。这里,我们声明两个全连接的层def __init__(self):# 调用MLP的父类Module的构造函数来执行必要的初始化。# 这样,在类实例化时也可以指定其他函数参数,例如模型参数params(稍后将介绍)super().__init__()self.hidden = nn.Linear(20, 256)  # 隐藏层self.out = nn.Linear(256, 10)  # 输出层# 定义模型的前向传播,即如何根据输入X返回所需的模型输出def forward(self, X):# 注意,这里我们使用ReLU的函数版本,其在nn.functional模块中定义。return self.out(F.relu(self.hidden(X)))

2.2 组合块

前篇文章用到的nn.Sequential其实也是一个块,只不过它的作用是将其它块按顺序组合到一起,形成一个串行执行的有序列表。

class MySequential(nn.Module):def __init__(self, *args):super().__init__()for idx, module in enumerate(args):# module是Module子类的一个实例,这里将它保存在'Module'类型的成员变量_modules中,_modules的类型是OrderedDictself._modules[str(idx)] = moduledef forward(self, X):# OrderedDict保证了按照成员添加的顺序遍历它们for block in self._modules.values():X = block(X)return X

每个nn.Module都有一个内置的_modules属性,目的是方便系统查找需要初始化参数的子块

3. 参数管理

训练模型的目的是为了找到使损失函数最小化的模型参数值,这个训练过程就如同我们调试程序一样,有时候需要打印中间结果以辅助我们进行问题的分析和诊断,所以我们有必要知道如何访问参数。

3.1 参数访问

当通过Sequential类定义模型时, 我们可以通过索引来访问模型的任意层,通过每层的state_dict()来获取该层的参数。

print(net[2].state_dict())
OrderedDict([('weight', tensor([[-0.0427, -0.2939, -0.1894,  0.0220, -0.1709, -0.1522, -0.0334, -0.2263]])), ('bias', tensor([0.0887]))])

可以看出,该层包含权重weight和偏置bias两个参数。

我们还可以直接访问权重或偏置。

print(type(net[2].weight)) # 类型
print(net[2].weight)       # 直接访问参数,包含参数值和梯度信息
print(net[2].weight.data)  # 访问参数值
# print(net[2].weight.grid)  # 访问参数的梯度,还没有训练,所以梯度还没值
<class 'torch.nn.parameter.Parameter'>
Parameter containing:
tensor([[ 0.0986,  0.2894,  0.3461,  0.2734, -0.3395, -0.0719, -0.3348, -0.0305]],requires_grad=True)
tensor([[ 0.0986,  0.2894,  0.3461,  0.2734, -0.3395, -0.0719, -0.3348, -0.0305]])

3.2 嵌套块参数访问

我们可以将多个块相互嵌套,组成更大的块。

  • block1有4层, linear, relu, linear, relu
  • block2嵌套了3个block1块
  • 最后将block2与一个线性输出层组合,构成一个网络
def block1():return nn.Sequential(nn.Linear(4, 2), nn.ReLU(),nn.Linear(2, 4), nn.ReLU())def block2():net = nn.Sequential()# block2中嵌套3个block1for i in range(3):net.add_module(f'block {i}', block1())return netrgnet = nn.Sequential(block2(), nn.Linear(4, 1))
print(rgnet)

这个包含嵌套块的网络结构如下:

Sequential((0): Sequential((block 0): Sequential((0): Linear(in_features=4, out_features=2, bias=True)(1): ReLU()(2): Linear(in_features=2, out_features=4, bias=True)(3): ReLU())(block 1): Sequential((0): Linear(in_features=4, out_features=2, bias=True)(1): ReLU()(2): Linear(in_features=2, out_features=4, bias=True)(3): ReLU())(block 2): Sequential((0): Linear(in_features=4, out_features=2, bias=True)(1): ReLU()(2): Linear(in_features=2, out_features=4, bias=True)(3): ReLU()))(1): Linear(in_features=4, out_features=1, bias=True)
)

嵌套块的参数访问:

rgnet[0][1][0].bias.data> tensor([ 0.4917, -0.3920])

3.3 参数初始化

对于参数的初始化,不明确指定时,pytorch会使用默认的随机初始化方法。PyTorch的nn.init模块也提供了多种可供选择的预置初始化方法。

  1. 指定使用正态分布的随机变量来初始化:
# 定义初始化函数
def init_normal(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, mean=0, std=0.01)nn.init.zeros_(m.bias)
# 使用指定函数对整个网络的参数进行初始化
net.apply(init_normal)  
net[0].weight.data[0], net[0].bias.data[0]> (tensor([-0.0101, -0.0117, -0.0116, -0.0016]), tensor(0.))
  1. 使用常数进行初始化:
def init_constant(m):if type(m) == nn.Linear:nn.init.constant_(m.weight, 1)nn.init.zeros_(m.bias)
net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]> (tensor([1., 1., 1., 1.]), tensor(0.))
  1. 对不同的块使用不同的初始化:
net[0].apply(init_normal)
net[2].apply(init_constant)
print(net[0].weight.data[0])
print(net[2].weight.data)> tensor([ 0.0054, -0.0188, -0.0112,  0.0097])
tensor([[1., 1., 1., 1., 1., 1., 1., 1.]])

4.4 参数绑定

含义:通过将一个层共享,可以实现相同的参数权重用于神经网络中的多个层。目的在于两方面:

  1. 减少模型的参数量:通过共享参数,可以大大减少需要学习的参数数量,从而减小模型的复杂度。
  2. 加速训练:参数共享可以减少内存占用和计算量,特别是在具有大量参数的深层网络中,可以显著提高计算效率。

下面是一个参数共享的代码示例:

# 给共享层一个名称,以便可以引用它的参数
shared = nn.Linear(8, 8)
# 第二层和第四层共享shared层的参数
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(),shared, nn.ReLU(),shared, nn.ReLU(),nn.Linear(8, 1))

输出第二层和第四层指定位置的初始参数,两者是相同的。

print(net[2].weight.data[0, 0])
print(net[4].weight.data[0, 0])> tensor(-0.0253)
tensor(-0.0253)  

修改第二层的参数:

net[2].weight.data[0, 0] = 100

再次输出第二层和第四层指定位置的参数,两者都变成了修改后的参数:

print(net[2].weight.data[0, 0])
print(net[4].weight.data[0, 0])> tensor(100.)
tensor(100.)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/14821.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

BLE学习笔记(0.0) —— 基础概念(0)

前言 &#xff08;1&#xff09;本章节主要是对BLE技术进行简单的介绍&#xff0c;熟悉蓝牙技术的发展过程&#xff0c;了解相关术语方便后续的学习。 &#xff08;2&#xff09;为了防止单篇博客太长以至于看不下去&#xff0c;因此我基础概念章节分为两篇来写。 &#xff08;…

直播回放| 机器人任务挑战赛线上培训资料合集

大赛培训回顾 5月22日&#xff0c;卓翼飞思实验室为全国各赛区精心组织的机器人任务挑战赛&#xff08;无人协同系统&#xff09;线上培训第三期顺利落下帷幕&#xff0c;吸引300余人参与。本次培训主要针对仿真平台的基本使用&#xff0c;从仿真平台获取激光雷达/视觉数据&am…

Mysql教程(0):学习框架

1、Mysql简介 MySQL 是一个开放源代码的、免费的关系型数据库管理系统。在 Web 开发领域&#xff0c;MySQL 是最流行、使用最广泛的关系数据库。MySql 分为社区版和商业版&#xff0c;社区版完全免费&#xff0c;并且几乎能满足全部的使用场景。由于 MySQL 是开源的&#xff0…

选择排序,改进冒泡排序,快速排序的查找和计数排序

简单选择排序 数据结构:单链表 实现方法:n为链表长度, 第1趟先选出1到n-1个元素中的最小值和0号元素交换, 第2趟从2到n-1号元素选出最小值和1号元素交换, … 第n-2趟从n-2到n-1号元素中选出最小值和n-2号元素交换. 第n-1趟n-1号元素即为最小值。比较结束。 代码:…

1075: 求最小生成树(Prim算法)

解法&#xff1a; 总结起来&#xff0c;Prim算法的核心思想是从一个顶点开始&#xff0c;一步一步地选择与当前最小生成树相邻的且权值最小的边&#xff0c;直到覆盖所有的顶点&#xff0c;形成一个最小生成树。 #include<iostream> #include<vector> using names…

算法-跳马

bfs类的应用题。 解法&#xff1a; 每一个点都可能作为汇集的那个点&#xff0c;因此采用遍历的方式&#xff0c;对每个点进行处理&#xff0c;得出每个点的“所有马跳到本点的最小步数和“&#xff0c;取最小值即可。 逻辑1&#xff1a;以该点作为源点出发&#xff0c;求处…

springboot基于Web前端技术的java养老院管理系统_utbl7

3.普通用户模块包括&#xff1a;普通会员的注册、养老院客房查询、养老院留言查询、预约老人基本信息登记、选择房间、用户缴费的功能。 4.数据信息能够及时进行动态更新&#xff0c;增删&#xff0c;用户搜素方便&#xff0c;使用户可以直接浏览相关信息&#xff0c;要考虑便于…

Vue3实战笔记(35)—集成炫酷的粒子特效

文章目录 前言一、vue3使用tsparticles二、使用步骤总结 前言 学习一个有趣炫酷的玩意开心一下。 tsparticles&#xff0c;可以方便的实现各种粒子特效。支持的语言框架也是相当的丰富. 官网&#xff1a;https://particles.js.org/ 一、vue3使用tsparticles 先来个vue3使用…

Go 语言逃逸分析:内存管理的关键

文章目录 前言1 逃逸分析是什么&#xff1f;2 逃逸分析的基本思想是什么&#xff1f;3 逃逸分析的分配原则是什么&#xff1f;4 如何进行逃逸分析&#xff1f;5 逃逸分析案例5.1 变量在函数外存在引用5.2 引用类型的逃逸5.3 闭包捕获变量5.4 变量占用内存较大 6 变量会逃逸到堆…

代码随想录训练营打卡第36天:动态规划解决子序列问题

1.300最长递增子序列 1.问题描述 找到其中最长严格递增子序列的长度。 子序列 是由数组派生而来的序列&#xff0c;删除&#xff08;或不删除&#xff09;数组中的元素而不改变其余元素的顺序。 2.问题转换 从nums[0...i]的最长的递增的子序列 3.解题思路 每一个位置的n…

经济学问题

问题1 1916年&#xff0c;福特汽车公司以440美元的价格生产了50万辆T型福特汽车。该公司当年盈利6000万美元。亨利福特告诉一位报纸记者&#xff0c;他打算把T型车的价格降至360美元&#xff0c;他希望在这个价格上能卖出80万辆汽车。福特说&#xff1a;“每辆车的利润减少&am…

Flutter 中的 CupertinoPicker 小部件:全面指南

Flutter 中的 CupertinoPicker 小部件&#xff1a;全面指南 在Flutter中&#xff0c;CupertinoPicker是一个用于创建iOS风格的选择器的组件&#xff0c;它允许用户通过滚动来选择一个值。CupertinoPicker可以用于选择日期、时间或者任何可枚举的值。本文将详细介绍CupertinoPi…

C++多态详解

目录 一、多态的概念 二、多态的定义及实现 1.多态的构成条件 2.虚函数 3.虚函数的重写 4.例题理解&#xff08;超级重要&#xff0c;强烈建议做一下&#xff09; 5.C11 override和 final 6.重载、覆盖&#xff08;重写&#xff09;、隐藏&#xff08;重定义&#xff0…

【yijiej】mysql报错 之 报错:Duplicate entry 字段 for key ‘表名.idx_字段’

一、问题操作 Mysql 进行insert 操作&#xff0c;报错&#xff1a;Duplicate entry 字段 for key ‘表名.idx_字段’ 原因解析&#xff1a;idx 是做的索引键&#xff0c;是具有唯一性二、问题原因&#xff08;三种情况&#xff0c;当前我遇到的情况是第一种&#xff09; 1、当 …

零基础代码随想录【Day42】|| 1049. 最后一块石头的重量 II,494. 目标和,474.一和零

目录 DAY42 1049.最后一块石头的重量II 解题思路&代码 494.目标和 解题思路&代码 474.一和零 解题思路&代码 DAY42 1049.最后一块石头的重量II 力扣题目链接(opens new window) 题目难度&#xff1a;中等 有一堆石头&#xff0c;每块石头的重量都是正整…

(Qt) 默认QtWidget应用包含什么?

文章目录 ⭐前言⭐创建&#x1f6e0;️选择一个模板&#x1f6e0;️Location&#x1f6e0;️构建系统&#x1f6e0;️Details&#x1f6e0;️Translation&#x1f6e0;️构建套件(Kit)&#x1f6e0;️汇总 ⭐项目⚒️概要⚒️构建步骤⚒️清除步骤 ⭐Code&#x1f526;untitled…

【EasyX】快速入门——消息处理,音频

1.消息处理 我们先看看什么是消息 1.1.获取消息 想要获取消息,就必须学会getmessage函数 1.1.1.getmessage函数 有两个重载版本,它们的作用是一样的 参数filter可以筛选我们需要的消息类型 我们看看参数filter的取值 当然我们可以使用位运算组合这些值 例如,我们…

华为CE6851-48S6Q-HI升级设备版本及补丁

文章目录 升级前准备工作笔记本和交换机设备配置互联地址启用FTP设备访问FTP设备升级系统版本及补丁 升级前准备工作 使用MobaXterm远程工具连接设备&#xff0c;并作为FTP服务器准备升级所需的版本文件及补丁文件 笔记本和交换机设备配置互联地址 在交换机接口配置IP&#…

Facebook隐私保护:数据安全的前沿挑战

在数字化时代&#xff0c;随着社交媒体的普及和应用&#xff0c;个人数据的隐私保护问题日益受到关注。作为全球最大的社交平台之一&#xff0c;Facebook承载了数十亿用户的社交活动和信息交流&#xff0c;但与此同时&#xff0c;也面临着来自内外部的数据安全挑战。本文将深入…

AWS Elastic Beanstalk 监控可观测最佳实践

一、概述 Amazon Web Services (AWS) 包含一百多种服务&#xff0c;每项服务都针对一个功能领域。服务的多样性可让您灵活地管理 AWS 基础设施&#xff0c;然而&#xff0c;判断应使用哪些服务以及如何进行预配置可能会非常困难。借助 Elastic Beanstalk&#xff0c;可以在 AW…