第T6周:Tensorflow实现好莱坞明星识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目标

具体实现

(一)环境

语言环境:Python 3.10
编 译 器: PyCharm
框 架:

(二)具体步骤
1.查询TF版本和设置使用GPU
from tensorflow import keras  
from tensorflow.keras import models, layers  
import os, PIL, pathlib  
import matplotlib.pyplot as plt  
import tensorflow as tf  
import numpy as np  # 查询tensorflow版本  
print("Tensorflow Version:", tf.__version__)  # 设置使用GPU  
gpus = tf.config.list_physical_devices("GPU")  
print(gpus)  if gpus:  gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPU  tf.config.experimental.set_memory_growth(gpu0, True)    # 设置GPU显存按需使用  tf.config.set_visible_devices([gpu0], "GPU")
Tensorflow Version: 2.10.0
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
2.导入数据

目录结构如下:
image.png

# 导入数据  
data_dir = './datasets/hollywood/'  
data_dir = pathlib.Path(data_dir)  # 查看数据  
image_count = len(list(data_dir.glob('*/*.jpg')))  
print('图片总数为:', image_count)
图片总数为: 1800
# 看看Brad Pitt(布拉德·皮特)的图片  
brad_pitt = list(data_dir.glob('Brad Pitt/*.jpg'))  
im = PIL.Image.open(brad_pitt[0])  
im.show()

image.png

3.数据预处理
# 数据预处理  
batch_size = 32  
image_height = 224  
image_width = 224  train_ds = tf.keras.preprocessing.image_dataset_from_directory(  directory=data_dir,  validation_split=0.1,  subset="training",  label_mode="categorical",  seed=123,  image_size=(image_height, image_width),  batch_size=batch_size  
)  
val_ds = tf.keras.preprocessing.image_dataset_from_directory(  directory=data_dir,  validation_split=0.1,  subset="validation",  label_mode="categorical",  seed=123,  image_size=(image_height, image_width),  batch_size=batch_size  
)
Found 1800 files belonging to 17 classes.
Using 1620 files for training.
Found 1800 files belonging to 17 classes.
Using 180 files for validation.
# 输出数据集的标签(标签按字母顺序对应于目录名称)  
class_names = train_ds.class_names  
print(class_names)
['Angelina Jolie', 'Brad Pitt', 'Denzel Washington', 'Hugh Jackman', 'Jennifer Lawrence', 'Johnny Depp', 'Kate Winslet', 'Leonardo DiCaprio', 'Megan Fox', 'Natalie Portman', 'Nicole Kidman', 'Robert Downey Jr', 'Sandra Bullock', 'Scarlett Johansson', 'Tom Cruise', 'Tom Hanks', 'Will Smith']
# 可视化数据  
plt.figure(figsize=(20, 10))  
for images, labels in train_ds.take(1):  for i in range(20):  ax = plt.subplot(5, 10, i + 1)  plt.imshow(images[i].numpy().astype("uint8"))  plt.title(class_names[np.argmax(labels[i])])  plt.axis("off")  
plt.show()

image.png

# 再次检查数据  
for image_batch, labels_batch in train_ds:  print(image_batch.shape)  print(labels_batch.shape)  break
(32, 224, 224, 3)
(32, 17)
4.配置数据集
# 配置数据集  
AUTOTUNE = tf.data.AUTOTUNE  
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)  
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
5.构建CNN网络

image.png

# 构建CNN网络  
model = models.Sequential([  layers.experimental.preprocessing.Rescaling(1./255, input_shape=(image_height, image_width, 3)),  layers.Conv2D(16, (3, 3), 1, activation='relu', input_shape=(image_height, image_width)),  layers.AveragePooling2D((2, 2)),  layers.Conv2D(32, (3, 3), activation='relu'),  layers.AveragePooling2D((2, 2)),  layers.Dropout(0.5),  layers.Conv2D(64, (3, 3), activation='relu'),  layers.AveragePooling2D((2, 2)),  layers.Dropout(0.5),  layers.Conv2D(128, (3, 3), activation='relu'),  layers.Dropout(0.5),  layers.Flatten(),  layers.Dense(128, activation='relu'),  layers.Dense(len(class_names))  
])  
print(model.summary())

image.png

6.训练模型
# 训练模型  
# 1. 设置动态学习率  
initial_learning_rate = 1e-4    # 设置初始学习率  
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(  # 对学习率使用指数衰减  initial_learning_rate=initial_learning_rate,  decay_steps=60,  # 学习率衰减的步数。在经过 decay_steps 步后,学习率将按照指数函数衰减。  decay_rate=0.96,  # 学习率的衰减率。它决定了学习率如何衰减。通常,取值在 0 到 1 之间。  staircase=True  
)  
# 2. 将指数衰减学习率送入优化器  
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)  # 3. 设置损失函数  
loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)    # 多分类的对数损失函数  model.compile(  optimizer=optimizer,  loss=loss,  metrics=['accuracy']  
)  epochs = 100  # 保存模型最佳参数  
checkpointer = ModelCheckpoint(  filepath='./models/hollywood-best.h5',  monitor='val_accuracy',  verbose=1,  save_best_only=True,  save_weights_only=True  
)  # 设置早停  
earlystopper = EarlyStopping(  monitor='val_accuracy',  min_delta=0.001,  patience=20,  verbose=1  
)  # 训练  
history = model.fit(  x=train_ds,  validation_data=val_ds,  epochs=epochs,  callbacks=[checkpointer, earlystopper]  
)  # 模型评估  acc = history.history['accuracy']  
val_acc = history.history['val_accuracy']  loss = history.history['loss']  
val_loss = history.history['val_loss']  epochs_range = range(len(loss))  plt.figure(figsize=(12, 4))  
plt.subplot(1, 2, 1)  
plt.plot(epochs_range, acc, label="Training Accuracy")  
plt.plot(epochs_range, val_acc, label="Validation Accuracy")  
plt.legend(loc="lower right")  
plt.title("Training and validation accuracy")  plt.subplot(1, 2, 2)  
plt.plot(epochs_range, loss, label="Loss Accuracy")  
plt.plot(epochs_range, val_loss, label="Validation Loss")  
plt.legend(loc="lower right")  
plt.title("Training and validation Loss")  plt.show()

image.png

7.真实图片预测
# 加载效果最好的模型权重  
model.load_weights('./models/hollywood-best.h5')  
# img = PIL.Image.open("./datasets/hollywood/Jennifer Lawrence/021_2eaafb9f.jpg")  
# image = tf.image.resize(img, [image_height, image_width])  image = tf.keras.utils.load_img('./datasets/hollywood/Jennifer Lawrence/021_2eaafb9f.jpg', target_size=(image_height, image_width))  
image_array = tf.keras.utils.img_to_array(image)  # 将PIL对象转换成numpy数组  
img_array = tf.expand_dims(image_array, 0)    # /255.0  # 记得做归一化处理(与训练集处理方式保持一致)  predictions =model.predict(img_array)  print("预测结果为:", class_names[np.argmax(predictions)])
预测结果为: Jennifer Lawrence
(三)总结

实际预测会发现结果有很大概率会出错的。其次看模型评估,其实训练的结果是比较差的, val_accuracy最好只有40%左右。如何提高准确率,还需要研究,后续出实践总结。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/57650.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring6框架搭建(自用)

一、什么是Spring 众所不周知,Spring就是爪哇人的春天,但是在框架程序设计之前都绕不开javaWeb 1.javaWeb框架发展史 1、ServletJSPJavaBean(跳转页面、业务逻辑判断、数据库查询) 2、MVC三层架构(M Model pojo(User)V-view(USP)C-(controller-servl…

linux-UART

参考博客 https://blog.csdn.net/m0_38106923/article/details/126024970?sharetypeblog&shareId126024970&sharereferAPP&sharesourceweixin_40933496&sharefromlink 1.串口 UART的全称是Universal Asynchronous Receiver and Transmitter,即异步…

大数据治理:策略、技术与挑战

随着信息技术的飞速发展,大数据已经成为现代企业运营和决策的重要基础。然而,大数据的复杂性、多样性和规模性给数据管理带来了前所未有的挑战。因此,大数据治理应运而生,成为确保数据质量、合规性、安全性和可用性的关键手段。本…

vue插件清除 所有console.log()

一、作用 1、提升性能console.log() 语句会消耗一定的性能,尤其是在频繁调用的情况下。在生产环境中移除这些语句可以提高应用的运行效率。 2、减少信息泄露console.log() 可以输出敏感信息(如用户数据、API 响应等)。在生产环境中&#xf…

DAY15|二叉树Part03|LeetCode: 513.找树左下角的值、112. 路径总和、106. 从中序与后序遍历序列构造二叉树

LeetCode: 513.找树左下角的值 力扣代码链接 文字讲解:LeetCode: 513.找树左下角的值 视频讲解:怎么找二叉树的左下角? 递归中又带回溯了,怎么办? 基本思路 对题目进行一下分析,要找二叉树最底层最左边节点…

【记录】Excel 公式|(一)根据某列内容和关键词列,自动生成当前行的关键词分类名称

文章目录 引言公式解析应用场景数据准备公式应用结果分析 结论扩展应用注意事项总结后续学习结语 我的 Excel 版本:2021 引言 在当今数据驱动的世界中,高效的数据处理和分类对于企业和个人来说至关重要。Excel 作为最常用的数据处理工具之一&#xff0c…

【ROS2】hbm_img_msgs/msg/HbmMsg1080P 转 opencv cv::Mat

1、简述 在ROS2中处理图像时,经常会用的OpenCV,因此常常会涉及到ROS2话题和cv::Mat的转换 ROS2内置消息 sensor_msgs::msg::Image 可以使用 cv_bridge 转换成 OpenCV的 cv::Mat。 参见博客:【ROS2】cv_bridge:ROS图像消息和OpenCV的cv::Mat格式转换库 在使用地平线X3派时…

ClkLog企业版(CDP)预售开启,更有鸿蒙SDK前来助力

新版本发布 ClkLog在上线近1年后,获得了客户的一致肯定与好评,并收到了不少客户对功能需求的反馈。根据客户的反馈,我们在今年三季度对ClkLog的版本进行了重新的规划与调整,简化了原有的版本类型,方便客户进行选择。 与…

C++:set和map的使用

目录 序列式容器和关联式容器 set set类的介绍 构造和迭代器 增删查 insert find和erase erase迭代器失效 lower_bound与upper_bound multiset和set的区别 map map类的介绍 pair类型介绍 构造和迭代器 增删查 map数据修改:重载operator[] multimap…

Unix和Linux系统中的文件权限

详细解释Unix和Linux系统中的文件权限设置以及如何使用chmod命令来修改这些权限。 文件权限的详细解释 在Unix和Linux系统中,文件权限是控制谁可以访问和操作文件或目录的重要机制。权限分为三类:所有者(owner)、所属组&#xf…

el-tree展开子节点后宽度没有撑开,溢出内容隐藏了,不显示横向滚动条

html结构如下 <div class"tree-div"><el-tree><template #default"{ node, data }"><div class"node-item">...</div></template></el-tree></div> css代码(scss) .tree-div {width: 300px;…

android定时器循环实现轮播图

说明&#xff1a; android定时器加for循环实现轮播图 效果&#xff1a; step1: package com.example.iosdialogdemo;import android.os.Bundle; import android.os.Handler; import android.widget.ImageView; import android.widget.TextView;import androidx.appcompat.ap…

ChatGPT能预测时间序列?基于大模型的时间序列预测中的迭代事件推理_chatgpt能预测时间序列

引言 时间序列预测&#xff08;Time Series Forecasting&#xff09;是支撑经济、基础设施和社会各领域决策的关键技术。然而&#xff0c;传统的预测方法在面对由外部随机事件引起的突发性变化或异常时&#xff0c;往往表现出局限性。这些方法通常依赖于历史数据的模式识别&am…

计算机网络-传输层提供的服务

传输层在协议栈中的位置 我们可以给应用层的这些应用程序提供我们想要传输的数据&#xff0c;比如说我们想用微信传一张图片&#xff0c;或者想用QQ发一串字符。那这些数据是由我们用户直接提供的&#xff0c;那么我们的数据交给了应用层的某一个进程之后。这个进程可能会在我们…

将Notepad++添加到右键菜单【一招实现】

一键添加注册表 复制以下代码保存为 Notepad.reg&#xff0c;将红框内路径修改为自己电脑的“Notepad.exe路径”后&#xff0c;再双击运行即可。 Windows Registry Editor Version 5.00[HKEY_CLASSES_ROOT\*\shell\NotePad] "Notepad" "Icon""D:\\N…

vue3二次封装UI组件

直接上代码 <template><el-uploadclass"lth_upload":action"${baseUrl}/file/upload":headers"uploadHeader"v-bind"$attr"><template v-for"(_, key) in $slots" #[key]"valueData"><slot…

存储引擎技术进化

B-tree 目前支撑着数据库产业的半壁江山。 50 年来不变而且人们还没有改变它的意向 鉴定一个算法的优劣&#xff0c;有一个学派叫 IO复杂度分析 &#xff0c;简单推演真假便知。 下面就用此法分析下 B-tree(traditional b-tree) 的 IO 复杂度&#xff0c;对读、写 IO 一目了…

vscode | 开发神器vscode快捷键删除和恢复

目录 快捷键不好使了删除快捷键恢复删除的快捷键 在vscode使用的过程中&#xff0c;随着我们自身需求的不断变化&#xff0c;安装的插件将会持续增长&#xff0c;那么随之而来的就会带来一个问题&#xff1a;插件的快捷键重复。快捷键重复导致的问题就是快捷键不好使了&#xf…

mysql如何发现慢查询sql

mysql如何发现慢查询sql tail -n 10 /data/mysql/mysql-slow.log

vm.max_map_count 表示啥意思啊?通俗易懂点,有单位么?262144表示啥意思?

背景&#xff1a;ERROR: [1] bootstrap checks failed. You must address the points described in the following [1] lines-CSDN博客 vm.max_map_count 是一个 Linux 内核参数&#xff0c;用于限制一个进程可以拥有的最大内存映射区域数量。内存映射&#xff08;Memory Mapp…