TensorFlow系列:第四讲:MobileNetV2实战

一. 加载数据集

编写工具类,实现数据集的加载

![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/808da38d6ad74628b869c28e937b02d9.png


import keras"""
加载数据集工具类
"""class DatasetLoader:def __init__(self, path_url, image_size=(224, 224), batch_size=32, class_mode='categorical'):self.path_url = path_urlself.image_size = image_sizeself.batch_size = batch_sizeself.class_mode = class_mode# 不使用图像增强def load_data(self):# 加载训练数据集train_data = keras.preprocessing.image_dataset_from_directory(self.path_url + '/train',  # 训练数据集的目录路径image_size=self.image_size,  # 调整图像大小batch_size=self.batch_size,  # 每批次的样本数量label_mode=self.class_mode,  # 类别模式:返回one-hot编码的标签)# 加载验证数据集val_data = keras.preprocessing.image_dataset_from_directory(self.path_url + '/validation',  # 验证数据集的目录路径image_size=self.image_size,  # 调整图像大小batch_size=self.batch_size,  # 每批次的样本数量label_mode=self.class_mode  # 类别模式:返回one-hot编码的标签)# 加载测试数据集test_data = keras.preprocessing.image_dataset_from_directory(self.path_url + '/test',  # 验证数据集的目录路径image_size=self.image_size,  # 调整图像大小batch_size=self.batch_size,  # 每批次的样本数量label_mode=self.class_mode  # 类别模式:返回one-hot编码的标签)class_names = train_data.class_namesreturn train_data, val_data, test_data, class_names

二. 训练模型完整代码

import keras
from keras import layersfrom utils.dataset_loader import DatasetLoader"""
使用MobileNetV2,实现图像多分类
"""# 模型训练地址
PATH_URL = '../data/fruits'
# 训练曲线图
RESULT_URL = '../results/fruits'
# 模型保存地址
SAVED_MODEL_DIR = '../saved_model/fruits'#  图片大小
IMG_SIZE = (224, 224)
# 定义图像的输入形状
IMG_SHAPE = IMG_SIZE + (3,)
# 数据加载批次,训练轮数
BATCH_SIZE, EPOCH = 32, 16# 训练模型
def train():# 实例化数据集加载工具类dataset_loader = DatasetLoader(PATH_URL, IMG_SIZE, BATCH_SIZE)train_ds, val_ds, test_ds, class_total = dataset_loader.load_data()# 构建 MobileNet 模型base_model = keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False)# 将模型的主干参数进行冻结base_model.trainable = Falsemodel = keras.Sequential([layers.Rescaling(1. / 127.5, offset=-1, input_shape=IMG_SHAPE),# 设置主干模型base_model,# 对主干模型的输出进行全局平均池化layers.GlobalAveragePooling2D(),# 通过全连接层映射到最后的分类数目上layers.Dense(len(class_total), activation='softmax')])# 编译模型model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# 模型结构model.summary()# 指明训练的轮数epoch,开始训练model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)# 测试loss, accuracy = model.evaluate(test_ds)# 输出结果print('Mobilenet test accuracy :', accuracy, ',loss :', loss)# 保存模型 savedModel格式model.export(filepath=SAVED_MODEL_DIR)if __name__ == '__main__':train()

训练模型输出如下:

模型结构:

在这里插入图片描述
训练进度:主要看最下边一行输出,一轮训练完成会显示训练集和验证集的正确率。
在这里插入图片描述
验证正确率:

在这里插入图片描述
保存的模型:

在这里插入图片描述

三. 函数式调用方式

以后的所有讲解,都基于函数式方式进行,因为函数式调用比较灵活。

# 函数式调用方式
def train1():# 实例化数据集加载工具类dataset_loader = DatasetLoader(PATH_URL, IMG_SIZE, BATCH_SIZE)train_ds, val_ds, test_ds, class_total = dataset_loader.load_data()inputs = keras.Input(shape=IMG_SHAPE)# 加载预训练的 MobileNetV2 模型,不包括顶层分类器,并在 Rescaling 层之后连接base_model = keras.applications.MobileNetV3Large(weights='imagenet', include_top=False, input_tensor=inputs)# 冻结 MobileNetV2 的所有层,以防止在初始阶段进行权重更新for layer in base_model.layers:layer.trainable = False# 在 MobileNetV2 之后添加自定义的顶层分类器x = layers.GlobalAveragePooling2D()(base_model.output)predictions = layers.Dense(len(class_total), activation='softmax')(x)# 构建最终模型model = keras.Model(inputs=base_model.input, outputs=predictions)# 编译模型model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# 查看模型结构model.summary()model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)# 测试loss, accuracy = model.evaluate(test_ds)# 输出结果print('Mobilenet test accuracy :', accuracy, ',loss :', loss)# 保存模型 savedModel格式model.export(filepath=SAVED_MODEL_DIR)

四. 保存训练过程曲线图

在训练模型时,我们不可能时时盯着训练数据结果,如果把训练过程曲线保存成图片,这样就比较方便查看。

在项目中编写一个工具类如下:
在这里插入图片描述
上边代码简单改造:

    # 训练模型history = model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)# 保存曲线图Utils.trainResult(history, RESULT_URL)

曲线图如下:训练集和验证集准确率上升,损失率下降,这是完美的表现。

在这里插入图片描述

五. 模型可视化批量测试

在这里插入图片描述
编写可视化批量测试工具类:

import keras
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import FancyBboxPatchfrom utils.dataset_loader import DatasetLoader"""
模型工具类
"""class ModelUtil:def __init__(self, saved_model_dir, path_url):self.save_model_dir = saved_model_dir  # savedModel 模型保存地址self.path_url = path_url  # 模型训练数据地址# 批量识别 进行可视化显示def batch_evaluation(self, class_mode='categorical', image_size=(224, 224), num_images=25):dataset_loader = DatasetLoader(self.path_url, image_size=image_size, class_mode=class_mode)train_ds, val_ds, test_ds, class_names = dataset_loader.load_data()# 加载savedModel模型tfs_layer = keras.layers.TFSMLayer(self.save_model_dir)# 创建一个新的 Keras 模型,包含 TFSMLayermodel = keras.Sequential([keras.Input(shape=image_size + (3,)),  # 根据你的模型的输入形状tfs_layer])plt.figure(figsize=(10, 10))for images, labels in test_ds.take(1):# 使用模型进行预测outputs = model.predict(images)for i in range(num_images):plt.subplot(5, 5, i + 1)image = np.array(images[i]).astype("uint8")plt.imshow(image)index = int(np.argmax(outputs[i]))prediction = outputs[i][index]percentage_str = "{:.2f}%".format(prediction * 100)plt.title(f"{class_names[index]}: {percentage_str}")plt.axis("off")plt.subplots_adjust(hspace=0.5, wspace=0.5)plt.show()

使用工具类:

if __name__ == '__main__':# train()model_util = ModelUtil(SAVED_MODEL_DIR, PATH_URL)model_util.batch_evaluation()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/44717.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

物联网系统中市电电量计量方案(一)

为什么要进行电量计量? 节约资源:电量计量可以帮助人们控制用电量,从而达到节约资源的目的。在当前严峻的资源供应形势下,节约能源是我们应该重视的问题。合理计费:电表可以帮助公共事业单位进行合理计费,…

3.相机标定原理及代码实现(opencv)

1.相机标定原理 相机参数的确定过程就叫做相机标定。 1.1 四大坐标系及关系 (1)像素坐标系(单位:像素(pixel)) 像素坐标系是指相机拍到的图片的坐标系,以图片的左上角为坐标原点&a…

为校园后勤注入智慧:收件登记功能驱动全新体验

在智慧校园的后勤管理体系中,收件登记服务是一项旨在提升快递接收体验的创新举措,它无缝融合了现代科技与日常校园生活,为师生带来便捷与安心。 为应对日益增长的快递需求,师生可事先通过校园网平台或特制的移动应用预报快递信息&…

光学传感器图像处理流程(二)

光学传感器图像处理流程(二) 2.4. 图像增强2.4.1. 彩色合成2.4.2 直方图变换2.4.3. 密度分割2.4.4. 图像间运算2.4.5. 邻域增强2.4.6. 主成分分析2.4.7. 图像融合 2.5. 裁剪与镶嵌2.5.1. 图像裁剪2.5.2. 图像镶嵌 2.6. 遥感信息提取2.6.1. 目视解译2.6.2…

数字化时代的供应链管理综合解决方案

目录 引言背景与意义供应链管理综合解决方案的目标 📄供应链管理系统主要功能系统优势 📄物流管理系统主要功能系统优势 📄订单管理系统主要功能应用场景 📄仓储管理系统系统亮点主要功能系统优势 📄商城管理系统主要功…

【python】QWidget父子关系,控件显示优先级原理剖析与应用实战演练

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

又是三道简单的web题(2)

一、cookie 1.打开后是如下页面,抓包,关注cookie 2.发现cookie中有一个文件 3.直接访问这个文件,得到flag 二、employeeswork 打开后页面如下: 点击后出现一串php代码 审一下这个代码,需要添加参数work并且赋值work…

Linux笔记之使用系统调用sendfile高速拷贝文件

Linux笔记之使用系统调用sendfile高速拷贝文件 code review! 文章目录 Linux笔记之使用系统调用sendfile高速拷贝文件sendfile 性能优势sendfile 系统调用优点:缺点: cp 命令优点:缺点: 实际测试:拷贝5.8个G的文件&a…

合合信息大模型加速器亮相WAIC大会:文档解析与文本识别新突破

合合信息大模型加速器亮相WAIC大会:文档解析与文本识别新突破 文章目录 合合信息大模型加速器亮相WAIC大会:文档解析与文本识别新突破前言合合信息TextIn平台:智能文档处理的领军者文档解析引擎:百页文档秒级处理大模型的发展背景…

【漏洞复现】Crocus系统——Download——文件读取

声明:本文档或演示材料仅供教育和教学目的使用,任何个人或组织使用本文档中的信息进行非法活动,均与本文档的作者或发布者无关。 文章目录 漏洞描述漏洞复现测试工具 漏洞描述 Crocus系统旨在利用人工智能、高清视频、大数据和自动驾驶技术&…

工程化-vue3+ts:代码检测工具 ESLint

一、理解ESLint ESLint是一个开源的JavaScript代码检查工具,用于帮助开发人员规范和统一编码风格。它可以检查代码中的潜在错误、不一致的编码习惯以及一些常见的代码问题。 ESLint使用基于规则的插件体系,可以根据项目的需求和个人的偏好配置不同的规…

数据库数据恢复—SQL Server数据库由于存放空间不足报错的数据恢复案例

SQL Server数据库数据恢复环境: 某品牌服务器存储中有两组raid5磁盘阵列。操作系统层面跑着SQL Server数据库,SQL Server数据库存放在D盘分区中。 SQL Server数据库故障: 存放SQL Server数据库的D盘分区容量不足,管理员在E盘中生…

MacOS如何切换shell类型

切换 shell 类型 如果你想在不同的 shell 之间切换,以探索它们的不同之处,或者因为你知道自己需要其中的一个或另一个,可以使用如下命令: 切换到 bash chsh -s $(which bash)切换到 zsh chsh -s $(which zsh)$()语法的作用是运…

FastGPT:给 GPT 插上知识库的翅膀!0基础搭建本地私有知识库,有手就行

写在前面 上一篇,我们部署了接口管理和分发神器-OneAPI,将所有大模型一键封装成OpenAI协议。见:[OneAPI)。 基于此,本篇继续带领大家搭建一个基于本地知识库检索的问答系统。 有同学说 Coze 不也可以实现同样功能么&#xff1f…

51单片机:电脑通过串口控制LED亮灭(附溢出率和波特率详解)

一、功能实现 1.电脑通过串口发送数据:0F 2.点亮4个LED 二、注意事项 1.发送和接受数据的文本模式 2.串口要对应 3.注意串口的波特率要和程序中的波特率保持一致 4.有无校验位和停止位 三、如何使用串口波特率计算器 1.以本程序为例 2.生成代码如下 void Uar…

[论文笔记]涨点近5%! 以内容中心的检索增强生成可扩展的级联框架:Pistis-RAG

引言 今天带来一篇较新RAG的论文笔记:Pistis-RAG: A Scalable Cascading Framework Towards Content-Centric Retrieval-Augmented Generation。 在希腊神话中,Pistis象征着诚信、信任和可靠性。受到这些原则的启发,Pistis-RAG是一个可扩展…

windows远程桌面到 Linux系统(Ubuntu:22.04)—— 安装xrdp软件

1、在Linux系统上安装xrdp软件 sudo apt update sudo apt install xrdp2、安装完成后,需要开启xrdp服务 sudo systemctl start xrdp sudo systemctl enable xrdp打印返回 Synchronizing state of xrdp.service with SysV service script with /lib/systemd/system…

一键叫车|开发打车小程序,随时随地便利出行!

随着移动互联网的普及,人们出行的方式也在不断发生变化。对于出行多样化和便捷化的需求,一款打车小程序可以方便人们的出行,提高出行效率和便捷性。打车小程序能够根据用户的出行需求为其打造个性化的出行方案,从而让用户的出行生…

【DevOps】在云原生时代的角色与重要性探索

🐇明明跟你说过:个人主页 🏅个人专栏:《未来已来:云原生之旅》🏅 🔖行路有良友,便是天堂🔖 目录 一、引言 1、什么是云原生 2、云原生的核心特性 3、什么是DevOps…