和这篇文章对比https://blog.csdn.net/fanzonghao/article/details/81023730
不希望重复定义图上的运算,也就是在模型恢复过程中,不想sess.run(init)首先看路径
lineRegulation_model.py定义线性回归类:
import tensorflow as tf
"""
类定义一些公共量,方便模型载入用
"""
class LineRegModel:def __init__(self):with tf.variable_scope('var'):self.a_val=tf.Variable(tf.random_normal(shape=[1]),name='a_val')self.b_val = tf.Variable(tf.random_normal(shape=[1]),name='b_val')self.x_input=tf.placeholder(dtype=tf.float32,name='input_placeholder')self.y_label = tf.placeholder(dtype=tf.float32,name='result_placeholder')self.y_output = tf.add(tf.multiply(self.x_input,self.a_val),self.b_val,name='output')self.loss=tf.reduce_mean(tf.pow(self.y_output-self.y_label,2))def get_saver(self):return tf.train.Saver()def get_op(self):return tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)
model_train.py定义模型训练过程
import tensorflow as tf
import numpy as np
from save_and_restore2 import global_variable
from save_and_restore2 import lineRegulation_model as model
import os
if not os.path.exists('./model'):os.makedirs('./model')
"""
训练模型
"""
train_x=np.random.rand(5)
train_y=train_x*5+3
model=model.LineRegModel()#类要加括号
a_val=model.a_val
b_val=model.b_val
x_input=model.x_input
y_label=model.y_label
y_output=model.y_output
loss=model.loss
optimizer=model.get_op()
saver=model.get_saver()
if __name__ == '__main__':init=tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)flag=Trueepoch=0while flag:epoch+=1cost,_=sess.run([loss,optimizer],feed_dict={x_input:train_x,y_label:train_y})if cost<1e-6:flag=Falseprint('a={},b={}'.format(a_val.eval(sess),b_val.eval(sess)))print('epoch={}'.format(epoch))print(a_val)# print(a_val.op)saver.save(sess,global_variable.save_path)print('model save finish')
print(a_val)的形式
print(a_val.op)的形式
model_restore.py恢复模型 ,利用恢复图在恢复权重的方式,可实现更细节的模型恢复
import tensorflow as tf
from save_and_restore import global_variable,lineRegulation_model as model
"""
恢复模型图文件
"""
saver=tf.train.import_meta_graph('./model/weight.meta')
#读取placeholder和最终的输出结果
graph=tf.get_default_graph()
a_val=graph.get_tensor_by_name('var/a_val:0')
b_val=graph.get_tensor_by_name('var/b_val:0')input_placeholder=graph.get_tensor_by_name('input_placeholder:0')
labels_placeholder=graph.get_tensor_by_name('result_placeholder:0')
y_output=graph.get_tensor_by_name('output:0')with tf.Session() as sess:#具体权重的恢复saver.restore(sess,'./model/weight')result=sess.run(y_output,feed_dict={input_placeholder:[1]})print(result)print(sess.run(a_val))print(sess.run(b_val))