基于卷积神经网络实现手写数字识别

基于卷积神经网络实现手写数字识别

基于卷积神经网络实现手写数字识别。具体过程如下:

(1) 定义ConvNet结构类及其前向传播方式

(2) 设置超参数以及导入相关的包。

(3) 定义训练网络函数和绘图函数,并在main函数中完成调用过程

程序
import os 
import numpy as np 
#from sklearn.datasets import fetch_openml # 引入openml数据源
from matplotlib import pyplot as plt # 引入绘图工具
import torch
from torchvision.datasets import mnist
#from mnist_models import AlexNet, ConvNet
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import VariableBASE_PATH = os.path.dirname(__file__)# 设置模型超参数
EPOCHS = 50
SAVE_PATH = './models''''
# 载入MNIST数据集并显示部分样本
def load_mnist():# 从openml源载入MNIST数据集mnist = fetch_openml('mnist_784', version=1, data_home=os.path.join(BASE_PATH, './dataset'))X, y = mnist['data'], mnist['target']#X = mnist['data']#.astype(np.float32)#y = mnist['target']#.astype(np.int32)print('MNIST数据集大小:{}'.format(X.shape))# 显示其中25张样本图片for i in range(25):#print(i)digit = X.iloc[i * 2500]# 将图片恢复到28*28大小digit_image = digit.values.reshape(28, 28)# 绘制图片plt.subplot(5, 5, i + 1)# 隐藏坐标轴plt.axis('off')# 按灰度图绘制图片plt.imshow(digit_image, cmap='gray')# 显示图片plt.show()return X, y
'''# 定义卷积网络结构
class ConvNet(torch.nn.Module):def __init__(self):super(ConvNet, self).__init__()self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1, 10, 5, 1, 1),torch.nn.MaxPool2d(2),torch.nn.ReLU(),torch.nn.BatchNorm2d(10))self.conv2 = torch.nn.Sequential(torch.nn.Conv2d(10, 20, 5, 1, 1),torch.nn.MaxPool2d(2),torch.nn.ReLU(),torch.nn.BatchNorm2d(20))self.fc1 = torch.nn.Sequential(torch.nn.Linear(500, 60),torch.nn.Dropout(0.5),torch.nn.ReLU())self.fc2 = torch.nn.Sequential(torch.nn.Linear(60, 20),torch.nn.Dropout(0.5),torch.nn.ReLU())self.fc3 = torch.nn.Linear(20, 10)# 定义网络前向传播方式def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(-1, 500)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x# 定义AlexNet结构
class AlexNet(torch.nn.Module):def __init__(self, num_classes=10):super(AlexNet, self).__init__()self.features = torch.nn.Sequential(torch.nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=2),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(kernel_size=3, stride=1),torch.nn.Conv2d(64, 192, kernel_size=3, padding=2),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(kernel_size=3, stride=2),torch.nn.Conv2d(192, 384, kernel_size=3, padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(384, 256, kernel_size=3, padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(256, 256, kernel_size=3, padding=1),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(kernel_size=3, stride=2))self.classifier = torch.nn.Sequential(torch.nn.Dropout(),torch.nn.Linear(256 * 6 * 6, 4096),torch.nn.ReLU(inplace=True),torch.nn.Dropout(),torch.nn.Linear(4096, 4096),torch.nn.ReLU(inplace=True),torch.nn.Linear(4096, num_classes))# 定义AlexNet前向传播过程def forward(self, x):x = self.features(x)x = x.view(x.size(0), 256 * 6 * 6)x = self.classifier(x)return x    # 训练网络函数
def train_net(net, train_data, test_data):losses = []acces = []# 测试集上Loss变化情况eval_losses = []eval_acces = []# 损失函数设置为交叉熵函数criterion = torch.nn.CrossEntropyLoss()# 优化方法选用SGD,初始学习率为1e-2optimizer = torch.optim.SGD(net.parameters(), 1e-2)for e in range(EPOCHS):train_loss = 0train_acc = 0# 将网络设置为训练模型net.train()for image, label in train_data:image = Variable(image)label = Variable(label)# 前向传播out = net(image)loss = criterion(out, label)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 记录误差train_loss += loss.data# 计算分类的准确率_, pred = out.max(1)num_correct = (np.array(pred, dtype=np.int32) == np.array(label, dtype=np.int32)).sum()acc = num_correct / image.shape[0]train_acc += acctrain_loss_rate = train_loss / len(train_data)train_acc_rate = train_acc / len(train_data)losses.append(train_loss_rate)acces.append(train_acc_rate)# 在测试集上检验效果eval_loss = 0eval_acc = 0net.eval() # 将模型改为预测模式for image, label in test_data:image = Variable(image)label = Variable(label)out = net(image)loss = criterion(out, label)# 记录误差eval_loss += loss.data# 记录准确率_, pred = out.max(1)num_correct = (np.array(pred, dtype=np.int32) == np.array(label, dtype=np.int32)).sum()acc = num_correct / image.shape[0]eval_acc += acceval_loss_rate = eval_loss / len(test_data)eval_acc_rate = eval_acc / len(test_data)eval_losses.append(eval_loss_rate)eval_acces.append(eval_acc_rate)print('epoch:{}, Train Loss: {:.6f}, Train Acc:{:.6f}, Eval Loss:{:.6f}, Eval Acc:{:.6f}'.format(e, train_loss_rate, train_acc_rate, eval_loss_rate, eval_acc_rate))torch.save(net.state_dict(), os.path.join(BASE_PATH, SAVE_PATH, 'Alex_model_epoch' + str(e) + '.pkl'))return eval_losses, eval_accesdef draw_result(eval_losses, eval_acces):x = range(1, EPOCHS + 1)fig, left_axis = plt.subplots()p1, = left_axis.plot(x, eval_losses, 'ro-')right_axis = left_axis.twinx()p2, = right_axis.plot(x, eval_acces, 'bo-')plt.xticks(x, rotation=0)# 设置左坐标轴以及右坐标轴的范围、精度left_axis.set_ylim(0, 0.5)left_axis.set_yticks(np.arange(0, 0.5, 0.1))right_axis.set_ylim(0.9, 1.01)right_axis.set_yticks(np.arange(0.9, 1.01, 0.02))# 设置坐标及标题的大小、颜色left_axis.set_xlabel('Labels')left_axis.set_ylabel('Loss', color='r')left_axis.tick_params(axis='y', colors='r')right_axis.set_ylabel('Accuracy', color='b')right_axis.tick_params(axis='y', colors='b')plt.show()if __name__ == '__main__':#x, y = load_mnist()print("基于卷积神经网络实现手写数字识别")train_set = mnist.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())//需要转化成tensor数据格式test_set = mnist.MNIST('./data', train=False, download=True, transform=transforms.ToTensor())train_data = DataLoader(train_set, batch_size=64, shuffle=True)test_data = DataLoader(test_set, batch_size=64, shuffle=False)a, a_label = next(iter(train_data))#net = AlexNet()net = ConvNet()eval_losses, eval_acces = train_net(net, train_data, test_data)draw_result(eval_losses, eval_acces)
结果:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/749720.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

备战蓝桥杯Day28 - 拼接最大数字问题

问题描述 有n个非负整数,将其按照字符串拼接的方式拼接为一个整数如何拼接可以使得得到的整数最大? 例: 32,94,128,1286,6,71可以拼接除的最大整数为 94716321286128。 问题思路 1.比较两个字符串的第一个数字,数值大的在前面,数值小的在…

MATLAB中的数据类型,例如double,char,logical等。

在MATLAB中创建和操作矩阵是MATLAB编程的基础,因为MATLAB本身就是一个以矩阵为基本数据单位的编程环境。下面将详细解释如何在MATLAB中创建和操作矩阵。 创建矩阵 在MATLAB中,创建矩阵的基本方法是将数据按行输入,元素之间用空格或逗号分隔…

桌面待办,电脑桌面待办事项便利贴怎么搞的

电脑桌面待办事项贴便利,是一款非常实用的小工具,可以帮助我们高效管理工作和生活中的各种任务。通过简单的操作,你可以在电脑桌面上添加待办事项,随时提醒自己完成任务,提高工作效率。那么,桌面待办&#…

Hack The Box-Jab

目录 信息收集 nmap enum4linux 服务信息收集 Pidgin kerbrute hashcat 反弹shell & get user 提权 系统信息收集 端口转发 漏洞利用 get root 信息收集 nmap 端口探测┌──(root㉿ru)-[~/kali/hackthebox] └─# nmap -p- 10.10.11.4 --min-rate 10000 -oA…

vitepress里使用gitalk(图文教程)

vitepress里使用gitalk Gitalk 是一个基于 GitHub Issue 和 Preact 开发的评论插件 生成client配置 创建OAuth application 填写完毕,点击 Register application 即可 生成client secrets 一开始没有自动生成 Client secrets,需要手动生成&#xff…

Day17 深入类加载机制

Day17 深入类加载机制 文章目录 Day17 深入类加载机制一、初识类加载过程二、深入类加载过程三、利用类加载过程理解面试题四、类加载器五、类加载器分类六、类加载器之间的层次关系七、双亲委派模型 - 概念八、双亲委派模型 - 工作过程九、双亲委派模型 - 好处十、双亲委派原则…

MySQL:视图

1. 概述 在MySQL中,视图(View)是一个虚拟存在的表,其内容是由查询定义的。视图本身并不包含数据,它只包含一条SQL查询语句(即定义视图的SELECT语句)。当通过视图访问数据时,MySQL会执…

【软考高项】八、信息技术发展之新一代信息技术及应用

1、物联网 定义:通过信息传感设备,按约定的协议将任何物品与互联网相连接,进行信息交换和通信,以实现智能化识别、定位、跟踪、监控和管理的网络 分层: 感知层---各种传感器构成 网络层---物联网的中枢&#xff0c…

西门子TIA中配置Anybus PROFINET IO Slave 模块

1、所需产品 Siemens S7 PLC CPU 315-2 PN/DP 6ES7 315-2EH-0AB0 Siemens PLC 编程电缆 n.a. n.a. PC ,并安装Siemens PLC编程软件 TIA Portal V11 X-gateway Slave 接口的GSDML文件 根据网关的软件版本而定 Anybus Communicator GSD文件 GSDML-V1.0-HMS-ABCPRT-20050317.xl…

win下 VirtualBox 自动启动脚本脚本

文章目录 一、找到VBoxManage二、测试脚本1、打开cmd2、输入命令 (直接把上面找到的VBoxManage.exe 拖入到cmd中,这样就不用输入路径了)3、效果展示 比如虚拟机中的系统名称叫“centos-mini” 三、设置自动启动脚本1、复制刚才测试好的命令到新建文本中2、修改文本名…

Golang实现Redis分布式锁(Lua脚本+可重入+自动续期)

Golang实现Redis分布式锁(Lua脚本可重入自动续期) 1 概念 应用场景 Golang自带的Lock锁单机版OK(存储在程序的内存中),分布式不行 分布式锁: 简单版:redis setnx》加锁设置过期时间需要保证原…

P8706 [蓝桥杯 2020 省 AB1] 解码 Python

[蓝桥杯 2020 省 AB1] 解码 题目描述 小明有一串很长的英文字母,可能包含大写和小写。 在这串字母中,有很多连续的是重复的。小明想了一个办法将这串字母表达得更短:将连续的几个相同字母写成字母 出现次数的形式。 例如,连续…

React Hooks、useState、useEffect 、react函数状态

Hooks Hooks 概念理解 学习目标: 理解 Hooks 的概念及解决的问题 什么是 hooks hooks 的本质: 一套能够使函数组件更强大、更灵活的(钩子) React 体系里组件分为类组件和函数组件 多年使用发现,函数组件是一个更加匹…

Unity3d版白银城地图

将老外之前拼接的Unity3d版白银城地图,导入到国内某手游里,改成它的客户端地图模式,可以体验一把手游的快乐。 人物角色用的是它原版的手游默认的,城内显示效果很好,大家可以仔细看看。 由于前期在导入时遇到重大挫折&…

PMP的学习方法

PMBOK编撰了管理项目需要的49个过程(输入、工具技术、输出)。工具技术文件,林林总总百余个。第一部分,按照十大知识领域顺序从前到后编排;第二部分,按照五大过程组顺序重新编排了一遍。 一,PMB…

xray问题排查,curl: (35) Encountered end of file(已解决)

经过了好几次排查,都没找到问题,先说问题的排查过程,多次确认了user信息,包括用户id和alterid,都没问题,头大的一逼 问题排查过程 确保本地的xray服务是正常的 [rootk8s-master01 xray]# systemctl stat…

StarRocks面试题及答案整理,最新面试题

StarRocks 的 MV(物化视图)机制是如何工作的? StarRocks 的物化视图(MV)机制通过预先计算和存储数据的聚合结果或者转换结果来提高查询性能。其工作原理如下: 1、数据预处理: 在创建物化视图时…

2024年3月环境管理体系基础考试真题

2024年3月环境管理体系基础考试真题 一、单项选择题(每题1.5分,共60分) 1.依据GB/T24001-2016标准,6.1.1中要求应确定需应对的风险和机遇,以确保组织能够实现其环境管理体系的预期结果,预防或减少&#x…

开发指南005-前端配置文件

平台要求无论前端还是后端,修改配置可以直接用记事本修改,无需重新打包或修改压缩包里文件。就前端而言,很多系统修改配置是在代码里修改,然后打包或者是修改编译环境来重新编译。 平台前端的配置文件为/static/js/下qlm_config.j…

算法打卡day19|二叉树篇08|Leetcode 235. 二叉搜索树的最近公共祖先、701.二叉搜索树中的插入操作、450.删除二叉搜索树中的节点

算法题 Leetcode 235. 二叉搜索树的最近公共祖先 题目链接:235. 二叉搜索树的最近公共祖先 大佬视频讲解:二叉搜索树的最近公共祖先视频讲解 个人思路 昨天做过一道二叉树的最近公共祖先,而这道是二叉搜索树,那就要好好利用这个有序的特点…