python数据集的预处理_关于Pytorch的MNIST数据集的预处理详解

关于Pytorch的MNIST数据集的预处理详解

MNIST的准确率达到99.7%

用于MNIST的卷积神经网络(CNN)的实现,具有各种技术,例如数据增强,丢失,伪随机化等。

操作系统:ubuntu18.04

显卡:GTX1080ti

python版本:2.7(3.7)

网络架构

具有4层的CNN具有以下架构。

输入层:784个节点(MNIST图像大小)

第一卷积层:5x5x32

第一个最大池层

第二卷积层:5x5x64

第二个最大池层

第三个完全连接层:1024个节点

输出层:10个节点(MNIST的类数)

用于改善CNN性能的工具

采用以下技术来改善CNN的性能。

1. Data augmentation

通过以下方式将列车数据的数量增加到5倍

随机旋转:每个图像在[-15°,+ 15°]范围内随机旋转。

随机移位:每个图像在两个轴上随机移动一个范围为[-2pix,+ 2pix]的值。

零中心归一化:将像素值减去(PIXEL_DEPTH / 2)并除以PIXEL_DEPTH。

2. Parameter initializers

重量初始化器:xaiver初始化器

偏差初始值设定项:常量(零)初始值设定项

3. Batch normalization

所有卷积/完全连接的层都使用批量标准化。

4. Dropout

The third fully-connected layer employes dropout technique.

5. Exponentially decayed learning rate

A learning rate is decayed every after one-epoch.

代码部分

第一步:了解MNIST数据集

MNIST数据集是一个手写体数据集,一共60000张图片,所有的图片都是28×28的,下载数据集的地址:数据集官网。这个数据集由四部分组成,分别是:

train-images-idx3-ubyte.gz: training set images (9912422 bytes)

train-labels-idx1-ubyte.gz: training set labels (28881 bytes)

t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)

t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)

也就是一个训练图片集,一个训练标签集,一个测试图片集,一个测试标签集;我们可以看出这个其实并不是普通的文本文件

或是图片文件,而是一个压缩文件,下载并解压出来,我们看到的是二进制文件。

第二步:加载MNIST数据集

先引入一些库文件

import torchvision,torch

import torchvision.transforms as transforms

from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt

加载MNIST数据集有很多方法:

方法一:在pytorch下可以直接调用torchvision.datasets里面的MNIST数据集(这是官方写好的数据集类)

train = torchvision.datasets.MNIST(root='./mnist/',train=True, transform= transforms.ToTensor())

返回值为一个元组(train_data,train_target)(这个类使用的时候也有坑,必须用train[i]索引才能使用 transform功能)

一般是与torch.utils.data.DataLoader配合使用

dataloader = DataLoader(train, batch_size=50,shuffle=True, num_workers=4)

for step, (x, y) in enumerate(dataloader):

b_x = x.shape

b_y = y.shape

print 'Step: ', step, '| train_data的维度' ,b_x,'| train_target的维度',b_y

如图将60000张图片的数据分为1200份,每份包含50张图像,这样并行处理数据能有效加快计算速度

看个人喜好,本人不太喜欢这种固定的数据类,所以想要灵活多变,可以开始自己写数据集类

方法二:自己设置数据集

使用pytorch相关类,API对数据集进行封装,pytorch中数据集相关的类位于torch.utils.data package中。

本次实验,主要使用以下类:

torch.utils.data.Dataset

torch.utils.data.DataLoader

Dataset类的使用: 所有的类都应该是此类的子类(也就是说应该继承该类)。 所有的子类都要重写(override) len(), getitem() 这两个方法。

使用到的python package

python package

目的

numpy

矩阵操作,对图像进行转置

skimage

图像处理,图像I/O,图像变换

matplotlib

图像的显示,可视化

os

一些文件查找操作

torch

pytorch

torvision

pytorch

导入相关的包

import numpy as np

from skimage import io

from skimage import transform

import matplotlib.pyplot as plt

import os

import torch

import torchvision

from torch.utils.data import Dataset, DataLoader

from torchvision.transforms import transforms

from PIL import Image

第一步:

定义一个子类,继承Dataset类, 重写 __len()__, __getitem()__ 方法。

细节:

1.数据集一个样本的表示:采用字典的形式sample = {'img': img, 'target': target}。

图像的读取:采用torch.load进行读取,读取之后的结果为torch.Tensor形式。

图像变换:transform参数

class MY_MNIST(Dataset):

training_file = 'training.pt'

test_file = 'test.pt'

def __init__(self, root, transform=None):

self.transform = transform

self.data, self.targets = torch.load(root)

def __getitem__(self, index):

img, target = self.data[index], int(self.targets[index])

img = Image.fromarray(img.numpy(), mode='L')

if self.transform is not None:

img = self.transform(img)

img =transforms.ToTensor()(img)

sample = {'img': img, 'target': target}

return sample

def __len__(self):

return len(self.data)

train = MY_MNIST(root='./mnist/MNIST/processed/training.pt', transform= None)

第二步

实例化一个对象,并读取和显示数据集

for (cnt,i) in enumerate(train):

image = i['img']

label = i['target']

ax = plt.subplot(4, 4, cnt+1)

# ax.axis('off')

ax.imshow(image.squeeze(0))

ax.set_title(label)

plt.pause(0.001)

if cnt ==15:

break

输出如下 ,这样就表明,咱们自己写的数据集读取图像,并读取之后的结果为torch.Tensor形式成功啦!

第三步(可选 optional)

对数据集进行变换:一般收集到的图像大小尺寸,亮度等存在差异,变换的目的就是使得数据归一化。另一方面,可以通过变换进行数据增强

关于pytorch中的变换transforms,请参考该系列之前的文章

由于数据集中样本采用字典dicts形式表示。 因此不能直接调用torchvision.transofrms中的方法。

本实验进行了旋转,随机裁剪,调节图像的色彩饱和明暗等操作。

compose = transforms.Compose([

transforms.Resize(20),

transforms.RandomHorizontalFlip(),

transforms.RandomCrop(20),

transforms.ColorJitter(brightness=1, contrast=0.1, hue=0.5),

# transforms.ToTensor(),

# transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])

train_transformed = MY_MNIST(root='./mnist/MNIST/processed/training.pt', transform= compose)

#显示变换后的图像

for (cnt,i) in enumerate(train_transformed):

image = i['img']

# print image[0].sum()

# image = compose(image)

print 'sdsdadfasfasfasf',type(image)

label = i['target']

ax = plt.subplot(4, 4, cnt+1)

# ax.axis('off')

ax.imshow(image.squeeze(0))

ax.set_title(label)

plt.pause(0.001)

if cnt ==15:

break

变换后的图像,和之前对比,你发现了什么不同吗?

第四步: 使用DataLoader进行包装

为何要使用DataLoader?

① 深度学习的输入是mini_batch形式

② 样本加载时候可能需要随机打乱顺序,shuffle操作

③ 样本加载需要采用多线程

pytorch提供的DataLoader封装了上述的功能,这样使用起来更方便。

# 使用DataLoader可以利用多线程,batch,shuffle等

trainset_dataloader = DataLoader(dataset=transformed_trainset,

batch_size=4,

shuffle=True,

num_workers=4)

可视化:

dataloader = DataLoader(train, batch_size=50,shuffle=True, num_workers=4)

通过DataLoader包装之后,样本以min_batch形式输出,而且进行了随机打乱顺序。

for step, i in enumerate(dataloader):

b_x = i['img'].shape

b_y = i['target'].shape

print 'Step: ', step, '| train_data的维度' ,b_x,'| train_target的维度',b_y

如图图片大小已经裁剪为20*20,而且并行处理让60000个数据在3秒内就能处理好,效率非常高

Step: 1186 | train_data的维度 (50, 1, 20, 20) | train_target的维度 (50,)

Step: 1187 | train_data的维度 (50, 1, 20, 20) | train_target的维度 (50,)

Step: 1188 | train_data的维度 (50, 1, 20, 20) | train_target的维度 (50,)

Step: 1189 | train_data的维度 (50, 1, 20, 20) | train_target的维度 (50,)

Step: 1190 | train_data的维度 (50, 1, 20, 20) | train_target的维度 (50,)

Step: 1191 | train_data的维度 (50, 1, 20, 20) | train_target的维度 (50,)

Step: 1192 | train_data的维度 (50, 1, 20, 20) | train_target的维度 (50,)

Step: 1193 | train_data的维度 (50, 1, 20, 20) | train_target的维度 (50,)

Step: 1194 | train_data的维度 (50, 1, 20, 20) | train_target的维度 (50,)

Step: 1195 | train_data的维度 (50, 1, 20, 20) | train_target的维度 (50,)

Step: 1196 | train_data的维度 (50, 1, 20, 20) | train_target的维度 (50,)

Step: 1197 | train_data的维度 (50, 1, 20, 20) | train_target的维度 (50,)

Step: 1198 | train_data的维度 (50, 1, 20, 20) | train_target的维度 (50,)

Step: 1199 | train_data的维度 (50, 1, 20, 20) | train_target的维度 (50,)

未完待续…

以上这篇关于Pytorch的MNIST数据集的预处理详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/564587.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java集合基础_java常用集合基础知识

【纯出自个人笔记,如有错误,望改正,谢谢哈!学习~】一、Java集合1、集合类:容器类 装对象的(不能存放基本数据类型,但是里面看到的其实是包装类型)java.util包ArrayList底层是一个对象数组----------------------------…

java cookie安全_cookie的安全性问题

HTTP协议:(1)请求组成部分:请求行:(get或者post请求;请求路径(不包括主机) ;http1.1)请求头:请求头是浏览器交给服务器的一些信息(比较cookie啥的)请求体:只有post请求有请求体,get请…

python画环形图_用Python把图做的好看点:用Matplotlib画个环形图

P老板:小Lo啊,你觉得这几个图好看吗我:好看,好看P老板:我也觉得,这个月的月报,就用这个把,你给我几个,我看看我:.....于是乎,我们今天的目标是什么…

Java main()方法

在 Java 中,main() 方法是 Java 应用程序的入口方法,程序在运行的时候,第一个执行的方法就是 main() 方法。main() 方法和其他的方法有很大的不同。 下面先来看最简单的一个 Java 应用程序 HelloWorld,我们将通过这个例子讲解 Ja…

Java方法的可变参数

在具体实际开发过程中,有时方法中参数的个数是不确定的。为了解决这个问题,在 J2SE 5.0 版本中引入了可变参数的概念。 声明可变参数的语法格式如下: methodName({paramList},paramType…paramName)其中,methodName 表示方法名称…

python中ans的用法_python cx_Oracle基础使用方法

问题使用python操作oracle数据库,获取表的某几个字段作为变量值使用。使用Popensqlplus的方法需要对格式进行控制,通过流获取这几个字段值不简洁(个人观点……)。(优点是能够使用sqlplus的方法直接访问sql文件,不需要考虑打开/关闭连接&#…

java gradle 资源访问_java在gradle工程访问src/test/resources目录下的资源配置文件

package com.jiepu;import java.io.File;import java.net.URISyntaxException;import java.util.Map;import java.util.Properties;//java在gradle工程访问src/test/resources或者src/main/resources目录下的资源配置文件public class TestMain{public static void main(String…

python 匹配字符串map lambda函数_Python map amp; reduce 以及lambda匿名函数 - jvisualvm - ITeye博客...

map()map()函数接收两个参数,一个是函数,一个是Iterable,map将传入的函数依次作用到序列的每个元素,并把结果作为新的Iterator返回。使用map实现一个f(x) x * x的功能def f(x):return x * xm map(f, list(range(1, 10)))# [1, 4…

java执行数据库命令行_java程序执行命令行,解锁数据库表

有些表锁的时间长或其他原因,在plsql中不能解锁,只能用命令行解锁。有些功能跨平台系统的交互偶尔会锁表,就需要自动解锁。下面是解锁的代码:package com.lg.BreakOracleUtils;import com.lg.DB.DBProjp;import com.lg.database.D…

python display函数_【python】pandas display选项

import pandas as pd1、pd.set_option(expand_frame_repr, False)True就是可以换行显示。设置成False的时候不允许换行2、pd.set_option(display.max_rows, 10)pd.set_option(display.max_columns, 10)显示的最大行数和列数,如果超额就显示省略号,这个指…

Java查询个人信息

每个员工都会有自己的档案,主管可以查看在职员工的档案。使用 Java 创建一个员工实体类,然后通过构造方法创建一个名为“王洁”的员工,最后打印出员工档案信息。 1 . 创建 Person 类,在该类中定义个人基本信息属性,并…

java幻灯片播放代码_简单常用的幻灯片播放实现代码

幻灯片自动播放图片是当前网站比较流行的一个展示方式。在网上我们能找到各种特效丰富的幻灯片插件和代码。这里项目需要,我自己做了一个简单的,就不详细讲解了,代码很简单。直接看效果图和代码吧。所有代码 ppt.html,需要提供相应…

ssms没有弹出服务器验证_powerbi报表服务器搭建链接

作品展示​www.chinapowerbi.com安装 Power BI 报表服务器所要满足的硬件和软件要求 - Power BI​docs.microsoft.comDownload 用于基于 x64 的系统的 Windows 8.1 更新程序 (KB2919442) from Official Microsoft Download Center​www.microsoft.comDownload Windows Server 2…

Java析构方法

析构方法与构造方法相反,当对象脱离其作用域时,系统自动执行析构方法。析构方法往往用来做清理垃圾碎片的工作,例如在建立对象时用 new 开辟了一片内存空间,应退出前在析构方法中将其释放。 在 Java 的 Object 类中还提供了一个 …

2048java课程设计报告_2048小游戏-Java-课程设计报告书

《2048小游戏-Java-课程设计报告书》由会员分享,可在线阅读,更多相关《2048小游戏-Java-课程设计报告书(31页珍藏版)》请在金锄头文库上搜索。1、JAVA 语言程序设计课程设计报告2048 智力小游戏设计专业班级: 计算机科学与技术嵌入 13-1 学生…

python批量音频转格式_python将mp3格式批量转化为wav格式

语音识别无论是接口还是开源的项目,大多情况下都需要将语音格式转化为wav格式。首先,需要安装pydub库,pip install pydub 就行。接下来将你需要转化的mp3文件放入文件夹,创建好需要存入的wav文件夹。接下来python 代码实现 &#…

Java包详解

Java 引入了包(package)机制,提供了类的多层命名空间,用于解决类的命名冲突、类文件管理等问题。 包允许将类组合成较小的单元(类似文件夹),它基本上隐藏了类,并避免了名称上的冲突…

groovy java_在java中使用groovy怎么搞

临摹微笑一种基于Java虚拟机的动态语言,可以和java无缝集成,正是这个特性,很多时候把二者同时使用,把groovy作为java的有效补充。对于Java程序员来说,学习成本几乎为零。同时支持DSL和其他简介的语法(例如闭包)&#x…

python自动控制库_一个可以自动化控制鼠标键盘的库:PyAUtoGUI

PyAutoGUI 不知道你们有没有用过,它是一款用Python自动化控制键盘、鼠标的库。但凡是你不想手动重复操作的工作都可以用这个库来解决。如果,我想半夜时候定时给发个微信,或者每天自动刷页面等操作,它能完全模拟手动操作&#xff0…

Java使用自定义包

包的声明和使用非常简单,在了解基本语法之后,示例在 Java 程序中声明包,以及不同包之间类的使用。 1 创建一个名为 com.dao 的包。 2 向 com.dao 包中添加一个 Student 类,该类包含一个返回 String 类型数组的 GetAll() 方法。S…