动手学深度学习——层和块

1. 层

层是一个将输入数据转换为输出数据的神经网络组件。每个层都会对输入数据进行一定的操作,例如线性变换、非线性激活函数等,以产生输出数据。

torch.nn模块提供了各种预定义的层,如线性层、卷积层、池化层等,

  • nn.Linear:线性层
  • nn.MaxPool2d:二维池化层
  • nn.Conv2d:二维卷积层
  • nn.ReLu:激活函数层

也支持基于nn.Module自定义层。

1.1 自定义简单层

import torch
import torch.nn.functional as F
from torch import nnclass CenteredLayer(nn.Module):def __init__(self):super().__init__()def forward(self, X):return X - X.mean()

这个层的功能是对每个输入减去均值,运行示例:

layer = CenteredLayer()
layer(torch.FloatTensor([1, 2, 3, 4, 5]))> tensor([-2., -1.,  0.,  1.,  2.])

这个层没有定义需要训练的参数,这一类的层往往用于特定的功能转换,例如数据重排、裁剪、归一化等。

1.2 定义带参数的层

使用nn.Parameter来创建需要训练的参数,以线性全连接层为例

  • 需要两个参数:权重和偏置
  • 需要两个参数in_units和units来指明输入维度和输出维度
class MyLinear(nn.Module):def __init__(self, in_units, units):super().__init__()self.weight = nn.Parameter(torch.randn(in_units, units))self.bias = nn.Parameter(torch.randn(units,))def forward(self, X):return torch.matmul(X, self.weight.data) + self.bias.data

实例化MyLinear类实例并用其进行前向传播计算:

linear = MyLinear(5, 3)
linear(torch.rand(2, 5))> tensor([[ 1.9813, -0.1214,  0.1627],[ 2.6518, -0.8198,  0.6513]])

2. 块

在神经网络中,块可以表示为多个层组成的组件,将多个块组合能构成复杂的网络模型。

在这里插入图片描述
从编程的角度(以pytorch为例),块也是由继承nn.Module的类来表示,它必须具有的组成部分:

  1. 组成块的层
  2. 前向传播函数forward,用于将输入转换为输出;
  3. 反向传播函数backward,用于计算梯度;
  4. 待训练的参数;

由于pytorch支持反向传播自动求导,已经由pytorch内部封装了反向传播函数的实现,另外pytorch会自动根据层的大小来初始化模型参数w和b,所以我们在定义块时只需要考虑前向传播函数和组成块的层。

2.1 自定义块

以前一篇文章中的多层感知机为例,可以封装为一个块:

class MLP(nn.Module):# 用模型参数声明层。这里,我们声明两个全连接的层def __init__(self):# 调用MLP的父类Module的构造函数来执行必要的初始化。# 这样,在类实例化时也可以指定其他函数参数,例如模型参数params(稍后将介绍)super().__init__()self.hidden = nn.Linear(20, 256)  # 隐藏层self.out = nn.Linear(256, 10)  # 输出层# 定义模型的前向传播,即如何根据输入X返回所需的模型输出def forward(self, X):# 注意,这里我们使用ReLU的函数版本,其在nn.functional模块中定义。return self.out(F.relu(self.hidden(X)))

2.2 组合块

前篇文章用到的nn.Sequential其实也是一个块,只不过它的作用是将其它块按顺序组合到一起,形成一个串行执行的有序列表。

class MySequential(nn.Module):def __init__(self, *args):super().__init__()for idx, module in enumerate(args):# module是Module子类的一个实例,这里将它保存在'Module'类型的成员变量_modules中,_modules的类型是OrderedDictself._modules[str(idx)] = moduledef forward(self, X):# OrderedDict保证了按照成员添加的顺序遍历它们for block in self._modules.values():X = block(X)return X

每个nn.Module都有一个内置的_modules属性,目的是方便系统查找需要初始化参数的子块

3. 参数管理

训练模型的目的是为了找到使损失函数最小化的模型参数值,这个训练过程就如同我们调试程序一样,有时候需要打印中间结果以辅助我们进行问题的分析和诊断,所以我们有必要知道如何访问参数。

3.1 参数访问

当通过Sequential类定义模型时, 我们可以通过索引来访问模型的任意层,通过每层的state_dict()来获取该层的参数。

print(net[2].state_dict())
OrderedDict([('weight', tensor([[-0.0427, -0.2939, -0.1894,  0.0220, -0.1709, -0.1522, -0.0334, -0.2263]])), ('bias', tensor([0.0887]))])

可以看出,该层包含权重weight和偏置bias两个参数。

我们还可以直接访问权重或偏置。

print(type(net[2].weight)) # 类型
print(net[2].weight)       # 直接访问参数,包含参数值和梯度信息
print(net[2].weight.data)  # 访问参数值
# print(net[2].weight.grid)  # 访问参数的梯度,还没有训练,所以梯度还没值
<class 'torch.nn.parameter.Parameter'>
Parameter containing:
tensor([[ 0.0986,  0.2894,  0.3461,  0.2734, -0.3395, -0.0719, -0.3348, -0.0305]],requires_grad=True)
tensor([[ 0.0986,  0.2894,  0.3461,  0.2734, -0.3395, -0.0719, -0.3348, -0.0305]])

3.2 嵌套块参数访问

我们可以将多个块相互嵌套,组成更大的块。

  • block1有4层, linear, relu, linear, relu
  • block2嵌套了3个block1块
  • 最后将block2与一个线性输出层组合,构成一个网络
def block1():return nn.Sequential(nn.Linear(4, 2), nn.ReLU(),nn.Linear(2, 4), nn.ReLU())def block2():net = nn.Sequential()# block2中嵌套3个block1for i in range(3):net.add_module(f'block {i}', block1())return netrgnet = nn.Sequential(block2(), nn.Linear(4, 1))
print(rgnet)

这个包含嵌套块的网络结构如下:

Sequential((0): Sequential((block 0): Sequential((0): Linear(in_features=4, out_features=2, bias=True)(1): ReLU()(2): Linear(in_features=2, out_features=4, bias=True)(3): ReLU())(block 1): Sequential((0): Linear(in_features=4, out_features=2, bias=True)(1): ReLU()(2): Linear(in_features=2, out_features=4, bias=True)(3): ReLU())(block 2): Sequential((0): Linear(in_features=4, out_features=2, bias=True)(1): ReLU()(2): Linear(in_features=2, out_features=4, bias=True)(3): ReLU()))(1): Linear(in_features=4, out_features=1, bias=True)
)

嵌套块的参数访问:

rgnet[0][1][0].bias.data> tensor([ 0.4917, -0.3920])

3.3 参数初始化

对于参数的初始化,不明确指定时,pytorch会使用默认的随机初始化方法。PyTorch的nn.init模块也提供了多种可供选择的预置初始化方法。

  1. 指定使用正态分布的随机变量来初始化:
# 定义初始化函数
def init_normal(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, mean=0, std=0.01)nn.init.zeros_(m.bias)
# 使用指定函数对整个网络的参数进行初始化
net.apply(init_normal)  
net[0].weight.data[0], net[0].bias.data[0]> (tensor([-0.0101, -0.0117, -0.0116, -0.0016]), tensor(0.))
  1. 使用常数进行初始化:
def init_constant(m):if type(m) == nn.Linear:nn.init.constant_(m.weight, 1)nn.init.zeros_(m.bias)
net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]> (tensor([1., 1., 1., 1.]), tensor(0.))
  1. 对不同的块使用不同的初始化:
net[0].apply(init_normal)
net[2].apply(init_constant)
print(net[0].weight.data[0])
print(net[2].weight.data)> tensor([ 0.0054, -0.0188, -0.0112,  0.0097])
tensor([[1., 1., 1., 1., 1., 1., 1., 1.]])

4.4 参数绑定

含义:通过将一个层共享,可以实现相同的参数权重用于神经网络中的多个层。目的在于两方面:

  1. 减少模型的参数量:通过共享参数,可以大大减少需要学习的参数数量,从而减小模型的复杂度。
  2. 加速训练:参数共享可以减少内存占用和计算量,特别是在具有大量参数的深层网络中,可以显著提高计算效率。

下面是一个参数共享的代码示例:

# 给共享层一个名称,以便可以引用它的参数
shared = nn.Linear(8, 8)
# 第二层和第四层共享shared层的参数
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(),shared, nn.ReLU(),shared, nn.ReLU(),nn.Linear(8, 1))

输出第二层和第四层指定位置的初始参数,两者是相同的。

print(net[2].weight.data[0, 0])
print(net[4].weight.data[0, 0])> tensor(-0.0253)
tensor(-0.0253)  

修改第二层的参数:

net[2].weight.data[0, 0] = 100

再次输出第二层和第四层指定位置的参数,两者都变成了修改后的参数:

print(net[2].weight.data[0, 0])
print(net[4].weight.data[0, 0])> tensor(100.)
tensor(100.)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/14821.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

BLE学习笔记(0.0) —— 基础概念(0)

前言 &#xff08;1&#xff09;本章节主要是对BLE技术进行简单的介绍&#xff0c;熟悉蓝牙技术的发展过程&#xff0c;了解相关术语方便后续的学习。 &#xff08;2&#xff09;为了防止单篇博客太长以至于看不下去&#xff0c;因此我基础概念章节分为两篇来写。 &#xff08;…

Mysql教程(0):学习框架

1、Mysql简介 MySQL 是一个开放源代码的、免费的关系型数据库管理系统。在 Web 开发领域&#xff0c;MySQL 是最流行、使用最广泛的关系数据库。MySql 分为社区版和商业版&#xff0c;社区版完全免费&#xff0c;并且几乎能满足全部的使用场景。由于 MySQL 是开源的&#xff0…

1075: 求最小生成树(Prim算法)

解法&#xff1a; 总结起来&#xff0c;Prim算法的核心思想是从一个顶点开始&#xff0c;一步一步地选择与当前最小生成树相邻的且权值最小的边&#xff0c;直到覆盖所有的顶点&#xff0c;形成一个最小生成树。 #include<iostream> #include<vector> using names…

springboot基于Web前端技术的java养老院管理系统_utbl7

3.普通用户模块包括&#xff1a;普通会员的注册、养老院客房查询、养老院留言查询、预约老人基本信息登记、选择房间、用户缴费的功能。 4.数据信息能够及时进行动态更新&#xff0c;增删&#xff0c;用户搜素方便&#xff0c;使用户可以直接浏览相关信息&#xff0c;要考虑便于…

Vue3实战笔记(35)—集成炫酷的粒子特效

文章目录 前言一、vue3使用tsparticles二、使用步骤总结 前言 学习一个有趣炫酷的玩意开心一下。 tsparticles&#xff0c;可以方便的实现各种粒子特效。支持的语言框架也是相当的丰富. 官网&#xff1a;https://particles.js.org/ 一、vue3使用tsparticles 先来个vue3使用…

代码随想录训练营打卡第36天:动态规划解决子序列问题

1.300最长递增子序列 1.问题描述 找到其中最长严格递增子序列的长度。 子序列 是由数组派生而来的序列&#xff0c;删除&#xff08;或不删除&#xff09;数组中的元素而不改变其余元素的顺序。 2.问题转换 从nums[0...i]的最长的递增的子序列 3.解题思路 每一个位置的n…

C++多态详解

目录 一、多态的概念 二、多态的定义及实现 1.多态的构成条件 2.虚函数 3.虚函数的重写 4.例题理解&#xff08;超级重要&#xff0c;强烈建议做一下&#xff09; 5.C11 override和 final 6.重载、覆盖&#xff08;重写&#xff09;、隐藏&#xff08;重定义&#xff0…

零基础代码随想录【Day42】|| 1049. 最后一块石头的重量 II,494. 目标和,474.一和零

目录 DAY42 1049.最后一块石头的重量II 解题思路&代码 494.目标和 解题思路&代码 474.一和零 解题思路&代码 DAY42 1049.最后一块石头的重量II 力扣题目链接(opens new window) 题目难度&#xff1a;中等 有一堆石头&#xff0c;每块石头的重量都是正整…

(Qt) 默认QtWidget应用包含什么?

文章目录 ⭐前言⭐创建&#x1f6e0;️选择一个模板&#x1f6e0;️Location&#x1f6e0;️构建系统&#x1f6e0;️Details&#x1f6e0;️Translation&#x1f6e0;️构建套件(Kit)&#x1f6e0;️汇总 ⭐项目⚒️概要⚒️构建步骤⚒️清除步骤 ⭐Code&#x1f526;untitled…

【EasyX】快速入门——消息处理,音频

1.消息处理 我们先看看什么是消息 1.1.获取消息 想要获取消息,就必须学会getmessage函数 1.1.1.getmessage函数 有两个重载版本,它们的作用是一样的 参数filter可以筛选我们需要的消息类型 我们看看参数filter的取值 当然我们可以使用位运算组合这些值 例如,我们…

华为CE6851-48S6Q-HI升级设备版本及补丁

文章目录 升级前准备工作笔记本和交换机设备配置互联地址启用FTP设备访问FTP设备升级系统版本及补丁 升级前准备工作 使用MobaXterm远程工具连接设备&#xff0c;并作为FTP服务器准备升级所需的版本文件及补丁文件 笔记本和交换机设备配置互联地址 在交换机接口配置IP&#…

Facebook隐私保护:数据安全的前沿挑战

在数字化时代&#xff0c;随着社交媒体的普及和应用&#xff0c;个人数据的隐私保护问题日益受到关注。作为全球最大的社交平台之一&#xff0c;Facebook承载了数十亿用户的社交活动和信息交流&#xff0c;但与此同时&#xff0c;也面临着来自内外部的数据安全挑战。本文将深入…

AWS Elastic Beanstalk 监控可观测最佳实践

一、概述 Amazon Web Services (AWS) 包含一百多种服务&#xff0c;每项服务都针对一个功能领域。服务的多样性可让您灵活地管理 AWS 基础设施&#xff0c;然而&#xff0c;判断应使用哪些服务以及如何进行预配置可能会非常困难。借助 Elastic Beanstalk&#xff0c;可以在 AW…

AWTK实现汽车仪表Cluster/DashBoard嵌入式GUI开发(七):快启

前言: 汽车仪表是人们了解汽车状况的窗口,而仪表中的大部分信息都是以指示灯形式显示给驾驶者。仪表指示灯图案都较为抽象,对驾驶不熟悉的人在理解仪表指示灯含义方面存在不同程度的困难,尤其对于驾驶新手,如果对指示灯的含义不求甚解,有可能影响驾驶的安全性。即使是对…

Pytest框架实战二

在Pytest框架实战一中详细地介绍了Pytest测试框架在参数化以及Fixture函数在API测试领域的实战案例以及具体的应用。本文章接着上个文章的内容继续阐述Pytest测试框架优秀的特性以及在自动化测试领域的实战。 conftest.py 在上一篇文章中阐述到Fixture函数的特性&#xff0c;第…

shell循环

一、for循环 用法&#xff1a; for 变量 in 取值列表 do 命令序列 done 例1&#xff1a;打印1到10的数字列表 #!/bin/bashfor i in {1..10} do echo $i done 例2&#xff1a;#批量添加用户,用户名存放在users.txt文件中&#xff0c;每行一个,初始密码均设为123456 #!/bin/bas…

KMP算法【C++】

KMP算法测试 KMP 算法详解 根据解释写出对应的C代码进行测试&#xff0c;也可以再整理成一个函数 #include <iostream> #include <vector>class KMP { private:std::string m_pat;//被匹配的字符串std::vector<std::vector<int>> m_dp;//状态二维数组…

哪些类型的产品适合用3D形式展示?

随着3D技术的蓬勃发展&#xff0c;众多品牌和企业纷纷投身3D数字化浪潮&#xff0c;将产品打造成逼真的3D模型进行展示&#xff0c;消费者可以更加直观地了解产品的特点和优势&#xff0c;从而做出更明智的购买决策。 哪些产品适合3D交互展示&#xff1f; 产品3D交互展示具有直…

自己手写一个字符串【C风格】

//字符串的常见操作 #include <iostream>#define MAX_SIZE 15 #define OK 1 #define ERROR 0 #define TRUE 1 #define FALSE 0 typedef int Status;//状态类型 typedef char ElemType;//元素类型typedef ElemType String[MAX_SIZE 1];//第一个字节记录长度//***tring是数…

c#自动生成缺陷图像-添加新功能(可从xml直接提取目标数据,然后进行数据离线增强)--20240524

在进行深度学习时,数据集十分重要,尤其是负样本数据。 故设计该软件进行深度学习数据预处理,最大可能性获取较多的模拟工业现场负样本数据集。 该软件基于VS2015、.NETFrameWork4.7.2、OpenCvSharp1.0.0.0、netstandard2.0.0.0、SunnyUI3.2.9.0、SunnyUI.Common3.2.9.0及Ope…