第P2周:Pytorch实现CIFAR10彩色图片识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目标

  1. 实现CIFAR-10的彩色图片识别
  2. 实现比P1周更复杂一点的CNN网络

具体实现

(一)环境

语言环境:Python 3.10
编 译 器: PyCharm
框 架: Pytorch 2.5.1

(二)具体步骤
1.
import torch  
import torch.nn as nn  
import matplotlib.pyplot as plt  
import torchvision  # 第一步:设置GPU  
def USE_GPU():  if torch.cuda.is_available():  print('CUDA is available, will use GPU')  device = torch.device("cuda")  else:  print('CUDA is not available. Will use CPU')  device = torch.device("cpu")  return device  device = USE_GPU()  

输出:CUDA is available, will use GPU

  
# 第二步:导入数据。同样的CIFAR-10也是torch内置了,可以自动下载  
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,  transform=torchvision.transforms.ToTensor())  
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,  transform=torchvision.transforms.ToTensor())  batch_size = 32  
train_dataload = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  
test_dataload = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)  # 取一个批次查看数据格式  
# 数据的shape为:[batch_size, channel, height, weight]  
# 其中batch_size为自己设定,channel,height和weight分别是图片的通道数,高度和宽度。  
imgs, labels = next(iter(train_dataload))  
print(imgs.shape)  # 查看一下图片  
import numpy as np  
plt.figure(figsize=(20, 5))  
for i, images in enumerate(imgs[:20]):  # 使用numpy的transpose将张量(C,H, W)转换成(H, W, C),便于可视化处理  npimg = imgs.numpy().transpose((1, 2, 0))  # 将整个figure分成2行10列,并绘制第i+1个子图  plt.subplot(2, 10, i+1)  plt.imshow(npimg, cmap=plt.cm.binary)  plt.axis('off')  
plt.show()  

输出:
Files already downloaded and verified
Files already downloaded and verified
torch.Size([32, 3, 32, 32])
image.png

# 第三步,构建CNN网络  
import torch.nn.functional as F  num_classes = 10  # 因为CIFAR-10是10种类型  
class Model(nn.Module):  def __init__(self):  super(Model, self).__init__()  # 提取特征网络  self.conv1 = nn.Conv2d(3, 64, 3)  self.pool1 = nn.MaxPool2d(kernel_size=2)  self.conv2 = nn.Conv2d(64, 64, 3)  self.pool2 = nn.MaxPool2d(kernel_size=2)  self.conv3 = nn.Conv2d(64, 128, 3)  self.pool3 = nn.MaxPool2d(kernel_size=2)  # 分类网络  self.fc1 = nn.Linear(512, 256)  self.fc2 = nn.Linear(256, num_classes)  # 前向传播  def forward(self, x):  x = self.pool1(F.relu(self.conv1(x)))  x = self.pool2(F.relu(self.conv2(x)))  x = self.pool3(F.relu(self.conv3(x)))  x = torch.flatten(x, 1)  x = F.relu(self.fc1(x))  x = self.fc2(x)  return x  from torchinfo import summary  
# 将模型转移到GPU中  
model = Model().to(device)  
summary(model)  

image.png

# 训练模型  
loss_fn = nn.CrossEntropyLoss() # 创建损失函数  
learn_rate = 1e-2   # 设置学习率  
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)    # 设置优化器  # 编写训练函数  
def train(dataloader, model, loss_fn, optimizer):  size = len(dataloader.dataset) # 训练集的大小 ,这里一共是60000张图片  num_batches = len(dataloader)   # 批次大小,这里是1875(60000/32=1875)  train_acc, train_loss = 0, 0    # 初始化训练正确率和损失率都为0  for X, y in dataloader: # 获取图片及标签,X-图片,y-标签(也是实际值)  X, y = X.to(device), y.to(device)  # 计算预测误差  pred = model(X) # 网络输出预测值  loss = loss_fn(pred, y) # 计算网络输出的预测值和实际值之间的差距  # 反向传播  optimizer.zero_grad()   # grad属性归零  loss.backward() # 反向传播  optimizer.step()    # 第一步自动更新  # 记录正确率和损失率  train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()  train_loss += loss.item()  train_acc /= size  train_loss /= num_batches  return train_acc, train_loss  # 测试函数  
def test(dataloader, model, loss_fn):  size = len(dataloader.dataset) # 测试集大小,这里一共是10000张图片  num_batches = len(dataloader)   # 批次大小 ,这里312,即10000/32=312.5,向上取整  test_acc, test_loss = 0, 0  # 因为是测试,因此不用训练,梯度也不用计算不用更新  with torch.no_grad():  for imgs, target in dataloader:  imgs, target = imgs.to(device), target.to(device)  # 计算loss  target_pred = model(imgs)  loss = loss_fn(target_pred, target)  test_loss += loss.item()  test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()  test_acc /= size  test_loss /= num_batches  return test_acc, test_loss  # 正式训练  
epochs = 10  
train_acc, train_loss, test_acc, test_loss = [], [], [], []  for epoch in range(epochs):  model.train()  epoch_train_acc, epoch_train_loss = train(train_dataload, model, loss_fn, opt)  model.eval()  epoch_test_acc, epoch_test_loss = test(test_dataload, model, loss_fn)  train_acc.append(epoch_train_acc)  train_loss.append(epoch_train_loss)  test_acc.append(epoch_test_acc)  test_loss.append(epoch_test_loss)  template = 'Epoch:{:2d}, 训练正确率:{:.1f}%, 训练损失率:{:.3f}, 测试正确率:{:.1f}%, 测试损失率:{:.3f}'  print(template.format(epoch+1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))  print('Done')  # 结果可视化  
# 隐藏警告  
import warnings  
warnings.filterwarnings('ignore')   # 忽略警告信息  
plt.rcParams['font.sans-serif'] = ['SimHei']    # 正常显示中文标签  
plt.rcParams['axes.unicode_minus'] = False  # 正常显示+/-号  
plt.rcParams['figure.dpi'] = 100    # 分辨率  epochs_range = range(epochs)  plt.figure(figsize=(12, 3))  plt.subplot(1, 2, 1)    # 第一张子图  
plt.plot(epochs_range, train_acc, label='训练正确率')  
plt.plot(epochs_range, test_acc, label='测试正确率')  
plt.legend(loc='lower right')  
plt.title('训练和测试正确率比较')  plt.subplot(1, 2, 2)    # 第二张子图  
plt.plot(epochs_range, train_loss, label='训练损失率')  
plt.plot(epochs_range, test_loss, label='测试损失率')  
plt.legend(loc='upper right')  
plt.title('训练和测试损失率比较')  plt.show()# 保存模型  
torch.save(model, './models/cnn-cifar10.pth')

image.png
再次设置epochs为50训练结果:
image.png
epochs增加到100,训练结果:
image.png
可以看到训练集和测试集的差距有点大,不太理想。做一下数据增加试试:

data_transforms= {  'train': transforms.Compose([  transforms.RandomHorizontalFlip(),  transforms.ToTensor(),  ]),  'test': transforms.Compose([  transforms.ToTensor(),  ])  
}

在dataset中:

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,  transform=data_transforms['train'])  
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=data_transforms['test'])

运行结果:
image.png
image.png
比较漂亮了,再调整batch_size=16和epochs=20,提高了近6个百分点。
image.png
batch_size=16,epochs=50:有第20轮左右的时候,验证集的确认性基本就没有再提高了。和上面基本一样。
image.png

(三)总结
  1. epochs并不是越多越好。batch_size同样的道理
  2. 数据增强确实可以提高模型训练的准确性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/62905.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Quant connect的优势和不足,学习曲线难

Quant connect的优势和不足 Quant connect作为一个成熟的算法交易平台,具有许多优势,包括: 强大的回测功能:Quant connect提供了丰富的数据源和回测功能,可以对各种交易策略进行全面的回测和分析。 容易上手&#xf…

深入理解 Ansible Playbook:组件与实战

目录 1 playbook介绍 2 YAML语言 2.1语法简介 2.2数据类型 3 Playbook核心组件 3.1 hosts组件 3.2 remote_user组件 3.3 task列表和action组件 3.4 handlers 3.5 tags组件 3.6 其他组件说明 1 playbook介绍 playbook 剧本是由一个或多个"play"组成的列表。…

2024年食堂采购系统源码技术趋势:如何开发智能的供应链管理APP

本篇文章,小编将与大家一同探讨2024年食堂采购系统的技术趋势,并提供开发更智能的供应链管理APP的策略。 一、2024年食堂采购系统的技术趋势 1.人工智能与机器学习的深度应用 在2024年,AI和机器学习在食堂采购系统中的应用将更加普遍。这些…

代码随想录-算法训练营-番外(图论01:图论理论基础,所有可到达的路径)

day01 图论part01 今日任务:图论理论基础/所有可到达的路径 代码随想录图论视频部分还没更新 https://programmercarl.com/kamacoder/图论理论基础.html#图的基本概念 day01 所有可达路径 邻接矩阵 import java.util.Scanner;import java.util.List;import java.util.ArrayL…

系统架构的演变

什么是系统架构? 系统架构是系统的一种整体的高层次的结构表示,它确定了系统的基本组织、组件之间的关系、组件与环境的关系,以及指导其设计和发展的原则。随着技术的发展和业务需求的增长,系统架构经历了从简单到复杂、从集中到…

c++总复习

C 中多态性在实际项目中的应用场景 图形绘制系统 描述:在一个图形绘制软件中,可能有多种图形,如圆形、矩形、三角形等。这些图形都有一个共同的操作,比如绘制(draw)。通过多态性,可以定义一个基…

pip离线安装一个github仓库

要使用pip安装一个本地Git仓库,你可以按照以下步骤操作: 确保你已经克隆了Git仓库到本地。 进入仓库所在的目录。 使用pip安装。 以下是具体的命令: 克隆Git仓库到本地(替换下面的URL为你的仓库URL) git clone https…

【从零开始入门unity游戏开发之——C#篇04】栈(Stack)和堆(Heap),值类型和引用类型,以及特殊的引用类型string

文章目录 知识回顾一、栈(Stack)和堆(Heap)1、什么是栈和堆2、为什么要分栈和堆3、栈和堆的区别栈堆 4、总结 二、值类型和引用类型1、那么值类型和引用类型到底有什么区别呢?值类型引用类型 2、总结 三、特殊的引用类…

【C语言实现:用队列模拟栈与用栈模拟队列(LeetCode 225 232)】

LeetCode刷题记录 🌐 我的博客主页:iiiiiankor🎯 如果你觉得我的内容对你有帮助,不妨点个赞👍、留个评论✍,或者收藏⭐,让我们一起进步!📝 专栏系列:LeetCode…

【Python】Selenium 爬虫的使用技巧和案例

引言 Selenium 是 Python 中功能强大的自动化测试工具,因其能够操控浏览器进行模拟操作,被广泛应用于网页数据爬取。相比传统的 requests 等库,Selenium 能更好地应对动态加载内容和复杂交互场景。本文将详细介绍 Selenium 爬虫的使用技巧,并提供实际案例来帮助读者快速上…

MySQL SQL语句性能优化

MySQL SQL语句性能优化指南 一、查询设计优化1. 避免 SELECT *2. 使用 WHERE 进行条件过滤3. 避免在索引列上使用函数和表达式4. 使用 LIMIT 限制返回行数5. 避免使用子查询6. 优化 JOIN 操作7. 避免全表扫描 二、索引优化1. 使用合适的索引2. 覆盖索引3. 索引选择性4. 多列索引…

Mybatis动态sql执行过程

动态SQL的执行原理主要涉及到在运行时根据条件动态地生成SQL语句,然后将其发送给数据库执行。以下是动态SQL执行原理的详细解释: 一、接收参数 动态SQL首先会根据用户的输入或系统的条件接收参数。这些参数可以是查询条件、更新数据等,它们…

java jar包加密 jar-protect

介绍 java 本身是开放性极强的语言,代码也容易被反编译,没有语言层面的一些常规保护机制,jar包很容易被反编译和破解。 受classfinal(已停止维护)设计启发,针对springboot日常项目开发,重新编写安全可靠的jar包加壳加密技术,用于保护软件版权。 使用说…

Linux:Git

Git常见指令: git help xx_command git xx_command --help git --version 查看git版本git config --global user.name "xxx_name" 全局级别的签名设置,全局的放在本用 git config --global user.ema…

【WiFi】WiFi中RSSI、SNR、NF之间关系及说明

RSSI(接收信号强度指示) 定义: RSSI 是一个相对值,用于表示接收到的无线信号的强度。它通常由无线设备的硬件(如无线网卡或无线芯片)直接提供。 计算: RSSI 的计算通常是由设备的无线芯片完成的…

提升音频转录准确性:VAD技术的应用与挑战

引言 在音频转录技术飞速发展的今天,我们面临着一个普遍问题:在嘈杂环境中,转录系统常常将非人声误识别为人声,导致转录结果出现错误。例如,在whisper模式下,系统可能会错误地转录出“谢谢大家”。本文将探…

[ZMQ] -- ZMQ通信Protobuf数据结构 1

1、前言背景 工作需要域间实现zmq通信,刚开始需要比较简单的数据结构,比如两个bool,后面可能就需要传输比较大的数据,所以记录下实现流程,至于为啥选择proto数据结构去做大数据传输,可能是地平线也用这个&…

顺序表的使用,对数据的增删改查

主函数: 3.c #include "3.h"//头文件调用 SqlListptr sql_cerate()//创建顺序表函数 {SqlListptr ptr(SqlListptr)malloc(sizeof(SqlList));//在堆区申请连续的空间if(NULLptr){printf("创建失败\n");return NULL;//如果没有申请成功&#xff…

React和Vue中暴露子组件的属性和方法给父组件用,并且控制子组件暴露的颗粒度的做法

React 在 React 中,forwardRef 是一种高级技术,它允许你将 ref 从父组件传递到子组件,从而直接访问子组件的 DOM 节点或公开的方法。这对于需要操作子组件内部状态或 DOM 的场景非常有用。为了使子组件能够暴露其属性和方法给父组件&#xf…

《C++ 实时视频流物体跟踪与行为分析全解析》

在当今科技飞速发展的时代,视频监控与智能分析技术在众多领域发挥着极为重要的作用。从安防监控到智能交通,从工业自动化到人机交互,利用 C 处理实时视频流中的物体跟踪和行为分析成为了热门且极具挑战性的研究与开发方向。本文将深入探讨其中…