【现代深度学习技术】深度学习计算 | 参数管理

在这里插入图片描述

【作者主页】Francek Chen
【专栏介绍】 ⌈ ⌈ PyTorch深度学习 ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重要的技术特征是具有自动提取特征的能力。神经网络算法、算力和数据是开展深度学习的三要素。深度学习在计算机视觉、自然语言处理、多模态数据分析、科学探索等领域都取得了很多成果。本专栏介绍基于PyTorch的深度学习算法实现。
【GitCode】专栏资源保存在我的GitCode仓库:https://gitcode.com/Morse_Chen/PyTorch_deep_learning。

文章目录

    • 一、参数访问
      • (一)目标参数
      • (二)一次性访问所有参数
      • (三)从嵌套块收集参数
    • 二、参数初始化
      • (一)内置初始化
      • (二)自定义初始化
    • 三、参数绑定
    • 小结


  在选择了架构并设置了超参数后,我们就进入了训练阶段。此时,我们的目标是找到使损失函数最小化的模型参数值。经过训练后,我们将需要使用这些参数来做出未来的预测。此外,有时我们希望提取参数,以便在其他环境中复用它们,将模型保存下来,以便它可以在其他软件中执行,或者为了获得科学的理解而进行检查。

  之前的介绍中,我们只依靠深度学习框架来完成训练的工作,而忽略了操作参数的具体细节。本节,我们将介绍以下内容:

  • 访问参数,用于调试、诊断和可视化;
  • 参数初始化;
  • 在不同模型组件间共享参数。

  我们首先看一下具有单隐藏层的多层感知机。

import torch
from torch import nnnet = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(size=(2, 4))
net(X)

在这里插入图片描述

一、参数访问

  我们从已有模型中访问参数。当通过Sequential类定义模型时,我们可以通过索引来访问模型的任意层。这就像模型是一个列表一样,每层的参数都在其属性中。如下所示,我们可以检查第二个全连接层的参数。

print(net[2].state_dict())

在这里插入图片描述

  输出的结果告诉我们一些重要的事情:首先,这个全连接层包含两个参数,分别是该层的权重和偏置。两者都存储为单精度浮点数(float32)。注意,参数名称允许唯一标识每个参数,即使在包含数百个层的网络中也是如此。

(一)目标参数

  注意,每个参数都表示为参数类的一个实例。要对参数执行任何操作,首先我们需要访问底层的数值。有几种方法可以做到这一点。有些比较简单,而另一些则比较通用。下面的代码从第二个全连接层(即第三个神经网络层)提取偏置,提取后返回的是一个参数类实例,并进一步访问该参数的值。

print(type(net[2].bias))
print(net[2].bias)
print(net[2].bias.data)

在这里插入图片描述

  参数是复合的对象,包含值、梯度和额外信息。这就是我们需要显式参数值的原因。除了值之外,我们还可以访问每个参数的梯度。在上面这个网络中,由于我们还没有调用反向传播,所以参数的梯度处于初始状态。

net[2].weight.grad == None

在这里插入图片描述

(二)一次性访问所有参数

  当我们需要对所有参数执行操作时,逐个访问它们可能会很麻烦。当我们处理更复杂的块(例如,嵌套块)时,情况可能会变得特别复杂,因为我们需要递归整个树来提取每个子块的参数。下面,我们将通过演示来比较访问第一个全连接层的参数和访问所有层。

print(*[(name, param.shape) for name, param in net[0].named_parameters()])
print(*[(name, param.shape) for name, param in net.named_parameters()])

在这里插入图片描述

  这为我们提供了另一种访问网络参数的方式,如下所示。

net.state_dict()['2.bias'].data

在这里插入图片描述

(三)从嵌套块收集参数

  让我们看看,如果我们将多个块相互嵌套,参数命名约定是如何工作的。我们首先定义一个生成块的函数(可以说是“块工厂”),然后将这些块组合到更大的块中。

def block1():return nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 4), nn.ReLU())def block2():net = nn.Sequential()for i in range(4):# 在这里嵌套net.add_module(f'block {i}', block1())return netrgnet = nn.Sequential(block2(), nn.Linear(4, 1))
rgnet(X)

在这里插入图片描述

  设计了网络后,我们看看它是如何工作的。

print(rgnet)

在这里插入图片描述

  因为层是分层嵌套的,所以我们也可以像通过嵌套列表索引一样访问它们。下面,我们访问第一个主要的块中、第二个子块的第一层的偏置项。

rgnet[0][1][0].bias.data

在这里插入图片描述

二、参数初始化

  知道了如何访问参数后,现在我们看看如何正确地初始化参数。我们在【深度学习基础】多层感知机 | 数值稳定性和模型初始化 中讨论了良好初始化的必要性。深度学习框架提供默认随机初始化,也允许我们创建自定义初始化方法,满足我们通过其他规则实现初始化权重。

  默认情况下,PyTorch会根据一个范围均匀地初始化权重和偏置矩阵,这个范围是根据输入和输出维度计算出的。PyTorch的nn.init模块提供了多种预置初始化方法。

(一)内置初始化

  让我们首先调用内置的初始化器。下面的代码将所有权重参数初始化为标准差为0.01的高斯随机变量,且将偏置参数设置为0。

def init_normal(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, mean=0, std=0.01)nn.init.zeros_(m.bias)
net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]

在这里插入图片描述

  我们还可以将所有参数初始化为给定的常数,比如初始化为1。

def init_constant(m):if type(m) == nn.Linear:nn.init.constant_(m.weight, 1)nn.init.zeros_(m.bias)
net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]

在这里插入图片描述

  我们还可以对某些块应用不同的初始化方法。例如,下面我们使用Xavier初始化方法初始化第一个神经网络层,然后将第三个神经网络层初始化为常量值42。

def init_xavier(m):if type(m) == nn.Linear:nn.init.xavier_uniform_(m.weight)
def init_42(m):if type(m) == nn.Linear:nn.init.constant_(m.weight, 42)net[0].apply(init_xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data)

在这里插入图片描述

(二)自定义初始化

  有时,深度学习框架没有提供我们需要的初始化方法。在下面的例子中,我们使用以下的分布为任意权重参数 w w w定义初始化方法:

w ∼ { U ( 5 , 10 ) 可能性  1 4 0 可能性  1 2 U ( − 10 , − 5 ) 可能性  1 4 (1) \begin{aligned} w \sim \begin{cases} U(5, 10) & \text{ 可能性 } \frac{1}{4} \\ 0 & \text{ 可能性 } \frac{1}{2} \\ U(-10, -5) & \text{ 可能性 } \frac{1}{4} \end{cases} \end{aligned} \tag{1} w U(5,10)0U(10,5) 可能性 41 可能性 21 可能性 41(1)

  同样,我们实现了一个my_init函数来应用到net

def my_init(m):if type(m) == nn.Linear:print("Init", *[(name, param.shape) for name, param in m.named_parameters()][0])nn.init.uniform_(m.weight, -10, 10)m.weight.data *= m.weight.data.abs() >= 5net.apply(my_init)
net[0].weight[:2]

在这里插入图片描述

  注意,我们始终可以直接设置参数。

net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data[0]

在这里插入图片描述

三、参数绑定

  有时我们希望在多个层间共享参数:我们可以定义一个稠密层,然后使用它的参数来设置另一个层的参数。

# 我们需要给共享层一个名称,以便可以引用它的参数
shared = nn.Linear(8, 8)
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), shared, nn.ReLU(),shared, nn.ReLU(), nn.Linear(8, 1))
net(X)
# 检查参数是否相同
print(net[2].weight.data[0] == net[4].weight.data[0])
net[2].weight.data[0, 0] = 100
# 确保它们实际上是同一个对象,而不只是有相同的值
print(net[2].weight.data[0] == net[4].weight.data[0])

在这里插入图片描述

  这个例子表明第三个和第五个神经网络层的参数是绑定的。它们不仅值相等,而且由相同的张量表示。因此,如果我们改变其中一个参数,另一个参数也会改变。这里有一个问题:当参数绑定时,梯度会发生什么情况?答案是由于模型参数包含梯度,因此在反向传播期间第二个隐藏层(即第三个神经网络层)和第三个隐藏层(即第五个神经网络层)的梯度会加在一起。

小结

  • 我们有几种方法可以访问、初始化和绑定模型参数。
  • 我们可以使用自定义初始化方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/893971.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言,无法正常释放char*的空间

问题描述 #include <stdio.h> #include <stdio.h>const int STRSIZR 10;int main() {char *str (char *)malloc(STRSIZR*sizeof(char));str "string";printf("%s\n", str);free(str); } 乍一看&#xff0c;这块代码没有什么问题。直接书写…

2025春招 SpringCloud 面试题汇总

大家好&#xff0c;我是 V 哥。SpringCloud 在面试中属于重灾区&#xff0c;不仅是基础概念、组件细节&#xff0c;还有高级特性、性能优化&#xff0c;关键是项目实践经验的解决方案&#xff0c;都是需要掌握的内容&#xff0c;正所谓打有准备的仗&#xff0c;秒杀面试官&…

C#面试常考随笔6:ArrayList和 List的主要区别?

在 C# 中&#xff0c;ArrayList和List<T>&#xff08;泛型列表&#xff09;都可用于存储一组对象。推荐优先使用List<T>&#xff0c;因为它具有更好的类型安全性、性能和语法简洁性&#xff0c;并且提供了更丰富的功能。只有在需要与旧代码兼容或存储不同类型对象的…

用C++编写一个2048的小游戏

以下是一个简单的2048游戏的实现。这个实现使用了控制台输入和输出&#xff0c;适合在终端或命令行环境中运行。 2048游戏的实现 1.游戏逻辑 2048游戏的核心逻辑包括&#xff1a; • 初始化一个4x4的网格。 • 随机生成2或4。 • 处理玩家的移动操作&#xff08;上、下、左、…

Julia Distributed(分布式计算)详解

分布式计算是现代计算机科学中日益重要的一个分支&#xff0c;特别是在处理大规模数据和计算密集型应用时变得尤为关键。Julia 语言作为一种高性能的开源编程语言&#xff0c;其内置的分布式计算功能强大且易于使用。本篇博客将深入探讨 Julia Distributed 的基础概念、使用方法…

【开源免费】基于Vue和SpringBoot的在线文档管理系统(附论文)

本文项目编号 T 038 &#xff0c;文末自助获取源码 \color{red}{T038&#xff0c;文末自助获取源码} T038&#xff0c;文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 查…

忘记宝塔的访问地址怎么找

在linux中安装宝塔面板后会生成网址、账号和密码 如果网址忘记了那将进不去宝塔面板该怎么办呢&#xff1f; bt命令 我们输入 bt 命令的时候&#xff0c;是在根目录里面进行操作的。 / bt 我们根据自己的需要&#xff0c;选择对应的数字就可以了。 bt 14 输入 14 查看面板默…

FLTK - FLTK1.4.1 - demo - animgifimage

文章目录 FLTK - FLTK1.4.1 - demo - animgifimage概述笔记END FLTK - FLTK1.4.1 - demo - animgifimage 概述 知识点: 注册图像文件类型判断回调 FLTK支持的图像格式 GIF, BMP, ICO, PNM, PNG, jpg, svg 事件回调的注册 GIF图像显示为图片或动画的标志设置 // 超时回调的设置…

力扣hot100-->滑动窗口、贪心

你好呀&#xff0c;欢迎来到 Dong雨 的技术小栈 &#x1f331; 在这里&#xff0c;我们一同探索代码的奥秘&#xff0c;感受技术的魅力 ✨。 &#x1f449; 我的小世界&#xff1a;Dong雨 &#x1f4cc; 分享我的学习旅程 &#x1f6e0;️ 提供贴心的实用工具 &#x1f4a1; 记…

【蓝桥杯嵌入式入门与进阶】2.与开发板之间破冰:初始开发板和原理图2

个人主页&#xff1a;Icomi 专栏地址&#xff1a;蓝桥杯嵌入式组入门与进阶 大家好&#xff0c;我是一颗米&#xff0c;本篇专栏旨在帮助大家从0开始入门蓝桥杯并且进阶&#xff0c;若对本系列文章感兴趣&#xff0c;欢迎订阅我的专栏&#xff0c;我将持续更新&#xff0c;祝你…

Spring Boot - 数据库集成02 - 集成JPA

集成JPA 文章目录 集成JPA一&#xff1a;JPA概述1&#xff1a;JPA & JDBC2&#xff1a;JPA规范3&#xff1a;JPA的状态和转换关系 二&#xff1a;Spring data JPA1&#xff1a;JPA_repository1.1&#xff1a;CurdRepostory<T, ID>1.2&#xff1a;PagingAndSortingRep…

从ai产品推荐到利用cursor快速掌握一个开源项目再到langchain手搓一个Text2Sql agent

目录 0. 经验分享&#xff1a;产品推荐 1. 经验分享&#xff1a;提示词优化 2. 经验分享&#xff1a;使用cursor 阅读一篇文章 3. 经验分享&#xff1a;使用cursor 阅读一个完全陌生的开源项目 4. 经验分享&#xff1a;手搓一个text2sql agent &#xff08;使用langchain l…

基于DeepSeek在藏语学习推广和藏语信息化方面可以做哪些工作?

基于DeepSeek对藏语的技术优势&#xff0c;您可在以下三大方向开展创新性工作&#xff0c;以下是20具体落地方案&#xff1a; 一、藏语智能教育工具开发 《三十颂》AI语法教练 开发虚拟助教自动解析藏文句子结构&#xff08;标注格助词/时态变化&#xff09;错误检测系统&…

【Java-数据结构】Java 链表面试题下 “最后一公里”:解决复杂链表问题的致胜法宝

我的个人主页 我的专栏&#xff1a;Java-数据结构&#xff0c;希望能帮助到大家&#xff01;&#xff01;&#xff01;点赞❤ 收藏❤ 引言&#xff1a; Java链表&#xff0c;看似简单的链式结构&#xff0c;却蕴含着诸多有趣的特性与奥秘&#xff0c;等待我们去挖掘。它就像一…

深入探索 HTML5 拖拽效果 API:打造流畅交互体验

在现代的 Web 开发中&#xff0c;交互性和用户体验一直是开发者关注的重点。HTML5 的拖拽效果 API (Drag and Drop API) 提供了一种非常直观的方式来让网页元素或文件能够被拖动并放置到页面的指定位置&#xff0c;极大提升了用户的交互体验。本篇文章将深入探讨如何使用 HTML5…

智慧园区系统的类型及其在企业管理效率提升中的关键作用解析

内容概要 在智慧园区的建设中&#xff0c;各类系统的采用是提升管理效率的关键所在。快鲸智慧园区(楼宇)管理系统&#xff0c;通过其全面数字化的管理手段&#xff0c;已经成为了企业管理的新标杆。这一系统能够有效整合租赁管理、资产管理、招商管理和物业管理等功能&#xf…

微信小程序压缩图片

由于wx.compressImage(Object object) iOS 仅支持压缩 JPG 格式图片。所以我们需要做一下特殊的处理&#xff1a; 1.获取文件&#xff0c;判断文件是否大于设定的大小 2.如果大于则使用canvas进行绘制&#xff0c;并生成新的图片路径 3.上传图片 async chooseImage() {let …

国内flutter环境部署(记录篇)

设置系统环境变量 export PUB_HOSTED_URLhttps://pub.flutter-io.cn export FLUTTER_STORAGE_BASE_URLhttps://storage.flutter-io.cn使用以下命令下载flutter镜像 git clone -b stable https://mirror.ghproxy.com/https://github.com/<github仓库地址>#例如flutter仓…

【uniapp】uniapp使用java线程池

标题 由于js是性能孱弱的单线程语言&#xff0c;只要在渲染中执行了一些其他操作&#xff0c;会中断渲染&#xff0c;导致页面卡死&#xff0c;卡顿&#xff0c;吐司不消失等问题。在安卓端可以调用java线程池&#xff0c;把耗时操作写入线程池里面&#xff0c;优化性能。 实…

多级缓存(亿级并发解决方案)

多级缓存&#xff08;亿级流量&#xff08;并发&#xff09;的缓存方案&#xff09; 传统缓存的问题 传统缓存是请求到达tomcat后&#xff0c;先查询redis&#xff0c;如果未命中则查询数据库&#xff0c;问题如下&#xff1a; &#xff08;1&#xff09;请求要经过tomcat处…