keras保存模型_TF2 8.模型保存与加载

83b58c97ed5c720b41dd1a1ec0b0e72a.png

0351108cd1f48bbb8fbed16bc86ab5ff.png

a98a7ffbe52bc4c72219ad0b043cbbb7.png

举个例子:先训练出一个模型

import 

接下来第一种方法:只保留模型的参数:这个有2种方法:

model.save_weights("adasd.h5")model.load_weights("adasd.h5")
model.predict(x_test)
model.save_weights('./checkpoints/mannul_checkpoint')
model.load_weights('./checkpoints/mannul_checkpoint')
model.predict(x_test)

因为这种方法只保留了参数,而并没有保留整个模型,所以说在加载的时候需要使用model.load_weights。这个函数只会保留模型的权重,它不包含模型的结构,所以当我们加载权重文件时候,需要先输入网络结构


再第2种方法:保留h5文件,保留整个模型

这种方法已经保存了模型的结构和权重,以及损失函数和优化器

model.save('keras_model_hdf5_version.h5')new_model = tf.keras.models.load_model('keras_model_hdf5_version.h5')
new_model.predict(x_test)

注意:这种方法只可以使用在keras的顺序模型和函数式模型中,不能使用在子类模型和自定义模型中,否则会报错。


再第3种方法:保留pb文件,保留整个模型

# Export the model to a SavedModel
model.save('keras_model_tf_version', save_format='tf')# Recreate the exact same model
new_model = tf.keras.models.load_model('keras_model_tf_version')
new_model.predict(x_test)

这个方法没有保留优化器配置。

保留pb文件还有另一种方法:

tf.saved_model.save(model,'文件夹名') tf.saved_model.load('文件夹名')

注意:这里是文件夹名称!!

当我们使用这个方法后,对应目录下会出现一个文件夹,文件夹下有两个子文件夹和一个子文件:assets、variables、save_model.pb

TensorFlow 为我们提供的SavedModel这一格式可在不同的平台上部署模型文件,当模型导出为 SavedModel 文件时,无需建立。

模型的源代码即可再次运行模型,这使得SavedModel尤其适用于模型的分享和部署

tf.saved_model.save(model,'tf_saved_model_version')
restored_saved_model = tf.saved_model.load('tf_saved_model_version')
f = restored_saved_model.signatures["serving_default"]
注意这里加载好了以后不能直接用predict进行预测哦。

我们这里看一下保存信息:

!saved_model_cli show --dir tf_saved_model_version --all

0802c911d1a87d2cd4bc69953242b488.png

使用模型的命令是:

f(digits = tf.constant(x_test.tolist()) )

输出为:

da74212f3b290b15abce068fa0a05ea0.png

关键是f = restored_saved_model.signatures["serving_default"]。


最后我们看看自定义模型的保存与加载:

注意这里要用以下命令

@tf.function(input_signature=[tf.TensorSpec([None,32], tf.float32,name='digits')])

把动态图变成静态图:

class MyModel(tf.keras.Model):def __init__(self, num_classes=10):super(MyModel, self).__init__(name='my_model')self.num_classes = num_classes# 定义自己需要的层self.dense_1 = tf.keras.layers.Dense(32, activation='relu')self.dense_2 = tf.keras.layers.Dense(num_classes)@tf.function(input_signature=[tf.TensorSpec([None,32], tf.float32,name='digits')])def call(self, inputs):#定义前向传播# 使用在 (in `__init__`)定义的层x = self.dense_1(inputs)return self.dense_2(x)
import numpy as np
x_train = np.random.random((1000, 32))
y_train = np.random.random((1000, 10))
x_val = np.random.random((200, 32))
y_val = np.random.random((200, 10))
x_test = np.random.random((200, 32))
y_test = np.random.random((200, 10))# 优化器
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
# 损失函数
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)# 准备metrics函数
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
val_acc_metric = tf.keras.metrics.CategoricalAccuracy()# 准备训练数据集
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)# 准备测试数据集
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(64)model = MyModel(num_classes=10)
epochs = 3
for epoch in range(epochs):print('Start of epoch %d' % (epoch,))# 遍历数据集的batch_sizefor step, (x_batch_train, y_batch_train) in enumerate(train_dataset):with tf.GradientTape() as tape:logits = model(x_batch_train)loss_value = loss_fn(y_batch_train, logits)grads = tape.gradient(loss_value, model.trainable_weights)optimizer.apply_gradients(zip(grads, model.trainable_weights))# 更新训练集的metricstrain_acc_metric(y_batch_train, logits)# 每200 batches打印一次.if step % 200 == 0:print('Training loss (for one batch) at step %s: %s' % (step, float(loss_value)))print('Seen so far: %s samples' % ((step + 1) * 64))# 在每个epoch结束时显示metrics。train_acc = train_acc_metric.result()print('Training acc over epoch: %s' % (float(train_acc),))# 在每个epoch结束时重置训练指标train_acc_metric.reset_states()# 在每个epoch结束时运行一个验证集。for x_batch_val, y_batch_val in val_dataset:val_logits = model(x_batch_val)# 更新验证集mericsval_acc_metric(y_batch_val, val_logits)val_acc = val_acc_metric.result()val_acc_metric.reset_states()print('Validation acc: %s' % (float(val_acc),))

模型保存方法一:保存weight:

model.save_weights("adasd.h5")
model.load_weights("adasd.h5")
model.predict(x_test)
model.save_weights('./checkpoints/mannul_checkpoint')
model.load_weights('./checkpoints/mannul_checkpoint')
model.predict(x_test)

模型保存方法二:保留h5,方法失败,因为自定义模型无法保留h5:

#model.save('my_saved_model.h5')

模型保存方法三:pb格式:

model.save('path_to_my_model',save_format='tf')
new_model = tf.keras.models.load_model('path_to_my_model')
new_model.predict(x_test)

输出为:

11b887f9f7c9562a7ee2d78697f53a7b.png

或者:

tf.saved_model.save(model,'my_saved_model')
restored_saved_model = tf.saved_model.load('my_saved_model')
f = restored_saved_model.signatures["serving_default"]f(digits = tf.constant(x_test.tolist()) )

887eb6333f0030540340a189bacb3a36.png
!saved_model_cli show --dir my_saved_model --all

0b127b83da91de25e6e053e9699e9668.png

注意前面模型定义时候:

@tf.function(input_signature=[tf.TensorSpec([None,32], tf.float32,name='digits')])

这个digits,对应的就是input['digits'],也对应的是f函数中的自变量digits。


总结:

注意第2种方法h5格式只可以使用在keras的顺序模型和函数式模型中,不能使用在子类模型和自定义模型中。

ce97c308b0a9d7df610c9a0ca5cdd2d3.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/571765.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第一章 Burp Suite 安装和环境配置

Burp Suite是一个集成化的渗透测试工具,它集合了多种渗透测试组件,使我们自动化地或手工地能更好的完成对web应用的渗透测试和攻击。在渗透测试中,我们使用Burp Suite将使得测试工作变得更加容易和方便,即使在不需要娴熟的技巧的情…

mysql57服务无法启动_将mysqld.service服务加入到systemctl

在开始安装二进制MySQL的时候感觉都还挺好,就是在启动服务的时候比较麻烦,一开始是在Centos6下的感觉也没有什么费劲的;但是在Centos7下面还是有点不太适应,不过还好用用就熟悉了;说明一下,我的安装目录在/usr/local/m…

linux raid autodetect,软raid的建立

1 增加磁盘并分区(修改id)fdisk /dev/sdbCommand (m for help): pDisk /dev/sdb: 8589 MB, 8589934592 bytes255 heads, 63 sectors/track, 1044 cylindersUnits cylinders of 16065 * 512 8225280 bytesDevice Boot Start End Blocks Id System/dev/sd…

input readonly 光标显示问题

input readonly模式下在ie跟火狐访问的时候会有光标会出现&#xff0c;以下方法可解决这个问题 <input type"text" readonly unselectableon onfocus"this.blur()"> 1.unselectableon 是解决ie下光标出现的问题 2.οnfοcus"this.blur() 是解决…

c语言for循环的省略写法,C语言两种for循环写法分析

每个C程序员都知道同一个for循环语句可以有两种写法:A: for (i 0; i B: for (i cnt; i > 0; i--){ }前几天,DEBUG的时候, 发现采用A写法的代码反汇编出来有BUG.当时没有时间记录,环境也没有保存下来.今天尝试重现,又没来出现上次的问题...很奇怪.很久很久以前也听说过这两…

python文字游戏 生成数字菜单_pygame游戏之旅 游戏中添加显示文字

本文为大家分享了pygame游戏之旅的第5篇&#xff0c;供大家参考&#xff0c;具体内容如下 在游戏中添加显示文字&#xff1a; 这里自己定义一个crash函数接口&#xff1a; def crash(): message_diaplay(You Crashed) 然后实现接口函数message_display(text) def message_diapl…

快速排序的改进

package com.txq.test; /*** quicksort,三方面改进&#xff1a;①三数中值选择枢纽元②容量小的时候使用插入排序③重复元素的处理* author XueQiang Tong* date 2017/10/25*/ public class QS {public void quicksort(int []arr,int low,int high){int first low;int last h…

23根火柴游戏 c语言,23 根火柴游戏

#includegt;int main(){int g 23;int k 3;int b, c;printf("这里是23 根火柴游戏&#xff01;&#xff01;\n");printf("注意&#xff1a;最大移动火柴数目为三根\n");do{printf("请输入移动的火柴数目&#xff1a;\n");scanf("%d",…

springboot netty给特定客户端推送_Spring Boot 又升级了?2.0 你搞懂了吗?!

【小宅按】作为知名互联网公司都在用的技术&#xff0c;Spring Boot 2.0 的更新引起了很大的关注&#xff0c;本文将分为三部分解读 2.0 的更新&#xff1a;第一类&#xff0c;基础环境升级&#xff1b;第二类&#xff0c;默认软件替换和优化&#xff1b;第三类&#xff0c;新技…

OSI七层模型与TCP/IP五层模型详解

博主是搞是个FPGA的&#xff0c;一直没有真正的研究过以太网相关的技术&#xff0c;现在终于能静下心学习一下&#xff0c;希望自己能更深入的掌握这项最基本的通信接口技术。下面就开始搞了。 一、OSI参考模型 今天我们先学习一下以太网最基本也是重要的知识——OSI参考模型。…

c是过程化语言吗数据库,关于SQL错误的是()A、所有数据库的公共语言B、非过程化的C、统一的语言D、所有用SQL缩写的程序都...

关于SQL错误的是()A、所有数据库的公共语言B、非过程化的C、统一的语言D、所有用SQL缩写的程序都更多相关问题[多选] 在彩色电视机遥控系统中&#xff0c;属于模拟量控制的有()等几种。[多选] 在色度信号记录处理中&#xff0c;家用录像机一般都要对色度信号经过()等处理。[多选…

python建立数据库和基本表_python基础 — 链接 Mysql 创建 数据库和创表

重点&#xff1a; &#xff11;. 链接服务器的数据库 &#xff12;. 创建表和格式 &#xff13;. 插入多行数据 import pymysql try: hostxxx userxxx passwdxxx dbtest01 port3306 Table_namekaka5 # 链接到服务器 db pymysql.connect(host, user, passwd, db, port) # 创…

c语言陷阱试题,超级经典计算机二级C语言陷阱考试题.doc

超级经典计算机二级C语言陷阱考试题超级经典计算机二级C语言陷阱考试题若有定义&#xff1a;int a[2][3],则对a数组的第i行j列元素地址的正确引用为___d___.a)*(a[i]j) b)(ai) c)*(aj) d)a[i]j以下正确的程序段是_________.a)char str[20]; b)char *p;scanf("%s",&am…

python开发跟淘宝有关联微_Python_淘宝用户行为分析

一、数据导入与清洗 源数据量有1亿余条&#xff0c;为减轻计算量&#xff0c;抽样总量的20%用于计算分析 #codinggbk import numpy as py import pandas as pd import datetime import os os.chdir(D:/pythonlily/test1) datapd.read_csv(UserBehavior.csv,headerNone) data.co…

android 自定义表情包,android基于环信的聊天和表情自定义

环信sdk的导入自定义聊天界面此处只有静态图&#xff0c;请谅解。自定义表情发送自定义聊天界面简单说下自定义的聊天界面&#xff0c;一个带有recyclerview和的xml文件&#xff0c;和对应的adapter即可。recyclerview为展示聊天信息。通过EMClient.getInstance().chatManager(…

如何快速获取properties中的配置属性值

本文为博主原创&#xff0c;未经博主允许&#xff0c;不得转载&#xff1a; 在项目中&#xff0c;经常需要将一些配置的常量信息放到properties文件中&#xff0c;这样在项目的配置变动的时候&#xff0c;只需要修改配置文件中 对应的配置常量即可。 在项目应用中&#xff0c;如…

erlang安装_RabbitMQ的使用(一)- RabbitMQ服务安装

作者&#xff1a;markjiang7m2博客园地址&#xff1a;https://www.cnblogs.com/markjiang7m2/p/12769627.html官网地址&#xff1a;http://letyouknow.netRabbitMQ&#xff0c;消息队列的一个中间件&#xff0c;这里不打算展开介绍了。此文意在记录工作中使用RabbitMQ时的过程及…

android 本地资源 uri,Android 本地文件选择

打开系统文件&#xff1a;Intent intent new Intent(Intent.ACTION_GET_CONTENT);intent.setType("*/*");intent.addCategory(Intent.CATEGORY_OPENABLE);try {startActivityForResult(Intent.createChooser(intent, getString(R.string.im_text_select_file)), SEN…

NodeJS React 开发环境搭建

1、首先需要安装NodeJS环境&#xff0c;下载NodeJS安装程序安装即可。 NodeJS下载地址&#xff1a; https://nodejs.org/en/download/ 2、安装NodeJS的web框架express npm install express-generator -g 3、创建项目 express studyReact 4、添加jsx引擎支持 npm install ex…

dreamweaver 正则表达式为属性值加上双引号_Python正则表达式(一)

Python正则表达式正则表达式是处理字符串的强大工具&#xff0c;拥有独特的语法和独立的处理引擎。我们在大文本中匹配字符串时&#xff0c;有些情况用str自带的函数(比如find, in)可能可以完成&#xff0c;有些情况会稍稍复杂一些(比如说找出所有“像邮箱”的字符串&#xff0…