tf.Keras (tf-1.15)使用记录3-model.compile方法

model.compile 是 TensorFlow Keras 中用于配置训练模型的方法。在开始训练之前,需要通过这个方法来指定模型的优化器、损失函数和评估指标等。

注意事项: 在开始训练(调用 model.fit)之前,必须先调用 model.compile()

1 基本用法

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

1) optimizer: 优化器

可以是预定义优化器的字符串(如 'adam', 'sgd' 等),也可以是 tf.keras.optimizers 下的优化器实例。优化器负责调整模型的权重以最小化损失函数。

以下是可以使用的字符串参数:

  1. 'sgd': 随机梯度下降优化器
  2. 'adam': Adam 优化器
  3. 'rmsprop': RMSprop 优化器
  4. 'adagrad': Adagrad 优化器
  5. 'adadelta': Adadelta 优化器
  6. 'adamax': Adamax 优化器
  7. 'nadam': Nadam 优化器
  8. 'ftrl': Ftrl 优化器

需要注意的是:

  1. 这些字符串参数是不区分大小写的。例如,‘Adam’ 和 ‘adam’ 都是有效的。

  2. 使用字符串参数时,优化器会使用其默认参数值。如果你需要自定义优化器的参数(如学习率),最好直接使用优化器类:

    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
    
  3. ‘adam’ 通常是一个很好的默认选择,因为它在各种问题上都表现良好。但对于特定问题,其他优化器可能会表现得更好。

  4. 在实践中,选择合适的优化器和调整其参数(如学习率)往往比选择特定的优化器算法更重要。

2) loss: 损失函数

用于计算模型的预测值和真实值之间的差异。可以是字符串(预定义损失函数的名称),也可以是 tf.keras.losses 下的损失函数对象。对于不同类型的问题(如分类、回归等),需要选择合适的损失函数。

以下是一些常用的字符串参数对应的损失函数:

  1. 'binary_crossentropy': 用于二分类问题的交叉熵损失。
  2. 'categorical_crossentropy': 用于多分类问题的交叉熵损失,要求标签为 one-hot 编码。
  3. 'sparse_categorical_crossentropy': 用于多分类问题的交叉熵损失,标签为整数。
  4. 'mean_squared_error''mse': 均方误差损失,用于回归问题。
  5. 'mean_absolute_error''mae': 平均绝对误差损失,用于回归问题。
  6. 'mean_absolute_percentage_error''mape': 平均绝对百分比误差,用于回归问题。
  7. 'mean_squared_logarithmic_error''msle': 均方对数误差,用于回归问题,对小差异不敏感。
  8. 'poisson': 泊松损失,适用于计数问题或其他泊松分布问题。
  9. 'kullback_leibler_divergence''kld': Kullback-Leibler 散度,用于衡量两个概率分布之间的差异。
  10. 'hinge': 用于“最大间隔”分类问题的铰链损失。
  11. 'squared_hinge': 铰链损失的平方版本。
  12. 'logcosh': 对数双曲余弦损失,用于回归问题,对异常值不敏感。

3) metrics: 评估指标列表,用于评估模型的性能

这些指标在训练过程中不会用于梯度计算,仅用于观察。常见的指标包括 'accuracy''precision''recall' 等。

model.compile() 方法中,metrics 参数用于指定在训练和评估期间模型将评估哪些指标。这些指标不会用于训练过程中的反向传播和权重更新,仅用于观察模型的性能。以下是一些可以通过字符串参数传入的常用指标:

  1. 'accuracy''acc': 准确率,用于分类问题。
  2. 'binary_accuracy': 二分类准确率。
  3. 'categorical_accuracy': 多分类准确率,要求标签为 one-hot 编码。
  4. 'sparse_categorical_accuracy': 多分类准确率,标签为整数。
  5. 'top_k_categorical_accuracy': Top-k 准确率,即目标类别在模型预测的前 k 个最可能的类别中的准确率,用于多分类问题。
  6. 'sparse_top_k_categorical_accuracy': 与 'top_k_categorical_accuracy' 类似,但适用于标签为整数的情况。
  7. 'mean_squared_error''mse': 均方误差,用于回归问题。
  8. 'mean_absolute_error''mae': 平均绝对误差,用于回归问题。
  9. 'mean_absolute_percentage_error''mape': 平均绝对百分比误差,用于回归问题。
  10. 'mean_squared_logarithmic_error''msle': 均方对数误差,用于回归问题。
  11. 'cosine_similarity': 余弦相似度,用于回归问题或多标签分类问题。
  12. 'precision': 精确率,用于二分类或多标签分类问题。
  13. 'recall': 召回率,用于二分类或多标签分类问题。
  14. 'auc': 曲线下面积(Area Under the Curve),用于二分类问题。

使用示例:

# 二分类问题
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy', 'precision', 'recall'])# 多分类问题
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy', 'top_k_categorical_accuracy'])# 回归问题
model.compile(optimizer='adam',loss='mean_squared_error',metrics=['mae', 'mse'])

对于一些特定的指标(如 'precision', 'recall', 'auc' 等),可能需要使用 tf.keras.metrics 下的类实例来获得更多的配置选项,例如设置阈值或为多标签分类问题指定平均方法。

from tensorflow.keras.metrics import Precision, Recallmodel.compile(optimizer='adam',loss='binary_crossentropy',metrics=[Precision(thresholds=0.5), Recall(thresholds=0.5)])

2 高级用法

  • 使用自定义优化器:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
  • 使用自定义损失函数:
def custom_loss(y_true, y_pred):# 自定义损失计算逻辑return tf.reduce_mean(tf.square(y_true - y_pred))model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])
  • 使用多个损失函数和评估指标:

如果模型有多个输出,你可以为每个输出指定不同的损失函数和评估指标。

model.compile(optimizer='adam',loss={'output_a': 'sparse_categorical_crossentropy', 'output_b': 'mse'},metrics={'output_a': ['accuracy'], 'output_b': ['mae', 'mse']})
  • 使用学习率衰减:
from tensorflow.keras.optimizers.schedules import ExponentialDecaylr_schedule = ExponentialDecay(initial_learning_rate=1e-2, decay_steps=10000, decay_rate=0.9)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/68840.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决whisper 本地运行时GPU 利用率不高的问题

我在windows 环境下本地运行whisper 模型,使用的是nivdia RTX4070 显卡,结果发现GPU 的利用率只有2% 。使用 import torch print(torch.cuda.is_available()) 返回TRUE。表示我的cuda 是可用的。 最后在github 的下列网页上找到了问题 极低的 GPU 利…

大模型综合性能考题汇总

- K1.5长思考版本 一、创意写作能力 题目1:老爸笑话 要求:写五个原创的老爸笑话。 考察点:考察模型的幽默感和创意能力,以及对“原创”要求的理解和执行能力。 题目2:创意故事 要求:写一篇关于亚伯拉罕…

在 crag 中用 LangGraph 进行评分知识精炼-下

在上一次给大家展示了基本的 Rag 检索过程,着重描述了增强检索中的知识精炼和补充检索,这些都是 crag 的一部分,这篇内容结合 langgraph 给大家展示通过检索增强生成(Retrieval-Augmented Generation, RAG)的工作流&am…

(二)QT——按钮小程序

目录 前言 按钮小程序 1、步骤 2、代码示例 3、多个按钮 ①信号与槽的一对一 ②多对一(多个信号连接到同一个槽) ③一对多(一个信号连接到多个槽) 结论 前言 按钮小程序 Qt 按钮程序通常包含 三个核心文件: m…

win11本地部署 DeepSeek-R1 大模型!免费开源,媲美OpenAI-o1能力,断网也能用

一、下载ollama 二、安装ollama 三、部署DeepSeek-R1 在cmd窗口中先输入ollama -v查看ollama是否安装成功,然后直接运行部署deepseek-r1的命令 ollama run deepseek-r1,出现下面界面即为安装成功。 C:\Users\admin>ollama -v ollama version is 0.5…

【工欲善其事】利用 DeepSeek 实现复杂 Git 操作:从原项目剥离出子版本树并同步到新的代码库中

文章目录 利用 DeepSeek 实现复杂 Git 操作1 背景介绍2 需求描述3 思路分析4 实现过程4.1 第一次需求确认4.2 第二次需求确认4.3 第三次需求确认4.4 V3 模型:中间结果的处理4.5 方案验证,首战告捷 5 总结复盘 利用 DeepSeek 实现复杂 Git 操作 1 背景介绍…

TensorFlow 示例摄氏度到华氏度的转换(一)

TensorFlow 实现神经网络模型来进行摄氏度到华氏度的转换,可以将其作为一个回归问题来处理。我们可以通过神经网络来拟合这个简单的转换公式。 1. 数据准备与预处理 2. 构建模型 3. 编译模型 4. 训练模型 5. 评估模型 6. 模型应用与预测 7. 保存与加载模型 …

gitea - fatal: Authentication failed

文章目录 gitea - fatal: Authentication failed概述run_gitea_on_my_pkm.bat 笔记删除windows凭证管理器中对应的url认证凭证启动gitea服务端的命令行正常用 TortoiseGit 提交代码备注END gitea - fatal: Authentication failed 概述 本地的git归档服务端使用gitea. 原来的用…

MySql运维篇---008:日志:错误日志、二进制日志、查询日志、慢查询日志,主从复制:概述 虚拟机更改ip注意事项

#先登录mysql mysql -uroot -p1234#通过此系统变量,查看当前mysql的版本中默认的日志格式是哪个 show variables like %binlog\_format%;1.2.3 查看 由于日志是以二进制方式存储的,不能直接读取,需要通过二进制日志查询工具 mysqlbinlog 来查…

【背包问题】二维费用的背包问题

目录 二维费用的背包问题详解 总结: 空间优化: 1. 状态定义 2. 状态转移方程 3. 初始化 4. 遍历顺序 5. 时间复杂度 例题 1,一和零 2,盈利计划 二维费用的背包问题详解 前面讲到的01背包中,对物品的限定条件…

使用 DeepSeek-R1 等推理模型将 RAG 转换为 RAT,以实现更智能的 AI

使用 DeepSeek-R1 等推理模型将 RAG 转换为 RAT,以实现更智能的 AI 传统的检索增强生成(RAG)系统在生成具备上下文感知的答案方面表现出色。然而,它们往往存在以下不足: 精确性不足:单次推理可能会忽略复杂…

小红的合数寻找

A-小红的合数寻找_牛客周赛 Round 79 题目描述 小红拿到了一个正整数 x,她希望你在 [x,2x] 区间内找到一个合数,你能帮帮她吗? 一个数为合数,当且仅当这个数是大于1的整数,并且不是质数。 输入描述 在一行上输入一…

笔灵ai写作技术浅析(三):深度学习

笔灵AI写作的深度学习技术主要基于Transformer架构,尤其是GPT(Generative Pre-trained Transformer)系列模型。 1. Transformer架构 Transformer架构由Vaswani等人在2017年提出,是GPT系列模型的基础。它摒弃了传统的循环神经网络(RNN)和卷积神经网络(CNN),完全依赖自…

IM 即时通讯系统-50-[特殊字符]cim(cross IM) 适用于开发者的分布式即时通讯系统

IM 开源系列 IM 即时通讯系统-41-开源 野火IM 专注于即时通讯实时音视频技术,提供优质可控的IMRTC能力 IM 即时通讯系统-42-基于netty实现的IM服务端,提供客户端jar包,可集成自己的登录系统 IM 即时通讯系统-43-简单的仿QQ聊天安卓APP IM 即时通讯系统-44-仿QQ即…

Zemax 中带有体素探测器的激光谐振腔

激光谐振腔是激光系统的基本组成部分,在光的放大和相干激光辐射的产生中起着至关重要的作用。 激光腔由两个放置在光学谐振器两端的镜子组成。一个镜子反射率高(后镜),而另一个镜子部分透明(输出耦合器)。…

17.2 图形绘制4

版权声明:本文为博主原创文章,转载请在显著位置标明本文出处以及作者网名,未经作者允许不得用于商业目的。 17.2.5 线条样式 C#为画笔绘制线段提供了多种样式:一是线帽(包括起点和终点处)样式&#xff1b…

基于微信小程序的酒店管理系统设计与实现(源码+数据库+文档)

酒店管理小程序目录 目录 基于微信小程序的酒店管理系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、管理员模块的实现 (1) 用户信息管理 (2) 酒店管理员管理 (3) 房间信息管理 2、小程序序会员模块的实现 (1)系统首页 &#xff…

计算机网络 应用层 笔记 (电子邮件系统,SMTP,POP3,MIME,IMAP,万维网,HTTP,html)

电子邮件系统: SMTP协议 基本概念 工作原理 连接建立: 命令交互 客户端发送命令: 服务器响应: 邮件传输: 连接关闭: 主要命令 邮件发送流程 SMTP的缺点: MIME: POP3协议 基本概念…

Golang Gin系列-9:Gin 集成Swagger生成文档

文档一直是一项乏味的工作(以我个人的拙见),但也是编码过程中最重要的任务之一。在本文中,我们将学习如何将Swagger规范与Gin框架集成。我们将实现JWT认证,请求体作为表单数据和JSON。这里唯一的先决条件是Gin服务器。…

零基础学习书生.浦语大模型-入门岛

第一关:Linux基础知识 Cursor连接服务器 使用Remote - SSH插件即可 注:46561:服务器端口号 运行指令 python hello_world.py端口映射 ssh -p 46561 rootssh.intern-ai.org.cn -CNg -L 7860:127.0.0.1:7860 -o StrictHostKeyCheckingno …