深度學習筆記14-CIFAR10彩色圖片識別(Pytorch)

  • 🍨 本文為🔗365天深度學習訓練營 中的學習紀錄博客
  • 🍖 原作者:K同学啊 | 接輔導、項目定制

一、我的環境

  • 電腦系統:Windows 10

  • 顯卡:NVIDIA GeForce GTX 1060 6GB

  • 語言環境:Python 3.7.0

  • 開發工具:Sublime Text,Command Line(CMD)

  • 深度學習環境:1.12.1+cu113


二、準備套件

# PyTorch 的核心模組,包含了張量操作、自動微分、神經網絡構建、優化器等
import torch# PyTorch 的神經網絡模組,包含了各種神經網絡層和相關操作的類別和函數
import torch.nn as nn# Matplotlib 的繪圖模組,用於創建各種圖表和視覺化數據
import matplotlib.pyplot as plt# PyTorch 的計算機視覺工具包,包含了常用的數據集、模型和圖像轉換操作
import torchvision# 一個用於數值計算的 Python 庫,提供了高效的數組和矩陣操作功能
import numpy as np# PyTorch 的函數式神經網絡操作模組,包含了神經網絡中常用的操作,例如激活函數、損失函數等
import torch.nn.functional as F# 提供 PyTorch 模型的詳細摘要信息,包括層數、參數數量和輸出形狀等,類似於 Keras 的 model.summary()
from torchinfo import summary# 隱藏警告
import warnings

三、環境準備

# 忽略警告訊息
warnings.filterwarnings("ignore")  # 輸出 PyTorch 的版本號
print(torch.__version__)# 檢查是否有可用的CUDA設備,否則使用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 打印出當前使用的設備
print(device)


四、載入數據

# 載入CIFAR-10訓練數據集
train_ds = torchvision.datasets.CIFAR10('data',  # 數據下載後保存的目錄train=True,   # 指定載入訓練數據集transform=torchvision.transforms.ToTensor(), # 將圖像轉換為Tensordownload=True # 如果數據集不存在,則從網絡下載
)test_ds  = torchvision.datasets.CIFAR10('data',   # 數據下載後保存的目錄train=False,   # 指定載入測試數據集transform=torchvision.transforms.ToTensor(),   # 將圖像轉換為Tensordownload=True  # 如果數據集不存在,則從網絡下載
)


五、數據預處理

# 定義每個批次的大小為32
batch_size = 32# 創建訓練數據的DataLoader
train_dl = torch.utils.data.DataLoader(train_ds,   # 訓練數據集batch_size=batch_size,   # 每個批次包含32個樣本shuffle=True # 在每個epoch開始時打亂數據
)# 創建測試數據的DataLoader
test_dl  = torch.utils.data.DataLoader(test_ds, # 測試數據集batch_size=batch_size # 每個批次包含32個樣本# 測試數據集不需要shuffle,默認為False
)# 從訓練數據加載器中取出一個批次的圖像和標籤
imgs, labels = next(iter(train_dl))
# 打印圖像的形狀 (batch_size, channels, height, width)
print(imgs.shape)


六、圖片可視化

# 創建一個大小為 (20, 5) 的圖形
plt.figure(figsize=(20, 5)) 
# 遍歷前20個圖像
for i, imgs in enumerate(imgs[:20]):# 將圖像從 (channels, height, width) 轉換為 (height, width, channels) 以便於顯示npimg = imgs.numpy().transpose((1, 2, 0))# 在2行10列的子圖中繪製圖像plt.subplot(2, 10, i+1)# 顯示圖像,使用灰度色彩映射plt.imshow(npimg, cmap=plt.cm.binary)# 隱藏坐標軸plt.axis('off')
# 顯示圖形
plt.show()


七、定義模型

# 定義分類數量(CIFAR-10有10個類別)
num_classes = 10 # 定義模型類
class Model(nn.Module):def __init__(self):super().__init__()# 定義第一個卷積層,輸入通道為3(CIFAR-10圖像的RGB通道),輸出通道為64,卷積核大小為3x3self.conv1 = nn.Conv2d(3, 64, kernel_size=3)  # 定義第一個池化層,使用2x2的最大池化self.pool1 = nn.MaxPool2d(kernel_size=2)       # 定義第二個卷積層,輸入通道為64,輸出通道為64,卷積核大小為3x3self.conv2 = nn.Conv2d(64, 64, kernel_size=3)   # 定義第二個池化層,使用2x2的最大池化self.pool2 = nn.MaxPool2d(kernel_size=2) # 定義第三個卷積層,輸入通道為64,輸出通道為128,卷積核大小為3x3self.conv3 = nn.Conv2d(64, 128, kernel_size=3)    # 定義第三個池化層,使用2x2的最大池化self.pool3 = nn.MaxPool2d(kernel_size=2) # 定義第一個全連接層,輸入大小為512,輸出大小為256self.fc1 = nn.Linear(512, 256)          # 定義第二個全連接層,輸出大小為 num_classes(分類的數量)self.fc2 = nn.Linear(256, num_classes)# 定義前向傳播def forward(self, x):# 通過第一個卷積層、ReLU激活函數和池化層x = self.pool1(F.relu(self.conv1(x)))     # 通過第二個卷積層、ReLU激活函數和池化層x = self.pool2(F.relu(self.conv2(x)))# 通過第三個卷積層、ReLU激活函數和池化層x = self.pool3(F.relu(self.conv3(x)))# 將特徵圖展平為一維向量x = torch.flatten(x, start_dim=1)# 通過第一個全連接層和ReLU激活函數x = F.relu(self.fc1(x))# 通過第二個全連接層,輸出為分類結果x = self.fc2(x)# 返回輸出return x# 將模型移動到指定設備(GPU或CPU)
model = Model().to(device)# 打印模型結構摘要
summary(model)


八、定義訓練函數

# 定義損失函數為交叉熵損失
loss_fn = nn.CrossEntropyLoss() # 設定學習率為0.01
learn_rate = 1e-2 # 使用隨機梯度下降(SGD)優化器
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)# 定義訓練函數
def train(dataloader, model, loss_fn, optimizer):# 獲取訓練集的大小size = len(dataloader.dataset)  # 獲取批次數量num_batches = len(dataloader) # 初始化訓練損失和準確率train_loss, train_acc = 0, 0  # 遍歷訓練數據for X, y in dataloader: # 將輸入和標籤移動到指定設備(GPU或CPU)X, y = X.to(device), y.to(device)# 前向傳播:計算模型預測pred = model(X)         # 計算損失loss = loss_fn(pred, y)# 反向傳播前清零梯度optimizer.zero_grad()  # 反向傳播:計算梯度loss.backward()        # 更新模型參數optimizer.step()     # 計算訓練準確率train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()# 累加訓練損失train_loss += loss.item()# 計算平均訓練準確率train_acc /= size# 計算平均訓練損失train_loss /= num_batchesreturn train_acc, train_loss

九、定義測試函數

def test(dataloader, model, loss_fn):# 獲取測試集的大小size = len(dataloader.dataset) # 獲取批次數量num_batches = len(dataloader)  # 初始化測試損失和準確率test_loss, test_acc = 0, 0# 禁用梯度計算(加速推理過程)with torch.no_grad():# 遍歷測試數據for imgs, target in dataloader:# 將輸入和標籤移動到指定設備(GPU或CPU)imgs, target = imgs.to(device), target.to(device)# 前向傳播:計算模型預測target_pred = model(imgs)# 計算損失loss = loss_fn(target_pred, target)# 累加測試損失test_loss += loss.item()# 計算測試準確率test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()# 計算平均測試準確率test_acc /= size# 計算平均測試損失test_loss /= num_batchesreturn test_acc, test_loss

十、模型訓練

epochs     = 10   # 訓練的回合數
train_loss = []   # 存儲每個回合的訓練損失
train_acc  = []   # 存儲每個回合的訓練準確率
test_loss  = []   # 存儲每個回合的測試損失
test_acc   = []   # 存儲每個回合的測試準確率# 訓練和測試循環
for epoch in range(epochs):model.train()  # 將模型設置為訓練模式epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)  # 訓練模型並返回訓練準確率和損失model.eval()  # 將模型設置為評估模式epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)  # 測試模型並返回測試準確率和損失train_acc.append(epoch_train_acc)  # 存儲訓練準確率train_loss.append(epoch_train_loss)  # 存儲訓練損失test_acc.append(epoch_test_acc)  # 存儲測試準確率test_loss.append(epoch_test_loss)  # 存儲測試損失# 格式化並輸出當前回合的訓練和測試結果template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')


十一、結果可視化

# 設定 Matplotlib 的參數以支援中文和負號的顯示
plt.rcParams['font.sans-serif']    = ['SimHei']  # 用來正常顯示中文標籤
plt.rcParams['axes.unicode_minus'] = False       # 用來正常顯示負號
plt.rcParams['figure.dpi']         = 100         # 設定圖表的解析度epochs_range = range(epochs)  # 訓練回合的範圍plt.figure(figsize=(12, 3))  # 設置圖表大小
plt.subplot(1, 2, 1)  # 創建一個2行1列的子圖佈局,這裡畫第一個圖# 畫出訓練和測試的準確率曲線
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')  # 顯示圖例
plt.title('Training and Validation Accuracy')  # 設定標題plt.subplot(1, 2, 2)  # 畫第二個圖
# 畫出訓練和測試的損失曲線
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')  # 顯示圖例
plt.title('Training and Validation Loss')  # 設定標題plt.show()  # 顯示圖表


十二、心得

最近開始使用 PyTorch 框架來訓練模型,以下是 PyTorch 和 TensorFlow 2 的差異說明

  • 動態圖 vs 靜態圖

    • PyTorch 使用動態計算圖:這意味著計算圖是即時定義的,每次迭代可以根據需要改變結構,更靈活,更容易進行調試和編程
    • TensorFlow 2 則引入了即時執行(Eager Execution),類似於 PyTorch 的動態圖模式,使得模型的建立和調試更加直觀和靈活,此外,TensorFlow 2 也可以使用靜態圖來進行更高效的低級優化和部署
  • API 設計

    • PyTorch 的 API 設計更直觀和簡潔,更貼近 Python 編程風格,使得學習曲線較平緩,特別適合研究和實驗
    • TensorFlow 2 則採用了 Keras 作為其主要高級 API,提供了更高層次的抽象和簡化,使得模型的定義和訓練更加容易,特別適合生產環境和大規模部署
  • 模型部署

    • TensorFlow 2 在模型訓練後的部署和生產環境中表現更加優異,支持較多的低級優化和部署工具(如 TensorFlow Serving)
    • PyTorch 雖然近年來在這方面有所改進,但相對而言仍有一定的差距,部署需要更多的自定義和額外的工作
  • 社區和生態系統

    • TensorFlow 擁有更大的社區支持和更成熟的生態系統,有更多的文檔、教程和預訓練模型可用
    • PyTorch 的社區雖然較小,但在學術界和研究領域中得到了廣泛的應用和支持,並且快速增長

這次作業中,我學到如何使用 PyTorch 的 torchvision 模組來加載 CIFAR-10 數據集,這個過程包括對圖片進行標準化和轉換,使其適合訓練模型,載入數據後使用 PyTorch 定義一個簡單的卷積神經網絡(CNN)來處理圖像分類任務,這裡使用了幾個卷積層和池化層,以及全連接層,之後定義損失函數和優化器,然後訓練模型,這裡使用交叉熵損失和隨機梯度下降(SGD)優化器,最後使用測試集來評估模型的表現,計算準確率和其他指標

通過調整不同的超參數和嘗試不同的模型架構,也意識到了如何優化模型以達到更好的性能,這個過程不僅加深了我對深度學習的理解,還增強了我解決實際問題的能力

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/36134.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ThreadX简介

文章目录 1. 摘要2. ThreadX的特性2.1 免费开源2.2 安全认证级别高2.3 组件完善2.4 实时性高2.5 支持多核2.6 支持应用动态加载2.7 代码符合MISAR规范2.8 文档全面,例程丰富2.9 集成方便3. 移植示例4. 产品应用示例1. 摘要 在嵌入式系统领域,实时性能、系统稳定性以及广泛的…

Camera开发-相机输出常用数据格式

作者简介: 一个平凡而乐于分享的小比特,中南民族大学通信工程专业研究生在读,研究方向无线联邦学习 擅长领域:驱动开发,嵌入式软件开发,BSP开发 作者主页:一个平凡而乐于分享的小比特的个人主页…

JMeter安装与使用

安装包下载:https://pan.xunlei.com/s/VNigSM9IEjqNBVkw8by6i-LoA1?pwdu6gq# 也可以官网下载: 1.解压安装包 2.打开方式 (1)bin->ApacheJMeter.jar->打开界面 (2)如果(1)打…

LoadBalance 负载均衡

什么是负载均衡 负载均衡(Load Balance,简称 LB),是⾼并发,⾼可⽤系统必不可少的关键组件. 当服务流量增⼤时,通常会采⽤增加机器的⽅式进⾏扩容,负载均衡就是⽤来在多个机器或者其他资源中,按照⼀定的规则合理分配负载. 负载均衡的⼀些实现 服务多机部署时,开发⼈…

专业软件测试公司分享:安全测评对于软件产品的重要性

在互联网普及的今天,随着各类软件的大规模使用,安全问题也变得愈发突出。因此,对软件进行全面的安全测评,不仅可以有效保障用户的信息安全,还能提升软件产品的信任度和市场竞争力。 安全测评对于软件产品的重要性就如…

LLDB 详解

LLDB 详解 LLDB 详解编译器集成优势LLDB 的主要功能命令格式原始(raw)命令选项终止符: -- LLDB 中的变量唯一匹配原则helpexpressionprint、call、po控制流程:continue、next、step、finishregister read / writethread backtracethread retu…

精彩回顾 | 2024高通边缘智能创新应用大赛系列公开课

5月29日-6月6日,我们陆续开展了四场精彩绝伦的2024高通边缘智能创新应用大赛直播公开课。高通、阿加犀、广翼智联以及美格智能的业务领袖和行业大咖齐聚一堂,聚焦边缘智能,分享前沿技术、探讨创新应用,抢先揭秘比赛设备的核心特性…

MIT6.s081 2021 Lab System calls

xv6系统调用实现 不同于 Lab1 利用已实现的系统调用来实现一些用户态下的命令行程序,本 Lab 是要在内核层面实现一些系统调用。这其中难免涉及到一些对内核数据结构的操作,以及处理器体系结构(本系列 Lab 基于 RISCV)相关的内容&…

什么是慢查询——Java全栈知识(26)

1、什么是慢查询 慢查询:也就是接口压测响应时间过长,页面加载时间过长的查询 原因可能如下: 1、聚合查询 2、多表查询 3、单表数据量过大 4、深度分页查询(limit) 如何定位慢查询? 1、Skywalking 我们…

IND83081芯片介绍(一)

一、芯片介绍 IND83081是indiemicro推出的一款高性能的汽车矩阵LED照明控制器,集成了四个子模块,每个子模块包含三个串联的MOSFET开关,每个开关均可通过12位PWM内部信号控制,可配置的上升和下降速率及相位移以实现精确控制&#x…

JOSEF约瑟 JOXL-J拉绳开关 整定范围宽

用途 双向拉绳开关的壳体采用金属材料铸造,具有足够的机械强度,抵抗并下工作时脱落的岩石,爆块等物体的撞击不被破坏,当胶带输送机发生紧急事故时,启动拉绳开关,可立即停机报警,防止事故的扩大,保证工作现场的人身安全…

java 操作 milvus 2.1.4

1. 确认 docker 运行的 milvus容器镜像版本情况&#xff1a; 2. pom 依赖&#xff1a; <dependency><groupId>io.milvus</groupId><artifactId>milvus-sdk-java</artifactId><version>2.1.0</version><exclusions><exclusi…

Java学习 - Redis慢查询与发布订阅与流水线

慢查询 慢查询是什么 慢查询本质是慢查询日志&#xff0c;它记录了一些执行速度很慢的命令 慢查询与生命周期 生命周期 ------- ------------------------------------------ | | 1.发送请求 | redis服务端 …

Simulink缓存文件有什么用?

在使用Simulink进行仿真的过程中&#xff0c;经常会发现目录下存在一些后缀为.slxc的文件&#xff0c;这些其实就是Simulink模型的缓存文件&#xff08;.slx cache&#xff09;。 Simulink缓存文件的主要作用是提高仿真和代码生成的效率。 借助缓存文件&#xff0c;可以避免…

Web浏览器读写NFC Ntag标签

本示例使用的发卡器&#xff1a;RS232串口USB转COM读写器IC卡发卡器WEB浏览器二次开发JS编程SDK-淘宝网 (taobao.com) <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"&g…

不锈钢氩弧焊丝ER316L

说明&#xff1a;TG316L 是超低碳的不锈钢焊丝。熔敷金属耐蚀、耐热、抗裂性能优良。防腐蚀性能良好。 用途:用于石油化工、化肥设备等。也可用于要求焊接后不进行热处理的高Cr钢的焊接。

真实评测:可道云teamOS文件上传功能丝滑到爱不释手

对于每日沉浸在图片与视频海洋中的媒体工作者而言&#xff0c;与海量的多媒体文件打交道几乎成了家常便饭。 文件的上传和存储&#xff0c;对他们而言&#xff0c;不仅仅是工作中的一个环节&#xff0c;更像是将一天的辛勤与付出妥善安置的仪式。无论是突发现场的精彩瞬间&am…

海报在线制作系统源码小程序

轻松设计&#xff0c;创意无限 一款基于ThinkPHPFastAdminUniApp开发的海报在线制作系统&#xff0c; 本系统不包含演示站中的素材模板资源。​ 一、引言&#xff1a;设计新纪元&#xff0c;在线海报制作引领潮流 在数字时代&#xff0c;海报已成为传播信息、展示创意的重要媒…

配音软件哪个好用?推荐5款智能配音软件

随着期末考来袭&#xff0c;校园里的空气似乎都凝固了&#xff0c;每个角落都充满了紧张的气氛。 然而&#xff0c;在这紧张的氛围中&#xff0c;有一群学生却显得格外从容&#xff0c;因为他们掌握了一种秘密武器——配音软件。这些软件就像是他们的个人学习助理&#xff0c;…

git 中有关 old mode 100644、new mode 10075的问题解决小结

问题&#xff1a; 同一个文件被修改后&#xff0c;最后代码没有变&#xff0c;文件变了&#xff0c;导致提交了一个空文件 git diff 提示 filemode 发生改变&#xff08;old mode 100644、new mode 10075&#xff09; 解决办法 &#xff1a; 原来是 filemode 的变化&#xff…