PyTorch参数管理详解:从访问到初始化与共享

本文通过实例代码讲解如何在PyTorch中管理神经网络参数,包括参数访问、多种初始化方法、自定义初始化以及参数绑定技术。所有代码可直接运行,适合深度学习初学者进阶学习。


1. 定义网络与参数访问

1.1 定义单隐藏层多层感知机

import torch
from torch import nn# 定义单隐藏层多层感知机
net1 = nn.Sequential(nn.Linear(4, 8),  # 输入层4维,隐藏层8维nn.ReLU(),nn.Linear(8, 1)   # 输出层1维
)
x = torch.rand(2, 4)  # 随机生成2个4维输入向量
net1(x)                # 前向传播

1.2 访问网络参数

# 访问第二层(索引2)的参数(权重和偏置)
print(net1[2].state_dict())# 查看参数类型、数据和梯度
print(type(net1[2].bias))    # 类型:Parameter
print(net1[2].bias)          # 参数值(含梯度信息)
print(net1[2].bias.data)     # 参数数据(张量)
print(net1[2].bias.grad)     # 梯度(未反向传播时为None)

1.3 批量访问参数

# 访问第一层的参数名称和形状
print(*[(name, param.shape) for name, param in net1[0].named_parameters()])# 访问整个网络的参数
print(*[(name, param.shape) for name, param in net1.named_parameters()])# 通过state_dict直接访问参数数据
print(net1.state_dict()['2.bias'].data)

2. 参数初始化方法

2.1 内置初始化

# 正态分布初始化权重,偏置置零
def init_normal(model):if isinstance(model, nn.Linear):nn.init.normal_(model.weight, mean=0, std=0.01)nn.init.zeros_(model.bias)net1.apply(init_normal)
print(net1[0].weight.data[0], net1[0].bias.data[0])# 常数初始化(权重为1,偏置为0)
def init_constant(model):if isinstance(model, nn.Linear):nn.init.constant_(model.weight, 1)nn.init.zeros_(model.bias)net1.apply(init_constant)
print(net1[0].weight.data[0], net1[0].bias.data[0])

2.2 分层初始化

# 对第一层使用Xavier初始化,第二层使用常数42初始化
def xavier(model):if isinstance(model, nn.Linear):nn.init.xavier_uniform_(model.weight)def init_42(model):if isinstance(model, nn.Linear):nn.init.constant_(model.weight, 42)net1[0].apply(xavier)
net1[2].apply(init_42)
print(net1[0].weight.data[0])
print(net1[2].weight.data)

2.3 自定义初始化

# 自定义初始化:权重在[-10,10]均匀分布,并过滤绝对值小于5的值
def my_init(model):if isinstance(model, nn.Linear):print(f'init weight {model.weight.shape}')nn.init.uniform_(model.weight, -10, 10)model.weight.data *= (model.weight.abs() >= 5)net1.apply(my_init)
print(net1[0].weight.data[:2])  # 显示前两行权重

3. 参数绑定与共享

3.1 直接修改参数

# 直接操作参数数据
net1[0].weight.data[:] += 1     # 所有权重+1
net1[0].weight.data[0, 0] = 42  # 修改特定位置权重
print(net1[0].weight.data[0])   # 输出第一行权重

3.2 参数共享

# 共享线性层参数
shared_layer = nn.Linear(8, 8)
net3 = nn.Sequential(nn.Linear(4, 8), nn.ReLU(),shared_layer, nn.ReLU(),     # 第2层shared_layer, nn.ReLU(),     # 第4层(共享参数)nn.Linear(8, 1)
)# 验证参数共享
print(net3[2].weight.data[0] == net3[4].weight.data[0])  # 输出全True
net3[2].weight.data[0, 0] = 100
print(net3[2].weight.data[0] == net3[4].weight.data[0])  # 修改后仍为True

4. 嵌套网络结构

# 构建嵌套网络
def model1():return nn.Sequential(nn.Linear(4, 8), nn.ReLU(),nn.Linear(8, 4), nn.ReLU())def model2():net = nn.Sequential()for i in range(4):net.add_module(f'model{i}', model1())return netrgnet = nn.Sequential(model2(), nn.Linear(4, 1))
print(rgnet)  # 打印网络结构

总结

本文演示了PyTorch中参数管理的核心操作,包括:

  • 通过state_dictnamed_parameters访问参数

  • 使用内置初始化方法(正态分布、常数、Xavier)

  • 自定义初始化逻辑

  • 参数的直接修改与共享

  • 复杂嵌套网络的定义

掌握这些技能可以更灵活地设计和优化神经网络模型。建议读者在实践中结合具体任务调整初始化策略,并注意参数共享时的梯度传播特性。


提示:以上代码需要在PyTorch环境中运行,建议使用Jupyter Notebook逐步调试以观察中间结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/77248.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于springboot+vue的课程管理系统

一、系统架构 前端:vue | element-ui 后端:springboot | mybatis-plus 环境:jdk1.8 | mysql8 | maven | node v16.20.2 | idea 二、代码及数据 三、功能介绍 01. 登录 02. 管理员-首页 03. 管理员-系管理 04. 管理员-专业管理 05. 管…

ssh密钥连接远程服务器并用scp传输文件

ssh密钥连接远程服务器 私钥的权限必须是600chmod 600 id_rsa连接时在命令中加上私钥的地址ssh -i PATH_to_id_rsa usernameip -p port scp -P port -i PATH_to_id_rsa file usernameip:PATH

ElasticSearch迁移数据

一、查询索引 1、查询所有索引 curl --user elastic:123456 -XGET "http://localhost:19200/_cat/indices?v&sindex" 2、查询索引配置 以索引名称hello为例 curl --user elastic:123456 -XGET "http://localhost:19200/hello/_settings?pretty" 3…

【Unity】animator检测某state动画播放完毕方法

博主对动画系统很不熟,可能使用的方法比较曲折,但是我确实没找到更有效的方法了。 unity的这个animator在我看来简直有毛病啊,为什么那么难以获取某状态动画的信息呢??? 想要知道动画播完没有只有用norma…

Jmeter 插件【性能测试监控搭建】

1. 安装Plugins Manager 1.1 下载路径: Install :: JMeter-Plugins.org 1.2 放在lib/ext目录下 1.3 重启Jmeter,会在菜单-选项下多一个 Plugins Manager菜单,打开即可对插件进行安装、升级。 2. 客户端(Jmeter端) 2.1 安装plugins manager…

ollama+open-webui本地部署自己的模型到d盘+两种open-webui部署方式(详细步骤+大量贴图)

一、ollama准备 1.官网下载ollama:https://ollama.com/download 2.在 d 盘创建 ollama 文件夹,把软件包放进去 3.管理员身份运行黑窗口 win r 弹出运行窗口 输入 cmd 后, ctrl shift 回车,以管理员身份打开 3.切换到 d 盘&a…

(学习总结33)Linux Ext2 文件系统与软硬链接

Linux Ext2 文件系统与软硬链接 理解硬件磁盘、服务器、机柜、机房磁盘物理结构磁盘的逻辑结构实际过程 CHS 与 LBA 地址转换 引入文件系统引入 " 块 " 概念引入 " 分区 " 概念引入 " inode " 概念 ext2 文件系统宏观认识Block Group 块组与其内…

Go语言sync.Mutex包源码解读

互斥锁sync.Mutex是在并发程序中对共享资源进行访问控制的主要手段,对此Go语言提供了非常简单易用的机制。sync.Mutex为结构体类型,对外暴露Lock()、Unlock()、TryLock()三种方法,分别用于阻塞加锁、解锁、非阻塞加锁操作(加锁失败…

SQL注入流量分析

免责声明:本文仅作分享 ~ 目录 SQL注入流量分析 特征: sqlmap注入类型 漏洞环境搭建 error_sql: bool_sql: time_sql: union_sql: Stacked Queries: Inline Queries: SQL注入流量分析 https://www.freebuf.com/column/161797.html SQLMAP攻击…

Linux 时间同步工具 Chrony 简介与使用

一、Chrony 是什么? chrony 是一个开源的网络时间同步工具,主要由两个组件组成: chronyd:后台服务进程,负责与时间服务器交互,同步系统时钟。chronyc:命令行工具,用于手动查看或修…

Flutter:Flutter SDK版本控制,fvm安装使用

1、首先已经安装了Dart,cmd中执行 dart pub global activate fvm2、windows配置系统环境变量 fvm --version3、查看本地已安装的 Flutter 版本 fvm releases4、验证当前使用的 Flutter 版本: fvm flutter --version5、切换到特定版本的 Flutter fvm use …

Vue 项目中的package.json各部分的作用和用法的详细说明

1. 基本信息 {"name": "my-vue-app","version": "1.0.0","description": "A Vue.js project","author": "Your Name <your.emailexample.com>","license": "MIT"…

Linux网络编程——TCP通信的四次挥手

一、前言 上篇文章讲到了TCP通信建立连接的“三次握手”的一些细节&#xff0c;本文再对TCP通信断开连接的“四次挥手”的过程做一些分析了解。 二、TCP断开连接的“四次挥手” 我们知道TCP在建立连接的时需要“三次握手”&#xff0c;三次握手完后就可以进行通信了。而在通…

某碰瓷国赛美赛,号称第三赛事的数模竞赛

首先我非常不能理解的就是怎么好意思自称第三赛事的呢&#xff1f;下面我们进行一个简单讨论&#xff0c;当然这里不对国赛和美赛进行讨论。首先我们来明确一点&#xff0c;比赛的含金量由什么来定&#xff1f;这个可能大家的评价指标可能不唯一&#xff0c;我通过DeepSeek选取…

Redis 缓存问题:缓存雪崩、缓存击穿、缓存穿透

文章目录 缓存雪崩缓存击穿缓存穿透在实际的业务场景中,Redis 通常作为缓存和其他数据库(例如 MySQL)搭配使用,用来减轻数据库的压力。但是在使用 Redis 作为缓存数据库的过程中,可能会遇到一些常见问题,例如缓存穿透、缓存击穿和缓存雪崩等。 缓存雪崩 缓存雪崩是指缓存…

Qt 入门 4 之标准对话框

Qt 入门 4 之标准对话框 Qt提供了一些常用的对话框类型,它们全部继承自QDialog类,并增加了自己的特色功能,比如获取颜色、显示特定信息等。下面简单讲解这些对话框,可以在帮助索引中查看Standard Dialogs关键字,也可以直接索引相关类的类名。 本文将以一个新的项目为主介绍不…

买不起了,iPhone 或涨价 40% ?

周知的原因&#xff0c;新关税对 iPhone 的打击&#xff0c;可以说非常严重。 根据 Rosenblatt Securities分析师的预测&#xff0c;若苹果完全把成本转移给消费者。 iPhone 16 标配版的价格&#xff0c;可能上涨43%。 iPhone 16 标配的价格是799美元&#xff0c;上涨43%&am…

软件需求分析习题汇编

需求工程练习题 一、选择题 1. 软件需求规格说明书的内容不应包括对&#xff08; &#xff09;的描述。 A. 主要功能B. 算法的详细过程C. 用户界面及运行环境D. 软件的性能 *正确答案:*B:算法的详细过程; 2. 需求分析最终结果是产生&#xff08; &#xff09; A. 项目开发…

clickhouse注入手法总结

clickhouse 遇到一题clickhouse注入相关的&#xff0c;没有见过&#xff0c;于是来学习clickhouse的使用&#xff0c;并总结相关注入手法。 环境搭建 直接在docker运行 docker pull clickhouse/clickhouse-server docker run -d --name some-clickhouse-server --ulimit n…

智能语音识别工具开发手记

智能语音识别工具开发手记 序言&#xff1a;听见数字化的声音 在县级融媒体中心的日常工作中&#xff0c;我们每天需要处理大量音频素材——从田间地头的采访录音到演播室的节目原声&#xff0c;从紧急会议记录到专题报道素材。二十多年前&#xff0c;笔者刚入职时&#xff0…