深度學習筆記14-CIFAR10彩色圖片識別(Pytorch)

  • 🍨 本文為🔗365天深度學習訓練營 中的學習紀錄博客
  • 🍖 原作者:K同学啊 | 接輔導、項目定制

一、我的環境

  • 電腦系統:Windows 10

  • 顯卡:NVIDIA GeForce GTX 1060 6GB

  • 語言環境:Python 3.7.0

  • 開發工具:Sublime Text,Command Line(CMD)

  • 深度學習環境:1.12.1+cu113


二、準備套件

# PyTorch 的核心模組,包含了張量操作、自動微分、神經網絡構建、優化器等
import torch# PyTorch 的神經網絡模組,包含了各種神經網絡層和相關操作的類別和函數
import torch.nn as nn# Matplotlib 的繪圖模組,用於創建各種圖表和視覺化數據
import matplotlib.pyplot as plt# PyTorch 的計算機視覺工具包,包含了常用的數據集、模型和圖像轉換操作
import torchvision# 一個用於數值計算的 Python 庫,提供了高效的數組和矩陣操作功能
import numpy as np# PyTorch 的函數式神經網絡操作模組,包含了神經網絡中常用的操作,例如激活函數、損失函數等
import torch.nn.functional as F# 提供 PyTorch 模型的詳細摘要信息,包括層數、參數數量和輸出形狀等,類似於 Keras 的 model.summary()
from torchinfo import summary# 隱藏警告
import warnings

三、環境準備

# 忽略警告訊息
warnings.filterwarnings("ignore")  # 輸出 PyTorch 的版本號
print(torch.__version__)# 檢查是否有可用的CUDA設備,否則使用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 打印出當前使用的設備
print(device)


四、載入數據

# 載入CIFAR-10訓練數據集
train_ds = torchvision.datasets.CIFAR10('data',  # 數據下載後保存的目錄train=True,   # 指定載入訓練數據集transform=torchvision.transforms.ToTensor(), # 將圖像轉換為Tensordownload=True # 如果數據集不存在,則從網絡下載
)test_ds  = torchvision.datasets.CIFAR10('data',   # 數據下載後保存的目錄train=False,   # 指定載入測試數據集transform=torchvision.transforms.ToTensor(),   # 將圖像轉換為Tensordownload=True  # 如果數據集不存在,則從網絡下載
)


五、數據預處理

# 定義每個批次的大小為32
batch_size = 32# 創建訓練數據的DataLoader
train_dl = torch.utils.data.DataLoader(train_ds,   # 訓練數據集batch_size=batch_size,   # 每個批次包含32個樣本shuffle=True # 在每個epoch開始時打亂數據
)# 創建測試數據的DataLoader
test_dl  = torch.utils.data.DataLoader(test_ds, # 測試數據集batch_size=batch_size # 每個批次包含32個樣本# 測試數據集不需要shuffle,默認為False
)# 從訓練數據加載器中取出一個批次的圖像和標籤
imgs, labels = next(iter(train_dl))
# 打印圖像的形狀 (batch_size, channels, height, width)
print(imgs.shape)


六、圖片可視化

# 創建一個大小為 (20, 5) 的圖形
plt.figure(figsize=(20, 5)) 
# 遍歷前20個圖像
for i, imgs in enumerate(imgs[:20]):# 將圖像從 (channels, height, width) 轉換為 (height, width, channels) 以便於顯示npimg = imgs.numpy().transpose((1, 2, 0))# 在2行10列的子圖中繪製圖像plt.subplot(2, 10, i+1)# 顯示圖像,使用灰度色彩映射plt.imshow(npimg, cmap=plt.cm.binary)# 隱藏坐標軸plt.axis('off')
# 顯示圖形
plt.show()


七、定義模型

# 定義分類數量(CIFAR-10有10個類別)
num_classes = 10 # 定義模型類
class Model(nn.Module):def __init__(self):super().__init__()# 定義第一個卷積層,輸入通道為3(CIFAR-10圖像的RGB通道),輸出通道為64,卷積核大小為3x3self.conv1 = nn.Conv2d(3, 64, kernel_size=3)  # 定義第一個池化層,使用2x2的最大池化self.pool1 = nn.MaxPool2d(kernel_size=2)       # 定義第二個卷積層,輸入通道為64,輸出通道為64,卷積核大小為3x3self.conv2 = nn.Conv2d(64, 64, kernel_size=3)   # 定義第二個池化層,使用2x2的最大池化self.pool2 = nn.MaxPool2d(kernel_size=2) # 定義第三個卷積層,輸入通道為64,輸出通道為128,卷積核大小為3x3self.conv3 = nn.Conv2d(64, 128, kernel_size=3)    # 定義第三個池化層,使用2x2的最大池化self.pool3 = nn.MaxPool2d(kernel_size=2) # 定義第一個全連接層,輸入大小為512,輸出大小為256self.fc1 = nn.Linear(512, 256)          # 定義第二個全連接層,輸出大小為 num_classes(分類的數量)self.fc2 = nn.Linear(256, num_classes)# 定義前向傳播def forward(self, x):# 通過第一個卷積層、ReLU激活函數和池化層x = self.pool1(F.relu(self.conv1(x)))     # 通過第二個卷積層、ReLU激活函數和池化層x = self.pool2(F.relu(self.conv2(x)))# 通過第三個卷積層、ReLU激活函數和池化層x = self.pool3(F.relu(self.conv3(x)))# 將特徵圖展平為一維向量x = torch.flatten(x, start_dim=1)# 通過第一個全連接層和ReLU激活函數x = F.relu(self.fc1(x))# 通過第二個全連接層,輸出為分類結果x = self.fc2(x)# 返回輸出return x# 將模型移動到指定設備(GPU或CPU)
model = Model().to(device)# 打印模型結構摘要
summary(model)


八、定義訓練函數

# 定義損失函數為交叉熵損失
loss_fn = nn.CrossEntropyLoss() # 設定學習率為0.01
learn_rate = 1e-2 # 使用隨機梯度下降(SGD)優化器
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)# 定義訓練函數
def train(dataloader, model, loss_fn, optimizer):# 獲取訓練集的大小size = len(dataloader.dataset)  # 獲取批次數量num_batches = len(dataloader) # 初始化訓練損失和準確率train_loss, train_acc = 0, 0  # 遍歷訓練數據for X, y in dataloader: # 將輸入和標籤移動到指定設備(GPU或CPU)X, y = X.to(device), y.to(device)# 前向傳播:計算模型預測pred = model(X)         # 計算損失loss = loss_fn(pred, y)# 反向傳播前清零梯度optimizer.zero_grad()  # 反向傳播:計算梯度loss.backward()        # 更新模型參數optimizer.step()     # 計算訓練準確率train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()# 累加訓練損失train_loss += loss.item()# 計算平均訓練準確率train_acc /= size# 計算平均訓練損失train_loss /= num_batchesreturn train_acc, train_loss

九、定義測試函數

def test(dataloader, model, loss_fn):# 獲取測試集的大小size = len(dataloader.dataset) # 獲取批次數量num_batches = len(dataloader)  # 初始化測試損失和準確率test_loss, test_acc = 0, 0# 禁用梯度計算(加速推理過程)with torch.no_grad():# 遍歷測試數據for imgs, target in dataloader:# 將輸入和標籤移動到指定設備(GPU或CPU)imgs, target = imgs.to(device), target.to(device)# 前向傳播:計算模型預測target_pred = model(imgs)# 計算損失loss = loss_fn(target_pred, target)# 累加測試損失test_loss += loss.item()# 計算測試準確率test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()# 計算平均測試準確率test_acc /= size# 計算平均測試損失test_loss /= num_batchesreturn test_acc, test_loss

十、模型訓練

epochs     = 10   # 訓練的回合數
train_loss = []   # 存儲每個回合的訓練損失
train_acc  = []   # 存儲每個回合的訓練準確率
test_loss  = []   # 存儲每個回合的測試損失
test_acc   = []   # 存儲每個回合的測試準確率# 訓練和測試循環
for epoch in range(epochs):model.train()  # 將模型設置為訓練模式epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)  # 訓練模型並返回訓練準確率和損失model.eval()  # 將模型設置為評估模式epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)  # 測試模型並返回測試準確率和損失train_acc.append(epoch_train_acc)  # 存儲訓練準確率train_loss.append(epoch_train_loss)  # 存儲訓練損失test_acc.append(epoch_test_acc)  # 存儲測試準確率test_loss.append(epoch_test_loss)  # 存儲測試損失# 格式化並輸出當前回合的訓練和測試結果template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')


十一、結果可視化

# 設定 Matplotlib 的參數以支援中文和負號的顯示
plt.rcParams['font.sans-serif']    = ['SimHei']  # 用來正常顯示中文標籤
plt.rcParams['axes.unicode_minus'] = False       # 用來正常顯示負號
plt.rcParams['figure.dpi']         = 100         # 設定圖表的解析度epochs_range = range(epochs)  # 訓練回合的範圍plt.figure(figsize=(12, 3))  # 設置圖表大小
plt.subplot(1, 2, 1)  # 創建一個2行1列的子圖佈局,這裡畫第一個圖# 畫出訓練和測試的準確率曲線
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')  # 顯示圖例
plt.title('Training and Validation Accuracy')  # 設定標題plt.subplot(1, 2, 2)  # 畫第二個圖
# 畫出訓練和測試的損失曲線
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')  # 顯示圖例
plt.title('Training and Validation Loss')  # 設定標題plt.show()  # 顯示圖表


十二、心得

最近開始使用 PyTorch 框架來訓練模型,以下是 PyTorch 和 TensorFlow 2 的差異說明

  • 動態圖 vs 靜態圖

    • PyTorch 使用動態計算圖:這意味著計算圖是即時定義的,每次迭代可以根據需要改變結構,更靈活,更容易進行調試和編程
    • TensorFlow 2 則引入了即時執行(Eager Execution),類似於 PyTorch 的動態圖模式,使得模型的建立和調試更加直觀和靈活,此外,TensorFlow 2 也可以使用靜態圖來進行更高效的低級優化和部署
  • API 設計

    • PyTorch 的 API 設計更直觀和簡潔,更貼近 Python 編程風格,使得學習曲線較平緩,特別適合研究和實驗
    • TensorFlow 2 則採用了 Keras 作為其主要高級 API,提供了更高層次的抽象和簡化,使得模型的定義和訓練更加容易,特別適合生產環境和大規模部署
  • 模型部署

    • TensorFlow 2 在模型訓練後的部署和生產環境中表現更加優異,支持較多的低級優化和部署工具(如 TensorFlow Serving)
    • PyTorch 雖然近年來在這方面有所改進,但相對而言仍有一定的差距,部署需要更多的自定義和額外的工作
  • 社區和生態系統

    • TensorFlow 擁有更大的社區支持和更成熟的生態系統,有更多的文檔、教程和預訓練模型可用
    • PyTorch 的社區雖然較小,但在學術界和研究領域中得到了廣泛的應用和支持,並且快速增長

這次作業中,我學到如何使用 PyTorch 的 torchvision 模組來加載 CIFAR-10 數據集,這個過程包括對圖片進行標準化和轉換,使其適合訓練模型,載入數據後使用 PyTorch 定義一個簡單的卷積神經網絡(CNN)來處理圖像分類任務,這裡使用了幾個卷積層和池化層,以及全連接層,之後定義損失函數和優化器,然後訓練模型,這裡使用交叉熵損失和隨機梯度下降(SGD)優化器,最後使用測試集來評估模型的表現,計算準確率和其他指標

通過調整不同的超參數和嘗試不同的模型架構,也意識到了如何優化模型以達到更好的性能,這個過程不僅加深了我對深度學習的理解,還增強了我解決實際問題的能力

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/36134.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ThreadX简介

文章目录 1. 摘要2. ThreadX的特性2.1 免费开源2.2 安全认证级别高2.3 组件完善2.4 实时性高2.5 支持多核2.6 支持应用动态加载2.7 代码符合MISAR规范2.8 文档全面,例程丰富2.9 集成方便3. 移植示例4. 产品应用示例1. 摘要 在嵌入式系统领域,实时性能、系统稳定性以及广泛的…

Camera开发-相机输出常用数据格式

作者简介: 一个平凡而乐于分享的小比特,中南民族大学通信工程专业研究生在读,研究方向无线联邦学习 擅长领域:驱动开发,嵌入式软件开发,BSP开发 作者主页:一个平凡而乐于分享的小比特的个人主页…

【鸿蒙培训】第一天环境安装

目录标题 安装DevEco Studio 【IDE】配置开发环境配置离线SDK创建工程配置离线插件 hvigor配置模拟器模拟器执行代码 安装DevEco Studio 【IDE】 1・解压 devecostudio-windows-4.1.3.500.zip。 2・执行 deveco-studio-4.1.3.500.exe 安装IDE。 配置开…

29. 深度学习中的损失函数及其数学性质详解

在深度学习中,优化算法的研究对象是损失函数。损失函数的数学性质对最优化求解过程至关重要。本文将详细介绍深度学习中的损失函数应具备的特性,帮助大家在后续的学习中避免概念上的误解。 函数的可微性和可导性 学过高等数学的同学对可微性和可导性已…

void * 返回类型 与 void *arg 参数的区别

void * 返回类型:void * 作为函数的返回类型,表示该函数可以返回任何类型的指针。void * 是一种特殊的指针类型,称为“无类型指针”或“泛型指针”,因为它可以指向任何类型的数据。函数通过返回 void * 类型的指针,提供…

JMeter安装与使用

安装包下载:https://pan.xunlei.com/s/VNigSM9IEjqNBVkw8by6i-LoA1?pwdu6gq# 也可以官网下载: 1.解压安装包 2.打开方式 (1)bin->ApacheJMeter.jar->打开界面 (2)如果(1)打…

ruby面试题

ruby 基础 1、each、map、collect的区别 each: 仅遍历数组,并做相应操作,数组本身不发生改变。 map:遍历数组,并做相应操作后,返回新数组(处理),原数组不变。 collect: 跟map作用一样。 collect! map!: 多了一个作…

LoadBalance 负载均衡

什么是负载均衡 负载均衡(Load Balance,简称 LB),是⾼并发,⾼可⽤系统必不可少的关键组件. 当服务流量增⼤时,通常会采⽤增加机器的⽅式进⾏扩容,负载均衡就是⽤来在多个机器或者其他资源中,按照⼀定的规则合理分配负载. 负载均衡的⼀些实现 服务多机部署时,开发⼈…

专业软件测试公司分享:安全测评对于软件产品的重要性

在互联网普及的今天,随着各类软件的大规模使用,安全问题也变得愈发突出。因此,对软件进行全面的安全测评,不仅可以有效保障用户的信息安全,还能提升软件产品的信任度和市场竞争力。 安全测评对于软件产品的重要性就如…

LLDB 详解

LLDB 详解 LLDB 详解编译器集成优势LLDB 的主要功能命令格式原始(raw)命令选项终止符: -- LLDB 中的变量唯一匹配原则helpexpressionprint、call、po控制流程:continue、next、step、finishregister read / writethread backtracethread retu…

线性代数|机器学习-P19SVDLUQR分解自由参数计算和鞍点

文章目录 1. 矩阵A分解1.1 A L U ALU ALU 1. 矩阵A分解 对于矩阵A来说,我们有常见矩阵分解: A L U , A Q R , A X Λ X − 1 , A Q Λ Q T ; A Q S , A S V D \begin{equation} ALU,AQR,AX\Lambda X^{-1},AQ\Lambda Q^T;AQS,ASVD \end{equatio…

React Native优质开源项目推荐与解析

目录 2. React Native的优势 2.1. 跨平台开发 2.2. 热更新 2.3. 丰富的社区资源 2.4. 优秀的性能 3. 优质开源项目推荐 3.1. React Navigation 3.1.1 项目简介 3.1.2 特点和优势 3.1.3 应用场景 3.2. Redux 3.2.1 项目简介 3.2.2 特点和优势 3.2.3 应用场景 3.3…

精彩回顾 | 2024高通边缘智能创新应用大赛系列公开课

5月29日-6月6日,我们陆续开展了四场精彩绝伦的2024高通边缘智能创新应用大赛直播公开课。高通、阿加犀、广翼智联以及美格智能的业务领袖和行业大咖齐聚一堂,聚焦边缘智能,分享前沿技术、探讨创新应用,抢先揭秘比赛设备的核心特性…

MIT6.s081 2021 Lab System calls

xv6系统调用实现 不同于 Lab1 利用已实现的系统调用来实现一些用户态下的命令行程序,本 Lab 是要在内核层面实现一些系统调用。这其中难免涉及到一些对内核数据结构的操作,以及处理器体系结构(本系列 Lab 基于 RISCV)相关的内容&…

什么是慢查询——Java全栈知识(26)

1、什么是慢查询 慢查询:也就是接口压测响应时间过长,页面加载时间过长的查询 原因可能如下: 1、聚合查询 2、多表查询 3、单表数据量过大 4、深度分页查询(limit) 如何定位慢查询? 1、Skywalking 我们…

js url参数转对象类型(对象类型转url参数)支持中文解码编码

先上代码 后面上函数参数说明以及调用返回结果 /** Author: 夏林* Date: 24.6.27* desc 时间差算法* params params -> 传入数据 String | Object* params _needEncode -> 是否需要编码 默认 true*/ export function dealUrlSearchParams(_params , _needEncode tr…

IND83081芯片介绍(一)

一、芯片介绍 IND83081是indiemicro推出的一款高性能的汽车矩阵LED照明控制器,集成了四个子模块,每个子模块包含三个串联的MOSFET开关,每个开关均可通过12位PWM内部信号控制,可配置的上升和下降速率及相位移以实现精确控制&#x…

JOSEF约瑟 JOXL-J拉绳开关 整定范围宽

用途 双向拉绳开关的壳体采用金属材料铸造,具有足够的机械强度,抵抗并下工作时脱落的岩石,爆块等物体的撞击不被破坏,当胶带输送机发生紧急事故时,启动拉绳开关,可立即停机报警,防止事故的扩大,保证工作现场的人身安全…

常用的通信协议有哪些

常用的通信协议有很多种,主要根据其应用领域和通信需求可以分为几类: 网络通信协议: TCP/IP:传输控制协议/互联网协议,用于互联网及局域网通信。 UDP:用户数据报协议,用于实时数据传输&#…

java 操作 milvus 2.1.4

1. 确认 docker 运行的 milvus容器镜像版本情况&#xff1a; 2. pom 依赖&#xff1a; <dependency><groupId>io.milvus</groupId><artifactId>milvus-sdk-java</artifactId><version>2.1.0</version><exclusions><exclusi…