PyTorch中定义可学习参数时的坑

当需要在模型运行时定义可学习参数时(常见场景:参数的维度由每一层的维度定),我们就需要用这样的写法来实现:

class model(torch.nn.Module):def __init__(self):super().__init__()self.alpha = Nonedef forward(self, x):if self.alpha is None:self.alpha = torch.nn.Parameter(torch.ones(x.shape[0]), requires_grad=True)...

采用这种写法的话,必须要在正式训练模型之前进行一次预推理,该预推理可以是伪输入数据的推理,目的是预推理时构建好每一层所需要的self.alpha可学习参数。我常用的写法如下:

dummy_input = torch.randn(1, 3, 32, 32)
# 1:batch size为1,只推理单个样本;3:数据集的图像通道数;32:数据集的图像大小
model(dummy_input)

必须要注意的是,新定义的self.alpha必须要放入optimizer中才可以训练,因此,上面这段预推理的代码必须要放在声明optimizer之前!!!原因很简单,声明optimizer时,有个传入参数就是模型参数列表:

optimizer = torch.optim.SGD(model.parameters(), xxx)

但是这里会出现一个问题:由于self.alpha时在模型运行(预推理)时构建的,所以尚未放入cuda中。因此,需要手动将self.alpha放入cuda中。于是,有如下两种可能的写法:

# 写法1(错误)
self.alpha = torch.nn.Parameter(torch.ones(x.shape[0]), requires_grad=True).to(x.device)# 写法2(正确)
self.alpha = torch.nn.Parameter(torch.ones(x.shape[0]).to(x.device), requires_grad=True)

试问这两种写法都正确吗?思考一分钟…
时间到!实际上,只有写法2是正确的!

写法1先定义nn.Parameter,后放入cuda,会导致参数重新变回到tensor,从而不可学习;
写法2先放入cuda,后定义nn.Parameter,可以成功定义参数,可以学习。

总之,记住就好,这确实也是一个一找可以找一整晚的BUG了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/639282.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JavaEE中什么是Web容器?

Web容器(也称为Servlet引擎)是一个用于执行Java Servlet和JSP的服务器端环境。它负责管理和执行在其上运行的Web应用程序。 Tomcat是Web容器 Apache Tomcat 是一个流行的开源的Web容器,它实现了Java Servlet和JavaServer Pages(…

pinctrl子系统简介

一. 简介 上一章我们编写了基于设备树的 LED 驱动,但是驱动的本质还是没变,都是配置 LED 灯所使用的 GPIO 寄存器,驱动开发方式和裸机基本没啥区别。 Linux 是一个庞大而完善的系统, 尤其是驱动框架,像 GPIO …

【深度学习目标检测】十七、基于深度学习的洋葱检测系统-含GUI和源码(python,yolov8)

使用AI实现洋葱检测对农业具有以下意义: 提高效率:AI技术可以快速、准确地检测出洋葱中的缺陷和问题,从而提高了检测效率,减少了人工检测的时间和人力成本。提高准确性:AI技术通过大量的数据学习和分析,能够…

第五课:MindSpore自动并行

文章目录 第五课:MindSpore自动并行1、学习总结:数据并行模型并行MindSpore算子级并行算子级并行示例 流水线并行GPipe和Micro batch1F1B流水线并行示例 内存优化重计算优化器并行 MindSpore分布式并行模式课程ppt及代码地址 2、学习心得:3、…

如何使用pytorch的Dataset, 来定义自己的Dataset

Dataset与DataLoader的关系 Dataset: 构建一个数据集,其中含有所有的数据样本DataLoader:将构建好的Dataset,通过shuffle、划分batch、多线程num_workers运行的方式,加载到可训练的迭代容器。 import torch from torch.utils.dat…

Qt6入门教程 9:QWidget、QMainWindow和QDialog

目录 一.QWidget 1.窗口和控件 2.事件 二.QMainWindow 三.QDialog 1.模态对话框 1.1模态对话框 1.2.半模态对话框 2.非模态对话框 在用Qt Creator创建Qt Widgets项目时,会默认提供三种基类以供选择,它们分别是QWidget、QMainWIndow和QDialog&am…

SQL 注入总结(详细)

一、前言 这篇文章是最近学习 SQL 注入后的笔记,里面整理了 SQL 常见的注入方式,供大家学习了解 SQL 注入的原理及方法,也方便后续自己回顾,如有什么错误的地方欢迎指出! 二、判断注入类型 按照注入点类型分类 数字型…

外贸自建站如何建立?海洋建站的操作指南?

外贸自建站的建站流程什么?做跨境怎么搭建外贸网站? 外贸自建站成为企业开拓国际市场、提升品牌形象的重要途径。然而,对于许多企业而言,如何高效地进行外贸自建站仍然是一个挑战。海洋建站将带您一步步探讨外贸自建站的关键步骤…

logback排除指定包类方法的日志

logback排除指定包\类\方法的日志 修改logback.xml <!--排除指定包--> <logger name"com.servier" level"OFF"/><!--排除指定类--> <logger name"com.servier.UserServier" level"OFF"/><!--排除指定指定…

计算机网络——面试问题

1 从输⼊ URL 到⻚⾯展示到底发⽣了什么&#xff1f; 1. 先检查浏览器缓存⾥是否有缓存该资源&#xff0c;如果有直接返回&#xff1b;如果没有进⼊下⼀ 步⽹络请求。 2. ⽹络请求前&#xff0c;进⾏ DNS 解析 &#xff0c;以获取请求域名的 IP地址 。 3. 浏览器与服务器…

《WebKit 技术内幕》之七(3): 渲染基础

3 渲染方式 3.1 绘图上下文&#xff08;GraphicsContext&#xff09; 上面介绍了WebKit的内部表示结构&#xff0c;RenderObject对象知道如何绘制自己&#xff0c;但是&#xff0c;问题是RenderObject对象用什么来绘制内容呢&#xff1f;在WebKit中&#xff0c;绘图操作被定…

xcode 设置 ios苹果图标,为Flutter应用程序配置iOS图标

图标设置 1,根据图片构建各类尺寸的图标2.xcode打开ios文件3.xcode设置图标4.打包提交审核,即可(打包教程可通过我的主页查找) 1,根据图片构建各类尺寸的图标 工具网址:https://icon.wuruihong.com/ 下载之后文件目录如下 拷贝到项目的ios\Runner\Assets.xcassets\AppIcon.ap…

java简单的抽奖工具类(含测试方法)

文章目录 结果代码 结果 代码 import lombok.AllArgsConstructor; import lombok.Data; import lombok.ToString;import java.util.ArrayList; import java.util.List;/****/ public class LotteryUtils {public static void main(String[] args) throws InterruptedException…

PythonNet,Csharp如何白嫖Python生态和使用Matplotlib

文章目录 前言PythonNet环境配置Python环境配置Csharp Nuget配置运行代码测试运行结果 总结 前言 我既然用Csharp去尝试学习机器视觉&#xff0c;我就想试试用Csharp去使用Python的库。 这个世界上有没有编程语言既有Python的开发效率&#xff0c;又有C/C/ PythonNet Pythonne…

9 | Tensorflow中的batch批处理

TensorFlow支持批处理(batch processing)。批处理是指同时处理多个样本或数据点而不是单个样本。在深度学习中,批处理通常用于提高训练的效率和稳定性。 在TensorFlow中,可以使用tf.data.Dataset API来设置和处理批处理数据。这允许以批处理的方式加载和处理数据,适用于训…

MySQL in和exists的取舍

in和exists的取舍 之前说过要小表驱动大表&#xff0c;即先遍历小表再遍历大表&#xff0c;接下来看一下in和exists的区别 in 先执行子查询&#xff0c;适合于外表大而内表小的情况 select * from A where id in (select id from B)等价于先遍历表B select id from B再遍历表A …

Android:JNI实战,加载三方库、编译C/C++

一.概述 Android Jni机制让开发者可以在Java端调用到C/C&#xff0c;也是Android应用开发需要掌握的一项重要的基础技能。 计划分两篇博文讲述Jni实战开发。 本篇主要从项目架构上剖析一个Android App如何通过Jni机制加载三方库和C/C文件。 二.Native C Android Studio可…

精准核酸检测 - 华为OD统一考试

OD统一考试&#xff08;C卷&#xff09; 分值&#xff1a; 100分 题解&#xff1a; Java / Python / C 题目描述 为了达到新冠疫情精准防控的需要&#xff0c;为了避免全员核酸检测带来的浪费&#xff0c;需要精准圈定可能被感染的人群。 现在根据传染病流调以及大数据分析&a…

✅枚举类型在技术派中的应用示例

在Java编程中&#xff0c;我们经常会遇到需要表示一组相关常量的情况。为了提高代码的可读性和可维护性&#xff0c;Java引入了枚举类型。本文将介绍枚举类型的基本概念&#xff0c;并通过一个实际的示例来说明如何在Java中使用枚举。 什么是枚举类型&#xff1f; 枚举类型是一…

【代码实战】从0到1实现transformer

获取数据 import pathlibimport tensorflow as tf# download dataset provided by Anki: https://www.manythings.org/anki/ text_file tf.keras.utils.get_file(fname"fra-eng.zip",origin"http://storage.googleapis.com/download.tensorflow.org/data/fra-…