深度学习笔记11-优化器对比实验(Tensorflow)

  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

目录

一、导入数据并检查

二、配置数据集

三、数据可视化

四、构建模型

五、训练模型

六、模型对比评估

七、总结


一、导入数据并检查

import pathlib,PIL
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签data_dir    = pathlib.Path("./T6")
image_count = len(list(data_dir.glob('*/*')))
batch_size = 16
img_height = 336
img_width  = 336
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

二、配置数据集

AUTOTUNE = tf.data.AUTOTUNE
#归一化处理
def train_preprocessing(image,label):return (image/255.0,label)train_ds = (train_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)val_ds = (val_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)

三、数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")for images, labels in train_ds.take(1):for i in range(15):plt.subplot(4, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)# 显示图片plt.imshow(images[i])# 显示标签plt.xlabel(class_names[labels[i]-1])plt.show()

四、构建模型

from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Modeldef create_model(optimizer='adam'):# 加载预训练模型vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',include_top=False,#不包含顶层的全连接层input_shape=(img_width, img_height, 3),pooling='avg')#平均池化层替代顶层的全连接层for layer in vgg16_base_model.layers:layer.trainable = False  #将 trainable属性设置为 False 意味着在训练过程中,这些层的权重不会更新X = vgg16_base_model.outputX = Dense(170, activation='relu')(X)X = BatchNormalization()(X)X = Dropout(0.5)(X)output = Dense(len(class_names), activation='softmax')(X)#神经元数量等于类别数vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)vgg16_model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics=['accuracy'])return vgg16_modelmodel1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())#随机梯度下降(SGD)优化器的
model2.summary()

五、训练模型

NO_EPOCHS = 20history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)

六、模型对比评估

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']epochs_range = range(len(acc1))plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.show()

可以看出,在这个实例中,Adam优化器的效果优于SGD优化器

七、总结

      通过本次实验,学会了比较不同优化器(Adam和SGD)在训练过程中的性能表现,可视化训练过程的损失曲线和准确率等指标。这是一项非常重要的技能,在研究论文中,可以通过这些优化方法可以提高工作量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/65673.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FFmpeg Muxer HLS

使用FFmpeg命令来研究它对HLS协议的支持程度是最好的方法: ffmpeg -h muxerhls Muxer HLS Muxer hls [Apple HTTP Live Streaming]:Common extensions: m3u8.Default video codec: h264.Default audio codec: aac.Default subtitle codec: webvtt. 这里面告诉我…

Docker Desktop 构建java8基础镜像jdk安装配置失效解决

Docker Desktop 构建java8基础镜像jdk安装配置失效解决 文章目录 1.问题2.解决方法3.总结 1.问题 之前的好几篇文章中分享了在Linux(centOs上)和windows10上使用docker和docker Desktop环境构建java8的最小jre基础镜像,前几天我使用Docker Desktop环境重新构建了一个…

Node.js——fs(文件系统)模块

个人简介 👀个人主页: 前端杂货铺 🙋‍♂️学习方向: 主攻前端方向,正逐渐往全干发展 📃个人状态: 研发工程师,现效力于中国工业软件事业 🚀人生格言: 积跬步…

Microsoft Azure Cosmos DB:全球分布式、多模型数据库服务

目录 前言1. Azure Cosmos DB 简介1.1 什么是 Azure Cosmos DB?1.2 核心技术特点 2. 数据模型与 API 支持2.1 文档存储(Document Store)2.2 图数据库(Graph DBMS)2.3 键值存储(Key-Value Store)…

2025年华为OD上机考试真题(Java)——整数对最小和

题目: 给定两个整数数组array1、array2,数组元素按升序排列。假设从array1、array2中分别取出一个元素可构成一对元素,现在需要取出k对元素,并对取出的所有元素求和,计算和的最小值。 注意:两对元素如果对应…

7 分布式定时任务调度框架

先简单介绍下分布式定时任务调度框架的使用场景和功能和架构,然后再介绍世面上常见的产品 我们在大型的复杂的系统下,会有大量的跑批,定时任务的功能,如果在独立的子项目中单独去处理这些任务,随着业务的复杂度的提高…

网络安全 | 网络安全法规:GDPR、CCPA与中国网络安全法

网络安全 | 网络安全法规:GDPR、CCPA与中国网络安全法 一、前言二、欧盟《通用数据保护条例》(GDPR)2.1 背景2.2 主要内容2.3 特点2.4 实施效果与影响 三、美国《加利福尼亚州消费者隐私法案》(CCPA)3.1 背景3.2 主要内…

“AI智能陪练培训服务系统,让学习更轻松、更高效

大家好,我是资深产品经理小李,今天咱们来侃侃一个新兴的教育辅助工具——AI智能陪练培训服务系统。这个系统可谓是教育培训行业的一股新势力,它究竟有什么神奇之处呢?下面我就跟大家伙儿好好聊聊。 一、什么是AI智能陪练培训服务系…

notebook主目录及pip镜像源修改

目录 一、notebook主目录修改二、pip镜像源修改 一、notebook主目录修改 在使用Jupyter Notebook进行数据分析时,生成的.ipynb文件默认会保存在Jupyter的主目录中。通常情况下,系统会将Jupyter的主目录设置为系统的文档目录,而文档目录通常位…

如何利用百炼智能体编排应用轻松搭建智能AI旅游助手?

各位小伙伴儿,好哈! 在上一篇文章《5分钟基于阿里云百炼平台搭建专属智能AI机器人》中我们体验了如何利用阿里云百炼平台的智能体应用搭建专属智能机器人。 它的配置过程相对简单,其“对话式”的输出形式也十分直观,非常适合初学…

计算机视觉目标检测-DETR网络

目录 摘要abstractDETR目标检测网络详解二分图匹配和损失函数 DETR总结总结 摘要 DETR(DEtection TRansformer)是由Facebook AI提出的一种基于Transformer架构的端到端目标检测方法。它通过将目标检测建模为集合预测问题,摒弃了锚框设计和非…

【Vim Masterclass 笔记09】S06L22:Vim 核心操作训练之 —— 文本的搜索、查找与替换操作(第一部分)

文章目录 S06L22 Search, Find, and Replace - Part One1 从光标位置起,正向定位到当前行的首个字符 b2 从光标位置起,反向查找某个字符3 重复上一次字符查找操作4 定位到目标字符的前一个字符5 单字符查找与 Vim 命令的组合6 跨行查找某字符串7 Vim 的增…

springboot 默认的 mysql 驱动版本

本案例以 springboot 3.1.12 版本为例 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>3.1.12</version><relativePath/> </parent> 点击 spring-…

计算机网络(二)——物理层和数据链路层

一、物理层 1.作用 实现相信计算机节点之间比特流的透明传输&#xff0c;尽可能屏蔽具体传输介质和物理设备的差异。 2.数据传输单位 比特。 3.相关通信概念 ①信源和信宿&#xff1a;即信号的发送方和接收方。 ②数据&#xff1a;即信息的实体&#xff0c;比如图像、视频等&am…

sql server cdc漏扫数据

SQL Server的CDC指的是“变更数据捕获”&#xff08;Change Data Capture&#xff09;。这是SQL Server数据库提供的一项功能&#xff0c;能够跟踪并记录对数据库表中数据所做的更改。这些更改包括插入、更新和删除操作。CDC可以捕获这些变更的详细信息&#xff0c;并使这些信息…

AI数字人+文旅:打造数字文旅新名片

在数字化浪潮的推动下&#xff0c;人工智能技术正以前所未有的速度渗透到我们生活的每一个角落。特别是在文化和旅游领域&#xff0c;AI数字人的出现&#xff0c;不仅为传统文旅产业注入了新的活力&#xff0c;也为游客带来了全新的体验。 肇庆AI数字人——星湖 “星湖”是肇…

做一个 简单的Django 《股票自选助手》显示 用akshare 库(A股数据获取)

图&#xff1a; 股票自选助手 这是一个基于 Django 开发的 A 股自选股票信息查看系统。系统使用 akshare 库获取实时股票数据&#xff0c;支持添加、删除和更新股票信息。 功能特点 支持添加自选股票实时显示股票价格和涨跌幅一键更新所有股票数据支持删除不需要的股票使用中…

Protobuf编码规则详解

Protobuf编码规则详解 1 Message 结构1.1 tag1.1.1 字段编号(field_num)1.1.2 传输类型(wire_type) 1.2 字段顺序1.3 默认值 2 编码2.1 Varint编码2.1.1 Varint编码过程2.1.2解码过程2.1.3 存储2.1.4 小结2.2 有符号整数(sint32和sint64)编码的问题与zigzag优化 3 编码实践3.1测…

【docker】exec /entrypoint.sh: no such file or directory

dockerfile生成的image 报错内容&#xff1a; exec /entrypoint.sh: no such file or directory查看文件正常在此路径&#xff0c;但是就是报错没找到。 可能是因为sh文件的换行符使用了win的。

计算机的错误计算(二百零七)

摘要 利用两个数学大模型计算 arccot(0.125664e2)的值&#xff0c;结果保留16位有效数字。 实验表明&#xff0c;它们的输出中分别仅含有3位和1位正确数字。 例1. 计算 arccot(0.125664e2)的值&#xff0c;结果保留16位有效数字。 下面是与一个数学解题器的对话。 以上为与…