ChatGPT对话:关于训练模型h5格式和SavedModel格式的问题

【编者按】有几天没有更新,一直在学习和调试程序,其中有大量与ChatGPT的对话,遇到的每一个问题都找ChatGPT,再也没有用到百度和google。在ChatGPT的帮助下已完成Python下的花卉识别训练与预测,部署到网站,也可以通过网页识别花卉图片。移植到微信小程序中还有一点问题(在模型加载时,遇到微信小程序不支持“fetch”的问题)。在编写与调试过程中,遇到的坑太多,与ChatGPT对话太多,无法一一呈现,与大家分享,后续会选择有代表性的知识点来分享。针对踩坑的解决方法,后续再总结为一篇文章来呈现,不以围观ChatGPT的方式呈现。
下面是有关h5格式与SavedModel格式的问题。


1问:model.predict()与 model.signatures[‘serving_default’]、infer(),这两种方法有什么区别?(说明:在ChatGPT推荐的代码中涉及两种方法,所以有此一问,同时引出了模型格式的问题)

ChatGPT

在TensorFlow中,model.predict()model.signatures['serving_default'](或infer())是两种不同的方法来进行推理(推断)或预测。它们的区别主要在于使用场景和底层实现。

model.predict()

  • 适用场景: 这是Keras模型的一个方法,主要用于在本地对数据进行预测。
  • 使用简便: 该方法非常直观且易于使用,特别适合在训练和评估模型后进行批量预测。
  • 输入数据格式: 需要符合模型的输入形状和类型,通常是一个NumPy数组或一个tf.data.Dataset对象。
  • 底层实现: 这是一个高层次的Keras API,内部处理了很多细节,如数据预处理和批处理。

示例:

import tensorflow as tf
import numpy as np# 假设model已经定义和训练好
# 输入数据
data = np.random.random((10, 224, 224, 3))
predictions = model.predict(data)

model.signatures['serving_default']infer()

  • 适用场景: 这种方法主要用于TensorFlow Serving或导出SavedModel后进行推理。
  • 灵活性: 提供了更多的控制和灵活性,可以自定义输入和输出的签名。
  • 输入数据格式: 输入是一个字典,键为签名中定义的输入张量名称,值为相应的张量。
  • 底层实现: 这是一个低层次的API,通常在模型导出后用于TensorFlow Serving时使用。

示例:

import tensorflow as tf
import numpy as np# 假设model已经定义和训练好,并导出为SavedModel
# 加载模型
loaded_model = tf.saved_model.load('path_to_saved_model')# 获取签名
infer = loaded_model.signatures['serving_default']# 输入数据
input_data = np.random.random((10, 224, 224, 3)).astype(np.float32)
tensor_input_data = tf.convert_to_tensor(input_data)# 构建输入字典
input_dict = {'input_1': tensor_input_data}# 推理
output_dict = infer(**input_dict)
predictions = output_dict['output_1'].numpy()

主要区别

  1. 接口和使用简便性:

    • model.predict() 是一个高层次的API,适合快速和方便地进行预测。
    • model.signatures['serving_default'] 提供了更低层次的控制,适合在模型部署或需要自定义签名的情况下使用。
  2. 输入和输出的处理:

    • model.predict() 处理输入和输出非常方便,用户只需提供符合要求的输入数据格式。
    • model.signatures['serving_default'] 需要用户手动构建输入和输出字典,这使得其更灵活但也稍微复杂一些。
  3. 使用场景:

    • model.predict() 主要用于本地模型预测,训练和评估后的快速推理。
    • model.signatures['serving_default'] 主要用于TensorFlow Serving或模型导出后的远程推理。
  4. 性能和优化:

    • 在某些情况下,model.signatures['serving_default'] 可能会提供更好的性能,因为它可以利用TensorFlow Serving的优化和批处理能力。

总之,选择哪种方法取决于你的使用场景。如果是在本地进行快速预测,model.predict() 是更好的选择。如果是需要部署模型并在生产环境中进行推理,使用 model.signatures['serving_default'] 会更合适。


2问:与训练模型格式有关吗?如h5,或SavedModel

ChatGPT

是的,model.predict()model.signatures['serving_default'] 在某种程度上与模型的保存格式有关。具体来说,它们对 h5SavedModel 格式的支持有所不同:

model.predict()

  • 适用格式: 主要适用于 h5SavedModel 格式的 Keras 模型。
  • 加载模型: 无论是 h5 还是 SavedModel 格式的模型,都可以使用 tf.keras.models.load_model() 来加载。
  • 使用方式: 无论模型以何种格式保存,加载后都可以直接调用 model.predict() 进行预测。

示例:

import tensorflow as tf
import numpy as np# 加载h5格式的模型
model_h5 = tf.keras.models.load_model('model.h5')# 加载SavedModel格式的模型
model_saved = tf.keras.models.load_model('saved_model_directory')# 输入数据
data = np.random.random((10, 224, 224, 3))# 进行预测
predictions_h5 = model_h5.predict(data)
predictions_saved = model_saved.predict(data)

model.signatures['serving_default']

  • 适用格式: 主要适用于 SavedModel 格式的模型。
  • 加载模型: 使用 tf.saved_model.load() 来加载 SavedModel 格式的模型。
  • 使用方式: 需要访问模型的签名,并根据签名的要求准备输入和输出。

示例:

import tensorflow as tf
import numpy as np# 加载SavedModel格式的模型
loaded_model = tf.saved_model.load('saved_model_directory')# 获取签名
infer = loaded_model.signatures['serving_default']# 输入数据
input_data = np.random.random((10, 224, 224, 3)).astype(np.float32)
tensor_input_data = tf.convert_to_tensor(input_data)# 构建输入字典
input_dict = {'input_1': tensor_input_data}  # 注意这里的 'input_1' 需要与模型签名匹配# 推理
output_dict = infer(**input_dict)
predictions = output_dict['output_1'].numpy()  # 注意这里的 'output_1' 需要与模型签名匹配

总结

  • model.predict():

    • 适用于 h5SavedModel 格式的 Keras 模型。
    • 加载模型后直接使用,接口简单方便。
    • 主要用于本地推理和评估。
  • model.signatures['serving_default']:

    • 主要适用于 SavedModel 格式的模型。
    • 需要通过签名访问和使用,更灵活但也更复杂。
    • 主要用于部署和远程推理,如在 TensorFlow Serving 中使用。

因此,模型保存格式确实会影响你选择哪种方法来进行推理。对于 h5 格式的模型,只能使用 model.predict() 方法,而对于 SavedModel 格式的模型,你可以根据需要选择使用 model.predict()model.signatures['serving_default'] 方法。


3问:把h5与SavedModel模型转换为tensorflowjs后,是否还有model.predict()与model.signatures[‘serving_default’]区别呢?在web中使用哪个更好呢?

ChatGPT

将 Keras 模型(无论是 h5 格式还是 SavedModel 格式)转换为 TensorFlow.js 格式后,主要使用 TensorFlow.js 提供的 API 进行推理。在 TensorFlow.js 中,使用模型进行推理的方式和在 TensorFlow(Python)中有所不同。

TensorFlow.js 中的模型推理

在 TensorFlow.js 中,加载和使用模型的方式主要有以下几种:

  1. 使用 tf.loadLayersModel()model.predict()
  2. 使用 tf.loadGraphModel()model.execute()

tf.loadLayersModel()model.predict()

  • 适用模型: 主要适用于 Keras h5 和通过 tfjs.converters.save_keras_model 转换的模型。
  • 使用方式: 类似于在 Python 中的 Keras 模型,加载后可以直接调用 model.predict() 进行推理。

示例:

// 加载Keras模型
const model = await tf.loadLayersModel('path/to/model.json');// 准备输入数据
const input = tf.tensor4d(data, [1, 224, 224, 3]);// 进行预测
const predictions = model.predict(input);
predictions.print();

tf.loadGraphModel()model.execute()

  • 适用模型: 主要适用于 TensorFlow SavedModel 和通过 tfjs.converters.convert_tf_saved_model 转换的模型。
  • 使用方式: 使用 model.execute() 进行推理,类似于在 Python 中使用 model.signatures['serving_default']

示例:

// 加载Graph模型
const model = await tf.loadGraphModel('path/to/model.json');// 准备输入数据
const input = tf.tensor4d(data, [1, 224, 224, 3]);// 构建输入字典
const inputDict = { 'input_tensor_name': input }; // 'input_tensor_name'需要与模型的输入张量名称匹配// 进行推理
const output = model.execute(inputDict, ['output_tensor_name']); // 'output_tensor_name'需要与模型的输出张量名称匹配
output.print();

选择哪个更好

在 Web 中选择哪个方法更好取决于以下因素:

  1. 模型类型:

    • 如果模型是 Keras h5 模型或通过 tfjs.converters.save_keras_model 转换的模型,使用 tf.loadLayersModel()model.predict() 更加简便。
    • 如果模型是 TensorFlow SavedModel 或通过 tfjs.converters.convert_tf_saved_model 转换的模型,使用 tf.loadGraphModel()model.execute() 更合适。
  2. 使用场景:

    • 简单推理: 对于简单的推理任务,使用 tf.loadLayersModel()model.predict() 更加直观和方便。
    • 复杂推理: 如果需要自定义输入和输出张量,或者使用特定的签名,tf.loadGraphModel()model.execute() 提供了更多的灵活性。
  3. 性能:

    • 一般来说,两者在性能上没有显著差异。选择时更多考虑的是模型类型和使用的简便性。

总结

在 Web 中使用 TensorFlow.js 进行模型推理时:

  • 如果模型是 Keras h5 或通过 tfjs.converters.save_keras_model 转换的模型,建议使用 tf.loadLayersModel()model.predict()
  • 如果模型是 TensorFlow SavedModel 或通过 tfjs.converters.convert_tf_saved_model 转换的模型,建议使用 tf.loadGraphModel()model.execute()

无论选择哪种方式,都要确保输入和输出的张量名称与模型签名一致。这样可以确保模型能够正确地进行推理。


4问:我是使用tensorflowjs_converter命令把h5和SavedModel转换为tensorflowjs模型的,没有涉及你提到的tfjs.converters.save_keras_model和tfjs.converters.convert_tf_saved_model。阐述一下这2个模块

ChatGPT

在 TensorFlow.js 中,tfjs.converters 模块提供了用于在 TensorFlow 和 TensorFlow.js 之间进行模型转换的工具。以下是对 tfjs.converters.save_keras_modeltfjs.converters.convert_tf_saved_model 两个函数的详细说明:

tfjs.converters.save_keras_model

这个函数用于将一个 Keras 模型保存为 TensorFlow.js 可用的格式。它在 Python 环境中运行,将现有的 Keras 模型(无论是 .h5 格式还是其他格式的 Keras 模型)转换为 TensorFlow.js 格式。

使用方法:

import tensorflow as tf
import tensorflowjs as tfjs# 假设model是一个已经定义和训练好的Keras模型
model = tf.keras.models.load_model('model.h5')# 将模型转换并保存为TensorFlow.js格式
tfjs.converters.save_keras_model(model, 'path/to/save/model')

示例:

import tensorflow as tf
import tensorflowjs as tfjs# 创建一个简单的Keras模型
model = tf.keras.Sequential([tf.keras.layers.Dense(10, activation='relu', input_shape=(784,)),tf.keras.layers.Dense(10, activation='softmax')
])# 编译和训练模型
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# model.fit(...)# 将模型转换为TensorFlow.js格式
tfjs.converters.save_keras_model(model, 'path/to/save/model')

tfjs.converters.convert_tf_saved_model

这个函数用于将一个 TensorFlow SavedModel 格式的模型转换为 TensorFlow.js 格式。它同样在 Python 环境中运行,将 TensorFlow SavedModel 模型转换为 TensorFlow.js 格式。

使用方法:

import tensorflowjs as tfjs# 将SavedModel转换并保存为TensorFlow.js格式
tfjs.converters.convert_tf_saved_model('path/to/saved_model', 'path/to/save/model')

示例:

import tensorflow as tf
import tensorflowjs as tfjs# 假设已经有一个保存的SavedModel
# 将SavedModel加载
model = tf.saved_model.load('path/to/saved_model')# 将SavedModel转换为TensorFlow.js格式
tfjs.converters.convert_tf_saved_model('path/to/saved_model', 'path/to/save/model')

tensorflowjs_converter 命令行工具

你提到使用 tensorflowjs_converter 命令行工具将 Keras .h5 模型和 SavedModel 转换为 TensorFlow.js 格式。这实际上是利用了 tfjs.converters 模块的功能,但以命令行的方式进行操作。

转换 Keras .h5 模型:

tensorflowjs_converter --input_format keras path/to/model.h5 path/to/save/model

转换 TensorFlow SavedModel:

tensorflowjs_converter --input_format=tf_saved_model --saved_model_tags=serve path/to/saved_model path/to/save/model

在 Web 中使用转换后的模型

不管你是使用 tfjs.converters.save_keras_modeltfjs.converters.convert_tf_saved_model,还是 tensorflowjs_converter 命令行工具来进行转换,转换后的模型在 Web 中的使用方式是相同的。

加载和使用 Keras 模型:

// 使用 tf.loadLayersModel() 加载模型
const model = await tf.loadLayersModel('path/to/model.json');// 准备输入数据
const input = tf.tensor4d(data, [1, 224, 224, 3]);// 进行预测
const predictions = model.predict(input);
predictions.print();

加载和使用 SavedModel:

// 使用 tf.loadGraphModel() 加载模型
const model = await tf.loadGraphModel('path/to/model.json');// 准备输入数据
const input = tf.tensor4d(data, [1, 224, 224, 3]);// 构建输入字典
const inputDict = { 'input_tensor_name': input }; // 'input_tensor_name'需要与模型的输入张量名称匹配// 进行推理
const output = model.execute(inputDict, ['output_tensor_name']); // 'output_tensor_name'需要与模型的输出张量名称匹配
output.print();

总结

  • tfjs.converters.save_keras_model: 用于在 Python 环境中将 Keras 模型转换为 TensorFlow.js 格式。
  • tfjs.converters.convert_tf_saved_model: 用于在 Python 环境中将 TensorFlow SavedModel 转换为 TensorFlow.js 格式。
  • tensorflowjs_converter: 命令行工具,提供了将 Keras .h5 模型和 TensorFlow SavedModel 转换为 TensorFlow.js 格式的便捷方法。

在 Web 中,使用 tf.loadLayersModel() 进行预测更适合简单的 Keras 模型,而使用 tf.loadGraphModel() 则更适合复杂的 TensorFlow SavedModel。选择哪种方法取决于你转换的模型类型以及使用场景。


至此,清楚了模型格式的一些问题。最后,编者选择了在训练的Python代码中直接保存为Tensorflow.js格式,而不需要在使用tensorflowjs_converter来转换,这样,部署到网站或微信小程序就更方便了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/51001.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计划外投料

计划外领料通过A07 R10发料/其他,进行计划外投料。移动类型仍然是261。 在科目分配页签。维护计划外投料的工单号。 在实际业务中,有的企业也会单独复制出一个移动类型。进行区分。 在COOIS中,有张表可以看出物料是计划内投料还是计划外。 …

MySQL查询执行(三):显示随机消息

假设有如下表结构&#xff1a; -- 创建表words CREATE TABLE words (id int(11) NOT NULL AUTO_INCREMENT,word varchar(64) DEFAULT NULL,PRIMARY KEY (id) ) ENGINEInnoDB;--数据生成存储过程 delimiter ;; create procedure idata() begindeclare i int;set i0;while i<…

[工具]GitHub + PicGo 搭建免费博客图床

文章目录 起因GitHub新建GitHub仓库新建token授予picgo权限 PicGOPicGO上传失败原因 起因 还是觉得个人博客记录最好还是不要money&#x1f625;&#xff0c;所以还是想白嫖&#xff0c;找到了GitHub PicGO的方式&#xff0c;记录一下。 GitHub 过程和搭建博客链接类似&…

DDOS攻击学习 - kali初学

文章目录 本地ssh配置nmap(网络连接的工具)nmap -sP IP地址nmap -p 1-65535 -A IP地址主机发现Ping扫描端口扫描时序扫描常用扫描方式指纹识别与探测全端口版本探测防火墙/IDS逃逸报文分段信息收集IP信息收集WHOIS查询数据库渗透测试MySQL列举数据库列举MySQL变量发起请求目录扫…

PostgreSQL的pg-collector工具

PostgreSQL的pg-collector工具 pg-collector 是一个用于 PostgreSQL 数据库的监控和数据收集工具。它主要用于收集 PostgreSQL 实例的性能指标、查询统计和日志信息&#xff0c;以便进行数据库性能分析和故障排查。通过收集这些数据&#xff0c;管理员可以更好地了解数据库的运…

day3 测试基础知识

1. 你认为性能测试的目的是什么&#xff1f;做好性能测试的工作的关键是什么&#xff1f; 性能测试工作的目的是检查系统是否满足在需求说明书中规定的性能&#xff0c;性能测试常常需要和强度测试结合起来&#xff0c;并常常要求同时进行软件和硬件的检测。 性能测试主要的关…

Laravel视图渲染封装

第一种 app/Helpers/ViewHelper.php 创建一个辅助函数&#xff0c;用于动态确定视图路径&#xff1a; <?php if (!function_exists(fetchView)) {function fetchView($data []){$currentAction \Route::currentRouteAction();list($controller, $method) explode(, $c…

关于SpringBoot项目利用阿里EasyExcel快捷导入Excel文件入库初始化数据的简单实现

一、问题描述 无论新项目还是旧项目&#xff0c;都会出现数据维护、数据初始化等操作&#xff0c;手动录显然很low(领导会骂你)&#xff0c;所以一般采用批量导入导出。这里你还在用原始读取excel逐行逐列去读取吗&#xff1f;2024了ok&#xff1f;利用工具是我们cv大师的一贯…

在 OpenEuler24.03 源码安装 PG16.3

在ANOLIS 23上源码安装了16.1&#xff0c;在OpenEuler24.03上PG16.3&#xff0c;安装也是一样的吗&#xff1f; 抱着这样的态度&#xff0c;我试​着去安装&#xff0c;如果不关闭SELINUX&#xff0c;还是有一个差异的&#xff0c;同时&#xff0c;发现即使是最小安装&#xf…

ElasticSearch学习篇15_《检索技术核心20讲》进阶篇之TopK检索

背景 学习极客实践课程《检索技术核心20讲》https://time.geekbang.org/column/article/215243&#xff0c;文档形式记录笔记。 相关问题&#xff1a; ES全文检索是如何进行相关性打分的&#xff1f;ES中计算相关性得分的时机?如何加速TopK检索&#xff1f;三种思路 精准To…

GEE:设置ui.Map.Layer上交互矢量边界填充颜色为空,只显示边界

一、目标 最近在GEE的交互功能鼓捣一些事情&#xff0c;在利用buffer功能实现了通过选点建立一个矩形后&#xff0c;需要将该矩形填充颜色设为空&#xff0c;只留边界。 然而通过正常设置layer的可视化参数并不能实现这一目的。因此只能另辟蹊径&#xff0c;改为定义矢量边界…

【JavaEE精炼宝库】 网络编程套接字——初识网络编程 | UDP数据报套接字编程

文章目录 一、网络编程基础1.1 网络编程的意义&#xff1a;1.2 网络编程的概念&#xff1a;1.3 网络编程的术语解释&#xff1a;1.4 常见的客户端服务端模型&#xff1a; 二、Socket 套接字2.1 Socket 套接字的概念&#xff1a;2.2 Socket 套接字的分类&#xff1a; 三、UDP数据…

fetchApi === 入门篇

目录 fetch 基本认知 fetch 如何使用 Response对象&#xff08;了解&#xff09; 常见属性 常见方法 fetch 配置参数 fetch发送post请求 fetch 函数封装 fetch 实战 - 图书管理案例 渲染功能 添加功能 删除数据 完整代码 fetch 基本认知 思考&#xff1a; 以前开发…

NumpyPandas:Pandas库(25%-50%)

目录 前言 一、列操作 1.修改变量列 2.筛选变量列 3.删除变量列 4.添加变量列 二、数据类型的转换 1.查看数据类型 2.将 ok的int类型转换成float类型 3.将ar的float类型转换成int类型 三、建立索引 1.建立DataFrame时建立索引 2.在读入数据时建立索引 3.指定某列或…

URL参数中携带中文?分享 1 段优质 JS 代码片段!

本内容首发于工粽号&#xff1a;程序员大澈&#xff0c;每日分享一段优质代码片段&#xff0c;欢迎关注和投稿&#xff01; 大家好&#xff0c;我是大澈&#xff01; 本文约 800 字&#xff0c;整篇阅读约需 1 分钟。 今天分享一段优质 JS 代码片段&#xff0c;在发送 ajax 请…

virtualbox ubuntu扩充磁盘大小

首先在虚拟存储管理里面修改磁盘大小 然后安装gparted sudo gparted 打开管理工具 选中要调整的区域右键选择调整区域大小 拖动上述位置就可以实现扩容。完成后点击应用 然后重启虚拟机即可。

基于深度学习技术及强大的专家团队,针对多个工业垂类场景进行算法优化的智慧城管开源了。

智慧城管视觉监控平台是一款功能强大且简单易用的实时算法视频监控系统。它的愿景是最底层打通各大芯片厂商相互间的壁垒&#xff0c;省去繁琐重复的适配流程&#xff0c;实现芯片、算法、应用的全流程组合&#xff0c;从而大大减少企业级应用约95%的开发成本。 基于深度学习技…

IMU 科 普 o(* ̄︶ ̄*)o

惯性测量单元(Inertial measurement unit&#xff0c;IMU)&#xff0c;是测量物体三轴姿态角速率以及加速度的装置。 IMU通常包含陀螺仪 (Gyroscope)、加速度计 (Accelermeters)&#xff0c;有的还包含磁力计(Magnetometers)和 气压计(barometer)。 几种陀螺仪的类型 1 机械…

微服务的入门

带着问题进行学习&#xff1a; 1. 对服务进行拆分后&#xff0c;物理上是隔离的&#xff0c;数据上也是隔离的&#xff0c;如何进行不同服务之间进行访问呢&#xff1f; 2.前端是怎么样向后端发送请求的&#xff1f; 通过http请求&#xff0c;通过url&#xff0c;请求的…

wireshark--流量分析利器

&#x1f3bc;个人主页&#xff1a;金灰 &#x1f60e;作者简介:一名简单的大一学生;易编橙终身成长社群的嘉宾.✨ 专注网络空间安全服务,期待与您的交流分享~ 感谢您的点赞、关注、评论、收藏、是对我最大的认可和支持&#xff01;❤️ &#x1f34a;易编橙终身成长社群&#…