保存模型
使用TensorFlow的saver()类先实例化一个saver对象,然后在session中通过saver的save方法将模型保存起来。代码示例如下:
#初始化所有变量
init = tf.global_variable_initializer()#定义saver和保存路径
saver = tf.train.Saver()
saverdir = "save_path"#启动Session
with tf.Session() as sess:sess.run(init)#使用saver的save方法保存saver.save(sess,saverdir + "file_name")
其中,filename如果不存在,程序会自动创建。
打印模型中的内容
使用inspect_checkpoint包中的print_tensors_in_checkpoint_file方法将模型中的具体内容打印出来。代码示例如下:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
form tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_filesaverdir = "log/"
print_tensors_in_checkpoint_file(savedir + "linearmodel.cpkt",None,True)
保存模型的其他方法
使用saver()类保存模型时,可以在函数中放入参数来实现更高级的功能,如指定存储变量名字与变量的对应关系。代码示例如下:
W = tf.Variable(1.0,name = "weight")
b = tf.Variable(2.0,name = "bias")saver = tf.train.Saver({'weight':W,'bias':b})
with tf.Session() as sess:tf.global_variables_initializer().run()saver.save(sess,savedir + "linearmodel.cpkt")
print_tensors_in_checkpoint_file(savedir + "linearmodel.cpkt",None,True)
载入模型
通过调用saver的restore()函数,从指定的路径找到模型文件,并覆盖到相关参数中。代码示例如下:
#初始化所有变量
init = tf.global_variable_initializer()#定义saver和保存路径
saver = tf.train.Saver()
saverdir = "save_path"#启动Session
with tf.Session() as sess:sess.run(init)#使用saver的restore方法载入模型print("x=0.2,z=",sess.run(z,feed_dict = {X:0.2}))