【神经网络扩展】:断点续训和参数提取

课程来源:人工智能实践:Tensorflow笔记2

文章目录

  • 前言
  • 断点续训主要步骤
  • 参数提取主要步骤
  • 总结


前言

本讲目标:断点续训,存取最优模型;保存可训练参数至文本


断点续训主要步骤

读取模型:

先定义出存放模型的路径和文件名,命名为.ckpt文件。
生成ckpt文件的时候会同步生成索引表,所以通过判断是否存在索引表来知晓是不是已经保存过模型参数。
如果有了索引表就利用load_weights函数读取已经保存的模型参数。

code:


checkpoint_save_path = "./checkpoint/fashion.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)

在这里插入图片描述

保存模型:

保存模型参数可以使用TensorFlow给出的回调函数,直接保存训练出来的模型参数
tf.keras.callbacks.ModelCheckpoint( filepath=路径文件名(文件存储路径),
save_weights_only=True/False,(是否只保留参数模型)
save_best_only=True/False(是否只保留最优结果)) 执行训练过程中时,加入callbacks选项:
history=model.fit(callbacks=[cp_callback])

code:

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])

第一次运行:
在这里插入图片描述
第二次运行:可以发现模型并不是从初始训练,而是在基于保存的模型开始训练的(这一点可以从准确率和损失看出):
在这里插入图片描述
全部代码:

import tensorflow as tf
import osfashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])checkpoint_save_path = "./checkpoint/fashion.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])
model.summary()

参数提取主要步骤

设置打印的格式,使所有参数都打印出来

np.set_printoptions(threshold=np.inf)
print(model.trainable_variables)

将所有可训练参数存入文本:

file = open('./weights.txt', 'w')
for v in model.trainable_variables:file.write(str(v.name) + '\n')file.write(str(v.shape) + '\n')file.write(str(v.numpy()) + '\n')
file.close()

完整代码:

import tensorflow as tf
import os
import numpy as npnp.set_printoptions(threshold=np.inf)fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])checkpoint_save_path = "./checkpoint/fashion.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])
model.summary()print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:file.write(str(v.name) + '\n')file.write(str(v.shape) + '\n')file.write(str(v.numpy()) + '\n')
file.close()

效果:
在这里插入图片描述

总结

课程链接:MOOC人工智能实践:TensorFlow笔记2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/378275.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开发DBA(APPLICATION DBA)的重要性

开发DBA是干什么的? 1. 审核开发人员写的SQL,并且纠正存在性能问题的SQL ---非常重要 2. 编写复杂业务逻辑SQL,因为复杂业务逻辑SQL开发人员写出的SQL基本上都是有性能问题的,与其让开发人员写,不如DBA自己写。---非常…

javascript和var之间的区别?

You can define your variables in JavaScript using two keywords - the let keyword and the var keyword. The var keyword is the oldest way of defining and declaring variables in JavaScript whereas the let is fairly new and was introduced by ES15. 您可以使用两…

小米手环6NFC安装太空人表盘

以前看我室友峰哥、班长都有手环,一直想买个手环,不舍得,然后今年除夕的时候降价,一狠心,入手了,配上除夕的打年兽活动还有看春晚京东敲鼓领的红包和这几年攒下来的京东豆豆,原价279的小米手环6…

计算机二级c语言题库缩印,计算机二级C语言上机题库(可缩印做考试小抄资料)...

小抄,答案,形成性考核册,形成性考核册答案,参考答案,小抄资料,考试资料,考试笔记第一套1.程序填空程序通过定义学生结构体数组,存储了若干个学生的学号、姓名和三门课的成绩。函数fun 的功能是将存放学生数据的结构体数组,按照姓名的字典序(从小到大排序…

为什么两层3*3卷积核效果比1层5*5卷积核效果要好?

目录1、感受野2、2层3 * 3卷积与1层5 * 5卷积3、2层3 * 3卷积与1层5 * 5卷积的计算量比较4、2层3 * 3卷积与1层5 * 5卷积的非线性比较5、2层3 * 3卷积与1层5 * 5卷积的参数量比较1、感受野 感受野:卷积神经网络各输出特征像素点,在原始图片映射区域大小。…

算法正确性和复杂度分析

算法正确性——循环不变式 算法复杂度的计算 方法一 代换法 —局部代换 这里直接对n变量进行代换 —替换成对数或者指数的情形 n 2^m —整体代换 这里直接对递推项进行代换 —替换成内部递推下标的形式 T(2^n) S(n) 方法二 递归树法 —用实例说明 —分析每一层的内容 —除了…

第十五章 Python和Web

第十五章 Python和Web 本章讨论Python Web编程的一些方面。 三个重要的主题:屏幕抓取、CGI和mod_python。 屏幕抓取 屏幕抓取是通过程序下载网页并从中提取信息的过程。 下载数据并对其进行分析。 从Python Job Board(http://python.org/jobs&#x…

array_chunk_PHP array_chunk()函数与示例

array_chunkPHP array_chunk()函数 (PHP array_chunk() Function) array_chunk() function is an array function, it is used to split a given array in number of array (chunks of arrays). array_chunk()函数是一个数组函数,用于将给定数组拆分为多个数组(数组…

raise

raise - Change a windows position in the stacking order button .b -text "Hi there!"pack [frame .f -background blue]pack [label .f.l1 -text "This is above"]pack .b -in .fpack [label .f.l2 -text "This is below"]raise .b转载于:ht…

c语言输出最大素数,for语句计算输出10000以内最大素数怎么搞最简单??各位大神们...

该楼层疑似违规已被系统折叠 隐藏此楼查看此楼#include #include int* pt NULL; // primes_tableint pt_size 0; // primes_table 数量大小int init_primes_table(void){FILE* pFile;pFile fopen("primes_table.bin", "rb");if (pFile NULL) {fputs(&q…

【数据结构基础笔记】【图】

代码参考《妙趣横生的算法.C语言实现》 文章目录前言1、图的概念2、图的存储形式1、邻接矩阵:2、邻接表3、代码定义邻接表3、图的创建4、深度优先搜索DFS5、广度优先搜索BFS6、实例分析前言 本章总结:图的概念、图的存储形式、邻接表定义、图的创建、图…

第十六章 测试基础

第十六章 测试基础 在编译型语言中,需要不断重复编辑、编译、运行的循环。 在Python中,不存在编译阶段,只有编辑和运行阶段。测试就是运行程序。 先测试再编码 极限编程先锋引入了“测试一点点,再编写一点点代码”的理念。 换而…

如何蹭网

引言蹭网,在普通人的眼里,是一种很高深的技术活,总觉得肯定很难,肯定很难搞。还没开始学,就已经败给了自己的心里,其实,蹭网太过于简单。我可以毫不夸张的说,只要你会windows的基本操…

android对象缓存,Android简单实现 缓存数据

前言1、每一种要缓存的数据都是有对应的versionCode,通过versionCode请求网络获取是否需要更新2、提前将要缓存的数据放入assets文件夹中,打包上线。缓存设计代码实现/*** Created by huangbo on 2017/6/19.** 主要是缓存的工具类** 缓存设计&#xff1a…

通信原理.绪论

今天刚上通信原理的第一节课,没有涉及过多的讲解,只是讲了下大概的知识框架。现记录如下: 目录1、基本概念消息、信息与信号2、通信系统模型1、信息源2、发送设备3、信道4、接收设备5、信宿6、模拟通信系统模型7、数字通信系统模型8、信源编…

Android版本演进中的兼容性问题

原文:http://android.eoe.cn/topic/summary Android 3.0 的主要变化包括: 不再使用硬件按键进行导航 (返回、菜单、搜索和主屏幕),而是采用虚拟按键 (返回,主屏幕和最近的应用)。在操作栏固定菜单。 Android 4.0 则把这些变化带到了手机平台。…

css rgba透明_rgba()函数以及CSS中的示例

css rgba透明Introduction: 介绍: Functions are used regularly while we are developing a web page or website. Therefore, to be a good developer you need to master as many functions as you can. This way your coding knowledge will increase as well …

第十七章 扩展Python

第十七章 Python什么都能做,真的是这样。这门语言功能强大,但有时候速度有点慢。 鱼和熊掌兼得 本章讨论确实需要进一步提升速度的情形。在这种情况下,最佳的解决方案可能不是完全转向C语言(或其他中低级语言)&…

android studio资源二进制,无法自动检测ADB二进制文件 – Android Studio

我尝试在Android Studio上测试我的应用程序,但我遇到了困难"waiting for AVD to come online..."我读过Android设备监视器重置adb会做到这一点,它确实……对于1次测试,当我第二天重新启动电脑时,我不仅仅是:"waiting for AVD to come online..."…

犀牛脚本:仿迅雷的增强批量下载

迅雷的批量下载满好用。但是有两点我不太中意。在这个脚本里会有所增强 1、不能设置保存的文件名。2、不能单独设置这批下载的线程限制。 使用方法 // 下载从编号001到编号020的图片,保存名为猫咪写真*.jpg 使用6个线程 jdlp http://bizhi.zhuoku.com/bizhi/200804/…