register_parameter和register_buffer 详解

在参考yolo系列代码或其他开源代码,经常看到register_buffer register_parameter的使用,接下来将详细对他们进行介绍。

1. 前沿

在搭建网络时,我们 自定义的参数,往往不会保存到模型权重文件中,或者成为模型可学习的参数。即我们通过 net.named_parameters() (模型可学习参数)或 net.state_dict().items()(保存模型权重值)方法都无法遍历输出。那如何解决呢,这就需要用到本文讲的register_parameterregister_buffer方法。

2. register_parameter

register_parameter() 是 torch.nn.Module 类中的一个方法。

2.1 主要作用

  • 用于定义可学习参数
  • 定义的参数可被保存到网络对象的参数中,可使用 net.parameters()net.named_parameters() 查看
  • 定义的参数可用 net.state_dict() 转换到字典中,进而 保存到网络文件 / 网络参数文件

2.2 函数说明

register_parameter(name,param)

参数:

  • name:参数名称

  • param:参数张量, 须是torch.nn.Parameter()对象 或 None ,否则报错如下
    TypeError: cannot assign 'torch.FloatTensor' object to parameter 'xx' (torch.nn.Parameter or None required)

2.3 举例说明

(1)自定义的参数未使用register_parameter

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)self.weight = torch.ones(10,10)self.bias = torch.zeros(10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x * self.weight + self.biasreturn xnet = MyModule()print('\n', '*'*30+"net.named_parameters"+'*'*30, '\n')
for name, param in net.named_parameters():print(name, param.shape)print('\n', '*'*30+"net.state_dict"+'*'*30, '\n')
for key, val in net.state_dict().items():print(key, val.shape) 

输出:
在这里插入图片描述
在网络搭建的代码中,我们自定义了self.weightself.bias参数。我们思考下2个问题:1. 我们定义的self.weightself.bias参数是否会保存到网络的参数中,是否能在优化器的作用下进行学习。2. 这些参数是否能够保存到模型文件中,从而可以利用state_dict中遍历出来。通过上面的打印信息我们发现:

  • 使用net.named_parameters()迭代网络中可学习的参数,发现输出的参数只有conv1conv2的weight参数,并没有输出我们定义的self.weightself.bias
  • 接下来使用net.state_dict()方法迭代保存的参数,同样发现self.weightself.bias参数也没有被输出出来。

(2)通过register_parameter方法来定义参数

  • 接下来我们使用register_parameter来定义weight和bias参数,看看会有啥效果。代码修改如下:
self.register_parameter('weight',torch.nn.Parameter(torch.ones(10,10)))
self.register_parameter('bias',torch.nn.Parameter(torch.zeros(10)))

完整代码

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)self.register_parameter('weight',torch.nn.Parameter(torch.ones(10,10)))self.register_parameter('bias',torch.nn.Parameter(torch.zeros(10)))def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x * self.weight + self.biasreturn xnet = MyModule()print('\n', '*'*30+"net.named_parameters"+'*'*30, '\n')
for name, param in net.named_parameters():print(name, param.shape)print('\n', '*'*30+"net.state_dict"+'*'*30, '\n')
for key, val in net.state_dict().items():print(key, val.shape) 

在这里插入图片描述

  • 可以看到,使用了register_parameter定义的参数weight和bias,可以通过net.named_parameters或者net.parameters迭代出来的,这说明weight和bias已经存到了网络的参数中,他们是可学习的参数
  • 同时,通过state_dict()也能将参数和值给迭代出来,就说明如果要保存模型权重或网络参数时,这两个参数时可以被保存起来的。

3 register_buffer()

register_buffer()是 torch.nn.Module() 类中的一个方法

3.1 作用

  • 用于定义不可学习的参数
  • 定义的参数不会被保存到网络对象的参数中,使用 net.parameters() 或 net.named_parameters() 查看不到
  • 定义的参数可用 net.state_dict() 转换到字典中,进而 保存到网络文件 / 网络参数文件中

register_buffer() 用于在网络实例中 注册缓冲区,存储在缓冲区中的数据,类似于参数(但不是参数),它与参数的区别为:

  • 参数:可以被优化器更新 (requires_grad=False / True)

  • buffer 中的数据 (不可学习): 不会被优化器更新

3.2、举例说明

将定义的weight和bias,通过register_buffer来定义。

self.register_buffer('weight',torch.ones(10,10))
self.register_buffer('bias',torch.zeros(10))

运行完整代码看看效果:

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)self.register_buffer('weight',torch.ones(10,10))self.register_buffer('bias',torch.zeros(10))def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x * self.weight + self.biasreturn xnet = MyModule()zprint('\n', '*'*30+"net.named_parameters"+'*'*30, '\n')
for name, param in net.named_parameters():print(name, param.shape)print('\n', '*'*30+"net.state_dict"+'*'*30, '\n')
for key, val in net.state_dict().items():print(key, val.shape) 

在这里插入图片描述
我们可以看到:

  • 通过register_buffer定义的参数weight和bias,它是没有被named_parameter给迭代出来的,也就是说weight和bias不是网络的可学习参数,无法通过优化器来迭代更新,我们把它叫做buffer,而不是参数
  • 然而我们使用net.state_dict去迭代的话,weight和bias事可以被迭代出来的,这就说明使用register_buffer定义的数据,可以保持到模型或者权重文件中。

注意:

  • 在使用register_parameter定义参数时,必须定义为可学习的参数,因此需要通过torch.nn.Parameter去定义为一个可学习的参数
  • 而我们使用register_buffer定义参数时,是不需要通过torch.nn.Parameter去定义为可学习的参数的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/130170.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

软件测试之BUG篇(定义,创建,等级,生命周期)

目录 1. BUG 的定义 2. 如何创建 BUG 3. BUG 等级 4. BUG 生命周期 高频面试题: 1. BUG 的定义 当且仅当产品规格书存在且正确时,程序的实现和规格书的要求不匹配时,那就是软件错误。当产品规格说明书没有提到的功能时,以用户…

国家统计局教育部各级各类学历教育学生情况数据爬取

教育部数据爬取 1、数据来源2、爬取目标3、网页分析4、爬取与解析5、如何使用Excel打开CSV1、数据来源 国家统计局:http://www.stats.gov.cn/sj/ 教育部:http://www.moe.gov.cn/jyb_sjzl/ 数据来源:国家统计局教育部文献教育统计数据2021年全国基本情况(各级各类学历教育学…

编写shell脚本,利用mysqldump实现MySQL数据库分库分表备份

查看数据和数据表 mysql -uroot -p123456 -e show databases mysql -uroot -p123456 -e show tables from cb_d 删除头部Database和数据库自带的表 mysql -uroot -p123456 -e show databases -N | egrep -v "information_schema|mysql|performance_schema|sys"编写…

HTML和CSS的基础-前端扫盲

想要写出一个网页,就需要学习前端开发(写网页代码)和后端开发(服务器代码)。 对于前端的要求,我们不需要了解很深,仅仅需要做到扫盲的程度就可以了。 写前端,主要用到的有&#xf…

蓝鹏测控测宽仪系列又添一员大将——双目测宽仪

轧钢过程中钢板的宽度是一个重要的参数,它直接决定了成材率。同时,随着高新科技越来越广泛的应用到工程实际中,许多控制系统需要钢板实时宽度值作为模型参数。 当前,相当一部分宽厚板厂还在采用人工检测的方法,检测环境…

【漏洞复现】Drupal XSS漏洞复现

感谢互联网提供分享知识与智慧,在法治的社会里,请遵守有关法律法规 复现环境:Vulhub 环境启动后,访问 http://192.168.80.141:8080/ 将会看到drupal的安装页面,一路默认配置下一步安装。因为没有mysql环境,…

Mac下使用nvm,执行微信小程序自定义处理命令失败

环境 系统:Mac OS 终端:zsh CPU:M1/ARM架构 node环境:nvm,node20 node目录:/Users/laoxu/.nvm/versions/node/v20.1.0/bin/ 问题 在使用微信小程序的自定义处理命令时,启动失败 提示找不…

【音视频 | Ogg】libogg库详细介绍以及使用——附带libogg库解析.opus文件的C源码

😁博客主页😁:🚀https://blog.csdn.net/wkd_007🚀 🤑博客内容🤑:🍭嵌入式开发、Linux、C语言、C、数据结构、音视频🍭 🤣本文内容🤣&a…

算法随想录算法训练营第四十九天| 503.下一个更大元素II 42. 接雨水

503.下一个更大元素II 题目:给定一个循环数组 nums ( nums[nums.length - 1] 的下一个元素是 nums[0] ),返回 nums 中每个元素的 下一个更大元素 。数字 x 的 下一个更大的元素 是按数组遍历顺序,这个数字之后的第一个…

算法题:16. 最接近的三数之和(Python Java 详解)

解题思路 Step1:先对数组排序,然后设置3个指针,指针1遍历范围为(0~数组长度减2)。 Step2:指针1位置确定时,指针1后面的数组元素首位各放置一个指针(指针2、指针3)。 S…

Python笔记——linux/ubuntu下安装mamba,安装bob.learn库

Python笔记——linux/ubuntu下安装mamba,安装bob.learn库 一、安装/卸载anaconda二、安装mamba1. 命令行安装(大坑,不推荐)2. 命令行下载guihub上的安装包并安装(推荐)3. 网站下载安装包并安装(…

电路正负反馈,电压电流反馈,串并联反馈详细判别方法

正/负反馈:假设输出升高,转一圈回来仍使其升高就是正反馈,反之就是负反馈。作图法:在RL的信号端画一个向上的小箭头,沿着反馈环路,每经过一个元器件就画一个相应的箭头,一直画到放大器的输出端&…

代理模式(静态代理、JDK代理、CGLIB代理)

简介 代理模式有三种不同的形式:静态代理、动态代理(JDK代理、接口代理)、CGLIB代理 目标:在不修改目标对象的前提下,对目标对象进行扩展。 静态代理 需要定义接口或父类对象,被代理对象和代理对象通过实…

asp.net docker-compose添加dapr配置

docker-compose.yml添加配置 webapplication1-dapr:image: "daprio/daprd:1.9.6"network_mode: "service:webapplication1"depends_on:- webapplication1 docker-compose.override.yml中添加 dapr-placement:command: ["./placement", "-po…

【数据结构】顺序表的学习

前言:在之前我们学习了C语言的各种各样的语法,因此我们今天开始学习数据结构这一个模块,因此我们就从第一个部分来开始学习"顺序表"。 💖 博主CSDN主页:卫卫卫的个人主页 💞 👉 专栏分类:数据结构 &#x1f…

windows和docker环境下springboot整合gdal3.x

链接: gdal官网地址 gdal gdal的一个用c语言编写的库,用于处理地理信息相关的数据包括转换,识别数据,格式化数据以及解析 同时提供第三方语言的SDK包括python,java上述需要编译后使用 java是需要使用jni接口调用实现方法在wind…

mysql---存储引擎

目录 mysql---存储引擎 功能: mysql的存储引擎分类 MYISAM和INNODB做个对比 MYISAM 在磁盘上有三个文件: MYISAM的特点: 支持的存储格式: INNODB innodb的特点 使用场景: 三个文件: 行锁 表锁 排他锁 …

uniapp原生插件之安卓文件操作原生插件

插件介绍 安卓文件操作原生插件,读写文件,文件下载等,支持读取移动设备路径等外部存储设备路径,如U盘路径 插件地址 安卓文件操作原生插件 - DCloud 插件市场 超级福利 uniapp 插件购买超级福利 详细使用文档 uniapp 安卓文…

互联网医院|湖南互联网医院|智慧医疗改善就医服务

互联网医院系统,是指利用互联网技术和远程医疗技术,提供在线就诊、咨询、诊断和治疗等医疗服务的一种医疗模式。互联网医院系统实际上与医院的HIS系统很相似,是侧重服务于线上问诊的专业HIS,包含传统HIS的基本模块,如挂…

VS Code 开发 Spring Boot 类型的项目

在VS Code中开发Spring Boot的项目, 可以导入如下的扩展: Spring Boot ToolsSpring InitializrSpring Boot Dashboard 比较建议的方式是安装Spring Boot Extension Pack, 这里面就包含了上面的扩展。 安装方式就是在扩展查找 “Spring Boot…