运动鞋品牌识别

一、前期工作


1. 设置GPU

from tensorflow       import keras
from tensorflow.keras import layers,models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow        as tfgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")gpus


如果使用的是CPU可以忽略这步



2. 导入数据

data_dir = "./46-data/"data_dir = pathlib.Path(data_dir)




3. 查看数据

 

image_count = len(list(data_dir.glob('*/*/*.jpg')))print("图片总数为:",image_count)


 

图片总数为: 578
roses = list(data_dir.glob('train/nike/*.jpg'))
PIL.Image.open(str(roses[0]))

YAIRI

output_11_0.png



二、数据预处理

1. 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset中
●tf.keras.preprocessing.image_dataset_from_directory():是 TensorFlow 的 Keras 模块中的一个函数,用于从目录中创建一个图像数据集(dataset)。这个函数可以以更方便的方式加载图像数据,用于训练和评估神经网络模型。


测试集与验证集的关系:

1验证集并没有参与训练过程梯度下降过程的,狭义上来讲是没有参与模型的参数训练更新的。
2但是广义上来讲,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后模型在valid data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等。
3因此,我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集

batch_size = 32
img_height = 224
img_width = 224



如果准备尝试 categorical_crossentropy损失函数,下面的代码遇到变动哈,变动细节将在下一周博客内公布。
 

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory("./46-data/train/",seed=123,image_size=(img_height, img_width),batch_size=batch_size)

Found 502 files belonging to 2 classes.
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory("./46-data/test/",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
Found 76 files belonging to 2 classes.




我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。
 

class_names = train_ds.class_names
print(class_names)
['adidas', 'nike']



2. 可视化数据

 

plt.figure(figsize=(20, 10))for images, labels in train_ds.take(1):for i in range(20):ax = plt.subplot(5, 10, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")

output_22_0.png



3. 再次检查数据

 

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break
(32, 224, 224, 3)
(32,)


●Image_batch是形状的张量(32,224,224,3)。这是一批形状224x224x3的32张图片(最后一维指的是彩色通道RGB)。
●Label_batch是形状(32,)的张量,这些标签对应32张图片

4. 配置数据集

●shuffle() :打乱数据,关于此函数的详细介绍可以参考:数据集shuffle方法中buffer_size的理解 - 知乎
●prefetch() :预取数据,加速运行

prefetch()功能详细介绍:CPU 正在准备数据时,加速器处于空闲状态。相反,当加速器正在训练模型时,CPU 处于空闲状态。因此,训练所用的时间是 CPU 预处理时间和加速器训练时间的总和。prefetch()将训练步骤的预处理和模型执行过程重叠到一起。当加速器正在执行第 N 个训练步时,CPU 正在准备第 N+1 步的数据。这样做不仅可以最大限度地缩短训练的单步用时(而不是总用时),而且可以缩短提取和转换数据所需的时间。如果不使用prefetch(),CPU 和 GPU/TPU 在大部分时间都处于空闲状态:

image.png


使用prefetch()可显著减少空闲时间:

image.png


●cache() :将数据集缓存到内存当中,加速运行

AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)



三、构建CNN网络

卷积神经网络(CNN)的输入是张量 (Tensor) 形式的 (image_height, image_width, color_channels),包含了图像高度、宽度及颜色信息。不需要输入batch size。color_channels 为 (R,G,B) 分别对应 RGB 的三个颜色通道(color channel)。在此示例中,我们的 CNN 输入的形状是 (224, 224, 3)即彩色图像。我们需要在声明第一层时将形状赋值给参数input_shape。

网络结构图(可单击放大查看):

image.png

"""
关于卷积核的计算不懂的可以参考文章:https://blog.csdn.net/qq_38251616/article/details/114278995layers.Dropout(0.4) 作用是防止过拟合,提高模型的泛化能力。
关于Dropout层的更多介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/115826689
"""model = models.Sequential([layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 卷积层1,卷积核3*3  layers.AveragePooling2D((2, 2)),               # 池化层1,2*2采样layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3layers.AveragePooling2D((2, 2)),               # 池化层2,2*2采样layers.Dropout(0.3),  layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3layers.Dropout(0.3),  layers.Flatten(),                       # Flatten层,连接卷积层与全连接层layers.Dense(128, activation='relu'),   # 全连接层,特征进一步提取layers.Dense(len(class_names))               # 输出层,输出预期结果
])model.summary()  # 打印网络结构
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
rescaling (Rescaling)        (None, 224, 224, 3)       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 222, 222, 16)      448       
_________________________________________________________________
average_pooling2d (AveragePo (None, 111, 111, 16)      0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 109, 109, 32)      4640      
_________________________________________________________________
average_pooling2d_1 (Average (None, 54, 54, 32)        0         
_________________________________________________________________
dropout (Dropout)            (None, 54, 54, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 52, 52, 64)        18496     
_________________________________________________________________
dropout_1 (Dropout)          (None, 52, 52, 64)        0         
_________________________________________________________________
flatten (Flatten)            (None, 173056)            0         
_________________________________________________________________
dense (Dense)                (None, 128)               22151296  
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 258       
=================================================================
Total params: 22,175,138
Trainable params: 22,175,138
Non-trainable params: 0
_________________________________________________________________




四、训练模型

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

●损失函数(loss):用于衡量模型在训练期间的准确率。
●优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
●指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。

1.设置动态学习率

📮 ExponentialDecay函数:
tf.keras.optimizers.schedules.ExponentialDecay是 TensorFlow 中的一个学习率衰减策略,用于在训练神经网络时动态地降低学习率。学习率衰减是一种常用的技巧,可以帮助优化算法更有效地收敛到全局最小值,从而提高模型的性能。

🔎 主要参数:
●initial_learning_rate(初始学习率):初始学习率大小。
●decay_steps(衰减步数):学习率衰减的步数。在经过 decay_steps 步后,学习率将按照指数函数衰减。例如,如果 decay_steps 设置为 10,则每10步衰减一次。
●decay_rate(衰减率):学习率的衰减率。它决定了学习率如何衰减。通常,取值在 0 到 1 之间。
●staircase(阶梯式衰减):一个布尔值,控制学习率的衰减方式。如果设置为 True,则学习率在每个 decay_steps 步之后直接减小,形成阶梯状下降。如果设置为 False,则学习率将连续衰减。

# 设置初始学习率
initial_learning_rate = 0.1lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps=10,      # 敲黑板!!!这里是指 steps,不是指epochsdecay_rate=0.92,     # lr经过一次衰减就会变成 decay_rate*lrstaircase=True)# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)model.compile(optimizer=optimizer,loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])



注:这里设置的动态学习率为:指数衰减型(ExponentialDecay)。在每一个epoch开始前,学习率(learning_rate)都将会重置为初始学习率(initial_learning_rate),然后再重新开始衰减。计算公式如下:

learning_rate = initial_learning_rate * decay_rate ^ (step / decay_steps)

学习率大与学习率小的优缺点分析:

学习率大

● 优点:
○1、加快学习速率。
○2、有助于跳出局部最优值。
● 缺点:
○1、导致模型训练不收敛。
○2、单单使用大学习率容易导致模型不精确。

学习率小

● 优点:
○1、有助于模型收敛、模型细化。
○2、提高模型精度。
● 缺点:
○1、很难跳出局部最优值。
○2、收敛缓慢。

2.早停与保存最佳模型参数

EarlyStopping()参数说明:

●monitor: 被监测的数据。
●min_delta: 在被监测的数据中被认为是提升的最小变化, 例如,小于 min_delta 的绝对变化会被认为没有提升。
●patience: 没有进步的训练轮数,在这之后训练就会被停止。
●verbose: 详细信息模式。
●mode: {auto, min, max} 其中之一。 在 min 模式中, 当被监测的数据停止下降,训练就会停止;在 max 模式中,当被监测的数据停止上升,训练就会停止;在 auto 模式中,方向会自动从被监测的数据的名字中判断出来。
●baseline: 要监控的数量的基准值。 如果模型没有显示基准的改善,训练将停止。
●estore_best_weights: 是否从具有监测数量的最佳值的时期恢复模型权重。 如果为 False,则使用在训练的最后一步获得的模型权重。

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStoppingepochs = 50# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',monitor='val_accuracy',verbose=1,save_best_only=True,save_weights_only=True)# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy', min_delta=0.001,patience=20, verbose=1)



3. 模型训练
 

history = model.fit(train_ds,validation_data=val_ds,epochs=epochs,callbacks=[checkpointer, earlystopper])
Epoch 1/50
16/16 [==============================] - 4s 31ms/step - loss: 3.5439 - accuracy: 0.4721 - val_loss: 0.6931 - val_accuracy: 0.5789Epoch 00001: val_accuracy improved from -inf to 0.57895, saving model to best_model.h5
Epoch 2/50
16/16 [==============================] - 0s 12ms/step - loss: 0.6929 - accuracy: 0.5279 - val_loss: 0.6891 - val_accuracy: 0.6447......Epoch 00040: val_accuracy did not improve from 0.89474
Epoch 41/50
16/16 [==============================] - 0s 12ms/step - loss: 0.0931 - accuracy: 0.9841 - val_loss: 0.3837 - val_accuracy: 0.8816Epoch 00041: val_accuracy did not improve from 0.89474
Epoch 42/50
16/16 [==============================] - 0s 12ms/step - loss: 0.0871 - accuracy: 0.9801 - val_loss: 0.3834 - val_accuracy: 0.8816Epoch 00042: val_accuracy did not improve from 0.89474
Epoch 00042: early stopping



五、模型评估

1. Loss与Accuracy图
 

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(len(loss))plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

output_51_0.png


2. 指定图片进行预测
 

from PIL import Image
import numpy as np# img = Image.open("./45-data/Monkeypox/M06_01_04.jpg")  #这里选择你需要预测的图片
img = Image.open("./46-data/test/nike/1.jpg")  #这里选择你需要预测的图片
image = tf.image.resize(img, [img_height, img_width])img_array = tf.expand_dims(image, 0) #/255.0  # 记得做归一化处理(与训练集处理方式保持一致)predictions = model.predict(img_array) # 这里选用你已经训练好的模型
print("预测结果为:",class_names[np.argmax(predictions)])

预测结果为: nike


 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/178426.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Leetcode—18.四数之和【中等】

2023每日刷题&#xff08;四十一&#xff09; Leetcode—18.四数之和 实现代码 class Solution { public:vector<vector<int>> fourSum(vector<int>& nums, int target) {vector<vector<int>> ans;sort(nums.begin(), nums.end());int n …

chatgpt prompt提示词

ChatGPT 最近十分火爆&#xff0c;今天我也来让 ChatGPT 帮我阅读一下 Vue3 的源代码。 都知道 Vue3 组件有一个 setup函数。那么它内部做了什么呢&#xff0c;今天跟随 ChatGPT 来一探究竟。 实战 1.setup setup 函数在什么位置呢&#xff0c;我们不知道他的实现函数名称&…

12 网关实战:Spring Cloud Gateway基础理论

为什么需要网关? 传统的单体架构中只有一个服务开放给客户端调用,但是微服务架构中是将一个系统拆分成多个微服务,那么作为客户端如何去调用这些微服务呢?如果没有网关的存在,只能在本地记录每个微服务的调用地址。 无网关的微服务架构往往存在以下问题: 客户端多次请求…

人机交互3——多主题多轮对话

1.主动切换 2.被动切换 3.多轮状态记忆

3.2 Windows驱动开发:内核CR3切换读写内存

CR3是一种控制寄存器&#xff0c;它是CPU中的一个专用寄存器&#xff0c;用于存储当前进程的页目录表的物理地址。在x86体系结构中&#xff0c;虚拟地址的翻译过程需要借助页表来完成。页表是由页目录表和页表组成的&#xff0c;页目录表存储了页表的物理地址&#xff0c;而页表…

使用Sui天气预言机获取全球实时天气数据

新的Sui天气预言机为全球1000多个城市的建设者提供天气数据&#xff0c;并作为一个独特的随机数生成器&#xff0c;适用于需要可信赖的随机结果的游戏和投注应用。它由基于Sui的智能合约和一个从OpenWeather API获取天气数据的后端服务组成&#xff0c;任何人都可以将天气数据集…

SpringCloudAlibaba之Nacos——详细讲解

目录 一、SpringCloudAlibaba简介 1. spring cloud alibaba 特点 2.springcloud 组件 二、环境搭建 1.构建项目并引入依赖 三、Nacos 1.什么是Nacos 2.安装Nacos 3.启动安装服务 4.访问nacos的web服务管理界面 四、开发服务注册到nacos 1.创建项目并引入依赖 2.配置注册地…

【Linux】了解进程的基础知识

进程 1. 进程的概念1.1 进程的理解1.2 Linux下的进程1.3 查看进程属性1.4 getpid和getppid 2. 创建进程3. 进程状态4. 进程优先级5. 进程切换6. 环境变量7. 本地变量与内建命令 1. 进程的概念 一个已经加载到内存中的程序&#xff0c;叫做进程&#xff08;也叫任务&#xff09…

Python+Selenium WebUI自动化框架 -- 基础操作封装

前言&#xff1a; 封装Selenium基本操作&#xff0c;让所有页面操作一键调用&#xff0c;让UI自动化框架脱离高成本、低效率时代&#xff0c;将用例的重用性贯彻到极致&#xff0c;让烦人的PO模型变得无所谓&#xff0c;让一个测试小白都能编写并实现自动化。 知识储备前提&a…

中小型公司如何搭建运维平台,rancher、kubersphere、rainbond

很多开发人员应该是了解过运维发布相关的平台或实际操作过应用发布&#xff0c;但又通常不是十分熟悉。在一个初创公司&#xff0c;或者没有成熟的运维发布平台的公司&#xff0c;如果让你来搭建一套发布平台&#xff0c;你应该如何去抉择呢&#xff1f; 这里我简单介绍几种。…

【Linux】:信号在内核里的处理

信号的发送和保存 一.内核中的信号处理二.信号集操作函数1.一些信号函数2.sigprocmask3.sigpending4.写代码 三.信号在什么时候处理的四.再谈地址空间 一.内核中的信号处理 1.实际执行信号的处理动作称为信号递达(Delivery )2.信号从产生到递达之间的状态,称为信号未决(Pending…

vue找依赖包的网址

https://www.npmjs.com/ 浅收藏一下

心大数据结构题型

选择题 2021 数据处理的单位&#xff1a;数据元素 矩阵压缩存储 2022 ①单链表头插法选择 ②矩阵压缩存储&#xff0c;行优先 ③删除链表节点的时间复杂度 ④稀疏矩阵存储 ⑤平衡二叉树时间复杂度 ⑥栈和队列的出队&#xff0c;问栈的大小至少多少 ⑦拓扑排序 ⑧参考书 360…

30.0/集合/ArrayList/LinkedList

目录 30.1什么是集合? 30.1.2为什么使用集合 30.1.3自己创建一个集合类 30.1.3 集合框架有哪些? 30.1.2使用ArrayList集合 30.2增加元素 30.3查询的方法 30.4删除 30.5 修改 30.6泛型 30.1什么是集合? 我们之前讲过数组&#xff0c;数组中它也可以存放多个元素。集合…

Cenos7系统通过链接一键安装LAMP项目环境(linux,apache,mysql,php)

前言&#xff1a;嫌装环境麻烦&#xff0c;以下介绍自动安装环境的方法 一.环境配置 根据自己需要选择 操作系统&#xff1a;CenOS 7.x以上Web服务器&#xff1a;Apache 2.4数据库&#xff1a;MySQL 5.7开发框架&#xff1a;ThinkPHP 5.0&#xff08;PHP5.0以上&#xff09;…

【Web】NewStarCtf Week2 个人复现

目录 ①游戏高手 ②include 0。0 ③ez_sql ④Unserialize&#xff1f; ⑤Upload again! ⑥ R!!C!!E!! ①游戏高手 经典前端js小游戏 检索与分数相关的变量 控制台直接修改分数拿到flag ②include 0。0 禁了base64和rot13 尝试过包含/var/log/apache/access.log,ph…

Git 入门指南

什么是 Git&#xff1f; Git 的目前最流行的分布式版本控制软件&#xff0c;可以帮助我们高效敏捷的处理任何项目。 版本管理 要理解 Git 我们首先要理解版本管理。 版本管理就是开发过程中用于管理对文件、目录或者工程等内容的修改历史&#xff0c;可以让我们方便的查看历史…

java学习part20内部类

116-面向对象(高级)-类的成员之五&#xff1a;内部类_哔哩哔哩_bilibili 1.内部类

在Anaconda中用命令行安装环境以及安装包

一、下载Anaconda 下载地址 二、创建环境 1. 打开Anaconda命令行 2.创建环境 conda create -n 环境名称 python3.10(需要的python版本号) 3.激活环境 activate 环境名4.下载安装包 pip install 模块名 -i https://pypi.tuna.tsinghua.edu.cn/simple5.下载torch 官网&…

Python语言学习笔记之三(字符编码)

本课程对于有其它语言基础的开发人员可以参考和学习&#xff0c;同时也是记录下来&#xff0c;为个人学习使用&#xff0c;文档中有此不当之处&#xff0c;请谅解。 什么是字符编码 计算机从本质上来说只认识二进制中的0和1&#xff0c;字符编码(Character Encoding) 是一种将…