TensorFlow案例学习:对服装图像进行分类

前言

官方为我们提供了一个 对服装图像进行分类 的案例,方便我们快速学习

学习

预处理数据

案例中有下面这段代码

# 预处理数据,检查训练集中的第一个图像可以看到像素值处于0~255之间
plt.figure() # 创建图像窗口
plt.imshow(train_images[0]) # 显示图片
plt.colorbar()  # 在图像旁边添加颜色条
plt.grid(False) # 取消网格线
plt.show() # 显示图形窗口# 将值缩小至0~1之间,然后将其反馈到神经网络模型。训练集和测试集都需要处理
train_images = train_images / 255.0
test_images = test_images / 255.0

在这里插入图片描述

百度查了一下,将值缩小至0~1之间是为了

将训练集和测试集数据的值缩小到0~1之间是为了进行数据归一化(Normalization)。这是一个常见的预处理步骤,对于图像分类任务特别重要。
将图像的像素值缩放到0~1之间有几个好处:

  • 数值范围一致性:将所有像素值限制在0~1范围内可以确保不同样本的特征具有一致的数值区间。这有助于避免某些特征对模型训练产生过大的影响。
  • 梯度下降稳定性:在深度学习中,常用的优化算法如梯度下降依赖于权重的更新和损失函数的梯度计算。将像素值缩小到较小的范围可以使这些计算更加稳定,有助于加速模型的收敛。
  • 避免数值溢出:在一些激活函数和优化算法中,如果输入值太大,可能会导致数值溢出或不稳定的情况。将像素值限制在0~1之间可以减少这种情况的发生。

以后再遇见处理255时就明白这样做的目的了

构建模型

构建神经网络需要先配置模型的层,然后再编译模型。

设置层
神经网络的基本组成部分是层。层会从向其馈送的数据中提取表示形式。希望这些表示形式有助于解决手头上的问题。

大多数深度学习都包括将简单的层链接在一起。大多数层(如 tf.keras.layers.Dense)都具有在训练期间才会学习的参数

# 1、设置层
# tf.keras是TensorFlow中的高级API,用于构建和训练神经网络模型。它是一个基于Keras库的接口,提供了更简单、更高级的方式来定义、配置和训练神经网络模型。
# tf.keras.Sequential 用于按顺序堆叠各个神经网络层来构建模型,是一种简单的模型类型
model = tf.keras.Sequential([# 将图像格式从二维数组(28*28像素),转化为一维数组(28*28 = 784像素)。将该层视为图像中未堆叠的像素行并将其排列起来。该层没有要学习的参数,它只会重新格式化数据。tf.keras.layers.Flatten(input_shape=(28,28)), # 第二层,是一个具有128个神经元的全连接神经层tf.keras.layers.Dense(128,activation='relu'),# 第三层会返回一个长度为10的数组,每个都包含一个得分来表示当前图像属于10个类中的哪一个tf.keras.layers.Dense(10)
])

这段代码我相信很多人跟我一样都有些疑问,还好现在有gpt,不然都不知道上哪里去找答案。下面是我的一些疑问及gpt的回答:

  • 为什么只有三层。答:在神经网络中,层数的选择是一个灵活的设计选择,取决于特定问题的复杂性和数据集的特征。选择三层可能是为了简化模型或者问题本身不需要更多层
  • 第二层为什么是tf.keras.layers.Dense(128)。答:选择128个神经元是基于对问题复杂性的估计和经验。如果问题比较复杂或数据集较大,增加神经元数量可以增加模型的容量,提高模型的表示能力。
  • 第三层为什么是tf.keras.layers.Dense(10)。答:因为这是一个分类问题,这个案例中有10个分类。每个神经元对应一个类别,并输出相应类别的预测概率。
  • tf.keras.layers.Dense(128)是计算的来的吗。答:通常需要根据实际问题和数据集来进行调整。增加神经元的数量可以增加模型的容量和学习能力,但也可能导致过拟合。过拟合是指模型在训练数据上表现良好,但在新数据上表现较差。建议先从较小的数量开始,然后逐渐增加,直到模型的性能不再提高或开始出现过拟合为止。
  • 模型的最后一层是输出层吗。答:模型的最后一层通常是输出层。输出层的神经元数量通常与你要解决的问题相关。对于分类任务,输出层的神经元数量应该等于类别的数量。对于二分类任务,可以使用一个神经元来表示两个类别的概率。对于多分类任务,可以使用多个神经元,每个神经元表示一个类别的概率。在使用tf.keras``构建模型时,你可以使用tf.keras.layers.Dense`来定义输出层,并使用适当的激活函数来产生输出。

编译模型

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数 - 测量模型在训练期间的准确程度。你希望最小化此函数,以便将模型“引导”到正确的方向上。
  • 优化器 - 决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标 - 用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
# 2、编译模型
model.compile(optimizer='adam', # 指定优化器,adam是常用的优化器,可以自适应的调整学习率loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # 指定损失函数,这里使用了稀疏分类交叉熵损失函数metrics=['accuracy'] # 指定评估模型性能的指标,这里使用准确率
)

训练模型

训练神经网络模型需要执行以下步骤:

  • 将训练数据馈送给模型。在本例中,训练数据位于 train_images 和 train_labels 数组中。
  • 模型学习将图像和标签关联起来。
  • 要求模型对测试集(在本例中为 test_images 数组)进行预测。
  • 验证预测是否与 test_labels 数组中的标签相匹配。
# 1、将训练数据反馈给模型
# model.fit用于将模型与训练数据进行拟合,这里是将所有样本迭代10次
model.fit(train_images,train_labels,epochs=10)

如下图:
在这里插入图片描述

# 2、在测试数据集上评估准确率,verbose=2参数表示以详细模式输出评估过程
test_loss,test_acc = model.evaluate(test_images,test_labels,verbose=2)
print("损失率:",test_loss,"准确率:",test_acc)

如下图:
在这里插入图片描述

进行预测

# 进行预测
# 模型经过训练后,您可以使用它对一些图像进行预测。附加一个 Softmax 层,将模型的线性输出 logits 转换成更容易理解的概率
probability_model = tf.keras.Sequential([model,tf.keras.layers.Softmax()])
# 预测图片
predictions = probability_model.predict(test_images)print("第一个预测结果:",predictions[0])

预测结果是一个包含 10 个数字的数组。它们代表模型对 10 种不同服装中每种服装的“置信度”。您可以看到哪个标签的置信度值最大:

np.argmax(predictions[0])

使用训练好的模型

现在模型已经训练好了,我们可以基于模型对单个图像进行预测

# 使用训练好的模型
# 加载图片
img = Image.open('pics/shirt.png') 
# 调整大小
img = img.resize((28,28))
# 将彩色图片转为灰度图片
img_gray = img.convert('L')
# 将图像转换为 NumPy 数组,并反转颜色
img_arr = np.array(img_gray)
img_arr = 255 - img_arr
# 将图像像素值归一化到0~1
img_arr = img_arr / 255.0
# 将图像形状调整为(128288)
img_arr = img_arr.reshape(1,28,28)
# 可以保存处理后的文件,也可以进行预测
# np.save('abc.npy',img_arr)
# tf.keras 模型经过了优化,可同时对一个批或一组样本进行预测。因此,即便您只使用一个图像,您也需要将其添加到列表中
#img_arr = tf.keras.preprocessing.image.img_to_array(img)res = probability_model.predict(img_arr)
print("预测结果是:",res,class_names[np.argmax(res[0])])# 可视化显示
font = FontProperties()
font.set_family('Microsoft YaHei')
plt.figure() # 创建图像窗口
plt.xticks([])
plt.yticks([])
plt.grid(False) # 取消网格线
plt.imshow(img_arr[0]) # 显示图片
plt.xlabel(class_names[np.argmax(res[0])],fontproperties=font)
plt.show() # 显示图形窗口

这块是最复杂的,搞了好久才成功。你加载的图片是彩色的,你必须将图片变成灰度的,并且是28*28像素的图片,也就是你的图片要处理成符合这个模型的图片才行。

但是最终结果其实也不是很准确,根本原因是你的图片处理后,能够获取的特征就很少了,这样会导致判断错误。

结果
在这里插入图片描述

遇到的问题

问题1
在执行(train_images, train_labels), (test_images,test_labels) = fashion_mnist.load_data()时提示

Exception: URL fetch failure on https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz: None – [WinError 10054] 远程主机强迫关闭了 一个现有的连接。

这是加载数据集时失败了,国内访问下载谷歌的数据总会出现这样的问题。

解决:
1、打开数据集官方网站 https://github.com/zalandoresearch/fashion-mnist,将下面这4个数据下载到本地放到项目里

在这里插入图片描述
2、加载本地数据

import gzip
import numpy as npdef load_data():# 加载训练集图像数据with gzip.open('train-images-idx3-ubyte.gz', 'rb') as f:train_images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)# 加载训练集标签数据with gzip.open('train-labels-idx1-ubyte.gz', 'rb') as f:train_labels = np.frombuffer(f.read(), np.uint8, offset=8)# 加载测试集图像数据with gzip.open('t10k-images-idx3-ubyte.gz', 'rb') as f:test_images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)# 加载测试集标签数据with gzip.open('t10k-labels-idx1-ubyte.gz', 'rb') as f:test_labels = np.frombuffer(f.read(), np.uint8, offset=8)return (train_images, train_labels), (test_images, test_labels)# 调用加载数据函数
(train_images, train_labels), (test_images, test_labels) = load_data()

问题2
验证前25个图像,设置中文乱码。教程中的使用的是英文,我这里尝试了一下中文,中文乱码
在这里插入图片描述
解决:设置中文字体

# 字体属性
from matplotlib.font_manager import FontProperties# 验证训练集中的前25个图像,并显示其名称
font = FontProperties()
font.set_family('Microsoft YaHei')
plt.figure(figsize=(10,10))
for i in range(25):plt.subplot(5,5,i+1) # 按照 5*5进行显示plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(class_names[train_labels[i]],fontproperties=font)
plt.show()

在这里插入图片描述

完整代码

# 导入 TensorFlow 重命名
import tensorflow as tf# numpy是科学计算库,matplotlib是用于绘制图表和可视化数据的库
import numpy as np
import matplotlib.pylab as plt
# 字体属性
from matplotlib.font_manager import FontProperties# 用于加载文件
import gzip# 用于处理图片
from PIL import Image# 用于加载数据集的函数
def load_data():# 加载训练集图像数据with gzip.open('train-images-idx3-ubyte.gz', 'rb') as f:train_images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)# 加载训练集标签数据with gzip.open('train-labels-idx1-ubyte.gz', 'rb') as f:train_labels = np.frombuffer(f.read(), np.uint8, offset=8)# 加载测试集图像数据with gzip.open('t10k-images-idx3-ubyte.gz', 'rb') as f:test_images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)# 加载测试集标签数据with gzip.open('t10k-labels-idx1-ubyte.gz', 'rb') as f:test_labels = np.frombuffer(f.read(), np.uint8, offset=8)return (train_images, train_labels), (test_images, test_labels)print("tf版本:",tf.__version__)# 导入数据集,TensorFlow 内置的数据集
fashion_mnist = tf.keras.datasets.fashion_mnist
# 将训练数据、测试数据取出,保存的元组里
(train_images, train_labels), (test_images,test_labels) = load_data()# 映射标签类,用于后面绘制图像使用
class_names = ['T恤/上衣', '裤子', '套头衫', '连衣裙', '外套', '凉鞋', '衬衫', '运动鞋', '包', '短靴']# 会打印出(60000, 28, 28),官方文档解释为训练集中有 60,000 个图像,每个图像由 28 x 28 的像素表示
print("训练数据集数据:",train_images.shape)# 预处理数据,检查训练集中的第一个图像可以看到像素值处于0~255之间
# plt.figure() # 创建图像窗口
# plt.imshow(train_images[0]) # 显示图片
# plt.colorbar()  # 在图像旁边添加颜色条
# plt.grid(False) # 取消网格线
# plt.show() # 显示图形窗口# 将值缩小至0~1之间,然后将其反馈到神经网络模型。训练集和测试集都需要处理
train_images = train_images / 255.0
test_images = test_images / 255.0# 验证训练集中的前25个图像,并显示其名称
# font = FontProperties()
# font.set_family('Microsoft YaHei')
# plt.figure(figsize=(10,10))
# for i in range(25):
#     plt.subplot(5,5,i+1) # 按照 5*5进行显示
#     plt.xticks([])
#     plt.yticks([])
#     plt.grid(False)
#     plt.imshow(train_images[i], cmap=plt.cm.binary)
#     plt.xlabel(class_names[train_labels[i]],fontproperties=font)
# plt.show()# 构建模型# 1、设置层
# tf.keras是TensorFlow中的高级API,用于构建和训练神经网络模型。它是一个基于Keras库的接口,提供了更简单、更高级的方式来定义、配置和训练神经网络模型。
# tf.keras.Sequential 用于按顺序堆叠各个神经网络层来构建模型,是一种简单的模型类型
model = tf.keras.Sequential([# 将图像格式从二维数组(28*28像素),转化为一维数组(28*28 = 784像素)。将该层视为图像中未堆叠的像素行并将其排列起来。该层没有要学习的参数,它只会重新格式化数据。tf.keras.layers.Flatten(input_shape=(28,28)), # 第二层,是一个具有128个神经元的全连接神经层tf.keras.layers.Dense(128,activation='relu'),# 第三层会返回一个长度为10的数组,每个都包含一个得分来表示当前图像属于10个类中的哪一个tf.keras.layers.Dense(10)
])# 2、编译模型
model.compile(optimizer='adam', # 指定优化器,adam是常用的优化器,可以自适应的调整学习率loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # 指定损失函数,这里使用了稀疏分类交叉熵损失函数metrics=['accuracy'] # 指定评估模型性能的指标,这里使用准确率
)# 训练模型# 1、将训练数据反馈给模型
# model.fit用于将模型与训练数据进行拟合,这里是将所有样本迭代10次
model.fit(train_images,train_labels,epochs=10)# 2、在测试数据集上评估准确率,verbose=2参数表示以详细模式输出评估过程
test_loss,test_acc = model.evaluate(test_images,test_labels,verbose=2)
print("损失率:",test_loss,"准确率:",test_acc)# 进行预测
# 模型经过训练后,您可以使用它对一些图像进行预测。附加一个 Softmax 层,将模型的线性输出 logits 转换成更容易理解的概率
probability_model = tf.keras.Sequential([model,tf.keras.layers.Softmax()])
# 预测图片
predictions = probability_model.predict(test_images)print("第一个预测结果:",predictions[0],'类别是:',class_names[np.argmax(predictions[0])])# 使用训练好的模型
# 加载图片
img = Image.open('pics/shirt.png') 
# 调整大小
img = img.resize((28,28))
# 将彩色图片转为灰度图片
img_gray = img.convert('L')
# 将图像转换为 NumPy 数组,并反转颜色
img_arr = np.array(img_gray)
img_arr = 255 - img_arr
# 将图像像素值归一化到0~1
img_arr = img_arr / 255.0
# 将图像形状调整为(128288)
img_arr = img_arr.reshape(1,28,28)
# 可以保存处理后的文件,也可以进行预测
# np.save('abc.npy',img_arr)res = probability_model.predict(img_arr)
print("预测结果是:",res,class_names[np.argmax(res[0])])# 可视化显示
font = FontProperties()
font.set_family('Microsoft YaHei')
plt.figure() # 创建图像窗口
plt.xticks([])
plt.yticks([])
plt.grid(False) # 取消网格线
plt.imshow(img_arr[0]) # 显示图片
plt.xlabel(class_names[np.argmax(res[0])],fontproperties=font)
plt.show() # 显示图形窗口

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/99063.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开发过程教学——交友小程序

交友小程序 1. 我的基本信息2. 我的人脉2.1 我的关注2.2 我的粉丝 3. 我的视频4. 我的相册 特别注意:由于小程序分包限制2M以内,所以要注意图片和视频的处理。 1. 我的基本信息 数据库表: 我的基本信息我的登录退出记录我的登录状态&#x…

Godot 官方2D游戏笔记(1):导入动画资源和添加节点

文章目录 前言2D官方游戏案例资源下载项目配置添加角色节点模拟运行移动根节点 结束 Godot专栏地址 前言 Godot 官方给了我们2D游戏和3D游戏的案例,不过如果是独立开发者只用考虑2D游戏就可以了,因为2D游戏纯粹,我们只需要关注游戏的玩法即可…

蓝桥杯---第二讲---二分与前缀和

文章目录 前言Ⅰ. 数的范围0x00 算法思路0x00 代码书写 Ⅱ. 数的三次方根0x00 算法思路0x01代码书写 Ⅲ. 前缀和0x00 算法思路0x01 代码书写 Ⅳ. 子矩阵的和0x00 算法思路0x01 代码书写 Ⅴ. 机器人跳跃问题0x00 算法思路0x01 代码书写 Ⅵ. 四平方和0x00 算法思路0x01 代码书写 …

SpringCloud学习笔记-注册微服务到Eureka注册中心

目录 1.在该Module的pom文件中引入eureka依赖2.在该module的src/main/resources/application.yml配置文件3.启动对应的微服务4.查看微服务是否启动成功 假如我有一个微服务名字叫user-service,我需要把它注册到Eureka注册中心,则具体步骤如下: 1.在该Module的pom文件中引入eure…

Flink的处理函数——processFunction

目录 一、处理函数概述 二、Process函数分类——8个 (1)ProcessFunction (2)KeyedProcessFunction (3)ProcessWindowFunction (4)ProcessAllWindowFunction &#xff…

真香!Jenkins 主从模式解决问题So Easy~

01.Jenkins 能干什么 Jenkins 是一个开源软件项目,是基于 Java 开发的一种持续集成工具,用于监控持续重复的工作,旨在提供一个开放易用的软件平台,使软件项目可以进行持续集成。 中文官网:https://jenkins.io/zh/ 0…

好消息:用 vue3+layui 共同铸造我们新的项目

前言: layui这个框架不知道多少人还在关注着,记得第一次接触它是在18年,后来随着vue,react的盛行,jquerylayui的模式受到了特别大的冲击,后来作者都放弃维护他的官方网站,转而在github/gitee上做…

SLAM从入门到精通(ROS和底盘Stm32的关系)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 学过Ros的同学,一般对subscribe、publish、话题、服务这些内容都比较熟悉。如果再熟悉一点的话,还会知道slam、move_base、…

NLP - 数据预处理 - 文本按句子进行切分

NLP - 数据预处理 - 文本按句子进行切分 文章目录 NLP - 数据预处理 - 文本按句子进行切分一、前言二、环境配置1、安装nltk库2、下载punkt分句器 三、运行程序四、额外补充 一、前言 在学习对数据训练的预处理的时候遇到了一个问题,就是如何将文本按句子切分&#…

ChainForge:衡量Prompt性能和模型稳健性的GUI工具包

ChainForge是一个用于构建评估逻辑来衡量模型选择,提示模板和执行生成过程的GUI工具包。ChainForge可以安装在本地,也可以从chrome浏览器运行。 ChainForge可以通过聊天节点对多个对话可以使用不同的llm并行运行。可以对聊天消息进行模板化,并…

基于复旦微JFM7K325T FPGA的高性能PCIe总线数据预处理载板(100%国产化)

PCIE711是一款基于PCIE总线架构的高性能数据预处理FMC载板,板卡采用复旦微的JFM7K325T FPGA作为实时处理器,实现各个接口之间的互联。该板卡可以实现100%国产化。 板卡具有1个FMC(HPC)接口,1路PCIe x8主机接口&#x…

【HomeKit】HAT User Manual教程

前言:这篇文章是对于苹果协议文件《HomeKit Accessory Tester (HAT) User Manual》的学习,即 HomeKit配件测试仪(HAT) 用户手册,该版本是第11次修订 第一章 概述 本文档介绍了Apple HomeKit配件测试仪(HAT)的配置和使用方法。HAT是一个Mac应…

Redis作为缓存,mysql的数据如何与redis进行同步?

Redis作为缓存,mysql的数据如何与redis进行同步? 一定要设置前提,先介绍业务背景 延时双删 双写一致性:当修改了数据库的数据也要同时更新缓存的数据,缓存和数据库的数据要保持一致 读操作:缓存命中,直接返回;缓存未…

【Spring Cloud】深入探索统一网关 Gateway 的搭建,断言工厂,过滤器工厂,全局过滤器以及跨域问题

文章目录 前言为什么需要网关以及网关的作用网关的技术实现 一、Gateway 网关的搭建1.1 创建 Gateway 模块1.2 引入依赖1.3 配置网关1.4 验证网关是否搭建成功1.5 微服务结构分析 二、Gateway 断言工厂2.1 Spring 提供的断言工厂2.2 示例:设置断言工厂 三、Gateway …

Spring的事务控制

目录 基于AOP的声明事务控制 Spring事务编程概述 搭建测试环境 基于xml声明式事务控制 详解 事务增强的AOP 平台事务管理器 Spring提供的Advice(重点介绍) 原理 (源码没有翻太明白) 基于注解声明式事务控制 基于AOP的声明…

小视频APP源码选择指南:挑选最适合你的开发框架

在如今蓬勃发展的小视频APP行业中,源码的选择是打造一款成功应用的关键步骤。然而,面对众多开发框架的选择,如何挑选最适合你的小视频APP源码呢?作为这一领域的专家,我将为你提供一份详尽的指南,助你在源码…

Windows10打开应用总是会弹出提示窗口的解决方法

用户们在Windows10电脑中打开应用程序,遇到了总是会弹出提示窗口的烦人问题。这样的情况会干扰到用户的正常操作,给用户带来不好的操作体验,接下来小编给大家详细介绍关闭这个提示窗口的方法,让大家可以在Windows10电脑中舒心操作…

智能工厂MES系统,终端设备支持手机、PDA、工业平板、PC

一、开源项目简介 源计划智能工厂MES系统(开源版) 功能包括销售管理,仓库管理,生产管理,质量管理,设备管理,条码追溯,财务管理,系统集成,移动端APP。 二、开源协议 使用GPL-3.0开…

010:连续跌3天,同时这三天收盘价都在20日均线下,第四天上涨的概率--以京泉华为例

对于《连续跌三天,压第四天上涨的盈利计算》,我们可以继续优化这个策略,增加条件:同时三天都收盘在20日均线下。 因为我们上一篇《获取20日均线数据到excel表中》获得了20日均线数据,我们可以利用均线数据来编写新的脚…

基于SpringBooy的安康旅游网站的设计与实现

目录 前言 一、技术栈 二、系统功能介绍 登录模块的实现 景点信息管理界面 酒店信息管理界面 特产管理界面 游客管理界面 景点购票订单管理界面 系统主界面 游客注册界面 景点信息详情界面 酒店详情界面 特产详情界面 三、核心代码 1、登录模块 2、文件上传模块…