PyTorch进行CIFAR-10图像分类

本节将通过一个实战案例来详细介绍如何使用PyTorch进行深度学习模型的开发。我们将使用CIFAR-10图像数据集来训练一个卷积神经网络。

神经网络训练的一般步骤如图5-3所示。

(1)加载数据集,并做预处理。

(2)预处理后的数据分为Feature和Label两部分,Feature 送到模型里面,Label被当作ground-truth。

(3)Model接收Feature作为Input,并通过一系列运算,向外输出 predict。

(4)建立一个损失函数 Loss,Loss 的函数值是为了表示 predict 与 ground-truth 之间的差距。

(5)建立 Optimizer 优化器,优化的目标就是 Loss 函数,让它的取值尽可能最小,Loss越小代表 Model 预测的准确率越高。

(6)Optimizer 优化过程中,Model 根据规则改变自身参数的权重,这是一个反复循环和持续的过程,直到Loss值趋于稳定,不能再取得更小的值。

数据集的加载可以自行编写代码,但如果是基于学习目的的话,那么把精力放在编写这个步骤的代码上面会让人十分无聊,好在PyTorch 提供了非常方便的包torchvision。torchvison提供了dataloader来加载常见的MNIST、CIFAR-10、ImageNet 等数据集,也提供了transform对图像进行变换、正则化和可视化。

在本项目中,我们的目的是用 PyTorch 创建基于 CIFAR-10 数据集的图像分类器。CIFAR-10图像数据集共有60 000幅彩色图像,这些图像是32×32的,分为10个类,分别是airplane、automobile、bird、cat等,每类6 000幅图,如图5-4所示。这里面有50 000幅训练图像,10 000幅测试图像。

首先,加载数据并进行预处理。我们将使用torchvision包来下载CIFAR-10数据集,并使用transforms模块对数据进行预处理。主要用来进行数据增强,为了防止训练出现过拟合,通常在小型数据集上,通过随机翻转图片、随机调整图片的亮度来增加训练时数据集的容量。但是,测试的时候,并不需要对数据进行增强。运行代码后,会自动下载数据集。

接下来,定义卷积神经网络模型。在这个网络模型中,我们使用nn.Module来定义网络模型,然后在__init__方法中定义网络的层,最后在forward方法中定义网络的前向传播过程。在PyTorch中可以通过继承nn.Module来自定义神经网络,在init()中设定结构,在forward()中设定前向传播的流程。因为PyTorch可以自动计算梯度,所以不需要特别定义反向传播。

定义好神经网络模型后,还需要定义损失函数(Loss)和优化器(Optimizer)。在这里采用 cross-entropy-loss函数作为损失函数,采用 Adam 作为优化器,当然SGD也可以。

一切准备就绪后,开始训练网络,这里训练10次(可以增加训练次数,提高准确率)。在训练过程中,首先通过网络进行前向传播得到输出,然后计算输出与真实标签的损失,接着通过后向传播计算梯度,最后使用优化器更新模型参数。训练完成后,我们需要在测试集上测试网络的性能。这可以让我们了解模型在未见过的数据上的表现如何,以评估其泛化能力。

完整代码如下:

#############cifar-10-pytorch.py####################
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim# torchvision输出的是PILImage,值的范围是[0, 1]
# 我们将其转换为张量数据,并归一化为[-1, 1]
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5),std=(0.5, 0.5, 0.5)),])# 训练集,将相对目录./data下的cifar-10-batches-py文件夹中的全部数据
# (50 000幅图片作为训练数据)加载到内存中
# 若download为True,则自动从网上下载数据并解压
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)# 将训练集的50 000幅图片划分成12 500份,每份4幅图,用于mini-batch输入
# shffule=True在表示不同批次的数据遍历时,打乱顺序。num_workers=2表示使用两个子进程来加载数据
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 下面的代码只是为了给小伙伴们展示一个图片例子,让大家有个直观感受
# functions to show an image
import matplotlib.pyplot as plt
import numpy as np# matplotlib inline
def imshow(img):img = img / 2 + 0.5  # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()
class Net(nn.Module):# 定义Net的初始化函数,这个函数定义了该神经网络的基本结构def __init__(self):super(Net, self).__init__()# 复制并使用Net的父类的初始化方法,即先运行nn.Module的初始化函数self.conv1 = nn.Conv2d(3, 6, 5)# 定义conv1函数是图像卷积函数:输入为3张特征图# 输出为 6幅特征图, 卷积核为5×5的正方形self.conv2 = nn.Conv2d(6, 16, 5)# 定义conv2函数的是图像卷积函数:输入为6幅特征图,输出为16幅特征图# 卷积核为5×5的正方形self.fc1 = nn.Linear(16 * 5 * 5, 120)# 定义fc1(fullconnect)全连接函数1为线性函数:y = Wx + b# 并将16×5×5个节点连接到120个节点上self.fc2 = nn.Linear(120, 84)# 定义fc2(fullconnect)全连接函数2为线性函数:y = Wx + b# 并将120个节点连接到84个节点上self.fc3 = nn.Linear(84, 10)# 定义fc3(fullconnect)全连接函数3为线性函数:y = Wx + b# 并将84个节点连接到10个节点上# 定义该神经网络的向前传播函数,该函数必须定义# 一旦定义成功,向后传播函数也会自动生成(autograd)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))# 输入x经过卷积conv1之后,经过激活函数ReLU# 使用2×2的窗口进行最大池化,然后更新到xx = F.max_pool2d(F.relu(self.conv2(x)), 2)# 输入x经过卷积conv2之后,经过激活函数ReLU# 使用2×2的窗口进行最大池化,然后更新到xx = x.view(-1, self.num_flat_features(x))# view函数将张量x变形成一维的向量形式# 总特征数并不改变,为接下来的全连接作准备x = F.relu(self.fc1(x))# 输入x经过全连接1,再经过ReLU激活函数,然后更新xx = F.relu(self.fc2(x))# 输入x经过全连接2,再经过ReLU激活函数,然后更新xx = self.fc3(x)# 输入x经过全连接3,然后更新xreturn x# 使用num_flat_features函数计算张量x的总特征量# 把每个数字都作一个特征,即特征总量# 比如x是4×2×2的张量,那么它的特征总量就是16def num_flat_features(self, x):size = x.size()[1:]# 这里为什么要使用[1:],是因为PyTorch只接受批输入# 也就是说一次性输入好几幅图片,那么输入数据张量的维度自然上升到了4维# 【1:】让我们把注意力放在后3维上面# x.size() 会 return [nSamples, nChannels, Height, Width]。# 只需要展开后三项成为一个一维的张量num_features = 1for s in size:num_features *= sreturn num_features
net = Net()
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 使用SGD(随机梯度下降)优化,学习率为0.001,动量为0.9
if __name__ == '__main__':for epoch in range(10):running_loss = 0.0# enumerate(sequence, [start=0]),i是序号,data是数据for i, data in enumerate(trainloader, 0):inputs, labels = data# data的结构是:[4×3×32×32的张量,长度为4的张量]inputs, labels = Variable(inputs), Variable(labels)# 把input数据从tensor转为variableoptimizer.zero_grad()# 将参数的grad值初始化为0# forward + backward + optimizeoutputs = net(inputs)loss = criterion(outputs, labels)# 将output和labels使用交叉熵计算损失loss.backward()  # 反向传播optimizer.step()  # 用SGD更新参数# 每2000批数据打印一次平均loss值running_loss += loss.item()# loss本身为Variable类型# 要使用data获取其张量,因为其为标量,所以取0 或使用loss.item()if i % 2000 == 1999:  # 每2000批打印一次print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')# 测试集,将相对目录./data下的cifar-10-batches-py文件夹中的全部数据# (10 000幅图片作为测试数据)加载到内存中# 若download为True,则自动从网上下载数据并解压testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)# 将测试集的10 000幅图片划分成2500份,每份4幅图,用于mini-batch输入testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(Variable(images))# print outputs.data# print(outputs.data)# print(labels)value, predicted = torch.max(outputs.data,1)# outputs.data是一个4x10张量# 将每一行的最大的那一列的值和序号各自组成一个一维张量返回# 第一个是值的张量,第二个是序号的张量# label.size(0) 是一个数total += labels.size(0)correct += (predicted == labels).sum()# 两个一维张量逐行对比,相同的行记为1,不同的行记为0# 再利用sum()求总和,得到相同的个数print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

运行结果如下:

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
[1,  2000] loss: 2.165
[1,  4000] loss: 1.834
[1,  6000] loss: 1.667
[1,  8000] loss: 1.566
[1, 10000] loss: 1.532
[1, 12000] loss: 1.462
Files already downloaded and verified
Files already downloaded and verified
[2,  2000] loss: 1.403
[2,  4000] loss: 1.380
[2,  6000] loss: 1.325
[2,  8000] loss: 1.281
[2, 10000] loss: 1.304
[2, 12000] loss: 1.262
Files already downloaded and verified
Files already downloaded and verified
[3,  2000] loss: 1.230
[3,  4000] loss: 1.221
[3,  6000] loss: 1.181
[3,  8000] loss: 1.147
[3, 10000] loss: 1.175
[3, 12000] loss: 1.147
Files already downloaded and verified
Files already downloaded and verified
[4,  2000] loss: 1.120
[4,  4000] loss: 1.110
[4,  6000] loss: 1.079
[4,  8000] loss: 1.064
[4, 10000] loss: 1.090
[4, 12000] loss: 1.068
Files already downloaded and verified
Files already downloaded and verified
[5,  2000] loss: 1.039
[5,  4000] loss: 1.030
[5,  6000] loss: 1.009
[5,  8000] loss: 0.990
[5, 10000] loss: 1.021
[5, 12000] loss: 1.007
Files already downloaded and verified
Files already downloaded and verified
[6,  2000] loss: 0.975
[6,  4000] loss: 0.971
[6,  6000] loss: 0.947
[6,  8000] loss: 0.937
[6, 10000] loss: 0.963
[6, 12000] loss: 0.953
Files already downloaded and verified
Files already downloaded and verified
[7,  2000] loss: 0.930
[7,  4000] loss: 0.923
[7,  6000] loss: 0.902
[7,  8000] loss: 0.891
[7, 10000] loss: 0.928
[7, 12000] loss: 0.911
Files already downloaded and verified
Files already downloaded and verified
[8,  2000] loss: 0.881
[8,  4000] loss: 0.890
[8,  6000] loss: 0.864
[8,  8000] loss: 0.868
[8, 10000] loss: 0.896
[8, 12000] loss: 0.875
Files already downloaded and verified
Files already downloaded and verified
[9,  2000] loss: 0.846
[9,  4000] loss: 0.870
[9,  6000] loss: 0.836
[9,  8000] loss: 0.834
[9, 10000] loss: 0.851
[9, 12000] loss: 0.847
Files already downloaded and verified
Files already downloaded and verified
[10,  2000] loss: 0.816
[10,  4000] loss: 0.835
[10,  6000] loss: 0.797
[10,  8000] loss: 0.805
[10, 10000] loss: 0.841
[10, 12000] loss: 0.809
Finished Training
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Accuracy of the network on the 10000 test images: 61 %
Files already downloaded and verified
Files already downloaded and verified
Accuracy of plane : 58 %
Accuracy of   car : 72 %
Accuracy of  bird : 41 %
Accuracy of   cat : 51 %
Accuracy of  deer : 55 %
Accuracy of   dog : 44 %
Accuracy of  frog : 66 %
Accuracy of horse : 72 %
Accuracy of  ship : 80 %
Accuracy of truck : 69 %

在这段代码中,我们在整个测试集上测试网络,并打印出网络在测试集上的准确率。通过这种详细且实践性的方式介绍了PyTorch的使用,包括张量操作、自动求导机制、神经网络创建、数据处理、模型训练和测试。我们利用PyTorch从头到尾完成了一个完整的神经网络训练流程,并在 CIFAR-10数据集上测试了网络的性能。在这个过程中,我们深入了解了PyTorch提供的强大功能。

本文节选自《PyTorch深度学习与企业级项目实战》,获出版社和作者授权发布。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/838065.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

sin^2(x) 的图像

[TOC](sin^2(x) 的图像) 正文 这里记录一下 s i n 2 x sin^2{x} sin2x 的图像。 函数值以正弦的形式在 [0, 1] 区间内波动。 如果大家觉得有用,就点个赞让更多的人看到吧~

Ant Design Vue 的组件库的<a-tab-pane>的force-render

在使用类似 Ant Design Vue 的组件库时&#xff0c;force-render 属性通常用于指示是否强制渲染标签页的内容&#xff0c;即使它还没有被显示。通常&#xff0c;在一个标签页组件&#xff08;如 <a-tab-pane>&#xff09;中&#xff0c;内容只有在用户激活该标签页时才会…

嵌入式文件系统

嵌入式文件系统 文件系统简介 在计算机系统中&#xff0c; 需要用到大量的程序和数据&#xff0c; 它们大部分以文件的形式存放在外部存储当中&#xff0c; 根据需要可随时调入内存使用 如果用户直接管理外存文件所面临的问题&#xff1a; 必须熟悉外存的物理特性了解各种存…

英语新概念2-回译法-lesson14

第一次回译 I had a amusing experience. I kept driving to the next town when I left a small village in the south of Franch.A teenager waved to me in the path.I stopped my car, he asked me to give him a lift.As soon as he get my car,I say good morning in Fr…

【Kubernetes】污点、容忍度、亲和性、调度和重启策略

标签、污点、容忍度、亲和性 一、标签1、定义2、给资源打标签【1】给Pod打标签【2】给Service打标签【3】给Node打标签 3、查看资源标签 二、节点选择器1、nodeName2、nodeSelector 三、污点、容忍度、亲和性1、node节点亲和性【1】硬亲和性【2】软亲和性 2、pod节点亲和性【1】…

LeetCode hot100-36-N

94. 二叉树的中序遍历给定一个二叉树的根节点 root &#xff0c;返回 它的 中序 遍历 。这题都写不出来&#xff0c;废了 以前学数据结构的时候&#xff0c;书上老师写伪代码&#xff0c;于是递归版的遍历在伪代码中看起来就很简单。但是怎么输出这个list呢&#xff0c;其实是另…

安泰电子电压放大器应用及示例是什么样的

电压放大器是电子电路中常用的一种器件&#xff0c;用于将输入信号的电压放大至所需的输出电压。它在许多领域中有着重要的应用&#xff0c;包括通信、音频放大、仪器测量等。以下是电压放大器的一些应用及示例&#xff1a; 信号处理&#xff1a;在许多电子系统中&#xff0c;需…

Kasawaki川崎机器人故障维修

在当今的自动化工业领域&#xff0c;川崎工业机器人以其卓越的性能和可靠的工作效率赢得了广泛的赞誉。作为机器人的核心组成部分&#xff0c;伺服电机的作用至关重要。然而&#xff0c;就像所有机械设备一样&#xff0c;也可能会遭遇电机磨损或故障&#xff0c;需要适时的川崎…

vue自定义权限指令

定义v-hasPermi指令 /*** v-hasPermi 操作权限处理*/import useUserStore from /store/modules/userexport default {mounted(el, binding, vnode) {const { value } bindingconst all_permission "*:*:*";const permissions useUserStore().permissions&#xff…

linux - 搭建部署ftp服务器

ftp 服务: 实现ftp功能的一个服务,安装vsftpd软件搭建一台ftp服务器 ftp协议: 文件传输协议 (file transfer protocol),在不同的机器之间实现文件传输功能, 例如 视频文件下载,源代码文件下载 公司内部:弄一个专门的文件服务器,将公司里的文档资料和视频都存放…

基于死区补偿的永磁同步电动机矢量控制系统simulink仿真模型

整理了基于死区补偿的永磁同步电动机矢量控制系统simulink仿真&#xff0c;该模型使用线性死区补偿的PMSM矢量控制算法进行仿真&#xff0c;使用Foc电流双闭环 。 1.模块划分清晰&#xff0c;补偿前后仿真有对比&#xff0c;易于学习; 2.死区补偿算法的线性区区域可调; 3.自…

5.13网络编程

只要在一个电脑中的两个进程之间可以通过网络进行通信那么拥有公网ip的两个计算机的通信是一样的。但是一个局域网中的两台电脑上的虚拟机是不能进行通信的&#xff0c;因为这两个虚拟机在电脑中又有各自的局域网所以通信很难实现。 socket套接字是一种用于网络间进行通信的方…

Python接口自动化测试之动态数据处理

在前面的知识基础上介绍了在接口自动化测试中&#xff0c;如何把数据分离出来&#xff0c;并且找到它的共同点&#xff0c;然后依据这个共同点来找到解决复杂问题的思想。我一直认为&#xff0c;程序是人设计的&#xff0c;它得符合人性&#xff0c;那么自动化测试的&#xff0…

自由职业是种怎样的体验?普通人如何成为一名自由职业者?

自由职业在哪都能办公自由职业在哪都要办公。 放弃幻想&#xff0c;没有不辛苦的工作&#xff0c;5年经验后端开发程序员&#xff0c;已经从事自由职业1年半&#xff0c;今天就来客观分享一下自由职业的利与弊。 时间自由&#xff0c;减少中间商赚差价 自由职业最让人羡慕的就…

React Native 开发心得分享

有一段时间没更新了&#xff0c;花了点时间研究了下 React Native&#xff08;后续用 RN 简称&#xff09;&#xff0c;同时也用该技术作为我的毕设项目(一个校园社交应用&#xff0c;仿小红书)&#xff0c;经过了这段时间的疯狂折腾&#xff0c;对 RN 生态有了一定的了解&…

【实战selenium框架下在爱企查爬取企业的历史变更信息】文末附Google浏览器和驱动的下载

代码如下 # 导入包 import random import time from tkinter import filedialog import tkinter as tk import xlrd import os import datetime import csv from selenium import webdriver from selenium.webdriver import Keys from selenium.webdriver.common.by import By…

图搜索算法-最小生成树问题-普里姆算法(prim)

相关文章&#xff1a; 数据结构–图的概念 图搜索算法 - 深度优先搜索法&#xff08;DFS&#xff09; 图搜索算法 - 广度优先搜索法&#xff08;BFS&#xff09; 图搜索算法 - 拓扑排序 图搜索算法-最短路径算法-戴克斯特拉算法 图搜索算法-最短路径算法-贝尔曼-福特算法 图搜索…

Flutter 中的 AlertDialog 小部件:全面指南

Flutter 中的 AlertDialog 小部件&#xff1a;全面指南 在Flutter中&#xff0c;AlertDialog是一个用于显示警告、错误、信息或者确认消息的模态对话框。它提供了一种简单而直接的方式与用户进行交流&#xff0c;通常用于需要用户注意的重要信息或者需要用户做出决策的场合。本…

【069】基于SpringBoot+Vue实现的企业资产管理系统

系统介绍 基于SpringBootVue实现的企业资产管理系统管理员功能有个人中心&#xff0c;用户管理&#xff0c;资产分类管理&#xff0c;资产信息管理&#xff0c;资产借出管理&#xff0c;资产归还管理&#xff0c;资产维修管理。用户可以对资产进行借出和归还操作。因而具有一定…

计算机类的英语

Algorithm&#xff08;算法&#xff09;Binary code&#xff08;二进制代码&#xff09;Byte&#xff08;字节&#xff09;Cache&#xff08;缓存&#xff09;Database&#xff08;数据库&#xff09;Encryption&#xff08;加密&#xff09;Firewall&#xff08;防火墙&#x…