Week-T11-优化器对比试验

文章目录

  • 一、准备环境
  • 二、准备数据
  • 三、搭建训练网络
  • 三、训练模型
    • (1)VSCode训练情况:
    • (2)`jupyter notebook`训练情况:
  • 四、模型评估 & 模型预测
    • 1、绘制Accuracy-Loss图
    • 2、显示model2的预测效果
  • 五、总结
    • 1、`plt.savefig("./数据展示.jpg")`保存的图片在文件夹内打开是空白的,如下图所示:
    • 2. 优化器是什么?包括哪些?

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

本文主要探究不同优化器、以及不同参数配置对模型的影响,最终对Adam、SGD优化器进行比较,并绘制比较结果。

使用的数据集为咖啡豆数据集,共有四类。

优化器常用的有Adam、SGD。优化器的归纳将放在文末的总结部分。

本文将使用Adam优化器的模型命名为"model1",使用SGD优化器的模型命名为"model2",然后根据模型训练结果绘制各自的Accuracy-Loss图。比较得出,在运行环境、epoch次数相同、模型结构相同等条件下,Adam优化器的整体情况要优于SGD优化器。

一、准备环境

# 1. 设置环境
import sys
import tensorflow as tf
from datetime import datetimefrom tensorflow          import keras
import matplotlib.pyplot as plt
import pandas            as pd
import numpy             as np
import warnings,os,PIL,pathlibprint("---------------------1.配置环境------------------")
print("Start time: ", datetime.today())
print("tensorflow version: " + tf.__version__)
print("Python version: " + sys.version)gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")# 打印显卡信息,确认GPU可用print("GPU: " + gpus)
else:print("Using CPU")warnings.filterwarnings("ignore")             #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号

在这里插入图片描述

Q1: VSCode虚拟环境安装pandas
在这里插入图片描述

二、准备数据

# 2.导入数据
# 本次使用咖啡豆数据集(共4类)
print("---------------------2.1 从本地读取数据------------------")
data_dir    = "D:/jupyter notebook/DL-100-days/datasets/coffebeans-data"
data_dir    = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)batch_size = 16
img_height = 336
img_width  = 336"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
print("---------------------2.2 划分训练数据------------------")
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
print("---------------------2.3 划分验证数据------------------")
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)print("---------------------2.4 打印数据类别 && 数据的shape------------------")
class_names = train_ds.class_names
print(class_names)for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)breakprint("---------------------2.5 配置数据集------------------")
AUTOTUNE = tf.data.AUTOTUNEdef train_preprocessing(image,label):return (image/255.0,label)train_ds = (train_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)val_ds = (val_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)print("---------------------2.6 数据可视化,显示部分样本图片------------------")
plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")for images, labels in train_ds.take(1):for i in range(15):plt.subplot(4, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)# 显示图片plt.imshow(images[i])# 显示标签plt.xlabel(class_names[labels[i]-1])plt.show()
plt.savefig("./数据展示.jpg")

在这里插入图片描述
在这里插入图片描述

Q2:plt.savefig("./数据展示.jpg")保存的图片在文件夹内打开是空白的

三、搭建训练网络

print("---------------------3. 搭建训练网络,此处预训练模型调用VGG-16官方模型------------------")
# 自定义一个创建模型的函数,形参是优化器类型,预训练模型是VGG-16,但屏蔽了自带的训练部分以及顶层,然后对输出进行处理
# 在此处创建了两个网络,拥有不同的优化器类型
from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Modeldef create_model(optimizer='adam'):# 加载预训练模型vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',include_top=False,input_shape=(img_width, img_height, 3),pooling='avg')for layer in vgg16_base_model.layers:layer.trainable = FalseX = vgg16_base_model.outputX = Dense(170, activation='relu')(X)X = BatchNormalization()(X)X = Dropout(0.5)(X)output = Dense(len(class_names), activation='softmax')(X)vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)vgg16_model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics=['accuracy'])return vgg16_modelmodel1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())
model2.summary()

在这里插入图片描述

三、训练模型

print("---------------------4.启动训练,epoch==50------------------")
# try:加入早停试一下,一个epoch跑完要220s,时间还是有点久
NO_EPOCHS = 50history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)

(1)VSCode训练情况:

model1.fit():Adam优化器
在这里插入图片描述
model2.fit():SGD优化器
在这里插入图片描述

(2)jupyter notebook训练情况:

model1.fit():即Adam优化器
在这里插入图片描述
model2.fit():即SGD优化器
在这里插入图片描述

四、模型评估 & 模型预测

1、绘制Accuracy-Loss图

print("---------------------5.1 模型评估,绘制Accuracy-Loss图------------------")
from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']epochs_range = range(len(acc1))plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))
plt.savefig("./Accuracy-Loss图.jpg")
plt.show()

plt.show()显示的图片:
请添加图片描述
比较Accuracy图表,可以看出训练时Adam优化器的表现要稍优于SGD优化器,而验证时则相反。

Q: VSCode绘制出来的图咋这么奇怪?
改变plt.savefig("./Accuracy-Loss图.jpg")的位置后所保存的图片,比直接plt.show()的图片比例要好些。
在这里插入图片描述

2、显示model2的预测效果

print("---------------------5.2 模型预测------------------")
def test_accuracy_report(model):score = model.evaluate(val_ds, verbose=0)print('Loss function: %s, accuracy:' % score[0], score[1])test_accuracy_report(model2)

VSCode环境下的预测结果:
在这里插入图片描述
jupyter notebook环境下的预测结果:
在这里插入图片描述

五、总结

1、plt.savefig("./数据展示.jpg")保存的图片在文件夹内打开是空白的,如下图所示:

在这里插入图片描述
将保存的语句放在plt.show()之前,因为plt.show()之后会默认打开一个空白画板。

2. 优化器是什么?包括哪些?

(参考文章也是来自训练营文章)

优化器是什么?

  • 优化器是一种算法,它在模型优化过程中,动态地调整梯度的大小和方向,使模型能够收敛到更好的位置,或者用更快的速度进行收敛。
  • 各类优化器方法总结如下:
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/179264.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++类与对象(7)—友元、内部类、匿名对象、拷贝对象时编译器优化

目录 一、友元 1、定义 2、友元函数 3、友元类 二、内部类 1、定义 2、特性: 三、匿名对象 四、拷贝对象时的一些编译器优化 1、传值&传引用返回优化对比 2、匿名对象作为函数返回对象 3、接收返回值方式对比 总结: 一、友元 1、定义…

RPC之grpc重试策略

1、grpc重试策略 RPC 调用失败可以分为三种情况: 1、RPC 请求还没有离开客户端; 2、RPC 请求到达服务器,但是服务器的应用逻辑还没有处理该请求; 3、服务器应用逻辑开始处理请求,并且处理失败; 最后一种…

2020年3月2日 Go生态洞察:Go协议缓冲区的新API发布

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…

如何轻松将 4K 转换为 1080p 高清视频

由于某些原因,你可能有一些 4K 视频,与1080p、1080i、720p、720i等高清视频相比,4K 视频具有更高的分辨率,可以给您带来更多的视觉和听觉享受。但是,播放4k 视频是不太容易的,因为超高清电视没有高清电视那…

C#面向对象

过程类似函数只能执行没有返回值 函数不仅能执行,还可以返回结果 1、面向过程 a 把完成某一需求的所有步骤 从头到尾 逐步实现 b 根据开发需求,将某些 功能独立 的代码 封装 成一个又一个 函数 c 最后完成的代码就是顺序的调用不同的函数 特点 1、…

【问题系列】消费者与MQ连接断开问题解决方案(二)

1. 问题描述 当使用RabbitMQ作为中间件,而消费者为服务时,可能会出现以下情况:在长时间没有消息传递后,消费者与RabbitMQ之间出现连接断开,导致无法处理新消息。解决这一问题的方法是重启Python消费者服务,…

大数据平台/大数据技术与原理-实验报告--部署ZooKeeper集群和实战ZooKeeper

实验名称 部署ZooKeeper集群和实战ZooKeeper 实验性质 (必修、选修) 必修 实验类型(验证、设计、创新、综合) 综合 实验课时 2 实验日期 2023.11.04-2023.11.05 实验仪器设备以及实验软硬件要求 专业实验室&#xff08…

Spring Boot 3.2.0 Tomcat虚拟线程初体验 (部分装配解析)

写在前面 spring boot 3 已经提供了对虚拟线程的支持。 虚拟线程和平台线程主要区别在于,虚拟线程在运行周期内不依赖操作系统线程:它们与硬件脱钩,因此被称为 “虚拟”。这种解耦是由 JVM 提供的抽象层赋予的。 虚拟线程的运行成本远低于平…

如何使用APP UI自动化测试提高测试效率与质量?

pythonappium自动化测试系列就要告一段落了,本篇博客咱们做个小结。 首先想要说明一下,APP自动化测试可能很多公司不用,但也是大部分自动化测试工程师、高级测试工程师岗位招聘信息上要求的,所以为了更好的待遇,我们还…

C++11『右值引用 ‖ 完美转发 ‖ 新增类功能 ‖ 可变参数模板』

✨个人主页: 北 海 🎉所属专栏: C修行之路 🎃操作环境: Visual Studio 2022 版本 17.6.5 文章目录 🌇前言🏙️正文1.右值引用1.1.什么是右值引用?1.2.move 转移资源1.3.左值引用 vs …

CSS问题:如何实现瀑布流布局?

前端功能问题系列文章,点击上方合集↑ 序言 大家好,我是大澈! 本文约2500字,整篇阅读大约需要4分钟。 本文主要内容分三部分,如果您只需要解决问题,请阅读第一、二部分即可。如果您有更多时间&#xff…

JavaEE进阶学习:Bean 作用域和生命周期

1.Bean 作用域 .通过一个案例来看 Bean 作用域的问题 假设现在有一个公共的 Bean,提供给 A 用户和 B 用户使用,然而在使用的途中 A 用户却“悄悄”地修改了公共 Bean 的数据,导致 B 用户在使用时发生了预期之外的逻辑错误。 我们预期的结果…

colab notebook导出为PDF

目录 方法一:使用浏览器打印功能 方法二:使用nbconvert转换 方法三:在线转换 方法一:使用浏览器打印功能 一般快捷键是CTRLP 然后改变目标打印机为另存为PDF 这样就可以将notebook保存为PDF了 方法二:使用nbconver…

芯片技术前沿:了解构现代集成电路的设计与制造

芯片技术前沿:解构现代集成电路的设计与制造 摘要:本文将深入探讨芯片技术的最新进展,重点关注集成电路的设计与制造。我们将带领读者了解芯片设计的基本流程,包括电路分析、版图设计和验证等步骤,并介绍当前主流的制…

强化学习中的深度Q网络

深度 Q 网络(Deep Q-Network,DQN)是一种结合了深度学习和强化学习的方法,用于解决离散状态和离散动作空间的强化学习问题。DQN 的核心思想是使用深度神经网络来近似 Q 函数,从而学习复杂环境中的最优策略。 以下是 DQN…

从苹果到蔚来,「车手互联」网罗顶级玩家

作者 |Amy 编辑 |德新 汽车作为家之外的第二大移动空间,正与手机这一移动智能终端进行「车手互联」。 车手互联始于十年前的苹果CarPlay,一度成为时代弄潮儿,不断有后继者模仿并超越。时至今日,CarPlay2.0依旧停留在概念阶段&am…

RK3568笔记六:基于Yolov8的训练及部署

若该文为原创文章,转载请注明原文出处。 基于Yolov8的训练及部署,参考鲁班猫的手册训练自己的数据集部署到RK3568,用的是正点的板子。 1、 使用 conda 创建虚拟环境 conda create -n yolov8 python3.8 ​ conda activate yolov8 2、 安装 pytorch 等…

osgFX扩展库-异性光照、贴图、卡通特效(1)

本章将简单介绍 osgFX扩展库及osgSim 扩展库。osgFX库用得比较多,osgSim库不常用,因此,这里只对这个库作简单的说明。 osgFX扩展库 osgFX是一个OpenSceneGraph 的附加库,是一个用于实现一致、完备、可重用的特殊效果的构架工具,其…

UE 事件分发机制 day9

观察者模式原理 观察者模式通常有观察者与被观察者,当被观察者状态发生改变时,它会通知所有的被观察者对象,使他们能够及时做出响应,所以也被称作“发布-订阅模式”。总得来说就是你关注了一个主播,主播的状态改变会通…