PyTorch梯度直通反传

有时我们想在层的输出端放置一个阈值函数。这可能出于多种原因。其中之一是我们想将激活总结为二进制值。这种激活的二值化在自编码器中很有用。

然而,阈值化在反向传播过程中会带来问题:阈值函数的导数为零。这种梯度的缺乏导致我们的网络无法学习任何东西。为了解决这个问题,我们可以使用直通估计器 (STE:Straight Through Estimator)。

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割 

1、什么是直通估计器?

假设我们想使用以下函数将层的激活二值化:

此函数将为每个大于 0 的值返回 1,否则将返回 0。

如前所述,此函数的问题在于其梯度为零。为了解决这个问题,我们将在反向传递中使用直通估计器。

直通估计器顾名思义就是它估计函数的梯度。具体来说,它忽略阈值函数的导数,并将传入的梯度传递,就好像该函数是恒等函数一样。下图有助于更好地解释它:

你可以看到在反向传递中如何绕过阈值函数。就是这样,这就是直通式估计器的作用。它使阈值函数的梯度看起来像恒等函数的梯度。

2、直通估计器的PyTorch 实现

截至目前,PyTorch 的 API 中尚未包含 STE 的实现。因此,我们必须自己实现它。为此,我们需要创建一个 Function 类和一个 Module 类。Function 类将包含 STE 的前向和后向功能。Module 类是创建和使用 STE Function 对象的地方。我们将在我们的神经网络中使用 STE Module。

以下是 STE Function 类的实现:

class STEFunction(torch.autograd.Function):@staticmethoddef forward(ctx, input):return (input > 0).float()@staticmethoddef backward(ctx, grad_output):return F.hardtanh(grad_output)

PyTorch 让我们可以定义具有前向和后向功能的自定义自动求导函数。这里我们为直通式估算器定义了一个自动求导函数。在前向传递中,我们希望将输入张量中的所有值从浮点转换为二进制。在后向传递中,我们希望传递传入的梯度而不对其进行修改。这是为了模仿恒等函数。不过,这里我们对传入的梯度执行 F.hardtanh 操作。此操作将梯度限制在 -1 和 1 之间。我们这样做是为了让梯度不会变得太大。

现在,让我们实现 STE 模块类:

class StraightThroughEstimator(nn.Module):def __init__(self):super(StraightThroughEstimator, self).__init__()def forward(self, x):x = STEFunction.apply(x)return x

你可以看到,我们在 forward 函数中使用了我们定义的 STE 函数类。要使用 autograd 函数,我们必须将输入传递给 apply 方法。现在,我们可以在神经网络中使用此模块。

使用 STE 的常见方法是在自编码器的瓶颈层内。以下是此类自编码器的实现:

class Autoencoder(nn.Module):def __init__(self):super(Autoencoder, self).__init__()self.encoder = nn.Sequential(nn.Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.ReLU(),nn.Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(512),nn.ReLU(),StraightThroughEstimator(),)self.decoder = nn.Sequential(nn.ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(256),nn.ReLU(),nn.ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(128),nn.ReLU(),nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(64),nn.ReLU(),nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.Tanh(),)def forward(self, x, encode=False, decode=False):if encode:x = self.encoder(x)elif decode:x = self.decoder(x)else:encoding = self.encoder(x)x = self.decoder(encoding)return x

这个自编码器是为 MNIST 数据集制作的。它将 28x28 图像压缩为具有 512 个通道的 1x1 图像。然后将其解码回 28x28 图像。

我将 STE 放在编码器的末尾。它将把接收到的张量的所有值转换为二进制。你可能已经注意到我使用了一个非常规的前向函数。我添加了两个新参数 encode 和 decrypt,它们要么是 True,要么是 False。如果 encode 设置为 True,网络将返回编码器的输出。同样,如果 decrypt 设置为 True,网络需要有效的编码并将其解码回图像。

我在 MNIST 数据集上对自动编码器进行了 5 个 epoch 的训练,并带有 MSE 损失。以下是测试集上的重建:

如你所见,重建效果非常好。STE 可用于神经网络,且性能不会有太大损失。

完整代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdmdevice = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# dataset preparation
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, ), (0.5, ))
])
trainset = datasets.MNIST('dataset/', train=True, download=True, transform=transform)
testset = datasets.MNIST('dataset/', train=False, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
# defining networks
class STEFunction(autograd.Function):@staticmethoddef forward(ctx, input):return (input > 0).float()@staticmethoddef backward(ctx, grad_output):return F.hardtanh(grad_output)
class StraightThroughEstimator(nn.Module):def __init__(self):super(StraightThroughEstimator, self).__init__()def forward(self, x):x = STEFunction.apply(x)return x
class Autoencoder(nn.Module):def __init__(self):super(Autoencoder, self).__init__()self.encoder = nn.Sequential(nn.Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.ReLU(),nn.Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(512),nn.ReLU(),StraightThroughEstimator(),)self.decoder = nn.Sequential(nn.ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(256),nn.ReLU(),nn.ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(128),nn.ReLU(),nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(64),nn.ReLU(),nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),nn.Tanh(),)def forward(self, x, encode=False, decode=False):if encode:x = self.encoder(x)elif decode:x = self.decoder(x)else:encoding = self.encoder(x)x = self.decoder(encoding)return x
net = Autoencoder().to(device)
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.5, 0.999))
criterion_MSE = nn.MSELoss().to(device)
# train loop
epoch = 5
for e in range(epoch):print(f'Starting epoch {e} of {epoch}')for X, y in tqdm(trainloader):optimizer.zero_grad()X = X.to(device)reconstruction = net(X)loss = criterion_MSE(reconstruction, X)loss.backward()optimizer.step()print(f'Loss: {loss.item()}')
# test loop
i = 1
fig = plt.figure(figsize=(10, 10))
for X, y in testloader:X_in = X.to(device)recon = net(X_in).detach().cpu().numpy()if i >= 10:breakfig.add_subplot(5, 2, i).set_title('Original')plt.imshow(X[0].reshape((28, 28)), cmap="gray")fig.add_subplot(5, 2, i+1).set_title('Reconstruction')plt.imshow(recon[0].reshape((28, 28)), cmap="gray")i += 2
fig.tight_layout()
plt.show()

原文链接:梯度反传直通图解 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/31116.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL:表的增删查改

文章目录 1.Create(创建)2.Retrieve(读取、查询)2.1 SELECT 列2.2 WHERE 子句2.3 结果排序(order by)2.4 筛选分页结果(limit、offset)2.5 Update更新2.6 Delete删除2.7 去重 3.聚合函数3.1 聚合函数的基本使用3.2group by子句的使用(分组查询) 增删查改:: Create(创…

Tailwindcss 扩展默认配置来自定义颜色

背景 项目里多个Tab标签都需要设置同样的背景颜色#F1F5FF,在集成tailwindcss之前就是重复该样式,如下图: .body {background-color: #f1f5ff; }集成tailwindcss时,我们希望在class中直接设置该背景色,但是默认的tai…

Windows11平台C++在VS2022中安装和使用Matplot++绘图库的时候出现的问题和解决方法

Matplot 是一个基于 C 的绘图库,专门用于绘制高质量的数据图表。它提供了一个简洁而强大的接口,使得用户能够轻松地创建各种类型的图表,包括线图、散点图、柱状图、饼图等。Matplot 的设计目标是提供与 MATLAB 相似的绘图体验,同时…

在编译内核时添加驱动的固件

最近调驱动时,无法正常加载引导。 使用的内核5.10 内核启动先于文件系统,内核启动时驱动无法访问固件文件,所以无法加载驱动。 有2个办法,可以解决,一是驱动编译KO模块,系统启动后,再动态加载…

Spring Boot 3 整合 SpringDoc OpenAPI 生成接口文档

😄 19年之后由于某些原因断更了三年,23年重新扬帆起航,推出更多优质博文,希望大家多多支持~ 🌷 古之立大事者,不惟有超世之才,亦必有坚忍不拔之志 🎐 个人CSND主页——Mi…

Unity 工具 之 Azure 微软 【GPT4o】HttpClient 异步流式请求的简单封装

Unity 工具 之 Azure 微软 【GPT4o】HttpClient 异步流式请求的简单封装 目录 Unity 工具 之 Azure 微软 【GPT4o】HttpClient 异步流式请求的简单封装 一、简单介绍 二、实现原理 三、注意实现 四、简单效果预览 五、案例简单实现步骤 六、关键代码 一、简单介绍 Unit…

网络安全:Web 安全 面试题.(XSS)

网络安全:Web 安全 面试题.(XSS) 网络安全面试是指在招聘过程中,面试官会针对应聘者的网络安全相关知识和技能进行评估和考察。这种面试通常包括以下几个方面: (1)基础知识:包括网络基础知识、操作系统知…

超级好用的JSON格式化可视化在线工具

JSON是开发非常常用的一种报文格式,最常见的需求就是将JSON进行格式化,最好是有图形化界面显示结构关系,以便进行数据分析。 理想的在线JSON工具,应该支持快速格式化、可压缩、快捷复制、可下载导出,对存在语法错误的地…

网络与协议安全复习 - 系统安全部分

文章目录 恶意软件什么是恶意软件传播机制和载荷传播载荷 DDoS 攻击和防范 防火墙什么是防火墙防火墙类型防火墙载体 入侵检测入侵者入侵检测蜜罐技术 口令管理基于Bloom过滤器的口令检查技术 恶意软件 什么是恶意软件 恶意软件定义为:隐蔽植入另一段程序的程序&a…

Node.js单点登录SSO详解:Session、JWT、CORS让登录更简单

文章目录 一、SSO介绍1、使用SSO的好处 二、中间件介绍1、Express安装导入使用 2、cors安装导入配置 3、express-session安装导入配置使用 4、jsonwebtoken安装导入使用 5、jwt和session对比 三、SSO实现方案1、安装依赖2、结构3、实现原理 三、示例代码1、nodejs端 server/ind…

React是怎么进行事件处理的

什么是事件? 事件是指一些可以通过脚本响应的页面动作。当用户按下鼠标或者提交一个表单等等时候,事件都会出现。事件处理是一段JavaScript代码,总是与页面中的特定部分以及一定的事件相关联。当与页面特定部分相关联的事件发生时&#xff0c…

MDK-ARM 编译后 MAP 文件分析

本文配合 STM32 堆栈空间分布 食用更佳! 一图胜千言。。。

pytorch十大核心操作

PyTorch的十大核心操作涵盖了张量创建、数据转换、操作变换等多个方面。以下是结合参考文章信息整理出的PyTorch十大核心操作的概述: 张量创建: 从Python列表或NumPy数组创建张量。使用特定值创建张量,如全零、全一、指定范围、均匀分布、正…

开发环境安装---Visual Studio Code

开发环境安装---Visual Studio Code 1.官网下载Visual Studio Code2.安装步骤3.安装插件 1.官网下载Visual Studio Code VScode: https://code.visualstudio.com/ Visual Studio Code 简称 VSCode ,2015 年由微软公司发布。可用于 Windows,macOS 和 Li…

HTML(17)——圆角和盒子阴影

盒子模型——圆角 作用:设置元素的外边框为圆角 属性名:border-radius 属性值:数字px/百分比 也可以每个角设置不同的效果,从左上角顺时针开始赋值,没有取值的角与对角取值相同。 正圆 给正方形盒子设置圆角属性…

数据库实战(二)(引言+关系代数)

🌈 个人主页:十二月的猫-CSDN博客 🔥 系列专栏: 🏀数据库 💪🏻 十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光 目录 前言 常见概念 一、什么是数据库&#xf…

【鸿蒙】HUAWEI DevEco Studio安装

HUAWEI DevEco Studio介绍 面向HarmonyOS应用及元服务开发者提供的集成开发环境(IDE), 助力高效开发。 DevEco Studio当前最新版本是: 3.1。 DevEco Studio计划里程碑 版本类型说明 下载 下载网址:DevEco Studio安装包官⽅下载 双击运行…

C++ | Leetcode C++题解之第169题多数元素

题目&#xff1a; 题解&#xff1a; class Solution { public:int majorityElement(vector<int>& nums) {int candidate -1;int count 0;for (int num : nums) {if (num candidate)count;else if (--count < 0) {candidate num;count 1;}}return candidate;…

STM32通过Flymcu串口下载程序

文章目录 1. Flymcu 2. 操作流程 2.1 设备准备 2.2 硬件连接 2.3 设置BOOT引脚 2.4 配置 2.5 下载程序 1. Flymcu Flymcu软件可以通过串口给STM32下载程序&#xff0c;如果没有STLINK的时候&#xff0c;就可以使用这个来烧录程序。软件不用安装&#xff0c;直接打开就行…

Windows11+CUDA12.0+RTX4090如何配置安装Tensorflow2-GPU环境?

1 引言 电脑配置 Windows 11 cuda 12.0 RTX4090 由于tensorflow2官网已经不支持cuda11以上的版本了&#xff0c;配置cuda和tensorflow可以通过以下步骤配置实现。 2 步骤 &#xff08;1&#xff09;创建conda环境并安装cuda和cudnn&#xff0c;以及安装tensorflow2.10 con…