深度学习基础知识 register_buffer 与 register_parameter用法分析

深度学习基础知识 register_buffer 与 register_parameter用法分析

  • 1、问题引入
  • 2、register_parameter()
    • 2.1 作用
    • 2.2 用法
  • 3、register_buffer()
    • 3.1 作用
    • 3.2 用法

1、问题引入

思考问题:定义的weight与bias是否会被保存到网络的参数中,可否在优化器的作用下进行学习

验证方案:定义网络模型,设置weigut与bias,遍历网络结构参数net.named_parameters(),如果定义的weight与bias在里面,则说明是可学习参数;否则,是不可学习参数

import torch
import torch.nn as nn# 思考两个问题,定义的weight与bias是否会被保存到网络的参数中,可否在优化器的作用下进行学习class MyModule(nn.Module):def __init__(self):super(MyModule,self).__init__()self.conv1=nn.Conv2d(in_channels= 3,out_channels= 6,kernel_size=3,stride = 1,padding=1,bias=False)self.conv2=nn.Conv2d(in_channels= 6,out_channels= 9,kernel_size=3,stride = 1,padding=1,bias=False)self.waight=torch.ones(10,10)self.bias=torch.zeros(10)def forward(self,x):x=self.conv1(x)x=self.conv2(x)x = x * self.weight + self.biasreturn xnet=MyModule()for name,param in net.named_parameters():  # 如果weight与bias在里面,说明其是可学习参数;否则,是不可学习参数print(name,param.shape)print("\n","-"*40,"\n")for key,val in net.state_dict().items():  # 说明weight与bias是不会被state_dict转化为字典中的元素的print(key,val.shape)

打印分析结果:
在这里插入图片描述
可以看到,weight与bias不在其中,所以此种定义方式不会是的weight与bias成为可训练参数

2、register_parameter()

register_parameter()是 torch.nn.Module 类中的一个方法

2.1 作用

1、可将 self.weight 和 self.bias 定义为可学习的参数,保存到网络对象的参数中,被优化器作用进行学习
2、self.weight 和 self.bias 可被保存到 state_dict 中,进而可以 保存到网络文件 / 网络参数文件中

2.2 用法

register_parameter(name,param)

  • name:参数名称
  • param:参数张量, 须是 torch.nn.Parameter() 对象 或 None ,

否则报错如下
在这里插入图片描述

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)self.register_parameter('weight', torch.nn.Parameter(torch.ones(10, 10)))self.register_parameter('bias', torch.nn.Parameter(torch.zeros(10)))def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x * self.weight + self.biasreturn xnet = MyModule()for name, param in net.named_parameters():print(name, param.shape)print('\n', '*'*40, '\n')for key, val in net.state_dict().items():print(key, val.shape)

结果显示:
在这里插入图片描述

3、register_buffer()

register_buffer()是 torch.nn.Module() 类中的一个方法

3.1 作用

  • 将 self.weight 和 self.bias 定义为不可学习的参数,不会被保存到网络对象的参数中,不会被优化器作用进行学习

  • self.weight 和 self.bias 可被保存到 state_dict 中,进而可以 保存到网络文件 / 网络参数文件中

它用于在网络实例中 注册缓冲区,存储在缓冲区中的数据,类似于参数(但不是参数)

  • 参数:可以被优化器更新 (requires_grad=False / True)
  • buffer 中的数据 : 不会被优化器更新

3.2 用法

register_buffer(name,tensor)

  • name:参数名称
  • tensor:张量

代码:

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)self.register_buffer('weight', torch.ones(10, 10))   # 注意:定义的方式self.register_buffer('bias', torch.zeros(10))def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x * self.weight + self.biasreturn xnet = MyModule()for name, param in net.named_parameters():print(name, param.shape)print('\n', '*'*40, '\n')for key, val in net.state_dict().items():print(key, val.shape)

效果如下所示:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/98936.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《深入理解计算机系统》(2):虚拟内存

虚拟内存是一种对主存的抽象概念。 (1)将主存看作一个存储在磁盘上的地址空间的高速缓存,在主存中只保存活动区域,并根据需要在磁盘和主存之间来回传送数据,通过这种方式高效地使用内存 (2)为每…

GitHub要求开启2FA,否则不让用了。

背景 其实大概在一个多月前,在 GitHub 网页端以及邮箱里都被提示:要求开启 2FA ,即双因子认证;但是当时由于拖延症和侥幸心理作祟,直接忽略了相关信息,毕竟“又不是不能用”。。 只到今天发现 GitHub 直接…

Linux软硬链接和动静态库

本文已收录至《Linux知识与编程》专栏! 作者:ARMCSKGT 演示环境:CentOS 7 软硬链接和动静态库 前言正文软硬链接原理使用 文件时间动静态库库介绍静态库静态库制作静态库的使用关于静态链接 动态库动态库制作动态库的使用关于动态链接 补充 最…

Java练习题-用冒泡排序法实现数组排序

✅作者简介:CSDN内容合伙人、阿里云专家博主、51CTO专家博主、新星计划第三季python赛道Top1🏆 📃个人主页:hacker707的csdn博客 🔥系列专栏:Java练习题 💬个人格言:不断的翻越一座又…

Computer Architecture Subtitle:Engineering And Technology

原文链接:https://www.cs.umd.edu/~meesh/411/CA-online/index.html

基于Springboot实现疫情网课管理系统项目【项目源码+论文说明】

基于Springboot实现疫情网课管理系统演示 摘要 随着科学技术的飞速发展,各行各业都在努力与现代先进技术接轨,通过科技手段提高自身的优势;对于疫情网课管理系统当然也不能排除在外,随着网络技术的不断成熟,带动了疫情…

Windows11 安全中心页面不可用问题(无法打开病毒和威胁防护)解决方案汇总(图文介绍版)

本文目录 Windows版本与报错信息问题详细图片: 解决方案:方案一、管理员权限(若你确定你的电脑只有你一个账户,则此教程无效,若你也不清楚,请阅读后再做打算)方案二、修改注册表(常用方案)方案三、进入开发…

leetcode:2427. 公因子的数目(python3解法)

难度:简单 给你两个正整数 a 和 b ,返回 a 和 b 的 公 因子的数目。 如果 x 可以同时整除 a 和 b ,则认为 x 是 a 和 b 的一个 公因子 。 示例 1: 输入:a 12, b 6 输出:4 解释:12 和 6 的公因…

Meta分析的流程及方法

Meta分析是针对某一科研问题,根据明确的搜索策略、选择筛选文献标准、采用严格的评价方法,对来源不同的研究成果进行收集、合并及定量统计分析的方法,最早出现于“循证医学”,现已广泛应用于农林生态,资源环境等方面。…

linux日志审计常用命令

文章目录 cut参数指定范围命令 awk参数内置变量命令 wc参数命令 uniq参数命令 sort参数命令 head参数 cut 参数 选项含义-b仅显示行中指定直接范围的内容-c仅显示行中指定范围的字符-d指定分割符, 默认为“TAB”制表符-f显示指定字段的内容-n与“-b”连用&#xf…

Prometheus普罗米修斯

什么是Prometheus 官网:Overview | Prometheus 是一个开源的系统监控和警报工具,多数Prometheus组件是Go语言写的 为用户提供可视化仪表板、警报、告警等功能,以帮助用户快速定位和解决问题 现在已经成为一个独立于企业级的开源项目和一个…

供水管网监测系统

随着城市人口的不断增长和经济的快速发展,供水管网的安全和可靠性变得尤为重要。在过去,供水管网的监测往往是依靠人工巡查,这种方式不仅费时费力,而且容易出现疏漏和盲区。然而,随着科技的进步,供水管网监…

大数据集群(Hadoop生态)安装部署

目录 1. 简介 2. 前置要求 3. Hadoop集群角色 4. 角色和节点分配 5. 调整虚拟机内存 6. Zookeeper集群部署 7. Hadoop集群部署 7.1 下载Hadoop安装包、解压、配置软链接 7.2 修改配置文件:hadoop-env.sh 7.3 修改配置文件:core-site…

Vue3目录结构与Yarn.lock 的版本锁定

Vue目录结构与Yarn.lock 的版本锁定 一、Vue3.0目录结构图总览 举个例子看vue的目录,一开始不知道该目录是什么意思目录里各个文件包里安放有什么,程序员在哪里操作该如何操作。 下图目录看Vue新项目 VS Code 打开文件包后出现一列目录 二、目录结构 1…

宝塔面板二次元透明主题美化模板

看惯了宝塔面板默认风格模板,我们可以试试自己美化修改,我的站长站知道一款非常漂亮的宝塔面板二次元透明主题美化模板,美不美大家看下图,分享给大家。 下载:飞猫盘|文件加速传输工具|云盘&…

Vulnhub系列靶机-The Planets Earth

文章目录 Vulnhub系列靶机-The Planets: Earth1. 信息收集1.1 主机扫描1.2 端口扫描1.3 目录爆破 2. 漏洞探测2.1 XOR解密2.2 解码 3. 漏洞利用3.1 反弹Shell 4. 权限提升4.1 NC文件传输 Netcat(nc)文件传输 Vulnhub系列靶机-The Planets: Earth 1. 信息…

软件工程师都应该知道的10个定律

一、海勒姆法则 内容 当一个 API 有足够多的用户,你在契约中承诺了什么并不重要:系统中所有看得见的行为都会有某个人依赖…… 案例 现在有两个系统A和B,B的一个接口返回一个列表。A系统的开发人员发现返回的列表都是按照ID正向排序的。本…

Flink实现kafka到kafka、kafka到doris的精准一次消费

1 流程图 2 Flink来源表建模 --来源-城市topic CREATE TABLE NJ_QL_JC_SSJC_SOURCE ( record string ) WITH (connector = kafka,topic = QL_JC_SSJC,properties.bootstrap.servers = 172.*.*.*:9092,properties.group.id = QL_JC_SSJC_NJ_QL_JC_SSJC_SOURCE,scan.startup.mo…

基于Springboot实现疫情网课管理系统项目【项目源码+论文说明】分享

基于Springboot实现疫情网课管理系统演示 摘要 随着科学技术的飞速发展,各行各业都在努力与现代先进技术接轨,通过科技手段提高自身的优势;对于疫情网课管理系统当然也不能排除在外,随着网络技术的不断成熟,带动了疫情…