PyTorch神经网络-激励函数

在PyTorch 神经网络当中,使用激励函数处理非线性的问题,普通的神经网络出来的数据一般是线性的关系,但是遇到比较复杂的数据的话,需要激励函数处理一些比较难以处理的问题,非线性结果就是其中的情况之一。


FAQ:为什么要使用激励函数?

  • 在简易乃至复杂的神经网络当中,对于神经元的数据结果为线性的时候,可能就会导致后面预测无法达到预期的效果,线性可能是一个不具体的一个范围,所以需要加入激励函数(一般是非线性的)去将数据进行处理

在少量的神经网络当中,一般使用 RELU激励函数,而在比较复杂的神经网络(包含循环神经网络RNN)当中,一般会使用 RELU或者 TANH函数进行处理
在这里插入图片描述

激励函数 Activation

下图为 Kaggle 编写的四种激励函数,分别是(relu,sigmoid,tanh和softplus)
在这里插入图片描述
可以看到上面四个激励函数当x轴为输出的数据,可以理解成为神经网络的输出结果,然后将神经网络素输出结果加入激活函数,也就四上图四个激励函数对应的y轴数据,经过了激励函数处理之后,结果就有了限制,这个就是不同的激励函数带来的不同效果

程序源码:

import torch
import torch.nn.functional as Ffrom torch.autograd import Variable
import matplotlib.pyplot as plt# fake
x = torch.linspace(-10,10,200)
x = Variable(x)
x_np = x.data.numpy()  # change the value to tensory_relu = F.relu(x).data.numpy()
y_sigmoid = F.sigmoid(x).data.numpy()
y_tanh = F.tanh(x).data.numpy()
y_softplus = F.softplus(x).data.numpy()plt.figure(1,figsize=(8,6))
plt.subplot(221)
plt.plot(x_np,y_relu,c='red',label='relu')
plt.ylim((-1,11))
plt.legend(loc='best')plt.subplot(222)
plt.plot(x_np,y_sigmoid,c='green',label='sigmoid')
plt.ylim((-0.2,1.2))
plt.legend(loc='best')plt.subplot(223)
plt.plot(x_np,y_tanh,c='blue',label='y_tanh')
plt.ylim((-1.2,2.2))
plt.legend(loc='best')plt.subplot(224)
plt.plot(x_np,y_softplus,c='yellow',label='y_softplus')
plt.ylim((-0.2,11))
plt.legend(loc='best')

🐱神经网络浅试

当了解了神经网络相关的原理之后,可以尝试着结合激励函数进行一个简单的Demo编写
主要氛围以下几个步骤:
(1)创建数据集->(2)建立神经网络->(3)训练数据->预测(显示验证可选)

1. 创建数据集

建立一些数据,去模拟真实的情况比如一个一元二次函数: y = a * x^2 + b, 我们给 y 数据加上一点噪声来更加真实的展示它。

import torch
import matplotlib.pyplot as pltx = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())                 # noisy y data (tensor), shape=(100, 1)# 画图
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()
2. 建立一个神经网络:

我们可以直接运用 torch 中的体系. 先定义所有的层属性(init()), 然后再一层层搭建(forward(x))层于层的关系链接. 建立关系的时候, 我们会用到激励函数。

import torch
import torch.nn.functional as F     # 激励函数都在这class Net(torch.nn.Module):  # 继承 torch 的 Moduledef __init__(self, n_feature, n_hidden, n_output):super(Net, self).__init__()     # 继承 __init__ 功能# 定义每层用什么样的形式self.hidden = torch.nn.Linear(n_feature, n_hidden)   # 隐藏层线性输出self.predict = torch.nn.Linear(n_hidden, n_output)   # 输出层线性输出def forward(self, x):   # 这同时也是 Module 中的 forward 功能# 正向传播输入值, 神经网络分析出输出值x = F.relu(self.hidden(x))      # 激励函数(隐藏层的线性值)x = self.predict(x)             # 输出值return xnet = Net(n_feature=1, n_hidden=10, n_output=1)print(net)  # net 的结构
"""
Net ((hidden): Linear (1 -> 10)(predict): Linear (10 -> 1)
)
"""
3.训练网络

训练的步骤如下:

# optimizer 是训练的工具
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  # 传入 net 的所有参数, 学习率
loss_func = torch.nn.MSELoss()      # 预测值和真实值的误差计算公式 (均方差)for t in range(100):prediction = net(x)     # 喂给 net 训练数据 x, 输出预测值loss = loss_func(prediction, y)     # 计算两者的误差optimizer.zero_grad()   # 清空上一步的残余更新参数值loss.backward()         # 误差反向传播, 计算参数更新值optimizer.step()        # 将参数更新值施加到 net 的 parameters 上
4.可视化训练的过程
import matplotlib.pyplot as pltplt.ion()   # 画图
plt.show()for t in range(200):...loss.backward()optimizer.step()# 接着上面来if t % 5 == 0:# plot and show learning processplt.cla()plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color':  'red'})plt.pause(0.1)

结合上述步骤,总结出代码如下,在这里的程序更改的训练次数为 100次

import torch 
import torch.nn.functional as Ffrom torch.autograd import Variable
import matplotlib.pyplot as plt  # 数据可视化处理工具# 输出数据
x = torch.unsqueeze(torch.linspace(-1,1,100),dim  =1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # x 的二次方 + 一些随机噪点plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()# 定义神经网络
class Net(torch.nn.Module):def __init__(self,n_features,n_hidden,n_output):# 搭建层所需要的信息super(Net,self).__init__()  # 初始化函数继承self.hidden = torch.nn.Linear(n_features,n_hidden)  # 隐藏层包含了,少的输出和输出self.predict = torch.nn.Linear(n_hidden,1)  # 输出值为 1,只是一个值def forward(self,x):# 前一层的信息,也就是xx = F.relu(self.hidden(x))  # 激励函数激活x = self.predict(x)  # 输出xreturn x# Net(1,100,1)中的1表示输出数据为1个,100 为神经元个数,最后的1表示输出的数据,也是为1
net = Net(1,100,1)
print(net)# 优化
optimizer = torch.optim.SGD(net.parameters(),lr = 0.5) # lr < 1
loss_function = torch.nn.MSELoss()plt.ion()
plt.show()for t in range(100):prediction = net(x)loss = loss_function(prediction,y)optimizer.zero_grad()  # 将传进来的参数的地图设置为零loss.backward()  # 反向传递过程optimizer.step()  # 优化梯度if t % 5 == 0:plt.cla()plt.scatter(x.data.numpy(),y.data.numpy())plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)plt.text(0.5,0,'Loss%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'red'})plt.pause(0.1)
plt.ioff()
plt.show()

请添加图片描述

请添加图片描述

请添加图片描述
请添加图片描述请添加图片描述请添加图片描述
请添加图片描述

请添加图片描述请添加图片描述

请添加图片描述请添加图片描述

请添加图片描述

请添加图片描述
请添加图片描述

请添加图片描述

最后一张图为训练拟合的结果图,可以看到,红色的预测线将传入的随机点有了很接近的拟合,说明神经网络内的训练和优化有了很大的效果。


🌸🌸🌸完结撒花🌸🌸🌸


🌈🌈Redamancy🌈🌈

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/151618.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C# 32应用程序获取64位操作系统注册表

若C#的程序都是32位的&#xff0c;访问注册表的时候&#xff0c;会访问HKEY_LOCAL_MACHINE\SOFTWARE\Wow6432Node\&#xff0c; 而访问不到HKEY_LOCAL_MACHINE\SOFTWARE 适用版本&#xff1a;.NET 4.0及更高版本 public static Dictionary<string, string> GetInstalled…

【Java 进阶篇】Ajax 实现——JQuery 实现方式 `get` 与 `post`

嗨&#xff0c;亲爱的小白们&#xff01;欢迎来到这篇关于使用 jQuery 实现 Ajax 请求的博客。在前端开发中&#xff0c;Ajax 是一项非常重要的技术&#xff0c;它使我们能够在不刷新整个页面的情况下与服务器进行数据交互。而在 jQuery 中&#xff0c;get 和 post 方法提供了简…

vue之Error: Unknown option: .devServer.

背景 在使用内网穿透工具时&#xff0c;加入对应的配置&#xff0c;启动出现报错。 一、遇到的问题 报错&#xff1a; Error: Unknown option: .devServer. Check out https://babeljs.io/docs/en/babel-core/#options for more information about options. Error: Unknown …

全流量分析应用运行和访问情况

在当今数字化时代&#xff0c;应用程序的运行和访问情况对于企业和组织来说至关重要。无论是在线销售平台、移动应用还是企业内部系统&#xff0c;应用的性能和可用性直接影响着用户体验、业务流程以及组织效率。因此&#xff0c;对应用的运行和访问情况进行全面分析和评估&…

JZM-D30室温探针台技术参数

概况&#xff1a; JZM-D30室温探针台的诸多设计都是专用的&#xff0c;探针台的配置主要是根据用户的需求进行选配及设计。例如&#xff0c;要求的磁场型号&#xff0c;电源型号&#xff0c;磁场值&#xff0c;样品台的尺寸等&#xff0c;除此之外&#xff0c;该探针台和我司自…

Go 语言中的map和内存泄漏

map在内存中总是会增长&#xff1b;它不会收缩。因此&#xff0c;如果map导致了一些内存问题&#xff0c;你可以尝试不同的选项&#xff0c;比如强制 Go 重新创建map或使用指针。 在 Go 中使用map时&#xff0c;我们需要了解map增长和收缩的一些重要特性。让我们深入探讨这一点…

架构开发与优化咨询和实施服务

服务概述 得益于硬件平台算力的提升&#xff0c;汽车电子电气架构的集成度逐渐提高&#xff0c;从单体ECU、到功能域集成控制器、到区域集成控制器&#xff0c;多域融合成为了目前行业中软件工程的重要工作内容。同时&#xff0c;在传统控制器C代码开发的基础上&#xff0c;C、…

手把手从零开始训练YOLOv8改进项目(官方ultralytics版本)教程

手把手从零开始训练 YOLOv8 改进项目 (Ultralytics版本) 教程,改进 YOLOv8 算法 本文以Windows服务器为例:从零开始使用Windows训练 YOLOv8 算法项目 《芒果 YOLOv8 目标检测算法 改进》 适用于芒果专栏改进 YOLOv8 算法 文章目录 官方 YOLOv8 算法介绍改进网络代码汇总第…

CISP模拟试题(一)

免责声明 文章仅做经验分享用途,利用本文章所提供的信息而造成的任何直接或者间接的后果及损失,均由使用者本人负责,作者不为此承担任何责任,一旦造成后果请自行承担!!! 1.下面关于信息安全保障的说法错误的是:C A.信息安全保障的概念是与信息安全的概念同时产生的 …

ROS参数服务器(Param):通信模型、Hello World与拓展

参数服务器在ROS中主要用于实现不同节点之间的数据共享。 参数服务器相当于是独立于所有节点的一个公共容器&#xff0c;可以将数据存储在该容器中&#xff0c;被不同的节点调用&#xff0c;当然不同的节点也可以往其中存储数据。 使用场景一般存储一些机器人的固有参数&…

20、动态路由_下滑线为前缀的目录

创建文件 pages_question\index.vue pages_question\detail.vue 生成的对应路由&#xff1a; const _6bf6ece8 () > interopDefault(import(..\\pages\\_question\\index.vue /* webpackChunkName: "pages/_question/index" */)) const _a98c80aa () > in…

AIGC 技术在淘淘秀场景的探索与实践

本文介绍了AIGC相关领域的爆发式增长&#xff0c;并探讨了淘宝秀秀(AI买家秀)的设计思路和技术方案。文章涵盖了图像生成、仿真形象生成和换背景方案&#xff0c;以及模型流程串联等关键技术。 文章还介绍了淘淘秀的使用流程和遇到的问题及处理方法。最后&#xff0c;文章展望…

安全项目简介

安全项目 基线检查 密码 复杂度有效期 用户访问和身份验证 禁用administrator禁用guest认证失败锁定 安全防护软件操作系统安全配置 关闭自动播放 文件和目录权限端口限制安全审计… 等保测评 是否举办了安全意识培训是否有应急响应预案有无第一负责人 工作内容 测评准备…

Python实现精确控制asyncio并发过程中的多个任务(1)

前言 本文是该专栏的第37篇,后面会持续分享python的各种干货知识,值得关注。 asyncio是Python中并发编程的一种实现方式,它是Python3.4版本引入的标准库,直接内置了对异步IO的支持。异步,就是多个任务之间执行没有先后顺序,可以同时运行,执行的先后顺序不会有什么影响,…

【VRTK】【VR开发】【Unity】7-配置交互能力和向量追踪

【前情提要】 目前为止,我们虽然设定了手模型和动画,还能够正确根据输入触发动作,不过还未能与任何物体互动。要互动,需要给手部设定相应的Interactor能力。 【配置Interactor的抓取功能】 在Hierarchy中选中[VRTK_CAMERA_RIGS_SETUP] ➤ Camera Rigs, Tracked Alias ➤ …

(BMS)电池管理系统技术研究与仿真

目录 简介 1、 建立电池模型 1.1 、脉冲放电实验 1.2、 离线参数辨识方法优化

Attingo:西部数据部分SSD存在硬件设计制造缺陷

今年5月&#xff0c;西部数据SanDisk Extreme Pro硬盘陆续有用户反馈有故障发生&#xff0c;用户反馈最多的问题是数据丢失和硬件损坏。8月份&#xff0c;因为这个事情&#xff0c;还被爆出&#xff0c;西部数据面临用户的集体诉讼。 近期&#xff0c;有一个专门从事数据恢复的…

高防CDN的需求分析:社会与企业发展的推动力

在当今数字化飞速发展的时代&#xff0c;网络安全成为社会和企业发展的关键因素之一。随着网络攻击手段的不断升级&#xff0c;企业对于高防CDN&#xff08;内容分发网络&#xff09;的需求逐渐成为保障业务稳健运行的重要部分。从社会和企业发展的角度来看&#xff0c;高防CDN…

【Java 进阶篇】Ajax 实现——原生JS方式

大家好&#xff0c;欢迎来到这篇关于原生 JavaScript 中使用 Ajax 实现的博客&#xff01;在前端开发中&#xff0c;我们经常需要与服务器进行数据交互&#xff0c;而 Ajax&#xff08;Asynchronous JavaScript and XML&#xff09;是一种用于创建异步请求的技术&#xff0c;它…

getchar函数的功能有哪些

getchar函数是C语言标准库中的一个函数&#xff0c;主要用于从标准输入&#xff08;通常是键盘&#xff09;获取一个字符。它的功能包括&#xff1a; 从标准输入获取一个字符&#xff1a;getchar函数会等待用户输入一个字符&#xff0c;然后将其返回给程序。可以通过控制台输入…