【代码】python实现一个BP神经网络-原理讲解与代码展示

本文来自《老饼讲解-BP神经网络》https://www.bbbdata.com/

目录

  • 一、BP神经网络原理回顾
    • 1.1 BP神经网络的结构简单回顾
    • 1.2.BP神经网络的训练算法流程
  • 二、python实现BP神经网络代码
    • 2.1.数据介绍
    • 2.2.pytorch实现BP神经网络代码

在python中要如何使用代码实现一个BP神经网络呢?
在python中可以利用pytorch来实现BP神经网络,这是最简洁也是最常用的方法。
通过本文可以详细掌握怎么使用python的pytorch来实现一个BP神经网络。

一、BP神经网络原理回顾

1.1 BP神经网络的结构简单回顾

BP神经网络的结构如下:
BP神经网络结构图
BP神经网络由输入层、隐层、输出层组成,其中隐层可以是有多层的,整个网络以前馈式进行计算,也就是每层的输出作为下层的输入,不断套娃,直到输出层

每层的计算公式如下:
y = T ( W X + B ) y=T(WX+B) y=T(WX+B)
其中,
X:该层的输入
W:该层的权重
B:该层的阈值
T:该层的激活函数

1.2.BP神经网络的训练算法流程

梯度下降算法求解BP神经网络的流程如下:
梯度下降算法求解BP神经网络

一、先初始化一个解                                                 
二、迭代                                                                  
1. 计算所有w,b在当前处的梯度dw,db           
2. 将w,b往负梯度方向更新:                       w = w-lr*dw                       b = b-lr*db       
3. 判断是否满足退出条件,如果满足,则退出迭代

二、python实现BP神经网络代码

在python中只需要使用pytorch就可以简单实现BP神经网络,而且提供了丰富的训练算法。

2.1.数据介绍

为方便理解,不妨采用以下的简单数据:
在这里插入图片描述
上述即为sin函数在[-5,5]之间的20个采样数据

2.2.pytorch实现BP神经网络代码

下面展示在pytorch中实现BP神经网络的代码
特别说明:需要先安装pytorch包

import torch
import matplotlib.pyplot as plt 
torch.manual_seed(99)# -----------计算网络输出:前馈式计算---------------
def forward(w1,b1,w2,b2,x):                                   return w2@torch.tanh(w1@x+b1)+b2# -----------计算损失函数: 使用均方差--------------
def loss(y,py):return ((y-py)**2).mean()# ------训练数据----------------
x = torch.linspace(-5,5,20).reshape(1,20)                      # 在[-5,5]之间生成20个数作为x
y = torch.sin(x)                                               # 模型的输出值y#-----------训练模型------------------------
in_num  = x.shape[0]                                            # 输入个数
out_num = y.shape[0]                                            # 输出个数
hn  = 4                                                         # 隐节点个数
w1  = torch.randn([hn,in_num],requires_grad=True)               # 初始化输入层到隐层的权重w1
b1  = torch.randn([hn,1],requires_grad=True)                    # 初始化隐层的阈值b1
w2  = torch.randn([out_num,hn],requires_grad=True)              # 初始化隐层到输出层的权重w2
b2  = torch.randn([out_num,1],requires_grad=True)               # 初始化输出层的阈值b2lr = 0.01                                                       # 学习率
for i in range(5000):                                           # 训练5000步py = forward(w1,b1,w2,b2,x)                                 # 计算网络的输出L = loss(y,py)                                              # 计算损失函数print('第',str(i),'轮:',L)                                 # 打印当前损失函数值L.backward()                                                # 用损失函数更新模型参数的梯度w1.data=w1.data-w1.grad*lr                                  # 更新模型系数w1b1.data=b1.data-b1.grad*lr                                  # 更新模型系数b1w2.data=w2.data-w2.grad*lr                                  # 更新模型系数w2b2.data=b2.data-b2.grad*lr                                  # 更新模型系数b2w1.grad.zero_()                                             # 清空w1梯度,以便下次backwardb1.grad.zero_()                                             # 清空b1梯度,以便下次backwardw2.grad.zero_()                                             # 清空w2梯度,以便下次backwardb2.grad.zero_()                                             # 清空b2梯度,以便下次backward
px = torch.linspace(-5,5,100).reshape(1,100)                    # 测试数据,用于绘制网络的拟合曲线    
py = forward(w1,b1,w2,b2,px).detach().numpy()                   # 网络的预测值
plt.scatter(x, y)                                               # 绘制样本
plt.plot(px[0,:],py[0,:])                                       # 绘制拟合曲线  
print('w1:',w1)
print('b1:',b1)
print('w2:',w2)
print('b2:',b2)

运行结果如下:

.....                                            
第 4996 轮: tensor(0.0083, grad_fn=<MeanBackward0>)
第 4997 轮: tensor(0.0083, grad_fn=<MeanBackward0>)
第 4998 轮: tensor(0.0083, grad_fn=<MeanBackward0>)
第 4999 轮: tensor(0.0083, grad_fn=<MeanBackward0>)
w1: tensor([[ 0.1742],[-0.8133],[-0.6450],[-0.4054]],requires_grad=True)
b1: tensor([[ 0.8125],[0.0593],[-1.8776],[1.1220]],requires_grad=True)
w2: tensor([[-0.7753,-2.0142,1.1161,1.9635]],requires_grad=True)
b2: tensor([[0.1094]], requires_grad=True)   

运行结果
可以看到,模型根据训练数据,已经较好地拟合出sin函数曲线


相关链接:

《老饼讲解-机器学习》:老饼讲解-机器学习教程-通俗易懂
《老饼讲解-神经网络》:老饼讲解-matlab神经网络-通俗易懂
《老饼讲解-神经网络》:老饼讲解-深度学习-通俗易懂

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/31944.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用Rsbuild构建基于Vue3+Vant4开发h5应用

目录 一、介绍 1.1 Vant介绍 1.2 Rsbuild介绍 1.3 Vue介绍 二、构建应用 1.第一步 2.第二步 3.第三步 4.第四步 5.第五步 6.在项目中使用 Vant4 组件 7.移动端适配Rem 8. 执行 cnpm run dev 启动项目 一、介绍 1.1 Vant介绍 Vant 是一个轻量、可定制的移动端组…

单机小游戏好上架的应用市场有哪些?

&#x1f3c6;本文收录于「Bug调优」专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收藏&&…

Vue3中的常见组件通信(超详细版)

Vue3中的常见组件通信 概述 ​ 在vue3中常见的组件通信有props、mitt、v-model、 r e f s 、 refs、 refs、parent、provide、inject、pinia、slot等。不同的组件关系用不同的传递方式。常见的撘配形式如下表所示。 组件关系传递方式父传子1. props2. v-model3. $refs4. 默认…

Mac电脑FTP客户端推荐:Transmit 5 for Mac 中文版

Transmit 5是一款专为macOS平台设计的功能强大的FTP&#xff08;文件传输协议&#xff09;客户端软件。Transmit 5凭借其强大的功能、直观易用的界面和高效的性能&#xff0c;成为需要频繁进行文件传输和管理的个人用户和专业用户的理想选择。无论是对于新手还是经验丰富的用户…

Starlink全系卫星详细介绍,波段频谱、激光星间链路技术、数据传输速率等等

Starlink全系卫星详细介绍&#xff0c;波段频谱、激光星间链路技术、数据传输速率等等。 Starlink是SpaceX公司开发的一个低轨道&#xff08;LEO&#xff09;卫星网络系统&#xff0c;旨在为全球用户提供高速宽带互联网服务。截至2024年6月&#xff0c;Starlink已经发射并运行…

终于找到了免费的云服务器

今天朋友推荐了一个免费的云服务器&#xff1a;“阿贝云” 我最喜欢的是它的"免费虚拟主机"“免费云服务器”&#xff0c;省了我好多钱&#xff0c;我的使用感受是用起来经济实惠省心&#xff0c;不要钱的东西谁不喜欢呢&#xff0c;对于普通开发者来说&#xff0c;…

长尾式差分放大电路调零

长尾式放大电路用了两个参数相同的三极管&#xff0c;但实际上并没有完全相同的三极管&#xff0c;所以为了提高差分放大电路的对称性(一边电流增加多少&#xff0c;另一边电流减小多少&#xff0c;即能在电阻Re上产生的压降不变(后面做虚地处理))&#xff0c;在下图中加入可调…

【Linux 杂记】TOP命令

top命令用于动态显示系统中正在运行的进程的详细信息&#xff0c;以及系统的整体资源使用情况。以下是其主要输出解释&#xff1a; Header 表头信息&#xff1a; top&#xff1a;当前时间和运行时间。Tasks&#xff1a;进程统计信息&#xff0c;如总进程数、运行中、睡眠中等。…

xocde编辑器支持修改为中文吗?不支持

xocde编辑器支持修改为中文吗&#xff1f; 不支持

rttys服务器和客户端

rttys服务器 1.下载 https://github.com/zhaojh329/rttys/releases2.解压运行 libev交叉编译 cd libev ./configure --hostarm-linux CCaarch64-poky-linux-gcc --prefix/home/michael/rtty_install make install DESTDIR/home/michael/rtty_installrtty客户端 1.git地…

RabbitMQ —— 理解及应用场景

一、MQ相关的概念 RabbitMQ 是一种分布式消息中间件&#xff0c;消息中间件也称消息队列MQ&#xff0c;那么什么是MQ呢&#xff1f;请继续阅读下文。 1.1、MQ的基本概念 什么是MQ MQ(message queue)&#xff0c;从字面意思上看就个 FIFO 先入先出的队列&#xff0c;只不过队列…

2024 年解锁 Android 手机的 7 种简便方法

您是否忘记了 Android 手机的 Android 锁屏密码&#xff0c;并且您的手机已被锁定&#xff1f;您需要使用锁屏解锁 Android 手机&#xff1f;别担心&#xff0c;您不是唯一一个忘记密码的人。我将向您展示如何解锁 Android 手机的锁屏。 密码 PIN 可保护您的 Android 手机和 G…

Node.js中基于node-schedule实现定时任务之详解

文章目录 一、定时任务二、node-schedule、1、安装2、引入3、基于Cron表达式的规则4、基于Date的规则5、基于RecurrenceRule的规则6、API7、状态监听 一、定时任务 实际工作中&#xff0c;可能会遇到定时清除某个文件夹内容&#xff0c;定时发送消息或发送邮件给指定用户&…

Django集成OpenAI

Django集成OpenAI 通过前面 django 框架的基本开发知识&#xff0c;我们现在可以开始在 django 上做稍微深一点当然应用开发了。 这一章开始编写怎么集成调用 openai &#xff0c;设置环境以及 openai 的基础知识。 大家都知道 ai 的多模态逐渐扩大&#xff0c;各种应用层出…

怎么采集阿里巴巴1688的商品或商家数据?

怎么使用简数采集器批量采集阿里巴巴1688的商品或商家相关信息呢&#xff1f; 简数采集器暂时不支持采集阿里巴巴1688的相关数据&#xff0c;谢谢。 简数采集器采集网络网页数据非常简单高效&#xff1a;输入要采集的网址&#xff0c;简数智能算法会自动提取出网页上的关键信…

Charles抓取安卓应用https包演示

一、准备软件 夜神安卓模拟器 (yeshen.com) Charles (charlesproxy.com) 二、配置抓包 2.1 Charles安装PC根证书 记住这里的ip端口 三、安卓模拟器配置 3.1 配置安卓客户端网络代理 填写上文的ip端口&#xff0c;保存 3.2 安装根证书 3.2.1 导出根证书 linux主机执行 op…

推荐4款实用工具,非常好用,建议收藏

PDFREAL PDFReal 是一个功能强大的在线PDF编辑工具&#xff0c;提供多种实用的PDF处理功能。用户可以在一个网站上完成包括PDF合并、PDF拆分、PDF压缩、PDF保护、PDF解锁等多种操作。此外&#xff0c;PDFReal 还支持将文本转换为PDF、将图片转换为PDF、添加水印、提取页面内容等…

基于Django、Bootstrap的电影推荐系统,算法基于用户的协同过滤算法,有爬虫有可视化后台

背景 基于Django和Bootstrap的电影推荐系统结合了用户协同过滤算法&#xff0c;通过爬虫技术获取电影数据&#xff0c;并在可视化后台展示推荐结果。该系统旨在提供个性化的电影推荐服务&#xff0c;帮助用户发现符合其喜好的电影。 用户协同过滤算法是一种常用的推荐算法&am…

qt开发-09_分裂器

QSplitter 是 Qt 框架中的一个非常实用的控件&#xff0c;用于创建可调整大小的窗格。它允许用户通过拖动子窗口间的边界&#xff08;也称为分割条&#xff09;来动态调整子窗口的尺寸。这在开发需要多个视图同时显示&#xff0c;且用户需要根据需要调整每个视图大小的应用程序…

机器学习python实践——由特征选择引发的关于卡方检验的一些个人思考

最近在用python进行机器学习实践&#xff0c;在做到特征选择这一部分时&#xff0c;对于SelectPercentile和SelectKBest方法有些不理解&#xff0c;所以去了查看了帮助文档&#xff0c;但是在帮助文档的例子中出现了"chi2"&#xff0c;没接触过&#xff0c;看过去就更…