Tensorflow入门实战 T06-Vgg16 明星识别

目录

1、前言

2、 完整代码

3、运行过程+结果

4、遇到的问题

5、小结


  • 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

1、前言

这周主要是使用VGG16模型,完成明星照片识别。

2、 完整代码

from keras.utils import losses_utils
from tensorflow import keras
from keras import layers, models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from keras.callbacks import ModelCheckpoint, EarlyStoppinggpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True)  # 设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0], "GPU")# 导入数据
data_dir = "/Users/MsLiang/Documents/mySelf_project/pythonProject_pytorch/learn_demo/P_model/p06_vgg16/data"
data_dir = pathlib.Path(data_dir)# 查看数据
image_count = len(list(data_dir.glob('*/*.jpg')))
print("图片总数为:",image_count)  # 1800roses = list(data_dir.glob('Jennifer Lawrence/*.jpg'))
img = PIL.Image.open(str(roses[0]))
# img.show()  # 查看图片# 数据预处理
# 1、加载数据
batch_size = 32
img_height = 224
img_width = 224print('data_dir======>',data_dir)
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.1,subset="training",label_mode="categorical",seed=123,image_size=(img_height, img_width),batch_size=batch_size)"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.1,subset="validation",label_mode="categorical",seed=123,image_size=(img_height, img_width),batch_size=batch_size)class_names = train_ds.class_names
print(class_names)# 可视化数据
plt.figure(figsize=(20, 10))for images, labels in train_ds.take(1):for i in range(20):ax = plt.subplot(5, 10, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[np.argmax(labels[i])])plt.axis("off")
plt.show()# 再次检查数据
for image_batch, labels_batch in train_ds:print(image_batch.shape)   # (32, 224, 224, 3)print(labels_batch.shape)   # (32, 17)break# 配置数据集
AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)# 构建CNN网络
"""
关于卷积核的计算不懂的可以参考文章:https://blog.csdn.net/qq_38251616/article/details/114278995layers.Dropout(0.4) 作用是防止过拟合,提高模型的泛化能力。
关于Dropout层的更多介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/115826689
"""model = models.Sequential([keras.layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=(img_height, img_width, 3)),layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),  # 卷积层1,卷积核3*3layers.AveragePooling2D((2, 2)),  # 池化层1,2*2采样layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3layers.AveragePooling2D((2, 2)),  # 池化层2,2*2采样layers.Dropout(0.5),layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3layers.AveragePooling2D((2, 2)),layers.Dropout(0.5),layers.Conv2D(128, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3layers.Dropout(0.5),layers.Flatten(),  # Flatten层,连接卷积层与全连接层layers.Dense(128, activation='relu'),  # 全连接层,特征进一步提取layers.Dense(len(class_names))  # 输出层,输出预期结果
])# model.summary()  # 打印网络结构# 训练模型
# 1、设置动态学习率
# 设置初始学习率
initial_learning_rate = 1e-4lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate,decay_steps=60,      # 敲黑板!!!这里是指 steps,不是指epochsdecay_rate=0.96,     # lr经过一次衰减就会变成 decay_rate*lrstaircase=True)# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)model.compile(optimizer=optimizer,loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])# 损失函数
# 调用方式1:
model.compile(optimizer="adam",loss='categorical_crossentropy',metrics=['accuracy'])# 调用方式2:
# model.compile(optimizer="adam",
#               loss=tf.keras.losses.CategoricalCrossentropy(),
#               metrics=['accuracy'])# sparse_categorical_crossentropy(稀疏性多分类的对数损失函数)
# 调用方式1:
model.compile(optimizer="adam",loss='categorical_crossentropy',metrics=['accuracy'])
# ↑↑↑↑这里出现报错,需要将 sparse_categorical_crossentropy  改成→  categorical_crossentropy↑↑
# 调用方式2:
# model.compile(optimizer="adam",
#               loss=tf.keras.losses.SparseCategoricalCrossentropy(),
#               metrics=['accuracy'])# 函数原型
tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False,reduction=losses_utils.ReductionV2.AUTO,name='sparse_categorical_crossentropy'
)epochs = 100# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',monitor='val_accuracy',verbose=1,save_best_only=True,save_weights_only=True)# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy',min_delta=0.001,patience=20,verbose=1)# 网络模型训练
history = model.fit(train_ds,validation_data=val_ds,epochs=epochs,callbacks=[checkpointer, earlystopper])# 模型评估
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(len(loss))plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()# 指定图片进行预测
# 加载效果最好的模型权重
model.load_weights('best_model.h5')from PIL import Image
import numpy as npimg = Image.open("/Users/MsLiang/Documents/mySelf_project/pythonProject_pytorch/learn_demo/P_model/p06_vgg16/data/Jennifer Lawrence/003_963a3627.jpg")  #这里选择你需要预测的图片
image = tf.image.resize(img, [img_height, img_width])img_array = tf.expand_dims(image, 0)predictions = model.predict(img_array) # 这里选用你已经训练好的模型
print("预测结果为:",class_names[np.argmax(predictions)])

3、运行过程+结果

【查看图片】

【模型运行过程---第21epoch就早停了】

【训练精度、损失-----显然结果很很差】

4、遇到的问题

① 在运行代码的时候遇到报错:

错误:Graph execution error: Detected at node 'sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits' defined at (most recent call last):

出现这个问题来自我们使用的损失函数。

model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy'])

解决办法:

将损失函数里面的loss='sparse_categorical_crossentropy' 改成 'categorical_crossentropy',即可解决报错问题。

关于sparse_categorical_crossentropy和categorical_crossentropy的更多细节,详细参考这篇博文:交叉熵损失_多分类交叉熵损失函数-CSDN博客

5、小结

原始模型,跑出来效果很差很差!!!

(1)将原来的Adam优化器换成SGD优化器,效果如下:

(2)后续再补充,最近在写结课论文,有些忙。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/34726.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决uni-popup禁止滚动穿透

解决uni-popup禁止滚动穿透 uni-popup弹窗内内容有滚轮会带动外部页面滚动 在弹框内容页面中修改样式 .list {overflow: auto;height: calc(100% - 280rpx);overscroll-behavior: none;}overscroll-behavior: none; 表示禁止滚动溢出 发现移动遮罩层底部也会跟着变动 则在全…

奇点临近:人类与智能时代的未来

在信息爆炸的时代,我们每天都被海量的信息所淹没,如何才能在这个嘈杂的世界中找到真正有价值的信息?如何才能利用信息的力量,提升我们的认知水平,重塑我们的未来? 这些问题的答案,或许都能在雷…

无需高配置 怎么获得超流畅的VR体验?

传统VR眼镜在使用中存在一些显著不足,而实时渲染技术又是如何解决的?接下来与大家共同探讨遇到的问题以及实时渲染在VR眼镜中的实际应用。 1、高配置要求 目前主流VR一体机的眼镜需要较高配置才能运行普通VR内容,且受限于VR眼镜的算力限制&…

【小程序】聊天功能

文章目录 聊天功能实现功能实现思路后端前端效果展示 聊天功能 实现功能 要实现一个聊天机器人,它能够解答用户疑问,并且能够识别到用户聊天的主题,涉及到饮食方面时,会自动决定是否要去数据库中读取用户的相关喜好信息&#xf…

【ARM】MDK自动备份源文件

【更多软件使用问题请点击亿道电子官方网站】 1、 文档目标 解决MDK在编写文档的时候需要找回上一版代码的问题。 2、 问题场景 目前大部分情况下对于源代码的管理都是使用的Git等第三方的代码管理平台。这样的第三方代码管理平台都是针对与代码的版本更新进行管理。对于本地…

2024年6月上半月30篇大语言模型的论文推荐

大语言模型(LLMs)在近年来取得了快速发展。本文总结了2024年6月上半月发布的一些最重要的LLM论文,可以让你及时了解最新进展。 LLM进展与基准测试 1、WildBench: Benchmarking LLMs with Challenging Tasks from Real Users in the Wild Wi…

数字心动+华为运动健康服务 使用体验指导

一、应用介绍 “数字心动”是一个体育生态平台APP,践行“体育大健康娱乐数字营销”模式,打造深度融合体育平台。APP集跑步运动记录、赛事活动报名、成绩/大众等级证书查询等多功能于一体,采取“线上线下”模式,结合协会、行业、品…

【CT】LeetCode手撕—56. 合并区间

目录 题目1- 思路2- 实现⭐56. 合并区间——题解思路 3- ACM 实现 题目 原题连接:56. 合并区间 1- 思路 模式识别:合并区间 ——> 数组先排序 思路 1.先对数组内容进行排序 ——> 定义 left、right 根据排序后的结果,更新 right2.遍…

高性能的多媒体播放器(提供补帧功能)

一、简介 1、一款高性能的多媒体播放器,支持几乎所有主流和部分罕见的音视频格式。无需额外安装coder插件,即可顺利播放各种媒体文件。此外,它还提供补帧功能,显著提升了视频播放的流畅性和视觉效果 二、下载 1、文末有下载链接,不…

Shopee API接口:一键获取商品买家评论数据,赋能电商运营新智慧

一、核心功能介绍——一键获取商品买家评论数据 在电商领域,买家评论是反映商品质量和市场反馈的重要指标。为了帮助商家更好地了解买家需求,优化产品和服务,Shopee接口特别推出了获取商品买家评论数据的功能。以下是该功能的核心介绍&#…

数据库设计文档编写

PS:建议使用第三种方法 方法1:使用 Navicat 生成数据库设计文档 效果 先看简单的效果图,如果效果合适,大家在进行测试使用,不合适直接撤退,也不浪费时间。 随后在docx文档中生成目标字段的表格&#xf…

人工智能赋能数据资产分析:借助先进的人工智能技术,优化数据处理流程,显著提升数据资产分析的准确性和效率,为企业决策提供强大支撑,推动业务快速发展

一、引言 在数字化浪潮席卷全球的今天,数据已经成为企业最宝贵的资产之一。如何有效地分析这些数据,挖掘其中的价值,为企业决策提供有力支持,是每个企业都面临的挑战。近年来,人工智能技术的快速发展,为数…

【面试干货】Java中的++操作符与线程安全性

【面试干货】Java中的操作符与线程安全性 1、什么是线程安全性?2、 操作符的工作原理3、 操作符与线程安全性4、如何确保线程安全?5、 结论 💖The Begin💖点点关注,收藏不迷路💖 在Java编程中,操…

Java 8新特性全面解读

Java 8新特性全面解读 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿! Java 8引入了许多令人兴奋的新特性,为开发者提供了更强大的工具和更高效的编…

非遗!四川省21市非遗大师工作室申报认定条件程序和认定补贴经费支持(管理办法)

第一章总则 第一条贯彻落实中共中央办公厅、国务院办公厅《关于进一步加强非物质文化遗产保护工作的意见》(厅字〔2021〕31号)、四川省文化和旅游厅等12部门《关于进一步加强非物质文化遗产保护工作的实施意见》(川文旅发〔2022〕25号&#…

SpringCloud是什么?它解决了什么问题?

Spring Cloud是一个基于Spring Boot提供的一系列框架的集合,它利用Spring Boot的开发便利性简化了分布式系统(例如微服务架构下的应用程序)的开发。Spring Cloud为开发者提供了在分布式系统中快速实现和采用模式(pattern&#xff…

Kendryte K210 固件烧录

本章将为读者介绍 Kendryte K210 的固件烧录,以及 Kendryte K210 外部 NOR Flash 的空间 分布。 本章分为如下几个小节: 6.1 外部 NOR Flash 的空间分布 6.2 Ubuntu 下的固件烧录 6.3 Windows 下的固件烧录 外部 NOR Flash 的空间分布 Kendryte K210 的…

Python信息处理问题精选及参考答案

目录 使用MATLAB或Python实现一个简单的FIR滤波器,并测试其性能。 介绍一个常用的信号处理库(如SciPy, OpenCV等),并演示其在特定问题上的应用。 使用Simulink搭建一个数字通信系统的模型,并分析其性能指标。 实现…

mac 常用工具快捷键集合

一、vim 快捷键 1、移动光标 h j k l 左 下 上 右 箭头上 上移一行 箭头下 下移一行 0 跳至行首,不管有无缩进,就是跳到第0个字符 ^ 跳至行首的第一个字符 $ 跳至行尾 gg 跳至文首 G 调至文尾 5gg/5G 调至第5行w 跳到下一个字首,按标点或…

51单片机最火型号大比拼:性能、应用与选型指南

51单片机作为经典的微控制器架构,凭借其易于学习、价格低廉、应用广泛等优势,一直活跃在嵌入式开发领域。面对市场上琳琅满目的51单片机型号,初学者和开发者常常感到眼花缭乱。本文将对几款最火的51单片机型号进行深度剖析,从性能…