Tensorflow入门实战 T06-Vgg16 明星识别

目录

1、前言

2、 完整代码

3、运行过程+结果

4、遇到的问题

5、小结


  • 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

1、前言

这周主要是使用VGG16模型,完成明星照片识别。

2、 完整代码

from keras.utils import losses_utils
from tensorflow import keras
from keras import layers, models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from keras.callbacks import ModelCheckpoint, EarlyStoppinggpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True)  # 设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0], "GPU")# 导入数据
data_dir = "/Users/MsLiang/Documents/mySelf_project/pythonProject_pytorch/learn_demo/P_model/p06_vgg16/data"
data_dir = pathlib.Path(data_dir)# 查看数据
image_count = len(list(data_dir.glob('*/*.jpg')))
print("图片总数为:",image_count)  # 1800roses = list(data_dir.glob('Jennifer Lawrence/*.jpg'))
img = PIL.Image.open(str(roses[0]))
# img.show()  # 查看图片# 数据预处理
# 1、加载数据
batch_size = 32
img_height = 224
img_width = 224print('data_dir======>',data_dir)
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.1,subset="training",label_mode="categorical",seed=123,image_size=(img_height, img_width),batch_size=batch_size)"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.1,subset="validation",label_mode="categorical",seed=123,image_size=(img_height, img_width),batch_size=batch_size)class_names = train_ds.class_names
print(class_names)# 可视化数据
plt.figure(figsize=(20, 10))for images, labels in train_ds.take(1):for i in range(20):ax = plt.subplot(5, 10, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[np.argmax(labels[i])])plt.axis("off")
plt.show()# 再次检查数据
for image_batch, labels_batch in train_ds:print(image_batch.shape)   # (32, 224, 224, 3)print(labels_batch.shape)   # (32, 17)break# 配置数据集
AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)# 构建CNN网络
"""
关于卷积核的计算不懂的可以参考文章:https://blog.csdn.net/qq_38251616/article/details/114278995layers.Dropout(0.4) 作用是防止过拟合,提高模型的泛化能力。
关于Dropout层的更多介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/115826689
"""model = models.Sequential([keras.layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=(img_height, img_width, 3)),layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),  # 卷积层1,卷积核3*3layers.AveragePooling2D((2, 2)),  # 池化层1,2*2采样layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3layers.AveragePooling2D((2, 2)),  # 池化层2,2*2采样layers.Dropout(0.5),layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3layers.AveragePooling2D((2, 2)),layers.Dropout(0.5),layers.Conv2D(128, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3layers.Dropout(0.5),layers.Flatten(),  # Flatten层,连接卷积层与全连接层layers.Dense(128, activation='relu'),  # 全连接层,特征进一步提取layers.Dense(len(class_names))  # 输出层,输出预期结果
])# model.summary()  # 打印网络结构# 训练模型
# 1、设置动态学习率
# 设置初始学习率
initial_learning_rate = 1e-4lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate,decay_steps=60,      # 敲黑板!!!这里是指 steps,不是指epochsdecay_rate=0.96,     # lr经过一次衰减就会变成 decay_rate*lrstaircase=True)# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)model.compile(optimizer=optimizer,loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])# 损失函数
# 调用方式1:
model.compile(optimizer="adam",loss='categorical_crossentropy',metrics=['accuracy'])# 调用方式2:
# model.compile(optimizer="adam",
#               loss=tf.keras.losses.CategoricalCrossentropy(),
#               metrics=['accuracy'])# sparse_categorical_crossentropy(稀疏性多分类的对数损失函数)
# 调用方式1:
model.compile(optimizer="adam",loss='categorical_crossentropy',metrics=['accuracy'])
# ↑↑↑↑这里出现报错,需要将 sparse_categorical_crossentropy  改成→  categorical_crossentropy↑↑
# 调用方式2:
# model.compile(optimizer="adam",
#               loss=tf.keras.losses.SparseCategoricalCrossentropy(),
#               metrics=['accuracy'])# 函数原型
tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False,reduction=losses_utils.ReductionV2.AUTO,name='sparse_categorical_crossentropy'
)epochs = 100# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',monitor='val_accuracy',verbose=1,save_best_only=True,save_weights_only=True)# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy',min_delta=0.001,patience=20,verbose=1)# 网络模型训练
history = model.fit(train_ds,validation_data=val_ds,epochs=epochs,callbacks=[checkpointer, earlystopper])# 模型评估
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(len(loss))plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()# 指定图片进行预测
# 加载效果最好的模型权重
model.load_weights('best_model.h5')from PIL import Image
import numpy as npimg = Image.open("/Users/MsLiang/Documents/mySelf_project/pythonProject_pytorch/learn_demo/P_model/p06_vgg16/data/Jennifer Lawrence/003_963a3627.jpg")  #这里选择你需要预测的图片
image = tf.image.resize(img, [img_height, img_width])img_array = tf.expand_dims(image, 0)predictions = model.predict(img_array) # 这里选用你已经训练好的模型
print("预测结果为:",class_names[np.argmax(predictions)])

3、运行过程+结果

【查看图片】

【模型运行过程---第21epoch就早停了】

【训练精度、损失-----显然结果很很差】

4、遇到的问题

① 在运行代码的时候遇到报错:

错误:Graph execution error: Detected at node 'sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits' defined at (most recent call last):

出现这个问题来自我们使用的损失函数。

model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy'])

解决办法:

将损失函数里面的loss='sparse_categorical_crossentropy' 改成 'categorical_crossentropy',即可解决报错问题。

关于sparse_categorical_crossentropy和categorical_crossentropy的更多细节,详细参考这篇博文:交叉熵损失_多分类交叉熵损失函数-CSDN博客

5、小结

原始模型,跑出来效果很差很差!!!

(1)将原来的Adam优化器换成SGD优化器,效果如下:

(2)后续再补充,最近在写结课论文,有些忙。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/34726.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

奇点临近:人类与智能时代的未来

在信息爆炸的时代,我们每天都被海量的信息所淹没,如何才能在这个嘈杂的世界中找到真正有价值的信息?如何才能利用信息的力量,提升我们的认知水平,重塑我们的未来? 这些问题的答案,或许都能在雷…

无需高配置 怎么获得超流畅的VR体验?

传统VR眼镜在使用中存在一些显著不足,而实时渲染技术又是如何解决的?接下来与大家共同探讨遇到的问题以及实时渲染在VR眼镜中的实际应用。 1、高配置要求 目前主流VR一体机的眼镜需要较高配置才能运行普通VR内容,且受限于VR眼镜的算力限制&…

【小程序】聊天功能

文章目录 聊天功能实现功能实现思路后端前端效果展示 聊天功能 实现功能 要实现一个聊天机器人,它能够解答用户疑问,并且能够识别到用户聊天的主题,涉及到饮食方面时,会自动决定是否要去数据库中读取用户的相关喜好信息&#xf…

【ARM】MDK自动备份源文件

【更多软件使用问题请点击亿道电子官方网站】 1、 文档目标 解决MDK在编写文档的时候需要找回上一版代码的问题。 2、 问题场景 目前大部分情况下对于源代码的管理都是使用的Git等第三方的代码管理平台。这样的第三方代码管理平台都是针对与代码的版本更新进行管理。对于本地…

2024年6月上半月30篇大语言模型的论文推荐

大语言模型(LLMs)在近年来取得了快速发展。本文总结了2024年6月上半月发布的一些最重要的LLM论文,可以让你及时了解最新进展。 LLM进展与基准测试 1、WildBench: Benchmarking LLMs with Challenging Tasks from Real Users in the Wild Wi…

数字心动+华为运动健康服务 使用体验指导

一、应用介绍 “数字心动”是一个体育生态平台APP,践行“体育大健康娱乐数字营销”模式,打造深度融合体育平台。APP集跑步运动记录、赛事活动报名、成绩/大众等级证书查询等多功能于一体,采取“线上线下”模式,结合协会、行业、品…

【CT】LeetCode手撕—56. 合并区间

目录 题目1- 思路2- 实现⭐56. 合并区间——题解思路 3- ACM 实现 题目 原题连接:56. 合并区间 1- 思路 模式识别:合并区间 ——> 数组先排序 思路 1.先对数组内容进行排序 ——> 定义 left、right 根据排序后的结果,更新 right2.遍…

高性能的多媒体播放器(提供补帧功能)

一、简介 1、一款高性能的多媒体播放器,支持几乎所有主流和部分罕见的音视频格式。无需额外安装coder插件,即可顺利播放各种媒体文件。此外,它还提供补帧功能,显著提升了视频播放的流畅性和视觉效果 二、下载 1、文末有下载链接,不…

Shopee API接口:一键获取商品买家评论数据,赋能电商运营新智慧

一、核心功能介绍——一键获取商品买家评论数据 在电商领域,买家评论是反映商品质量和市场反馈的重要指标。为了帮助商家更好地了解买家需求,优化产品和服务,Shopee接口特别推出了获取商品买家评论数据的功能。以下是该功能的核心介绍&#…

数据库设计文档编写

PS:建议使用第三种方法 方法1:使用 Navicat 生成数据库设计文档 效果 先看简单的效果图,如果效果合适,大家在进行测试使用,不合适直接撤退,也不浪费时间。 随后在docx文档中生成目标字段的表格&#xf…

人工智能赋能数据资产分析:借助先进的人工智能技术,优化数据处理流程,显著提升数据资产分析的准确性和效率,为企业决策提供强大支撑,推动业务快速发展

一、引言 在数字化浪潮席卷全球的今天,数据已经成为企业最宝贵的资产之一。如何有效地分析这些数据,挖掘其中的价值,为企业决策提供有力支持,是每个企业都面临的挑战。近年来,人工智能技术的快速发展,为数…

【面试干货】Java中的++操作符与线程安全性

【面试干货】Java中的操作符与线程安全性 1、什么是线程安全性?2、 操作符的工作原理3、 操作符与线程安全性4、如何确保线程安全?5、 结论 💖The Begin💖点点关注,收藏不迷路💖 在Java编程中,操…

Kendryte K210 固件烧录

本章将为读者介绍 Kendryte K210 的固件烧录,以及 Kendryte K210 外部 NOR Flash 的空间 分布。 本章分为如下几个小节: 6.1 外部 NOR Flash 的空间分布 6.2 Ubuntu 下的固件烧录 6.3 Windows 下的固件烧录 外部 NOR Flash 的空间分布 Kendryte K210 的…

mac 常用工具快捷键集合

一、vim 快捷键 1、移动光标 h j k l 左 下 上 右 箭头上 上移一行 箭头下 下移一行 0 跳至行首,不管有无缩进,就是跳到第0个字符 ^ 跳至行首的第一个字符 $ 跳至行尾 gg 跳至文首 G 调至文尾 5gg/5G 调至第5行w 跳到下一个字首,按标点或…

51单片机最火型号大比拼:性能、应用与选型指南

51单片机作为经典的微控制器架构,凭借其易于学习、价格低廉、应用广泛等优势,一直活跃在嵌入式开发领域。面对市场上琳琅满目的51单片机型号,初学者和开发者常常感到眼花缭乱。本文将对几款最火的51单片机型号进行深度剖析,从性能…

蓝牙透传芯片TD5322A,低功耗ble芯片,蓝牙电表通信方案介绍—拓达半导体

蓝牙透传芯片TD5322A芯片是一款支持蓝牙BLE的纯数传芯片, 蓝牙5.1版本。芯片的亮点在尺寸小( SOP-8封装)、主从切换、性能强、 性价比高。以及简单明了的透传和串口 AT 控制功能。大大降低了嵌入蓝牙在其它产品的开发难度和成本。 蓝牙透传芯…

中国 AGI 市场—4543 亿市场下的新机会

前言 我们正站在一个全新智能纪元的路口,围绕通用人工智能(AGI),在学术界、科技界、产业界的讨论中,一部分 AGI 的神秘面纱已被揭开,但这面纱之后还有更多的未知等待着我们。 InfoQ 研究中心在此背景下&a…

LabVIEW高精度电能质量监测系统

LabVIEW和研华采集卡的高精度电能质量监测系统利用虚拟仪器技术,实时监测电能质量的关键指标,如三相电压、频率和谐波。通过提高监测精度和效率,改善电网的电能质量。系 一、系统背景 电能作为现代社会的关键能源,其质量直接影响…

Casaos之qittorrent设置(没有账号密码)

点击安装只有没有账号密码,只能从运行日志中找密码: # 查看container docker ps -a # 查看container日志 docker logs ae15cb90afbd 进入系统 最下方,保存。

改网络ip地址有什么用

在数字化时代,网络IP地址是每个网络设备和终端在互联网上的唯一标识符。然而,有时出于安全、隐私或网络管理的需要,我们可能需要更改网络IP地址。例如很多小伙伴会选择使用虎观代理IP更改电脑或手机设备上的网络IP地址,那么&#…