tensorflow 保存训练loss_tensorflow2.0保存和加载模型 (tensorflow2.0官方教程翻译)

633fe6d9560fb5fed5a44e35bd1b3b78.png
最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-keras-save_and_restore_models.html
英文版本:https://tensorflow.google.cn/alpha/tutorials/keras/save_and_restore_models
翻译建议PR:https://github.com/mashangxue/tensorflow2-zh/edit/master/r2/tutorials/keras/save_and_restore_models.md

模型进度可以在训练期间和训练后保存。这意味着模型可以在它停止的地方继续,并避免长时间的训练。保存还意味着您可以共享您的模型,其他人可以重新创建您的工作。当发布研究模型和技术时,大多数机器学习实践者共享: 用于创建模型的代码 以及模型的训练权重或参数

共享此数据有助于其他人了解模型的工作原理,并使用新数据自行尝试。

注意:小心不受信任的代码(TensorFlow模型是代码)。有关详细信息,请参阅安全使用TensorFlow 。

选项

保存TensorFlow模型有多种方法,具体取决于你使用的API。本章节使用tf.keras(一个高级API,用于TensorFlow中构建和训练模型),有关其他方法,请参阅TensorFlow保存和还原指南或保存在eager中。

1. 设置

1.1. 安装和导入

需要安装和导入TensorFlow和依赖项

pip install h5py pyyaml

1.2. 获取样本数据集

我们将使用MNIST数据集来训练我们的模型以演示保存权重,要加速这些演示运行,请只使用前1000个样本数据:

from __future__ import absolute_import, division, print_function, unicode_literalsimport os!pip install tensorflow==2.0.0-alpha0
import tensorflow as tf
from tensorflow import kerastf.__version__
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()train_labels = train_labels[:1000]
test_labels = test_labels[:1000]train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

1.3. 定义模型

让我们构建一个简单的模型,我们将用它来演示保存和加载权重。

# 返回一个简短的序列模型 
def create_model():model = tf.keras.models.Sequential([keras.layers.Dense(512, activation='relu', input_shape=(784,)),keras.layers.Dropout(0.2),keras.layers.Dense(10, activation='softmax')])model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])return model# 创建基本模型实例
model = create_model()
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 512)               401920    
_________________________________________________________________
dropout (Dropout)            (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

2. 在训练期间保存检查点

主要用例是在训练期间和训练结束时自动保存检查点,通过这种方式,您可以使用训练有素的模型,而无需重新训练,或者在您离开的地方继续训练,以防止训练过程中断。

tf.keras.callbacks.ModelCheckpoint是执行此任务的回调,回调需要几个参数来配置检查点。

2.1. 检查点回调使用情况

训练模型并将其传递给 ModelCheckpoint回调

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)# 创建一个检查点回调
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,save_weights_only=True,verbose=1)model = create_model()model.fit(train_images, train_labels,  epochs = 10,validation_data = (test_images,test_labels),callbacks = [cp_callback])  # pass callback to training# 这可能会生成与保存优化程序状态相关的警告。
# 这些警告(以及整个笔记本中的类似警告)是为了阻止过时使用的,可以忽略。
Train on 1000 samples, validate on 1000 samples......Epoch 10/10960/1000 [===========================>..] - ETA: 0s - loss: 0.0392 - accuracy: 1.0000Epoch 00010: saving model to training_1/cp.ckpt1000/1000 [==============================] - 0s 207us/sample - loss: 0.0393 - accuracy: 1.0000 - val_loss: 0.3976 - val_accuracy: 0.8750<tensorflow.python.keras.callbacks.History at 0x7efc3eba7358>

这将创建一个TensorFlow检查点文件集合,这些文件在每个周期结束时更新。 文件夹checkpoint_dir下的内容如下:(Linux系统使用 ls命令查看)

checkpoint  cp.ckpt.data-00000-of-00001  cp.ckpt.index

创建一个新的未经训练的模型,仅从权重恢复模型时,必须具有与原始模型具有相同体系结构的模型,由于它是相同的模型架构,我们可以共享权重,尽管它是模型的不同示例。

现在重建一个新的,未经训练的模型,并在测试集中评估它。未经训练的模型将在随机水平(约10%的准确率):

model = create_model()loss, acc = model.evaluate(test_images, test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))
1000/1000 [==============================] - 0s 107us/sample - loss: 2.3224 - accuracy: 0.1230
Untrained model, accuracy: 12.30%

然后从检查点加载权重,并重新评估:

model.load_weights(checkpoint_path)
loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
1000/1000 [==============================] - 0s 48us/sample - loss: 0.3976 - accuracy: 0.8750
Restored model, accuracy: 87.50%

2.2. 检查点选项

回调提供了几个选项,可以为生成的检查点提供唯一的名称,并调整检查点频率。

训练一个新模型,每5个周期保存一次唯一命名的检查点:

# 在文件名中包含周期数. (使用 `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, verbose=1, save_weights_only=True,# 每5个周期保存一次权重period=5)model = create_model()
model.save_weights(checkpoint_path.format(epoch=0))
model.fit(train_images, train_labels,epochs = 50, callbacks = [cp_callback],validation_data = (test_images,test_labels),verbose=0)
Epoch 00005: saving model to training_2/cp-0005.ckpt
......
Epoch 00050: saving model to training_2/cp-0050.ckpt
<tensorflow.python.keras.callbacks.History at 0x7efc7c3bbd30>

现在,查看生成的检查点并选择最新的检查点:

latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
'training_2/cp-0050.ckpt'

注意:默认的tensorflow格式仅保存最近的5个检查点。

要测试,请重置模型并加载最新的检查点:

model = create_model()
model.load_weights(latest)
loss, acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
1000/1000 [==============================] - 0s 84us/sample - loss: 0.4695 - accuracy: 0.8810Restored model, accuracy: 88.10%

3. 这些文件是什么?

上述代码将权重存储到检查点格式的文件集合中,这些文件仅包含二进制格式的训练权重. 检查点包含: 一个或多个包含模型权重的分片; 索引文件,指示哪些权重存储在哪个分片。

如果您只在一台机器上训练模型,那么您将有一个带有后缀的分片:.data-00000-of-00001

4. 手动保存权重

上面你看到了如何将权重加载到模型中。手动保存权重同样简单,使用Model.save_weights方法。

# 保存权重
model.save_weights('./checkpoints/my_checkpoint')# 加载权重
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

5. 保存整个模型

模型和优化器可以保存到包含其状态(权重和变量)和模型配置的文件中,这允许您导出模型,以便可以在不访问原始python代码的情况下使用它。由于恢复了优化器状态,您甚至可以从中断的位置恢复训练。

保存完整的模型非常有用,您可以在TensorFlow.js(HDF5, Saved Model) 中加载它们,然后在Web浏览器中训练和运行它们,或者使用TensorFlow Lite(HDF5, Saved Model)将它们转换为在移动设备上运行。

5.1. 作为HDF5文件

Keras使用HDF5标准提供基本保存格式,出于我们的目的,可以将保存的模型视为单个二进制blob。

model = create_model()model.fit(train_images, train_labels, epochs=5)# 保存整个模型到HDF5文件 
model.save('my_model.h5')

现在从该文件重新创建模型:

# 重新创建完全相同的模型,包括权重和优化器
new_model = keras.models.load_model('my_model.h5')
new_model.summary()
Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_12 (Dense)             (None, 512)               401920    
_________________________________________________________________
dropout_6 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_13 (Dense)             (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

检查模型的准确率:

loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
1000/1000 [==============================] - 0s 94us/sample - loss: 0.4137 - accuracy: 0.8540
Restored model, accuracy: 85.40%

此方法可保存模型的所有东西: 权重值 模型的配置(架构) * 优化器配置

Keras通过检查架构来保存模型,目前它无法保存TensorFlow优化器(来自tf.train)。使用这些时,您需要在加载后重新编译模型,否则您将失去优化程序的状态。

5.2. 作为 saved_model

注意:这种保存tf.keras模型的方法是实验性的,在将来的版本中可能会有所改变。

创建一个新的模型:

model = create_model()model.fit(train_images, train_labels, epochs=5)

创建saved_model,并将其放在带时间戳的目录中:

import time
saved_model_path = "./saved_models/{}".format(int(time.time()))tf.keras.experimental.export_saved_model(model, saved_model_path)
saved_model_path
'./saved_models/1555630614'

从保存的模型重新加载新的keras模型:

new_model = tf.keras.experimental.load_from_saved_model(saved_model_path)
new_model.summary()
Model: "sequential_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_14 (Dense)             (None, 512)               401920    
_________________________________________________________________
dropout_7 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_15 (Dense)             (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

运行加载的模型进行预测:

model.predict(test_images).shape
(1000, 10)
# 必须要在评估之前编译模型
# 如果仅部署已保存的模型,则不需要此步骤 new_model.compile(optimizer=model.optimizer,  # keep the optimizer that was loadedloss='sparse_categorical_crossentropy',metrics=['accuracy'])# 评估加载后的模型 
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
1000/1000 [==============================] - 0s 102us/sample - loss: 0.4367 - accuracy: 0.8570Restored model, accuracy: 85.70%

6. 下一步是什么

这是使用tf.keras保存和加载的快速指南。

  • tf.keras指南显示了有关使用tf.keras保存和加载模型的更多信息。
  • 在eager execution期间保存,请参阅在Saving in eager。
  • 保存和还原指南包含有关TensorFlow保存的低阶详细信息。
最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-keras-save_and_restore_models.html
英文版本:https://tensorflow.google.cn/alpha/tutorials/keras/save_and_restore_models
翻译建议PR:https://github.com/mashangxue/tensorflow2-zh/edit/master/r2/tutorials/keras/save_and_restore_models.md

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/527652.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

layui导入模板数据_数据可视化图表 教程echarts,第一讲

1我们写web项目&#xff0c;展示数据的地方&#xff0c;可能会使用到图表。今天就讲这个玩意。本教程暂时定为 三讲&#xff1a;(随后情况&#xff0c;如果有新的研究&#xff0c;会有所更新&#xff01;)第一讲 饼图的使用第二讲 柱状图的使用第三讲 拆线图的使用此教程希望…

出发a标签_以用户标签为例,复盘B端产品的需求挖掘方法论

阅读指南受众人群&#xff1a;B端初级产品经理阅读收获&#xff1a;B端产品需求挖掘的一些技巧&#xff1b;了解用户标签/画像的一些业务知识。手上负责一个和数据方面有关的B端系统&#xff0c;在日常的产品规划当中&#xff0c;没有关于“用户标签”方面的规划&#xff0c;突…

字符ascii码值转换_没想到 Unicode 字符还能这样玩?

脚本之家你与百万开发者在一起来源 | 程序通事(ID&#xff1a;US_stocks)如若转载请联系原公众号上周的时候&#xff0c;朋友圈的直升飞机不知道为什么就火了&#xff0c;很多朋友开着各种花式飞机带着起飞。图片来自网络还没来得及了解咋回事来着&#xff0c;这个直升飞机就?…

右键菜单无响应_被流氓软件玩坏了?这两个清理工具拯救你凌乱的右键菜单。...

Hello 这里是一周进步我们写了四年近2000篇的干货文章&#xff0c;还分享了许多实用的神器工具&#xff0c;一路以来&#xff0c;感谢大家的支持与陪伴~文 / 一周进步 安哥拉如果你和我们一样&#xff0c;是一个喜欢在电脑上安装各种各样的软件的人&#xff0c;你的电脑右键菜…

jsp mysql源码_jsp+servlet+mysql员工管理系统源代码下载

jspservletmysql员工管理系统项目截图注册页面登录页面添加员工编辑员工员工列表数据库建表语句/*Navicat MySQL Data TransferSource Server : localhostSource Server Version : 50509Source Host : localhost:3306Source Database : wdhdbTarget Server Type : MYSQLTarget …

vs里安装了mysql吗_vs2017安装 MySQL for Visual Studio 1.2.

vs2017安装想在win7EF6 VS2017 MySQL 但是安装MySQL for Visual Studio 1.2.7 时一直安装不上去&#xff0c;如下&#xff1a;Action 9:40:05: InstallFinalize.1: Action 9:40:05: DeleteRegKeyAndExtensionsFile_VS2013.1: Action 9:40:06: DeleteRegKeyAndExtensionsFile_…

mysql数据库优化语句_mysql数据库优化语句

mysql优化语句数据库语句&#xff1a; Ddl(数据定义语言) alter create drop Dml(数据操作语言) inset delete update www.2cto.com Dtl(数据事务语言) conmmit rollback savepoint Select Dcl(数据控制语句) grant赋权限 revoke回收 Mysql数据库优化&#xff1a; 1、 数据库表…

json模拟数据怎么用_在使用axios获取自己模拟的json数据是踩到的坑

最近在使用Vue仿写一个网易云音乐的单页面应用&#xff0c;当页面布局什么的写完后&#xff0c;然后就准备用axios获取后台数据渲染页面了&#xff0c;当然&#xff0c;我自己写的&#xff0c;并没有后台&#xff0c;所以&#xff0c;我就自己写json文件&#xff0c;然后弄prox…

mysql架构深入_mysql性能优化2:深入认识mysql体系架构

前言本文将重点梳理mysql的体系架构&#xff0c;便于了解mysql的实现原理。Mysql体系结构Client Connectors 接入方 支持协议很多Management Serveices & Utilities 系统管理和控制工具&#xff0c;mysqldump、 mysql复制集群、分区管理等Connection Pool 连接池&#xff1…

mysql租车管理系统_基于java实现租车管理系统

概述基于java swing JFrame 的图书馆管理系统&#xff0c;租车&#xff0c;还车&#xff0c;管理员管理用户&#xff0c;付款等。部分代码public class Login extends JFrame {private static final long serialVersionUID 1L;/*** 登录窗体*/public Login() {setDefaultClo…

java 1的阶乘之和_1-20的阶乘之和(java)

import java.math.BigInteger;public class Factorial {//2)求1&#xff01;2&#xff01;……20&#xff01;public static void main(String[] args){BigInteger sumBigInteger.ZERO;for(BigInteger iBigInteger.ONE;i.intValue()<20;){ii.add(BigInteger.ONE);sumsum.add…

java构建json_Java构造和解析Json数据的两种方法详解一

在www.json.org上公布了很多JAVA下的json构造和解析工具&#xff0c;其中org.json和json-lib比较简单&#xff0c;两者使用上差不多但还是有些区别。下面首先介绍用json-lib构造和解析Json数据的方法示例。用org.son构造和解析Json数据的方法详解请参见我下一篇博文&#xff1a…

java final被覆盖_java中的final的使用

1、final类不能被继承&#xff0c;因此final类的成员方法没有机会被覆盖&#xff0c;默认都是final的。在设计类时候&#xff0c;如果这个类不需要有子类&#xff0c;类的实现细节不允许改变&#xff0c;并且确信这个类不会再被扩展&#xff0c;那么就设计为final类。(什么时候…

wordcount.java_mapreduce中wordcount的java实现

用java模拟词频统计。有3个文件&#xff1a;text1: hello worldtext2:hello hadooptext3:hello mapreduce对上面的文件进行词频统计&#xff1a;结果应该是&#xff1a;hello:3; hadoop:1; world:1; mapreduce:1代码实现如下&#xff1a;package count;import java.ut…

java程序回滚之后在哪看_Java在触发事务回滚之后为什么会再一次回到Servlet开始的地方重新走一次流程?...

代码流程前台点击"提交订单"进入BaseServlet.classBaseServlet.class分发至子类OrderServlet.class的submitOrder()方法submitOrder()调用Service层的submitOrder()方法.关键是Service层submitOrder()中使用了事务回滚. 这里调用了Dao层两个方法: fun01()和fun02(), …

java不进入for_为什么阿里巴巴Java开发手册中强制要求不要在foreach循环里进行元素的remove和add操作?...

在阅读《阿里巴巴Java开发手册》时&#xff0c;发现有一条关于在 foreach 循环里进行元素的 remove/add 操作的规约&#xff0c;具体内容如下&#xff1a;错误演示我们首先在 IDEA 中编写一个在 foreach 循环里进行 remove 操作的代码&#xff1a;import java.util.ArrayList;i…

8086汇编4位bcd码_二进制格雷码与自然二进制码的互换分析

在精确定位控制系统中&#xff0c;为了提高控制精度&#xff0c;准确测量控制对象的位置是十分重要的。目前&#xff0c;检测位置的办法有两种&#xff1a;其一是使用位置传感器&#xff0c;测量到的位移量由变送器经A/D转换成数字量送至系统进行进一步处理。此方法精度高&…

软件工程结构化建模的方法和工具_软件工程系列-结构化设计方法2

本系列文章为笔记&#xff0c;内容根据北京大学《软件工程》MOOC 初始化模块结构图精化的启发式规则常见的启发式规则什么叫做“启发式”根据设计准则&#xff0c;从长期的软件开发实践中&#xff0c;总结出来的规则既不是设计目标&#xff0c;也不是设计时应该普遍遵循的原理常…

java四种权限的高低_Java(四种权限修饰符)

/*Java中有四种权限修饰符&#xff1a;public > protected > (default) > private同一个类(我自己) YES YES YES YES同一个包(我邻居) YES YES YES NO不同包子类(我儿子) YES YES NO NO不同包非子类(陌生人) YES NO NO NO注意事项&#xff1a;(default)并不是关键字“…

安全扫描失败无法上传_Apache Solr 未授权上传(RCE)漏洞的原理分析与验证

漏洞简介Apache Solr 发布公告&#xff0c;旧版本的ConfigSet API 中存在未授权上传漏洞风险&#xff0c;被利用可能导致 RCE (远程代码执行)。受影响的版本&#xff1a;Apache Solr6.6.0 -6.6.5Apache Solr7.0.0 -7.7.3Apache Solr8.0.0 -8.6.2安全专家建议用户尽快升级到安全…