Pytorch 自定义激活函数前向与反向传播 Tanh

看完这篇,你基本上可以自定义前向与反向传播,可以自己定义自己的算子

文章目录

    • Tanh
      • 公式
      • 求导过程
      • 优点:
      • 缺点:
      • 自定义Tanh
      • 与Torch定义的比较
      • 可视化

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F%matplotlib inlineplt.rcParams['figure.figsize'] = (7, 3.5)
plt.rcParams['figure.dpi'] = 150
plt.rcParams['axes.unicode_minus'] = False  #解决坐标轴负数的铅显示问题

Tanh

公式

tanh⁡(x)=sinh⁡(x)cosh⁡(x)=ex−e−xex+e−x\tanh(x) = \frac{\sinh(x)}{\cosh(x)} = \frac{e^x - e^{-x}}{e^x + e^{-x}}tanh(x)=cosh(x)sinh(x)=ex+exexex

tanh⁡(x)=2σ(2x)−1\tanh(x) = 2 \sigma(2x) - 1 tanh(x)=2σ(2x)1

求导过程

tanh⁡′(x)=(ex−e−xex+e−x)′=[(ex−e−x)(ex+e−x)−1]′=(ex+e−x)(ex+e−x)−1+(ex−e−x)(−1)(ex+e−x)−2(ex−e−x)=1−(ex−e−x)2(ex+e−x)−2=1−(ex−e−x)2(ex+e−x)2=1−tanh⁡2(x)\begin{aligned} \tanh'(x) =& \big(\frac{e^x - e^{-x}}{e^x + e^{-x}}\big)' \\ =& \big[(e^x - e^{-x})(e^x + e^{-x})^{-1}\big]' \\ =& (e^x + e^{-x})(e^x + e^{-x})^{-1} + (e^x - e^{-x})(-1)(e^x + e^{-x})^{-2} (e^x - e^{-x}) \\ =& 1-(e^x - e^{-x})^2(e^x + e^{-x})^{-2} \\ =& 1 - \frac{(e^x - e^{-x})^2}{(e^x + e^{-x})^2} \\ =& 1- \tanh^2(x) \\ \end{aligned}tanh(x)======(ex+exexex)[(exex)(ex+ex)1](ex+ex)(ex+ex)1+(exex)(1)(ex+ex)2(exex)1(exex)2(ex+ex)21(ex+ex)2(exex)21tanh2(x)

优点:

Tanh也称为双切正切函数,取值范围为[-1,1]。tanh在特征相差明显时的效果会很好,在循环过程中会不断扩大特征效果。与 sigmoid 的区别是,tanh 是 0 均值的,因此实际应用中 tanh 会比 sigmoid 更好。文献 [LeCun, Y., et al., Backpropagation applied to handwritten zip code recognition. Neural computation, 1989. 1(4): p. 541-551.] 中提到tanh 网络的收敛速度要比sigmoid快,因为tanh 的输出均值比 sigmoid 更接近 0,SGD会更接近 natural gradient[4](一种二次优化技术),从而降低所需的迭代次数。非常优秀,几乎适合所有的场景

缺点:

  • 该导数在正负饱和区的梯度都会接近于0值,会造成梯度消失。还有其更复杂的幂运算。

自定义Tanh

class SelfDefinedTanh(torch.autograd.Function):@staticmethoddef forward(ctx, inp):exp_x = torch.exp(inp)exp_x_ = torch.exp(-inp)result = torch.divide((exp_x - exp_x_), (exp_x + exp_x_))ctx.save_for_backward(result)return result@staticmethoddef backward(ctx, grad_output):# ctx.saved_tensors is tuple (tensors, grad_fn)result, = ctx.saved_tensorsreturn grad_output * (1 - result.pow(2))class Tanh(nn.Module):def __init__(self):super().__init__()def forward(self, x):out = SelfDefinedTanh.apply(x)return out
def tanh_sigmoid(x):"""according to the equation"""# 2 * torch.sigmoid(2 * x) -1 return torch.mul(torch.sigmoid(torch.mul(x, 2)), 2) - 1

与Torch定义的比较

# self defined
torch.manual_seed(0)tanh = Tanh()  # SelfDefinedTanh
inp = torch.randn(5, requires_grad=True)
out = tanh((inp + 1).pow(2))print(f'Out is\n{out}')out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nFirst call\n{inp.grad}")out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")
Out is
tensor([1.0000, 0.4615, 0.8831, 0.9855, 0.0071],grad_fn=<SelfDefinedTanhBackward>)First call
tensor([ 5.0889e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])Second call
tensor([ 1.0178e-04,  2.2243e+00, -1.0382e+00,  1.8053e-01, -3.3807e-01])Call after zeroing gradients
tensor([ 5.0889e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])
# self defined tanh_sigmoid
torch.manual_seed(0)inp = torch.randn(5, requires_grad=True)
out = tanh_sigmoid((inp + 1).pow(2))print(f'Out is\n{out}')out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nFirst call\n{inp.grad}")out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")
Out is
tensor([1.0000, 0.4615, 0.8831, 0.9855, 0.0071], grad_fn=<SubBackward0>)First call
tensor([ 5.0889e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])Second call
tensor([ 1.0178e-04,  2.2243e+00, -1.0382e+00,  1.8053e-01, -3.3807e-01])Call after zeroing gradients
tensor([ 5.0889e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])
# torch defined
torch.manual_seed(0)inp = torch.randn(5, requires_grad=True)
out = torch.tanh((inp + 1).pow(2))print(f'Out is\n{out}')out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nFirst call\n{inp.grad}")out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")
Out is
tensor([1.0000, 0.4615, 0.8831, 0.9855, 0.0071], grad_fn=<TanhBackward>)First call
tensor([ 5.0283e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])Second call
tensor([ 1.0057e-04,  2.2243e+00, -1.0382e+00,  1.8053e-01, -3.3807e-01])Call after zeroing gradients
tensor([ 5.0283e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])

从上3个结果,可以看出,不管是经过sigmoid来计算,还是公式定义都可以得到一样的output与gradient。但在输入的值较大时,torch应该是减去一个小值,使得梯度更小。

可视化

# visualization
inp = torch.arange(-8, 8, 0.1, requires_grad=True)
out = tanh(inp)
out.sum().backward()inp_grad = inp.gradplt.plot(inp.detach().numpy(),out.detach().numpy(),label=r"$\tanh(x)$",alpha=0.7)
plt.plot(inp.detach().numpy(),inp_grad.numpy(),label=r"$\tanh'(x)$",alpha=0.5)
plt.grid()
plt.legend()
plt.show()

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/260477.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

multi mysql_mysqld_multi 的使用方法

mysqld_multi 的使用方法:官方文档&#xff1a;https://dev.mysql.com/doc/refman/5.7/en/mysqld-multi.html 【文档有些问题&#xff0c;按照它的这个配置&#xff0c;mysqld_multi无法关闭实例】mysqld_multi无法关闭实例的解决方法&#xff1a;https://bugs.mysql.com/bug…

vsftp 无法启动,500 OOPS: bad bool value in config file for: anonymous_enable

朋友的FTP启动不了&#xff0c;叫我帮他看&#xff0c;启动时出现以下错误信息&#xff1a; 500 OOPS: bad bool value in config file for: anonymous_enable 看似配置文件错误&#xff0c;看了一下配置相应的行&#xff1a; anonymous_enableNO 语句没什么错误&#xff0c;不…

HDU ACM 1181 变形课 (广搜BFS + 动态数组vector)-------第一次使用动态数组vector

http://acm.hdu.edu.cn/showproblem.php?pid1181 题意&#xff1a;给我若干个单词,若单词A的结尾与单词B的开头相同,则表示A能变成B,判断能不能从b开头变成m结尾. 如: big-got-them 第一次使用动态数组vector View Code 1 #include <iostream>2 #include <vector>…

Max Sum 杭电 1003

2019独角兽企业重金招聘Python工程师标准>>> #题目概述 题目的意思是给你一个数列&#xff0c;找到一个子数列&#xff0c;这个子数列的和是所有子数列中和最大的。 当然把数列的所有数都列出来肯定不现实。 黑黑&#xff0c;不知道正不正确&#xff0c;我是先从第一…

shiro反序列化工具_Apache Shiro 1.2.4反序列化漏洞(CVE-2016-4437)源码解析

Apache ShiroApache Shiro是一个功能强大且灵活的开源安全框架,主要功能包括用户认证、授权、会话管理以及加密。在了解该漏洞之前,建议学习下Apache Shiro是怎么使用.debug环境jdk1.8Apache Shiro 1.2.4测试demo本地debug需要以下maven依赖<!-- https://mvnrepository.com/…

window 下的mysql_Windows下MySQL下载安装、配置与使用

用过MySQL之后&#xff0c;不论容量的话&#xff0c;发现比其他两个(sql server 、oracle)好用的多&#xff0c;一下子就喜欢上了。下面给那些还不知道怎么弄的童鞋们写下具体的方法步骤。(我这个写得有点太详细了&#xff0c;甚至有些繁琐&#xff0c;有很多步骤在其他的教程文…

H264视频通过RTMP直播

http://blog.csdn.net/firehood_/article/details/8783589 前面的文章中提到了通过RTSP&#xff08;Real Time Streaming Protocol&#xff09;的方式来实现视频的直播&#xff0c;但RTSP方式的一个弊端是如果需要支持客户端通过网页来访问&#xff0c;就需要在在页面中嵌入一个…

Pytorch 自定义激活函数前向与反向传播 ReLu系列 含优点与缺点

文章目录ReLu公式求导过程优点&#xff1a;缺点&#xff1a;自定义ReLu与Torch定义的比较可视化Leaky ReLu PReLu公式求导过程优点&#xff1a;缺点&#xff1a;自定义LeakyReLu与Torch定义的比较可视化自定义PReLuELU公式求导过程优点缺点自定义LeakyReLu与Torch定义的比较可视…

手势处理

在ios开发中&#xff0c;需用到对于手指的不同操作&#xff0c;以手指点击为例&#xff1a;分为单指单击、单指多击、多指单击、多指多击。对于这些事件进行不同的操作处理&#xff0c;由于使用系统自带的方法通过判断touches不太容易处理&#xff0c;而且会有事件之间的冲突。…

mybatis select count(*) 一直返回0 mysql_Mybatis教程1:MyBatis快速入门

点击上方“Java技术前线”&#xff0c;选择“置顶或者星标”与你一起成长一、Mybatis介绍MyBatis是一个支持普通*SQL*查询&#xff0c;存储过程和高级映射的优秀持久层框架。MyBatis消除了几乎所有的JDBC代码和参数的手工设置以及对结果集的检索封装。MyBatis可以使用简单的XML…

css预处理器sass使用教程(多图预警)

css预处理器赋予了css动态语言的特性&#xff0c;如变量、函数、运算、继承、嵌套等&#xff0c;有助于更好地组织管理样式文件&#xff0c;以及更高效地开发项目。css预处理器可以更方便的维护和管理css代码&#xff0c;让整个网页变得更加灵活可变。对于预处理器&#xff0c;…

mysql 主从优点_MySql主从配置实践及其优势浅谈

1、增加两个MySQL,我将C:\xampp\mysql下的MYSQL复制了一份&#xff0c;放到D:\Mysql2\Mysql5.1修改my.ini(linux下应该是my.cnf)&#xff1a;[client]port 3307[mysqld]port 3307basedirD:/Mysql2/Mysql5.1/mysqldatadirD:/Mysql2/Mysql5.1/mysql/data/之后&#xff0c;再增加…

python 多线程并发编程(生产者、消费者模式),边读图像,边处理图像,处理完后保存图像实现提高处理效率

文章目录需求实现先导入本次需要用到的包一些辅助函数如下函数是得到指定后缀的文件如下的函数一个是读图像&#xff0c;一个是把RGB转成BGR下面是主要的几个处理函数在上面几个函数构建对应的处理函数main函数按顺序执行结果需求 本次的需求是边读图像&#xff0c;边处理图像…

Sharepoint学习笔记—Site Definition系列-- 2、创建Content Type

Sharepoint本身就是一个丰富的大容器&#xff0c;里面存储的所有信息我们可以称其为“内容(Content)”&#xff0c;为了便于管理这些Conent&#xff0c;按照人类的正常逻辑就必然想到的是对此进行“分类”。分类所涉及到的层面又必然包括: 1、分类的标准或特征描述{即&#xf…

arduino byte转string_Java数组转List集合的三驾马车

点击上方 蓝字关注我们来源&#xff1a;blog.csdn.net/x541211190/article/details/79597236前言本文中的代码命名有的可能不太规范&#xff0c;是因为没法排版的问题&#xff0c;小仙已经很努力去解决了&#xff0c;希望各位能多多点赞、分享。好了&#xff0c;不多bb了(不要让…

ES6笔记(4)-- Symbol类型

系列文章 -- ES6笔记系列 Symbol是什么&#xff1f;中文意思是标志、记号&#xff0c;顾名思义&#xff0c;它可以用了做记号。 是的&#xff0c;它是一种标记的方法&#xff0c;被ES6引入作为一种新的数据类型&#xff0c;表示独一无二的值。 由此&#xff0c;JS的数据类型多了…

mysql类型说明_MYSQL 数据类型说明

MySQL支持大量的列类型&#xff0c;它可以被分为3类&#xff1a;数字类型、日期和时间类型以及字符串(字符)类型。本节首先给出可用类型的一个概述&#xff0c;并且总结每个列类型的存储需求&#xff0c;然后提供每个类中的类型性质的更详细的描述。概述有意简化&#xff0c;更…

LeetCode OJ - Convert Sorted List to Binary Search Tree

题目&#xff1a; Given a singly linked list where elements are sorted in ascending order, convert it to a height balanced BST. 解题思路&#xff1a; 注意是让构造平衡二叉搜索树。 每次将链表从中间断开&#xff0c;分成左右两部分。左边部分用来构造左子树&#xff…

手把手教你如下在Linux下如何写一个C语言代码,编译并运行

文章目录手把手教你如下在Linux下如何写一个C语言代码&#xff0c;编译并运行打开Ubuntu终端创建 helloworld.c编译C文件手把手教你如下在Linux下如何写一个C语言代码&#xff0c;编译并运行 打开Ubuntu终端 我这里的终端是Windows下的WSL&#xff0c;如果有疑问&#xff0c;…

邮件群发工具的编写(二)数据的保存

数据的保存与读取 人类是在不断探索与改进中进步的 上一篇&#xff0c;邮件群发工具的编写&#xff08;一&#xff09;邮件地址提取&#xff0c;我们讲到了邮箱的提取。 那么这一篇&#xff0c;讲一下提取完的邮箱信息的保存和读取。 首先&#xff0c;我希望对上一篇邮箱提取类…