TensorFlow 2.0 - Checkpoint 保存变量、TensorBoard 训练可视化

文章目录

    • 1. Checkpoint 保存变量
    • 2. TensorBoard 训练过程可视化

学习于:简单粗暴 TensorFlow 2

1. Checkpoint 保存变量

  • tf.train.Checkpoint 可以保存 tf.keras.optimizertf.Variabletf.keras.Layertf.keras.Model
path = "./checkp.ckpt"
# 建立一个 checkpoint 
mycheckpoint = tf.train.Checkpoint(mybestmodel=mymodel) # 接受 **kwargs 键值对
mycheckpoint.save(path)

  • 恢复指定模型变量
# 待恢复参数的模型
restored_model = LinearModel()
# mybestmodel 名字任意写,跟下面恢复时保持一致
mycheckpoint = tf.train.Checkpoint(mybestmodel=restored_model)
# 恢复指定的变量
path = "./checkp.ckpt-1"
mycheckpoint.restore(path)X_test = tf.constant([[5.1], [6.1]])
res = restored_model.predict(X_test)
print(res)
# [[10.182168] 前一节的线性回归模型
#  [12.176777]]
  • 恢复最近的模型,自动选定目录下最新的存档(后缀数字最大的)
mycheckpoint.restore(tf.train.latest_checkpoint("./"))
  • 管理保存的参数,有时不需要保存太多,占空间
mycheckpoint = tf.train.Checkpoint(mybestmodel=mymodel)  # 接受 **kwargs 键值对
manager = tf.train.CheckpointManager(mycheckpoint, directory="./",checkpoint_name='checkp.ckpt',max_to_keep=2) # 最多保存k个最新的for loop:manager.save() # 自动递增编号manager.save(checkpoint_number=idx) # 指定编号

2. TensorBoard 训练过程可视化

  • summary_writer = tf.summary.create_file_writer(logdir=log_dir)
  • tf.summary.scalar(name='loss', data=loss, step=idx)
  • tf.summary.trace_on(profiler=True)
for loop:with summary_writer.as_default():tf.summary.scalar(name='loss', data=loss, step=idx)
with summary_writer.as_default():tf.summary.trace_export(name='model_trace', step=0,profiler_outdir=log_dir)
  • 示例
import tensorflow as tf
import numpy as npclass MNistLoader():def __init__(self):data = tf.keras.datasets.mnist# 加载数据(self.train_data, self.train_label), (self.test_data, self.test_label) = data.load_data()# 扩展维度,灰度图1通道 [batch_size, 28, 28, chanels=1]self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)self.train_label = self.train_label.astype(np.int32)self.test_label = self.test_label.astype(np.int32)# 样本个数self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]def get_batch(self, batch_size):# 从训练集里随机取出 batch_size 个样本idx = np.random.randint(0, self.num_train_data, batch_size)return self.train_data[idx, :], self.train_label[idx]# 自定义多层感知机模型
class MLPmodel(tf.keras.Model):def __init__(self):super().__init__()# 除第一维以外的维度展平self.flatten = tf.keras.layers.Flatten()self.dense1 = tf.keras.layers.Dense(units=100, activation='relu')self.dense2 = tf.keras.layers.Dense(units=10)def call(self, input):x = self.flatten(input)x = self.dense1(x)x = self.dense2(x)output = tf.nn.softmax(x)return output# %%num_epochs = 5
batch_size = 50
learning_rate = 1e-4
log_dir = './log' # 日志目录
mymodel = MLPmodel()# %%
data_loader = MNistLoader()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
num_batches = int(data_loader.num_train_data // batch_size * num_epochs)# 实例化记录器
summary_writer = tf.summary.create_file_writer(logdir=log_dir)
# 开启 trace,(可选),记录训练时的大量信息(图的结构,耗时等)
tf.summary.trace_on(profiler=True)for idx in range(num_batches):X, y = data_loader.get_batch(batch_size)with tf.GradientTape() as tape:y_pred = mymodel(X)loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)loss = tf.reduce_mean(loss)print("batch {}, loss {}".format(idx, loss.numpy()))# 记录器记录losswith summary_writer.as_default():tf.summary.scalar(name='loss', data=loss, step=idx)grads = tape.gradient(loss, mymodel.variables)optimizer.apply_gradients(grads_and_vars=zip(grads, mymodel.variables))with summary_writer.as_default():tf.summary.trace_export(name='model_trace', step=0,profiler_outdir=log_dir)
  • 开始训练,命令行进入 可视化界面 tensorboard --logdir=./log
  • 点击命令行中的链接,打开浏览器,查看训练曲线
  • 若重新训练,请删除 log 文件,或设置别的 log 路径,重新 cmd 开启 浏览器

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/473041.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

coturn的负载均衡特性_高性能负载均衡

单服务器无论如何优化,无论采用多好的硬件,总会有一个性能天花板,当单服务器的性能无法满足业务需求时,就需要设计高性能集群来提升系统整体的处理性能。高性能集群的本质很简单,通过增加更多的服务器来提升系统整体的…

LintCode MySQL 1928. 网课上课情况分析 I

文章目录1. 题目2. 解题1. 题目 online_class_situation 表展示了一些同学上网课的行为活动。 每行数据记录了一名同学在退出网课之前,当天使用同一台设备登录课程后听过的课程数目(可能是0个)。 写一条 SQL 语句,查询每位同学第…

poj1284:欧拉函数+原根

何为原根?由费马小定理可知 如果a于p互质 则有a^(p-1)≡1(mod p)对于任意的a是不是一定要到p-1次幂才会出现上述情况呢?显然不是,当第一次出现a^k≡1(mod p)时, 记为ep(a)k 当k(p-1)时,称a是p的…

python输入十个数输出最大值_python输入十个数如何输出最大值

python输入十个数输出最大值的方法:1、如果是整数的话,使用函数【a, b, c map(int, input().split())】;2、使用函数【Xinput().split()】。 相关免费学习推荐:python视频教程 python输入十个数输出最大值的方法: 第一…

SQL Server 2005远程连接连不上的解决办法收藏 Microsoft给的方法

SQL Server 2005远程连接连不上的解决办法收藏 Microsoft给的方法http://support.microsoft.com/kb/914277 是可以的,但我怕以后还会遇到这问题,干脆我也写到blog中来. 我的情况是别人怎么连也连不上我本地的DB,我装了2005的sp2也不行,后来发现关了防火墙就可以了,但我总不能什…

LintCode MySQL 1921. 从不充值的玩家(where not in)

文章目录1. 题目2. 解题1. 题目 描述 A game database contains two tables, player table and recharge table. Write a SQL query to find all players who never recharge. 样例 https://www.lintcode.com/problem/players-who-never-recharge/description 2. 解题 -- …

古风一棵桃花树简笔画_广东有个现实版的“桃花源”,藏于秘境之中,最适合情侣来度假!...

上学时,初闻“芳草鲜美,落英缤纷”,并没有多大感触。直到后来长大离家,每每为生活奔波劳累时,为工作琐碎忧心费神时,才骤然明了当年五柳先生所描绘的“桃花源”该是多少人的脑中所想、心中所向……原以为这…

关于NIOS ii烧写的几种方式(转)

源:http://www.cnblogs.com/bingoo/p/3450850.html 1. 方法一:.sof和.elf全部保存在FPGA内,程序加载和运行也是在FPGA内部。 把FPGA的配置文件.sof通过JTAG方式下载(其实是在线运行)进入FPGA本身,此时在NIOS II的界面中&#xff…

clob和blob是不是可以进行模糊查询_你知道什么是 MySQL 的模糊查询?

作者 | luanhz责编 | 郭芮本文对MySQL中几种常用的模糊搜索方式进行了介绍,包括LIKE通配符、RegExp正则匹配、内置字符串函数以及全文索引,最后给出了性能对比。引言MySQL根据不同的应用场景,支持的模糊搜索方式有多种,例如应用最…

一个长文档里,包括封面、不同的章节,如果我想封面不设置页眉页脚,每个章节的页眉都不同,请问应该如何设置页眉页脚?

问:在一个长文档里,包括封面、不同的章节,如果我想封面不设置页眉页脚,每个章节的页眉都不同,请问应该如何设置页眉页脚? 答:如果只需要首页不同,可选择“文件”菜单下的“页面设置”…

LintCode 1917. 切割剩余金属

文章目录1. 题目2. 解题1. 题目 描述 金属棒工厂的厂长拥有 n 根多余的金属棒。 当地的一个承包商提出,只要所有的棒材具有相同的长度(用 saleLength 表示棒材的长度),就将金属棒工厂的剩余棒材全部购买。 厂长可以通过将每根棒…

太原理工电子信焦工程_电气工程及其自动化专业毕业后做什么工作?近几年就业和收入怎样...

本文内容为各大高校往届大学生真实的现身说法内容,但因为是往届,每年该专业的大学情况可能会发生略微变化,所以部分内容较今年,明年甚至以后几年,实际情况可能会略有不同但是对于本专业的相关信息还是非常有参考价值的…

js定时器和linux命令locate

js定时器如果带有参数,应该采用如下方式 setTimeout(function(){function(param)},1000); 匿名函数的方法。 linux locate基于数据库的查找方法。转载于:https://www.cnblogs.com/birdskyws/p/3974556.html

编程竞赛控制系统(PC2)使用说明书

编程竞赛控制系统(PC2)使用说明书 1. 系统简介 PC2是由美国加利福尼亚大学为国际大学生编程竞赛开发研制的竞赛控制系统。目前主要用于ACM/ICPC等国际编程竞赛。PC2最新的版本是8.5d,系统采用JAVA语言编写,可以运行在任何支持JAVA的平台(windows…

怎么查看linux日志里请求量最高的url访问最多的_实用的Linux高级命令,开发运维都要懂!...

在运维的坑里摸爬滚打好几年了,我还记得我刚开始的时候,我只会使用一些简单的命令,写脚本的时候,也是要多简单有多简单,所以有时候写出来的脚本又长又臭。像一些高级点的命令,比如说 Xargs 命令、管道命令、…

ggplot2箱式图两两比较_第十九章_使用ggplot2进行高级绘图

介绍ggplot2包使用形状、颜色和尺寸来对多元数据进行可视化用刻面图比较各组自定义ggplot2图19.1 R中的四种图形系统基础gridlatticeggplot2(用的较多)gghub需要的R包ggpolt2gridExtra(可以拼图)car19.2 ggplot2介绍library(ggplot2)ggplot(datamtcars, aes(xwt, ympg)) geom_p…

centos7 編譯 chmsee

安装libchm及相关的devel包,安装 xulrunner 及 devel 包!否则后面make的时候会出错! 到解压出来的chmsee/src目录下,找到与你系统对应的Makefile文件,我选的是Makefile.fedora,cp Makefile.fedora Makefile…

python调用cmd命令释放端口_详解python调用cmd命令三种方法

目前我使用到的python中执行cmd的方式有三种 使用os.system("cmd") 该方法在调用完shell脚本后,返回一个16位的二进制数,低位为杀死所调用脚本的信号号码,高位为脚本的退出状态码,即脚本中“exit 1”的代码执行后,os.system函数返回值的高位数则是1,如果低位数是0的情…

LeetCode 1742. 盒子中小球的最大数量

文章目录1. 题目2. 解题1. 题目 你在一家生产小球的玩具厂工作,有 n 个小球,编号从 lowLimit 开始,到 highLimit 结束(包括 lowLimit 和 highLimit ,即 n highLimit - lowLimit 1)。 另有无限数量的盒子…

bash shell命令(1)

本文地址:http://www.cnblogs.com/archimedes/p/bash-shell1.html,转载请注明源地址。 ls命令 ls用来列出目录的内容,它是用户最常用的命令之一,ls命令的格式为: ls[选项][目录名或文件名] 选项的主要参数:…