深度学习笔记11-优化器对比实验(Tensorflow)

  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

目录

一、导入数据并检查

二、配置数据集

三、数据可视化

四、构建模型

五、训练模型

六、模型对比评估

七、总结


一、导入数据并检查

import pathlib,PIL
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签data_dir    = pathlib.Path("./T6")
image_count = len(list(data_dir.glob('*/*')))
batch_size = 16
img_height = 336
img_width  = 336
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

二、配置数据集

AUTOTUNE = tf.data.AUTOTUNE
#归一化处理
def train_preprocessing(image,label):return (image/255.0,label)train_ds = (train_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)val_ds = (val_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)

三、数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")for images, labels in train_ds.take(1):for i in range(15):plt.subplot(4, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)# 显示图片plt.imshow(images[i])# 显示标签plt.xlabel(class_names[labels[i]-1])plt.show()

四、构建模型

from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Modeldef create_model(optimizer='adam'):# 加载预训练模型vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',include_top=False,#不包含顶层的全连接层input_shape=(img_width, img_height, 3),pooling='avg')#平均池化层替代顶层的全连接层for layer in vgg16_base_model.layers:layer.trainable = False  #将 trainable属性设置为 False 意味着在训练过程中,这些层的权重不会更新X = vgg16_base_model.outputX = Dense(170, activation='relu')(X)X = BatchNormalization()(X)X = Dropout(0.5)(X)output = Dense(len(class_names), activation='softmax')(X)#神经元数量等于类别数vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)vgg16_model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics=['accuracy'])return vgg16_modelmodel1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())#随机梯度下降(SGD)优化器的
model2.summary()

五、训练模型

NO_EPOCHS = 20history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)

六、模型对比评估

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']epochs_range = range(len(acc1))plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.show()

可以看出,在这个实例中,Adam优化器的效果优于SGD优化器

七、总结

      通过本次实验,学会了比较不同优化器(Adam和SGD)在训练过程中的性能表现,可视化训练过程的损失曲线和准确率等指标。这是一项非常重要的技能,在研究论文中,可以通过这些优化方法可以提高工作量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/65673.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FFmpeg Muxer HLS

使用FFmpeg命令来研究它对HLS协议的支持程度是最好的方法: ffmpeg -h muxerhls Muxer HLS Muxer hls [Apple HTTP Live Streaming]:Common extensions: m3u8.Default video codec: h264.Default audio codec: aac.Default subtitle codec: webvtt. 这里面告诉我…

Apache和PHP:构建动态网站的黄金组合

在当今的互联网世界,网站已经成为了企业、个人和机构展示自己、与用户互动的重要平台。而在这些动态网站的背后,Apache和PHP无疑是最受开发者青睐的技术组合之一。这一组合提供了高效、灵活且可扩展的解决方案,帮助您快速搭建出强大的网站&am…

git相关操作笔记

git相关操作笔记 1. git init git init 是一个 Git 命令,用于初始化一个新的 Git 仓库。执行该命令后,Git 会在当前目录创建一个 .git 子目录,这是 Git 用来存储所有版本控制信息的地方。 使用方法如下: (1&#xff…

Docker Desktop 构建java8基础镜像jdk安装配置失效解决

Docker Desktop 构建java8基础镜像jdk安装配置失效解决 文章目录 1.问题2.解决方法3.总结 1.问题 之前的好几篇文章中分享了在Linux(centOs上)和windows10上使用docker和docker Desktop环境构建java8的最小jre基础镜像,前几天我使用Docker Desktop环境重新构建了一个…

VUE + pdfh5 实现pdf 预览,主要用来uniappH5实现嵌套预览PDF

1. 安装依赖 npm install pdfh5 2. pdfh5 预览(移动端,h5) npm install pdfh5 , (会报错,需要其他依赖,不能直接用提示的语句直接npm下载,依旧会报错,npm报错:These dependencies were not fou…

Node.js——fs(文件系统)模块

个人简介 👀个人主页: 前端杂货铺 🙋‍♂️学习方向: 主攻前端方向,正逐渐往全干发展 📃个人状态: 研发工程师,现效力于中国工业软件事业 🚀人生格言: 积跬步…

Microsoft Azure Cosmos DB:全球分布式、多模型数据库服务

目录 前言1. Azure Cosmos DB 简介1.1 什么是 Azure Cosmos DB?1.2 核心技术特点 2. 数据模型与 API 支持2.1 文档存储(Document Store)2.2 图数据库(Graph DBMS)2.3 键值存储(Key-Value Store)…

springboot项目读取resources目录下文件

要用以下这种方式读取 classPathResource new ClassPathResource("template/test.docx");不能用以下这种获取绝对路径的方式,idea调试正常,但是部署window和linux的目录结构不一样,部署后会找不到文件,另外window直接…

Ruby语言的软件开发工具

Ruby语言的软件开发工具概述 引言 Ruby是一种简单且功能强大的编程语言,它以优雅的语法和灵活性而闻名。自1995年首次发布以来,Ruby已经被广泛应用于各种开发领域,特别是Web开发。随着Ruby语言的普及,相关的开发工具也日益丰富。…

C++例程:使用I/O模拟IIC接口(6)

完整的STM32F405代码工程I2C驱动源代码跟踪 一)myiic.c #include "myiic.h" #include "delay.h" #include "stm32f4xx_rcc.h" //初始化IIC void IIC_Init(void) { GPIO_InitTypeDef GPIO_InitStructure;RCC_AHB1PeriphCl…

CNN-BiLSTM-Attention模型详解及应用分析

CNN-BiLSTM-Attention结构 CNN-BiLSTM-Attention结构是一种强大的深度学习架构,巧妙地结合了三种不同的技术优势:卷积神经网络(CNN)、双向长短期记忆网络(BiLSTM)和注意力机制(Attention)。这种创新性的组合使得模型能够在处理复杂序列数据时表现出色,尤其适用于自然…

2025年华为OD上机考试真题(Java)——整数对最小和

题目: 给定两个整数数组array1、array2,数组元素按升序排列。假设从array1、array2中分别取出一个元素可构成一对元素,现在需要取出k对元素,并对取出的所有元素求和,计算和的最小值。 注意:两对元素如果对应…

【Java知识】Groovy 一个兼容java的编程语言

groovy语言介绍 概述一、基本特点二、主要特性三、应用领域四、与Java的比较 基本语法特性一、基本语法二、数据类型三、运算符四、字符串五、方法六、闭包七、类与对象八、异常处理九、其他特性 集成到springboot项目1. 创建Spring Boot项目2. 添加Groovy依赖3. 编写Groovy类4…

Python网络爬虫:从入门到实战

Python以其简洁易用和强大的库支持成为网络爬虫开发的首选语言。本文将系统介绍Python网络爬虫的开发方法,包括基础知识、常用工具以及实战案例,帮助读者从入门到精通。 什么是网络爬虫? 网络爬虫(Web Crawler)是一种…

【vLLM 学习】安装

vLLM 是一款专为大语言模型推理加速而设计的框架,实现了 KV 缓存内存几乎零浪费,解决了内存管理瓶颈问题。 更多 vLLM 中文文档及教程可访问 →https://vllm.hyper.ai/ vLLM 是一个 Python 库,包含预编译的 C 和 CUDA (12.1) 二进制文件。 …

npm : 无法加载文件 D:\SoftFile\npm.ps1,因为在此系统上禁止运行脚本。

这个错误是由于 Windows PowerShell 的执行策略禁止执行脚本,导致无法运行 npm 命令。你可以通过以下步骤来解决这个问题: 以管理员身份运行 PowerShell: 点击“开始”菜单,搜索“PowerShell”,然后右键点击“Windows …

7 分布式定时任务调度框架

先简单介绍下分布式定时任务调度框架的使用场景和功能和架构,然后再介绍世面上常见的产品 我们在大型的复杂的系统下,会有大量的跑批,定时任务的功能,如果在独立的子项目中单独去处理这些任务,随着业务的复杂度的提高…

网络安全 | 网络安全法规:GDPR、CCPA与中国网络安全法

网络安全 | 网络安全法规:GDPR、CCPA与中国网络安全法 一、前言二、欧盟《通用数据保护条例》(GDPR)2.1 背景2.2 主要内容2.3 特点2.4 实施效果与影响 三、美国《加利福尼亚州消费者隐私法案》(CCPA)3.1 背景3.2 主要内…

Elixir语言的计算机基础

Elixir语言的计算机基础 引言 在当今这个快速发展的技术时代,编程语言层出不穷。Elixir作为一种较新的编程语言,以其高并发、低延迟和强大的容错能力受到越来越多开发者的青睐。它基于Erlang虚拟机(BEAM),自然继承了…

mysql的mvcc理解

人阅读 一、说到mvcc就少不了事务隔离级别(大白话解释) 序列化(SERIALIZABLE):事务之间完全隔离,当成一个序列,一个一个执行。 1 可重复读(REPEATABLE READ)&#xff…