keras保存模型_TF2 8.模型保存与加载

83b58c97ed5c720b41dd1a1ec0b0e72a.png

0351108cd1f48bbb8fbed16bc86ab5ff.png

a98a7ffbe52bc4c72219ad0b043cbbb7.png

举个例子:先训练出一个模型

import 

接下来第一种方法:只保留模型的参数:这个有2种方法:

model.save_weights("adasd.h5")model.load_weights("adasd.h5")
model.predict(x_test)
model.save_weights('./checkpoints/mannul_checkpoint')
model.load_weights('./checkpoints/mannul_checkpoint')
model.predict(x_test)

因为这种方法只保留了参数,而并没有保留整个模型,所以说在加载的时候需要使用model.load_weights。这个函数只会保留模型的权重,它不包含模型的结构,所以当我们加载权重文件时候,需要先输入网络结构


再第2种方法:保留h5文件,保留整个模型

这种方法已经保存了模型的结构和权重,以及损失函数和优化器

model.save('keras_model_hdf5_version.h5')new_model = tf.keras.models.load_model('keras_model_hdf5_version.h5')
new_model.predict(x_test)

注意:这种方法只可以使用在keras的顺序模型和函数式模型中,不能使用在子类模型和自定义模型中,否则会报错。


再第3种方法:保留pb文件,保留整个模型

# Export the model to a SavedModel
model.save('keras_model_tf_version', save_format='tf')# Recreate the exact same model
new_model = tf.keras.models.load_model('keras_model_tf_version')
new_model.predict(x_test)

这个方法没有保留优化器配置。

保留pb文件还有另一种方法:

tf.saved_model.save(model,'文件夹名') tf.saved_model.load('文件夹名')

注意:这里是文件夹名称!!

当我们使用这个方法后,对应目录下会出现一个文件夹,文件夹下有两个子文件夹和一个子文件:assets、variables、save_model.pb

TensorFlow 为我们提供的SavedModel这一格式可在不同的平台上部署模型文件,当模型导出为 SavedModel 文件时,无需建立。

模型的源代码即可再次运行模型,这使得SavedModel尤其适用于模型的分享和部署

tf.saved_model.save(model,'tf_saved_model_version')
restored_saved_model = tf.saved_model.load('tf_saved_model_version')
f = restored_saved_model.signatures["serving_default"]
注意这里加载好了以后不能直接用predict进行预测哦。

我们这里看一下保存信息:

!saved_model_cli show --dir tf_saved_model_version --all

0802c911d1a87d2cd4bc69953242b488.png

使用模型的命令是:

f(digits = tf.constant(x_test.tolist()) )

输出为:

da74212f3b290b15abce068fa0a05ea0.png

关键是f = restored_saved_model.signatures["serving_default"]。


最后我们看看自定义模型的保存与加载:

注意这里要用以下命令

@tf.function(input_signature=[tf.TensorSpec([None,32], tf.float32,name='digits')])

把动态图变成静态图:

class MyModel(tf.keras.Model):def __init__(self, num_classes=10):super(MyModel, self).__init__(name='my_model')self.num_classes = num_classes# 定义自己需要的层self.dense_1 = tf.keras.layers.Dense(32, activation='relu')self.dense_2 = tf.keras.layers.Dense(num_classes)@tf.function(input_signature=[tf.TensorSpec([None,32], tf.float32,name='digits')])def call(self, inputs):#定义前向传播# 使用在 (in `__init__`)定义的层x = self.dense_1(inputs)return self.dense_2(x)
import numpy as np
x_train = np.random.random((1000, 32))
y_train = np.random.random((1000, 10))
x_val = np.random.random((200, 32))
y_val = np.random.random((200, 10))
x_test = np.random.random((200, 32))
y_test = np.random.random((200, 10))# 优化器
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
# 损失函数
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)# 准备metrics函数
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
val_acc_metric = tf.keras.metrics.CategoricalAccuracy()# 准备训练数据集
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)# 准备测试数据集
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(64)model = MyModel(num_classes=10)
epochs = 3
for epoch in range(epochs):print('Start of epoch %d' % (epoch,))# 遍历数据集的batch_sizefor step, (x_batch_train, y_batch_train) in enumerate(train_dataset):with tf.GradientTape() as tape:logits = model(x_batch_train)loss_value = loss_fn(y_batch_train, logits)grads = tape.gradient(loss_value, model.trainable_weights)optimizer.apply_gradients(zip(grads, model.trainable_weights))# 更新训练集的metricstrain_acc_metric(y_batch_train, logits)# 每200 batches打印一次.if step % 200 == 0:print('Training loss (for one batch) at step %s: %s' % (step, float(loss_value)))print('Seen so far: %s samples' % ((step + 1) * 64))# 在每个epoch结束时显示metrics。train_acc = train_acc_metric.result()print('Training acc over epoch: %s' % (float(train_acc),))# 在每个epoch结束时重置训练指标train_acc_metric.reset_states()# 在每个epoch结束时运行一个验证集。for x_batch_val, y_batch_val in val_dataset:val_logits = model(x_batch_val)# 更新验证集mericsval_acc_metric(y_batch_val, val_logits)val_acc = val_acc_metric.result()val_acc_metric.reset_states()print('Validation acc: %s' % (float(val_acc),))

模型保存方法一:保存weight:

model.save_weights("adasd.h5")
model.load_weights("adasd.h5")
model.predict(x_test)
model.save_weights('./checkpoints/mannul_checkpoint')
model.load_weights('./checkpoints/mannul_checkpoint')
model.predict(x_test)

模型保存方法二:保留h5,方法失败,因为自定义模型无法保留h5:

#model.save('my_saved_model.h5')

模型保存方法三:pb格式:

model.save('path_to_my_model',save_format='tf')
new_model = tf.keras.models.load_model('path_to_my_model')
new_model.predict(x_test)

输出为:

11b887f9f7c9562a7ee2d78697f53a7b.png

或者:

tf.saved_model.save(model,'my_saved_model')
restored_saved_model = tf.saved_model.load('my_saved_model')
f = restored_saved_model.signatures["serving_default"]f(digits = tf.constant(x_test.tolist()) )

887eb6333f0030540340a189bacb3a36.png
!saved_model_cli show --dir my_saved_model --all

0b127b83da91de25e6e053e9699e9668.png

注意前面模型定义时候:

@tf.function(input_signature=[tf.TensorSpec([None,32], tf.float32,name='digits')])

这个digits,对应的就是input['digits'],也对应的是f函数中的自变量digits。


总结:

注意第2种方法h5格式只可以使用在keras的顺序模型和函数式模型中,不能使用在子类模型和自定义模型中。

ce97c308b0a9d7df610c9a0ca5cdd2d3.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/571765.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第一章 Burp Suite 安装和环境配置

Burp Suite是一个集成化的渗透测试工具,它集合了多种渗透测试组件,使我们自动化地或手工地能更好的完成对web应用的渗透测试和攻击。在渗透测试中,我们使用Burp Suite将使得测试工作变得更加容易和方便,即使在不需要娴熟的技巧的情…

mysql57服务无法启动_将mysqld.service服务加入到systemctl

在开始安装二进制MySQL的时候感觉都还挺好,就是在启动服务的时候比较麻烦,一开始是在Centos6下的感觉也没有什么费劲的;但是在Centos7下面还是有点不太适应,不过还好用用就熟悉了;说明一下,我的安装目录在/usr/local/m…

linux raid autodetect,软raid的建立

1 增加磁盘并分区(修改id)fdisk /dev/sdbCommand (m for help): pDisk /dev/sdb: 8589 MB, 8589934592 bytes255 heads, 63 sectors/track, 1044 cylindersUnits cylinders of 16065 * 512 8225280 bytesDevice Boot Start End Blocks Id System/dev/sd…

c语言for循环的省略写法,C语言两种for循环写法分析

每个C程序员都知道同一个for循环语句可以有两种写法:A: for (i 0; i B: for (i cnt; i > 0; i--){ }前几天,DEBUG的时候, 发现采用A写法的代码反汇编出来有BUG.当时没有时间记录,环境也没有保存下来.今天尝试重现,又没来出现上次的问题...很奇怪.很久很久以前也听说过这两…

python文字游戏 生成数字菜单_pygame游戏之旅 游戏中添加显示文字

本文为大家分享了pygame游戏之旅的第5篇,供大家参考,具体内容如下 在游戏中添加显示文字: 这里自己定义一个crash函数接口: def crash(): message_diaplay(You Crashed) 然后实现接口函数message_display(text) def message_diapl…

springboot netty给特定客户端推送_Spring Boot 又升级了?2.0 你搞懂了吗?!

【小宅按】作为知名互联网公司都在用的技术,Spring Boot 2.0 的更新引起了很大的关注,本文将分为三部分解读 2.0 的更新:第一类,基础环境升级;第二类,默认软件替换和优化;第三类,新技…

OSI七层模型与TCP/IP五层模型详解

博主是搞是个FPGA的,一直没有真正的研究过以太网相关的技术,现在终于能静下心学习一下,希望自己能更深入的掌握这项最基本的通信接口技术。下面就开始搞了。 一、OSI参考模型 今天我们先学习一下以太网最基本也是重要的知识——OSI参考模型。…

android 自定义表情包,android基于环信的聊天和表情自定义

环信sdk的导入自定义聊天界面此处只有静态图,请谅解。自定义表情发送自定义聊天界面简单说下自定义的聊天界面,一个带有recyclerview和的xml文件,和对应的adapter即可。recyclerview为展示聊天信息。通过EMClient.getInstance().chatManager(…

erlang安装_RabbitMQ的使用(一)- RabbitMQ服务安装

作者:markjiang7m2博客园地址:https://www.cnblogs.com/markjiang7m2/p/12769627.html官网地址:http://letyouknow.netRabbitMQ,消息队列的一个中间件,这里不打算展开介绍了。此文意在记录工作中使用RabbitMQ时的过程及…

NodeJS React 开发环境搭建

1、首先需要安装NodeJS环境,下载NodeJS安装程序安装即可。 NodeJS下载地址: https://nodejs.org/en/download/ 2、安装NodeJS的web框架express npm install express-generator -g 3、创建项目 express studyReact 4、添加jsx引擎支持 npm install ex…

dreamweaver 正则表达式为属性值加上双引号_Python正则表达式(一)

Python正则表达式正则表达式是处理字符串的强大工具,拥有独特的语法和独立的处理引擎。我们在大文本中匹配字符串时,有些情况用str自带的函数(比如find, in)可能可以完成,有些情况会稍稍复杂一些(比如说找出所有“像邮箱”的字符串&#xff0…

mapperscan注解_SpringBoot 遗忘后的简单快速回忆之环境搭建与常见注解

原文作者:笑而抿之乎搭建SpringBoot环境,创建maven 项目后1,创建入口类:MapperScan(basePackages "com.baizhi.dao" ) //把dao层交给工厂管理SpringBootApplication//标识入口类的注解public class Applincation { …

Android插件丢失怎么办,Android studio推荐插件以及升级后插件丢失问题解决

1、android-butterknife-zeleznyandroid-butterknife-zelezny 是根据butterknife定制的一款插件,能够方便快速初始化,对于我来说是开发必备,本人也对此插件进行了一些优化,个人感觉用起来更爽 ,博客地址:Bu…

软工团队 - 系统设计

软工团队 - 系统设计 修改完善需求规格说明书 针对栋哥在上周答辩中主要提到问题的相应改动 管理员层面没有在需求中得到很好的体现。没有手机号验证。那时候回答的比较含糊orz,所以在这里说明一下对此作出的解释和修改。 对于第一点,我们讨论的结果是至…

python decimal_python学习笔记一

1、~4不太明白、右移、左移整体移动添加零2、注意运算符&#xff0c;3、1<<5&716&704、set中的pop() 方法用于随机移除一个元素。字典中&#xff1a;list中5、Python dir() 函数dir()函数不带参数时&#xff0c;返回当前范围内的变量、方法和定义的类型列表&…

java基础基础总结----- Date

前言&#xff1a;其实在学习这个的时候&#xff0c;自我感觉学到什么直接查询API就可以了&#xff0c;没有必要再去研究某个方法怎么使用&#xff0c; 重点学习一下经常用到的方法。感觉自己的写的博客&#xff0c;就跟自己的笔记一样&#xff0c;用的是时候&#xff0c;就能快…

pandas object转float_Pandas中文官档~基础用法6

呆鸟云&#xff1a;“这一系列长篇终于连载完了&#xff0c;还请大家关注 Python 大咖谈&#xff0c;这里专注 Python 数据分析&#xff0c;后期呆鸟还会给大家分享更多 Pandas 好文。”数据类型大多数情况下&#xff0c;pandas 使用 Numpy 数组、Series 或 DataFrame 里某列的…

android studio 拉取分支,AndroidStudio中使用Git-高级篇(二)——新建分支(branch)和拉取请求(Pull request)...

前段时间写过一篇文章介绍如何在AndroidStudio使用上传项目到github&#xff0c;今天接着给大家带来了他的高级篇——新建分支(branch)和拉取请求(Pull request)。在真正的开发中我们很少写完代码commit后直接push代码上去&#xff0c;因为这样做没有经过第二个人的审核&#x…

collection转换为list_JAVA 集合 接口继承关系和实现,List,Set,Map(总结)

一. JAVA 集合1.接口继承关系和实现集合类存放于 Java.util 包中&#xff0c;主要有 3 种&#xff1a;set(集&#xff09;、list(列表包含 Queue&#xff09;和 map(映射)。1. Collection&#xff1a;Collection 是集合 List、Set、Queue 的最基本的接口。2. Iterator&#xff…

lazarus开发android应用程序指南,Lazarus开发Android应用程序指南(2)

本指南版权由delphicn所有&#xff0c;QQ&#xff1a;1339838080(tom)&#xff0c;转载请保留版权信息。文中难免有错&#xff0c;欢迎指正。2&#xff0e;编译运行lazarus中的Android示例程序。lazarus安装包中自带示例是在lazarus/examples/androidlcl/androidlcltest.lpi 。…