TensorFlow 2.0 - Checkpoint 保存变量、TensorBoard 训练可视化

文章目录

    • 1. Checkpoint 保存变量
    • 2. TensorBoard 训练过程可视化

学习于:简单粗暴 TensorFlow 2

1. Checkpoint 保存变量

  • tf.train.Checkpoint 可以保存 tf.keras.optimizertf.Variabletf.keras.Layertf.keras.Model
path = "./checkp.ckpt"
# 建立一个 checkpoint 
mycheckpoint = tf.train.Checkpoint(mybestmodel=mymodel) # 接受 **kwargs 键值对
mycheckpoint.save(path)

  • 恢复指定模型变量
# 待恢复参数的模型
restored_model = LinearModel()
# mybestmodel 名字任意写,跟下面恢复时保持一致
mycheckpoint = tf.train.Checkpoint(mybestmodel=restored_model)
# 恢复指定的变量
path = "./checkp.ckpt-1"
mycheckpoint.restore(path)X_test = tf.constant([[5.1], [6.1]])
res = restored_model.predict(X_test)
print(res)
# [[10.182168] 前一节的线性回归模型
#  [12.176777]]
  • 恢复最近的模型,自动选定目录下最新的存档(后缀数字最大的)
mycheckpoint.restore(tf.train.latest_checkpoint("./"))
  • 管理保存的参数,有时不需要保存太多,占空间
mycheckpoint = tf.train.Checkpoint(mybestmodel=mymodel)  # 接受 **kwargs 键值对
manager = tf.train.CheckpointManager(mycheckpoint, directory="./",checkpoint_name='checkp.ckpt',max_to_keep=2) # 最多保存k个最新的for loop:manager.save() # 自动递增编号manager.save(checkpoint_number=idx) # 指定编号

2. TensorBoard 训练过程可视化

  • summary_writer = tf.summary.create_file_writer(logdir=log_dir)
  • tf.summary.scalar(name='loss', data=loss, step=idx)
  • tf.summary.trace_on(profiler=True)
for loop:with summary_writer.as_default():tf.summary.scalar(name='loss', data=loss, step=idx)
with summary_writer.as_default():tf.summary.trace_export(name='model_trace', step=0,profiler_outdir=log_dir)
  • 示例
import tensorflow as tf
import numpy as npclass MNistLoader():def __init__(self):data = tf.keras.datasets.mnist# 加载数据(self.train_data, self.train_label), (self.test_data, self.test_label) = data.load_data()# 扩展维度,灰度图1通道 [batch_size, 28, 28, chanels=1]self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)self.train_label = self.train_label.astype(np.int32)self.test_label = self.test_label.astype(np.int32)# 样本个数self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]def get_batch(self, batch_size):# 从训练集里随机取出 batch_size 个样本idx = np.random.randint(0, self.num_train_data, batch_size)return self.train_data[idx, :], self.train_label[idx]# 自定义多层感知机模型
class MLPmodel(tf.keras.Model):def __init__(self):super().__init__()# 除第一维以外的维度展平self.flatten = tf.keras.layers.Flatten()self.dense1 = tf.keras.layers.Dense(units=100, activation='relu')self.dense2 = tf.keras.layers.Dense(units=10)def call(self, input):x = self.flatten(input)x = self.dense1(x)x = self.dense2(x)output = tf.nn.softmax(x)return output# %%num_epochs = 5
batch_size = 50
learning_rate = 1e-4
log_dir = './log' # 日志目录
mymodel = MLPmodel()# %%
data_loader = MNistLoader()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
num_batches = int(data_loader.num_train_data // batch_size * num_epochs)# 实例化记录器
summary_writer = tf.summary.create_file_writer(logdir=log_dir)
# 开启 trace,(可选),记录训练时的大量信息(图的结构,耗时等)
tf.summary.trace_on(profiler=True)for idx in range(num_batches):X, y = data_loader.get_batch(batch_size)with tf.GradientTape() as tape:y_pred = mymodel(X)loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)loss = tf.reduce_mean(loss)print("batch {}, loss {}".format(idx, loss.numpy()))# 记录器记录losswith summary_writer.as_default():tf.summary.scalar(name='loss', data=loss, step=idx)grads = tape.gradient(loss, mymodel.variables)optimizer.apply_gradients(grads_and_vars=zip(grads, mymodel.variables))with summary_writer.as_default():tf.summary.trace_export(name='model_trace', step=0,profiler_outdir=log_dir)
  • 开始训练,命令行进入 可视化界面 tensorboard --logdir=./log
  • 点击命令行中的链接,打开浏览器,查看训练曲线
  • 若重新训练,请删除 log 文件,或设置别的 log 路径,重新 cmd 开启 浏览器

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/473041.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

coturn的负载均衡特性_高性能负载均衡

单服务器无论如何优化,无论采用多好的硬件,总会有一个性能天花板,当单服务器的性能无法满足业务需求时,就需要设计高性能集群来提升系统整体的处理性能。高性能集群的本质很简单,通过增加更多的服务器来提升系统整体的…

LintCode MySQL 1928. 网课上课情况分析 I

文章目录1. 题目2. 解题1. 题目 online_class_situation 表展示了一些同学上网课的行为活动。 每行数据记录了一名同学在退出网课之前,当天使用同一台设备登录课程后听过的课程数目(可能是0个)。 写一条 SQL 语句,查询每位同学第…

python输入十个数输出最大值_python输入十个数如何输出最大值

python输入十个数输出最大值的方法:1、如果是整数的话,使用函数【a, b, c map(int, input().split())】;2、使用函数【Xinput().split()】。 相关免费学习推荐:python视频教程 python输入十个数输出最大值的方法: 第一…

LintCode MySQL 1921. 从不充值的玩家(where not in)

文章目录1. 题目2. 解题1. 题目 描述 A game database contains two tables, player table and recharge table. Write a SQL query to find all players who never recharge. 样例 https://www.lintcode.com/problem/players-who-never-recharge/description 2. 解题 -- …

古风一棵桃花树简笔画_广东有个现实版的“桃花源”,藏于秘境之中,最适合情侣来度假!...

上学时,初闻“芳草鲜美,落英缤纷”,并没有多大感触。直到后来长大离家,每每为生活奔波劳累时,为工作琐碎忧心费神时,才骤然明了当年五柳先生所描绘的“桃花源”该是多少人的脑中所想、心中所向……原以为这…

clob和blob是不是可以进行模糊查询_你知道什么是 MySQL 的模糊查询?

作者 | luanhz责编 | 郭芮本文对MySQL中几种常用的模糊搜索方式进行了介绍,包括LIKE通配符、RegExp正则匹配、内置字符串函数以及全文索引,最后给出了性能对比。引言MySQL根据不同的应用场景,支持的模糊搜索方式有多种,例如应用最…

LintCode 1917. 切割剩余金属

文章目录1. 题目2. 解题1. 题目 描述 金属棒工厂的厂长拥有 n 根多余的金属棒。 当地的一个承包商提出,只要所有的棒材具有相同的长度(用 saleLength 表示棒材的长度),就将金属棒工厂的剩余棒材全部购买。 厂长可以通过将每根棒…

太原理工电子信焦工程_电气工程及其自动化专业毕业后做什么工作?近几年就业和收入怎样...

本文内容为各大高校往届大学生真实的现身说法内容,但因为是往届,每年该专业的大学情况可能会发生略微变化,所以部分内容较今年,明年甚至以后几年,实际情况可能会略有不同但是对于本专业的相关信息还是非常有参考价值的…

怎么查看linux日志里请求量最高的url访问最多的_实用的Linux高级命令,开发运维都要懂!...

在运维的坑里摸爬滚打好几年了,我还记得我刚开始的时候,我只会使用一些简单的命令,写脚本的时候,也是要多简单有多简单,所以有时候写出来的脚本又长又臭。像一些高级点的命令,比如说 Xargs 命令、管道命令、…

ggplot2箱式图两两比较_第十九章_使用ggplot2进行高级绘图

介绍ggplot2包使用形状、颜色和尺寸来对多元数据进行可视化用刻面图比较各组自定义ggplot2图19.1 R中的四种图形系统基础gridlatticeggplot2(用的较多)gghub需要的R包ggpolt2gridExtra(可以拼图)car19.2 ggplot2介绍library(ggplot2)ggplot(datamtcars, aes(xwt, ympg)) geom_p…

LeetCode 1742. 盒子中小球的最大数量

文章目录1. 题目2. 解题1. 题目 你在一家生产小球的玩具厂工作,有 n 个小球,编号从 lowLimit 开始,到 highLimit 结束(包括 lowLimit 和 highLimit ,即 n highLimit - lowLimit 1)。 另有无限数量的盒子…

bash shell命令(1)

本文地址:http://www.cnblogs.com/archimedes/p/bash-shell1.html,转载请注明源地址。 ls命令 ls用来列出目录的内容,它是用户最常用的命令之一,ls命令的格式为: ls[选项][目录名或文件名] 选项的主要参数:…

LeetCode 1743. 从相邻元素对还原数组(拓扑排序)

文章目录1. 题目2. 解题1. 题目 存在一个由 n 个不同元素组成的整数数组 nums ,但你已经记不清具体内容。 好在你还记得 nums 中的每一对相邻元素。 给你一个二维整数数组 adjacentPairs ,大小为 n - 1 ,其中每个 adjacentPairs[i] [ui, v…

BP神经网络算法学习

BP(Back Propagation)网络是1986年由Rumelhart和McCelland为首的科学家小组提出,是一种按误差逆传播算法训练的多层前馈网络,是眼下应用最广泛的神经网络模型之中的一个。BP网络能学习和存贮大量的输入-输出模式映射关系&#xff…

无向图的深度优先遍历非递归_【数据结构图(一)】什么是图

一、什么是“图”(Graph) 表示“多对多”的关系包含一组顶点:通常用 V (Vertex) 表示顶点集合一组边:通常用 E (Edge) 表示边的集合无向边:(v, w) 有向边:不考虑重边和自回路二、抽象数据类型定义类型名称:图(Graph)数…

LeetCode 1744. 你能在你最喜欢的那天吃到你最喜欢的糖果吗?(前缀和)

文章目录1. 题目2. 解题1. 题目 给你一个下标从 0 开始的正整数数组 candiesCount ,其中 candiesCount[i] 表示你拥有的第 i 类糖果的数目。 同时给你一个二维数组 queries ,其中 queries[i] [favoriteTypei, favoriteDayi, dailyCapi] 。 你按照如下…

wdcp-apache开启KeepAlive提高响应速度

因为我们的网站,媒体文件,js文件,css文件等都在同一个服务器上,并且,我们网站有非常多的图片,所以当建立好tcp链接之后,不应该马上关闭连接,因为每建立一次连接还要进行dns解析&…

如何将网页保存为图片_网页账号密码该如何保存?

我们在使用浏览器浏览一些网页的时候,需要输入我们的账号密码才能登陆,以保证安全。但是有时候浏览网页,不小心关掉了,重新打开时又要重新输入密码,这样会显得很繁琐。那么有什么办法能让网页记住我们的账号密码吗&…

scala学习-类与对象

类  /  对象 【《快学Scala》笔记】 一、类 1、Scala中的类是公有可见性的,且多个类可以包含在同一个源文件中; 1 class Counter{ 2 private var value 0  //类成员变量必须初始化,否则报错 3 4 def increment(){ //类中的…

LeetCode 1745. 回文串分割 IV(区间DP)

文章目录1. 题目2. 解题1. 题目 给你一个字符串 s ,如果可以将它分割成三个 非空 回文子字符串,那么返回 true ,否则返回 false 。 当一个字符串正着读和反着读是一模一样的,就称其为 回文字符串 。 示例 1: 输入&a…