TensorFlow系列:第四讲:MobileNetV2实战

一. 加载数据集

编写工具类,实现数据集的加载

![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/808da38d6ad74628b869c28e937b02d9.png


import keras"""
加载数据集工具类
"""class DatasetLoader:def __init__(self, path_url, image_size=(224, 224), batch_size=32, class_mode='categorical'):self.path_url = path_urlself.image_size = image_sizeself.batch_size = batch_sizeself.class_mode = class_mode# 不使用图像增强def load_data(self):# 加载训练数据集train_data = keras.preprocessing.image_dataset_from_directory(self.path_url + '/train',  # 训练数据集的目录路径image_size=self.image_size,  # 调整图像大小batch_size=self.batch_size,  # 每批次的样本数量label_mode=self.class_mode,  # 类别模式:返回one-hot编码的标签)# 加载验证数据集val_data = keras.preprocessing.image_dataset_from_directory(self.path_url + '/validation',  # 验证数据集的目录路径image_size=self.image_size,  # 调整图像大小batch_size=self.batch_size,  # 每批次的样本数量label_mode=self.class_mode  # 类别模式:返回one-hot编码的标签)# 加载测试数据集test_data = keras.preprocessing.image_dataset_from_directory(self.path_url + '/test',  # 验证数据集的目录路径image_size=self.image_size,  # 调整图像大小batch_size=self.batch_size,  # 每批次的样本数量label_mode=self.class_mode  # 类别模式:返回one-hot编码的标签)class_names = train_data.class_namesreturn train_data, val_data, test_data, class_names

二. 训练模型完整代码

import keras
from keras import layersfrom utils.dataset_loader import DatasetLoader"""
使用MobileNetV2,实现图像多分类
"""# 模型训练地址
PATH_URL = '../data/fruits'
# 训练曲线图
RESULT_URL = '../results/fruits'
# 模型保存地址
SAVED_MODEL_DIR = '../saved_model/fruits'#  图片大小
IMG_SIZE = (224, 224)
# 定义图像的输入形状
IMG_SHAPE = IMG_SIZE + (3,)
# 数据加载批次,训练轮数
BATCH_SIZE, EPOCH = 32, 16# 训练模型
def train():# 实例化数据集加载工具类dataset_loader = DatasetLoader(PATH_URL, IMG_SIZE, BATCH_SIZE)train_ds, val_ds, test_ds, class_total = dataset_loader.load_data()# 构建 MobileNet 模型base_model = keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False)# 将模型的主干参数进行冻结base_model.trainable = Falsemodel = keras.Sequential([layers.Rescaling(1. / 127.5, offset=-1, input_shape=IMG_SHAPE),# 设置主干模型base_model,# 对主干模型的输出进行全局平均池化layers.GlobalAveragePooling2D(),# 通过全连接层映射到最后的分类数目上layers.Dense(len(class_total), activation='softmax')])# 编译模型model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# 模型结构model.summary()# 指明训练的轮数epoch,开始训练model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)# 测试loss, accuracy = model.evaluate(test_ds)# 输出结果print('Mobilenet test accuracy :', accuracy, ',loss :', loss)# 保存模型 savedModel格式model.export(filepath=SAVED_MODEL_DIR)if __name__ == '__main__':train()

训练模型输出如下:

模型结构:

在这里插入图片描述
训练进度:主要看最下边一行输出,一轮训练完成会显示训练集和验证集的正确率。
在这里插入图片描述
验证正确率:

在这里插入图片描述
保存的模型:

在这里插入图片描述

三. 函数式调用方式

以后的所有讲解,都基于函数式方式进行,因为函数式调用比较灵活。

# 函数式调用方式
def train1():# 实例化数据集加载工具类dataset_loader = DatasetLoader(PATH_URL, IMG_SIZE, BATCH_SIZE)train_ds, val_ds, test_ds, class_total = dataset_loader.load_data()inputs = keras.Input(shape=IMG_SHAPE)# 加载预训练的 MobileNetV2 模型,不包括顶层分类器,并在 Rescaling 层之后连接base_model = keras.applications.MobileNetV3Large(weights='imagenet', include_top=False, input_tensor=inputs)# 冻结 MobileNetV2 的所有层,以防止在初始阶段进行权重更新for layer in base_model.layers:layer.trainable = False# 在 MobileNetV2 之后添加自定义的顶层分类器x = layers.GlobalAveragePooling2D()(base_model.output)predictions = layers.Dense(len(class_total), activation='softmax')(x)# 构建最终模型model = keras.Model(inputs=base_model.input, outputs=predictions)# 编译模型model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# 查看模型结构model.summary()model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)# 测试loss, accuracy = model.evaluate(test_ds)# 输出结果print('Mobilenet test accuracy :', accuracy, ',loss :', loss)# 保存模型 savedModel格式model.export(filepath=SAVED_MODEL_DIR)

四. 保存训练过程曲线图

在训练模型时,我们不可能时时盯着训练数据结果,如果把训练过程曲线保存成图片,这样就比较方便查看。

在项目中编写一个工具类如下:
在这里插入图片描述
上边代码简单改造:

    # 训练模型history = model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)# 保存曲线图Utils.trainResult(history, RESULT_URL)

曲线图如下:训练集和验证集准确率上升,损失率下降,这是完美的表现。

在这里插入图片描述

五. 模型可视化批量测试

在这里插入图片描述
编写可视化批量测试工具类:

import keras
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import FancyBboxPatchfrom utils.dataset_loader import DatasetLoader"""
模型工具类
"""class ModelUtil:def __init__(self, saved_model_dir, path_url):self.save_model_dir = saved_model_dir  # savedModel 模型保存地址self.path_url = path_url  # 模型训练数据地址# 批量识别 进行可视化显示def batch_evaluation(self, class_mode='categorical', image_size=(224, 224), num_images=25):dataset_loader = DatasetLoader(self.path_url, image_size=image_size, class_mode=class_mode)train_ds, val_ds, test_ds, class_names = dataset_loader.load_data()# 加载savedModel模型tfs_layer = keras.layers.TFSMLayer(self.save_model_dir)# 创建一个新的 Keras 模型,包含 TFSMLayermodel = keras.Sequential([keras.Input(shape=image_size + (3,)),  # 根据你的模型的输入形状tfs_layer])plt.figure(figsize=(10, 10))for images, labels in test_ds.take(1):# 使用模型进行预测outputs = model.predict(images)for i in range(num_images):plt.subplot(5, 5, i + 1)image = np.array(images[i]).astype("uint8")plt.imshow(image)index = int(np.argmax(outputs[i]))prediction = outputs[i][index]percentage_str = "{:.2f}%".format(prediction * 100)plt.title(f"{class_names[index]}: {percentage_str}")plt.axis("off")plt.subplots_adjust(hspace=0.5, wspace=0.5)plt.show()

使用工具类:

if __name__ == '__main__':# train()model_util = ModelUtil(SAVED_MODEL_DIR, PATH_URL)model_util.batch_evaluation()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/44717.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

物联网系统中市电电量计量方案(一)

为什么要进行电量计量? 节约资源:电量计量可以帮助人们控制用电量,从而达到节约资源的目的。在当前严峻的资源供应形势下,节约能源是我们应该重视的问题。合理计费:电表可以帮助公共事业单位进行合理计费,…

32. 小批量梯度下降法(Mini-batch Gradient Descent)

在深度学习模型的训练过程中,梯度下降法是最常用的优化算法之一。我们前面介绍了批量梯度下降法(Batch Gradient Descent)和随机梯度下降法(Stochastic Gradient Descent),两者各有优缺点。为了在计算速度和…

QT跨平台开发(windows、mac)中.pro文件设置

方法一: 在配置前面加上平台标识符的前缀 # windows win32:INCLUDEPATH F:/Dev/ffmpeg-4.3.2/include win32:LIBS -LF:/Dev/ffmpeg-4.3.2/lib \-lavcodec \-lavdevice \-lavfilter \-lavformat \-lavutil \-lpostproc \-lswscale \-lswresample# mac macx:INCLUD…

预期功能的必要性与典型案例解析——MUNIK

前言 随着汽车行业的不断发展,人们已经不再满足车辆仅仅作为提高出行效率的简单工具,希望能有有更“聪明的车辆”帮用户解决一部分驾驶带来的困扰。因此,车企们不断探索自动驾驶能够带给人们哪些更便利的解决方案。在这个过程中不可避免地将…

3.相机标定原理及代码实现(opencv)

1.相机标定原理 相机参数的确定过程就叫做相机标定。 1.1 四大坐标系及关系 (1)像素坐标系(单位:像素(pixel)) 像素坐标系是指相机拍到的图片的坐标系,以图片的左上角为坐标原点&a…

为校园后勤注入智慧:收件登记功能驱动全新体验

在智慧校园的后勤管理体系中,收件登记服务是一项旨在提升快递接收体验的创新举措,它无缝融合了现代科技与日常校园生活,为师生带来便捷与安心。 为应对日益增长的快递需求,师生可事先通过校园网平台或特制的移动应用预报快递信息&…

《Linux与Windows文件系统的区别》

Linux与Windows文件系统的区别 在计算机操作系统领域,Linux和Windows是两种广泛使用的操作系统,它们在文件系统方面有许多显著的差异。这篇博客将详细介绍这两种操作系统文件系统的区别,帮助读者更好地理解它们各自的特点和优势。 类别Linu…

光学传感器图像处理流程(二)

光学传感器图像处理流程(二) 2.4. 图像增强2.4.1. 彩色合成2.4.2 直方图变换2.4.3. 密度分割2.4.4. 图像间运算2.4.5. 邻域增强2.4.6. 主成分分析2.4.7. 图像融合 2.5. 裁剪与镶嵌2.5.1. 图像裁剪2.5.2. 图像镶嵌 2.6. 遥感信息提取2.6.1. 目视解译2.6.2…

数字化时代的供应链管理综合解决方案

目录 引言背景与意义供应链管理综合解决方案的目标 📄供应链管理系统主要功能系统优势 📄物流管理系统主要功能系统优势 📄订单管理系统主要功能应用场景 📄仓储管理系统系统亮点主要功能系统优势 📄商城管理系统主要功…

如何安全使用代理ip

1、选择可靠的代理服务提供商:选择知名的、信誉良好的代理服务提供商,避免使用免费的代理服务,因为免费的代理服务可能存在安全隐患。 2、使用HTTPS代理:使用HTTPS代理可以加密你的网络流量,保护你的隐私和安全。 3、…

【python】QWidget父子关系,控件显示优先级原理剖析与应用实战演练

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

socks4 socks4a socks5 socks5h的区别

1、socks4 socks4a socks5 socks5h的区别 代理设置区别curl https://www.google.com -x 127.0.0.1:1080等于http://127.0.0.1:1080curl https://www.google.com -x http://127.0.0.1:1080http代理,代理端服务器完成 DNS 解析curl https://www.google.com -x https:…

又是三道简单的web题(2)

一、cookie 1.打开后是如下页面,抓包,关注cookie 2.发现cookie中有一个文件 3.直接访问这个文件,得到flag 二、employeeswork 打开后页面如下: 点击后出现一串php代码 审一下这个代码,需要添加参数work并且赋值work…

Linux笔记之使用系统调用sendfile高速拷贝文件

Linux笔记之使用系统调用sendfile高速拷贝文件 code review! 文章目录 Linux笔记之使用系统调用sendfile高速拷贝文件sendfile 性能优势sendfile 系统调用优点:缺点: cp 命令优点:缺点: 实际测试:拷贝5.8个G的文件&a…

【Vue3】export, import, export default

export对外输出: export var name "mike"; //导出多个变量 export {name1, name2}import导入: import {name} from "/.a.js" //引入多个变量 import {name1, name2} from "/.a.js"export default为模块指定默认输出&am…

合合信息大模型加速器亮相WAIC大会:文档解析与文本识别新突破

合合信息大模型加速器亮相WAIC大会:文档解析与文本识别新突破 文章目录 合合信息大模型加速器亮相WAIC大会:文档解析与文本识别新突破前言合合信息TextIn平台:智能文档处理的领军者文档解析引擎:百页文档秒级处理大模型的发展背景…

vue vite自动化路由 无需手动配置

vue vite自动化路由 测试某些功能或者框架以及库的时候 需要创建新vue页面 没次都有手动配置 仅仅测试 细化的话根据自己需求配置权限 这里方便点 直接把router文件删掉 直接在main.js 引入所有路由注册 这样 每次在views下创建一个vue文件 直接访即可 不用手动注册了 main.js …

C#的using IDisposable 接口的使用介绍

IDisposable 接口在C#中的主要作用是提供一种用于释放非托管资源的机制。非托管资源包括文件句柄、数据库连接、网络连接、COM组件等,它们不受.NET运行时管理,需要显式释放以避免资源泄漏和提高性能。 使用 IDisposable 接口的主要步骤包括: 实现 IDisposable 接口: 在类中…

【漏洞复现】Crocus系统——Download——文件读取

声明:本文档或演示材料仅供教育和教学目的使用,任何个人或组织使用本文档中的信息进行非法活动,均与本文档的作者或发布者无关。 文章目录 漏洞描述漏洞复现测试工具 漏洞描述 Crocus系统旨在利用人工智能、高清视频、大数据和自动驾驶技术&…