PyTorch梯度直通反传

有时我们想在层的输出端放置一个阈值函数。这可能出于多种原因。其中之一是我们想将激活总结为二进制值。这种激活的二值化在自编码器中很有用。

然而,阈值化在反向传播过程中会带来问题:阈值函数的导数为零。这种梯度的缺乏导致我们的网络无法学习任何东西。为了解决这个问题,我们可以使用直通估计器 (STE:Straight Through Estimator)。

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割 

1、什么是直通估计器?

假设我们想使用以下函数将层的激活二值化:

此函数将为每个大于 0 的值返回 1,否则将返回 0。

如前所述,此函数的问题在于其梯度为零。为了解决这个问题,我们将在反向传递中使用直通估计器。

直通估计器顾名思义就是它估计函数的梯度。具体来说,它忽略阈值函数的导数,并将传入的梯度传递,就好像该函数是恒等函数一样。下图有助于更好地解释它:

你可以看到在反向传递中如何绕过阈值函数。就是这样,这就是直通式估计器的作用。它使阈值函数的梯度看起来像恒等函数的梯度。

2、直通估计器的PyTorch 实现

截至目前,PyTorch 的 API 中尚未包含 STE 的实现。因此,我们必须自己实现它。为此,我们需要创建一个 Function 类和一个 Module 类。Function 类将包含 STE 的前向和后向功能。Module 类是创建和使用 STE Function 对象的地方。我们将在我们的神经网络中使用 STE Module。

以下是 STE Function 类的实现:

class STEFunction(torch.autograd.Function):@staticmethoddef forward(ctx, input):return (input > 0).float()@staticmethoddef backward(ctx, grad_output):return F.hardtanh(grad_output)

PyTorch 让我们可以定义具有前向和后向功能的自定义自动求导函数。这里我们为直通式估算器定义了一个自动求导函数。在前向传递中,我们希望将输入张量中的所有值从浮点转换为二进制。在后向传递中,我们希望传递传入的梯度而不对其进行修改。这是为了模仿恒等函数。不过,这里我们对传入的梯度执行 F.hardtanh 操作。此操作将梯度限制在 -1 和 1 之间。我们这样做是为了让梯度不会变得太大。

现在,让我们实现 STE 模块类:

class StraightThroughEstimator(nn.Module):def __init__(self):super(StraightThroughEstimator, self).__init__()def forward(self, x):x = STEFunction.apply(x)return x

你可以看到,我们在 forward 函数中使用了我们定义的 STE 函数类。要使用 autograd 函数,我们必须将输入传递给 apply 方法。现在,我们可以在神经网络中使用此模块。

使用 STE 的常见方法是在自编码器的瓶颈层内。以下是此类自编码器的实现:

class Autoencoder(nn.Module):def __init__(self):super(Autoencoder, self).__init__()self.encoder = nn.Sequential(nn.Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.ReLU(),nn.Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(512),nn.ReLU(),StraightThroughEstimator(),)self.decoder = nn.Sequential(nn.ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(256),nn.ReLU(),nn.ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(128),nn.ReLU(),nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(64),nn.ReLU(),nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.Tanh(),)def forward(self, x, encode=False, decode=False):if encode:x = self.encoder(x)elif decode:x = self.decoder(x)else:encoding = self.encoder(x)x = self.decoder(encoding)return x

这个自编码器是为 MNIST 数据集制作的。它将 28x28 图像压缩为具有 512 个通道的 1x1 图像。然后将其解码回 28x28 图像。

我将 STE 放在编码器的末尾。它将把接收到的张量的所有值转换为二进制。你可能已经注意到我使用了一个非常规的前向函数。我添加了两个新参数 encode 和 decrypt,它们要么是 True,要么是 False。如果 encode 设置为 True,网络将返回编码器的输出。同样,如果 decrypt 设置为 True,网络需要有效的编码并将其解码回图像。

我在 MNIST 数据集上对自动编码器进行了 5 个 epoch 的训练,并带有 MSE 损失。以下是测试集上的重建:

如你所见,重建效果非常好。STE 可用于神经网络,且性能不会有太大损失。

完整代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdmdevice = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# dataset preparation
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, ), (0.5, ))
])
trainset = datasets.MNIST('dataset/', train=True, download=True, transform=transform)
testset = datasets.MNIST('dataset/', train=False, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
# defining networks
class STEFunction(autograd.Function):@staticmethoddef forward(ctx, input):return (input > 0).float()@staticmethoddef backward(ctx, grad_output):return F.hardtanh(grad_output)
class StraightThroughEstimator(nn.Module):def __init__(self):super(StraightThroughEstimator, self).__init__()def forward(self, x):x = STEFunction.apply(x)return x
class Autoencoder(nn.Module):def __init__(self):super(Autoencoder, self).__init__()self.encoder = nn.Sequential(nn.Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.ReLU(),nn.Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(512),nn.ReLU(),StraightThroughEstimator(),)self.decoder = nn.Sequential(nn.ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(256),nn.ReLU(),nn.ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(128),nn.ReLU(),nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(64),nn.ReLU(),nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.Tanh(),)def forward(self, x, encode=False, decode=False):if encode:x = self.encoder(x)elif decode:x = self.decoder(x)else:encoding = self.encoder(x)x = self.decoder(encoding)return x
net = Autoencoder().to(device)
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.5, 0.999))
criterion_MSE = nn.MSELoss().to(device)
# train loop
epoch = 5
for e in range(epoch):print(f'Starting epoch {e} of {epoch}')for X, y in tqdm(trainloader):optimizer.zero_grad()X = X.to(device)reconstruction = net(X)loss = criterion_MSE(reconstruction, X)loss.backward()optimizer.step()print(f'Loss: {loss.item()}')
# test loop
i = 1
fig = plt.figure(figsize=(10, 10))
for X, y in testloader:X_in = X.to(device)recon = net(X_in).detach().cpu().numpy()if i >= 10:breakfig.add_subplot(5, 2, i).set_title('Original')plt.imshow(X[0].reshape((28, 28)), cmap="gray")fig.add_subplot(5, 2, i+1).set_title('Reconstruction')plt.imshow(recon[0].reshape((28, 28)), cmap="gray")i += 2
fig.tight_layout()
plt.show()

原文链接:梯度反传直通图解 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/31116.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

运动想象 (MI) 分类学习系列 (16) :LMDA-Net

运动想象分类学习系列:基于滑动窗口的通用空间模式 0. 引言1. 主要贡献2. 提出的方法2.1 LMDA-Net架构2.2 通道注意力2.3 深度注意力3. 结果3.1 实验结果3.2 消融实验4. 总结欢迎来稿论文地址:https://www.sciencedirect.com/science/article/pii/S1053811923003609 论文题目:…

MySQL:表的增删查改

文章目录 1.Create(创建)2.Retrieve(读取、查询)2.1 SELECT 列2.2 WHERE 子句2.3 结果排序(order by)2.4 筛选分页结果(limit、offset)2.5 Update更新2.6 Delete删除2.7 去重 3.聚合函数3.1 聚合函数的基本使用3.2group by子句的使用(分组查询) 增删查改:: Create(创…

Tailwindcss 扩展默认配置来自定义颜色

背景 项目里多个Tab标签都需要设置同样的背景颜色#F1F5FF,在集成tailwindcss之前就是重复该样式,如下图: .body {background-color: #f1f5ff; }集成tailwindcss时,我们希望在class中直接设置该背景色,但是默认的tai…

docker 安装与常用指令

1. docker 安装 sudo yum install -y yum-utilssudo yum-config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.reposudo yum install docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-pluginsudo systemctl enable do…

Windows11平台C++在VS2022中安装和使用Matplot++绘图库的时候出现的问题和解决方法

Matplot 是一个基于 C 的绘图库,专门用于绘制高质量的数据图表。它提供了一个简洁而强大的接口,使得用户能够轻松地创建各种类型的图表,包括线图、散点图、柱状图、饼图等。Matplot 的设计目标是提供与 MATLAB 相似的绘图体验,同时…

在编译内核时添加驱动的固件

最近调驱动时,无法正常加载引导。 使用的内核5.10 内核启动先于文件系统,内核启动时驱动无法访问固件文件,所以无法加载驱动。 有2个办法,可以解决,一是驱动编译KO模块,系统启动后,再动态加载…

Spring Boot 3 整合 SpringDoc OpenAPI 生成接口文档

😄 19年之后由于某些原因断更了三年,23年重新扬帆起航,推出更多优质博文,希望大家多多支持~ 🌷 古之立大事者,不惟有超世之才,亦必有坚忍不拔之志 🎐 个人CSND主页——Mi…

Flutter知识点

Dart语言基础知识 Dart特性: Dart 是少数同时支持 JIT(Just In Time,即时编译)和 AOT(Ahead of Time,运行前编译)的语言之一。语言在运行之前通常都需要编译,JIT 和 AOT 则是最常见…

HCIP-HarmonyOS Device Developer 课程大纲

一:系统及应用场景介绍 1 -(3 课时) - HarmonyOS 系统介绍;HarmonyOs 定义;HarmonyOS 特征; - 统一 OS,弹性部署;硬件互助,资源共享;一次开发,多…

vue3插槽slot的使用

一&#xff0c;默认插槽 父组件页面&#xff1a;使用子组件标签 <template><div>我是父组件自己的内容</div><ComTest></ComTest> // 这里使用子组件的内容<!-- <ComTest>我要替换默认插槽的内容</ComTest> // 这里替换子组件…

Unity 工具 之 Azure 微软 【GPT4o】HttpClient 异步流式请求的简单封装

Unity 工具 之 Azure 微软 【GPT4o】HttpClient 异步流式请求的简单封装 目录 Unity 工具 之 Azure 微软 【GPT4o】HttpClient 异步流式请求的简单封装 一、简单介绍 二、实现原理 三、注意实现 四、简单效果预览 五、案例简单实现步骤 六、关键代码 一、简单介绍 Unit…

使用Python进行数据可视化:从基础到高级

使用Python进行数据可视化:从基础到高级 数据可视化是数据分析过程中不可或缺的一部分,通过图形化的方式展示数据,可以更直观地发现数据中的趋势和模式。Python凭借其丰富的库和强大的功能,成为数据可视化的首选编程语言。本文将介绍数据可视化的基础概念、常用的Python库…

网络安全:Web 安全 面试题.(XSS)

网络安全&#xff1a;Web 安全 面试题.&#xff08;XSS&#xff09; 网络安全面试是指在招聘过程中,面试官会针对应聘者的网络安全相关知识和技能进行评估和考察。这种面试通常包括以下几个方面&#xff1a; &#xff08;1&#xff09;基础知识:包括网络基础知识、操作系统知…

超级好用的JSON格式化可视化在线工具

JSON是开发非常常用的一种报文格式&#xff0c;最常见的需求就是将JSON进行格式化&#xff0c;最好是有图形化界面显示结构关系&#xff0c;以便进行数据分析。 理想的在线JSON工具&#xff0c;应该支持快速格式化、可压缩、快捷复制、可下载导出&#xff0c;对存在语法错误的地…

Python之三大基本库——Numpy(1)

最近呢学了一些关于python的一些功能&#xff0c;为了更方便快捷高效的实现项目&#xff0c;我们要熟知python的三个基本库&#xff1a;numpy、pandas、matplotlib的功能。由于我也是入门新手&#xff0c;所以先做一些基本的总结&#xff0c;后续有进阶的话会再来更新。 一、Nu…

POI导入带有合并单元格的excel,demo实例,直接可以运行

直接可以运行 import org.apache.poi.hssf.usermodel.HSSFWorkbook; import org.apache.poi.ss.usermodel.Cell; import org.apache.poi.ss.usermodel.Row; import org.apache.poi.ss.usermodel.Sheet; import org.apache.poi.ss.usermodel.Workbook; import org.apache.poi.s…

网络与协议安全复习 - 系统安全部分

文章目录 恶意软件什么是恶意软件传播机制和载荷传播载荷 DDoS 攻击和防范 防火墙什么是防火墙防火墙类型防火墙载体 入侵检测入侵者入侵检测蜜罐技术 口令管理基于Bloom过滤器的口令检查技术 恶意软件 什么是恶意软件 恶意软件定义为&#xff1a;隐蔽植入另一段程序的程序&a…

使用 XML 配置定义和管理 Spring Bean

Spring 框架提供了多种方式来定义和管理 Bean&#xff0c;XML 配置是其中一种传统且强大的方式。尽管现在更多的项目使用基于注解的配置&#xff0c;但了解 XML 配置在理解 Spring 的工作原理和处理遗留系统时仍然非常重要。本文将详细介绍如何使用 XML 配置来定义和管理 Sprin…

数据赋能(125)——体系:数据格式化——实施过程、应用特点

实施过程 数据格式化的实施过程通常涉及以下几个关键步骤&#xff1a; 需求分析&#xff1a; 明确数据格式化的目标和需求&#xff0c;例如是为了数据展示、存储、传输还是其他目的。确定需要格式化的数据类型和格式&#xff0c;例如日期、数字、文本等。数据准备&#xff1a…

Node.js单点登录SSO详解:Session、JWT、CORS让登录更简单

文章目录 一、SSO介绍1、使用SSO的好处 二、中间件介绍1、Express安装导入使用 2、cors安装导入配置 3、express-session安装导入配置使用 4、jsonwebtoken安装导入使用 5、jwt和session对比 三、SSO实现方案1、安装依赖2、结构3、实现原理 三、示例代码1、nodejs端 server/ind…