register_parameter和register_buffer 详解

在参考yolo系列代码或其他开源代码,经常看到register_buffer register_parameter的使用,接下来将详细对他们进行介绍。

1. 前沿

在搭建网络时,我们 自定义的参数,往往不会保存到模型权重文件中,或者成为模型可学习的参数。即我们通过 net.named_parameters() (模型可学习参数)或 net.state_dict().items()(保存模型权重值)方法都无法遍历输出。那如何解决呢,这就需要用到本文讲的register_parameterregister_buffer方法。

2. register_parameter

register_parameter() 是 torch.nn.Module 类中的一个方法。

2.1 主要作用

  • 用于定义可学习参数
  • 定义的参数可被保存到网络对象的参数中,可使用 net.parameters()net.named_parameters() 查看
  • 定义的参数可用 net.state_dict() 转换到字典中,进而 保存到网络文件 / 网络参数文件

2.2 函数说明

register_parameter(name,param)

参数:

  • name:参数名称

  • param:参数张量, 须是torch.nn.Parameter()对象 或 None ,否则报错如下
    TypeError: cannot assign 'torch.FloatTensor' object to parameter 'xx' (torch.nn.Parameter or None required)

2.3 举例说明

(1)自定义的参数未使用register_parameter

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)self.weight = torch.ones(10,10)self.bias = torch.zeros(10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x * self.weight + self.biasreturn xnet = MyModule()print('\n', '*'*30+"net.named_parameters"+'*'*30, '\n')
for name, param in net.named_parameters():print(name, param.shape)print('\n', '*'*30+"net.state_dict"+'*'*30, '\n')
for key, val in net.state_dict().items():print(key, val.shape) 

输出:
在这里插入图片描述
在网络搭建的代码中,我们自定义了self.weightself.bias参数。我们思考下2个问题:1. 我们定义的self.weightself.bias参数是否会保存到网络的参数中,是否能在优化器的作用下进行学习。2. 这些参数是否能够保存到模型文件中,从而可以利用state_dict中遍历出来。通过上面的打印信息我们发现:

  • 使用net.named_parameters()迭代网络中可学习的参数,发现输出的参数只有conv1conv2的weight参数,并没有输出我们定义的self.weightself.bias
  • 接下来使用net.state_dict()方法迭代保存的参数,同样发现self.weightself.bias参数也没有被输出出来。

(2)通过register_parameter方法来定义参数

  • 接下来我们使用register_parameter来定义weight和bias参数,看看会有啥效果。代码修改如下:
self.register_parameter('weight',torch.nn.Parameter(torch.ones(10,10)))
self.register_parameter('bias',torch.nn.Parameter(torch.zeros(10)))

完整代码

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)self.register_parameter('weight',torch.nn.Parameter(torch.ones(10,10)))self.register_parameter('bias',torch.nn.Parameter(torch.zeros(10)))def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x * self.weight + self.biasreturn xnet = MyModule()print('\n', '*'*30+"net.named_parameters"+'*'*30, '\n')
for name, param in net.named_parameters():print(name, param.shape)print('\n', '*'*30+"net.state_dict"+'*'*30, '\n')
for key, val in net.state_dict().items():print(key, val.shape) 

在这里插入图片描述

  • 可以看到,使用了register_parameter定义的参数weight和bias,可以通过net.named_parameters或者net.parameters迭代出来的,这说明weight和bias已经存到了网络的参数中,他们是可学习的参数
  • 同时,通过state_dict()也能将参数和值给迭代出来,就说明如果要保存模型权重或网络参数时,这两个参数时可以被保存起来的。

3 register_buffer()

register_buffer()是 torch.nn.Module() 类中的一个方法

3.1 作用

  • 用于定义不可学习的参数
  • 定义的参数不会被保存到网络对象的参数中,使用 net.parameters() 或 net.named_parameters() 查看不到
  • 定义的参数可用 net.state_dict() 转换到字典中,进而 保存到网络文件 / 网络参数文件中

register_buffer() 用于在网络实例中 注册缓冲区,存储在缓冲区中的数据,类似于参数(但不是参数),它与参数的区别为:

  • 参数:可以被优化器更新 (requires_grad=False / True)

  • buffer 中的数据 (不可学习): 不会被优化器更新

3.2、举例说明

将定义的weight和bias,通过register_buffer来定义。

self.register_buffer('weight',torch.ones(10,10))
self.register_buffer('bias',torch.zeros(10))

运行完整代码看看效果:

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)self.register_buffer('weight',torch.ones(10,10))self.register_buffer('bias',torch.zeros(10))def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x * self.weight + self.biasreturn xnet = MyModule()zprint('\n', '*'*30+"net.named_parameters"+'*'*30, '\n')
for name, param in net.named_parameters():print(name, param.shape)print('\n', '*'*30+"net.state_dict"+'*'*30, '\n')
for key, val in net.state_dict().items():print(key, val.shape) 

在这里插入图片描述
我们可以看到:

  • 通过register_buffer定义的参数weight和bias,它是没有被named_parameter给迭代出来的,也就是说weight和bias不是网络的可学习参数,无法通过优化器来迭代更新,我们把它叫做buffer,而不是参数
  • 然而我们使用net.state_dict去迭代的话,weight和bias事可以被迭代出来的,这就说明使用register_buffer定义的数据,可以保持到模型或者权重文件中。

注意:

  • 在使用register_parameter定义参数时,必须定义为可学习的参数,因此需要通过torch.nn.Parameter去定义为一个可学习的参数
  • 而我们使用register_buffer定义参数时,是不需要通过torch.nn.Parameter去定义为可学习的参数的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/130170.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ElasticSearch使用

Java API操作ES 相关依赖&#xff1a; <dependencies><!-- ES的高阶的客户端API --><dependency><groupId>org.elasticsearch.client</groupId><artifactId>elasticsearch-rest-high-level-client</artifactId><version>7.6…

软件测试之BUG篇(定义,创建,等级,生命周期)

目录 1. BUG 的定义 2. 如何创建 BUG 3. BUG 等级 4. BUG 生命周期 高频面试题&#xff1a; 1. BUG 的定义 当且仅当产品规格书存在且正确时&#xff0c;程序的实现和规格书的要求不匹配时&#xff0c;那就是软件错误。当产品规格说明书没有提到的功能时&#xff0c;以用户…

国家统计局教育部各级各类学历教育学生情况数据爬取

教育部数据爬取 1、数据来源2、爬取目标3、网页分析4、爬取与解析5、如何使用Excel打开CSV1、数据来源 国家统计局:http://www.stats.gov.cn/sj/ 教育部:http://www.moe.gov.cn/jyb_sjzl/ 数据来源:国家统计局教育部文献教育统计数据2021年全国基本情况(各级各类学历教育学…

编写shell脚本,利用mysqldump实现MySQL数据库分库分表备份

查看数据和数据表 mysql -uroot -p123456 -e show databases mysql -uroot -p123456 -e show tables from cb_d 删除头部Database和数据库自带的表 mysql -uroot -p123456 -e show databases -N | egrep -v "information_schema|mysql|performance_schema|sys"编写…

HTML和CSS的基础-前端扫盲

想要写出一个网页&#xff0c;就需要学习前端开发&#xff08;写网页代码&#xff09;和后端开发&#xff08;服务器代码&#xff09;。 对于前端的要求&#xff0c;我们不需要了解很深&#xff0c;仅仅需要做到扫盲的程度就可以了。 写前端&#xff0c;主要用到的有&#xf…

蓝鹏测控测宽仪系列又添一员大将——双目测宽仪

轧钢过程中钢板的宽度是一个重要的参数&#xff0c;它直接决定了成材率。同时&#xff0c;随着高新科技越来越广泛的应用到工程实际中&#xff0c;许多控制系统需要钢板实时宽度值作为模型参数。 当前&#xff0c;相当一部分宽厚板厂还在采用人工检测的方法&#xff0c;检测环境…

代码随想录算法训练营第23期day42|1049. 最后一块石头的重量II、494. 目标和、474.一和零

目录 一、&#xff08;leetcode 1049&#xff09;最后一块石头的重量II 二、&#xff08;leetcode 494&#xff09;目标和 三、&#xff08;leetcode 474&#xff09;一和零 一、&#xff08;leetcode 1049&#xff09;最后一块石头的重量II 力扣题目链接 状态&#xff1a;…

【漏洞复现】Drupal XSS漏洞复现

感谢互联网提供分享知识与智慧&#xff0c;在法治的社会里&#xff0c;请遵守有关法律法规 复现环境&#xff1a;Vulhub 环境启动后&#xff0c;访问 http://192.168.80.141:8080/ 将会看到drupal的安装页面&#xff0c;一路默认配置下一步安装。因为没有mysql环境&#xff0c;…

Mac下使用nvm,执行微信小程序自定义处理命令失败

环境 系统&#xff1a;Mac OS 终端&#xff1a;zsh CPU&#xff1a;M1/ARM架构 node环境&#xff1a;nvm&#xff0c;node20 node目录&#xff1a;/Users/laoxu/.nvm/versions/node/v20.1.0/bin/ 问题 在使用微信小程序的自定义处理命令时&#xff0c;启动失败 提示找不…

【音视频 | Ogg】libogg库详细介绍以及使用——附带libogg库解析.opus文件的C源码

&#x1f601;博客主页&#x1f601;&#xff1a;&#x1f680;https://blog.csdn.net/wkd_007&#x1f680; &#x1f911;博客内容&#x1f911;&#xff1a;&#x1f36d;嵌入式开发、Linux、C语言、C、数据结构、音视频&#x1f36d; &#x1f923;本文内容&#x1f923;&a…

全国大学生GIS应用技能大赛2023-12

一、题目背景 为了计算不同高程区间范围内流域的面积&#xff0c;要求根据提供的DEM数据&#xff0c;按照要求&#xff0c;计算不同高程区间范围内流域的面积。 二、数据说明 1、DEM&#xff1a;某地区的数字高程模型&#xff1b; 三、题目要求 根据提供的数字高程模型&am…

算法随想录算法训练营第四十九天| 503.下一个更大元素II 42. 接雨水

503.下一个更大元素II 题目&#xff1a;给定一个循环数组 nums &#xff08; nums[nums.length - 1] 的下一个元素是 nums[0] &#xff09;&#xff0c;返回 nums 中每个元素的 下一个更大元素 。数字 x 的 下一个更大的元素 是按数组遍历顺序&#xff0c;这个数字之后的第一个…

layer.open再次渲染html,子页面调用在父页面打开弹出层,渲染html

使用的版本 layui-v2.5.6是在父页面弹出层&#xff0c;显示&#xff1b;调用的是父页面的layer.open(); 父页面&#xff1a; <link href"/layui/css/layui.css" rel"stylesheet" /> <script src"/layui/layui.all.js"></script…

算法题:16. 最接近的三数之和(Python Java 详解)

解题思路 Step1&#xff1a;先对数组排序&#xff0c;然后设置3个指针&#xff0c;指针1遍历范围为&#xff08;0~数组长度减2&#xff09;。 Step2&#xff1a;指针1位置确定时&#xff0c;指针1后面的数组元素首位各放置一个指针&#xff08;指针2、指针3&#xff09;。 S…

项目中用到的git指令合集

目录 前言一、删除分支本地远程 二、不小心删除未合并成功的分支总结 前言 提示&#xff1a;这里可以添加本文要记录的大概内容&#xff1a; 做了一个git的常用指令合集&#xff0c;包含具体场景介绍 提示&#xff1a;以下是本篇文章正文内容&#xff0c;下面案例可供参考 一…

Python笔记——linux/ubuntu下安装mamba,安装bob.learn库

Python笔记——linux/ubuntu下安装mamba&#xff0c;安装bob.learn库 一、安装/卸载anaconda二、安装mamba1. 命令行安装&#xff08;大坑&#xff0c;不推荐&#xff09;2. 命令行下载guihub上的安装包并安装&#xff08;推荐&#xff09;3. 网站下载安装包并安装&#xff08;…

ubuntu外接显示器、不识别笔记本显示器

如题&#xff1a;ubuntu外接显示器、不识别笔记本显示器 双屏幕&#xff0c;笔记本外接显示器HDMI&#xff0c;然后安装Nvidia显卡驱动&#xff0c;之后重启笔记本显示器无法识别&#xff0c;只能使用外接显示器了。 中文网站找遍了都没有解决方案&#xff0c;然后用英文搜索&a…

电路正负反馈,电压电流反馈,串并联反馈详细判别方法

正/负反馈&#xff1a;假设输出升高&#xff0c;转一圈回来仍使其升高就是正反馈&#xff0c;反之就是负反馈。作图法&#xff1a;在RL的信号端画一个向上的小箭头&#xff0c;沿着反馈环路&#xff0c;每经过一个元器件就画一个相应的箭头&#xff0c;一直画到放大器的输出端&…

用C++QT实现一个modbus rtu通讯程序框架

下面是一个简单的Modbus RTU通讯程序框架的示例&#xff0c;使用C和QT来实现&#xff1a; #include <QCoreApplication> #include <QSerialPort> #include <QModbusDataUnit> #include <QModbusRtuSerialMaster>int main(int argc, char *argv[]) {QC…

代理模式(静态代理、JDK代理、CGLIB代理)

简介 代理模式有三种不同的形式&#xff1a;静态代理、动态代理&#xff08;JDK代理、接口代理&#xff09;、CGLIB代理 目标&#xff1a;在不修改目标对象的前提下&#xff0c;对目标对象进行扩展。 静态代理 需要定义接口或父类对象&#xff0c;被代理对象和代理对象通过实…