浅析Estimator、model_fn与EstimatorSpec

参考阅读:https://zhuanlan.zhihu.com/p/74857888

文章目录

  • 综合对比
      • Estimator
      • model_fn
      • EstimatorSpec
      • 关系
      • 总结
  • Estimator
      • 主要功能
      • 构造函数参数
      • 示例用法
      • 小结
  • model_fn
  • EstimatorSpec
      • 字段解释
      • 解释代码
      • 用途

综合对比

Estimatormodel_fnEstimatorSpec 是 TensorFlow 中用于构建、训练和评估模型的三个核心组件。它们之间的关系可以总结如下:

Estimator

  • 定义: Estimator 是 TensorFlow 提供的高层 API,用于简化和标准化模型的训练、评估和预测。
  • 功能:
    • 封装训练、评估和预测的逻辑。
    • 管理检查点、日志记录和模型保存。
    • 提供一致的接口来处理不同类型的模型。
  • 参数:
    • model_fn: 定义模型的函数。
    • model_dir: 模型保存目录。
    • config: 执行环境的配置信息。
    • params: 超参数字典。
    • warm_start_from: 热启动配置。

model_fn

  • 定义: model_fn 是一个函数,定义了模型的结构和行为。它由 Estimator 在训练、评估和预测时调用。
  • 功能:
    • 构建模型的计算图。
    • 根据运行模式(TRAIN、EVAL、PREDICT)返回不同的操作。
    • 接受特征、标签、模式、超参数和配置信息作为输入。
  • 返回值:
    • 返回一个 EstimatorSpec 对象,定义了模型在不同模式下的行为。

EstimatorSpec

  • 定义: EstimatorSpec 是一个对象,包含了模型在训练、评估和预测模式下的所有必要信息。
  • 功能:
    • 定义模型的预测、损失、训练操作和评估指标。
    • 提供一致的接口,使 Estimator 能够在不同模式下正确运行模型。
  • 字段:
    • mode: 运行模式(TRAIN、EVAL、PREDICT)。
    • predictions: 预测结果。
    • loss: 损失值。
    • train_op: 训练操作。
    • eval_metric_ops: 评估指标操作。
    • export_outputs: 导出输出。
    • training_chief_hooks, training_hooks, scaffold, evaluation_hooks, prediction_hooks: 各种钩子和脚手架对象,用于在不同阶段执行自定义操作。

关系

  1. Estimator 使用 model_fn:

    • Estimator 调用 model_fn 来构建模型的计算图并定义其行为。
    • model_fn 接受特征、标签、模式、超参数和配置信息,并返回一个 EstimatorSpec 对象。
  2. model_fn 返回 EstimatorSpec:

    • model_fn 根据当前的运行模式(TRAIN、EVAL、PREDICT)创建并返回一个 EstimatorSpec 对象。
    • EstimatorSpec 对象包含了模型在当前模式下所需的所有操作和输出。
  3. Estimator 使用 EstimatorSpec:

    • Estimator 使用 EstimatorSpec 中定义的操作来执行训练、评估和预测。
    • 根据 EstimatorSpec 中的信息,Estimator 知道如何处理模型的预测、损失计算和训练步骤。

总结

  • Estimator 是高层接口,用于管理和运行模型。
  • model_fn 是用户定义的函数,用于构建模型的计算图并返回 EstimatorSpec
  • EstimatorSpec 定义了模型在不同模式下的行为,由 model_fn 返回,并由 Estimator 使用。

Estimator

Estimator 是 TensorFlow 提供的一个高层 API,用于简化模型的训练和评估。它封装了一个模型,模型通过 model_fn 指定。Estimator 负责处理训练、评估和预测所需的所有操作,并将结果输出到指定的目录。

主要功能

  1. 模型训练、评估和预测: Estimator 封装了这些操作,简化了模型的开发和部署过程。
  2. 模型保存和恢复: 所有输出(如检查点、事件文件等)都写入 model_dir,或其子目录。这样可以方便地保存和恢复模型。
  3. 运行配置: 通过 config 参数,Estimator 可以获取有关执行环境的信息,并将其传递给 model_fn
  4. 超参数传递: 通过 params 参数,Estimator 可以将超参数传递给 model_fn 和输入函数。

构造函数参数

  • model_fn: 模型函数,定义了如何构建模型。它接受以下参数:

    • features: 从 input_fn 返回的特征,通常是 TensorTensor 字典。
    • labels: 从 input_fn 返回的标签,通常是 TensorTensor 字典。在预测模式下,labelsNone
    • mode: 运行模式,可以是 TRAINEVALPREDICT
    • params: 超参数字典,包含传递给 Estimator 的超参数。
    • config: RunConfig 对象,包含执行环境的配置信息。
  • model_dir: 模型参数、图等的保存目录,也可以用于从目录加载检查点以继续训练之前保存的模型。

  • config: RunConfig 配置对象,包含执行环境的配置信息。如果model_fn函数也定义config这个变量,则会将config传给model_fn。

  • params: 超参数字典,包含传递给 model_fn 的超参数。

  • warm_start_from: 检查点或 SavedModel 的文件路径,用于热启动,或一个 WarmStartSettings 对象以完全配置热启动。

示例用法

  1. 创建一个 Estimator 实例

    estimator = tf.estimator.DNNClassifier(feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],hidden_units=[1024, 512, 256],warm_start_from="/path/to/checkpoint/dir"
    )
    
  2. 定义 model_fn

    def my_model_fn(features, labels, mode, params):# 构建模型logits = build_model(features, mode, params)predictions = {'classes': tf.argmax(input=logits, axis=1),'probabilities': tf.nn.softmax(logits)}# PREDICT 模式if mode == tf.estimator.ModeKeys.PREDICT:return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)# 计算损失loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)# 训练操作if mode == tf.estimator.ModeKeys.TRAIN:optimizer = tf.train.AdamOptimizer(learning_rate=params['learning_rate'])train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)# 评估指标eval_metric_ops = {'accuracy': tf.metrics.accuracy(labels=labels, predictions=predictions['classes'])}return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
    
  3. 使用 Estimator 进行训练、评估和预测

    # 训练
    estimator.train(input_fn=train_input_fn, steps=1000)# 评估
    eval_result = estimator.evaluate(input_fn=eval_input_fn)
    print(eval_result)# 预测
    predictions = estimator.predict(input_fn=predict_input_fn)
    for pred in predictions:print(pred)
    

小结

Estimator 提供了一种结构化的方法来定义和管理 TensorFlow 模型,使得模型的训练、评估和预测更加方便和标准化。它通过 model_fn 将模型的构建与训练、评估和预测逻辑分离,并且通过配置和参数化提供了灵活性。

model_fn

输入:

  • features: 从 input_fn 返回的特征,通常是 TensorTensor 字典。
  • labels: 从 input_fn 返回的标签,通常是 TensorTensor 字典。在预测模式下,labelsNone
  • mode: 运行模式,可以是 TRAINEVALPREDICT
  • params: 超参数字典,包含传递给 Estimator 的超参数。
  • config: RunConfig 对象,包含执行环境的配置信息。

返回值:
一个EstimatorSpec

前两个参数是从输入函数中返回的特征和标签批次;也就是说,features 和 labels 是模型将使用的数据。

params 是一个字典,它可以传入许多参数用来构建网络或者定义训练方式等。例如通过设置params[‘n_classes’]来定义最终输出节点的个数等。
config 通常用来控制checkpoint或者分布式什么,这里不深入研究。
mode 参数表示调用程序是请求训练、评估还是预测,分别通过tf.estimator.ModeKeys.TRAIN / EVAL / PREDICT 来定义。另外通过观察DNNClassifier的源代码可以看到,mode这个参数并不用手动传入,因为Estimator会自动调整。例如当你调用estimator.train(…)的时候,mode则会被赋值tf.estimator.ModeKeys.TRAIN。

模型有训练,验证和测试三种阶段,而且对于不同模式,对数据有不同的处理方式。例如在训练阶段,我们需要将数据喂给模型,模型基于输入数据给出预测值,然后我们在通过预测值和真实值计算出loss,最后用loss更新网络参数,而在评估阶段,我们则不需要反向传播更新网络参数,换句话说,model_fn需要对三种模式设置三套代码

EstimatorSpec

collections.namedtuple 是 Python 标准库中的一个函数,用于创建不可变的、具名的元组(named tuple)。这些具名元组可以像类一样使用,有字段名称,使代码更具可读性和可维护性。

在这段代码中,collections.namedtuple 被用来创建一个名为 EstimatorSpec 的具名元组,它包含了一组用于定义模型在不同模式下行为的字段。以下是每个字段的解释:

字段解释

  1. mode: 模式,表示当前的运行模式,可以是训练(TRAIN)、评估(EVAL)或预测(PREDICT)模式。
  2. predictions: 预测值,可以是一个 TensorTensor 字典,用于预测模式下输出结果。
  3. loss: 损失值,一个标量 Tensor,表示模型的损失,用于训练和评估模式。
  4. train_op: 训练操作,表示在训练模式下执行的操作(通常是优化步骤)。
  5. eval_metric_ops: 评估指标操作,是一个字典,包含评估模式下的度量结果。
  6. export_outputs: 导出输出,是一个字典,定义了模型在导出为 SavedModel 时的输出签名。
  7. training_chief_hooks: 主训练钩子,是一个迭代器,包含在主 worker 上运行的 SessionRunHook 对象。
  8. training_hooks: 训练钩子,是一个迭代器,包含在所有 worker 上运行的 SessionRunHook 对象。
  9. scaffold: 脚手架,是一个 tf.train.Scaffold 对象,用于设置初始化、保存和恢复操作。
  10. evaluation_hooks: 评估钩子,是一个迭代器,包含在评估过程中运行的 SessionRunHook 对象。
  11. prediction_hooks: 预测钩子,是一个迭代器,包含在预测过程中运行的 SessionRunHook 对象。

解释代码

collections.namedtuple('EstimatorSpec', ['mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops','export_outputs', 'training_chief_hooks', 'training_hooks', 'scaffold','evaluation_hooks', 'prediction_hooks'
])

这行代码创建了一个名为 EstimatorSpec 的具名元组类,它包含了上述的这些字段。EstimatorSpec 类可以用于存储和传递这些字段的值,使得在模型函数(model_fn)中可以方便地定义和返回这些值。

用途

EstimatorSpec 主要用于 TensorFlow 的 Estimator API 中,以统一的方式定义模型的各个组成部分。通过使用 EstimatorSpec,可以确保模型在不同模式下的行为是一致且正确的。例如:

  • 在训练模式下,必须提供 losstrain_op
  • 在评估模式下,必须提供 loss
  • 在预测模式下,必须提供 predictions

使用 EstimatorSpec,可以更简洁和清晰地定义模型的各个部分,并且通过具名元组的方式,使代码更加可读和易于维护。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/39321.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

西电811考研、140分专业课及811/821经验

被拟录取了,说一说自己考研经验,本人跟的研梦考研全程班,胖覃学长很负责任,貌似已经直博西电了,但也很负责。 1、通信工程学院分为学硕与专硕,学硕包含信息与通信工程、交通运输工程、军队指挥学&#xff…

Perl语言中的排序艺术:深入探讨内置排序函数

Perl是一种功能强大的脚本语言,以其灵活的文本处理能力而闻名。在Perl中,排序是一项常见的任务,无论是对数组元素进行排序,还是对复杂数据结构进行排序,Perl都提供了多种内置的排序函数,以满足不同的需求。…

深入掌握Symfony与Composer:PHP依赖管理的艺术

引言 Composer是PHP的依赖管理工具,广泛用于Symfony等现代PHP应用程序中。它允许开发者声明依赖项,自动处理依赖的安装和更新,确保应用程序的依赖项得到有效管理。本文将详细介绍Composer的使用方法,包括基本命令、依赖管理、自动…

Linux环境安装配置nginx服务流程

Linux环境的Centos、麒麟、统信操作系统安装配置nginx服务流程操作: 1、官网下载 下载地址 或者通过命令下载 wget http://nginx.org/download/nginx-1.20.2.tar.gz 2、上传到指定的服务器并解压 tar -zxvf nginx-1.20.1.tar.gzcd nginx-1.20.1 3、编译并安装到…

条件过滤检索

背景介绍 在大多数业务场景中,单纯使用向量进行相似性检索并无法满足业务需求,通常需要在满足特定过滤条件、或者特定的“标签”的前提下,再进行相似性检索。 向量检索服务DashVector支持条件过滤和向量相似性检索相结合,在精确满…

数字化供应链:背景特点

​背景 1、外部环境 近年来,供应链脆弱性凸显,企业供应链压力难以缓解。 美国媒体针对美国零售联合会、美国服装和鞋类协会、美国供应链管理专业委员会等主体进行的一项供应链调查显示: 61%的供应链经理预计,供应链紊乱问题至少…

C++(第一天-----命名空间和引用)

一、C/C的区别 1、与C相比   c语言面向过程,c面向对象。   c能够对函数进行重载,可使同名的函数功能变得更加强大。   c引入了名字空间,可以使定义的变量名更多。   c可以使用引用传参,引用传参比起指针传参更加快&#…

企业化运维(5)_mysql数据库

###1.源码编译mysql### 对压缩包进行解压,并对mysql进行源码编译,其中需要下载依赖才能编译成功。 官网: www.mysql.com解压并进入目录 [rootserver1 ~]# tar xf mysql-boost-5.7.40.tar.gz [rootserver1 ~]# cd mysql-5.7.40/安装依赖性…

初识Java(复习版)

一. 什么是Java Java是一种面向对象的编程语言,和C语言有所不同,C语言是一门面向过程的语言。偏底层实现,比较注重底层的逻辑实现。不能一味的说某一种语言特别好,每一种语言都是在特定的情况下有自己的优势。 二.Java语言发展史…

昇思25天学习打卡营第2天|yulang

今天主要了解快速入门,主要包含了处理数据集、网络构建、模型训练、保存模型和加载模型,这些对于不是算法工程师理解起来可能稍微有一点的难度,学习起来有点枯燥,期待后续实战部分能完成一些独立的比较有意思的项目。

鸿蒙项目实战-月木学途:2.自定义底部导航

效果预览 Tabs组件简介 Tabs组件的页面组成包含两个部分,分别是TabContent和TabBar。TabContent是内容页,TabBar是导航页签栏,页面结构如下图所示,根据不同的导航类型,布局会有区别,可以分为底部导航、顶部…

使用ECharts实现动态数据可视化的最佳实践

使用ECharts实现动态数据可视化的最佳实践 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿! 引言 随着数据驱动决策的重要性日益增强,动态数据可视…

第二十站:Java未来光谱——量子计算与新兴技术的展望

Java作为一门成熟且广泛使用的编程语言,其在传统计算领域已经取得了巨大的成功。然而,随着量子计算等新兴技术的出现,Java也在探索其在这些领域的应用潜力。IBM Qiskit是一个开源的量子计算软件框架,它允许开发者使用多种编程语言…

登录验证码高扩展性设计方案

登录验证码高扩展性建设方案 本文分享了一种登录验证码高扩展性的建设方案,通过工厂模式策略模式,增强了验证码服务中验证码生成器、验证码存储器、验证码图片生成器的扩展性,实现了服务组件的多样化,降低了维护成本 登录验证码高…

8617 阶乘数字和

这是一个关于计算阶乘结果所有位上的数字之和的问题。我们可以通过以下步骤来解决这个问题: 1. 首先,我们需要一个函数来计算阶乘。由于n的范围可以达到50,阶乘的结果可能非常大,所以我们需要使用一个可以处理大整数的数据类型&a…

adb shell logcat -b all|grep如何可以grep两个子串?

在adb shell logcat命令中结合grep来过滤日志时,如果你想要同时匹配两个子串,你可以使用管道(|)将两个grep命令连接起来,或者使用grep的-E(或egrep,它等同于-E)选项来支持扩展的正则…

[课程][原创]opencv图像在C#与C++之间交互传递

opencv图像在C#与C之间交互传递 课程地址:https://edu.csdn.net/course/detail/39689 无限期视频有效期 课程介绍课程目录讨论留言 你将收获 学会如何封装C的DLL 学会如何用C#调用C的DLL 掌握opencv在C#和C传递思路 学会如何配置C的opencv 适用人群 拥有C#…

报错:pathspec ‘xxx‘ did not match any file(s) known to git

在 escode 中进行分支切换时报如下错误 PS > git checkout xxx error: pathspec xxx did not match any file(s) known to git远程分支已经在 gitlab 客户端手动创建,在 escode 中也使用了拉取之类的操作,但是切换分支时依然报错。 解决方案 查看分…

怎么找到DNS服务器的地址?

所有域都注册到域名名称服务器(DNS)点,以解析域名应指向的IP地址。此查找类似于在查找个人名称并查找其电话号码时的电话簿如何运行。如果DNS服务器设置错误或指向错误的名称服务器,则域可能无法加载相应的网页。 如何查找当前的…

【深度学习】C++ onnx Yolov8 目标检测推理

【深度学习】C onnx Yolov8 目标检测推理 导出onnx模型代码onnx_detect_infer.honnx_detect_infer.cppmain.cppCMAKELIST 导出onnx模型 python 中导出 from ultralytics import YOLO# Load the YOLOv8 model model YOLO("best.pt")# # Export the model to ONNX f…