PyTorch中的torch.nn.Parameter() 详解

PyTorch中的torch.nn.Parameter() 详解

今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实现原理细节也是云里雾里,在参考了几篇博文,做过几个实验之后算是清晰了,本文在记录的同时希望给后来人一个参考,欢迎留言讨论。

分析

先看其名,parameter,中文意为参数。我们知道,使用PyTorch训练神经网络时,本质上就是训练一个函数,这个函数输入一个数据(如CV中输入一张图像),输出一个预测(如输出这张图像中的物体是属于什么类别)。而在我们给定这个函数的结构(如卷积、全连接等)之后,能学习的就是这个函数的参数了,我们设计一个损失函数,配合梯度下降法,使得我们学习到的函数(神经网络)能够尽量准确地完成预测任务。

通常,我们的参数都是一些常见的结构(卷积、全连接等)里面的计算参数。而当我们的网络有一些其他的设计时,会需要一些额外的参数同样很着整个网络的训练进行学习更新,最后得到最优的值,经典的例子有注意力机制中的权重参数、Vision Transformer中的class token和positional embedding等。

而这里的torch.nn.Parameter()就可以很好地适应这种应用场景。

下面是这篇博客的一个总结,笔者认为讲的比较明白,在这里引用一下:

首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。

ViT中nn.Parameter()的实验

看过这个分析后,我们再看一下Vision Transformer中的用法:

...self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
...

我们知道在ViT中,positonal embedding和class token是两个需要随着网络训练学习的参数,但是它们又不属于FC、MLP、MSA等运算的参数,在这时,就可以用nn.Parameter()来将这个随机初始化的Tensor注册为可学习的参数Parameter。

为了确定这两个参数确实是被添加到了net.Parameters()内,笔者稍微改动源码,显式地指定这两个参数的初始数值为0.98,并打印迭代器net.Parameters()。

...self.pos_embedding = nn.Parameter(torch.ones(1, num_patches+1, dim) * 0.98)
self.cls_token = nn.Parameter(torch.ones(1, 1, dim) * 0.98)
...

实例化一个ViT模型并打印net.Parameters():

net_vit = ViT(image_size = 256,patch_size = 32,num_classes = 1000,dim = 1024,depth = 6,heads = 16,mlp_dim = 2048,dropout = 0.1,emb_dropout = 0.1)for para in net_vit.parameters():print(para.data)

输出结果中可以看到,最前两行就是我们显式指定为0.98的两个参数pos_embedding和cls_token:

tensor([[[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],...,[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800]]])
tensor([[[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800]]])
tensor([[-0.0026, -0.0064,  0.0111,  ...,  0.0091, -0.0041, -0.0060],[ 0.0003,  0.0115,  0.0059,  ..., -0.0052, -0.0056,  0.0010],[ 0.0079,  0.0016, -0.0094,  ...,  0.0174,  0.0065,  0.0001],...,[-0.0110, -0.0137,  0.0102,  ...,  0.0145, -0.0105, -0.0167],[-0.0116, -0.0147,  0.0030,  ...,  0.0087,  0.0022,  0.0108],[-0.0079,  0.0033, -0.0087,  ..., -0.0174,  0.0103,  0.0021]])
...
...

这就可以确定nn.Parameter()添加的参数确实是被添加到了Parameters列表中,会被送入优化器中随训练一起学习更新。

from torch.optim import Adam
opt = Adam(net_vit.parameters(), learning_rate=0.001)

其他解释

以下是国外StackOverflow的一个大佬的解读,笔者自行翻译并放在这里供大家参考,想查看原文的同学请戳这里。

我们知道Tensor相当于是一个高维度的矩阵,它是Variable类的子类。Variable和Parameter之间的差异体现在与Module关联时。当Parameter作为model的属性与module相关联时,它会被自动添加到Parameters列表中,并且可以使用net.Parameters()迭代器进行访问。
最初在Torch中,一个Variable(例如可以是某个中间state)也会在赋值时被添加为模型的Parameter。在某些实例中,需要缓存变量,而不是将它们添加到Parameters列表中。
文档中提到的一种情况是RNN,在这种情况下,您需要保存最后一个hidden state,这样就不必一次又一次地传递它。需要缓存一个Variable,而不是让它自动注册为模型的Parameter,这就是为什么我们有一个显式的方法将参数注册到我们的模型,即nn.Parameter类。

举个例子:

import torch
import torch.nn as nn
from torch.optim import Adamclass NN_Network(nn.Module):def __init__(self,in_dim,hid,out_dim):super(NN_Network, self).__init__()self.linear1 = nn.Linear(in_dim,hid)self.linear2 = nn.Linear(hid,out_dim)self.linear1.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))self.linear1.bias = torch.nn.Parameter(torch.ones(hid))self.linear2.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))self.linear2.bias = torch.nn.Parameter(torch.ones(hid))def forward(self, input_array):h = self.linear1(input_array)y_pred = self.linear2(h)return y_predin_d = 5
hidn = 2
out_d = 3
net = NN_Network(in_d, hidn, out_d)

然后检查一下这个模型的Parameters列表:

for param in net.parameters():print(type(param.data), param.size())""" Output
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
"""

可以轻易地送入到优化器中:

opt = Adam(net.parameters(), learning_rate=0.001)

另外,请注意Parameter的require_grad会自动设定。

各位读者有疑惑或异议的地方,欢迎留言讨论。

参考:

https://www.jianshu.com/p/d8b77cc02410

https://stackoverflow.com/questions/50935345/understanding-torch-nn-parameter

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532358.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vision Transformer(ViT)PyTorch代码全解析(附图解)

Vision Transformer&#xff08;ViT&#xff09;PyTorch代码全解析 最近CV领域的Vision Transformer将在NLP领域的Transormer结果借鉴过来&#xff0c;屠杀了各大CV榜单。本文将根据最原始的Vision Transformer论文&#xff0c;及其PyTorch实现&#xff0c;将整个ViT的代码做一…

hdfs的副本数为啥增加了_HDFS详解之块大小和副本数

1.HDFSHDFS : 伪分布式(学习)NNDNSNNsbin/start-dfs.sh(开启hdfs使用的脚本)bin/hdfs dfs -ls (输入命令加前缀bin/hdfs dfs)2.block(块)dfs.blocksize &#xff1a; 134217728(字节) / 128M 官网默认一个块的大小128M*举例理解块1个文件 130M&#xff0c;默认一个块的大小128M…

Linux下的ELF文件、链接、加载与库(含大量图文解析及例程)

Linux下的ELF文件、链接、加载与库 链接是将将各种代码和数据片段收集并组合为一个单一文件的过程&#xff0c;这个文件可以被加载到内存并执行。链接可以执行与编译时&#xff0c;也就是在源代码被翻译成机器代码时&#xff1b;也可以执行于加载时&#xff0c;也就是被加载器加…

mysql gender_Mysql第一弹

1、创建数据库pythoncreate database python charsetutf8;2、设计班级表结构为id、name、isdelete&#xff0c;编写创建表的语句create table classes(id int unsigned auto_increment primary key not null,name varchar(10),isdelete bit default 0);向班级表中插入数据pytho…

python virtualenv nginx_Ubuntu下搭建Nginx+supervisor+pypy+virtualenv

系统&#xff1a;Ubuntu 14.04 LTS搭建python的运行环境&#xff1a;NginxSupervisorPypyVirtualenv软件说明&#xff1a;Nginx&#xff1a;通过upstream进行负载均衡Supervisor&#xff1a;管理python进程Pypy&#xff1a;用Python实现的Python解释器PyPy is a fast, complian…

如何设置mysql表中文乱码_php mysql表中文乱码问题如何解决

为避免mysql中出现中文乱码&#xff0c;建议在创建数据库时指定编码格式&#xff1a;复制代码 代码示例:create database zzjz CHARACTER SET gbk COLLATE gbk_chinese_ci;create table zz_employees (employeeid int unsigned not null auto_increment primary key,name varch…

java 按钮 监听_Button的四种监听方式

Button按钮设置点击的四种监听方式注&#xff1a;加粗放大的都是改变的代码1.使用匿名内部类的形式进行设置使用匿名内部类的形式&#xff0c;直接将需要设置的onClickListener接口对象初始化&#xff0c;内部的onClick方法会在按钮被点击的时候执行第一个活动的java代码&#…

java int转bitmap_Java Base64位编码与String字符串的相互转换,Base64与Bitmap的相互转换实例代码...

首先是网上大神给的类package com.duanlian.daimengmusic.utils;public final class Base64Util {private static final int BASELENGTH 128;private static final int LOOKUPLENGTH 64;private static final int TWENTYFOURBITGROUP 24;private static final int EIGHTBIT …

linux查看java虚拟机内存_深入理解java虚拟机(linux与jvm内存关系)

本文转载自美团技术团队发表的同名文章https://tech.meituan.com/linux-jvm-memory.html一, linux与进程内存模型要理解jvm最重要的一点是要知道jvm只是linux的一个进程,把jvm的视野放大,就能很好的理解JVM细分的一些概念下图给出了硬件系统进程三个层面内存之间的关系.从硬件上…

java 循环stringbuffer_java常用类-----StringBuilder和StringBuffer的用法

一、可变字符常用方法package cn.zxg.PackgeUse;/*** 测试StringBuilder,StringBuffer可变字符序列常用方法*/public class TestStringBuilder2 {public static void main(String[] args) {StringBuilder sbnew StringBuilder();for(int i0;i<26;i){char temp(char)(ai);sb.…

java function void_Java8中你可能不知道的一些地方之函数式接口实战

什么时候可以使用 Lambda&#xff1f;通常 Lambda 表达式是用在函数式接口上使用的。从 Java8 开始引入了函数式接口&#xff0c;其说明比较简单&#xff1a;函数式接口(Functional Interface)就是一个有且仅有一个抽象方法&#xff0c;但是可以有多个非抽象方法的接口。 java8…

java jvm内存地址_JVM--Java内存区域

Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干个不同的数据区域&#xff0c;如图&#xff1a;1.程序计数器可以看作是当前线程所执行的字节码的行号指示器&#xff0c;通俗的讲就是用来指示执行哪条指令的。为了线程切换后能恢复到正确的执行位置Java多线程是…

java情人节_情人节写给女朋友Java Swing代码程序

马上又要到情人节了&#xff0c;再不解风情的人也得向女友表示表示。作为一个程序员&#xff0c;示爱的时候自然也要用我们自己的方式。这里给大家上传一段我在今年情人节的时候写给女朋友的一段简单的Java Swing代码&#xff0c;主要定义了一个对话框&#xff0c;让女友选择是…

java web filter链_filter过滤链:Filter链是如何构建的?

在一个Web应用程序中可以注册多个Filter程序&#xff0c;每个Filter程序都可以针对某一个URL进行拦截。如果多个Filter程序都对同一个URL进行拦截&#xff0c;那么这些Filter就会组成一个Filter链(也叫过滤器链)。Filter链用FilterChain对象来表示&#xff0c;FilterChain对象中…

java web 应用技术与案例教程_《Java Web应用开发技术与案例教程》怎么样_目录_pdf在线阅读 - 课课家教育...

出版说明前言第1章 java Web应用开发技术概述1.1 Java Web应用开发技术简介1.1.1 Java Web应用1.1.2 Java Web应用开发技术1.2 Java Web开发环境及开发工具1.2.1 JDK的下载与安装1.2.2 Tomcat服务器的安装和配置1.2.3 MyEclipse集成开发工具的安装与操作1.3 Java Web应用程序的…

java环境变量自动设置_自动设置Java环境变量

echo offSETLOCALENABLEDELAYEDEXPANSIONfor /f "tokens2* delims " %%i in(reg query "HKLM\Software\JavaSoft\Java Development Kit" /s ^|find /I"JavaHome") do (echo 找到目录 %%jset /p isOK该目录是不是JDK^(JavaDevelopment Kit^)的安装…

mysql运行状态监控研究内容_如何监控mysql主从的运行状态shell脚本实例介绍

如何监控mysql主从的运行状态shell脚本实例介绍。#!/bin/bash#define mysql variablemysql_user”root”mysql_pass”123456″email_addr”slavecentos.bz”mysql_statusnetstat -nl | awk ‘NR>2{if ($4 ~ /.*:3306/) {print “Yes”;exit 0}}’if [ "$mysql_status&q…

java 100% cpu_Java服务,CPU 100%问题如何快速定位?

Java服务&#xff0c;有时候会遇到CPU 100%的问题&#xff0c;对于这样的问题&#xff0c;我们如何快速定位并解决呢&#xff1f;一般会有如下三个步骤&#xff1a;1、找到最耗CPU的进程2、找到这个进程中最耗CPU的线程3、查看堆栈信息&#xff0c;定位线程的什么操作消耗了大量…

java 泛型 加_Java泛型并将数字加在一起

为了一般地计算总和,您需要提供两个动作&#xff1a;>一种总计零项的方法>一种总结两个项目的方法在Java中,您可以通过界面完成.这是一个完整的例子&#xff1a;import java.util.*;interface adder {T zero(); // Adding zero itemsT add(T lhs, T rhs); // Adding two …

java 字母金字塔_LeetCode756:金字塔转换矩阵(JAVA题解)

题目描述现在&#xff0c;我们用一些方块来堆砌一个金字塔。 每个方块用仅包含一个字母的字符串表示。使用三元组表示金字塔的堆砌规则如下&#xff1a;对于三元组(A, B, C) &#xff0c;“C”为顶层方块&#xff0c;方块“A”、“B”分别作为方块“C”下一层的的左、右子块。当…