神经网络构建原理(以MINIST为例)

神经网络构建原理(以MINIST为例)

在 MNIST 手写数字识别任务中,构建神经网络并训练模型来进行分类是经典的深度学习应用。MNIST 数据集包含 28x28 像素的手写数字图像(0-9),任务是构建一个神经网络,能够根据输入的图像预测对应的数字。本文将通过该案例详细介绍神经网络的逻辑框架和具体的计算流程。

神经网络构建框架

1.数据预处理

  • 将输入数据进行标准化处理(归一化),并将标签转换为适合模型的格式( one-hot 编码)。

2.模型构建

  • a.输入层:定义输入的大小(将 28x28 的图像展平为 784 维向量)。
  • b.隐藏层:添加一个或多个隐藏层,每层包含一定数量的神经元,并应用激活函数(如 ReLU)[实际上神经元相当于维数,该过程即是对原特征维数进行扩维或降维]。
  • c.输出层:定义输出层的神经元数量(与分类类别一致),通常使用 Softmax 函数将输出转换为概率。

3.前向传播

  • 执行输入层到隐藏层、再到输出层的矩阵乘法和激活函数,计算输出值。

4.损失函数计算

  • 使用交叉熵等损失函数,计算预测输出与真实标签之间的误差。

5.反向传播

  • 通过链式法则,计算损失函数对每个参数的偏导数,并更新权重和偏置项。

6.优化器更新

  • 使用优化器(如 SGD、Adam)基于计算的梯度更新模型参数,降低损失值。

7.迭代训练

  • 不断重复前向传播、损失计算和反向传播,直到损失收敛或达到设定的训练轮次。

8.模型评估与预测

  • 训练完成后,用测试数据评估模型性能,并进行新数据的预测。
    在这里插入图片描述
    在这里插入图片描述

神经网络的具体计算流程

接下来以MINIST手写数字识别为例,模拟神经网络构建的具体计算过程。

假设该网络包含 两个隐藏层,每个隐藏层有 25 个神经元,最后的输出层为 10 个神经元

1. 前向传播(Forward Propagation)

1.1输入层到第一个隐藏层:
  • 输入大小:假设输入图像是 28x28 的像素矩阵,展平成 784 维的向量。 x ∈ R 784 × 1 x \in \mathbb{R}^{784 \times 1} xR784×1

  • 权重矩阵:连接输入层到第一个隐藏层的权重矩阵,大小为 25x784,[因为特征向量是列向量,所以需要转置]。 W 1 ∈ R 25 × 784 W_1 \in \mathbb{R}^{25 \times 784} W1R25×784

  • 偏置项:第一个隐藏层的偏置项,大小为 25x1。 b 1 ∈ R 25 × 1 b_1 \in \mathbb{R}^{25 \times 1} b1R25×1

  • 激活函数:使用 ReLU 激活函数。

计算步骤

  • 执行矩阵乘法 z 1 = W 1 ⋅ x + b 1 z_1 = W_1 \cdot x + b_1 z1=W1x+b1

z 1 z_1 z1 的维度是 25x1。

  • 应用 ReLU 激活函数: h 1 = ReLU ( z 1 ) h_1 = \text{ReLU}(z_1) h1=ReLU(z1)

其中: ReLU ( z 1 ) = max ⁡ ( 0 , z 1 ) \text{ReLU}(z_1) = \max(0, z_1) ReLU(z1)=max(0,z1)

  • 结果:第一个隐藏层的输出 h 1 h_1 h1 是 25 维的向量。
    在这里插入图片描述
1.2第一个隐藏层到第二个隐藏层:
  • 权重矩阵:连接第一个隐藏层到第二个隐藏层的权重矩阵,大小为 25x25。 W 2 ∈ R 25 × 25 W_2 \in \mathbb{R}^{25 \times 25} W2R25×25

  • 偏置项:第二个隐藏层的偏置项,大小为 25x1。 b 2 ∈ R 25 × 1 b_2 \in \mathbb{R}^{25 \times 1} b2R25×1

计算步骤

  • 执行矩阵乘法 z 2 = W 2 ⋅ h 1 + b 2 z_2 = W_2 \cdot h_1 + b_2 z2=W2h1+b2

z 2 z_2 z2的维度是 25x1。

  • 应用 ReLU 激活函数: h 2 = ReLU ( z 2 ) h_2 = \text{ReLU}(z_2) h2=ReLU(z2)

  • 结果:第二个隐藏层的输出 h 2 h_2 h2 仍然是 25 维的向量。

1.3第二个隐藏层到输出层:
  • 权重矩阵:连接第二个隐藏层到输出层的权重矩阵,大小为 10x25。 W 3 ∈ R 10 × 25 W_3 \in \mathbb{R}^{10 \times 25} W3R10×25

  • 偏置项:输出层的偏置项,大小为 10x1。 b 3 ∈ R 10 × 1 b_3 \in \mathbb{R}^{10 \times 1} b3R10×1

计算步骤

  • 执行矩阵乘法 z 3 = W 3 ⋅ h 2 + b 3 z_3 = W_3 \cdot h_2 + b_3 z3=W3h2+b3

z 3 z_3 z3的维度是 10x1。

  • 应用 Softmax 函数将输出转换为概率: Softmax ( z 3 ) i = e z 3 i ∑ j = 1 10 e z 3 j \text{Softmax}(z_3)_i = \frac{e^{z_{3i}}}{\sum_{j=1}^{10} e^{z_{3j}}} Softmax(z3)i=j=110ez3jez3i

Softmax 输出是 10 维的概率向量,表示输入属于 0-9 的概率。

2. 损失计算

使用交叉熵损失函数来计算预测输出与真实标签之间的误差,假设真实标签是 one-hot 编码的向量 y ∈ R 10 y \in \mathbb{R}^{10} yR10,其中,
y i = 1 y_i = 1 yi=1 表示真实类别, p i = Softmax ( z 3 ) i p_i = \text{Softmax}(z_3)_i pi=Softmax(z3)i 表示模型对类别 i i i 的预测概率。
在这里插入图片描述

交叉熵损失公式 L = − ∑ i = 1 10 y i log ⁡ ( p i ) L = -\sum_{i=1}^{10} y_i \log(p_i) L=i=110yilog(pi)

损失计算步骤

  • 对于每一个样本,计算预测类别对应的概率 p i p_i pi 的对数,然后计算损失 L L L

3. 反向传播(Backward Propagation)

反向传播的目标是通过链式法则计算损失函数对每层权重的偏导数,并更新权重矩阵。

3.1输出层到第二个隐藏层:
  • 计算损失对输出层的导数 ∂ L ∂ z 3 = Softmax ( z 3 ) − y \frac{\partial L}{\partial z_3} = \text{Softmax}(z_3) - y z3L=Softmax(z3)y

  • 计算损失对 W 3 W_3 W3的导数 ∂ L ∂ W 3 = ∂ L ∂ z 3 ⋅ h 2 T \frac{\partial L}{\partial W_3} = \frac{\partial L}{\partial z_3} \cdot h_2^T W3L=z3Lh2T

  • 计算损失对 b 3 b_3 b3的导数 ∂ L ∂ b 3 = ∂ L ∂ z 3 \frac{\partial L}{\partial b_3} = \frac{\partial L}{\partial z_3} b3L=z3L

3.2第二个隐藏层到第一个隐藏层:
  • 损失传播到第二层的输出 h 2 h_2 h2 ∂ L ∂ h 2 = W 3 T ⋅ ∂ L ∂ z 3 \frac{\partial L}{\partial h_2} = W_3^T \cdot \frac{\partial L}{\partial z_3} h2L=W3Tz3L

  • 计算 ReLU 激活函数的导数 ∂ L ∂ z 2 = ∂ L ∂ h 2 ⋅ ReLU ′ ( z 2 ) \frac{\partial L}{\partial z_2} = \frac{\partial L}{\partial h_2} \cdot \text{ReLU}'(z_2) z2L=h2LReLU(z2)

其中: ReLU ′ ( z 2 ) = { 1 if  z 2 > 0 0 if  z 2 ≤ 0 \text{ReLU}'(z_2) = \begin{cases} 1 & \text{if } z_2 > 0 \\ 0 & \text{if } z_2 \leq 0 \end{cases} ReLU(z2)={10if z2>0if z20

  • 计算损失对 W 2 W_2 W2的导数 ∂ L ∂ W 2 = ∂ L ∂ z 2 ⋅ h 1 T \frac{\partial L}{\partial W_2} = \frac{\partial L}{\partial z_2} \cdot h_1^T W2L=z2Lh1T

  • 计算损失对 b 2 b_2 b2的导数 ∂ L ∂ b 2 = ∂ L ∂ z 2 \frac{\partial L}{\partial b_2} = \frac{\partial L}{\partial z_2} b2L=z2L

3.3第一个隐藏层到输入层:
  • 损失传播到第一层的输出 h 1 h_1 h1 ∂ L ∂ h 1 = W 2 T ⋅ ∂ L ∂ z 2 \frac{\partial L}{\partial h_1} = W_2^T \cdot \frac{\partial L}{\partial z_2} h1L=W2Tz2L

  • 计算 ReLU 激活函数的导数 ∂ L ∂ z 1 = ∂ L ∂ h 1 ⋅ ReLU ′ ( z 1 ) \frac{\partial L}{\partial z_1} = \frac{\partial L}{\partial h_1} \cdot \text{ReLU}'(z_1) z1L=h1LReLU(z1)

  • 计算损失对 h 1 h_1 h1的导数 ∂ L ∂ W 1 = ∂ L ∂ z 1 ⋅ x T \frac{\partial L}{\partial W_1} = \frac{\partial L}{\partial z_1} \cdot x^T W1L=z1LxT

  • 计算损失对 b 1 b_1 b1的导数 ∂ L ∂ b 1 = ∂ L ∂ z 1 \frac{\partial L}{\partial b_1} = \frac{\partial L}{\partial z_1} b1L=z1L

4. 权重更新

使用梯度下降算法或 Adam 优化器来更新权重。

更新公式 W = W − η ⋅ ∂ L ∂ W W = W - \eta \cdot \frac{\partial L}{\partial W} W=WηWL

  • W W W是权重矩阵, η η η是学习率, ∂ L ∂ W \frac{\partial L}{\partial W} WL 是损失函数关于权重的梯度。

权重更新步骤在每一层执行:

  • 更新 W 1 W_1 W1, W 2 W_2 W2, W 3 W_3 W3和对应的偏置项 b 1 b_1 b1, b 2 b_2 b2, b 3 b_3 b3

5. 优化器(Adam)介绍

Adam 优化器通过结合动量和自适应学习率进行参数更新。详细的更新公式在上面的回答中已经给出。

1.一阶动量估计:
计算当前梯度 ∇ θ L \nabla_{\theta}L θL的加权平均,用来估计梯度的期望。这个一阶动量主要是累积之前的梯度,使得更新方向更加平滑。

m t = β 1 m t − 1 + ( 1 − β 1 ) ∇ θ L m_t = \beta_1 m_{t-1} + (1 - \beta_1) \nabla_{\theta}L mt=β1mt1+(1β1)θL
β 1 \beta_1 β1是一阶动量的衰减率,通常取值为 0.9。
m t m_t mt是当前的动量(梯度的指数加权平均)。

2.二阶矩估计:
计算当前梯度平方的加权平均,估计梯度的方差,用来调节学习率,避免更新步长过大。
v t = β 2 v t − 1 + ( 1 − β 2 ) ( ∇ θ L ) 2 v_t = \beta_2 v_{t-1} + (1 - \beta_2) (\nabla_{\theta}L)^2 vt=β2vt1+(1β2)(θL)2
β 2 \beta_2 β2是二阶动量的衰减率,通常取值为 0.999。
v t v_t vt是梯度平方的指数加权平均。

3.偏差修正:
由于 m ^ t \hat{m}_t m^t v ^ t \hat{v}_t v^t在前几步可能会有较大的偏差,Adam 引入了偏差修正,减少估计的偏差。
m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} m^t=1β1tmt,v^t=1β2tvt
m ^ t \hat{m}_t m^t v ^ t \hat{v}_t v^t是偏差修正后的动量和二阶矩估计.

4.参数更新:
使用修正后的动量和方差来更新参数。Adam 的更新方式是自适应的,能根据梯度的历史动态调整学习率。
W t + 1 = W t − η m ^ t v ^ t + ϵ W_{t+1} = W_t - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} Wt+1=Wtηv^t +ϵm^t
η η η是学习率,通常取值在 0.001 左右。
$ \epsilon$是一个小的平滑项,避免除以零,通常为 1 0 − 8 10^{-8} 108


Reference:

  1. TensorFlow Documentation
  2. CS231n Convolutional Neural Networks for Visual Recognition
  3. Adam Optimizer Paper
  4. Gradient Descent and Backpropagation Overview
  5. https://www.deeplearningbook.org/
  6. http://cs231n.github.io/optimization-2/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/54545.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

吉首大学--23级题目讲解

7-1 单链表基本操作 在 C/C 中,.(点)和 ->(箭头)运算符用于访问结构体或类的成员,但它们的使用场景不同。 1. . 运算符 . 运算符用于访问结构体或类的成员,通过对象或结构体变量直接访问。…

JS函数部分

函数调用 无参数 var fun function() {console.log(被调用) //不区分单引号双引号}fun () //有无分号都可有参数 var fun function(a, b) {var sum abconsole.log(sum) }fun (10,20) 立即执行函数 被定义完立即调用,且执行一次 (function (){alert(ni);})()创建…

es的封装

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、类和接口介绍0.封装思想1.es的操作分类 二、创建索引1.成员变量2.构造函数2.添加字段3.发送请求4.创建索引总体代码 三.插入数据四.删除数据五.查询数据 前…

Element Plus 中Input输入框

通过鼠标或键盘输入字符 input为受控组件,他总会显示Vue绑定值,正常情况下,input的输入事件会正常被响应,他的处理程序应该更新组件的绑定值(或使用v-model)。否则,输入框的值将不会改变 不支…

windows环境下配置MySQL主从启动失败 查看data文件夹中.err发现报错unknown variable ‘log‐bin=mysql‐bin‘

文章目录 问题解决方法 问题 今天在windows环境下配置MySQL主从同步,在修改my.ini文件后发现MySQL启动失败了 打开my.ini检查参数发现没有问题 [mysqld] #开启二进制日志,记录了所有更改数据库数据的SQL语句 log‐bin mysql‐bin #设置服务id&#x…

梧桐数据库(WuTongDB):Volcano/Cascades 优化器框架简介

Volcano/Cascades 是现代关系数据库系统中使用的两种重要的查询优化器框架,它们用于将高层 SQL 查询转换为高效的执行计划。它们采用了一种基于规则的方式来探索各种可能的查询执行计划,目的是选择一个代价最小的计划。以下是对这两种框架的详细讲解&…

[数据集][目标检测]不同颜色的安全帽检测数据集VOC+YOLO格式7574张5类别

重要说明:数据集里面有2/3是增强数据集,请仔细查看图片预览,确认符合要求在下载,分辨率均为640x640 数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件…

Python 二级考试

易错点 电脑基础知识 定义学生关系模式如下:Student (S#, Sn, Ssex,class,monitorS#)(其属性分别为学号、学生名、性别、班级和班长学号) 在关系模式中,如果…

全志A523 系统篇(一) 获取vmlinux

通过固件获取 longan/build/getvmlinux.sh ./getvmlinux.sh <aw-format-firmware> 其中<aw-format-firmware>为全志格式的包含vmlinux的固件 运行成功后&#xff0c;会在脚本目录下生成output目录&#xff0c;目录里面包含vmlinux.fex&#xff08;vmlinux的.ta…

python-SZ斐波那契数列/更相减损数

一&#xff1a;SZ斐波那契数列题目描述 你应该很熟悉斐波那契数列&#xff0c;不是吗&#xff1f;现在小理不知在哪里搞了个山寨版斐波拉契数列&#xff0c;如下公式&#xff1a; F(n) { $\ \ \ \ \ \ \ \ \ \ \ \ $ a,( n1) $\ \ \ \ \ \ \ \ \ \ \ \ $ b,( n2) $\ \ \ \ \ \ …

【优选算法之双指针】No.2--- 经典双指针算法(下)

文章目录 前言一、双指针示例&#xff1a;1.1 ⽔果成篮1.2 和为s的两个数字1.3 三数之和1.4 四数之和 二、双指针总结&#xff1a; 前言 &#x1f467;个人主页&#xff1a;小沈YO. &#x1f61a;小编介绍&#xff1a;欢迎来到我的乱七八糟小星球&#x1f31d; &#x1f4cb;专…

git-fork操作指南

git-fork操作指南 1.fork github仓库2. clone fork仓库3. 分支修改4.与原始仓库保持修改同步4.1添加上游仓库4.2 拉取上游分支4.3 合并更改4.4 推送更改 参考&#xff1a; 有时候我们需要将github的项目fork到自己名下&#xff0c;然后修改并提交pull request&#xff0c;这里将…

安装黑群晖系统,并使用NAS公网助手访问教程(好文)

由于正版群晖系统的价格不菲&#xff0c;对于预算有限的用户来说&#xff0c;安装黑群晖系统成为了一个不错的选择&#xff08;如果您预算充足&#xff0c;建议选择白群晖&#xff09;。如您对宅系科技比较感兴趣&#xff0c;欢迎查看本文&#xff0c;将详细介绍如何安装黑群晖…

reg和wire的区别 HDL语言

文章目录 数据类型根本区别什么时候要定义wire小结 数据类型 HDL语言有三种数据类型&#xff1a;寄存器数据类型&#xff08;reg&#xff09;、线网数据类型&#xff08;wire&#xff09;、参数数据类型&#xff08;parameter&#xff09;。 根本区别 reg&#xff1a; 寄存器…

【算法题】53. 最大子数组和-力扣(LeetCode)

【算法题】53. 最大子数组和-力扣(LeetCode) 1.题目 下方是力扣官方题目的地址 53. 最大子数组和 给你一个整数数组 nums &#xff0c;请你找出一个具有最大和的连续子数组&#xff08;子数组最少包含一个元素&#xff09;&#xff0c;返回其最大和。 子数组 是数组中的一…

allWebPlugin中间件自定义alert、confirm及prompt使用

allWebPlugin简介 allWebPlugin中间件是一款为用户提供安全、可靠、便捷的浏览器插件服务的中间件产品&#xff0c;致力于将浏览器插件重新应用到所有浏览器。它将现有ActiveX控件直接嵌入浏览器&#xff0c;实现插件加载、界面显示、接口调用、事件回调等。支持Chrome、Firefo…

跨游戏引擎的H5渲染解决方案(腾讯)

本文是腾讯的一篇H5 跨引擎解决方案的精炼。 介绍 本文通过实现基于精简版的HTML5&#xff08;HyperText Mark Language 5&#xff09;来屏蔽不同引擎&#xff0c;平台底层的差异。 好处&#xff1a; 采用H5的开发方式&#xff0c;可以将开发和运营分离&#xff0c;运营部门自…

pip install、yum install和conda install三者技术区分

pip install、yum install和conda install在安装系统环境时可以从以下几个方面进行区分选择&#xff1a; 一、适用范围 pip install 主要用于安装 Python 包。适用于 Python 项目中特定的库和工具的安装。如果你的项目是纯 Python 开发&#xff0c;并且需要安装各种 Python 库&…

Great_data=>Copy_Data=>Chart_RealTime=>UI_All

Great_data -------------------- import csv import os import random from datetime import datetime import logging import time # 配置日志记录 logging.basicConfig(filename=D:/_Study/Case/Great_Data/log.txt, level=logging.INFO, …

代码随想录Day 51|题目:99.岛屿数量、100.岛屿的最大面积

提示&#xff1a;DDU&#xff0c;供自己复习使用。欢迎大家前来讨论~ 文章目录 题目一&#xff1a;99. 岛屿数量思路深度优先搜索DFS广度优先搜索BFS 题目二&#xff1a;100. 岛屿的最大面积DFSBFS 总结 题目一&#xff1a;99. 岛屿数量 99. 岛屿数量 (kamacoder.com) 思路 …