【踩坑记录】pytorch 自定义嵌套网络时部分网络有梯度但参数不更新

问题描述

使用如下的自定义的多层嵌套网络进行训练:

class FC1_bot(nn.Module):def __init__(self):super(FC1_bot, self).__init__()self.embeddings = nn.Sequential(nn.Linear(10, 10))def forward(self, x):emb = self.embeddings(x)return embclass FC1_top(nn.Module):def __init__(self):super(FC1_top, self).__init__()self.prediction = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(10, 10))def forward(self, x):logit = self.prediction(x)return logitclass FC1(nn.Module):def __init__(self, num):super(FC1, self).__init__()self.num = numself.bot = []for _ in range(num):self.bot.append(FC1_bot())self.top = FC1_top()self.softmax = nn.Softmax(dim=1)def forward(self, x):x = list(x)emb = []for i in range(self.num):emb.append(self.bot[i](x[i]))agg_emb = self._aggregate(emb)logit = self.top(agg_emb)pred = self.softmax(logit)return emb, preddef _aggregate(self, x):# Note: x is a list of tensors.return torch.cat(x, dim=1)

训练的代码如下:

num = 4
model = FC1(num)
optimizer_entire = torch.optim.SGD(model.parameters(), lr=0.01)def train(self):# train entire modelself.model.train()for epoch in range(self.args.epochs):pred = self.model(data)loss = torch.nn.CrossEntropyLoss(pred, labels)# zero grad for all optimizersoptimizer_entire.zero_grad()loss.backward()# update parameters for all optimizersoptimizer_entire.step()

解决办法

需要给所有用到的模型参数都设置optimizer,否则只有top部分的参数在训练,底层的会得到gradient,但parameter不会更新。

num = 4
model = FC1(num)
optimizer_entire = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer_top = torch.optim.SGD(model.top.parameters(), lr=0.01)
optimizer_bot = []
for i in range(num):optimizer_passive.append(torch.optim.SGD(model.passive[i].parameters(), lr=0.01))def train(self):# train entire modelself.model.train()self.model.top.train()for i in range(self.args.num):self.model.bot[i].train()for epoch in range(self.args.epochs):pred = self.model(data)loss = torch.nn.CrossEntropyLoss(pred, labels)# zero grad for all optimizersoptimizer_entire.zero_grad()optimizer_top.zero_grad()for i in range(num):optimizer_bot[i].zero_grad()loss.backward()# update parameters for all optimizersoptimizer_entire.step()optimizer_top.step()for i in range(num):optimizer_bot[i].step()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/231729.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

强化产品联动:网关V7独家解决方案的三重优势

客户背景 某央企单位汇聚了众多业内优秀的工程师和科研人员,拥有先进的研发设施和丰富的研发经验,专注于为全球汽车行业提供创新和实用的解决方案。其研发成果不仅在国内市场上得到了广泛应用,也在国际市场上赢得了广泛的认可和赞誉。 客户需…

jconsole与jvisualvm

jconsole 环境变量配置好后 直接输入在cmd 输入jconsole 即可 jvisualvm cmd 输入jvisualvm jvisualvm 能干什么 监控内存泄露,跟踪垃圾回收,执行时内存、cpu 分析,线程分析… 运行:正在运行的 休眠:sleep 等待…

接口测试的工具(3)----postman+node.js+newman

1.安装newman:输入命令之后 一定注意 什么都不要操作 静静的等待结束就行了。 2.安装失败的对此尝试不行 在用下面的方法 解压一下就行了 3.验证是否成功 多次尝试是可以在线安装成功的

测试进程监控:确保产品质量的关键

引言: 在软件开发过程中,测试是确保产品质量的重要环节。为了提高测试效率和准确性,测试进程监控成为了不可或缺的工具。本文将介绍测试进程监控的各个方面,包括产品风险度量、缺陷度量源、测试用例(或规程&#xff09…

Unity中Shader URP最简Shader框架(ShaderGraph 转 URP Shader)

文章目录 前言一、 我们先了解一下 Shader Graph 怎么操作1、了解一下 Shader Graph 的面板信息2、修改Shader路径3、鼠标中键 或 Alt 鼠标左键 移动画布4、鼠标右键 打开创建节点菜单5、把ShaderGraph节点转化为 Shader 代码6、可以看出 URP 和 BuildIn RP 大体框架一致 二、…

【Docker-2】在 Debian 上安装 Docker 引擎

在 Debian 上安装 Docker 引擎 要开始在 Debian 上使用 Docker 引擎,请确保满足先决条件,然后按照安装步骤操作。 先决条件 操作系统要求 要安装 Docker Engine,您需要以下 Debian 之一的 64 位版本 版本: Debian Bookworm 12…

隐私计算介绍

这里只对隐私计算做一些概念性的浅显介绍,作为入门了解即可 目录 隐私计算概述隐私计算概念隐私计算背景国外各个国家和地区纷纷出台了围绕数据使用和保护的公共政策国内近年来也出台了数据安全、隐私和使用相关的政策法规 隐私计算技术发展 隐私计算技术安全多方计…

C# WPF上位机开发(usb设备访问)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 目前很多嵌入式设备都支持usb访问,特别是很多mcu都支持高速usb访问。和232、485下个比较,usb的访问速度和它们基本不在一个…

C语言求n的阶乘(n!)

从键盘输入一个数,求出这个数的阶乘,即 n!。 1、算法思想 首先要清楚阶乘定义,所谓 n 的阶乘,就是从 1 开始乘以比前一个数大 1 的数,一直乘到 n,用公式表示就是:1234…(n-2)(n-1)nn! 具体的操…

unittest自动化测试框架讲解以及实战

为什么要学习unittest 按照测试阶段来划分,可以将测试分为单元测试、集成测试、系统测试和验收测试。单元测试是指对软件中的最小可测试单元在与程序其他部分相隔离的情况下进行检查和验证的工作,通常指函数或者类,一般是开发完成的。 单元…

进程间通讯-消息队列

介绍 消息队列是一种存放在内核中的数据结构,用于在不同进程之间传递消息。它基于先进先出(FIFO)的原则,进程可以将消息发送到队列中,在需要的时候从队列中接收消息。消息队列提供了一种异步通信的方式,使…

❤Mac上后端环境工具安装使用

❤Mac上后端环境工具安装使用 Cornerstone 使用 (最好的SVN Mac软键) 使用教程 安装 由于Cornerstone是收费的,因此你可以去网上下载破解版,直接安装即可。 配置远程仓库 首先,打开CornerStone,在界面…

工业数据的特殊性和安全防护体系探索思考

随着工业互联网的发展,工业企业在生产运营管理过程中会产生各式各样数据,主要有研发设计数据、用户数据、生产运营数据、物流供应链数据等等,这样就形成了工业大数据,这些数据需要依赖企业的网络环境和应用系统进行内外部流通才能…

19、商城系统(一):项目架构图,配置前端后台开发环境,构建git项目,导入 人人开源框架并前端后台启动

目录​​​​​​​ 一、项目架构图 二、配置环境 1.配置linux (1)复制linux环境

【Python】—— NumPy基础及取值操作

NumPy基础及取值操作 第1关:ndarray对象第2关:形状操作第3关:基础操作第4关:随机数生成第5关:索引与切片 第1关:ndarray对象 任务描述 本关任务:根据本关所学知识,补全代码编辑器中…

react基于antd二次封装spin组件

目录 react基于antd二次封装spin组件组件使用组件效果 react基于antd二次封装spin组件 组件 import { Spin } from antd; import propTypes from "prop-types"; import React from react; import styleId from "styled-components"; // 使用 父div必须加…

【爬虫课堂】如何高效使用短效代理IP进行网络爬虫

目录 一、前言 二、代理IP的基本知识 三、短效代理IP的优势 四、高效使用短效代理IP的技巧 1. 多源获取代理IP 2. 质量筛选代理IP 3. 使用代理池 4. 定时更换代理IP 5. 失败重试机制 6. 监控和自动化 五、示例代码 六、结语 一、前言 网络爬虫是一种自动化程序&am…

Windbg 常用命令

Windbg 是微软开发的一款强大的调试工具,用于调试 Windows 操作系统和应用程序。它支持各种调试技术,包括用户模式和内核模式调试、本地和远程调试、源代码和汇编级别调试等。以下是 Windbg 中一些常用的命令: 标准命令: g - 继…

MongoDB中的关系

本文主要介绍MongoDB中的关系。 目录 MongoDB的关系嵌入关系引用关系 MongoDB的关系 MongoDB是一个非关系型数据库,它使用了键值对的方式来存储数据。因此,MongoDB没有像传统关系型数据库中那样的表、行和列的概念。相反,MongoDB中的关系是通…

LLM之RAG实战(五)| 高级RAG 01:使用小块检索,小块所属的大块喂给LLM,可以提高RAG性能

RAG(Retrieval Augmented Generation,检索增强生成)系统从给定的知识库中检索相关信息,从而使其能够生成事实信息、上下文相关信息和特定领域的信息。然而,在有效检索相关信息和生成高质量响应方面,RAG面临…