深度学习:使用全连接神经网络FCN实现MNIST手写数字识别

1 引言

本项目构建了一个全连接神经网络(FCN),实现对MINST数据集手写数字的识别,没有借助任何深度学习算法库,从原理上理解手写数字识别的全过程,包括反向传播,梯度下降等。

2 全连接神经网络介绍

2.1 什么是全连接神经网络

全连接网络(Fully-Connected Network,简称FCN),即在多层神经网络中,第nN层的每个神经元都分别与第N-1层的神经元相互连接。如下图便是一个简单的全连接网络:

2.2 损失函数

损失函数(loss function)在深度学习领域是用来计算搭建模型预测的输出值和真实值之间的误差,是一种衡量模型与数据吻合程度的算法。损失函数的值越高预测就越错误,损失函数值越低则预测越接近真实值。对每个单独的观测(数据点)计算损失函数。将所有损失函数(loss function)的值取平均值的函数称为代价函数(cost function),更简单的理解就是损失函数是针对单个样本的,而代价函数是针对所有样本的。

  • 损失函数越小越好
  • 计算实际输出与目标之间的差距
  • 为更新输出提供依据(反向传播)

常见的损失函数

(1)均方误差损失(Mean Squared Error,MSE)

均方误差损失MSE,又称L2 Loss,用于计算模型输出y_hat 和目标值y 之差的均方差。一般用在线性回归中,可以理解为最小二乘法。均方差损失是机器学习、深度学习回归任务中最常用的一种损失函数

(2)平均绝对误差(Mean Absolute Error,MAE)

平均绝对误差MAE,又称L1 Loss,是另一种用于回归模型的损失函数。和 MSE 一样,这种度量方法也是在不考虑方向(如果考虑方向,那将被称为平均偏差(Mean Bias Error, MBE),它是残差或误差之和)的情况下衡量误差大小。但和 MSE 的不同之处在于,MAE 需要像线性规划这样更复杂的工具来计算梯度。此外,MAE 对异常值更加稳健,因为它不使用平方。损失范围也是 0 到 ∞。

(3)交叉熵损失函数(Cross Entropy Loss)

交叉熵(Cross Entropy)是Shannon信息论中一个重要概念,主要用于度量两个概率分布间的差异性信息。语言模型的性能通常用交叉熵和复杂度(perplexity)来衡量。交叉熵的意义是用该模型对文本识别的难度,或者从压缩的角度来看,每个词平均要用几个位来编码。Cross Entropy损失函数是分类问题中最常见的损失函数。

2.3 反向传播

误差反向传播(Back-propagation, BP)算法的出现是神经网络发展的重大突破,也是现在众多深度学习训练方法的基础。该方法会计算神经网络中损失函数对各参数的梯度,配合优化方法更新参数,降低损失函数。BP本来只指损失函数对参数的梯度通过网络反向流动的过程,但现在也常被理解成神经网络整个的训练方法由误差传播、参数更新两个环节循环迭代组成。

神经网络的训练过程中,前向传播和反向传播交替进行,前向传播通过训练数据和权重参数计算输出结果;反向传播通过导数链式法则计算损失函数对各参数的梯度,并根据梯度进行参数的更新

 

3 使用FCN实现MNIST手写数字识别

3.1 MINIST数据集介绍

MNIST数据集是美国国家标准与技术研究院收集整理的大型手写数字数据库,包含60,000个示例的训练集以及10,000个示例的测试集。其中的图像的尺寸为28*28。采样数据显示如下:

3.2 FCN识别MINIST数据集代码实现

import torch
from torch import nn
from torchvision import datasets
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import numpy as npclass MnistNet(nn.Module):def __init__(self):super().__init__()self.layer = nn.Sequential(# 图片的原尺寸为28*28,转化为784,输入层为784,输出层为256nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 64),nn.ReLU(),nn.Linear(64, 16),nn.ReLU(),nn.Linear(16, 10),nn.Softmax(dim=1))def forward(self, x):x = x.view(-1, 28*28*1)return self.layer(x)batchsize = 32
lr = 0.01transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307, ), (0.3081,))])data_train = datasets.MNIST(root="./data/", transform=transform, train=True, download=True)
data_test = datasets.MNIST(root="./data/", transform=transform, train=False)train_loader = torch.utils.data.DataLoader(data_train, batch_size=batchsize, shuffle=True)
test_loader = torch.utils.data.DataLoader(data_test, batch_size=batchsize, shuffle=False)if __name__ == '__main__':model = MnistNet()criterion = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.5)for i in range(5):plt.subplot(1, 5, i + 1)plt.xticks([])plt.yticks([])plt.imshow(data_train.data[i], cmap=plt.cm.binary)plt.show()lepoch = []llost = []lacc = []epochs = 30for epoch in range(epochs):lost = 0count = 0for num, (x, y) in enumerate(train_loader, 1):y_h = model(x)loss = criterion(y_h, y)optimizer.zero_grad()loss.backward()optimizer.step()lost += loss.item()count += batchsizeprint('epoch:', epoch + 1, 'loss:', lost / count, end=' ')lepoch.append(epoch + 1)llost.append(lost / count)with torch.no_grad():acc = 0count = 0for num, (x, y) in enumerate(test_loader, 1):y_h = model(x)_, y_h = torch.max(y_h.data, dim=1)acc += (y_h == y).sum().item()count += x.size(0)test_acc = acc / count * 100lacc.append(test_acc)print('acc:', test_acc)plt.plot(lepoch, llost, label='loss')plt.plot(lepoch, lacc, label='acc')plt.legend()plt.show()

3.3 结果输出

经过30个epoch后,在测试集上的准确率达到了97.3%

epoch: 1 loss: 0.0697015597740809 acc: 56.120000000000005
epoch: 2 loss: 0.0542279725531737 acc: 81.2
epoch: 3 loss: 0.051337766939401626 acc: 83.53
epoch: 4 loss: 0.05083678769866626 acc: 84.49
epoch: 5 loss: 0.05052243163983027 acc: 85.09
epoch: 6 loss: 0.05029139596422513 acc: 85.65
epoch: 7 loss: 0.050102355525890985 acc: 86.14
epoch: 8 loss: 0.04994755889574687 acc: 86.02
epoch: 9 loss: 0.0498184863169988 acc: 86.71
epoch: 10 loss: 0.04970114469528198 acc: 86.81
epoch: 11 loss: 0.04792855019172033 acc: 94.86
epoch: 12 loss: 0.047099880089362466 acc: 95.64
epoch: 13 loss: 0.04690476657748222 acc: 96.04
epoch: 14 loss: 0.04677621142864227 acc: 96.32
epoch: 15 loss: 0.046683601369460426 acc: 96.52
epoch: 16 loss: 0.04659009942809741 acc: 96.69
epoch: 17 loss: 0.04652327968676885 acc: 96.72
epoch: 18 loss: 0.04646410925189654 acc: 96.81
epoch: 19 loss: 0.0464125766257445 acc: 96.75
epoch: 20 loss: 0.04636456128358841 acc: 97.07000000000001
epoch: 21 loss: 0.046326734560728076 acc: 96.85000000000001
epoch: 22 loss: 0.04628034559885661 acc: 96.91
epoch: 23 loss: 0.04625135076443354 acc: 97.0
epoch: 24 loss: 0.046217381453514096 acc: 97.14
epoch: 25 loss: 0.046193461724122364 acc: 97.03
epoch: 26 loss: 0.046168098962306975 acc: 97.16
epoch: 27 loss: 0.0461397964378198 acc: 97.27
epoch: 28 loss: 0.0461252645790577 acc: 97.22
epoch: 29 loss: 0.04609716224273046 acc: 97.19
epoch: 30 loss: 0.04608173056840897 acc: 97.3

准确率变化曲线如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/18421.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

maven引入本地jar包的简单方式【IDEA】【SpringBoot】

前言 想必点进来看这篇文章的各位,都是已经习惯了Maven从中央仓库或者阿里仓库直接拉取jar包进行使用。我也是🤡🤡。 前两天遇到一个工作场景,对接三方平台,结果对方就是提供的一个jar包下载链接,可给我整…

SpringBoot使用MyBatis Plus + 自动更新数据表

1、Mybatis Plus介绍 Mybatis,用过的都知道,这里不介绍,mybatis plus只是在mybatis原来的基础上做了些改进,增强了些功能,增强的功能主要为增加更多常用接口方法调用,减少xml内sql语句编写,也可…

python使用selenium 打开谷歌浏览器闪退, 怎么解决

问题描述: 大家早好、午好、晚好吖 ❤ ~欢迎光临本文章 使用 Selenium 操作 Chrome 浏览器, Chrome 浏览器闪退 问题解决: 可能是以下几个方面出现了问题: 1. Chromedriver 版本与 Chrome 浏览器版本不匹配 你需要确保你正在…

安卓:JzvdStd——网络视频播放器

目录 一、JzvdStd介绍 JzvdStd的特点和功能: JzvdStd常用方法: 二、JzvdStd使用 1、补充知识: 例子: MainActivity : VideoPageAdapter : activity_main: video_page: …

第十次CCF计算机软件能力认证

第一题:分蛋糕 小明今天生日,他有 n 块蛋糕要分给朋友们吃,这 n 块蛋糕(编号为 1 到 n)的重量分别为 a1,a2,…,an。 小明想分给每个朋友至少重量为 k 的蛋糕。 小明的朋友们已经排好队准备领蛋糕,对于每个朋…

Blazor前后端框架Known-V1.2.9

V1.2.9 Known是基于C#和Blazor开发的前后端分离快速开发框架,开箱即用,跨平台,一处代码,多处运行。 Gitee: https://gitee.com/known/KnownGithub:https://github.com/known/Known 概述 基于C#和Blazor…

UE4 unlua学习笔记

将这三个插件放入Plugins内并重新编译 创建一个BlueprintLibrary,声明一个全局函数 在这里声明路径 点击Create Lua Template 在Content的Script即可生成对应的lua文件打开它! 显示以上lua代码 打印Hello Unlua 创建该UI,就会在创建UI的Con…

Flutter-基础Widget

Flutter页面-基础Widget 文章目录 Flutter页面-基础WidgetWidgetStateless WidgetStateful WidgetState生命周期 基础widget文本显示TextRichTextDefaultTextStyle 图片显示FlutterLogoIconImageIamge.assetImage.fileImage.networkImage.memory CircleAvatarFadeInImage 按钮R…

火山引擎DataLeap如何解决SLA治理难题(二):申报签署流程与复盘详解

申报签署流程详解 火山引擎DataLeap SLA保障的前提是先达成SLA协议。在SLA保障平台中,以 申报单签署的形式达成SLA协议。平台核心特点是 优化了SLA达成的流程,先通过 “系统卡点计算”减少待签署任务的数量,再通过 “SLA推荐计算”自动签署部…

【Linux】网络基础

🍎作者:阿润菜菜 📖专栏:Linux系统网络编程 文章目录 一、协议初识和网络协议分层(TCP/IP四层模型)认识协议TCP/IP五层(或四层)模型 二、认识MAC地址和IP地址认识MAC地址认识IP地址认…

基于Java的闲置物品管理系统(源码+文档+数据库)

很多在校学生经常因为冲动或者因为图一时的新鲜,购买了很多可能只是偶尔用一下的物品,大量物品将会闲置,因此,构建一个资源共享平台,将会极大满足师院学生的需求,可以将其闲置物品挂在资源共享平台上让有需要的学生浏览&#xff0…

Linux【网络基础】数据链路层IP协议技术补充DNSDHCP

文章目录 一、数据链路层(1)数据链路层与网络层的关联(2)局域网通信原理(3)以太网协议(4)ARP协议 二、NAT协议三、NAPT协议四、ICMP协议五、DNS六、DHCP 一、数据链路层 &#xff0…

二、JVM-深入运行时数据区

深入运行时数据区 计算机体系结构 JVM的设计实际上遵循了遵循冯诺依曼计算机结构 CPU与内存交互图: 硬件一致性协议: MSI、MESI、MOSI、Synapse、Firely、DragonProtocol 摩尔定律 摩尔定律是由英特尔(Intel)创始人之一戈登摩尔(Gordon Moore)提出来…

配置GIt账号、配置公钥

1.设置账号和邮箱 打开终端输入以下命令: git config --global --unset-all user.name git config --global --unset-all user.email然后输入以下命令来设置新的账号和邮箱: git config --global user.name "your_username" git config --glo…

与“云”共舞,联想凌拓的新科技与新突破

伴随着数字经济的高速发展,IT信息技术在数字中国建设中起到的驱动和支撑作用也愈发凸显。特别是2023年人工智能和ChatGPT在全球的持续火爆,更是为整个IT产业注入了澎湃动力。那么面对日新月异的IT信息技术,再结合疫情之后截然不同的经济环境和…

效率提升丨大学必看校园安全实用技巧

在当今社会,教育是培养人才、传承文明的重要场所。然而,教学楼作为学生、教师和员工活动的核心区域,也存在着潜在的安全隐患,其中最为突出的风险之一是火灾。火灾不仅危及生命财产,还可能给整个学校带来不可估量的损失…

vue3中使用原始标签制作一个拖拽和点击上传组件上传成功后展示

在Vue3中&#xff0c;可以使用<input type"file">标签来实现上传文件的功能&#xff0c;同时可以通过<div>标签来实现拖拽上传的功能。 首先&#xff0c;在template中定义一个包含<input>和<div>标签的组件&#xff1a; <template>&…

【C++】模板学习(二)

模板学习 非类型模板参数模板特化函数模板特化类模板特化全特化偏特化 模板分离编译模板总结 非类型模板参数 模板参数除了类型形参&#xff0c;还可以是非类型的形参。 非类型形参要求用一个常量作为类(函数)模板的一个参数。这个参数必须是整形家族的。浮点数&#xff0c;字…

pytorch学习——正则化技术——丢弃法(dropout)

一、概念介绍 在多层感知机&#xff08;MLP&#xff09;中&#xff0c;丢弃法&#xff08;Dropout&#xff09;是一种常用的正则化技术&#xff0c;旨在防止过拟合。&#xff08;效果一般比前面的权重衰退好&#xff09; 在丢弃法中&#xff0c;随机选择一部分神经元并将其输出…