pytorch dropout_PyTorch初探MNIST数据集

7cd6fe6b5ffe437db957fa7e8b843d66.png

前言

本文主要描述了如何使用现在热度和关注度比较高的Pytorch(深度学习框架)构建一个简单的卷积神经网络,并对MNIST数据集进行了训练和测试。MNIST数据集是一个28*28的手写数字图片集合,使用测试集来验证训练出的模型对手写数字的识别准确率。

PyTorch资料

PyTorch的官方文档链接:PyTorch documentation,在这里不仅有 API的说明还有一些经典的实例可供参考。

PyTorch官网论坛:vision,里面会有很大资料分享和一些热门问题的解答。

PyTorch搭建神经网络实践

在一开始导入需要导入PyTorch的两个核心库文件torch和torchvision,这两个库基本包含了PyTorch会用到的许多方法和函数

import 

其中值得一提的是torchvision的datasets可以很方便的自动下载数据集,这里使用的是MNIST数据集。另外的COCO,ImageNet,CIFCAR等数据集也可以很方的下载并使用,导入命令也非常简单

data_train = datasets.MNIST(root = "./data/",transform=transform,train = True,download = True)data_test = datasets.MNIST(root="./data/",transform = transform,train = False)

root指定了数据集存放的路径,transform指定导入数据集时需要进行何种变换操作,train设置为True说明导入的是训练集合,否则为测试集合。

transform里面还有很多好的方法,可以用在图片资源较少的数据集做Data Argumentation操作,这里只是做了个简单的Tensor格式转换和Batch Normalize

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])

数据下载完成后还需要做数据装载操作

data_loader_train = torch.utils.data.DataLoader(dataset=data_train,batch_size = 64,shuffle = True)data_loader_test = torch.utils.data.DataLoader(dataset=data_test,batch_size = 64,shuffle = True)

batch_size设置了每批装载的数据图片为64个,shuffle设置为True在装载过程中为随机乱序

下图为一个batch数据集(64张图片)的显示,可以看出来都为28*28的1维图片

d98973570b58f059854b032f575cb8ec.png
MNIST数据集图片预览

完成数据装载后就可以构建核心程序了,这里构建的是一个包含了卷积层和全连接层的神经网络,其中卷积层使用torch.nn.Conv2d来构建,激活层使用torch.nn.ReLU来构建,池化层使用torch.nn.MaxPool2d来构建,全连接层使用torch.nn.Linear来构建

class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),torch.nn.ReLU(),torch.nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),torch.nn.ReLU(),torch.nn.MaxPool2d(stride=2,kernel_size=2))self.dense = torch.nn.Sequential(torch.nn.Linear(14*14*128,1024),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(1024, 10))def forward(self, x):x = self.conv1(x)x = x.view(-1, 14*14*128)x = self.dense(x)return x

其中定义了torch.nn.Dropout(p=0.5)防止模型的过拟合

forward函数定义了前向传播,其实就是正常卷积路径。首先经过self.conv1(x)卷积处理,然后进行x.view(-1, 14*14*128)压缩扁平化处理,最后通过self.dense(x)全连接进行分类

之后就是对Model对象进行调用,然后定义loss计算使用交叉熵,优化计算使用Adam自动化方式,最后就可以开始训练了

model = Model()
cost = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

在训练前可以查看神经网络架构了,print输出显示如下

Model ((conv1): Sequential ((0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU ()(2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU ()(4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)))(dense): Sequential ((0): Linear (25088 -> 1024)(1): ReLU ()(2): Dropout (p = 0.5)(3): Linear (1024 -> 10))
)

定义训练次数为5次,开始跑神经网络,训练完成后输入测试集合得到的结果如下

Epoch 0/5
----------
Loss is:0.0003, Train Accuracy is:99.4167%, Test Accuracy is:98.6600
Epoch 1/5
----------
Loss is:0.0002, Train Accuracy is:99.5967%, Test Accuracy is:98.9200
Epoch 2/5
----------
Loss is:0.0002, Train Accuracy is:99.6667%, Test Accuracy is:98.7700
Epoch 3/5
----------
Loss is:0.0002, Train Accuracy is:99.7133%, Test Accuracy is:98.9600
Epoch 4/5
----------
Loss is:0.0001, Train Accuracy is:99.7317%, Test Accuracy is:98.7300

从结果上看还不错,训练准确率最高达到了99.73%,测试最高准确率为98.96%。结果有轻微的过拟合迹象,如果使用更加健壮的卷积模型测试集会取得更加好的结果。

随机对几张测试集的图片进行预测,并做可视化展示

Predict Label is: [3, 4, 9, 3]
Real Label is: [3, 4, 9, 3]

e1ccbcaaa5619230f5649016df26370f.png

训练完成后还可以保存训练得到的参数,方便下次导入后可供直接使用

torch.save(model.state_dict(), "model_parameter.pkl")

完整代码链接:JaimeTang/Pytorch-and-mnist(model_parameter.pkl文件较大未做上传)


微信公众号:PyMachine

6198995399de3ce12432bb1573de9734.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/431192.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

saphana服务器硬件评估,华为SAP HANA一体机:你身边的数据计算专家

​ 华为服务器高级营销经理 谭鑫/文​在当今数字经济时代,如何将数据快速变现为价值资产是很多企业追求的目标。借助SAPHANA平台方案与技术,可以将数据处理化繁为简,及时汲取价值信息,为企业的业务决策提供数据参考,从…

verilog设置24进制计数器_阅读笔记:《Verilog HDL入门》第3章 Verilog语言要素

3.1标识符1.Verilog中的Identifier是由任意字母、数字、下划线和$符号组成的,第一个字符必须是字母或者下划线。区分大小写。2.Escaped Identifier是为了解决简单标识符不能以数字和$符号开头的缺点。如下所示:3.关键字。我的理解是保留字包括关键字&…

锋利的jQuery--jQuery与DOM对象的互相转换,DOM的三种操作(读书笔记一)

1.jQuery对象就是通过jQuery包装DOM对象后产生的对象。2.jQuery对象和DOM对象的相互转换。良好的书写风格&#xff1a;var $input$("input")jQuery获取的对象在变量前面加上$。<1>jQUery对象转成DOM对象,两种方法&#xff1a;[index]和get(index)a:var $cr$(&q…

网站显示不正常服务器怎么弄,你真的知道网站出现收录不正常的原因是什么吗...

当一个新网站构建起来时&#xff0c;每天所担心的就是一个收录量。当你偶然看到收录减少时&#xff0c;不免心有所寒。不知道怎么“得罪”蜘蛛大哥了&#xff0c;发生了什么事把收录量给“没收”了。作为SEOer&#xff0c;我们知道在搜索引擎蜘蛛的心里内容的质量占有很大比重&…

ajax status php,解决laravel 出现ajax请求419(unknown status)的问题

如下所示&#xff1a;这个是因为laravel自带csrf验证的问题解决方法方法一&#xff1a;去关掉laravel的csrf验证&#xff0c;但这个人不建议&#xff0c;方法也不写出来了。方法二&#xff1a;把该接口写到api.php上就好了方法三&#xff1a;首先在页面加上然后请求的在header里…

string 转比较运算符_运算符

1、概述算术运算符 - * /基本运算算术运算符%取模&#xff0c;取余数&#xff0c;计算整除算术运算符 --自增 自减比较运算符 !相等比较 不等比较逻辑运算符&& &逻辑与 短路与&#xff08;同真为真&#xff09;逻辑运算符|| |逻辑或 短路或&#xff08;一真则真&am…

神舟战神换cpu教程_神舟将十代i5称为“神U出世”?聊聊到底有哪些优势

在各个品牌大力的宣传之下&#xff0c;消费者对于笔记本电脑乃至各种数码硬件的要求都越来越高。既要好的处理器、显卡等性能配置&#xff0c;又要好的屏幕&#xff0c;甚至还得低定价&#xff0c;这就产生一种鱼与熊掌不可兼得的感觉了。就在今年的表白日&#xff0c;神舟电脑…

服务器系统杀毒系统崩溃怎么恢复,系统崩溃是什么原因导致的

大家在使用电脑的时候&#xff0c;经常都是需要安装一些软件和其他东西的。但是在安装软件的时候&#xff0c;很容易让一些病毒侵入电脑。一旦病毒侵入了电脑&#xff0c;就很容易让电脑系统崩溃。那么系统崩溃是什么原因导致的呢&#xff1f;下面就来告诉大家系统崩溃的原因及…

atom配置python环境_python与excel有段情之二:python的安装和环境配置

索引python与excel有段情之一&#xff1a;前述python与excel有段情之二&#xff1a;python的安装和环境配置python与excel有段情之三&#xff1a;python编程前的准备工作和基本概念python与excel有段情之四&#xff1a;案例1.把多excel表抽数生成新excel表python与excel有段情之…

x86服务器当虚拟化的存储,龙存科技-软件定义数据中心产品提供商

一、应用背景服务器虚拟化技术是云计算的核心技术&#xff0c;是将系统进行虚拟化应用于服务器之上的技术。面向应用集中化处理&#xff0c;能最大的程度上利用硬件资源&#xff0c;并且实现灵活分配。虚拟化技术是将计算机底层的硬件功能的模拟&#xff0c;需要复杂的语句和机…

使用python开发网页游戏_不敢想!不敢想!我用Python自动玩转2048游戏

近来在折腾selenium自动化, 感觉配合爬虫很有意思, 大多数以前难以模拟登录的网站都可以爬了&#xff0c;折腾了这么久,于是想自动玩个2048游戏&#xff01;嘿嘿, 我是一个不擅长玩游戏的人, 以前玩2048就经常得了很低的分&#xff0c;每每想起都”痛心疾首”, 所以我打算拿204…

【飞谷六期】爬虫项目4

经过了几天的摸索&#xff0c;照猫画虎的把爬虫的部分做完了。 但是很多原理性的东西都不是很理解&#xff0c;就是照着抄的&#xff0c;还需要继续学习。 看这个目录结构&#xff0c;只看.py的文件&#xff0c;.pyc的文件是运行的时候生成的不管它。 items.py:定义想要导出的数…

activex控件 新对象 ocx 初始化_Office已经支持64位的树控件Treeview了

之前在使用Office365时发现微软其实已经悄悄地开始提供了64位的Treeview树控件&#xff0c;只是并没有公开宣布。当时是在一个网友的电脑上说他可以在64位Excel中可直接使用64位树控件&#xff0c;当时以为他看到的只是一个假的树控件&#xff0c;后来经过远程他的电脑&#xf…

mysql 获取昨天凌晨_MySQL慢日志体系建设

慢查询日志是MySQL提供的一种日志记录&#xff0c;用来记录在MySQL中响应时间超过阈值的SQL语句&#xff0c;在很大程度上会影响数据库整体的性能&#xff0c;是MySQL优化的一个重要方向。在58的云DB平台建设中&#xff0c;慢SQL系统作为一个非常重要功能模块&#xff0c;不仅是…

十进制小数化为二进制小数的方法是什么_十进制转成二进制的两种方式

第一种&#xff1a;用2整除的方式。用2整除十进制整数&#xff0c;得到一个商和余数&#xff1b;再用2去除商&#xff0c;又会得到一个商和余数&#xff0c;如此重复&#xff0c;直到商为小于1时为止&#xff0c;然后把先得到余数作为二进制数的低位有效位&#xff0c;后得到的…

notes邮件正文显示不全_python实现一次性批量发邮件

在上次实现了批量修改文件名后&#xff08;链接&#xff1a;https://zhuanlan.zhihu.com/p/133727520&#xff09;&#xff0c;又拿来了同事编写的一次性批量发邮件小程序&#xff0c;小编每月向分公司发数据任务算是基本上实现了自动化 需要新建2个.py文件实现&#xff0c;一个…

用python画五角星中心颜色不同_画个五角星让它绕中心点旋转

李兴球Python画个五角星绕中心点旋转 画一个五角星让它旋转起来,这在Python中有几个方案可选,这里提供一个不是用自定义形状的方案,以下是部分源代码, 其实关键的代码函数&#xff0c;也就是starpoints这个函数&#xff0c;它是核心。代码已经给你了。接下来就看你的聪明才智了…

oracle 建表id自增长_oracle 左连接、右连接、全外连接、内连接、以及 (+) 号用法...

Oracle中的连接可分为&#xff0c;内连接(INNER JOIN)、外连接(OUTER JOIN)、全连接(FULL JOIN)&#xff0c;不光是 Oracle&#xff0c;其他很多的数据库也都有这3种连接查询方式。Oracle 外连接(OUTER JOIN)&#xff0c;又分为左外连接和右外连接&#xff0c;即左连接和右连接…

匿名函数自我调用_Python中的匿名函数及递归思想简析

匿名函数前言上次咱们基本说了一下函数的定义及简单使用&#xff0c;Python中的基本函数及其常用用法简析&#xff0c;现在咱们整点进阶一些的。同样都是小白&#xff0c;咱也不知道实际需要不&#xff0c;但是对于函数的执行顺序以及装饰器的理解还是很有必要的。首先咱们先简…

java解析dxf文件_浅析JVM方法解析、创建和链接

一&#xff1a;前言上周末写了一篇文章《你知道Java类是如何被加载的吗&#xff1f;》&#xff0c;分析了HotSpot是如何加载Java类的&#xff0c;干脆趁热打铁&#xff0c;本周末再来分析下Hotspot又是如何解析、创建和链接类方法的。二&#xff1a;Class文件中的Java方法Java类…