Pytorch 自定义激活函数前向与反向传播 Tanh

看完这篇,你基本上可以自定义前向与反向传播,可以自己定义自己的算子

文章目录

    • Tanh
      • 公式
      • 求导过程
      • 优点:
      • 缺点:
      • 自定义Tanh
      • 与Torch定义的比较
      • 可视化

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F%matplotlib inlineplt.rcParams['figure.figsize'] = (7, 3.5)
plt.rcParams['figure.dpi'] = 150
plt.rcParams['axes.unicode_minus'] = False  #解决坐标轴负数的铅显示问题

Tanh

公式

tanh⁡(x)=sinh⁡(x)cosh⁡(x)=ex−e−xex+e−x\tanh(x) = \frac{\sinh(x)}{\cosh(x)} = \frac{e^x - e^{-x}}{e^x + e^{-x}}tanh(x)=cosh(x)sinh(x)=ex+exexex

tanh⁡(x)=2σ(2x)−1\tanh(x) = 2 \sigma(2x) - 1 tanh(x)=2σ(2x)1

求导过程

tanh⁡′(x)=(ex−e−xex+e−x)′=[(ex−e−x)(ex+e−x)−1]′=(ex+e−x)(ex+e−x)−1+(ex−e−x)(−1)(ex+e−x)−2(ex−e−x)=1−(ex−e−x)2(ex+e−x)−2=1−(ex−e−x)2(ex+e−x)2=1−tanh⁡2(x)\begin{aligned} \tanh'(x) =& \big(\frac{e^x - e^{-x}}{e^x + e^{-x}}\big)' \\ =& \big[(e^x - e^{-x})(e^x + e^{-x})^{-1}\big]' \\ =& (e^x + e^{-x})(e^x + e^{-x})^{-1} + (e^x - e^{-x})(-1)(e^x + e^{-x})^{-2} (e^x - e^{-x}) \\ =& 1-(e^x - e^{-x})^2(e^x + e^{-x})^{-2} \\ =& 1 - \frac{(e^x - e^{-x})^2}{(e^x + e^{-x})^2} \\ =& 1- \tanh^2(x) \\ \end{aligned}tanh(x)======(ex+exexex)[(exex)(ex+ex)1](ex+ex)(ex+ex)1+(exex)(1)(ex+ex)2(exex)1(exex)2(ex+ex)21(ex+ex)2(exex)21tanh2(x)

优点:

Tanh也称为双切正切函数,取值范围为[-1,1]。tanh在特征相差明显时的效果会很好,在循环过程中会不断扩大特征效果。与 sigmoid 的区别是,tanh 是 0 均值的,因此实际应用中 tanh 会比 sigmoid 更好。文献 [LeCun, Y., et al., Backpropagation applied to handwritten zip code recognition. Neural computation, 1989. 1(4): p. 541-551.] 中提到tanh 网络的收敛速度要比sigmoid快,因为tanh 的输出均值比 sigmoid 更接近 0,SGD会更接近 natural gradient[4](一种二次优化技术),从而降低所需的迭代次数。非常优秀,几乎适合所有的场景

缺点:

  • 该导数在正负饱和区的梯度都会接近于0值,会造成梯度消失。还有其更复杂的幂运算。

自定义Tanh

class SelfDefinedTanh(torch.autograd.Function):@staticmethoddef forward(ctx, inp):exp_x = torch.exp(inp)exp_x_ = torch.exp(-inp)result = torch.divide((exp_x - exp_x_), (exp_x + exp_x_))ctx.save_for_backward(result)return result@staticmethoddef backward(ctx, grad_output):# ctx.saved_tensors is tuple (tensors, grad_fn)result, = ctx.saved_tensorsreturn grad_output * (1 - result.pow(2))class Tanh(nn.Module):def __init__(self):super().__init__()def forward(self, x):out = SelfDefinedTanh.apply(x)return out
def tanh_sigmoid(x):"""according to the equation"""# 2 * torch.sigmoid(2 * x) -1 return torch.mul(torch.sigmoid(torch.mul(x, 2)), 2) - 1

与Torch定义的比较

# self defined
torch.manual_seed(0)tanh = Tanh()  # SelfDefinedTanh
inp = torch.randn(5, requires_grad=True)
out = tanh((inp + 1).pow(2))print(f'Out is\n{out}')out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nFirst call\n{inp.grad}")out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")
Out is
tensor([1.0000, 0.4615, 0.8831, 0.9855, 0.0071],grad_fn=<SelfDefinedTanhBackward>)First call
tensor([ 5.0889e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])Second call
tensor([ 1.0178e-04,  2.2243e+00, -1.0382e+00,  1.8053e-01, -3.3807e-01])Call after zeroing gradients
tensor([ 5.0889e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])
# self defined tanh_sigmoid
torch.manual_seed(0)inp = torch.randn(5, requires_grad=True)
out = tanh_sigmoid((inp + 1).pow(2))print(f'Out is\n{out}')out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nFirst call\n{inp.grad}")out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")
Out is
tensor([1.0000, 0.4615, 0.8831, 0.9855, 0.0071], grad_fn=<SubBackward0>)First call
tensor([ 5.0889e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])Second call
tensor([ 1.0178e-04,  2.2243e+00, -1.0382e+00,  1.8053e-01, -3.3807e-01])Call after zeroing gradients
tensor([ 5.0889e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])
# torch defined
torch.manual_seed(0)inp = torch.randn(5, requires_grad=True)
out = torch.tanh((inp + 1).pow(2))print(f'Out is\n{out}')out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nFirst call\n{inp.grad}")out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")
Out is
tensor([1.0000, 0.4615, 0.8831, 0.9855, 0.0071], grad_fn=<TanhBackward>)First call
tensor([ 5.0283e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])Second call
tensor([ 1.0057e-04,  2.2243e+00, -1.0382e+00,  1.8053e-01, -3.3807e-01])Call after zeroing gradients
tensor([ 5.0283e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])

从上3个结果,可以看出,不管是经过sigmoid来计算,还是公式定义都可以得到一样的output与gradient。但在输入的值较大时,torch应该是减去一个小值,使得梯度更小。

可视化

# visualization
inp = torch.arange(-8, 8, 0.1, requires_grad=True)
out = tanh(inp)
out.sum().backward()inp_grad = inp.gradplt.plot(inp.detach().numpy(),out.detach().numpy(),label=r"$\tanh(x)$",alpha=0.7)
plt.plot(inp.detach().numpy(),inp_grad.numpy(),label=r"$\tanh'(x)$",alpha=0.5)
plt.grid()
plt.legend()
plt.show()

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/260477.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HDU ACM 1181 变形课 (广搜BFS + 动态数组vector)-------第一次使用动态数组vector

http://acm.hdu.edu.cn/showproblem.php?pid1181 题意&#xff1a;给我若干个单词,若单词A的结尾与单词B的开头相同,则表示A能变成B,判断能不能从b开头变成m结尾. 如: big-got-them 第一次使用动态数组vector View Code 1 #include <iostream>2 #include <vector>…

Max Sum 杭电 1003

2019独角兽企业重金招聘Python工程师标准>>> #题目概述 题目的意思是给你一个数列&#xff0c;找到一个子数列&#xff0c;这个子数列的和是所有子数列中和最大的。 当然把数列的所有数都列出来肯定不现实。 黑黑&#xff0c;不知道正不正确&#xff0c;我是先从第一…

shiro反序列化工具_Apache Shiro 1.2.4反序列化漏洞(CVE-2016-4437)源码解析

Apache ShiroApache Shiro是一个功能强大且灵活的开源安全框架,主要功能包括用户认证、授权、会话管理以及加密。在了解该漏洞之前,建议学习下Apache Shiro是怎么使用.debug环境jdk1.8Apache Shiro 1.2.4测试demo本地debug需要以下maven依赖<!-- https://mvnrepository.com/…

window 下的mysql_Windows下MySQL下载安装、配置与使用

用过MySQL之后&#xff0c;不论容量的话&#xff0c;发现比其他两个(sql server 、oracle)好用的多&#xff0c;一下子就喜欢上了。下面给那些还不知道怎么弄的童鞋们写下具体的方法步骤。(我这个写得有点太详细了&#xff0c;甚至有些繁琐&#xff0c;有很多步骤在其他的教程文…

Pytorch 自定义激活函数前向与反向传播 ReLu系列 含优点与缺点

文章目录ReLu公式求导过程优点&#xff1a;缺点&#xff1a;自定义ReLu与Torch定义的比较可视化Leaky ReLu PReLu公式求导过程优点&#xff1a;缺点&#xff1a;自定义LeakyReLu与Torch定义的比较可视化自定义PReLuELU公式求导过程优点缺点自定义LeakyReLu与Torch定义的比较可视…

mybatis select count(*) 一直返回0 mysql_Mybatis教程1:MyBatis快速入门

点击上方“Java技术前线”&#xff0c;选择“置顶或者星标”与你一起成长一、Mybatis介绍MyBatis是一个支持普通*SQL*查询&#xff0c;存储过程和高级映射的优秀持久层框架。MyBatis消除了几乎所有的JDBC代码和参数的手工设置以及对结果集的检索封装。MyBatis可以使用简单的XML…

css预处理器sass使用教程(多图预警)

css预处理器赋予了css动态语言的特性&#xff0c;如变量、函数、运算、继承、嵌套等&#xff0c;有助于更好地组织管理样式文件&#xff0c;以及更高效地开发项目。css预处理器可以更方便的维护和管理css代码&#xff0c;让整个网页变得更加灵活可变。对于预处理器&#xff0c;…

Sharepoint学习笔记—Site Definition系列-- 2、创建Content Type

Sharepoint本身就是一个丰富的大容器&#xff0c;里面存储的所有信息我们可以称其为“内容(Content)”&#xff0c;为了便于管理这些Conent&#xff0c;按照人类的正常逻辑就必然想到的是对此进行“分类”。分类所涉及到的层面又必然包括: 1、分类的标准或特征描述{即&#xf…

arduino byte转string_Java数组转List集合的三驾马车

点击上方 蓝字关注我们来源&#xff1a;blog.csdn.net/x541211190/article/details/79597236前言本文中的代码命名有的可能不太规范&#xff0c;是因为没法排版的问题&#xff0c;小仙已经很努力去解决了&#xff0c;希望各位能多多点赞、分享。好了&#xff0c;不多bb了(不要让…

ES6笔记(4)-- Symbol类型

系列文章 -- ES6笔记系列 Symbol是什么&#xff1f;中文意思是标志、记号&#xff0c;顾名思义&#xff0c;它可以用了做记号。 是的&#xff0c;它是一种标记的方法&#xff0c;被ES6引入作为一种新的数据类型&#xff0c;表示独一无二的值。 由此&#xff0c;JS的数据类型多了…

手把手教你如下在Linux下如何写一个C语言代码,编译并运行

文章目录手把手教你如下在Linux下如何写一个C语言代码&#xff0c;编译并运行打开Ubuntu终端创建 helloworld.c编译C文件手把手教你如下在Linux下如何写一个C语言代码&#xff0c;编译并运行 打开Ubuntu终端 我这里的终端是Windows下的WSL&#xff0c;如果有疑问&#xff0c;…

邮件群发工具的编写(二)数据的保存

数据的保存与读取 人类是在不断探索与改进中进步的 上一篇&#xff0c;邮件群发工具的编写&#xff08;一&#xff09;邮件地址提取&#xff0c;我们讲到了邮箱的提取。 那么这一篇&#xff0c;讲一下提取完的邮箱信息的保存和读取。 首先&#xff0c;我希望对上一篇邮箱提取类…

c++ lambda函数_C++11 之 lambda函数的详细使用

1. lambda 函数概述lambda 表达式是一种匿名函数&#xff0c;即没有函数名的函数&#xff1b;该匿名函数是由数学中的λ演算而来的。通常情况下&#xff0c;lambda函数的语法定义为&#xff1a;[capture] (parameters) mutable ->return-type {statement}其中&#xff1a;[c…

pytorch 正向与反向传播的过程 获取模型的梯度(gradient),并绘制梯度的直方图

记录一下怎样pytorch框架下怎样获得模型的梯度 文章目录引入所需要的库一个简单的函数模型梯度获取先定义一个model如下定义两个获取梯度的函数定义一些过程与调用上述函数的方法可视化一下梯度的histogram引入所需要的库 import os import torch import torch.nn as nn impor…

ubuntu升级python_Ubuntu 升级python3为更高版本【已实测】

2020-04-13 更新安装步骤&#xff1a; 1. 先update一下 sudo apt update 2. 安装依赖库 sudo apt-get install zlib1g-dev libbz2-dev libssl-dev libncurses5-dev libsqlite3-dev libreadline-dev tk-dev libgdbm-dev libdb-dev libpcap-dev xz-utils libexpat1-dev liblzma-d…

Framework打包

2019独角兽企业重金招聘Python工程师标准>>> iOS app需要在许多不同的CPU架构下运行&#xff1a; arm7: 在最老的支持iOS7的设备上使用 arm7s: 在iPhone5和5C上使用 arm64: 运行于iPhone5S的64位 ARM 处理器 上 i386: 32位模拟器上使用 x86_64: 64为模拟器上使用…

windows 10 下利用WSL的Linux环境实现vscode C/C++环境的配置

本文主要结合二个工具&#xff0c;介绍如何在windows搭建Linux开发环境&#xff1a; WSL(Windows Subsystem for Linux)VSCode(Visual Studio Code) 文章目录WSL安装VSCode安装配置Linux下的C/C环境1. 打开WSL的控制台2. 更新ubuntu软件3. 安装GCC和GDB4. 配置VSCode(1). 打开…

Windows 11下 WSL使用 jupyter notebook

这里写目录标题前言在WSL下的配置测试运行更优雅的启动方法配置jupyter生成默认配置文件生成秘钥修改配置文件nohup启动前言 一直都使用jupyter notebook&#xff0c;不管做数据分析&#xff0c;还是调试代码&#xff0c;还有写文章都是。但是好像在WSL下又不好使。看了网上有…

sql2000导出mysql_如何将sql2000的数据库导入到mysql中?

展开全部先用SQl2000导出e68a843231313335323631343130323136353331333262373366文本文件&#xff0c;把后缀名改为CSv&#xff0c;再从Mysql中一导入OK参考&#xff1a;第一种是安装mysql ODBC&#xff0c;利用sql server的导出功能&#xff0c;选择mysql数据源&#xff0c;进…

实现日、周、月排行统计 sql

在如今很多系统中&#xff0c;都需要进行日、周、月排行统计&#xff0c;但是在网上寻找 了一番&#xff0c;发现很多都是相对的周、月排行&#xff0c;即周排行则用当前时间减去7天。这样我个人认为并不恰当。如月排行中&#xff0c;假设今天是4月22日,则从3月22日至4月22日之…