深度學習筆記12-優化器對比(Tensorflow)

  • 🍨 本文為🔗365天深度學習訓練營 中的學習紀錄博客
  • 🍖 原作者:K同学啊 | 接輔導、項目定制

一、我的環境

  • 電腦系統:Windows 10

  • 顯卡:NVIDIA Quadro P620

  • 語言環境:Python 3.7.0

  • 開發工具:Sublime Text,Command Line(CMD)

  • 深度學習環境:Tensorflow 2.5.0


二、準備套件

# 提供一些與操作系統交互的功能,例如文件路徑操作等
import os# 用於圖像處理,例如打開、操作、保存圖像文件
import PIL# 用於處理文件路徑的模塊,提供一種更加直觀和面向對象的操作文件路徑方式
import pathlib# 用於繪圖,可以創建各種類型的圖表和圖形
import matplotlib.pyplot as plt# 數值計算庫,用於處理大型多維數組和矩陣的
import numpy as np# 開源的機器學習框架
import tensorflow as tf# 導入 keras 模塊,為 tensorflow 的高級 API 之一,操作起來更加簡單、易用
from tensorflow import keras# layers模組包含了各種類型的神經網絡層
# models模組包含了用於定義神經網絡模型的類
# Input類用於定義模型的輸入
from tensorflow.keras import layers, models, Input# 用於定義自定義的神經網絡模型
from tensorflow.keras.models import Model# 導入Keras API中的一些常用神經網絡層
# 包括卷積層(Conv2D)、池化層(MaxPooling2D)、全連接層(Dense)、展平層(Flatten)、失活層(Dropout)、批量規範(BatchNormalization)
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout, BatchNormalization# 供了一個名為 tqdm 的進度條,可以在迭代過程中顯示進度,讓用戶了解運行的進度
# 它是一個很常用的進度條庫,對於長時間運行的程式非常有用
from tqdm import tqdm# 將 Keras 的後端函數庫引入為 K
# Keras 的後端函數庫提供了一系列與計算圖和張量操作相關的功能,
# 例如張量的數學運算、梯度計算等。通常,我們可以通過 K. 來訪問這些函數和類
import tensorflow.keras.backend as K#隱藏警告
import warnings
warnings.filterwarnings('ignore')# 用來在 Matplotlib 圖表中設置刻度的類
# 通過它,可以指定刻度的位置和間隔,以便更好地控制圖表的顯示效果
from matplotlib.ticker import MultipleLocator

三、設定GPU

# 列出系統中的GPU裝置列表
gpus = tf.config.list_physical_devices("GPU")# 如果有GPU
if gpus:# 挑選第一個 GPUgpu0 = gpus[0] # 僅在需要的時候分配記憶體tf.config.experimental.set_memory_growth(gpu0, True)# 將 GPU0 設置為 TensorFlow 中可見的唯一 GPU ,將運算限制在特定的 GPU 上 tf.config.set_visible_devices([gpu0],"GPU") plt.rcParams['axes.unicode_minus'] = False  # 顯示負號

四、載入資料 

# 設定數據目錄的相對路徑,也可以使用絕對路徑
# D:/AI/ai_note/T6,這邊要注意斜線的方向
data_dir = "T6/"
# 將路徑轉換成 pathlib.Path 對象,更易操作
data_dir = pathlib.Path(data_dir)
# 使用 glob 方法獲取指定目錄下所有以 '.png' 為副檔名的文件迭代器
# '*/*.png 是一個通配符模式,表示所有直接位於子目錄中的以 .png 結尾的文件
# 第一個星號表示所有目錄
# 第二個星號表示所有檔名image_count = len(list(data_dir.glob('*/*')))
# 印出圖片數量
print("圖片總數:",image_count)


五、數據預處理

# 設置批量大小,即每次訓練模型時輸入到模型中的圖像數量
# 在每次訓練跌代時,模型將同時處理16張圖像
# 批量大小的選擇會影響訓練速度和內存需求
batch_size = 16
# 圖像的高度,在加載圖像數據時,將所有的圖像調整為相同的高度,這裡設定為 336 像素
img_height = 336
# 圖像的寬度,在加載圖像數據時,將所有的圖像調整為相同的寬度,這裡設定為 336 像素
img_width = 336# 創建訓練數據集
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,  # 數據集所在目錄validation_split=0.2,  # 將數據集的20%用於驗證subset="training",  # 指定該部分為訓練數據集seed=12,  # 隨機種子,保證數據劃分的可重複性image_size=(img_height, img_width),  # 調整圖像尺寸batch_size=batch_size)  # 每個批次的圖像數量# 創建驗證數據集
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,  # 數據集所在目錄validation_split=0.2,  # 將數據集的20%用於驗證subset="validation",  # 指定該部分為驗證數據集seed=12,  # 隨機種子,保證數據劃分的可重複性image_size=(img_height, img_width),  # 調整圖像尺寸batch_size=batch_size) # 每個批次的圖像數量# 獲取數據集中類別的名稱
class_names = train_ds.class_names
print(class_names) # 輸出類別名稱


六、檢查數據

# 查看一個批次的圖像和標籤的形狀
for image_batch, labels_batch in train_ds:print(image_batch.shape)  # 打印圖像批次的形狀print(labels_batch.shape)  # 打印標籤批次的形狀break  # 只查看第一個批次


七、配置數據集

AUTOTUNE = tf.data.AUTOTUNE# 定義訓練數據預處理函數
def train_preprocessing(image, label):return (image / 255.0, label)  # 將圖像數據歸一化到[0, 1]範圍# 設置訓練數據集的預處理流程
train_ds = (train_ds.cache()  # 將數據集緩存到內存中,提高讀取速度.shuffle(1000)  # 將數據集隨機打亂.map(train_preprocessing)  # 應用預處理函數.prefetch(buffer_size=AUTOTUNE)  # 預取數據以提高性能
)# 設置驗證數據集的預處理流程
val_ds = (val_ds.cache()  # 將數據集緩存到內存中,提高讀取速度.shuffle(1000)  # 將數據集隨機打亂.map(train_preprocessing)  # 應用預處理函數.prefetch(buffer_size=AUTOTUNE)  # 預取數據以提高性能
)

八、數據展示

plt.figure(figsize=(10, 8))  # 設置圖像大小
plt.suptitle("數據展示")  # 設置整體標題# 從訓練數據集中取一個批次的圖像和標籤
for images, labels in train_ds.take(1):for i in range(15):  # 顯示前15張圖像plt.subplot(4, 5, i + 1)  # 創建子圖,4行5列plt.xticks([])  # 隱藏X軸刻度plt.yticks([])  # 隱藏Y軸刻度plt.grid(False)  # 隱藏網格線plt.imshow(images[i])  # 顯示圖像plt.xlabel(class_names[labels[i]])  # 顯示圖像對應的類別名稱plt.show()  # 顯示圖像


 九、建構模型

def create_model(optimizer='adam'):# 加載預訓練模型vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',include_top=False,input_shape=(img_width, img_height, 3),pooling='avg')# 冻結預訓練模型的所有層for layer in vgg16_base_model.layers:layer.trainable = False# 添加自定義的全連接層X = vgg16_base_model.outputX = Dense(170, activation='relu')(X)X = BatchNormalization()(X)X = Dropout(0.5)(X)# 添加輸出層,使用softmax激活函數進行多分類output = Dense(len(class_names), activation='softmax')(X)# 創建完整的模型vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)# 編譯模型vgg16_model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics=['accuracy'])return vgg16_model# 使用不同的優化器創建模型
model1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())# 打印模型結構
model2.summary()


十、 訓練模型

NO_EPOCHS = 10  # 訓練的輪數# 使用 model1 進行訓練
history_model1 = model1.fit(train_ds,  # 訓練數據集epochs=NO_EPOCHS,  # 訓練輪數verbose=1,  # 顯示訓練過程的詳細信息validation_data=val_ds  # 驗證數據集
)# 使用 model2 進行訓練
history_model2 = model2.fit(train_ds,  # 訓練數據集epochs=NO_EPOCHS,  # 訓練輪數verbose=1,  # 顯示訓練過程的詳細信息validation_data=val_ds  # 驗證數據集
)


十一、模型評估

plt.rcParams['savefig.dpi'] = 300  # 圖片像素
plt.rcParams['figure.dpi'] = 300   # 分辨率# 從訓練歷史中提取準確率和損失
acc1 = history_model1.history['accuracy']
acc2 = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']loss1 = history_model1.history['loss']
loss2 = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']epochs_range = range(len(acc1))  # 訓練的輪數範圍plt.figure(figsize=(16, 4))# 畫出訓練和驗證準確率
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')# 設置刻度間隔,x軸每1一個刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))# 畫出訓練和驗證損失
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')# 設置刻度間隔,x軸每1一個刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.show()  # 顯示圖像def test_accuracy_report(model):# 評估模型在驗證數據集上的性能score = model.evaluate(val_ds, verbose=0)# 打印損失值和準確率print('Loss function:', score[0], ', accuracy:', score[1])# 測試 model2 的準確率報告
test_accuracy_report(model2)

 


十二、總結

在深度學習中,選擇適當的優化器及其相應的參數配置對模型的訓練和性能表現具有重要影響

  1. 優化器的選擇

    • Adam 優化器通常是一個不錯的默認選擇,它結合了動量梯度下降和自適應學習率調整。它對於大多數情況下能夠提供良好的性能表現,並且相對容易調參
    • SGD(隨機梯度下降) 需要精心調參,特別是學習率、動量等參數的設置。在某些情況下,SGD 可以通過仔細調整參數實現更好的性能,特別是在計算資源有限的情況下
  2. 學習率的調整

    • Adam 優化器通常不需要手動調整學習率,因為它會自適應調整。但是,如果遇到訓練過程中性能停滯或不收斂的情況,可以考慮進行小幅度調整
    • SGD 優化器需要仔細調整學習率,通常會隨著訓練進行進行衰減或者動態調整
  3. 批量大小的影響

    • 選擇合適的批量大小對訓練速度和收斂性能至關重要。通常來說,較大的批量大小可以加速訓練,但可能會導致內存壓力或過擬合問題。較小的批量大小則可以提升模型的泛化能力
  4. 其他參數的影響

    • 動量(Momentum):對於 SGD,動量可以幫助加速收斂,特別是在具有高曲率的梯度表面上
    • 權重衰減(Weight Decay):可以用來控制模型的正則化,減少過擬合的風險
    • Dropout:隨機失活在訓練過程中可以有效防止過擬合,通常設置在 0.2 到 0.5 之間

選擇最佳的優化器及參數配置需要透過實驗和觀察來得出,在實際應用中,可以通過監控訓練和驗證的損失與準確率來評估不同設置的效果,並根據實際情況做出調整

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/29218.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于GTX的64B66B编码IP生成(高速收发器二十)

点击进入高速收发器系列文章导航界面 1、配置GTX IP 相关参数 前文讲解了64B66B编码解码原理,以及GTX IP实现64B66B编解码的相关信号组成,本文生成64B66B编码的GTX IP。 首先如下图所示,需要对GTX共享逻辑进行设置,为了便于扩展&a…

【开发工具】git服务器端安装部署+客户端配置

自己安装一个轻量级的git服务端,仅仅作为代码维护,尤其适合个人代码管理。毕竟代码的版本管理是很有必要的。 这里把git服务端部署在centos系统里,部署完成后可以通过命令行推拉代码,进行版本和用户管理。 一、服务端安装配置 …

【2024最新华为OD-C/D卷试题汇总】[支持在线评测] 内存访问热度分析(100分) - 三语言AC题解(Python/Java/Cpp)

🍭 大家好这里是清隆学长 ,一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-C/D卷的三语言AC题解 💻 ACM银牌🥈| 多次AK大厂笔试 | 编程一对一辅导 👏 感谢大家的订阅➕ 和 喜欢💗 &#x1f…

windows环境下,怎么查看本机的IP、MAC地址和端口占用情况

1.输入ipconfig,按回车。即查看了IP地址,子码掩码,网关信息。 2.输入ipconfig/all,按回车。即查看了包含IP地址,子码掩码,网关信息以及MAC地址 3.我们有时在启动应用程序的时候提示端口被占用,如何知道谁占有了我们需要…

Vue57-组件的自定义事件_解绑

给谁绑的自定义事件,就找谁去触发;给谁绑的自定义事件,就找谁去解绑; 一、解绑自定义事件 1-1、解绑一个自定义事件 到student.vue组件中去解绑。 1-2、解绑多个自定义事件 使用数组来解绑多个。 1-3、解绑所有的自定义事件 二、…

Android Studio无法连接夜神模拟器的解决方案

一、AS检测不到夜神模拟器 1、问题描述 在按照教程【如何安装和使用Android夜神模拟器】进入夜神的bin目录,输入连接命令回车后,终端显示的already connected to 127.0.0.1:62001,但是AS的Running Devices并没有显示夜神模拟器。 2、解决方…

Arm和高通的法律之争将扰乱人工智能驱动的PC浪潮

Arm和高通的法律之争将扰乱人工智能驱动的PC浪潮 科技行业高管和专家表示,两大科技巨头之间长达两年的法律大战可能会扰乱人工智能驱动的新一代个人电脑浪潮。 上周,来自微软(Microsoft)、华硕(Asus)、宏碁(Acer)、高通(Qualcomm)等公司的高管在台北举行…

计算机毕业设计Python+Vue.js知识图谱音乐推荐系统 音乐爬虫可视化 音乐数据分析 大数据毕设 大数据毕业设计 机器学习 深度学习 人工智能

开发技术 协同过滤算法、机器学习、LSTM、vue.js、echarts、django、Python、MySQL 创新点协同过滤推荐算法、爬虫、数据可视化、LSTM情感分析、短信、身份证识别 补充说明 适合大数据毕业设计、数据分析、爬虫类计算机毕业设计 介绍 音乐数据的爬取:爬取歌曲、…

深度学习推理显卡设置

深度学习推理显卡设置 进入NVIDIA控制面板,选择 “管理3D设置”设置 "低延时模式"为 "“超高”"设置 “电源管理模式” 为 “最高性能优先” 使用锁频来获得稳定的推理 法一:命令行操作 以管理员身份打开CMD查看GPU核心可用频率&…

云计算 | (四)基本云安全

文章目录 📚基本云安全🐇云安全背景🐇基本术语和概念⭐️风险(risk)⭐️安全需求🐇威胁作用者⭐️威胁作用者(threat agent)⭐️匿名攻击者(anonymous attacker)⭐️恶意服务作用者(malicious service agent)⭐️授信的攻击者(trusted attacker)⭐️恶意的内部人员(mal…

有趣且重要的JS知识合集(22)树相关的算法

0、举例&#xff1a;树形结构原始数据 1、序列化树形结构 /*** 平铺序列化树形结构* param tree 树形结构* param result 转化后一维数组* returns Array<TreeNode>*/ export function flattenTree(tree, result []) {if (tree.length 0) {return result}for (const …

开发一个python工具,pdf转图片,并且截成单个图片,然后修整没用的白边

今天推荐一键款本人开发的pdf转单张图片并截取没有用的白边工具 一、开发背景&#xff1a; 业务需要将一个pdf文件展示在前端显示&#xff0c;但是基于各种原因&#xff0c;放弃了h5使用插件展示 原因有多个&#xff0c;文件资源太大加载太慢、pdf展示兼容性问题、pdf展示效果…

CSDN 自动上传图片并优化Markdown的图片显示

文章目录 完整代码一、上传资源二、替换 MD 中的引用文件为在线链接参考 完整代码 完整代码由两个文件组成&#xff0c;upload.py 和 main.py&#xff0c;放在同一目录下运行 main.py 就好&#xff01; # upload.py import requests class UploadPic: def __init__(self, c…

力扣每日一题 6/17 枚举+双指针

博客主页&#xff1a;誓则盟约系列专栏&#xff1a;IT竞赛 专栏关注博主&#xff0c;后期持续更新系列文章如果有错误感谢请大家批评指出&#xff0c;及时修改感谢大家点赞&#x1f44d;收藏⭐评论✍ 522.最长特殊序列II【中等】 题目&#xff1a; 给定字符串列表 strs &…

【Ubuntu通用压力测试】Ubuntu16.04 CPU压力测试

使用 stress 对CPU进行压力测试 我也是一个ubuntu初学者&#xff0c;分享是Linux的优良美德。写的不好请大佬不要喷&#xff0c;多谢支持。 sudo apt-get update 日常先更新再安装东西不容易出错 sudo apt-get upgrade -y 继续升级一波 sudo apt-get install -y linux-tools…

Stable Diffusion文生图模型训练入门实战(完整代码)

Stable Diffusion 1.5&#xff08;SD1.5&#xff09;是由Stability AI在2022年8月22日开源的文生图模型&#xff0c;是SD最经典也是社区最活跃的模型之一。 以SD1.5作为预训练模型&#xff0c;在火影忍者数据集上微调一个火影风格的文生图模型&#xff08;非Lora方式&#xff…

Python | Leetcode Python题解之第162题寻找峰值

题目&#xff1a; 题解&#xff1a; class Solution:def findPeakElement(self, nums: List[int]) -> int:n len(nums)# 辅助函数&#xff0c;输入下标 i&#xff0c;返回 nums[i] 的值# 方便处理 nums[-1] 以及 nums[n] 的边界情况def get(i: int) -> int:if i -1 or…

STM32单片机DMA存储器详解

文章目录 1. DMA概述 2. 存储器映像 3. DMA框架图 4. DMA请求 5. 数据宽度与对齐 6. DMA数据转运 7. ADC扫描模式和DMA 8. 代码示例 1. DMA概述 DMA&#xff08;Direct Memory Access&#xff09;可以直接访问STM32内部的存储器&#xff0c;DMA是一种技术&#xff0c;…

【 ARMv8/ARMv9 硬件加速系列 3.5.1 -- SVE 谓词寄存器有多少位?】

文章目录 SVE 谓词寄存器(predicate registers)简介SVE 谓词寄存器的位数SVE 谓词寄存器对向量寄存器的控制SVE 谓词寄存器位数计算SVE 谓词寄存器小结SVE 谓词寄存器(predicate registers)简介 ARMv9的Scalable Vector Extension (SVE) 引入了谓词寄存器(Predicate Register…

打造工业操作系统开源开放体系

我国制造业具有细分行业、领域众多&#xff0c;产品丰富&#xff0c;制造模式多样等特点&#xff0c;围绕以工业操作系统为核心的工业软件赋能体系建设&#xff0c;离不开平台运营商、工业软件开发商、系统服务商、科研机构、工业企业等多方联合参与。聚众同行、聚力创新&#…