简单的线性模型实现tensorflow权重的生成和调用,并且用类的方式实现参数共享

首先看文件路径,line_regression是总文件夹,model文件夹存放权重文件,

global_variable.py写了一句话.

 

save_path='./model/weight'

权重要存放的路径,以weight命名.

lineRegulation_model.py代码

 

import tensorflow as tf
"""
类定义一些公共量,方便模型载入用
"""
class LineRegModel:def __init__(self):self.a_val=tf.Variable(tf.random_normal(shape=[1]))self.b_val = tf.Variable(tf.random_normal(shape=[1]))self.x_input=tf.placeholder(dtype=tf.float32)self.y_label = tf.placeholder(dtype=tf.float32)self.y_output = tf.multiply(self.x_input,self.a_val)+self.b_valself.loss=tf.reduce_mean(tf.pow(self.y_output-self.y_label,2))def get_op(self):return tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)

定义了一个类,方便后面共享权值恢复模型的调用

model_train.py代码:

 

import tensorflow as tf
import numpy as np
from save_and_restore import global_variable
from save_and_restore import  lineRegulation_model as model
"""
训练模型
"""
train_x=np.random.rand(5)
train_y=train_x*5+3
model=model.LineRegModel()#类要加括号
a_val=model.a_val
b_val=model.b_val
x_input=model.x_input
y_label=model.y_label
y_output=model.y_output
loss=model.loss
optimizer=model.get_op()
if __name__ == '__main__':saver = tf.train.Saver()init=tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)flag=Trueepoch=0while flag:epoch+=1cost,_=sess.run([loss,optimizer],feed_dict={x_input:train_x,y_label:train_y})if cost<1e-6:flag=Falseprint('a={},b={}'.format(a_val.eval(sess),b_val.eval(sess)))print('epoch={}'.format(epoch))saver.save(sess,global_variable.save_path)print('model save finish')

训练模型,并且存放模型的目的,这样前面三段代码就可以实现简单的线性模型权重的生成和存放。

其中checkpoint指的是检查点文件,记录存储文件名称,weight.data_00000-of-00001权重存储文件,weight.index存储权重目录

weight.meta模型的全部图文件,所以weight.data_00000-of-00001和weight.meta是最大的。

model_restore.py代码如下:

import tensorflow as tf
from save_and_restore import global_variable,lineRegulation_model as model
"""
加载模型
"""
model=model.LineRegModel()
x_input=model.x_input
y_output=model.y_output
init=tf.global_variables_initializer()
saver=tf.train.Saver()
with tf.Session() as sess:sess.run(init)saver.restore(sess,global_variable.save_path)result=sess.run(y_output,feed_dict={x_input:[1]})print(result)

调用生成的模型打印出预测结果:

结果和8差不多。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/493636.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

comparing ORB and AKAZE

文章全称是《Comparing ORB and AKAZE for visual odometry of unmanned aerial vehicles》。这是一篇来自巴西的文章&#xff0c;没有在百度文库中找到&#xff0c;是在其他博客中给出的链接得到的。从链接的URL可以看出这是一篇来自会议CCIS云计算与智能系统国际会议的文章。…

利用 CRISPR 基因编辑技术,人类正在做七件“疯狂”的事

来源&#xff1a;36Kr编译&#xff1a;喜汤很少有哪种现代科学创新能像CRISPR基因编辑技术一样影响深远。有了它&#xff0c;科学家们可以精确地改变任何细胞的DNA。CRISPR技术成为新宠&#xff0c;部分原因是它比早期基因编辑技术更容易使用。尽管CRISPR还没有彻底大展身手&am…

吴恩达作业3:利用深层神经网络实现小猫的分类

利用4层神经网络实现小猫的分类&#xff0c;小猫训练样本是&#xff08;209&#xff0c;64*64*312288&#xff09;,故输入节点是12288个&#xff0c;隐藏层节点依次为20&#xff0c;7&#xff0c;5&#xff0c;输出层为1。 首先看文件路径&#xff0c;dnn_utils_v2.py代码是激活…

A-KAZE论文研读

AKAZE是KAZE的加速版本。KAZE在构建非线性空间的过程中很耗时&#xff0c;在AKAZE中将Fast Explicit Diffusion(FED)加入到金字塔框架可以dramatically speed-up。在描述子方面&#xff0c;AKAZE使用了更高效的Modified Local Difference Binary(M-LDB)&#xff0c;可以从非线性…

和你抢“饭碗”的40家服务机器人企业大盘点!

来源&#xff1a;物联网智库摘要&#xff1a;本文将对国内近40家服务机器人企业进行汇总介绍&#xff0c;所选企业在其相应版块活跃度较高。从三个大类进行了细分盘点。国家机器人联盟&#xff08;IFR&#xff09;根据应用环境将机器人分为了工业机器人和服务机器人。服务机器人…

YOLO9000

YOLO9000是YOLO的第三个版本。前两个版本是YOLO v1&#xff0c;YOLO v2&#xff0c;在CVPR2017的文章《Better,Faster,Stronger》中的前半部分都是对前两个版本的介绍&#xff0c;新的内容主要在Stronger部分。YOLO9000中的9000指的是YOLO可以对超过9000种图像进行分类。 Bett…

吴恩达作业4:权重初始化

权重初始化的 正确选择能够有效的避免多层神经网络传播过程中的梯度消失和梯度爆炸问题&#xff0c;下面通过三个初始化的方法来验证&#xff1a; sigmoid导数函数&#xff1a;最大值小于0.25&#xff0c;故经过多层反向传播以后&#xff0c;会导致最初的层&#xff0c;权重无…

先发制人!Waymo将首推商用载人自动驾驶服务,Uber们怕不怕?

编译&#xff1a;费棋来源&#xff1a;AI科技大本营“真的&#xff0c;真的很难。”11 月举办的一场会议上&#xff0c;Alphabet 旗下 Waymo CEO John Krafcik 对做自动驾驶汽车技术的艰难不无感慨。在他看来&#xff0c;未来几十年内&#xff0c;自动驾驶汽车将一直存在限制&a…

利用ORB/AKAZE特征点进行图像配准

Kp1,kp2都是list类型&#xff0c;两幅图都是500个特征点。这和ORB论文中的数据是一样的。4.4章节 Matches也是list类型&#xff0c;找到325个匹配对。 AKAZE文章中提到一个指标&#xff1a;MS(matching score)# Correct Matches/# Features, 如果overlap area error 小于40%…

吴恩达作业5:正则化和dropout

构建了三层神经网络来验证正则化和dropout对防止过拟合的作用。 首先看数据集&#xff0c;reg_utils.py包含产生数据集函数&#xff0c;前向传播&#xff0c;计算损失值等&#xff0c;代码如下&#xff1a; import numpy as np import matplotlib.pyplot as plt import h5py …

十年之后,数字孪生将这样改变我们的工作与生活

来源&#xff1a;资本实验室数字孪生是近几年兴起的非常前沿的新技术&#xff0c;简单说就是利用物理模型&#xff0c;使用传感器获取数据的仿真过程&#xff0c;在虚拟空间中完成映射&#xff0c;以反映相对应的实体的全生命周期过程。在未来&#xff0c;物理世界中的各种事物…

什么是图像

图像&#xff0c;尤其是数字图像的定义&#xff0c;在冈萨雷斯的书中是一个二维函数f(x,y),x,y是空间平面坐标&#xff0c;幅值f是图像在该点处的灰度或者强度。下面通过OpenCV中最常用的图像表示方法Mat来看一下在计算机中是怎么定义图像的。 Mat的定义 OpenCV在2.0之后改用…

吴恩达作业6:梯度检验

梯度检验的目的就是看反向传播过程中的导数有没有较大的误差&#xff0c;首先看Jtheta*x的梯度检验&#xff1a;代码如下 import numpy as np """ Jx*theta的前向传播 """ def forward_propagation(x,theta):Jx*thetareturn J ""&quo…

10年后的计算机会是怎样的?

作者&#xff1a;孙鹏&#xff08;剑桥大学计算机系博士&#xff09;来源&#xff1a;新原理研究所上个世纪三十年代&#xff0c;邱奇和图灵共同提出了通用计算机的概念[1]。在接下来的十多年里&#xff0c;因为战争需要下的国家推动&#xff0c;计算机得以很快从理论发展成为实…

什么是图像变换

还是看OpenCV官方手册&#xff0c;我觉得这样可以同时学习如何使用函数和如何理解一些基本概念。 首先&#xff0c;这里的几何变换geometrical transformations是针对2D图像而言的&#xff0c;不改变图像内容而是将像素网格变形deform the pixel grid&#xff0c;映射到目标图…

MSRA20周年研究趋势文章|图像识别的未来:机遇与挑战并存

文/微软亚洲研究院 代季峰 林思德 郭百宁识别图像对人类来说是件极容易的事情&#xff0c;但是对机器而言&#xff0c;这也经历了漫长岁月。在计算机视觉领域&#xff0c;图像识别这几年的发展突飞猛进。例如&#xff0c;在 PASCAL VOC 物体检测基准测试中&#xff0c;检测器的…

吴恩达作业7:梯度下降优化算法

先说说BatchGD用整个训练样本进行训练得出损失值&#xff0c;SGD是只用一个训练样本训练就得出损失值&#xff0c;GD导致训练慢&#xff0c;SGD导致收敛到最小值不平滑&#xff0c;故引入Mini-batch GD&#xff0c;选取部分样本进行训练得出损失值&#xff0c; 普通梯度下降算…

什么是单应矩阵和本质矩阵

知乎上面的大牛还是很多&#xff0c;直接搜Homography或者单应矩阵就能得到很多大神的回答&#xff0c;可能回答中的一句话或者一个链接就够自己学习很久。 其实在之前研究双目视觉的时候就接触了对极几何&#xff0c;通过视觉就可以得到物体的远近信息&#xff0c;这也是特斯…

tensorflow实现反卷积

先看ogrid用法 from numpy import ogrid,repeat,newaxis from skimage import io import numpy as np size3 x,yogrid[:size,:size]#第一部分产生多行一列 第二部分产生一行多列 print(x) print(y) 打印结果&#xff1a; newaxis用法&#xff1a; """ newaxis…

寿命能推算吗?加州大学科学家提出“预测方法”

来源&#xff1a;中国科学报从古至今&#xff0c;从国内到国外&#xff0c;从炼丹术到现代科学&#xff0c;长生不老似乎一直是人类乐此不疲的追求。但若要延缓衰老&#xff0c;首先要弄清是什么造成了衰老。近日&#xff0c;加州大学洛杉矶分校&#xff08;UCLA&#xff09;生…