pytorch dropout_PyTorch初探MNIST数据集

7cd6fe6b5ffe437db957fa7e8b843d66.png

前言

本文主要描述了如何使用现在热度和关注度比较高的Pytorch(深度学习框架)构建一个简单的卷积神经网络,并对MNIST数据集进行了训练和测试。MNIST数据集是一个28*28的手写数字图片集合,使用测试集来验证训练出的模型对手写数字的识别准确率。

PyTorch资料

PyTorch的官方文档链接:PyTorch documentation,在这里不仅有 API的说明还有一些经典的实例可供参考。

PyTorch官网论坛:vision,里面会有很大资料分享和一些热门问题的解答。

PyTorch搭建神经网络实践

在一开始导入需要导入PyTorch的两个核心库文件torch和torchvision,这两个库基本包含了PyTorch会用到的许多方法和函数

import 

其中值得一提的是torchvision的datasets可以很方便的自动下载数据集,这里使用的是MNIST数据集。另外的COCO,ImageNet,CIFCAR等数据集也可以很方的下载并使用,导入命令也非常简单

data_train = datasets.MNIST(root = "./data/",transform=transform,train = True,download = True)data_test = datasets.MNIST(root="./data/",transform = transform,train = False)

root指定了数据集存放的路径,transform指定导入数据集时需要进行何种变换操作,train设置为True说明导入的是训练集合,否则为测试集合。

transform里面还有很多好的方法,可以用在图片资源较少的数据集做Data Argumentation操作,这里只是做了个简单的Tensor格式转换和Batch Normalize

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])

数据下载完成后还需要做数据装载操作

data_loader_train = torch.utils.data.DataLoader(dataset=data_train,batch_size = 64,shuffle = True)data_loader_test = torch.utils.data.DataLoader(dataset=data_test,batch_size = 64,shuffle = True)

batch_size设置了每批装载的数据图片为64个,shuffle设置为True在装载过程中为随机乱序

下图为一个batch数据集(64张图片)的显示,可以看出来都为28*28的1维图片

d98973570b58f059854b032f575cb8ec.png
MNIST数据集图片预览

完成数据装载后就可以构建核心程序了,这里构建的是一个包含了卷积层和全连接层的神经网络,其中卷积层使用torch.nn.Conv2d来构建,激活层使用torch.nn.ReLU来构建,池化层使用torch.nn.MaxPool2d来构建,全连接层使用torch.nn.Linear来构建

class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),torch.nn.ReLU(),torch.nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),torch.nn.ReLU(),torch.nn.MaxPool2d(stride=2,kernel_size=2))self.dense = torch.nn.Sequential(torch.nn.Linear(14*14*128,1024),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(1024, 10))def forward(self, x):x = self.conv1(x)x = x.view(-1, 14*14*128)x = self.dense(x)return x

其中定义了torch.nn.Dropout(p=0.5)防止模型的过拟合

forward函数定义了前向传播,其实就是正常卷积路径。首先经过self.conv1(x)卷积处理,然后进行x.view(-1, 14*14*128)压缩扁平化处理,最后通过self.dense(x)全连接进行分类

之后就是对Model对象进行调用,然后定义loss计算使用交叉熵,优化计算使用Adam自动化方式,最后就可以开始训练了

model = Model()
cost = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

在训练前可以查看神经网络架构了,print输出显示如下

Model ((conv1): Sequential ((0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU ()(2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU ()(4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)))(dense): Sequential ((0): Linear (25088 -> 1024)(1): ReLU ()(2): Dropout (p = 0.5)(3): Linear (1024 -> 10))
)

定义训练次数为5次,开始跑神经网络,训练完成后输入测试集合得到的结果如下

Epoch 0/5
----------
Loss is:0.0003, Train Accuracy is:99.4167%, Test Accuracy is:98.6600
Epoch 1/5
----------
Loss is:0.0002, Train Accuracy is:99.5967%, Test Accuracy is:98.9200
Epoch 2/5
----------
Loss is:0.0002, Train Accuracy is:99.6667%, Test Accuracy is:98.7700
Epoch 3/5
----------
Loss is:0.0002, Train Accuracy is:99.7133%, Test Accuracy is:98.9600
Epoch 4/5
----------
Loss is:0.0001, Train Accuracy is:99.7317%, Test Accuracy is:98.7300

从结果上看还不错,训练准确率最高达到了99.73%,测试最高准确率为98.96%。结果有轻微的过拟合迹象,如果使用更加健壮的卷积模型测试集会取得更加好的结果。

随机对几张测试集的图片进行预测,并做可视化展示

Predict Label is: [3, 4, 9, 3]
Real Label is: [3, 4, 9, 3]

e1ccbcaaa5619230f5649016df26370f.png

训练完成后还可以保存训练得到的参数,方便下次导入后可供直接使用

torch.save(model.state_dict(), "model_parameter.pkl")

完整代码链接:JaimeTang/Pytorch-and-mnist(model_parameter.pkl文件较大未做上传)


微信公众号:PyMachine

6198995399de3ce12432bb1573de9734.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/431192.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

DQL查询语句内容整理

select * from t_hq_ryxx;select bianh,xingm from t_hq_ryxx;--为字段名定义别名 select bianh as 编号,xingm as 姓名 from t_hq_ryxx;select bianh 编号 from t_hq_ryxx;select bianh || xingm as 编号和姓名 from t_hq_ryxx;select bianh as bh, t.* from t_hq_ryxx t ord…

saphana服务器硬件评估,华为SAP HANA一体机:你身边的数据计算专家

​ 华为服务器高级营销经理 谭鑫/文​在当今数字经济时代,如何将数据快速变现为价值资产是很多企业追求的目标。借助SAPHANA平台方案与技术,可以将数据处理化繁为简,及时汲取价值信息,为企业的业务决策提供数据参考,从…

verilog设置24进制计数器_阅读笔记:《Verilog HDL入门》第3章 Verilog语言要素

3.1标识符1.Verilog中的Identifier是由任意字母、数字、下划线和$符号组成的,第一个字符必须是字母或者下划线。区分大小写。2.Escaped Identifier是为了解决简单标识符不能以数字和$符号开头的缺点。如下所示:3.关键字。我的理解是保留字包括关键字&…

linux中python安装_linux环境下的python安装过程图解(含setuptools)

这里我不想采用诸如ubuntu下的apt-get install方式进行python的安装,而是在linux下采用源码包的方式进行python的安装。一、下载python源码包打开ubuntu下的shell终端,通过wget命令下载python源码包,如下图所示:将python-2.7.3.tgz下载至/opt…

锋利的jQuery--jQuery与DOM对象的互相转换,DOM的三种操作(读书笔记一)

1.jQuery对象就是通过jQuery包装DOM对象后产生的对象。2.jQuery对象和DOM对象的相互转换。良好的书写风格&#xff1a;var $input$("input")jQuery获取的对象在变量前面加上$。<1>jQUery对象转成DOM对象,两种方法&#xff1a;[index]和get(index)a:var $cr$(&q…

网站显示不正常服务器怎么弄,你真的知道网站出现收录不正常的原因是什么吗...

当一个新网站构建起来时&#xff0c;每天所担心的就是一个收录量。当你偶然看到收录减少时&#xff0c;不免心有所寒。不知道怎么“得罪”蜘蛛大哥了&#xff0c;发生了什么事把收录量给“没收”了。作为SEOer&#xff0c;我们知道在搜索引擎蜘蛛的心里内容的质量占有很大比重&…

自定义标签 (choose)

因为有3个标签,所以写3个标签处理器类 1.ChooseTag public class ChooseTag extends SimpleTagSupport {  //定义一个标记,用于控制跳转去向.public Boolean tag true;public Boolean getTag() {return tag;}public void setTag(Boolean tag) {this.tag tag;}Overridepubli…

计算机考上研究生暑假去哪里实习_浅谈化工与计算机行业

本人化工专业&#xff0c;毕业转行计算机。结合四年化工的学习以及互联网行业实习工作的经历&#xff0c;浅谈一下这两个行业。大二暑假到连云港一家民营炼化企业参观实习。该家企业给出的工资是三千五&#xff0c;四班两倒&#xff0c;可以住连云港的人才公寓&#xff0c;上下…

ajax status php,解决laravel 出现ajax请求419(unknown status)的问题

如下所示&#xff1a;这个是因为laravel自带csrf验证的问题解决方法方法一&#xff1a;去关掉laravel的csrf验证&#xff0c;但这个人不建议&#xff0c;方法也不写出来了。方法二&#xff1a;把该接口写到api.php上就好了方法三&#xff1a;首先在页面加上然后请求的在header里…

string 转比较运算符_运算符

1、概述算术运算符 - * /基本运算算术运算符%取模&#xff0c;取余数&#xff0c;计算整除算术运算符 --自增 自减比较运算符 !相等比较 不等比较逻辑运算符&& &逻辑与 短路与&#xff08;同真为真&#xff09;逻辑运算符|| |逻辑或 短路或&#xff08;一真则真&am…

神舟战神换cpu教程_神舟将十代i5称为“神U出世”?聊聊到底有哪些优势

在各个品牌大力的宣传之下&#xff0c;消费者对于笔记本电脑乃至各种数码硬件的要求都越来越高。既要好的处理器、显卡等性能配置&#xff0c;又要好的屏幕&#xff0c;甚至还得低定价&#xff0c;这就产生一种鱼与熊掌不可兼得的感觉了。就在今年的表白日&#xff0c;神舟电脑…

Oracle数据库更新时间的SQL语句

---Oracle数据库更新时间字段数据时的sql语句---格式化时间插入update t_user u set u.namepipi,u.modifytimeto_date(2015-10-07 00:00:00,YYYY-MM-DD HH24:MI:SS) where u.uid 11111---使用数据库系统当前时间update t_user u set u.namepipi,u.modifytimesysdate where u.u…

服务器系统杀毒系统崩溃怎么恢复,系统崩溃是什么原因导致的

大家在使用电脑的时候&#xff0c;经常都是需要安装一些软件和其他东西的。但是在安装软件的时候&#xff0c;很容易让一些病毒侵入电脑。一旦病毒侵入了电脑&#xff0c;就很容易让电脑系统崩溃。那么系统崩溃是什么原因导致的呢&#xff1f;下面就来告诉大家系统崩溃的原因及…

python argument list too long_间歇“OSError:[Errno 7]参数列表太长”,命令短(~125个字符)...

在Linux上的apache2mod_wsgi下运行的代码有时会产生以下输出。在notes.pycmd_list [abc_generate_pdf,--cdb-url-prefix, model.config(cdb_url_prefix),--request-cid, request_cid,]log.info("About to run: {!r}".format(cmd_list))subprocess.Popen(cmd_list)..…

atom配置python环境_python与excel有段情之二:python的安装和环境配置

索引python与excel有段情之一&#xff1a;前述python与excel有段情之二&#xff1a;python的安装和环境配置python与excel有段情之三&#xff1a;python编程前的准备工作和基本概念python与excel有段情之四&#xff1a;案例1.把多excel表抽数生成新excel表python与excel有段情之…

Scrapy--1安装和运行

1.Scrapy安装问题 一开始是按照官方文档上直接用pip安装的&#xff0c;创建项目的时候并没有报错&#xff0c; 然而在运行 scrapy crawl dmoz 的时候错误百粗/(ㄒoㄒ)/~~比如&#xff1a; ImportError: No module named _cffi_backend Unhandled error in Deferred 等等,发现是…

x86服务器当虚拟化的存储,龙存科技-软件定义数据中心产品提供商

一、应用背景服务器虚拟化技术是云计算的核心技术&#xff0c;是将系统进行虚拟化应用于服务器之上的技术。面向应用集中化处理&#xff0c;能最大的程度上利用硬件资源&#xff0c;并且实现灵活分配。虚拟化技术是将计算机底层的硬件功能的模拟&#xff0c;需要复杂的语句和机…

mysql 参照完整性规则_mysql参照完整性

该楼层疑似违规已被系统折叠 隐藏此楼查看此楼MYSQL支持数据库的参照完整性约束吗&#xff1f;有四个表&#xff1a;表一的主键是的表二外键表二的主键是表三的外键表三的主键是表四的外键请问&#xff1a;如果&#xff1a;删除表一表2 3 4 会自动删除吗&#xff1f;----------…

使用python开发网页游戏_不敢想!不敢想!我用Python自动玩转2048游戏

近来在折腾selenium自动化, 感觉配合爬虫很有意思, 大多数以前难以模拟登录的网站都可以爬了&#xff0c;折腾了这么久,于是想自动玩个2048游戏&#xff01;嘿嘿, 我是一个不擅长玩游戏的人, 以前玩2048就经常得了很低的分&#xff0c;每每想起都”痛心疾首”, 所以我打算拿204…

【飞谷六期】爬虫项目4

经过了几天的摸索&#xff0c;照猫画虎的把爬虫的部分做完了。 但是很多原理性的东西都不是很理解&#xff0c;就是照着抄的&#xff0c;还需要继续学习。 看这个目录结构&#xff0c;只看.py的文件&#xff0c;.pyc的文件是运行的时候生成的不管它。 items.py:定义想要导出的数…