『TensorFlow』模型保存和载入方法汇总

一、TensorFlow常规模型加载方法

保存模型

tf.train.Saver()类,.save(sess, ckpt文件目录)方法

参数名称功能说明默认值
var_listSaver中存储变量集合全局变量集合
reshape加载时是否恢复变量形状True
sharded是否将变量轮循放在所有设备上True
max_to_keep保留最近检查点个数5
restore_sequentially是否按顺序恢复变量,模型较大时顺序恢复内存消耗小True

 

var_list是字典形式{变量名字符串: 变量符号},相对应的restore也根据同样形式的字典将ckpt中的字符串对应的变量加载给程序中的符号。

如果Saver给定了字典作为加载方式,则按照字典来,如:saver = tf.train.Saver({"v/ExponentialMovingAverage":v}),否则每个变量寻找自己的name属性在ckpt中的对应值进行加载。

 

加载模型

当我们基于checkpoint文件(ckpt)加载参数时,实际上我们使用Saver.restore取代了initializer的初始化

checkpoint文件会记录保存信息,通过它可以定位最新保存的模型:

1

2

ckpt = tf.train.get_checkpoint_state('./model/')

print(ckpt.model_checkpoint_path)

 

.meta文件保存了当前图结构

.data文件保存了当前参数名和值

.index文件保存了辅助索引信息

.data文件可以查询到参数名和参数值,使用下面的命令可以查询保存在文件中的全部变量{名:值}对,

1

2

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

print_tensors_in_checkpoint_file(os.path.join(savedir,savefile),None,True)

tf.train.import_meta_graph函数给出model.ckpt-n.meta的路径后会加载图结构,并返回saver对象

1

ckpt = tf.train.get_checkpoint_state('./model/')

tf.train.Saver函数会返回加载默认图的saver对象,saver对象初始化时可以指定变量映射方式,根据名字映射变量(『TensorFlow』滑动平均)

1

saver = tf.train.Saver({"v/ExponentialMovingAverage":v}) 

saver.restore函数给出model.ckpt-n的路径后会自动寻找参数名-值文件进行加载

1

2

saver.restore(sess,'./model/model.ckpt-0')

saver.restore(sess,ckpt.model_checkpoint_path)

1.不加载图结构,只加载参数

由于实际上我们参数保存的都是Variable变量的值,所以其他的参数值(例如batch_size)等,我们在restore时可能希望修改,但是图结构在train时一般就已经确定了,所以我们可以使用tf.Graph().as_default()新建一个默认图(建议使用上下文环境),利用这个新图修改和变量无关的参值大小,从而达到目的。

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

'''

使用原网络保存的模型加载到自己重新定义的图上

可以使用python变量名加载模型,也可以使用节点名

'''

import AlexNet as Net

import AlexNet_train as train

import random

import tensorflow as tf

 

IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'

 

with tf.Graph().as_default() as g:

 

    x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3])

    y = Net.inference_1(x, N_CLASS=5, train=False)

 

    with tf.Session() as sess:

        # 程序前面得有 Variable 供 save or restore 才不报错

        # 否则会提示没有可保存的变量

        saver = tf.train.Saver()

 

        ckpt = tf.train.get_checkpoint_state('./model/')

        img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()

        img = sess.run(tf.expand_dims(tf.image.resize_images(

            tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0))

 

        if ckpt and ckpt.model_checkpoint_path:

            print(ckpt.model_checkpoint_path)

            saver.restore(sess,'./model/model.ckpt-0')

            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]

            res = sess.run(y, feed_dict={x: img})

            print(global_step,sess.run(tf.argmax(res,1)))

  2.加载图结构和参数

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

'''

直接使用使用保存好的图

无需加载python定义的结构,直接使用节点名称加载模型

由于节点形状已经定下来了,所以有不便之处,placeholder定义batch后单张传会报错

现阶段不推荐使用,以后如果理解深入了可能会找到使用方法

'''

import AlexNet_train as train

import random

import tensorflow as tf

 

IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'

 

 

ckpt = tf.train.get_checkpoint_state('./model/')                          # 通过检查点文件锁定最新的模型

saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')   # 载入图结构,保存在.meta文件中

 

with tf.Session() as sess:

    saver.restore(sess,ckpt.model_checkpoint_path)                        # 载入参数,参数保存在两个文件中,不过restore会自己寻找

 

    img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()

    img = sess.run(tf.image.resize_images(

        tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)))

    imgs = []

    for i in range(128):

       imgs.append(img)

    print(sess.run(tf.get_default_graph().get_tensor_by_name('fc3:0'),feed_dict={'Placeholder:0': imgs}))

 

    '''

    img = sess.run(tf.expand_dims(tf.image.resize_images(

        tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)), 0))

    print(img)

    imgs = []

    for i in range(128):

        imgs.append(img)

    print(sess.run(tf.get_default_graph().get_tensor_by_name('conv1:0'),

                   feed_dict={'Placeholder:0':img}))

注意,在所有两种方式中都可以通过调用节点名称使用节点输出张量,节点.name属性返回节点名称。

  3.简化版本

1

2

3

4

5

6

7

8

9

10

11

12

# 连同图结构一同加载

ckpt = tf.train.get_checkpoint_state('./model/')

saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')

with tf.Session() as sess:

    saver.restore(sess,ckpt.model_checkpoint_path)

             

# 只加载数据,不加载图结构,可以在新图中改变batch_size等的值

# 不过需要注意,Saver对象实例化之前需要定义好新的图结构,否则会报错

saver = tf.train.Saver()

with tf.Session() as sess:

    ckpt = tf.train.get_checkpoint_state('./model/')

    saver.restore(sess,ckpt.model_checkpoint_path)

回到顶部

二、TensorFlow二进制模型加载方法

这种加载方法一般是对应网上各大公司已经训练好的网络模型进行修改的工作

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

# 新建空白图

self.graph = tf.Graph()

# 空白图列为默认图

with self.graph.as_default():

    # 二进制读取模型文件

    with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:

        # 新建GraphDef文件,用于临时载入模型中的图

        graph_def = tf.GraphDef()

        # GraphDef加载模型中的图

        graph_def.ParseFromString(f.read())

        # 在空白图中加载GraphDef中的图

        tf.import_graph_def(graph_def,name='')

        # 在图中获取张量需要使用graph.get_tensor_by_name加张量名

        # 这里的张量可以直接用于session的run方法求值了

        # 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量

        self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)

        self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name   in self.layer_operation_names]

 『TensorFlow』迁移学习_他山之石,可以攻玉

参考自https://www.cnblogs.com/hellcat/p/6925757.html#_label0_1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/567927.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STL13-list容器(链表)

链表是由一系列的结点组成,结点包括两个域:一个数据域,一个指针域 1、链表内存是非连续的,添加删除元素效率较高,时间复杂度都是常数项,不需要移动元素 2、链表只有在需要的时候才会分配内存 3、链表只要…

php 前往页面,PHP实现网页截图?

如何使用PHP实现网页截图PHP实现网页截图是一个在日常开发中不常见的需求,但是如果实现还是非常有意思的。目前业界有很多成熟的方案,下面我推荐使用一个很稳定的第三方服务来直接实现,该服务有如下特点:支持多线路支持登录截图支…

STL14-set/multiset容器

set只有一个方法就是insert #include<iostream> #include<set> //set和multiset是一个头文件 //set内部实现机制 红黑色&#xff08;平衡二叉树的一种&#xff09; //关联式容器 //set不允许有重复元素 //multiset运行有重复元素 //容器查找效率高 //容器根据元素的…

普通的java类型是指,String是一个很普通的类 - Java那些事儿

上一篇我们讲了Java中的数组&#xff0c;其实是为本章的内容做准备的&#xff0c;String这个类是我们在写Java代码中用得最多的一个类&#xff0c;没有之一&#xff0c;今天我们就讲讲它&#xff0c;我们打开String这个类的源码&#xff1a;声明了一个char[]数组&#xff0c;变…

STL15-map/multimap容器

map的key值不可以重复 multimap的key值可以重复 #if 1 #include<iostream> #include<map> using namespace std; //初始化 void test01() {//map容器参数 第一个参数key的类型 第二个参数value类型map<int, int> mymap;//插入元素 pair.first key值 pair.se…

php nginx日志分析,如何通过NGINX的log日志来分析网站的访问情况,试试这些命令...

想知道你的网站每天的访问情况吗&#xff1f;有多少人访问了&#xff1f;访问最多的页面是哪个&#xff1f;哪个时段访问的人最多&#xff1f;哪个地方访问的最多&#xff1f;每秒有多少请求&#xff1f;很好奇吧&#xff0c;只要你是使用了nginx进行请求抓发&#xff0c;那么就…

php带来互联网的影响,网络对我们的影响有哪些?

影响有&#xff1a;1、丰富了我们的业余生活&#xff1b;2、降低了获取知识的成本&#xff0c;降低了提升工作的能力的成本&#xff0c;提高了工作的效率&#xff0c;可以快速建立良好的人脉关系&#xff1b;3、让购物变得更加简单便捷&#xff1b;4、朋友间深度沟通与交流越来…

STL17-函数对象

仿函数&#xff1a; #include<iostream> #include<vector> #include<algorithm> using namespace std; //仿函数&#xff08;函数对象&#xff09;重载“&#xff08;&#xff09;”操作符 使类对象可以像函数那样调用 //仿函数是一个类&#xff0c;不是一个…

STL18常用算法

#include<iostream> #include<algorithm> #include<vector> using namespace std; //transform 将一个容器中的元素搬运在另一个容器中 #if 0 //错误 struct PrintVector {void operator()(int v) {cout << v << " ";} }; void test0…

php中页面平滑回到顶部代码,原生JS实现平滑回到顶部组件

返回顶部组件是一种极其常见的网页功能&#xff0c;需求简单&#xff1a;页面滚动一定距离后&#xff0c;显示返回顶部的按钮&#xff0c;点击该按钮可以将滚动条滚回至页面开始的位置。实现思路也很容易&#xff0c;只要改变document.documentElement.scrollTop或document.bod…

C++基础01-C++对c的增强

所谓namespace&#xff0c;是指标识符的各种可见范围。C标准程序库中的所 有标识符都被定义于一个名为std的namespace中。 一 &#xff1a;<iostream>和<iostream.h>格式不一样&#xff0c;前者没有后缀&#xff0c;实际上&#xff0c; 在你的编译器include文件夹…

C++基础02-C++对c的拓展

变量名实质上是一段连续存储空间的别名&#xff0c;是一个标号(门牌号) 通过变量来申请并命名内存空间. 通过变量的名字可以使用存储空间. 变量名&#xff0c;本身是一段内存的引用&#xff0c;即别名(alias). 引用可以看作一个已定义变量的别名。 引用的语法&#xff…

php小程序onload,微信小程序 loading 组件实例详解

这篇文章主要介绍了微信小程序 loading 组件实例详解的相关资料,需要的朋友可以参考下loading通常使用在请求网络数据时的一种方式&#xff0c;通过hidden属性设置显示与否主要属性&#xff1a;wxml显示loading正在加载jsPage({data:{// text:"这是一个页面"hiddenLo…

C++基础04-类基础

一、类和对象 面向对象三大特点&#xff1a;封装、继承、多态。 struct 中所有行为和属性都是 public 的(默认)。C中的 class 可以指定行为和属性的访问方式。 封装,可以达到,对内开放数据,对外屏蔽数据,对外提供接口。达到了信息隐蔽的功能。 class 封装的本质,在于将数…

C++基础05-类构造函数与析构函数

总结&#xff1a; 1、类对象的作用域为两个{}之间。在遇到}后开始执行析构函数 2、当没有任何显式的构造函数&#xff08;无参&#xff0c;有参&#xff0c;拷贝构造&#xff09;时&#xff0c;默认构造函数才会发挥作用 一旦提供显式的构造函数&#xff0c;默认构造函数不复…

PHP网站配置项,Thinkphp5通用网站后台配置项的动态添加及更新

一、引入无论平时我们自己制作&#xff0c;还是浏览别人的网站&#xff0c;它都具有其相应的一些共用的、通用的属性&#xff0c;比如&#xff1a;网站的名字&#xff0c;关键字、备案号、分页数量、是否开启缓存等信息。一些网站可能将配置项写死在后台&#xff0c;无法动态更…

oracle 查询cpu 100%,Oracle 11g中查询CPU占有率高的SQL

oracle版本&#xff1a;oracle11g背景&#xff1a;今天在Linux中的oracle服务上&#xff0c;运用top命令发现许多进程的CPU占有率是100%。操作步骤&#xff1a;以进程PID:7851为例执行以下语句&#xff1a;方法一&#xff1a;(1)通过PID&#xff0c;查得相对应的系统进程对应的…

C++基础08-this指针-const修饰成员函数-函数返回引用/值

一、this指针 1、C类对象中的成员变量和成员函数是分开存储的。C语言中的内存四区模型仍然有效&#xff01; 2、C中类的普通成员函数都隐式包含一个指向当前对象的this指针。 3、静态成员函数、成员变量属于类 4、静态成员函数与普通成员函数的区别 静态成员函数不包含指…

C++基础09-货物售卖和MyArray实现

1、货物出货与进货 #if 0 #include<iostream> using namespace std; /* 某商店经销一种货物。货物购进和卖出时以箱为单位。各箱 的重量不一样&#xff0c;因此商店需要记录目前库存的总重量&#xff0c;现在用 C模拟商店货物购进和卖出的情况 */ class Goods { public:…

C++基础11-类和对象之操作符重载1

总结&#xff1a; 1、运算符重载的本质是函数重载 2、运算符重载可分为成员函数重载和全局函数重载(差一个参数) 3、运算符重载函数的参数至少有一个是类对象&#xff08;或类对象的引用&#xff09; 4、不可以被重载的操作符有&#xff1a;成员选择符(.) 成员对象选择符(.*) …