keras中的回调函数


keras训练


fit(self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[], validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None
)
1. x:输入数据。如果模型只有一个输入,那么x的类型是numpy array,如果模型有多个输入,那么x的类型应当为listlist的元素是对应于各个输入的numpy array。如果模型的每个输入都有名字,则可以传入一个字典,将输入名与其输入数据对应起来。
2. y:标签,numpy array。如果模型有多个输出,可以传入一个numpy arraylist。如果模型的输出拥有名字,则可以传入一个字典,将输出名与其标签对应起来。
3. batch_size:整数,指定进行梯度下降时每个batch包含的样本数。训练时一个batch的样本会被计算一次梯度下降,使目标函数优化一步。
4. nb_epoch:整数,训练的轮数,训练数据将会被遍历nb_epoch次。Keras中nb开头的变量均为"number of"的意思
5. verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录
6. callbacks:list,其中的元素是keras.callbacks.Callback的对象。这个list中的回调函数将会在训练过程中的适当时机被调用,参考回调函数
7. validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。
8. validation_data:形式为(X,y)或(X,y,sample_weights)的tuple,是指定的验证集。此参数将覆盖validation_spilt。
9. shuffle:布尔值,表示是否在训练过程中每个epoch前随机打乱输入样本的顺序。
10. class_weight:字典,将不同的类别映射为不同的权值,该参数用来在训练过程中调整损失函数(只能用于训练)。该参数在处理非平衡的训练数据(某些类的训练样本数很少)时,可以使得损失函数对样本数不足的数据更加关注。
11. sample_weight:权值的numpy array,用于在训练时调整损失函数(仅用于训练)。可以传递一个1D的与样本等长的向量用于对样本进行11的加权,或者在面对时序数据时,传递一个的形式为(samples,sequence_length)的矩阵来为每个时间步上的样本赋不同的权。这种情况下请确定在编译模型时添加了sample_weight_mode='temporal'

fit函数返回一个History的对象,其History.history属性记录了损失函数和其他指标的数值随epoch变化的情况,如果有验证集的话,也包含了验证集的这些指标变化情况。


保存模型结构、训练出来的权重、及优化器状态


keras 的 callback参数可以帮助我们实现在训练过程中的适当时机被调用。实现实时保存训练模型以及训练参数。

keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1
)1. filename:字符串,保存模型的路径
2. monitor:需要监视的值
3. verbose:信息展示模式,01
4. save_best_only:当设置为True时,将只保存在验证集上性能最好的模型
5. mode:‘auto’,‘min’,‘max’之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。
6. save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)
7. period:CheckPoint之间的间隔的epoch数

当验证损失不再继续降低时,如何中断训练?当监测值不再改善时中止训练


用EarlyStopping回调函数

from keras.callbacksimport EarlyStopping keras.callbacks.EarlyStopping(monitor='val_loss', patience=0, verbose=0, mode='auto'
)model.fit(X, y, validation_split=0.2, callbacks=[early_stopping])1. monitor:需要监视的量
2. patience:当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。
3. verbose:信息展示模式
4. mode:‘auto’,‘min’,‘max’之一,在min模式下,如果检测值停止下降则中止训练。在max模式下,当检测值不再上升则停止训练。

学习率动态调整1


keras.callbacks.LearningRateScheduler(schedule)
schedule:函数,该函数以epoch号为参数(从0算起的整数),返回一个新学习率(浮点数)
也可以让keras自动调整学习率

keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', 
    factor=0.1, 
    patience=10, 
    verbose=0, 
    mode='auto', 
    epsilon=0.0001, 
    cooldown=0, 
    min_lr=0
)1. monitor:被监测的量
2. factor:每次减少学习率的因子,学习率将以lr = lr*factor的形式被减少
3. patience:当patience个epoch过去而模型性能不提升时,学习率减少的动作会被触发
4. mode:‘auto’,‘min’,‘max’之一,在min模式下,如果检测值触发学习率减少。在max模式下,当检测值不再上升则触发学习率减少。
5. epsilon:阈值,用来确定是否进入检测值的“平原区”
6. cooldown:学习率减少后,会经过cooldown个epoch才重新进行正常操作
7. min_lr:学习率的下限

当学习停滞时,减少2倍或10倍的学习率常常能获得较好的效果


学习率动态2


def step_decay(epoch):initial_lrate = 0.01drop = 0.5epochs_drop = 10.0lrate = initial_lrate * math.pow(drop,math.floor((1+epoch)/epochs_drop))return lrate
lrate = LearningRateScheduler(step_decay)
sgd = SGD(lr=0.0, momentum=0.9, decay=0.0, nesterov=False)
model.fit(train_set_x, train_set_y, validation_split=0.1, nb_epoch=200, batch_size=256, callbacks=[lrate])

具体可以参考这篇文章Using Learning Rate Schedules for Deep Learning Models in Python with Keras


如何记录每一次epoch的训练/验证损失/准确度?


Model.fit函数会返回一个 History 回调,该回调有一个属性history包含一个封装有连续损失/准确的lists。代码如下:

hist = model.fit(X, y,validation_split=0.2)  
print(hist.history)

Keras输出的loss,val这些值如何保存到文本中去
Keras中的fit函数会返回一个History对象,它的History.history属性会把之前的那些值全保存在里面,如果有验证集的话,也包含了验证集的这些指标变化情况,具体写法

hist=model.fit(train_set_x,train_set_y,batch_size=256,shuffle=True,nb_epoch=nb_epoch,validation_split=0.1)
with open('log_sgd_big_32.txt','w') as f:f.write(str(hist.history))

示例,多个回调函数用逗号隔开


# checkpoint
checkpointer = ModelCheckpoint(filepath="./checkpoint.hdf5", verbose=1)
# learning rate adjust dynamic
lrate = ReduceLROnPlateau(min_lr=0.00001)answer.compile(optimizer='rmsprop', loss='categorical_crossentropy',metrics=['accuracy'])
# Note: you could use a Graph model to avoid repeat the input twice
answer.fit([inputs_train, queries_train, inputs_train], answers_train,batch_size=32,nb_epoch=5000,validation_data=([inputs_test, queries_test, inputs_test], answers_test),callbacks=[checkpointer, lrate]
)

keras回调函数中的Tensorboard


keras.callbacks.TensorBoard(log_dir='./Graph', histogram_freq=0,  write_graph=True, write_images=True)
tbCallBack = keras.callbacks.TensorBoard(log_dir='./Graph', histogram_freq=0, write_graph=True, write_images=True)
...
model.fit(...inputs and parameters..., callbacks=[tbCallBack])
tensorboard --logdir path_to_current_dir/Graph 

或者

from keras.callbacks import TensorBoardtensorboard = TensorBoard(log_dir='./logs', histogram_freq=0,write_graph=True, write_images=False)
# define model
model.fit(X_train, Y_train,batch_size=batch_size,epochs=nb_epoch,validation_data=(X_test, Y_test),shuffle=True,callbacks=[tensorboard])

https://stackoverflow.com/questions/42112260/how-do-i-use-the-tensorboard-callback-of-keras


参考文献


http://blog.csdn.net/u010159842/article/details/54602217
用Keras搞一个阅读理解机器人
回调函数Callbacks
Keras中文文档
深度学习框架Keras使用心得

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/246779.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

运动合成——机器学习技术

参考文献:《人体运动合成中的机器学习技术合成综述》 根据机器学习的用途分类,在图形学中使用到的大致如下: 1> 回归和函数逼近。回归是一种插值技术,分析已知数据点来合成新的数据。 2> 降维。从高维数的运动数据…

keras如何保存模型

使用model.save(filepath)将Keras模型和权重保存在一个HDF5文件中,该文件将包含: 模型的结构,以便重构该模型 模型的权重 训练配置(损失函数,优化器等) 优化器的状态,以便于从上次训练中断的地…

ICA独立成分分析—FastICA基于负熵最大

1. 概念 官方解释:利用统计原理进行计算的方法,是一种线性变换。 ICA分为基于信息论准则的迭代算法和基于统计学的代数方法两大类,如FastICA算法,Infomax算法,最大似然估计算法等。 这里主要讨论FastICA算法。 先来…

tensorboard的可视化及模型可视化

待整理 How to Check-Point Deep Learning Models in Keras LossWise Tensorboard 中文社区 谷歌发布TensorBoard API,让你自定义机器学习中的可视化 查找tensorflow安装的位置 pip show tensorflow-gpu Name: tensorflow-gpu Version: 1.0.1 Summary: TensorFl…

隐马尔科夫模型——简介

1. 前言 学习了概率有向图模型和概率无向图模型,回头再看了一下隐马尔可夫模型(hidden Markov model,HMM)。 HMM属于树状有向概率图模型,主要用于对时序数据的建模,它的潜在变量是离散的;而另一种状态空间模型&…

训练的神经网络不工作?一文带你跨过这37个坑

近日,Slav Ivanov 在 Medium 上发表了一篇题为《37 Reasons why your Neural Network is not working》的文章,从四个方面(数据集、数据归一化/增强、实现、训练),对自己长久以来的神经网络调试经验做了 37…

HMM——前向算法与后向算法

1. 前言 前向算法和后向算法主要还是针对HMM三大问题之一的评估问题的计算,即给定模型参数,计算观察序列的概率。文章不介绍过多公式,主要看两个例子 复习一下HMM的三大要素(以海藻(可观测)和天气&#x…

HMM——维特比算法(Viterbi algorithm)

1. 前言 维特比算法针对HMM第三个问题,即解码或者预测问题,寻找最可能的隐藏状态序列: 对于一个特殊的隐马尔可夫模型(HMM)及一个相应的观察序列,找到生成此序列最可能的隐藏状态序列。 也就是说给定了HMM的模型参数和一个观测…

tensorflow安装正确, import tf, the problem is Couldn't find field google.protob.ExtensionRange.options

问题描述 import tensorflow error with correct installation, the problem is “Couldn’t find field google.protobuf.DescriptorProto.ExtensionRange.options” python import tensorflow 出现如下问题: I tensorflow/stream_executor/dso_loader.cc:135]…

HMM——前向后向算法

1. 前言 解决HMM的第二个问题:学习问题, 已知观测序列,需要估计模型参数,使得在该模型下观测序列 P(观测序列 | 模型参数)最大,用的是极大似然估计方法估计参数。 根据已知观测序列和对应的状态序列,或者说…

Web安全(吴翰清)

安全工程师的核心竞争力不在于他能拥有多少个 0day,掌握多少种安全技术,而是在于他对安全理解的深度,以及由此引申的看待安全问题的角度和高度。 第一篇 我的安全世界观 脚本小子 “Script Kids”。 黑客精神所代表的 Open、Free、Share。…

机器学习两种方法——监督学习和无监督学习(通俗理解)

前言 机器学习分为:监督学习,无监督学习,半监督学习(也可以用hinton所说的强化学习)等。 在这里,主要理解一下监督学习和无监督学习。 监督学习(supervised learning) 从给定的训…

Tensorflow中padding的两种类型SAME和VALID

边界补充问题 原始图片尺寸为7*7,卷积核的大小为3*3,当卷积核沿着图片滑动后只能滑动出一个5*5的图片出来,这就造成了卷积后的图片和卷积前的图片尺寸不一致,这显然不是我们想要的结果,所以为了避免这种情况&#xff…

机器学习两种距离——欧式距离和马氏距离

我们熟悉的欧氏距离虽然很有用,但也有明显的缺点。它将样品的不同属性(即各指标或各变量)之间的差别等同看待,这一点有时不能满足实际要求。例如,在教育研究中,经常遇到对人的分析和判别,个体的…

HMM前向算法,维比特算法,后向算法,前向后向算法代码

typedef struct { int N; /* 隐藏状态数目;Q{1,2,…,N} */ int M; /* 观察符号数目; V{1,2,…,M}*/ double **A; /* 状态转移矩阵A[1..N][1..N]. a[i][j] 是从t时刻状态i到t1时刻状态j的转移概率 */ double **B; /* 混淆矩阵B[1..N][1..M]. b[j][k]在状态j时观察到符合k的概率。…

Python的GUI框架PySide

PySide学习笔记 PySide安装 Python自带了GUI模块Tkinter,只是界面风格有些老旧。 Python的Qt有PyQt和PySide吧。PyQt 是商业及 GPL 的版权, 而 PySide 是 LGPL。大意也就是PyQt开发商业软件是要购买授权的,而PySide则不需要。二者代码基本一…

Python中的除法保留两位小数

在C/C语言对于整形数执行除法会进行地板除(舍去小数部分)。例如 int a15/10; a的结果为1。 同样的在Java中也是如此,所以两个int型的数据相除需要返回一个浮点型数据的时候就需要强制类型转换,例如 float a (float)b/c ,其中b、…

最小二乘法深入

上次写了一个一次函数yaxb类型的最小二乘法,即可以看做是n维输入列向量对应的一个n维输出列向量,然后对已知结果进行学习,得到拟合公式。这里对m*n的矩阵进行最小二乘法分析。 设模型的输出为和训练集输出,它们之间的平方误差为&…

github访问太慢解决方案

相关文章 版本管理 github访问太慢解决方案 Material for git workshop github 域名列表 先建立一个域名列表haha.txt,下面列表中的gist.github.com是代码片功能 github.com assets-cdn.github.com avatars0.githubusercontent.com avatars1.githubusercontent.…

ubuntu16.04 制作gif

byzanz安装 sudo apt-get install byzanz byzanz-record #录像byzanz-playback #回放 下载完成后打开命令行输入byzanz-record –help 其中我们重点关注几个参数 * -d 动画录制的时间,默认录制10秒 * -e 动画开始延迟 * -x 录制区域的起始X坐标 * -y 录制区域的起始Y坐标 …