PyTorch神经网络-激励函数

在PyTorch 神经网络当中,使用激励函数处理非线性的问题,普通的神经网络出来的数据一般是线性的关系,但是遇到比较复杂的数据的话,需要激励函数处理一些比较难以处理的问题,非线性结果就是其中的情况之一。


FAQ:为什么要使用激励函数?

  • 在简易乃至复杂的神经网络当中,对于神经元的数据结果为线性的时候,可能就会导致后面预测无法达到预期的效果,线性可能是一个不具体的一个范围,所以需要加入激励函数(一般是非线性的)去将数据进行处理

在少量的神经网络当中,一般使用 RELU激励函数,而在比较复杂的神经网络(包含循环神经网络RNN)当中,一般会使用 RELU或者 TANH函数进行处理
在这里插入图片描述

激励函数 Activation

下图为 Kaggle 编写的四种激励函数,分别是(relu,sigmoid,tanh和softplus)
在这里插入图片描述
可以看到上面四个激励函数当x轴为输出的数据,可以理解成为神经网络的输出结果,然后将神经网络素输出结果加入激活函数,也就四上图四个激励函数对应的y轴数据,经过了激励函数处理之后,结果就有了限制,这个就是不同的激励函数带来的不同效果

程序源码:

import torch
import torch.nn.functional as Ffrom torch.autograd import Variable
import matplotlib.pyplot as plt# fake
x = torch.linspace(-10,10,200)
x = Variable(x)
x_np = x.data.numpy()  # change the value to tensory_relu = F.relu(x).data.numpy()
y_sigmoid = F.sigmoid(x).data.numpy()
y_tanh = F.tanh(x).data.numpy()
y_softplus = F.softplus(x).data.numpy()plt.figure(1,figsize=(8,6))
plt.subplot(221)
plt.plot(x_np,y_relu,c='red',label='relu')
plt.ylim((-1,11))
plt.legend(loc='best')plt.subplot(222)
plt.plot(x_np,y_sigmoid,c='green',label='sigmoid')
plt.ylim((-0.2,1.2))
plt.legend(loc='best')plt.subplot(223)
plt.plot(x_np,y_tanh,c='blue',label='y_tanh')
plt.ylim((-1.2,2.2))
plt.legend(loc='best')plt.subplot(224)
plt.plot(x_np,y_softplus,c='yellow',label='y_softplus')
plt.ylim((-0.2,11))
plt.legend(loc='best')

🐱神经网络浅试

当了解了神经网络相关的原理之后,可以尝试着结合激励函数进行一个简单的Demo编写
主要氛围以下几个步骤:
(1)创建数据集->(2)建立神经网络->(3)训练数据->预测(显示验证可选)

1. 创建数据集

建立一些数据,去模拟真实的情况比如一个一元二次函数: y = a * x^2 + b, 我们给 y 数据加上一点噪声来更加真实的展示它。

import torch
import matplotlib.pyplot as pltx = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())                 # noisy y data (tensor), shape=(100, 1)# 画图
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()
2. 建立一个神经网络:

我们可以直接运用 torch 中的体系. 先定义所有的层属性(init()), 然后再一层层搭建(forward(x))层于层的关系链接. 建立关系的时候, 我们会用到激励函数。

import torch
import torch.nn.functional as F     # 激励函数都在这class Net(torch.nn.Module):  # 继承 torch 的 Moduledef __init__(self, n_feature, n_hidden, n_output):super(Net, self).__init__()     # 继承 __init__ 功能# 定义每层用什么样的形式self.hidden = torch.nn.Linear(n_feature, n_hidden)   # 隐藏层线性输出self.predict = torch.nn.Linear(n_hidden, n_output)   # 输出层线性输出def forward(self, x):   # 这同时也是 Module 中的 forward 功能# 正向传播输入值, 神经网络分析出输出值x = F.relu(self.hidden(x))      # 激励函数(隐藏层的线性值)x = self.predict(x)             # 输出值return xnet = Net(n_feature=1, n_hidden=10, n_output=1)print(net)  # net 的结构
"""
Net ((hidden): Linear (1 -> 10)(predict): Linear (10 -> 1)
)
"""
3.训练网络

训练的步骤如下:

# optimizer 是训练的工具
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  # 传入 net 的所有参数, 学习率
loss_func = torch.nn.MSELoss()      # 预测值和真实值的误差计算公式 (均方差)for t in range(100):prediction = net(x)     # 喂给 net 训练数据 x, 输出预测值loss = loss_func(prediction, y)     # 计算两者的误差optimizer.zero_grad()   # 清空上一步的残余更新参数值loss.backward()         # 误差反向传播, 计算参数更新值optimizer.step()        # 将参数更新值施加到 net 的 parameters 上
4.可视化训练的过程
import matplotlib.pyplot as pltplt.ion()   # 画图
plt.show()for t in range(200):...loss.backward()optimizer.step()# 接着上面来if t % 5 == 0:# plot and show learning processplt.cla()plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color':  'red'})plt.pause(0.1)

结合上述步骤,总结出代码如下,在这里的程序更改的训练次数为 100次

import torch 
import torch.nn.functional as Ffrom torch.autograd import Variable
import matplotlib.pyplot as plt  # 数据可视化处理工具# 输出数据
x = torch.unsqueeze(torch.linspace(-1,1,100),dim  =1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # x 的二次方 + 一些随机噪点plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()# 定义神经网络
class Net(torch.nn.Module):def __init__(self,n_features,n_hidden,n_output):# 搭建层所需要的信息super(Net,self).__init__()  # 初始化函数继承self.hidden = torch.nn.Linear(n_features,n_hidden)  # 隐藏层包含了,少的输出和输出self.predict = torch.nn.Linear(n_hidden,1)  # 输出值为 1,只是一个值def forward(self,x):# 前一层的信息,也就是xx = F.relu(self.hidden(x))  # 激励函数激活x = self.predict(x)  # 输出xreturn x# Net(1,100,1)中的1表示输出数据为1个,100 为神经元个数,最后的1表示输出的数据,也是为1
net = Net(1,100,1)
print(net)# 优化
optimizer = torch.optim.SGD(net.parameters(),lr = 0.5) # lr < 1
loss_function = torch.nn.MSELoss()plt.ion()
plt.show()for t in range(100):prediction = net(x)loss = loss_function(prediction,y)optimizer.zero_grad()  # 将传进来的参数的地图设置为零loss.backward()  # 反向传递过程optimizer.step()  # 优化梯度if t % 5 == 0:plt.cla()plt.scatter(x.data.numpy(),y.data.numpy())plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)plt.text(0.5,0,'Loss%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'red'})plt.pause(0.1)
plt.ioff()
plt.show()

请添加图片描述

请添加图片描述

请添加图片描述
请添加图片描述请添加图片描述请添加图片描述
请添加图片描述

请添加图片描述请添加图片描述

请添加图片描述请添加图片描述

请添加图片描述

请添加图片描述
请添加图片描述

请添加图片描述

最后一张图为训练拟合的结果图,可以看到,红色的预测线将传入的随机点有了很接近的拟合,说明神经网络内的训练和优化有了很大的效果。


🌸🌸🌸完结撒花🌸🌸🌸


🌈🌈Redamancy🌈🌈

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/151618.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Java 进阶篇】Ajax 实现——JQuery 实现方式 `get` 与 `post`

嗨&#xff0c;亲爱的小白们&#xff01;欢迎来到这篇关于使用 jQuery 实现 Ajax 请求的博客。在前端开发中&#xff0c;Ajax 是一项非常重要的技术&#xff0c;它使我们能够在不刷新整个页面的情况下与服务器进行数据交互。而在 jQuery 中&#xff0c;get 和 post 方法提供了简…

全流量分析应用运行和访问情况

在当今数字化时代&#xff0c;应用程序的运行和访问情况对于企业和组织来说至关重要。无论是在线销售平台、移动应用还是企业内部系统&#xff0c;应用的性能和可用性直接影响着用户体验、业务流程以及组织效率。因此&#xff0c;对应用的运行和访问情况进行全面分析和评估&…

JZM-D30室温探针台技术参数

概况&#xff1a; JZM-D30室温探针台的诸多设计都是专用的&#xff0c;探针台的配置主要是根据用户的需求进行选配及设计。例如&#xff0c;要求的磁场型号&#xff0c;电源型号&#xff0c;磁场值&#xff0c;样品台的尺寸等&#xff0c;除此之外&#xff0c;该探针台和我司自…

Go 语言中的map和内存泄漏

map在内存中总是会增长&#xff1b;它不会收缩。因此&#xff0c;如果map导致了一些内存问题&#xff0c;你可以尝试不同的选项&#xff0c;比如强制 Go 重新创建map或使用指针。 在 Go 中使用map时&#xff0c;我们需要了解map增长和收缩的一些重要特性。让我们深入探讨这一点…

架构开发与优化咨询和实施服务

服务概述 得益于硬件平台算力的提升&#xff0c;汽车电子电气架构的集成度逐渐提高&#xff0c;从单体ECU、到功能域集成控制器、到区域集成控制器&#xff0c;多域融合成为了目前行业中软件工程的重要工作内容。同时&#xff0c;在传统控制器C代码开发的基础上&#xff0c;C、…

手把手从零开始训练YOLOv8改进项目(官方ultralytics版本)教程

手把手从零开始训练 YOLOv8 改进项目 (Ultralytics版本) 教程,改进 YOLOv8 算法 本文以Windows服务器为例:从零开始使用Windows训练 YOLOv8 算法项目 《芒果 YOLOv8 目标检测算法 改进》 适用于芒果专栏改进 YOLOv8 算法 文章目录 官方 YOLOv8 算法介绍改进网络代码汇总第…

ROS参数服务器(Param):通信模型、Hello World与拓展

参数服务器在ROS中主要用于实现不同节点之间的数据共享。 参数服务器相当于是独立于所有节点的一个公共容器&#xff0c;可以将数据存储在该容器中&#xff0c;被不同的节点调用&#xff0c;当然不同的节点也可以往其中存储数据。 使用场景一般存储一些机器人的固有参数&…

AIGC 技术在淘淘秀场景的探索与实践

本文介绍了AIGC相关领域的爆发式增长&#xff0c;并探讨了淘宝秀秀(AI买家秀)的设计思路和技术方案。文章涵盖了图像生成、仿真形象生成和换背景方案&#xff0c;以及模型流程串联等关键技术。 文章还介绍了淘淘秀的使用流程和遇到的问题及处理方法。最后&#xff0c;文章展望…

安全项目简介

安全项目 基线检查 密码 复杂度有效期 用户访问和身份验证 禁用administrator禁用guest认证失败锁定 安全防护软件操作系统安全配置 关闭自动播放 文件和目录权限端口限制安全审计… 等保测评 是否举办了安全意识培训是否有应急响应预案有无第一负责人 工作内容 测评准备…

【VRTK】【VR开发】【Unity】7-配置交互能力和向量追踪

【前情提要】 目前为止,我们虽然设定了手模型和动画,还能够正确根据输入触发动作,不过还未能与任何物体互动。要互动,需要给手部设定相应的Interactor能力。 【配置Interactor的抓取功能】 在Hierarchy中选中[VRTK_CAMERA_RIGS_SETUP] ➤ Camera Rigs, Tracked Alias ➤ …

Attingo:西部数据部分SSD存在硬件设计制造缺陷

今年5月&#xff0c;西部数据SanDisk Extreme Pro硬盘陆续有用户反馈有故障发生&#xff0c;用户反馈最多的问题是数据丢失和硬件损坏。8月份&#xff0c;因为这个事情&#xff0c;还被爆出&#xff0c;西部数据面临用户的集体诉讼。 近期&#xff0c;有一个专门从事数据恢复的…

高防CDN的需求分析:社会与企业发展的推动力

在当今数字化飞速发展的时代&#xff0c;网络安全成为社会和企业发展的关键因素之一。随着网络攻击手段的不断升级&#xff0c;企业对于高防CDN&#xff08;内容分发网络&#xff09;的需求逐渐成为保障业务稳健运行的重要部分。从社会和企业发展的角度来看&#xff0c;高防CDN…

【Java 进阶篇】Ajax 实现——原生JS方式

大家好&#xff0c;欢迎来到这篇关于原生 JavaScript 中使用 Ajax 实现的博客&#xff01;在前端开发中&#xff0c;我们经常需要与服务器进行数据交互&#xff0c;而 Ajax&#xff08;Asynchronous JavaScript and XML&#xff09;是一种用于创建异步请求的技术&#xff0c;它…

Javaweb之Vue生命周期的详细解析

2.4 生命周期 vue的生命周期&#xff1a;指的是vue对象从创建到销毁的过程。vue的生命周期包含8个阶段&#xff1a;每触发一个生命周期事件&#xff0c;会自动执行一个生命周期方法&#xff0c;这些生命周期方法也被称为钩子方法。其完整的生命周期如下图所示&#xff1a; 状…

代码随想录算法训练营第四十九天| 123.买卖股票的最佳时机III 188.买卖股票的最佳时机IV

文档讲解&#xff1a;代码随想录 视频讲解&#xff1a;代码随想录B站账号 状态&#xff1a;看了视频题解和文章解析后做出来了 123.买卖股票的最佳时机III class Solution:def maxProfit(self, prices: List[int]) -> int:if len(prices) 0:return 0dp [[0] * 5 for _ in…

安装2023最新版PyCharm来开发Python应用程序

安装2023最新版PyCharm来开发Python应用程序 Install the Latest JetBrains PyCharm Community to Develop Python Applications Python 3.12.0最新版已经由其官网python.org发布&#xff0c;这也是2023年底的最新的版本。 0. PyCharm与Python 自从1991年2月20日&#xff0…

【Java】抽象类和接口

文章目录 一、抽象类1.抽象类的概念2.抽象类的语法3.抽象类的特性4.抽象类的作用 二、接口1.接口的概念2.语法规则3.接口的使用4.接口的特性5.实现多个接口6.接口间的继承7.接口的使用实例8.Clonable 接口和深拷贝9.抽象类和接口的区别 三、Object类1.获取对象信息2.对象的比较…

Python基础入门----如何通过conda搭建Python开发环境

文章目录 使用 conda 搭建Python开发环境是非常方便的,它可以帮助你管理Python版本、依赖库、虚拟环境等。以下是一个简单的步骤,演示如何通过 conda 搭建Python开发环境: 安装conda: 如果你还没有安装 conda,首先需要安装Anaconda或Miniconda。Anaconda是一个包含很多数据…

pythom导出mysql指定binlog文件

要求 要求本地有py环境和全局环境变量 先测试直接执行binlog命令执行命令 Windows 本地直接执行命令 # E:\output>E:\phpstudy_pro\Extensions\MySQL5.7.26\bin\mysqlbinlog binglog文件地址 # --no-defaults 不限制编码 # -h mysql链接地址 # -u mysql 链接名称 # -p m…