【TensorFlow2 之013】TensorFlow-Lite

一、说明

        在这篇文章中,我们将展示如何构建计算机视觉模型并准备将其部署在移动和嵌入式设备上。有了这些知识,您就可以真正将脚本部署到日常使用或移动应用程序中。

        教程概述:

  1. 介绍
  2. 在 TensorFlow 中构建模型
  3. 将模型转换为 TensorFlow Lite
  4. 训练后量化

二、前提知识

      上次,我们展示了如何使用迁移学习来提高模型性能。但是,当我们可以在智能手机或其他嵌入式设备上使用我们的模型时,为什么我们只使用我们的模型来预测计算机上猫或狗的图像呢?

        TensorFlow Lite 是 TensorFlow 针对移动和嵌入式设备的轻量级解决方案。它使我们能够在移动设备上以低延迟快速运行机器学习模型,而无需访问服务器。

        它可以通过 C++ API 在 Android 和 iOS 设备上使用,也可以与 Android 开发人员的 Java 包装器一起使用。

三、 在 TensorFlow 中构建模型

        在开始使用 TensorFlow Lite 之前,我们需要训练一个模型。我们将使用台式计算机或云平台等更强大的机器从一组数据训练模型。然后,可以导出该模型并在移动设备上使用。

        让我们首先准备训练数据集。我们将使用 wget .download 命令下载数据集。之后,我们需要解压缩它并合并训练和测试部分的路径。

import os
import wget
import zipfile wget.download("https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip")
100% [........................................................................] 68606236 / 68606236

Out[2]:

'cats_and_dogs_filtered.zip'
with zipfile.ZipFile("cats_and_dogs_filtered.zip","r") as zip_ref:zip_ref.extractall()base_dir = 'cats_and_dogs_filtered'train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

        

        接下来,让我们导入所有必需的库并构建我们的模型。我们将使用一个名为MobileNetV2的预训练网络 ,该网络是在 ImageNet 数据集上进行训练的。之后,我们需要冻结预训练层并添加一个名为GlobalAveragePooling2D 的新层,然后是具有 sigmoid 激活函数的密集层。我们将使用图像数据生成器来迭代图像。

base_model = MobileNetV2(input_shape=(224, 224, 3),include_top=False,weights='imagenet')base_model.trainable = False
model = Sequential([base_model,GlobalAveragePooling2D(),Dense(1, activation='sigmoid')])
model.compile(loss='binary_crossentropy',optimizer=RMSprop(lr=0.001),metrics=['accuracy'])train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen = ImageDataGenerator(rescale=1./255)train_generator = train_datagen.flow_from_directory(train_dir,target_size=(224, 224),batch_size=32,class_mode='binary')validation_generator = val_datagen.flow_from_directory(validation_dir,target_size=(224, 224),batch_size=32,class_mode='binary')
Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.
history = model.fit(train_generator,epochs=6,validation_data=validation_generator,verbose=2)
Train for 63 steps, validate for 32 steps
Epoch 1/6
63/63 - 470s - loss: 0.3610 - accuracy: 0.8515 - val_loss: 0.1664 - val_accuracy: 0.9460
Epoch 2/6
63/63 - 139s - loss: 0.2055 - accuracy: 0.9210 - val_loss: 0.2065 - val_accuracy: 0.9150
Epoch 3/6
63/63 - 89s - loss: 0.1507 - accuracy: 0.9470 - val_loss: 0.1294 - val_accuracy: 0.9560
Epoch 4/6
63/63 - 81s - loss: 0.1342 - accuracy: 0.9510 - val_loss: 0.1733 - val_accuracy: 0.9350
Epoch 5/6
63/63 - 80s - loss: 0.1205 - accuracy: 0.9555 - val_loss: 0.1684 - val_accuracy: 0.9390
Epoch 6/6
63/63 - 82s - loss: 0.1079 - accuracy: 0.9620 - val_loss: 0.1418 - val_accuracy: 0.9450

四、 将模型转换为 TensorFlow Lite

        在 TensorFlow 中训练深度神经网络后,我们可以将其转换为 TensorFlow Lite。

        这里的想法是使用提供的 TensorFlow Lite 转换器来转换模型并生成 TensorFlow Lite FlatBuffer文件。此类文件的扩展名是.tflite

        可以直接从经过训练的 tf.keras 模型、保存的模型或具体的 TensorFlow 函数进行转换,这也是可能的。

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()open("converted_model.tflite", "wb").write(tflite_model)

        转换后,FlatBuffer 文件可以部署到客户端设备。在客户端设备上,它可以使用 TensorFlow Lite 解释器在本地运行。

https://www.tensorflow.org/lite/convert

        让我们看看如何使用TensorFlow Lite Interpreter直接从 Python 运行该模型。

interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

        我们可以通过运行get_input_detailsget_output_details函数来获取模型的输入和输出详细信息。我们还可以调整输入和输出的大小,这样我们就可以对整批图像进行预测。

input_details = interpreter.get_input_details()
output_details= interpreter.get_output_details()print("== Input details ==")
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])
print("\n== Output details ==")
print("shape:", output_details[0]['shape'])
print("type:", output_details[0]['dtype'])
== Input details ==
shape: [  1 224 224   3]
type: <class 'numpy.float32'>== Output details ==
shape: [1 1]
type: <class 'numpy.float32'>
interpreter.resize_tensor_input(input_details[0]['index'], (32, 224, 224, 3))
interpreter.resize_tensor_input(output_details[0]['index'], (32, 5))
interpreter.allocate_tensors()input_details = interpreter.get_input_details()
output_details= interpreter.get_output_details()print("== Input details ==")
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])
print("\n== Output details ==")
print("shape:", output_details[0]['shape'])
print("type:", output_details[0]['dtype'])
== Input details ==
shape: [ 32 224 224   3]
type: <class 'numpy.float32'>== Output details ==
shape: [32  1]
type: <class 'numpy.float32'>

        接下来,我们可以生成一批图像,在我们的例子中为 32 个,因为这是之前图像数据生成器中使用的数量。

        使用 TensorFlow Lite 进行预测的第一步是设置模型的输入,在我们的示例中是一批 32 张图像。设置输入张量后,我们可以调用invoke 命令来调用解释器。输出预测将通过调用get_tensor命令获得。

        我们来对比一下 TensorFlow Lite 模型和 TensorFlow 模型的结果。它们应该是相同的。我们将使用np.testing.assert_almost_equal函数来完成此操作。如果结果在一定的小数位数内不相同,则会引发错误。

mage_batch, label_batch = next(iter(validation_generator))
interpreter.set_tensor(input_details[0]['index'], image_batch)
interpreter.invoke()
tflite_results = interpreter.get_tensor(output_details[0]['index'])tf_results = model.predict(image_batch)for tf_result, tflite_result in zip(tf_results, tflite_results):np.testing.assert_almost_equal(tf_result, tflite_result, decimal=5)

        到目前为止没有错误,所以它向我们表明此转换已成功完成。但这是有阶级概率的。让我们编写一个函数将其转换为类。该函数将执行与 TensorFlow 中内置函数相同的操作。此外,我们还将展示这些预测。红色将显示与主 TensorFlow 模型不同的预测,蓝色图像将显示两个模型做出相同的预测。作为标签,我们将显示 TensorFlow Lite 模型的预测。

def predict_classes(tf_model, tf_lite_interpreter, x):tf_lite_interpreter.set_tensor(input_details[0]['index'], x)tf_lite_interpreter.invoke()tflite_results = tf_lite_interpreter.get_tensor(output_details[0]['index'])tf_results = model.predict_classes(x)if tflite_results.shape[-1] > 1:tflite_results = tflite_results.argmax(axis=-1)else:tflite_results = (tflite_results > 0.5).astype('int32')plt.figure(figsize=(12,10))for n in range(32):plt.subplot(6,6,n+1)plt.subplots_adjust(hspace = 0.3)plt.imshow(image_batch[n])color = "blue" if tf_results[n] == tflite_results[n] else "red"plt.title(tflite_results[n], color=color)plt.axis('off')_ = plt.suptitle("Model predictions (blue: same, red: different)")predict_classes(model, interpreter, image_batch)

模型预测 - 猫与狗

五、训练后量化

        这非常简单,但某些模型并未针对在低功耗设备上运行进行优化。因此,为了解决这个问题,我们需要进行量化。

        训练后量化是一种转换技术,可以减小模型大小,同时改善 CPU 和硬件加速器延迟,而模型精度几乎没有下降。

        下表显示了三种技术及其优点,以及可以运行该技术的硬件。

技术好处硬件
权重量化缩小 4 倍,加速 2-3 倍,准确度提高中央处理器
全整数量化缩小 4 倍,加速 3 倍以上CPU、边缘TPU等
Float16 量化体积缩小 2 倍,具有潜在的 GPU 加速能力中央处理器/图形处理器

        让我们看看它是如何进行权重量化的,这是默认的。

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()open("converted_model.tflite", "wb").write(tflite_quant_model)interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details= interpreter.get_output_details()interpreter.resize_tensor_input(input_details[0]['index'], (32, 224, 224, 3))
interpreter.resize_tensor_input(output_details[0]['index'], (32, 5))
interpreter.allocate_tensors()input_details = interpreter.get_input_details()
output_details= interpreter.get_output_details()
print("== Input details ==")
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])
print("\n== Output details ==")
print("shape:", output_details[0]['shape'])
print("type:", output_details[0]['dtype'])
== Input details ==
shape: [ 32 224 224   3]
type: <class 'numpy.float32'>== Output details ==
shape: [32  1]
type: <class 'numpy.float32'>

以下决策树可以帮助确定哪种训练后量化方法最适合您的用例。

决策树

        在我们的例子中,我们使用名为 MobileNetV2 的模型,该模型已经针对低功耗设备进行了优化。所以本例中的权重量化会降低模型精度。

        和之前一样,让我们​​看看这次模型的表现如何。红色将显示与主 TensorFlow 模型不同的预测,蓝色图像将显示两个模型做出相同的预测。作为标签,我们将显示 TensorFlow Lite 模型的预测。

predict_classes(model, interpreter, image_batch)

模型预测 2 - 猫与狗

六、总结

        总而言之,在这篇文章中,我们讨论了从 TensorFlow 到 TensorFlow Lite 的转换以及低功耗设备的优化。在下一篇文章中,我们将展示如何在 TensorFlow 2.0 中实现 LeNet-5。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/103502.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Jenkins 结合 ANT 发送测试报告

目录 全局变量配置 新建任务 插件安装 HTML 报告配置 邮件配置 全局变量配置 点击 ManageJenkins进入Jenkins 管理 点击 Global Tool Configuration 进入全局变量配置 配置 Ant &#xff0c;Name 自己定义一个比较好理解的名称。 去掉 Install automatically 勾选&#x…

华为OD机考算法题:找终点

目录 题目部分 解读与分析 代码实现 题目部分 题目找终点难度易题目说明给定一个正整数数组&#xff0c;设为nums&#xff0c;最大为100个成员&#xff0c;求从第一个成员开始&#xff0c;正好走到数组最后一个成员&#xff0c;所使用的最少步骤数。 要求&#xff1a; 1.第…

skywalking动态配置[集成nacos/apollo/consul]

说明:以下配置仅关于的阈值规则的动态配置,其他参数也可以进行配置。 1,skywalking动态配置集成nacos 编辑application.yml nacos配置参数如下: nacos:# Nacos Server HostserverAddr: 10.10.5.145# Nacos Server Portport: 8848# Nacos Configuration Groupgroup: skywal…

2.1、如何在FlinkSQL中读取写出到Kafka

目录 1、环境设置 方式1&#xff1a;在Maven工程中添加pom依赖 方式2&#xff1a;在 sql-client.sh 中添加 jar包依赖 2、读取Kafka 2.1 创建 kafka表 2.2 读取 kafka消息体&#xff08;Value&#xff09; 使用 format json 解析json格式的消息 使用 format csv 解析…

快速学习微服务保护框架--Sentinel

学习一个框架最好的方式就是查看官方地址,sentinel是国内阿里巴巴公司的,官网更方便官网 官网 微服务保护框架 Sentinel 1.初识Sentinel 1.1.雪崩问题及解决方案 1.1.1.雪崩问题 微服务中&#xff0c;服务间调用关系错综复杂&#xff0c;一个微服务往往依赖于多个其它微…

全力以赴,火山引擎边缘云代表团出战亚运会

END 未来&#xff0c;火山引擎边缘云赛事阵容将继续全力以赴&#xff0c;通过领先、可信赖的云和智能技术&#xff0c;助力游戏行业呈现更加精彩的竞技赛事。

光耦合器继电器与传统继电器:哪种最适合您的项目?

在电子和电气工程领域&#xff0c;继电器的选择可以显着影响项目的性能和安全性。两种常见类型的继电器是光耦合器继电器和传统机电继电器。每个都有其优点和缺点&#xff0c;因此选择过程对于项目的成功结果至关重要。 光耦合器继电器&#xff1a;基础知识 光耦合器继电器&…

linux环境下使用lighthouse与selenium

一、安装谷歌浏览器、谷歌浏览器驱动、lighthouse shell脚本 apt update && apt -y upgrade apt install -y curl curl -fsSL https://deb.nodesource.com/setup_18.x | bash apt install -y nodejs apt install -y npm npm install -g lighthouse apt-get install -y …

AutoRunner自动化测试工具

AutoRunner自动化测试工具(简称AR&#xff09;是泽众软件自主研发的自动化测试工具&#xff0c;也是一个自动测试框架&#xff0c;加载不同的测试组件&#xff0c;能够实现面向不同应用的测试。通过录制和编写测试脚本&#xff0c;实现功能测试、回归测试的自动化&#xff0c;自…

记录一个@Transaction注解引发的bug

记录一个Transactional(readOnly true)注解引发的bug 一、问题代码和报错 1-1 问题代码模拟 引发这个问题的三大要素分别是&#xff1a; 事务注解任意数据库操作数据库操作后执行耗时业务&#xff08;耗时超过数据库配置的超时时间&#xff09; //1.这里是问题的核心之一…

“过度炒作”的大模型巨亏,Copilot每月收10刀,倒赔20刀

大模型无论是训练还是使用&#xff0c;都比较“烧钱”&#xff0c;只是其背后的成本究竟高到何处&#xff1f;已经推出大模型商用产品的公司到底有没有赚到钱&#xff1f;事实上&#xff0c;即使微软、亚马逊、Adobe 这些大厂&#xff0c;距离盈利之路还有很远&#xff01;同时…

微信自动批量添加好友的方法

在现在的营销中微信已成为一种重要的沟通方式。微信目前是没有自动批量添加好友的功能&#xff0c;需要运营者一个一个手动去添加&#xff0c;这样太过于浪费时间&#xff0c;并且加频繁了还容易被封号&#xff0c;今天给大家介绍几种手动批量加好友的方式以及怎么借助第三方软…

国窖1573持续演绎共生魅力,携手马岩松个展感知建筑艺术之美

执笔 | 洪大大 编辑 | 萧 萧 从艺术到文化、从需求到场景、从体验到消费&#xff0c;国窖1573正通过一次次尝试与探索实现与多元文化的共创与共生。 10月12日&#xff0c;国窖1573品牌挚友马岩松举办的“流动的大地”展览在深圳当代艺术与城市规划馆正式开幕。泸州老窖股份…

什么是MTU(Maximum Transmission Unit)?

最大传输单元MTU&#xff08;Maximum Transmission Unit&#xff0c;MTU&#xff09;&#xff0c;是指网络能够传输的最大数据包大小&#xff0c;以字节为单位。MTU的大小决定了发送端一次能够发送报文的最大字节数。如果MTU超过了接收端所能够承受的最大值&#xff0c;或者是超…

2023年09月 C/C++(五级)真题解析#中国电子学会#全国青少年软件编程等级考试

C/C++编程(1~8级)全部真题・点这里 Python编程(1~6级)全部真题・点这里 第1题:红与黑 有一间长方形的房子,地上铺了红色、黑色两种颜色的正方形瓷砖。你站在其中一块黑色的瓷砖上,只能向相邻的黑色瓷砖移动。请写一个程序,计算你总共能够到达多少块黑色的瓷砖。 时间限…

[42000][923] ORA-00923: 未找到要求的 FROM 关键字

在oracle数据库写分页查询&#xff0c;使用 rownum时候出错&#xff0c; 代码&#xff1a; SELECT *FROM (SELECT *, ROWNUM AS rnumFROM test t ) WHERE rnum BETWEEN 1 AND 5; 报错&#xff1a; [42000][923] ORA-00923: 未找到要求的 FROM 关键字 Position: 31 问题原因…

vue3前端开发系列 - electron开发桌面程序(2023-10月最新版)

文章目录 1. 说明2. 创建项目3. 创建文件夹electron3.1 编写脚本electron.js3.2 编写脚本proload.js 4. 修改package.json4.1 删除type4.2 修改scripts4.3 完整的配置如下 5. 修改App.vue6. 修改vite.config.ts7. 启动8. 打包安装9. 项目公开地址 1. 说明 本次安装使用的环境版…

kantts docker化

kan-tts docker本地化 环境安装 下载docker镜像(python3.8的) registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.8.0-py38-torch2.0.1-tf2.13.0-1.9.2 安装基础模型 pip install modelscope 安装语音模型 pip install "modelscope…

RunnerGo测试平台,无代码玩转UI自动化测试

首先需要进入官网&#xff0c;RunnerGo支持开源&#xff0c;可以自行下载安装&#xff0c;也可以点击右上角体验企业版按钮快速体验 点击体验企业版进入工作台后可以点击页面上方的UI自动化 进入到测试页面 创建元素 我们可以在元素管理中创建我们测试时需要的元素 这里我们以…