TensorFlow 模型中的回调函数与损失函数

回调函数

tf.keras 的回调函数实际上是一个类,一般是在 model.fit 时作为参数指定,用于控制在训练过程开始或者在训练过程结束,在每个 epoch 训练开始或者训练结束,在每个 batch 训练开始或者训练结束时执行一些操作,例如收集一些日志信息,改变学习率等超参数,提前终止训练过程等等。同样地,针对 model.evaluate 或者 model.predict 也可以指定 callbacks 参数,用于控制在评估或预测开始或者结束时,在每个 batch 开始或者结束时执行一些操作,但这种用法相对少见。

大部分时候,keras.callbacks 子模块中定义的回调函数类已经足够使用了,如果有特定的需要,我们也可以通过对 keras.callbacks.Callbacks 实施子类化构造自定义的回调函数。所有回调函数都继承至 keras.callbacks.Callbacks 基类,拥有 params 和 model 这两个属性。

其中 params 是一个 dict,记录了训练相关参数 (例如 verbosity, batch size, number of epochs 等等)。model 即当前关联的模型的引用。此外,对于回调类中的一些方法如 on_epoch_begin,on_batch_end,还会有一个输入参数 logs, 提供有关当前 epoch 或者 batch 的一些信息,并能够记录计算结果,如果 model.fit 指定了多个回调函数类,这些 logs 变量将在这些回调函数类的同名函数间依顺序传递。

内置回调函数

  • BaseLogger:收集每个 epoch 上 metrics 在各个 batch 上的平均值,对 stateful_metrics 参数中的带中间状态的指标直接拿最终值无需对各个 batch 平均,指标均值结果将添加到 logs 变量中。该回调函数被所有模型默认添加,且是第一个被添加的。
  • History:将 BaseLogger 计算的各个 epoch 的 metrics 结果记录到 history 这个 dict 变量中,并作为 model.fit 的返回值。该回调函数被所有模型默认添加,在 BaseLogger 之后被添加。
  • EarlyStopping:当被监控指标在设定的若干个 epoch 后没有提升,则提前终止训练。
  • TensorBoard:为 Tensorboard 可视化保存日志信息。支持评估指标,计算图,模型参数等的可视化。
  • ModelCheckpoint:在每个 epoch 后保存模型。
  • ReduceLROnPlateau:如果监控指标在设定的若干个 epoch 后没有提升,则以一定的因子减少学习率。
  • TerminateOnNaN:如果遇到 loss 为 NaN,提前终止训练。
  • LearningRateScheduler:学习率控制器。给定学习率 lr 和 epoch 的函数关系,根据该函数关系在每个 epoch 前调整学习率。
  • CSVLogger:将每个 epoch 后的 logs 结果记录到 CSV 文件中。
  • ProgbarLogger:将每个 epoch 后的 logs 结果打印到标准输出流中。

自定义回调函数

可以使用 callbacks.LambdaCallback 编写较为简单的回调函数,也可以通过对 callbacks.Callback 子类化编写更加复杂的回调函数逻辑。

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers,models,losses,metrics,callbacks
import tensorflow.keras.backend as K# 示范使用LambdaCallback编写较为简单的回调函数
import json
json_log = open('./data/keras_log.json', mode='wt', buffering=1)
json_logging_callback = callbacks.LambdaCallback(on_epoch_end=lambda epoch, logs: json_log.write(json.dumps(dict(epoch = epoch,**logs)) + '\n'),on_train_end=lambda logs: json_log.close()
)# 示范通过Callback子类化编写回调函数(LearningRateScheduler的源代码)
class LearningRateScheduler(callbacks.Callback):def __init__(self, schedule, verbose=0):super(LearningRateScheduler, self).__init__()self.schedule = scheduleself.verbose = verbosedef on_epoch_begin(self, epoch, logs=None):if not hasattr(self.model.optimizer, 'lr'):raise ValueError('Optimizer must have a "lr" attribute.')try:lr = float(K.get_value(self.model.optimizer.lr))lr = self.schedule(epoch, lr)except TypeError:  # Support for old API for backward compatibilitylr = self.schedule(epoch)if not isinstance(lr, (tf.Tensor, float, np.float32, np.float64)):raise ValueError('The output of the "schedule" function ''should be float.')if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating:raise ValueError('The dtype of Tensor should be float')K.set_value(self.model.optimizer.lr, K.get_value(lr))if self.verbose > 0:print('\nEpoch %05d: LearningRateScheduler reducing learning ''rate to %s.' % (epoch + 1, lr))def on_epoch_end(self, epoch, logs=None):logs = logs or {}logs['lr'] = K.get_value(self.model.optimizer.lr)

损失函数

一般来说,监督学习的目标函数由损失函数和正则化项组成。(Objective = Loss + Regularization),对于 keras 模型,目标函数中的正则化项一般在各层中指定,例如使用 Dense 的 kernel_regularizer 和 bias_regularizer 等参数指定权重使用 l1 或者 l2 正则化项,此外还可以用 kernel_constraint 和 bias_constraint 等参数约束权重的取值范围,这也是一种正则化手段。

损失函数在模型编译时候指定。对于回归模型,通常使用的损失函数是均方损失函数 mean_squared_error。对于二分类模型,通常使用的是二元交叉熵损失函数 binary_crossentropy。

对于多分类模型,如果 label 是 one-hot 编码的,则使用类别交叉熵损失函数 categorical_crossentropy。如果 label 是类别序号编码的,则需要使用稀疏类别交叉熵损失函数 sparse_categorical_crossentropy。如果有需要,也可以自定义损失函数,自定义损失函数需要接收两个张量 y_true,y_pred 作为输入参数,并输出一个标量作为损失函数值。

损失函数和正则化项

tf.keras.backend.clear_session()model = models.Sequential()
model.add(layers.Dense(64, input_dim=64,kernel_regularizer=regularizers.l2(0.01),activity_regularizer=regularizers.l1(0.01),kernel_constraint = constraints.MaxNorm(max_value=2, axis=0)))
model.add(layers.Dense(10,kernel_regularizer=regularizers.l1_l2(0.01,0.01),activation = "sigmoid"))
model.compile(optimizer = "rmsprop",loss = "binary_crossentropy",metrics = ["AUC"])
model.summary()Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
dense (Dense)                (None, 64)                4160
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650
=================================================================
Total params: 4,810
Trainable params: 4,810
Non-trainable params: 0
_________________________________________________________________

内置损失函数

内置的损失函数一般有类的实现和函数的实现两种形式。如:CategoricalCrossentropy 和 categorical_crossentropy 都是类别交叉熵损失函数,前者是类的实现形式,后者是函数的实现形式。常用的一些内置损失函数说明如下。

  • mean_squared_error(均方误差损失,用于回归,简写为 mse, 类与函数实现形式分别为 MeanSquaredError 和 MSE)
  • mean_absolute_error (平均绝对值误差损失,用于回归,简写为 mae, 类与函数实现形式分别为 MeanAbsoluteError 和 MAE)
  • mean_absolute_percentage_error (平均百分比误差损失,用于回归,简写为 mape, 类与函数实现形式分别为 MeanAbsolutePercentageError 和 MAPE)
  • Huber(Huber 损失,只有类实现形式,用于回归,介于 mse 和 mae 之间,对异常值比较鲁棒,相对 mse 有一定的优势)
  • binary_crossentropy(二元交叉熵,用于二分类,类实现形式为 BinaryCrossentropy)
  • categorical_crossentropy(类别交叉熵,用于多分类,要求 label 为 onehot 编码,类实现形式为 CategoricalCrossentropy)
  • sparse_categorical_crossentropy(稀疏类别交叉熵,用于多分类,要求 label 为序号编码形式,类实现形式为 SparseCategoricalCrossentropy)
  • hinge(合页损失函数,用于二分类,最著名的应用是作为支持向量机 SVM 的损失函数,类实现形式为 Hinge)
  • kld(相对熵损失,也叫 KL 散度,常用于最大期望算法 EM 的损失函数,两个概率分布差异的一种信息度量。类与函数实现形式分别为 KLDivergence 或 KLD)
  • cosine_similarity(余弦相似度,可用于多分类,类实现形式为 CosineSimilarity)

自定义损失函数

自定义损失函数接收两个张量 y_true, y_pred 作为输入参数,并输出一个标量作为损失函数值。也可以对 tf.keras.loss.Loss 进行子类化,重写 call 方法实现损失的计算逻辑,从而得到损失函数的类的实现。

下面是一个 focal Loss 的自定义实现示范。focal Loss 是一种对 binary_crossentropy 的改进损失函数形式。它在样本不均衡和存在较多易分类的样本时相比 binary_crossentropy 具有明显的优势。

它有两个可调参数,alpha 参数和 gamma 参数。其中 alpha 参数主要用于衰减负样本的权重,gamma 参数主要用于衰减容易训练样本的权重。从而让模型更加聚焦在正样本和困难样本上。这就是为什么这个损失函数叫做 focal Loss。

def focal_loss(gamma=2., alpha=0.75):def focal_loss_fixed(y_true, y_pred):bce = tf.losses.binary_crossentropy(y_true, y_pred)p_t = (y_true * y_pred) + ((1 - y_true) * (1 - y_pred))alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha)modulating_factor = tf.pow(1.0 - p_t, gamma)loss = tf.reduce_sum(alpha_factor * modulating_factor * bce,axis = -1 )return lossreturn focal_loss_fixedclass FocalLoss(tf.keras.losses.Loss):def __init__(self,gamma=2.0,alpha=0.75,name = "focal_loss"):self.gamma = gammaself.alpha = alphadef call(self,y_true,y_pred):bce = tf.losses.binary_crossentropy(y_true, y_pred)p_t = (y_true * y_pred) + ((1 - y_true) * (1 - y_pred))alpha_factor = y_true * self.alpha + (1 - y_true) * (1 - self.alpha)modulating_factor = tf.pow(1.0 - p_t, self.gamma)loss = tf.reduce_sum(alpha_factor * modulating_factor * bce,axis = -1 )return loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/577053.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GIT 不同仓库之间合并代码

合并两个不同仓库的代码通常需要以下步骤。这里以合并两个远程仓库为例&#xff1a; 添加远程仓库&#xff1a; 在本地仓库中&#xff0c;使用以下命令添加第二个远程仓库&#xff1a; git remote add <远程仓库名> <远程仓库URL>例如&#xff1a; git remote add …

国内厉害的游戏开发公司有哪些?

中懿游游戏软件开发,中国有许多厉害的游戏开发公司&#xff0c;其中一些在国际上也享有盛誉。以下是一些在中国游戏开发领域中备受关注的公司&#xff1a; 腾讯游戏&#xff08;Tencent Games&#xff09;&#xff1a; 作为中国最大的互联网公司之一&#xff0c;腾讯的游戏分支…

视觉学习(6) —— 接收事件规则列表

条件&#xff1a; 两个地址 绑定地址1&#xff0c;条件是值为1才执行流程 &#xff08;1&#xff09;字节起止位置为 0-0 向100写入值1&#xff0c;流程次数是否会增加 答案是不会&#xff0c;字节0是在哪里 所以当写入值1 而因为字节起止位置是0 0 &#xff0c;所以只读字…

前端项目重构的深度思考和复盘

摘要&#xff1a; 项目重构是每一家稳定发展的互联企业的必经之路, 就像一个产品的诞生, 会经历产品试错和产品迭代 一样, 随着业务或新技术的不断发展, 已有架构已无法满足更多业务扩展的需求, 所以只有通过重构来让产品“进化”, 才能跟上飞速发展的时代浪潮. 技术因素 早期…

W5500-EVB-Pico评估版介绍

文章目录 1 概述2 板载资源2.1 硬件规格2.2 硬件规格2.3 工作条件 3 参考资料3.2 原理图3.3 尺寸图 (单位 : mm)3.4 参考例程 4 硬件协议栈优势 1 概述 W5500-EVB-Pico是基于树莓派RP2040和完全硬连线TCP/IP控制器W5500的微控制器开发板-基本上与树莓派Pico板相同&#xff0c;但…

【MATLAB库函数系列】线性调频Z(Chirp-Z,CZT)的MATLAB源码和C语言实现

在上一篇博客 【数字信号处理】线性调频Z(Chirp-Z,CZT)算法详解 已经详细介绍了CZT变换的应用背景和原理,先回顾一下: 回顾CZT算法 采用 FFT 算法可以很快计算出全部 N N N点 DFT 值,即Z变换 X ( z ) X(z) <

220v电源转换12v和24v用什么芯片

问&#xff1a;将220V电源转换为12V和24V - 使用什么芯片&#xff1f; 答&#xff1a;常用于将220V电源转换为12V和24V的芯片是AH8669和AH8665。 问&#xff1a;AH8669芯片提供了什么特点&#xff1f; 答&#xff1a;AH8669芯片适用于最大电流为700mA的应用。它内置了MOSFET…

01-黑马程序员大数据开发

一. Hadoop概述 1. 什么是大数据 &#xfeff;狭义上&#xff1a;对海量数据进行处理的软件技术体系&#xfeff;广义上&#xff1a;数字化、信息化时代的基础支撑&#xff0c;以数据为生活赋 2. 大数据的核心工作&#xff1a; &#xfeff;存储&#xff1a;妥善保存海量待…

5g消息-5G时代短信升级-富媒体智能交互-互联网新入口

在5G时代&#xff0c;运营商和各大手机厂商都在积极推进5G消息的商用&#xff0c;基于短信入口的富媒体消息应用在近两年得到快速发展&#xff0c;并在企业端形成了广泛应用。 作为5G时代的数字原生应用&#xff0c;5G消息支持用户通过文字、图片、音频、视频、位置等富媒体方式…

【算法题】链表重排(js)

力扣链接&#xff1a;https://leetcode.cn/problems/LGjMqU/description/ /*** Definition for singly-linked list.* function ListNode(val, next) {* this.val (valundefined ? 0 : val)* this.next (nextundefined ? null : next)* }*/ /*** param {ListNode…

C++11(上):新特性讲解

C11新特性讲解 前言1.列表初始化1.1{ }初始化1.2std::initializer_list 2.类型推导2.1 auto2.2 typeid2.3 decltype 3.范围for4.STL的变化4.1新容器4.2容器的新方法 5.右值引用和移动语义5.1 左值引用和右值引用5.2 左值引用与右值引用比较5.3 右值引用的使用场景5.4 右值、左值…

浙江大唐乌沙山电厂选择ZStack Cloud打造新一代云基础设施

浙江大唐乌沙山电厂选择云轴科技ZStack Cloud云平台为其提供高性能、高可用的云主机、云存储和云网络&#xff0c;构建了简单、稳定、安全、高效的云基础设施&#xff1b;通过ZStackCloud为其提供可视化服务编排、多租户自服务等模块&#xff0c;帮助电厂提高IT资源利用率&…

解决FTP传输慢的问题(ftp传输慢为什么)

在企业运营中&#xff0c;使用FTP进行文件或数据传输是相当普遍的做法。尽管FTP是一种传统的文件传输工具&#xff0c;但在实际应用中&#xff0c;我们可能会面临传输速度缓慢的问题&#xff0c;这不仅影响工作效率&#xff0c;还浪费时间。为了解决这一问题&#xff0c;我们可…

泛微OA xmlrpcServlet接口任意文件读取漏洞(CNVD-2022-43245)

CNVD-2022-43245 泛微e-cology XmlRpcServlet接口处存在任意文件读取漏洞&#xff0c;攻击者可利用漏洞获取敏感信息。 1.漏洞级别 中危 2.影响范围 e-office < 9.5 202201133.漏洞搜索 fofa 搜索 app"泛微-OA&#xff08;e-cology&#xff09;"4.漏洞复现 …

vue 项目/备案网页/ip网页打包成 apk 安装到平板/手机(含vue项目跨域代理打包成apk后无法访问接口的解决方案)

下载安装HBuilder X编辑器 https://www.dcloud.io/hbuilderx.html 新建 5APP 项目 打开 HBuilder X&#xff0c;新建项目 此处项目名以 ‘test’ 为例 含跨域代理的vue项目改造 若 vue 项目中含跨域代理&#xff0c;如 vue.config.js module.exports {publicPath: "./&…

【C++】开源:FTXUI终端界面库配置使用

&#x1f60f;★,:.☆(&#xffe3;▽&#xffe3;)/$:.★ &#x1f60f; 这篇文章主要介绍FTXUI终端界面库配置使用。 无专精则不能成&#xff0c;无涉猎则不能通。——梁启超 欢迎来到我的博客&#xff0c;一起学习&#xff0c;共同进步。 喜欢的朋友可以关注一下&#xff0c…

【文本处理】正则表达式

一、简介 正则表达式&#xff0c;又称规则表达式,&#xff08;Regular Expression&#xff0c;在代码中常简写为regex、regexp或RE&#xff09;&#xff0c;是一种文本模式&#xff0c;包括普通字符&#xff08;例如&#xff0c;a 到 z 之间的字母&#xff09;和特殊字符&…

苹果发布了一个Ferret(雪貂)多模态大模型,在一个无人问津的角落被一位博主捞起来

苹果12月14日释放了一个名为Ferret的多模态大语言模型&#xff0c;有的翻译是雪貂&#xff0c;有的是法学硕士&#xff0c;要我说&#xff0c;还是叫雪貂吧&#xff0c;接地气亲民&#xff0c;将来犯蠢的时候出来发张雪貂的可爱表情包作公关就完事了&#xff0c;你个法学硕士到…

C语言—每日选择题—Day63

指针相关博客 打响指针的第一枪&#xff1a;指针家族-CSDN博客 深入理解&#xff1a;指针变量的解引用 与 加法运算-CSDN博客 第一题 1. 设C语言中&#xff0c;一个int型数据在内存中占2个字节&#xff0c;则unsigned int型数据的取值范围为 A&#xff1a;0~255 B&#xff1a;0…

mysql8.x版本_select语句源码跟踪

总结 源码基于8.0.34版本分析&#xff0c;函数执行流程含义大致如下&#xff1a; do_command 方法从连接中读取命令并执行&#xff0c;调用 dispatch_command 对命令进行分发。dispatch_command 调用 mysql_parse 对命令进行解析&#xff0c;如果遇到一条语句用 ; 分隔多条命…