深度学习笔记(九)——tf模型导出保存、模型加载、常用模型导出tflite、权重量化、模型部署

文中程序以Tensorflow-2.6.0为例
部分概念包含笔者个人理解,如有遗漏或错误,欢迎评论或私信指正。
本篇博客主要是工具性介绍,可能由于软件版本问题导致的部分内容无法使用。

首先介绍tflite: TensorFlow Lite 是一组工具,可帮助开发者在移动设备、嵌入式设备和 loT 设备上运行模型,以便实现设备端机器学习。
框架具有的主要特性:

  • 延时(数据无需往返服务器)
  • 隐私(没有任何个人数据离开设备)
  • 连接性(无需连接互联网)
  • 大小(缩减了模型和二进制文件的大小)
  • 功耗(高效推断,且无需网络连接)

官方目前支持了大约130中可以量化的算子,在查阅大量资料后目前自定义的算子使用tflite导出任然存在较多问题。就解决常见的算法,使用支持的算子基本可以覆盖。tflite的压缩能力极强:使用官方算子构建的模型,导出TensorFlow Lite 二进制文件的大小约为 1 MB(针对 32 位 ARM build);如果仅使用支持常见图像分类模型(InceptionV3 和 MobileNet)所需的运算符,TensorFlow Lite 二进制文件的大小不到 300 KB。在后文的实例中我们用iris数据集的分类演示,可以导出一个仅仅只有2kb大小的模型权重相比未压缩的70kb模型缩小了30多倍。
同时tflite还实验性的在支持导出极轻量化的TFLM模型(TensorFlow Lite for Microcontrollers),这些模型可以直接在嵌入式单片机上进行推理,不过现阶段支持的算子还很少,简单的可以利用全连接和低向量卷积实现一些传感器参数的识别任务。现在主要的实例场景是MCU+IMU组合,识别IMU连续数据,来判断人体特定动作。同时开可以在MCU上离线运行语音命令识别,可以实现一个关键字的识别。

好了那我们继续看一下怎么保存模型,加载模型,保存tflite,加载tflite

保存权重或TF格式标准模型

通常情况下当完成了网络结构设计,数据处理,网络训练和评价之后需要及时的保存数据。先看到前面博客中已经介绍过的iris数据集实现网络分类任务。当时通过添加保存回调函数实现了网络权重的保存,这样保存下的是网络权重模型,需要配合网络结构的实例化使用。当然tf还提供了很多种模型的保存方式,tf2官方推荐使用tf形式保存,通过这种方式相关文件会保存到一个指定文件夹中,包含模型的权重参数模型结构信息。

ckpt格式

通过回调函数实现动态权重的保存。

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as np
x_train = datasets.load_iris().data
y_train = datasets.load_iris().target
np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)# 定义网络结构
class IrisModel(Model):def __init__(self):super(IrisModel, self).__init__()self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())def call(self, x):y = self.d1(x)return y
# 实例化化模型
model = IrisModel()
# 定义保存和记录数据的回调器
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath="./checkpoint/iris/iris.ckpt",  # 保存模型权重参数save_weights_only=True,save_best_only=True)
# 初始化模型
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'],callbacks=[cp_callback])
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()

在上面的代码中通过tf.keras.callbacks.ModelCheckpoint设置了一个回调器,会动态的在网络训练的过程中保存下参数表现效果最好的权重参数。这里主要保存网络中各个可变参数的值和网络的当前图数据。这个模型无法直接用于推理,应为其不包含网络完整的图信息。所以我们需要在训练结束时保存网络整体的图信息。

.pd格式

pd格式是tf保存静态模型的专用权重文件。在训练完成后直接执行:

model.save('./yor_save_path/model', save_format='tf')    # 保存模型为静态权重

这样就可以把model的全部图信息保存下来了那么怎么保存最好的呢?可以结合上一个.ckpt文件使用。
对比两个的保存方式差距,前者在动态的训练过程中存储数据,后者针对某一个节点的网络状态完整保存。所以可以在训练过程中保存下最好的参数,当训练结束后再加载回最好的动态权重,然后再保存为.pd文件。

加载权重或TF格式标准模型

动态权重和静态图模型的保存不同,加载也不同。加载.ckpt时,需要先实例化网络结构,然后再读取权重参数给实例化的模型赋值。对于静态模型文件则不需要实例化模型,也就是无需关注网络的内部,直接读取加载模型就会完成网络构建和参数赋值两个任务,在部署时明显静态模型模型的程序文件会更加简单。

model = userModer()
model.compile()
checkpoint_save_path = "./yor_file_path.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):     #print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)

上面的程序展示从动态图加载权重。下面的程序则直接从静态图加载模型。

model_path = './yor_model_path'      
new_model = tf.keras.models.load_model(model_path)  # 从tf模型加载,无需重新实例化网络

从静态图文件加载模型有便捷之处,但是也需要注意模型的输入和输出结构,要保证预测时输入网络的数据维度是符合要求的,同时根据网络输出的模式接收输出数据做相应处理。

转化模型到tflite

转化模型主要有三种方式:

  • 使用现有的 TensorFlow Lite 模型
  • 创建 TensorFlow Lite 模型
  • 将 TensorFlow 模型转换为TensorFlow Lite 模型

模型的保存就分别对应三个主要函数:
在这里插入图片描述

后续主要介绍使用tf构建网络后从tf模型保存到tflite,并以keras model为主。
首先我们需要上面iris数据集分类的例子,当网络训练结束后,可以使用如下的程序导出:

tflite_save_path = './your_file_path'
os.makedirs(tflite_save_path, exist_ok=True)
# Convert the model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# Save the model.
with open(tflite_save_path+'/model.tflite', 'wb') as f:f.write(tflite_model)

导出后可以看到model.tflite的模型文件。可以比较上面直接导出的完整模型,这个模型的体积小了很多,更加适合在低算力和存储能力的设备上运行。

从tflite加载模型并执行推理

从tflite上加载模型并推理主要有两个手段:使用完整tf框架加载tf.lite读取;或使用tflite_runtime,这是 TensorFlow Lite 解释器,无需安装所有 TensorFlow 软件包,但是对python版本和系统,硬件有一定的要求。目前tf-runtime支持的平台有:
在这里插入图片描述
在此以外的模型需要拉取完整源码在本地设备上编译执行。
安装了相关的软件环境后,可以使用如下的代码来加载模型并推理:

# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=tflite_save_path+'/model.tflite')  # 加载模型
interpreter.allocate_tensors()  # 为模型分配张量参数# Get input and output tensors.
input_details = interpreter.get_input_details()     # 设置网络输入
output_details = interpreter.get_output_details()   # 设置网络输出# Test the model on set input data.
input_shape = input_details[0]['shape'] # 获取输入层(第一层)的数据维度
print(input_shape)	# 输出维度结构,便于调试input_data = np.array([6.0,3.4,4.5,1.3], dtype=np.float32)	# 手动给一组鸢尾花数据
input_data = input_data.reshape([1,4])  # 确保维度相同
print(input_data.shape)
interpreter.set_tensor(input_details[0]['index'], input_data)   # 将数据输入到网络中
interpreter.invoke()  # 运行推理
output_data = interpreter.get_tensor(output_details[0]['index'])    # 获得网络输出print(output_data)pred = tf.argmax(output_data, axis=1)   # 网络输出层是softmax,需要找到最大值
print(int(pred))    # 输出最大位置的index

通过上面几行简单的代码就可以在终端设备实现预测推理。使用tf-runtim时只需要做简单修改,将包名替换即可。例如:

import tensorflow as tf  改为:import tflite_runtime.interpreter as tflite
interpreter = tf.lite.Interpreter(model_path=args.model_file) 改为: interpreter = tflite.Interpreter(model_path=args.model_file)

模型量化

对模型执行量化可以进一步解决嵌入式终端设备的痛点。量化模型可以实现:

  • 较小的存储大小:小模型在用户设备上占用的存储空间更少
  • 较小的下载大小:小模型下载到用户设备所需的时间和带宽较少
  • 更少的内存用量:小模型在运行时使用的内存更少,从而释放内存供应用的其他部分使用,并可以转化为更好的性能和稳定性
    tflite支持的量化形式有:
    在这里插入图片描述

训练后量化

训练后量化是一种转换技术,它可以在改善 CPU 和硬件加速器延迟的同时缩减模型大小,且几乎不会降低模型准确率。使用 TensorFlow Lite 转换器将已训练的浮点 TensorFlow 模型转换为 TensorFlow Lite 格式后,可以对该模型进行量化。

动态范围量化

动态范围量化能使模型大小缩减至原来的四分之一,在量化时激活函数始终以浮点格式保存,其它支持的算子会根据损失动态保存为8位整形,以此减小模型体积。在导出时量化模型,设置 optimizations 标记以优化大小:

# 上续训练后的模型
tflite_save_path = './your_file_path'
os.makedirs(tflite_save_path, exist_ok=True)
# Convert the model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
tflite_model_quant_file = tflite_models_dir/"opt_model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_quant_model)

加载模型时使用相同的方式加载即可

全整数量化

全整型量化相对更加复杂一些。
在上面导出tflite的过程中,实际是将tf默认的协议缓冲区模型压缩为FlatBuffers的格式,这种格式具有多种优势,例如可缩减大小(代码占用的空间较小)以及提高推断速度(可直接访问数据,无需执行额外的解析/解压缩步骤)。

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

上面的两行代码实质是做了模型格式的转换和压缩,并没有调整权重参数和计算格式。在上面的基础上,可以进一步使用动态范围量化:

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model_quant = converter.convert()

通过动态范围量化之后模型已经缩小了,但是任然有一部分模型参数是浮点格式,这对存储有效,计算能力有限的设备还是存在限制。
在进行整形量化时需要量化模型内部层和输入输出层。tflite给出了两种量化的方式,第一种量化兼容性相对广泛,但是需要输入一组足够大的代表数据集用来推理量化。这样得到的模型任然有小部分参数会是浮点,这无法支持纯整形计算的硬件。
这里转述官方给出的第二种整形量化方式:
为了量化输入和输出张量,并让转换器在遇到无法量化的运算时引发错误,使用一些附加参数再次转换模型:

def representative_data_gen():for input_value in tf.data.Dataset.from_tensor_slices(train_data).batch(yordatabatch).take(100):yield [input_value]converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to uint8 (APIs added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8tflite_model_quant = converter.convert()

上面的第一个函数是量化的一个必要步骤,要量化可变数据(例如模型输入/输出和层之间的中间体),需要提供 RepresentativeDataset。这是一个生成器函数,它提供一组足够大的输入数据来代表典型值。转换器可以通过该函数估算所有可变数据的动态范围。(相比训练或评估数据集,此数据集不必唯一。)为了支持多个输入,每个代表性数据点都是一个列表,并且列表中的元素会根据其索引被馈送到模型。
通过转化给定数据推理量化,现在模型的输入层和输出层数据已经是整形格式:

interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)

此时模型已经完全支持全整形设备的计算。
那继续的,将模型文件保存下来:

import pathlib
tflite_models_dir = pathlib.Path("/tmp/user_tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)
# Save the unquantized/float model:
tflite_model_file = tflite_models_dir/"user_model.tflite"
tflite_model_file.write_bytes(tflite_model)
# Save the quantized model:
tflite_model_quant_file = tflite_models_dir/"user_model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_model_quant)

执行推理时使用的程序结构和上文介绍的从tflite加载模型并执行推理的内容相同。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/639942.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[足式机器人]Part2 Dr. CAN学习笔记- 最优控制Optimal Control Ch07-1最优控制问题与性能指标

本文仅供学习使用 本文参考: B站:DR_CAN Dr. CAN学习笔记 - 最优控制Optimal Control Ch07-1最优控制问题与性能指标

Spring+SprinMVC+MyBatis注解方式简易模板

SpringSprinMVCMyBatis注解方式简易模板代码Demo GitHub访问 ssm-tpl-anno 一、数据准备 创建数据库test,执行下方SQL创建表ssm-tpl-cfg /*Navicat Premium Data TransferSource Server : 127.0.0.1Source Server Type : MySQLSource Server Version :…

【优化技术专题】「性能优化系列」针对Java对象压缩及序列化技术的探索之路

针对Java对象压缩及序列化技术的探索之路 序列化和反序列化为何需要有序列化呢?Java实现序列化的方式二进制格式 指定语言层级二进制格式 跨语言层级JSON 格式化类JSON格式化:XML文件格式化 序列化的分类在速度的对比上一般有如下规律:Java…

选择排序(二)——堆排序(性能)与直接选择排序

目录 一.前言 二.选择排序 2.1 堆排序 2.2选择排序 2.2.1 基本思想 2.2.2直接选择排序 三.结语 一.前言 本文给大家带来的是选择排序,其中选择排序中的堆排序在之前我们已经有过详解所以本次主要是对比排序性能,感兴趣的友友可移步观看堆排&#…

世微AP5179 60V高端电流采样降压恒流驱动器 LED车灯备用灯信号灯

产品描述 AP5179是一款连续电感电流导通模式的降压恒流源,用于驱动一颗或多颗串联LED输入电压范围从 5 V 到 60V,输出电流 可达 2.0A 。根据不同的输入电压和 外部器件, 可以驱动高达数十瓦的 LED。 内置功率开关,采用高端电流采样…

python实现带刷新的文本进度条

进度条已执行的部分使用“**”,未执行的部分使用“--”,用print()来完成 使用到的函数: time.sleep(),作用是在程序执行过程中暂停一段时间,即会使程序暂停指定的秒数,然后再继续执行后面的代…

【Unity学习笔记】Unity TestRunner使用

转载请注明出处:🔗https://blog.csdn.net/weixin_44013533/article/details/135733479 作者:CSDN|Ringleader| 参考: Input testingGetting started with Unity Test FrameworkHowToRunUnityUnitTest如果对Unity的newInputSystem感…

HTML JavaScript 数字变化特效

效果 案例一&#xff1a;上下滚动 案例二&#xff1a;本身变化 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" /><met…

初识 JVM

什么是JVM JVM 全称是 J ava V irtual M achine&#xff0c;中文译名 Java虚拟机 。 JVM 本质上是一个运行在计算机上的程序&#xff0c;他的职责是运行 Java字节码文件 。 JVM的功能 Java语言如果不做任何优化&#xff0c;性能不如C、C等语言。 Java需要实时解释&…

【主题广范|见刊快】2024年生物信息学与智能系统国际学术会议(IACBIS 2024)

【主题广范|见刊快】2024年生物信息学与智能系统国际学术会议(IACBIS 2024) 2024 International Conference Bioinformatics and Intelligent Systems(IACBIS 2024) 一、【会议简介】 在2024年&#xff0c;一场全球瞩目的学术盛会将在某个繁华的都市中心举行。这次会议的主题是…

Android学习之路(22) ARouter原理解析

1.ARouter认知 首先我们从命名来看:ARouter翻译过来就是一个路由器。 官方定义&#xff1a; 一个用于帮助 Android App 进行组件化改造的框架 —— 支持模块间的路由、通信、解耦 那么什么是路由呢&#xff1f; 简单理解就是&#xff1a;一个公共平台转发系统 工作方式&…

软件系统测试方案-word

2. 测试策略 2.1. 测试完成标准 2.2. 测试类型 2.2.1. 功能测试 2.2.2. 性能测试 2.2.3. 安全性与访问控制测试 2.3. 测试工具 3. 测试技术 4. 测试资源 4.1. 人员安排 4.2. 测试环境 4.2.1. 硬件环境 4.2.2. 软件环境 4.3. 进度安排 5. 功能测试 6. 性能测试 7. 安全性与访问控…

[Unity] Tilemap瓦片左右翻转(上下翻转)

Tile&#xff08;瓦片&#xff09;左右翻转感觉是很常用的一个功能啊&#xff01;看了一些教程都没有提及&#xff0c;心想难道要把每张Sprite再做一张对称的、再做成瓦片吗&#xff1f; 图片量x2 、瓦片量x2、不现实&#xff01;一定有方法&#xff01; 搜索了了半天没找到方…

架构篇04:复杂度来源 - 高性能

文章目录 单机复杂度集群的复杂度小结 从本篇开始&#xff0c;我们一起深入分析架构设计复杂度的 6 个来源&#xff0c;先来聊聊复杂度的来源之一高性能。 对性能孜孜不倦的追求是整个人类技术不断发展的根本驱动力。例如计算机&#xff0c;从电子管计算机到晶体管计算机再到集…

springsecurity集成kaptcha功能

前端代码 本次采用简单的html静态页面作为演示&#xff0c;也可结合vue前后端分离开发&#xff0c;复制就可运行测试 项目目录 登录界面 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</…

【设计模式-9】装饰模式的代码实现及使用场景

装饰器模式类比生活中房屋装修的场景&#xff0c;可以在毛坯房的基础上加以各种装饰&#xff0c;使得房屋的居住属性增强。装饰器模式能够在运行期间&#xff0c;动态地为原始对象增加一些额外的功能&#xff0c;使其功能更为丰富。 1. 概述 装饰模式 可以动态的为某些对象增…

深入Matplotlib:画布分区与高级图形展示【第33篇—python:Matplotlib】

文章目录 Matplotlib画布分区技术详解引言方法一&#xff1a;plt.subplot()方法二&#xff1a;简略写法方法三&#xff1a;plt.subplots()实例展示添加更多元素 进一步探索Matplotlib画布分区自定义子图布局3D子图结语 Matplotlib画布分区技术详解 引言 Matplotlib是一个强大…

代码随想录27期|Python|Day35|435. 无重叠区间|763.划分字母区间|56. 合并区间

435. 无重叠区间 和昨天的射爆气球是一样的处理方式&#xff1a; 由于不需要进行不重合的时候的计算&#xff0c;只需要对重合进行处理&#xff0c;所以反而更加简单。 1、按照区间左边界从小到大排序&#xff1b; 2、从索引1开始遍历&#xff0c;对于i-1的右边界大于i的左边…

网页无法访问但是有网什么原因

目录 1.运行网络诊断&#xff0c;确认原因 原因A.远程计算机或设备将不接受连接(该设备或资源(Web 代理)未设置为接受端口“7890”上的连接 原因B.DNS服务器未响应 场景A.其他的浏览器可以打开网页&#xff0c;自带的Edge却不行 方法A&#xff1a;关闭代理 Google自带翻译…

用户头像上传

将用户上传的头像存储在腾讯云存储桶里 注册腾讯云 https://cloud.tencent.com/login 创建存储桶 配置跨域 来源 * (任何都可以访问) put get post 请求都可以 点击概览&#xff0c;查看存储桶基本信息 记录保存存储桶名称和地域 找到api密钥管理&#xff0c;新建密钥 ht…