Training a classifier

你已经学习了如何定义神经网络,计算损失和执行网络权重的更新。

现在你或许在思考。

What about data?

通常当你需要处理图像,文本,音频,视频数据,你能够使用标准的python包将数据加载进numpy数组。之后你能够转换这些数组到torch.*Tensor。

  • 对于图片,类似于Pillow,OPenCV的包很有用
  • 对于音频,类似于scipy和librosa的包
  • 对于文字,无论是基于原生python和是Cython的加载,或者NLTK和SpaCy都有效

对于视觉,我们特意创建了一个包叫做torchvision,它有常见数据集的数据加载,比如ImageNet,CIFAR10,MNIST等,还有图片的数据转换,torchvision.datasets和torch.utils.data.Dataloader。

这提供了很方便的实现,避免了写样板代码。

对于这一文章,我们将使用CIFAR10数据集。它拥有飞机,汽车,鸟,猫,鹿,狗,雾,马,船,卡车等类别。CIFAR-10的图片尺寸为3*32*32,也就是3个颜色通道和32*32个像素。

 

Training  an image classifier

 我们将按照顺序执行如下步骤:

  1. 使用torchvision加载并且标准化CIFAR10训练和测试数据集
  2. 定义一个卷积神经网络
  3. 定义损失函数
  4. 使用训练数据训练网络
  5. 使用测试数据测试网络

 

1.加载并标准化CIFAR10

使用torchvision,加载CIFAR10非常简单

import torch
import torchvision
import torchvision.transforms as transforms

torchvision数据集的输出是PIL图片库图片,范围为[0,1]。我们将它们转换为tensor并标准化为[-1,1]。

import torch
import torchvision
import torchvision.transforms as transformstransform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset
=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform) trainloader=torch.data.Dataloader(trainset,batch_size=4,shuffle=True,num_workers=2)
testset
=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform) testloader=torch.utils.data.Dataloader(testset,batch_size=4,shuffle=False,num_workers=2)
classes
= ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
out:
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
Files already downloaded and verified

 我们来观察一下训练集图片

import matplotlib.pyplot as plt
import numpy as npdef imshow(img):img=img/2+0.5npimg=img.numpy()plt.imshow(np.transpose(npimg,(1,2,0)))dataiter=iter(trainloader)
images,labels=dataiter.next()imshow(torchvision.utils.make_grid(images))
plt.show()
print(''.join('%5s'%classes[labels[j]] for j in range(4)))

 

out:
truck truck  dog truck

 

 2.定义卷积神经网络

从前面神经网络章节复制神经网络,并把它改成接受3维图片输入(而不是之前定义的一维图片)。

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net,self).__init__()self.conv1=nn.Conv2d(3,6,5)self.pool=nn.MaxPool2d(2,2)self.conv2=nn.Conv2d(6,16,5)self.fc1=nn.Linear(16*5*5,120)self.fc2=nn.Linear(120,84)self.fc3=nn.Linear(84,10)def forward(self,x):x=self.pool(F.relu(self.conv1(x)))x=self.pool(F.relu(self.conv2(x)))x=x.view(-1,16*5*5)x=F.relu(self.fc1(x))x=F.relu(self.fc2(x))x=self.fc3(x)return xnet=Net()

 

 3.定义损失函数和优化器

我们使用分类交叉熵损失和带有动量的SGD

import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

 

4.训练网络

我们只需要简单地迭代数据,把输入喂进网络并优化。

import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)for epoch in range(2):running_loss=0.0for i,data in enumerate(trainloader,0):inputs,labels=dataoptimizer.zero_grad()outputs=net(inputs)loss=criterion(outputs,labels)loss.backward()optimizer.step()running_loss+=loss.item()if i%2000==1999:print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss=0
print('Finished Training')
out:
[1,  2000] loss: 2.208
[1,  4000] loss: 1.797
[1,  6000] loss: 1.627
[1,  8000] loss: 1.534
[1, 10000] loss: 1.508
[1, 12000] loss: 1.453
[2,  2000] loss: 1.378
[2,  4000] loss: 1.365
[2,  6000] loss: 1.326
[2,  8000] loss: 1.309
[2, 10000] loss: 1.290
[2, 12000] loss: 1.262
Finished Training

 

 

4.在测试数据集上测试网络

我们已经遍历了两遍训练集来训练网络。需要检查下网络是不是已经学习到了什么。

我们将检查神经网络输出的预测标签是否与真实标签相同。如果预测是正确的,我们将这一样本加入到正确预测的列表。

我们先来熟悉一下训练图片。

dataiter=iter(testloader)
images,labes=dataiter.next()imshow(torchvision.utils.make_grid(images))
plt.show()
print('GroundTruth: ',' '.join('%5s' % classes[labels[j]] for j in range(4)))

 

out:
GroundTruth:  plane  deer   dog horse

 ok,现在让我们看一下神经网络认为这些样本是什么。

outputs=net(images)

 输出是10个类别的量值,大的值代表网络认为某一类的可能性更大。所以我们来获得最大值得索引:

_,predicted=torch.max(outputs,1)
print("Predicted: ",' '.join('%5s' %classes[predicted[j]] for j in range(4)))
out:
Predicted:   bird   dog  deer horse

 让我们看看整个数据集上的模型表现。

out:
Accuracy of the network on the 10000 test images: 54 %

 这看起来要好过瞎猜,随机的话只要10%的准确率(因为是10类)。看来网络是学习到了一些东西。

我们来继续看看在哪些类上的效果好,在哪些类上的效果比较差:

out:
Accuracy of plane : 56 %
Accuracy of   car : 70 %
Accuracy of  bird : 27 %
Accuracy of   cat : 16 %
Accuracy of  deer : 44 %
Accuracy of   dog : 64 %
Accuracy of  frog : 61 %
Accuracy of horse : 73 %
Accuracy of  ship : 68 %
Accuracy of truck : 61 %

好了,接下来该干点啥?

我们怎样将这个神经网络运行在GPU上呢?

Trainning on GPU

就像你怎么把一个Tensor转移到GPU上一样,现在把神经网络转移到GPU上。

如果我们有一个可用的CUDA,首先将我们的设备定义为第一个可见的cuda设备:

device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
out:
cuda:0

 剩下的章节我们假定我们的设备是CUDA。

之后这些方法将递归到所有模块,将其参数和缓冲区转换为CUDA张量:

net.to(device)

 记得你还需要在每步循环里将数据转移到GPU上:

inputs,labels=inputs.to(device),labels.to(device)

为什么没注意到相对于CPU巨大的速度提升?这是因为你的网络还非常小。

 

练习:尝试增加你网络的宽度(第一个nn.Conv2d的参数2应该与第二个nn.Conv2d的参数1是相等的数字),观察你得到的速度提升。

达成目标:

  • 更深一步理解Pytorch的Tensor库和神经网络
  • 训练一个小神经网络来分类图片

 Trainning on multiple GPUs

如果你想看到更加显著的GPU加速,请移步:https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html

 

转载于:https://www.cnblogs.com/Thinker-pcw/p/9637411.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/249527.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ListableBeanFactory接口

ListableBeanFactory获取bean时,Spring 鼓励使用这个接口定义的api. 还有个Beanfactory方便使用.其他的4个接口都是不鼓励使用的. 提供容器中bean迭代的功能,不再需要一个个bean地查找.比如可以一次获取全部的bean(太暴力了),根据类型获取bean.在看SpringMVC时,扫描包路径下的…

面向对象之三大特性:继承,封装,多态

python面向对象的三大特性:继承,封装,多态。 1. 封装: 把很多数据封装到⼀个对象中. 把固定功能的代码封装到⼀个代码块, 函数, 对象, 打包成模块. 这都属于封装的思想. 具体的情况具体分析. 比如. 你写了⼀个很⽜B的函数. 那这个也可以被称为…

configurablebeanfactory

ConfigurableBeanFactory定义BeanFactory的配置.ConfigurableBeanFactory中定义了太多太多的api,比如类加载器,类型转化,属性编辑器,BeanPostProcessor,作用域,bean定义,处理bean依赖关系,合并其他ConfigurableBeanFactory,bean如何销毁. ConfigurableBeanFactory同时继承了Hi…

外观模式

一、什么是外观模式   有些人可能炒过股票,但其实大部分人都不太懂,这种没有足够了解证券知识的情况下做股票是很容易亏钱的,刚开始炒股肯定都会想,如果有个懂行的帮帮手就好,其实基金就是个好帮手,支付宝…

OC内存管理

OC内存管理 一、基本原理 (一)为什么要进行内存管理。 由于移动设备的内存极其有限,所以每个APP所占的内存也是有限制的,当app所占用的内存较多时,系统就会发出内存警告,这时需要回收一些不需要再继续使用的…

面试题集锦

1. L1范式和L2范式的区别 (1) L1范式是对应参数向量绝对值之和 (2) L1范式具有稀疏性 (3) L1范式可以用来作为特征选择,并且可解释性较强(这里的原理是在实际Loss function 中都需要求最小值,根据L1的定义可知L1最小值只有0,故可以…

Spring注解配置工作原理源码解析

一、背景知识 在【Spring实战】Spring容器初始化完成后执行初始化数据方法一文中说要分析其实现原理,于是就从源码中寻找答案,看源码容易跑偏,因此应当有个主线,或者带着问题、目标去看,这样才能最大限度的提升自身代…

Spring--Context

应用上下文 Spring通过应用上下文(Application Context)装载bean的定义并把它们组装起来。Spring应用上下文全权负责对象的创建和组装。Spring自带了多种应用上下文的实现,它们之间主要的区别仅仅在于如何加载配置。 1.AnnotationConfigApp…

了解PID控制

2019-03-07 【小记】 了解PID控制 比例 - 积分 - 微分 积分 --- 记忆过去 比例 --- 了解现在 微分 --- 预测未来 转载于:https://www.cnblogs.com/skullboyer/p/10487884.html

program collections

Java byte & 0xff byte[] b new byte[1];b[0] -127;System.out.println("b[0]:"b[0]"; b[0]&0xff:"(b[0] & 0xff));//output:b[0]:-127; b[0]&0xff:129计算机内二进制都是补码形式存储: b[0]: 补码,10000001&…

Spring ConfigurationClassPostProcessor Bean解析及自注册过程

一bean的自注册过程 二,自注册过程说明 1 configurationclassparser解析流程 1、处理PropertySources注解,配置信息的解析 2、处理ComponentScan注解:使用ComponentScanAnnotationParser扫描basePackage下的需要解析的类(SpringBootApplication注解也包…

2019第二周作业

基础作业 实验代码 #include<stdlib.h> int main(void) {FILE*fp;int num[4],i,b,max;char op;if((fpfopen("c:\\tmj.txt","r"))NULL){ printf("File open error!\n"); exit(0);}for(i0;i<4;i){fscanf(fp,"%d%c",&nu…

实验一(高见老师收)

学 号201521450016 中国人民公安大学 Chinese people’ public security university 网络对抗技术 实验报告 实验一 网络侦查与网络扫描 学生姓名 陈璪琛 年级 2015 区队 五 指导教师 高见 信息技术与网络安全学院 2018年9月18日 实验任务总纲 2018—2019学年…

Spring 钩子之BeanFactoryPostProcessor和BeanPostProcessor

BeanFactoryPostProcessor和BeanPostProcessor这两个接口都是初始化bean时对外暴露的入口之一&#xff0c;和Aware类似&#xff08;PS:关于spring的hook可以看看Spring钩子方法和钩子接口的使用详解讲的蛮详细&#xff09;本文也主要是学习具体的钩子的细节&#xff0c;以便于实…

什么是HTML DOM对象

HTML DOM 对象 HTML DOM Document 对象 Document 对象 每个载入浏览器的 HTML 文档都会成为 Document 对象。 Document 对象使我们可以从脚本中对 HTML 页面中的所有元素进行访问。 提示&#xff1a;Document 对象是 Window 对象的一部分&#xff0c;可通过 window.document 属…

Python3 matplotlib的绘图函数subplot()简介

Python3 matplotlib的绘图函数subplot()简介 一、简介 matplotlib下, 一个 Figure 对象可以包含多个子图(Axes), 可以使用 subplot() 快速绘制, 其调用形式如下 : subplot(numRows, numCols, plotNum) 图表的整个绘图区域被分成 numRows 行和 numCols 列 然后按照从左到右&…

signal(SIGHUP, SIG_IGN);

signal(SIGHUP, SIG_IGN); 的理解转载于:https://www.cnblogs.com/lanjiangzhou/p/10505653.html

spring钩子

Spring钩子方法和钩子接口的使用详解 前言 SpringFramework其实具有很高的扩展性&#xff0c;只是很少人喜欢挖掘那些扩展点&#xff0c;而且官方的Refrence也很少提到那些Hook类或Hook接口&#xff0c;至于是不是Spring官方有意为之就不得而知。本文浅析一下笔者目前看到的S…

day 012 生成器 与 列表推导式

生成器的本质就是迭代器&#xff0c;写法和迭代器不一样&#xff0c;用法一样。 获取方法&#xff1a; 1、通过生成器函数 2、通过各种推导式来实现生成器 3、通过数据的转换也可以获取生成器 例如&#xff1a; 更改return 为 yield 即成为生成器 该函数就成为了一个生成器函数…

20172325 2018-2019-1 《Java程序设计》第二周学习总结

20172325 2018-2019-1 《Java程序设计》第二周学习总结 教材学习内容总结 3.1集合 集合是一种聚集、组织了其他对象的对象。集合可以分为两大类&#xff1a;线性集合和非线性集合。线性集合&#xff1a;一种其元素按照直线方式组织的集合。非线性集合&#xff1a;一种其元素按某…