模型的保存与恢复
我们来简单实现一下模型的保存与恢复
训练完TensorFlow模型后,可将其保存为文件,以便于预测新数据时直接加载使用。
TensorFlow模型主要包含网络的设计或者图以及已经训练好的网络参数的值。
TensorFlow提供的tf.train.Saver()函数可以建立一个saver对象,在会话中调用其save()函数,即可将模型保存起来
save()函数的用法
函数 | 说明 |
save( sess, sace_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True, write_state=True ) | sess:保存模型,要求必须有一个加载了计算图的会话,且所有变量已被初始化。 sace_path:模型保存路径及保存名称 global_step:如果提供,该数字会添加到save_path后,用于区分不同训练阶段的结果 latest_filename:检查点文件的名称,默认是checkpoint meta_graph_suffix= MetaGraphDef元图后缀,默认为meta write_meta_graph=是否要保存元图数据,默认为True write_state:是否要保存CheckpointStateProto,默认为True |
模型保存
import tensorflow as tf
m1 = tf.Variable(tf.constant([[1.0,3.0],[2.0,4.0]],shape=[2,2]),name='m1')
m2 = tf.Variable(tf.constant([[2.0,7.0],[3.0,8.0]],shape=[2,2]),name='m2')
result = m1 + m2
saver = tf.train.Saver()
with tf.Session() as sess:sess.run(tf.global_variables_initializer())print('resulit:',sess.run(result))saver.save(sess,'C:/model/model.ckpt')
运行程序,当前目录的model文件夹下会产生4个文件:checkpoint,data-00000-of-00001,meta和index
checkpoint:保存模型的权重、偏置、梯度以及其他保护变量的二进制文件。
data:保存模型的所有变量的值
meta:保存计算图的结构。当meta文件存在时,不在程序中定义模型,直接加载meta可以直接运行
index:保存string-string的键值对。其中的key值为张量名,value为BundleEntryProto
模型恢复
模型保存好了以后,载入发出方便。
在会话中调用saver的restore()函数,就会从指定的路径找到模型文件,并覆盖相关参数。
saver.restore()函数的形式如表
函数 | 说明 |
saver.restore( sess, save_path ) | 从指定的路径恢复模型。 sess:用于恢复参数模型的会话 save_path:已保存模型的路径,通常包含模型名字 |
import tensorflow as tf
tf.reset_default_graph()
v1 = tf.Variable(tf.constant([[5.0,6.0],[7.0,7.0]],shape=[2,2]),name='m1')
v2 = tf.Variable(tf.constant([[4.0,6.0],[7.0,8.0]],shape=[2,2]),name='m2')
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:saver.restore(sess,'C:/model/model.ckpt')print(sess.run(result))