Training a classifier

你已经学习了如何定义神经网络,计算损失和执行网络权重的更新。

现在你或许在思考。

What about data?

通常当你需要处理图像,文本,音频,视频数据,你能够使用标准的python包将数据加载进numpy数组。之后你能够转换这些数组到torch.*Tensor。

  • 对于图片,类似于Pillow,OPenCV的包很有用
  • 对于音频,类似于scipy和librosa的包
  • 对于文字,无论是基于原生python和是Cython的加载,或者NLTK和SpaCy都有效

对于视觉,我们特意创建了一个包叫做torchvision,它有常见数据集的数据加载,比如ImageNet,CIFAR10,MNIST等,还有图片的数据转换,torchvision.datasets和torch.utils.data.Dataloader。

这提供了很方便的实现,避免了写样板代码。

对于这一文章,我们将使用CIFAR10数据集。它拥有飞机,汽车,鸟,猫,鹿,狗,雾,马,船,卡车等类别。CIFAR-10的图片尺寸为3*32*32,也就是3个颜色通道和32*32个像素。

 

Training  an image classifier

 我们将按照顺序执行如下步骤:

  1. 使用torchvision加载并且标准化CIFAR10训练和测试数据集
  2. 定义一个卷积神经网络
  3. 定义损失函数
  4. 使用训练数据训练网络
  5. 使用测试数据测试网络

 

1.加载并标准化CIFAR10

使用torchvision,加载CIFAR10非常简单

import torch
import torchvision
import torchvision.transforms as transforms

torchvision数据集的输出是PIL图片库图片,范围为[0,1]。我们将它们转换为tensor并标准化为[-1,1]。

import torch
import torchvision
import torchvision.transforms as transformstransform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset
=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform) trainloader=torch.data.Dataloader(trainset,batch_size=4,shuffle=True,num_workers=2)
testset
=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform) testloader=torch.utils.data.Dataloader(testset,batch_size=4,shuffle=False,num_workers=2)
classes
= ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
out:
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
Files already downloaded and verified

 我们来观察一下训练集图片

import matplotlib.pyplot as plt
import numpy as npdef imshow(img):img=img/2+0.5npimg=img.numpy()plt.imshow(np.transpose(npimg,(1,2,0)))dataiter=iter(trainloader)
images,labels=dataiter.next()imshow(torchvision.utils.make_grid(images))
plt.show()
print(''.join('%5s'%classes[labels[j]] for j in range(4)))

 

out:
truck truck  dog truck

 

 2.定义卷积神经网络

从前面神经网络章节复制神经网络,并把它改成接受3维图片输入(而不是之前定义的一维图片)。

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net,self).__init__()self.conv1=nn.Conv2d(3,6,5)self.pool=nn.MaxPool2d(2,2)self.conv2=nn.Conv2d(6,16,5)self.fc1=nn.Linear(16*5*5,120)self.fc2=nn.Linear(120,84)self.fc3=nn.Linear(84,10)def forward(self,x):x=self.pool(F.relu(self.conv1(x)))x=self.pool(F.relu(self.conv2(x)))x=x.view(-1,16*5*5)x=F.relu(self.fc1(x))x=F.relu(self.fc2(x))x=self.fc3(x)return xnet=Net()

 

 3.定义损失函数和优化器

我们使用分类交叉熵损失和带有动量的SGD

import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

 

4.训练网络

我们只需要简单地迭代数据,把输入喂进网络并优化。

import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)for epoch in range(2):running_loss=0.0for i,data in enumerate(trainloader,0):inputs,labels=dataoptimizer.zero_grad()outputs=net(inputs)loss=criterion(outputs,labels)loss.backward()optimizer.step()running_loss+=loss.item()if i%2000==1999:print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss=0
print('Finished Training')
out:
[1,  2000] loss: 2.208
[1,  4000] loss: 1.797
[1,  6000] loss: 1.627
[1,  8000] loss: 1.534
[1, 10000] loss: 1.508
[1, 12000] loss: 1.453
[2,  2000] loss: 1.378
[2,  4000] loss: 1.365
[2,  6000] loss: 1.326
[2,  8000] loss: 1.309
[2, 10000] loss: 1.290
[2, 12000] loss: 1.262
Finished Training

 

 

4.在测试数据集上测试网络

我们已经遍历了两遍训练集来训练网络。需要检查下网络是不是已经学习到了什么。

我们将检查神经网络输出的预测标签是否与真实标签相同。如果预测是正确的,我们将这一样本加入到正确预测的列表。

我们先来熟悉一下训练图片。

dataiter=iter(testloader)
images,labes=dataiter.next()imshow(torchvision.utils.make_grid(images))
plt.show()
print('GroundTruth: ',' '.join('%5s' % classes[labels[j]] for j in range(4)))

 

out:
GroundTruth:  plane  deer   dog horse

 ok,现在让我们看一下神经网络认为这些样本是什么。

outputs=net(images)

 输出是10个类别的量值,大的值代表网络认为某一类的可能性更大。所以我们来获得最大值得索引:

_,predicted=torch.max(outputs,1)
print("Predicted: ",' '.join('%5s' %classes[predicted[j]] for j in range(4)))
out:
Predicted:   bird   dog  deer horse

 让我们看看整个数据集上的模型表现。

out:
Accuracy of the network on the 10000 test images: 54 %

 这看起来要好过瞎猜,随机的话只要10%的准确率(因为是10类)。看来网络是学习到了一些东西。

我们来继续看看在哪些类上的效果好,在哪些类上的效果比较差:

out:
Accuracy of plane : 56 %
Accuracy of   car : 70 %
Accuracy of  bird : 27 %
Accuracy of   cat : 16 %
Accuracy of  deer : 44 %
Accuracy of   dog : 64 %
Accuracy of  frog : 61 %
Accuracy of horse : 73 %
Accuracy of  ship : 68 %
Accuracy of truck : 61 %

好了,接下来该干点啥?

我们怎样将这个神经网络运行在GPU上呢?

Trainning on GPU

就像你怎么把一个Tensor转移到GPU上一样,现在把神经网络转移到GPU上。

如果我们有一个可用的CUDA,首先将我们的设备定义为第一个可见的cuda设备:

device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
out:
cuda:0

 剩下的章节我们假定我们的设备是CUDA。

之后这些方法将递归到所有模块,将其参数和缓冲区转换为CUDA张量:

net.to(device)

 记得你还需要在每步循环里将数据转移到GPU上:

inputs,labels=inputs.to(device),labels.to(device)

为什么没注意到相对于CPU巨大的速度提升?这是因为你的网络还非常小。

 

练习:尝试增加你网络的宽度(第一个nn.Conv2d的参数2应该与第二个nn.Conv2d的参数1是相等的数字),观察你得到的速度提升。

达成目标:

  • 更深一步理解Pytorch的Tensor库和神经网络
  • 训练一个小神经网络来分类图片

 Trainning on multiple GPUs

如果你想看到更加显著的GPU加速,请移步:https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html

 

转载于:https://www.cnblogs.com/Thinker-pcw/p/9637411.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/249527.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

19岁白帽子通过bug悬赏赚到一百万美元--转

出处:https://news.cnblogs.com/n/620858/ 19 岁的 Santiago Lopez 通过 bug 悬赏平台 HackerOne 报告漏洞,成为第一位通过 bug 悬赏赚到一百万美元的白帽子黑客。他的白帽子生涯始于 2015 年,至今共报告了超过 1600 个安全漏洞。他在 16 岁时…

代码分层的设计

分层思想,是应用系统最常见的一种架构模式,我们会将系统横向切割,根据业务职责划分。MVC 三层架构就是非常典型架构模式,划分的目的是规划软件系统的逻辑结构便于开发维护。MVC:英文即 Model-View-Controller&#xff…

【24小时内第四更】为什么我们要坚持写博客?

前言 从2018年7月份,我开始了写作博客之路。开始之前,我打算分享下之前的经历。去年初公司来了个架构师,内部分享过docker原理,TDD单元测试驱动,并发并行异步编程等内容,让我着实惊呆了,因为确实…

sqoop快速入门

转自http://www.aboutyun.com/thread-22549-1-1.html 转载于:https://www.cnblogs.com/drjava/p/10473297.html

ListableBeanFactory接口

ListableBeanFactory获取bean时,Spring 鼓励使用这个接口定义的api. 还有个Beanfactory方便使用.其他的4个接口都是不鼓励使用的. 提供容器中bean迭代的功能,不再需要一个个bean地查找.比如可以一次获取全部的bean(太暴力了),根据类型获取bean.在看SpringMVC时,扫描包路径下的…

HDU 4035 Maze

Maze http://acm.hdu.edu.cn/showproblem.php?pid4035 分析: 在树上走来走去,然后在一个点可以k的概率回到1,可以e的概率走出去,可以1-k-e的概率走到其他的位置(分为父节点和子节点讨论)。 转移方程就是&a…

面向对象之三大特性:继承,封装,多态

python面向对象的三大特性:继承,封装,多态。 1. 封装: 把很多数据封装到⼀个对象中. 把固定功能的代码封装到⼀个代码块, 函数, 对象, 打包成模块. 这都属于封装的思想. 具体的情况具体分析. 比如. 你写了⼀个很⽜B的函数. 那这个也可以被称为…

configurablebeanfactory

ConfigurableBeanFactory定义BeanFactory的配置.ConfigurableBeanFactory中定义了太多太多的api,比如类加载器,类型转化,属性编辑器,BeanPostProcessor,作用域,bean定义,处理bean依赖关系,合并其他ConfigurableBeanFactory,bean如何销毁. ConfigurableBeanFactory同时继承了Hi…

Xlua文件在热更新中调用方法

Xlua文件在热更新中调用方法 public class news : MonoBehaviour { LuaEnv luaEnv;//定义Lua初始变量 void Awake() { luaEnv new LuaEnv();//new开辟空间 luaEnv.AddLoader(myload);//调用方法地址、返回字节 luaEnv.DoString("requirefish");//更新文件 } void O…

springboot 使用的配置

1,控制台打印sql logging:level:com.sdyy.test.mapper: debug 2,开启驼峰命名 mybatis.configuration.map-underscore-to-camel-casetrue 转载于:https://www.cnblogs.com/xiaohu1218/p/10477318.html

AutowireCapableBeanFactory接口

AutowireCapableBeanFactory在BeanFactory基础上实现了对存在实例的管理.可以使用这个接口集成其它框架,捆绑并填充并不由Spring管理生命周期并已存在的实例.像集成WebWork的Actions 和Tapestry Page就很实用. 一般应用开发者不会使用这个接口,所以像ApplicationContext这样的…

外观模式

一、什么是外观模式   有些人可能炒过股票,但其实大部分人都不太懂,这种没有足够了解证券知识的情况下做股票是很容易亏钱的,刚开始炒股肯定都会想,如果有个懂行的帮帮手就好,其实基金就是个好帮手,支付宝…

OC内存管理

OC内存管理 一、基本原理 (一)为什么要进行内存管理。 由于移动设备的内存极其有限,所以每个APP所占的内存也是有限制的,当app所占用的内存较多时,系统就会发出内存警告,这时需要回收一些不需要再继续使用的…

cf1132E. Knapsack(搜索)

题意 题目链接 Sol 看了status里面最短的代码。。感觉自己真是菜的一批。。直接爆搜居然可以过&#xff1f;。。但是现在还没终测所以可能会fst。。 #include<bits/stdc.h> #define Pair pair<int, int> #define MP(x, y) make_pair(x, y) #define fi first #defi…

ConfigurableListableBeanFactory

ConfigurableListableBeanFactory 提供bean definition的解析,注册功能,再对单例来个预加载(解决循环依赖问题). 貌似我们一般开发就会直接定义这么个接口了事.而不是像Spring这样先根据使用情况细分那么多,到这边再合并 ConfigurableListableBeanFactory具体&#xff1a; 1、…

焦旭超 201771010109《面向对象程序设计课程学习进度条》

《2018面向对象程序设计&#xff08;java&#xff09;课程学习进度条》 周次 &#xff08;阅读/编写&#xff09;代码行数 发布博客量/博客评论量 课堂/课余学习时间&#xff08;小时&#xff09; 最满意的编程任务 第一周 50/20 1/0 6/4 九九乘法表 第二周 90/5…

面试题集锦

1. L1范式和L2范式的区别 (1) L1范式是对应参数向量绝对值之和 (2) L1范式具有稀疏性 (3) L1范式可以用来作为特征选择&#xff0c;并且可解释性较强&#xff08;这里的原理是在实际Loss function 中都需要求最小值&#xff0c;根据L1的定义可知L1最小值只有0&#xff0c;故可以…

Spring注解配置工作原理源码解析

一、背景知识 在【Spring实战】Spring容器初始化完成后执行初始化数据方法一文中说要分析其实现原理&#xff0c;于是就从源码中寻找答案&#xff0c;看源码容易跑偏&#xff0c;因此应当有个主线&#xff0c;或者带着问题、目标去看&#xff0c;这样才能最大限度的提升自身代…

halt

关机 init 0 reboot init6 shutdown -r now 重启 -h now 关机 转载于:https://www.cnblogs.com/todayORtomorrow/p/10486123.html

Spring--Context

应用上下文 Spring通过应用上下文&#xff08;Application Context&#xff09;装载bean的定义并把它们组装起来。Spring应用上下文全权负责对象的创建和组装。Spring自带了多种应用上下文的实现&#xff0c;它们之间主要的区别仅仅在于如何加载配置。 1.AnnotationConfigApp…