Pytorch 反向传播 计算图被修改的报错

先看看报错的内容

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [5, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

报错中说,一个需要梯度计算的变量已经被原地修改了,这引发了报错。

torch.set_grad_enabled(True)

然后我使用上述语句开启了梯度跟踪,发现问题出在我的标签计算函数:

def get_label(net, X):return net(X).reshape((-1, 1))

为什么会出错呢?在这种情况下,由于 label 是从网络输出直接计算得到的,它与网络的计算图相连接。如果在 label 上进行了原地操作(上述的修改形状操作),就可能破坏计算图,使其不可导或其他,总之是导致反向传播时无法正确计算梯度,从而引发报错。

那怎么解决这个问题?将该结果与计算图进行分离就行了,此刻如果再进行反向传播,梯度就不会传播到此处。修改后,代码如下;

def get_label(net, X):return net(X).detach().reshape((-1, 1))

detach()函数的作用是将数据和计算图分离开来,得到数据部分,与计算图再无瓜葛。

举一个更形象的例子,看下面的代码:

label = net(X)  # 计算标签
# 对 label 或 label 的某个部分进行了原地操作,比如:
# label[0, 0] = label[0, 0] * 2
# 或
# label += 1
loss = Loss(label, y)  # 计算损失

在这个例子中,label由第一条语句前向传播得到,是直接与网络的输出连在一起,后面我却对label的值进行了手动修改。

这些操作可能导致计算图的结构不完整或不可导,从而影响反向传播的计算。为了避免这样的问题,一般建议避免在计算标签或损失时对张量进行原地操作。如果需要修改张量的值,最好创建一个新的张量,而不是直接在原有张量上进行修改。

下面是我的整个程序,大家也可以调试代码来理解其中的含义:

import torch.nn as nn
import matplotlib.pyplot as plt
import torch
from torch.utils import data
def get_label(net, X):#计算标签,计算完后必须要使用detach()分离计算图,否则代码将报计算图被修改的错误return net(X).detach().reshape((-1, 1))def train(net, trainer, Loss, train_data, train_label, epochs, batch_size):#将训练数据和标签捆在一起,便于后面一起便利data_iter = data.DataLoader(list(zip(train_data, train_label)), batch_size=batch_size)#用来存储数据的变化值,前者为训练轮次,后者为每一轮训练平均损失draw_x, draw_y = [], []for epoch in range(epochs):#每次处理一个批次的数据for X, y in data_iter:trainer.zero_grad()  # 清除梯度pre_y = net(X)  # 前向传播loss = Loss(pre_y, y)  # 计算损失loss.backward()  # 反向传播,计算梯度trainer.step()  # 更新权重,进行优化#添加绘图需要的数据draw_x.append(epoch)draw_y.append(torch.mean(Loss(net(train_data),train_label)).data)#设置绘图参数plt.figure(figsize=(5, 4), dpi=150)#设置图像大小和分辨率plt.plot(draw_x, draw_y, label='train_loss')#设置要绘制的数据,被给出图例plt.xlabel('epoch')#设置X轴标题plt.ylabel('loss')#设置y轴标题plt.legend()#显示图例#显示最终图像plt.show()def test(net, Loss, test_data, test_label):loss_sum = torch.zeros_like(test_label)data_iter = data.DataLoader(list(zip(test_data, test_label)), batch_size=batch_size, shuffle=False)for X, y in data_iter:pre_y = net(X)  # 前向传播loss = Loss(pre_y, y)  # 计算损失loss_sum += loss  # 累加损失return torch.sum(loss_sum) / len(loss_sum)  # 返回平均损失def init_weight(m):if type(m) == nn.Linear:#权重使用何凯明正态初始化方法进行初始化nn.init.kaiming_normal_(m.weight)#偏置使用0偏置nn.init.zeros_(m.bias)lr = 0.01  # 学习率
epochs = 100  # 训练轮数
batch_size = 5  # 批大小
shared = nn.Linear(5, 5)  # 共享层
net = nn.Sequential(nn.Linear(10, 5), nn.ReLU(),  # 输入层到隐藏层1的线性层,ReLU激活函数shared, nn.ReLU(),  # 共享层,ReLU激活函数shared, nn.ReLU(),  # 共享层,ReLU激活函数nn.Linear(5, 1))  # 从隐藏层到输出层的线性层,无激活函数(线性回归)#显示真实参数(我们的标签就是用这个参数跑出来的),这也是我们最终需要拟合的参数
for name, param in net.named_parameters():print(name, param)#获取随机数作为样本
X = torch.randn((200, 10))
# 通过网络得到真实标签
True_label = get_label(net, X)
#一开始自动随机生成了参数已经被我当作真实参数了,此刻我需要另重新初始化参数
net.apply(init_weight)
#获取训练器
trainer = torch.optim.SGD(net.parameters(), lr=lr)
#获取损失函数
Loss = nn.MSELoss()  # 定义损失函数,使用均方误差。#开始训练模型发
train(net, trainer, Loss, X[:50], True_label[:50], epochs, batch_size=batch_size)
#打印测试损失
print(f'测试损失{test(net, Loss, X[50:], True_label[50:])}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/607952.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【linux】更改infiniband卡在Debian系统的网络接口名

在Debian或任何其他基于Linux的系统中,网络接口的名称由udev系统管理。通过创建udev规则,可以修改网络接口名称。以下是更改InfiniBand卡接口名称的一般步骤: 1. 找到网络接口的属性,以编写匹配的udev规则 可以使用udevadm命令查…

4.6 BOUNDARY CHECKS

我们现在扩展了tile矩阵乘法内核,以处理具有任意宽度的矩阵。扩展必须允许内核正确处理宽度不是tile宽度倍数的矩阵。通过更改图4.14中的示例至33 M、N和P矩阵,图4.18创建了矩阵的宽度为3,不是tile宽度(2)的倍数。图4.…

Spring事务控制

1.事务介绍 1.1什么是事务? 当你需要一次执行多条SQL语句时,可以使用事务。通俗一点说,如果这几条SQL语句全部执行成功,则才对数据库进行一次更新,如果有一条SQL语句执行失败,则这几条SQL语句全部不进行执…

window mysql5.7 搭建主从同步环境

window 搭建mysql5.7数据库 主从同步 主节点 配置文件my3308.cnf [mysql] # 设置mysql客户端默认字符集 default-character-setutf8mb4[mysqld] server-id8 #server-uuidbc701be9-ac71-11ee-9e35-b06ebf511956 log-binD:\mysql_5.7.19\mysql-5.7.19-winx64\mysql-bin binlog-…

如何在 Umi /Umi 4.0 中配置自动删除 console.log 语句?

背景,开发时需要console.log 日志,再生产、uat 、sit不想看到日志打印信息 方案1、代码规范eslint校验"no-console": true, //console.log 方案2、bable 插件 babel-plugin-transform-remove-console 配置在.umirx.ts/js中 export default…

一篇文章足以让你掌握蓝牙协议栈基本架构(蓝牙核心文档、HCI架构解读等)

目录 1. 蓝牙核心文档介绍 1.1 架构 1.2 BR/EDR 控制器 1.3 主机 1.4 主机控制器接口

【笔记】用Python做手机多平台UI应用

最近一直在找一个简单的基于Python的多平台UI实现,特别是希望能比较好地支持手机端。 总结一下标准是: Python,最好能支持Numpy、Pandas等库无缝集成简单,不要考虑过多的实现细节,如html、css、qt等多端支持&#xff…

性能分析与调优: Linux 内存观测工具

目录 一、实验 1.环境 2.vmstat 3.PSI 4.swapon 5.sar 6.slabtop 7.numstat 8.ps 9.top 10.pmap 11.perf 12.bpftrace 二、问题 1.接口读写报错 2.slabtop如何安装 3.numactl如何安装 4.numad启动服务与关闭NUMA 5. perf如何安装 6. kernel-lt-doc与kern…

go-carbon v2.3.4 发布,轻量级、语义化、对开发者友好的 Golang 时间处理库

carbon 是一个轻量级、语义化、对开发者友好的 golang 时间处理库,支持链式调用。 目前已被 awesome-go 收录,如果您觉得不错,请给个 star 吧 github.com/golang-module/carbon gitee.com/golang-module/carbon 安装使用 Golang 版本大于…

Vue3+Vite打包跨平台(七牛、阿里OSS)上传部署前端项目

1、业务场景 阅读之前,想了解一下各位观众老爷们,你们公司的项目是怎么部署的: 1.本地打包手动上传服务器; 2.本地打包自动上传服务器; 3.代码仓库流水线自动构建; 4.其他…; 我们用的第3种部…

EasyExcel 不使用科学计数发并以千分位展示

EasyExcel 不使用科学计数发并以千分位展示 不使用科学计数法 不使用科学计数法 BigDecimalStringConverter 将 BigDecimal 类型的数值转换为字符串类型,并将其导出到 Excel 文件中。在 convertToExcelData 方法中,我们将 BigDecimal 转换为字符串&…

线程|死锁条件及实现

死锁(Deadlock)是指两个或多个进程在执行过程中因争夺资源而造成的一种互相等待的现象 死锁通常发生在多任务系统中,其中进程通过竞争有限的资源来完成任务 死锁通常涉及互斥、持有和等待三个条件。 死锁的原因 互斥条件(Mutual…

前端中什么是DOM对象

DOM(文档对象模型)是一种编程接口,用于HTML和XML文档。它提供了一种将文档结构表示为树结构的方式,这使得程序和脚本能够动态地访问和更新文档的内容、结构和样式。 在前端开发中,DOM是非常重要的概念。当浏览器加载网…

认知能力测验,⑥如何破解逻辑判断类测试题?

逻辑思维,是一个比较大的范围,在绝大多数的招聘中,认知能力测评形式多样,难度也较大,其中逻辑判断题型所涉及到的分类为:概念类、条件类、矛盾类、数字类、图形类等知识。比如奥数就是个好东西.....如果经历…

Go语言日志美化库,slog使用指南

Go语言日志美化库,slog使用指南 1.slog2.快速开始3.使用JSON格式4.Text格式化formatter 1.slog slog是Go 实现的一个易于使用的,易扩展、可配置的日志库 slog - github 控制台效果: 安装方式: go get github.com/gookit/slog2…

RAG 最新最全资料整理

最近在做RAG方面的工作。它山之石可以攻玉,做了一些调研,包含了OpenAi,百川,iki.ai为我们提供的一些实现方案。 本文以时间顺序,整理了最近最新最全的和RAG相关的资料。都是满满的干货,包含了RAG评测工具、…

同步流复制过程

同步流复制过程 第一步:主库(primary端)第二步:备库(standby端)第三步:主库(primary端)其他1.主库IP变更2.主库上做回归测试时会卡住并出现以下提示 参考链接&#xff1a…

TOPS、MIPS、DMIPS、MFLOPS、吞吐量与推理效率

1.概述 在深度学习对应的神经推理中经常涉及几个重要概念,TOPS、MIPS、DMIPS,MFLOPS,下文对其做对比说明。 2.概念对比 2.1 MIPS Million Instructions Per Second的缩写,每秒处理的百万级的机器语言instructions。这是衡量处…

【单片机】四种烧写方式简介

目录 单片机的四种烧写方式简介 1.使用JTAG接口实现2.SWD接口烧录方式3.ISP烧写方式,用UART实现(常用)4.SWIM单总线下载方式 烧录方式基本介绍烧录方式详述 1、ISP:In System Programming2、IAP:In Applicatin Program…

部分城市公交站点数据,Shp+excel格式数据,2020年,几何类型为点

随着城市的发展和人口的增长,公共交通成为了人们出行的重要方式之一。而公交站点作为公共交通的重要组成部分,其数据信息的获取和分析对于城市规划和管理具有重要意义。 今天来分享一下部分城市公交站点数据: 首先先了解下该数据的基本信息 …