3.pytorch cifar10

数据集

CIFAR10 是由 Hinton 的学生 Alex Krizhevsky、Ilya Sutskever 收集的一个用于普适物体识别的计算机视觉数据集,它包含 60000 张 32 X 32 的 RGB 彩色图片,总共 10 个分类。
这些类别分别是飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。其中,包括 50000 张用于训练集,10000 张用于测试集。

run

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import time
import os# transform 的作用主要是用来对数据进行预处理。
transform = transforms.Compose([transforms.RandomHorizontalFlip(), # 随机翻转图片 , 数据增强transforms.RandomGrayscale(), # 随机调整图片的亮度transforms.ToTensor(), # 数据集加载时,默认的图片格式是numpy,所以通过transforms转换成 Tensor。然后再对输入图片进行标准化。transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 给定均值:(R,G,B) 方差:(R,G,B),将会把Tensor正则化
])transform1 = transforms.Compose([transforms.ToTensor(), # 测试的时候,并不需要对数据进行增强transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,shuffle=True # shuffle = True 表明提取数据时,随机打乱顺序)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform1)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,shuffle=False)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')class Net(nn.Module):def __init__(self):super(Net,self).__init__()self.conv1 = nn.Conv2d(3,64,3,padding=1)self.conv2 = nn.Conv2d(64,64,3,padding=1)self.pool1 = nn.MaxPool2d(2, 2)self.bn1 = nn.BatchNorm2d(64)self.relu1 = nn.ReLU()self.conv3 = nn.Conv2d(64,128,3,padding=1)self.conv4 = nn.Conv2d(128, 128, 3,padding=1)self.pool2 = nn.MaxPool2d(2, 2, padding=1)self.bn2 = nn.BatchNorm2d(128)self.relu2 = nn.ReLU()self.conv5 = nn.Conv2d(128,128, 3,padding=1)self.conv6 = nn.Conv2d(128, 128, 3,padding=1)self.conv7 = nn.Conv2d(128, 128, 1,padding=1)self.pool3 = nn.MaxPool2d(2, 2, padding=1)self.bn3 = nn.BatchNorm2d(128)self.relu3 = nn.ReLU()self.conv8 = nn.Conv2d(128, 256, 3,padding=1)self.conv9 = nn.Conv2d(256, 256, 3, padding=1)self.conv10 = nn.Conv2d(256, 256, 1, padding=1)self.pool4 = nn.MaxPool2d(2, 2, padding=1)self.bn4 = nn.BatchNorm2d(256)self.relu4 = nn.ReLU()self.conv11 = nn.Conv2d(256, 512, 3, padding=1)self.conv12 = nn.Conv2d(512, 512, 3, padding=1)self.conv13 = nn.Conv2d(512, 512, 1, padding=1)self.pool5 = nn.MaxPool2d(2, 2, padding=1)self.bn5 = nn.BatchNorm2d(512)self.relu5 = nn.ReLU()self.fc14 = nn.Linear(512*4*4,1024)self.drop1 = nn.Dropout2d()self.fc15 = nn.Linear(1024,1024)self.drop2 = nn.Dropout2d()self.fc16 = nn.Linear(1024,10)def forward(self,x):x = self.conv1(x)x = self.conv2(x)x = self.pool1(x)x = self.bn1(x)x = self.relu1(x)x = self.conv3(x)x = self.conv4(x)x = self.pool2(x)x = self.bn2(x)x = self.relu2(x)x = self.conv5(x)x = self.conv6(x)x = self.conv7(x)x = self.pool3(x)x = self.bn3(x)x = self.relu3(x)x = self.conv8(x)x = self.conv9(x)x = self.conv10(x)x = self.pool4(x)x = self.bn4(x)x = self.relu4(x)x = self.conv11(x)x = self.conv12(x)x = self.conv13(x)x = self.pool5(x)x = self.bn5(x)x = self.relu5(x)# print(" x shape ",x.size())x = x.view(-1,512*4*4)x = F.relu(self.fc14(x))x = self.drop1(x)x = F.relu(self.fc15(x))x = self.drop2(x)x = self.fc16(x)return xdef train_sgd(self,device):optimizer = optim.Adam(self.parameters(), lr=0.0001)path = 'weights.tar'initepoch = 0if os.path.exists(path) is not True:loss = nn.CrossEntropyLoss()# optimizer = optim.SGD(self.parameters(),lr=0.01)else:checkpoint = torch.load(path)self.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])initepoch = checkpoint['epoch']loss = checkpoint['loss']for epoch in range(initepoch,100):  # loop over the dataset multiple timestimestart = time.time()running_loss = 0.0total = 0correct = 0for i, data in enumerate(trainloader, 0):# get the inputsinputs, labels = datainputs, labels = inputs.to(device),labels.to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = self(inputs)l = loss(outputs, labels)l.backward()optimizer.step()# print statisticsrunning_loss += l.item()# print("i ",i)if i % 500 == 499:  # print every 500 mini-batchesprint('[%d, %5d] loss: %.4f' %(epoch, i, running_loss / 500))running_loss = 0.0_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the %d tran images: %.3f %%' % (total,100.0 * correct / total))total = 0correct = 0torch.save({'epoch':epoch,'model_state_dict':net.state_dict(),'optimizer_state_dict':optimizer.state_dict(),'loss':loss},path)print('epoch %d cost %3f sec' %(epoch,time.time()-timestart))print('Finished Training')def test(self,device):correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = self(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %.3f %%' % (100.0 * correct / total))device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = Net()
net = net.to(device)
net.train_sgd(device)
net.test(device)

总结

  • 下载的数据是numpy格式,shape:HWC, 会转换成tensor,shape:CHW
  • torchvision 下载不是图像原始数据,是经过处理转换的numpy
  • plt.imshow(),输出的是HWC 格式图像信息

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/586354.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

服务号和订阅号哪个好

服务号和订阅号有什么区别?服务号转为订阅号有哪些作用?在推送频率上来看,服务号每月能推送四条消息,而订阅号可以每天(24小时)推送一条消息。如果企业开通公众号的目的是提供服务,例如售前资讯…

动态规划 | 鸡蛋问题 | 元旦假期来点“蛋”题

文章目录 鸡蛋掉落 - 两枚鸡蛋题目描述动态规划解法问题分析程序代码 鸡蛋掉落题目描述问题分析程序代码复杂度分析 鸡蛋掉落 - 两枚鸡蛋 题目描述 原题链接 给你 2 枚相同 的鸡蛋,和一栋从第 1 层到第 n 层共有 n 层楼的建筑。 已知存在楼层 f ,满足 …

我的学习C#回炉学习日志——Lua热更新06_模块

模块 个人理解:lua的包比起C#,基本上就是一个table记录所有信息,包括变量、常量等 module {} module.constant "一个常量" function module.func1()io.write("一个共有函数\n") endlocal function func2()-- bodypr…

C语言注释的使用与理解

什么是注释? 在编程中,注释(Comment)是一种非执行文本,它用于为代码提供解释、说明和文档。注释的内容不参与程序的实际编译和运行过程,其主要目的是提高代码的可读性和可维护性,方便开发者以及…

MySQL:排序和分组

1、排序 order by 用于对结果集按照一个列或者多个列进行排序。默认按照升序对记录进行排序,如果需要按照降序对记录进行排序,可以使用 desc 关键字。 order by 对多列排序的时候,先排序的列放前面,后排序的列放后面。并且&…

Python字典类型key找value或者value找key方法汇总

字典中,如何通过唯一的value获取key 如果传入的值在字典的值中不存在,可以返回一个特定的默认值或者抛出一个异常来表示该情况。以下是两种处理方式的示例: 返回默认值: def get_key_by_value(dictionary, value, defaultNone)…

JavaScript:函数隐含对象arguments/剩余参数. . .c/解构赋值

除了this,在函数内部还存在着一个隐含的参数arguments arguments 是一个类数组对象(伪数组) 调用函数时传递的所有实参,都被存储在arguments中 arguments[0] 表示的是第一个实参 arguments[1] 表示的是第二个实参 以此类推..…

2022年全球运维大会(GOPS深圳站)-核心PPT资料下载

一、峰会简介 GOPS 主要面向运维行业的中高端技术人员,包括运维、开发、测试、架构师等群体。目的在于帮助IT技术从业者系统学习了解相关知识体系,让创新技术推动社会进步。您将会看到国内外知名企业的相关技术案例,也能与国内顶尖的技术专家…

【数据结构】链式家族的成员——循环链表与静态链表

循环链表与静态链表 导言一、循环链表1.1 循环单链表1.2 循环双链表 二、静态链表2.1 静态链表的创建2.2 静态链表的初始化2.3 小结 结语 导言 大家好!很高兴又和大家见面啦!!! 经过前面的介绍,相信大家对链式家族的…

软件测试/测试开发丨Mac Appium环境搭建

Mac 上 Appium 环境搭建 安装 nodejs 与 npm 安装方式与 windows 类似 ,官网下载对应的 mac 版本的安装包,双击即可安装,无须配置环境变量。官方下载地址:https://nodejs.org/en/download/ 安装 appium Appium 分为两个版本&a…

【Transformer】深入理解Transformer模型1——初步认识了解

前言 Transformer模型出自论文:《Attention is All You Need》 2017年 近年来,在自然语言处理领域和图像处理领域,Transformer模型都受到了极为广泛的关注,很多模型中都用到了Transformer或者是Transformer模型的变体&#xff0…

ElasticSearch--基本操作

ElasticSearch 完成ES安装 http://101.42.93.208:5601/app/dev_tools#/console 库的操作 创建索引库 请求方式:PUT 请求路径:/索引库名,可以自定义 请求参数:mapping映射 PUT /test {"mappings": {"propertie…

计算机硬件 4.3显示器

第三节 显示器 一、基本概念 1.定义:将电信号转换为可以直接看到的图像的最基本输出设备。 2.分类:按显示色彩分:单色显示器、彩色显示器。 按显示原理分:CRT显示器、LCD显示器、LED显示器、OLED显示器。 3.原理结构&#xff…

【Oracle】 Oracle Sequence 性能优化

Sequence是很简单的,如果最大程度利用默认值的话,我们只需要定义sequence对象的名字即可。在序列Sequence对象的定义中,Cache是一个可选择的参数。默认的Sequence对象是有cache选项的,默认取值为20。这个默认值对于大多数情况下都…

云原生|kubernetes|kubernetes资源备份和集群迁移神器velero的部署和使用

前言: kubernetes集群需要灾备吗?kubernetes需要迁移吗? 答案肯定是需要的 那么,如何做kubernetes灾备和迁移呢?当然了,有很多的方法,例如,自己编写shell脚本,或者使用…

2023年江苏省职业院校技能大赛高职组“软件测试”赛项接口测试答案报告(含术语)

2023年江苏省职业院校技能大赛高职组“软件测试”赛项接口测试答案报告 接口测试要求: 1、执行接口测试 本部分按照软件接口测试文档要求,执行接口测试;使用接口测试工具PostMan,编写脚本、配置参数、执行接口测试并且截图。截图需粘贴在接口测试总结报告中。接口测试具体…

06-C++ 模板

模板、类型转换 模板 1. 简介 一种用于实现 通用编程 的机制。 将 数据类型 可以作为参数进行传递 。 通过使用模板&#xff0c;我们可以编写可复用的代码&#xff0c;可以适用于多种数据类型。 c模板的语法使用尖括号 < > 来表示泛型类型&#xff0c;并使用关键字…

啊?这也算事务?!

作者简介&#xff1a;大家好&#xff0c;我是smart哥&#xff0c;前中兴通讯、美团架构师&#xff0c;现某互联网公司CTO 联系qq&#xff1a;184480602&#xff0c;加我进群&#xff0c;大家一起学习&#xff0c;一起进步&#xff0c;一起对抗互联网寒冬 学习必须往深处挖&…

Java中的Optional类使用技巧

在Java中&#xff0c;Optional 是一个可以为null的容器对象。如果值存在则isPresent()方法返回true&#xff0c;调用get()方法会返回该对象。 使用Optional可以有效地防止NullPointerException。 下面是一些使用Optional的技巧&#xff1a; 创建Optional对象&#xff1a; Opt…

php-m和phpinfo之间不一致的问题的可能原因和解决办法

1.不同的 PHP配置文件: php -m 和 phpinfo 可能会使用不同的 PHP 配置文件。确保它们都使用相同的配置文件。你可以在命令行中使用 php --ini 来查找当前使用的配置文件位置&#xff0c;并在 phpinfo 中查看 Loaded Configuration File 来确保它们相同。 2.不同的 PHP 版本:确…