selenium之 chromedriver与chrome版本映射表_NLP实战篇之tf2训练与评估

本文是基于tensorflow2.2.0版本,介绍了模型的训练与评估。主要介绍了tf.keras的内置训练过程,包括compile、fit,其中compile中包含优化器、loss与metrics的使用,内置api中还包含了很多辅助工具,在Callback中进行介绍;除了简单的单输入单输出模型之外,本文还介绍了多输入、多输出模型的训练过程。本文中涉及的内容都是内置api,相关内容大都可以进行自定义,自定义相关内容会陆续在后续文章里介绍。

实战系列篇章中主要会分享,解决实际问题时的过程、遇到的问题或者使用的工具等等。如问题分解、bug排查、模型部署等等。相关代码实现开源在:https://github.com/wellinxu/nlp_store ,更多内容关注知乎专栏(或微信公众号):NLP杂货铺。4ec3a5df8989e396f691a0b25f20f49a.png

  • 简单文本分类模型示例
  • 训练与评估流程
    • compile
    • fit
    • Callback
  • 多输入、多输出模型
    • compile
    • fit
  • 参考

简单文本分类模型示例

如下面代码所示,根据【NLP实战篇之tensorflow2.0快速入门】获取一个完整的文本分类示例,其中包含数据获取、数据简单预处理、模型构建、训练与评估。

import tensorflow as tf

# 下载IMDB数据
vocab_size = 10000    # 保留词的个数
imdb = tf.keras.datasets.imdb
(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=vocab_size)

# 一个将单词映射到整数索引的词典
word_index = imdb.get_word_index()   # 索引从1开始
word_index = {k:(v+3) for k,v in word_index.items()}
word_index[""] = 0
word_index[""] = 1
word_index[""] = 2  # unknown
word_index[""] = 3

# 统一文本序列长度
train_data = tf.keras.preprocessing.sequence.pad_sequences(train_data, value=word_index[""], padding="post", truncating="post", maxlen=256)
test_data = tf.keras.preprocessing.sequence.pad_sequences(test_data, value=word_index[""], padding="post", truncating="post", maxlen=256)


# 模型构建
model = tf.keras.Sequential([
        tf.keras.layers.Embedding(vocab_size, 16),    # [batch_size, seq_len, 16]
        tf.keras.layers.GlobalAveragePooling1D(),    # [batch_size, 16]
        tf.keras.layers.Dense(16, activation='relu'),    # [batch_size, 16]
        tf.keras.layers.Dense(1, activation='sigmoid')    # [batch_size, 1]
    ])



# 配置模型训练参数
# model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.BinaryCrossentropy(), metrics=[tf.keras.metrics.BinaryAccuracy()])
# 训练模型
history = model.fit(train_data, train_labels, epochs=40, batch_size=512)
# 评估测试集
model.evaluate(test_data,  test_labels, verbose=2)

训练与评估流程

compile

如上面代码所示,要使用fit进行训练模型,需要制定优化器、损失函数和评价指标,通过compile方法传递给模型。除了这三个参数,compile还包含了其他参数:

  1. optimizer: 字符串(优化器名称)或者优化器实例,具体可看tf.keras.optimizers
  2. loss: 字符串(损失函数名称)、损失函数或损失实例,具体可看tf.keras.losses.Losstf.keras.losses
  3. metrics: 评价方法列表,列表中每个值可以是字符串(度量函数名称)、度量函数或者度量实例,具体可看tf.keras.metrics.Metrictf.keras.metrics
  4. loss_weights: 列表或者字典格式,模型不同输出的loss权重,用于多输出模型
  5. sample_weight_mode:
  6. weighted_metrics: 通过样本权重或者类别权重计算的度量列表
  7. **kwargs: 其他关键字参数

其中优化器、loss和度量方法都有很多内置的api,如下表所示。除此之外,还支持自定义相关函数,此部分内容在后续文章中介绍。

  1. optimizer优化器:Adadelta、Adagrad、Adam、Adamax、Ftrl、SGD、Nadam、RMSprop等
  2. loss损失:BinaryCrossentropy、CategoricalCrossentropy、CategoricalHinge、CosineSimilarity、Hinge、Huber、KLDivergence、LogCosh、MeanAbsoluteError、MeanAbsolutePercentageError、MeanSquaredError、MeanSquaredLogarithmicError、Poisson、SparseCategoricalCrossentropy、SquaredHinge、MAE、MAPE、MSE、MSLE等等
  3. metrics度量:AUC、Accuracy、BinaryAccuracy、BinaryCrossentropy、CategoricalAccuracy、CosineSimilarity、FalseNegatives、FalsePositives、MeanIoU、Mean、MeanAbsoluteError、MeanSquaredError、Precision、Recall、TopKCategoricalAccuracy等等

在训练神经网络的时候,有个比较重要的超参数:学习率,这个参数的大小或者变化,都严重影响着最终模型的效果。通过优化器中learning_rate参数可以设置学习率的大小与变化。learning_rate可以设置为静态的,比如2e-5,或者设置为动态的,tf.keras.optimizers.schedules中已经提供了部分学习率衰减方法,如:ExponentialDecay、InverseTimeDecay、LearningRateSchedule、PiecewiseConstantDecay、PolynomialDecay等等。

fit

fit方法是tf.keras中内置的训练方法,其除了包含必要的输入数据与训练轮次之外,还有包含很多其他参数,如下所示:

  1. batch_size: Integer或None,每次进行梯度更新的样本数量,默认32
  2. epochs: Integer. 模型训练的轮次Number of epochs to train the model.
  3. verbose: 0, 1, or 2. 信息显示模式0 = silent, 1 = progress bar, 2 = one line per epoch
  4. callbacks: keras.callbacks.Callback示例列表
  5. validation_split: 0-1之间的浮点数,将多少比例的训练数据用作验证数据
  6. validation_data: 评估数据
  7. shuffle: Boolean (每一轮训练之前是否扰乱训练数据) or str (for 'batch')
  8. class_weight: 字典,类别权重,类别索引为key,权重为值
  9. sample_weight: 样本权重
  10. initial_epoch: Integer,从第几轮次开始训练
  11. steps_per_epoch: Integer or None,每轮次训练多少步
  12. validation_steps:
  13. validation_batch_size:
  14. validation_freq:
  15. max_queue_size: Integer. 只用于keras.utils.Sequence输入
  16. workers: Integer. 只用于keras.utils.Sequence输入
  17. use_multiprocessing: Boolean. 只用于keras.utils.Sequence输入

其中validation_split参数可以自动分离训练集留作验证数据,class_weight作为类权重,可以缓解类别不平衡的问题,样本权重sample_weight可以起到类似的作用,其控制程度更细致,能够进一步提高难样本的权重,或者降低简单/无效等样本。

Callback

Callback是训练或评估期间,在不同时间点(某个周期开始时、某个批次结束时、某个周期结束时)调用的对象,这些对象可以实现以下行为:

  1. 定义验证模型
  2. 定期或者触发条件进行模型保存
  3. 训练停滞时改变学习率
  4. 训练停滞时微调结构
  5. 训练结束或者触发条件时发送电子邮件或即时消息等

具体使用方式,如下所示:

callbacks = [
    # 提前终止训练
    tf.keras.callbacks.EarlyStopping(monitor="val_loss", min_delta=1e-2, patience=2, verbose=1),
    # 保存中间模型
    tf.keras.callbacks.ModelCheckpoint(filepath="mymodel_{epoch}", save_best_only=True, monitor="val_loss", verbose=1),
    # 可视化化损失与指标
    tf.keras.callbacks.TensorBoard(log_dir="/full_path_to_your_logs", histogram_freq=0, embeddings_freq=0, update_freq="epoch")
]
model.fit(train_data, train_labels, epochs=40, batch_size=512, callbacks=callbacks)

tf.keras.callbacks中提供了一些内置的callback,我们也可以进行自定义,自定义相关内容后续文章介绍。下面展示了keras中内置的callback:

  1. BaseLogger:累积每轮训练的平均指标,这个Callback会被keras模型默认调用
  2. CSVLogger:将每轮的损失与度量结果数据流写入csv文件中
  3. EarlyStopping:当指定的度量指标停止改进时,停止训练
  4. History:将事件记录到History对象中,这个Callback会被keras模型默认调用
  5. LearningRateScheduler:修改学习率
  6. ModelCheckpoint:以某种频率保存模型或权重
  7. ProgbarLogger:向stdout输出度量指标
  8. ReduceLROnPlateau:当某个指标停止改进时降低学习率
  9. RemoteMonitor
  10. TensorBoard:可视化工具
  11. TerminateOnNaN:当遇到NaN损失时终止训练

多输入、多输出模型

2920d5366e99127d061722152312986d.png
前面的示例中,模型都是单个输入与单个输出,但有很多模型是多个输入或输出,例如上图模型结构所示,我们用以下方法构建模型:

image_input = tf.keras.Input(shape=(32, 32, 3), name="img_input")
timeseries_input = tf.keras.Input(shape=(None, 10), name="ts_input")

x1 = tf.keras.layers.Conv2D(3, 3)(image_input)
x1 = tf.keras.layers.GlobalMaxPooling2D()(x1)

x2 = tf.keras.layers.Conv1D(3, 3)(timeseries_input)
x2 = tf.keras.layers.GlobalMaxPooling1D()(x2)

x = tf.keras.layers.concatenate([x1, x2])

score_output = tf.keras.layers.Dense(1, name="score_output")(x)
class_output = tf.keras.layers.Dense(5, activation="softmax", name="class_output")(x)

model = tf.keras.Model(
        inputs=[image_input, timeseries_input], outputs=[score_output, class_output]
    )

compile

如果loss或者metrics参数只有单个传递给模型,则每一个输出都用一个loss或者metrics,但在很多情况下不同的输出需要不同的loss或者metrics,我们就需要对应每个输出给出不同的值,如下面所示:

model.compile(
        optimizer=tf.keras.optimizers.RMSprop(1e-3),
        loss=[tf.keras.losses.MeanSquaredError(), tf.keras.losses.CategoricalCrossentropy()],
        metrics=[
            [
                tf.keras.metrics.MeanAbsolutePercentageError(),
                tf.keras.metrics.MeanAbsoluteError(),
            ],
            [tf.keras.metrics.CategoricalAccuracy()],
        ],
    )

loss与metrics都是list的形式,按照output的顺序对应,而我们上面的模型中,已经给输出层进行了命名,则可以通过字典来制定loss与metrics,当输出超过2个的时候,尤其推荐字典的方式。同时,我们还可以使用loss_weights参数来给不同的输出指定权重,具体使用方法如下:

model.compile(
        optimizer=tf.keras.optimizers.RMSprop(1e-3),
        loss={
            "score_output": tf.keras.losses.MeanSquaredError(),
            "class_output": tf.keras.losses.CategoricalCrossentropy(),
        },
        metrics={
            "score_output": [
                tf.keras.metrics.MeanAbsolutePercentageError(),
                tf.keras.metrics.MeanAbsoluteError(),
            ],
            "class_output": [tf.keras.metrics.CategoricalAccuracy()],
        },
        loss_weights={"score_output": 2.0, "class_output": 1.0},
    )

除此之外,如果某些输出不为训练,只用来预测,则可以写成这样:

    # list的形式
    model.compile(
        optimizer=tf.keras.optimizers.RMSprop(1e-3),
        loss=[None, tf.keras.losses.CategoricalCrossentropy()],
    )

    # 或dict的形式
    model.compile(
        optimizer=tf.keras.optimizers.RMSprop(1e-3),
        loss={"class_output": tf.keras.losses.CategoricalCrossentropy()},
    )

fit

fit在接受多输入的时候,跟上面loss类似,可以使用numpy数组的list或者dict形式,如下所示:

    # 随机生成NumPy数据
    img_data = np.random.random_sample(size=(100, 32, 32, 3))
    ts_data = np.random.random_sample(size=(100, 20, 10))
    score_targets = np.random.random_sample(size=(100, 1))
    class_targets = np.random.random_sample(size=(100, 5))

    # list形式
    model.fit([img_data, ts_data], [score_targets, class_targets], batch_size=32, epochs=1)

    # 或者dict形式
    model.fit(
        {"img_input": img_data, "ts_input": ts_data},
        {"score_output": score_targets, "class_output": class_targets},
        batch_size=32,
        epochs=1,
    )

当然可以将数据转换成Dataset格式,然后传给fit,如下面代码所示:

    # dataset格式
    train_dataset = tf.data.Dataset.from_tensor_slices(
        (
            {"img_input": img_data, "ts_input": ts_data},
            {"score_output": score_targets, "class_output": class_targets},
        )
    )
    train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

    model.fit(train_dataset, epochs=1)

参考

https://www.tensorflow.org/guide/keras/train_and_evaluate

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/556523.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java会被rust替代吗_自从尝了 Rust,Java 突然不香了

Rust 是软件行业中相对而言比较新的一门编程语言,如果从语法上来比较,该语言与 C 其实非常类似,但从另一方面而言,Rust 能更高效地提供许多功能来保证性能和安全。而且,Rust 还能在无需使用传统的垃圾收集系统的情况下…

redis单线程原理___Redis为何那么快-----底层原理浅析

redis单线程原理 redis单线程问题 单线程指的是网络请求模块使用了一个线程(所以不需考虑并发安全性),即一个线程处理所有网络请求,其他模块仍用了多个线程。 1. 为什么说redis能够快速执行 (1) 绝大部分请求是纯粹的内存操作…

asm 查看 数据文件 修改 时间_Oracle的ASM介绍及管理

Oracle的ASM介绍及管理Oracle经历过的文件系统历史操作系统--逻辑卷管理器(LVM):管理文件相对容易,性能较差裸设备:管理文件相对困难,性能好OCFS(Oracle Cluster File System):是ORACLE数据库文件系统ASM(Automatic Storage Manag…

深入理解 Redis Template及4种序列化方式__spring boot整合redis实现RedisTemplate三分钟快速入门

概述 使用Spring 提供的 Spring Data Redis 操作redis 必然要使用Spring提供的模板类 RedisTemplate, 今天我们好好的看看这个模板类 。 RedisTemplate 看看4个序列化相关的属性 ,主要是 用于 KEY 和 VALUE 的序列化 。 举个例子,比如说我们…

java仿聊天室项目总结_Java团队课程设计-socket聊天室(Day4总结篇)

Java团队课程设计-socket聊天室(Day4总结篇)团队名称:ChatRoom项目git地址:git提交记录(仅截取部分):面向对象设计包图、类图包图UML类图总结:首先总结一下这几天遇到的问题和解决方案使用ObjectInputStream/ObjectOutputStream的…

python基础代码技巧_Python 代码优化技巧(二)

Python 是一种脚本语言,相比 C/C 这样的编译语言,在效率和性能方面存在一些不足,但是可以通过代码调整来提高代码的执行效率。本文整理一些代码优化技巧。 代码优化基本原则代码正常运行后优化。 很多人一开始写代码就奔着性能优化的目标&…

rpm 讲解

CentOS7主要有rpm和yum这两种包软件的管理。两种包的管理各有用处,其中主要区别是:YUM使用简单但需要联网,YUM会去网上的YUM包源去获取所需要的软件包。而RPM的需要的操作经度比较细,需要我们做的事情比较多。 软件包的安装和卸是…

java顺序表冒泡排序_冒泡排序就这么简单 - Java3y的个人空间 - OSCHINA - 中文开源技术交流社区...

冒泡排序就这么简单在我大一的时候自学c语言和数据结构,我当时就接触到了冒泡排序(当时使用的是C语言编写的)。现在大三了,想要在暑假找到一份实习的工作,又要回顾一下数据结构与算法的知识点了。排序对我们来说是一点也不陌生了,…

python 多线程和协程结合_如何让 python 处理速度翻倍?内含代码

阿里妹导读:作为在日常开发生产中非常实用的语言,有必要掌握一些python用法,比如爬虫、网络请求等场景,很是实用。但python是单线程的,如何提高python的处理速度,是一个很重要的问题,这个问题的…

python批量生成图_利用Python批量生成任意尺寸的图片

实现效果 通过源图片,在当前工作目录的/img目录下生成1000张,分别从1*1到1000*1000像素的图片。 效果如下:目录结构 实现示例 # -*- coding: utf-8 -*- import threading from PIL import Image image_size range(1, 1001) def start(): for…

Mysql 如果有多个可选条件怎么加索引_MySQL|mysql-索引

1、索引是什么 1.1索引简介 索引是表的目录,是数据库中专门用于帮助用户快速查询数据的一种数据结构。类似于字典中的目录,查找字典内容时可以根据目录查找到数据的存放位置,以及快速定位查询数据。对于索引,会保存在额外的文件…

Spring-bean的循环依赖以及解决方式___Spring源码初探--Bean的初始化-循环依赖的解决

本文主要是分析Spring bean的循环依赖,以及Spring的解决方式。 通过这种解决方式,我们可以应用在我们实际开发项目中。 什么是循环依赖?怎么检测循环依赖Spring怎么解决循环依赖Spring对于循环依赖无法解决的场景Spring解决循环依赖的方式我们…

Spring中bean的作用域与生命周期

在Spring中,那些组成应用程序的主体及由Spring IoC容器所管理的对象,被称之为bean。简单地讲,bean就是由IoC容器初始化、装配及管理的对象,除此之外,bean就与应用程序中的其他对象没有什么区别了。而bean的定义以及bea…

Spring循环依赖的三种方式

引言:循环依赖就是N个类中循环嵌套引用,如果在日常开发中我们用new 对象的方式发生这种循环依赖的话程序会在运行时一直循环调用,直至内存溢出报错。下面说一下Spring是如果解决循环依赖的。 第一种:构造器参数循环依赖 Spring容…

Spring 是如何解决循环依赖的?

1.由同事抛的一个问题开始 最近项目组的一个同事遇到了一个问题,问我的意见,一下子引起的我的兴趣,因为这个问题我也是第一次遇到。平时自认为对spring循环依赖问题还是比较了解的,直到遇到这个和后面的几个问题后,重…

java按钮触发另一个页面_前端跨页面通信,你知道哪些方法?

戳蓝字「前端技术优选」关注我们哦! 引言在浏览器中,我们可以同时打开多个Tab页,每个Tab页可以粗略理解为一个“独立”的运行环境,即使是全局对象也不会在多个Tab间共享。然而有些时候,我们希望能在这些“独立”的Tab页…

【Java用法】java 8两个List集合取交集、并集、差集、去重并集

在业务的开发过程中会经常用到两个List集合相互取值的情况&#xff0c;于是记录在此&#xff0c;方便后续使用哦~~~ public class ListTest {public static void main(String[] args) {ArrayList<String> listA CollectionUtil.toList("a", "b", &…

jsonp react 获取返回值_Django+React全栈开发:文章列表

React现在我们有了一个属于文章的API&#xff0c;可以添加、修改、删除、查看文章&#xff0c;但是对于我们的网站来说&#xff0c;还需要一个用户界面才行。现在开始探索一下ReactJS吧。经常听到有前端三大框架Angular、React、Vue的说法&#xff0c;不过React官网对自己的介绍…

24个经典的MySQL索引问题,你都遇到过哪些?

1、什么是索引&#xff1f; 索引是一种特殊的文件(InnoDB数据表上的索引是表空间的一个组成部分)&#xff0c;它们包含着对数据表里所有记录的引用指针。 索引是一种数据结构。数据库索引&#xff0c;是数据库管理系统中一个排序的数据结构&#xff0c;以协助快速查询、更新数…

java 3 4_Java-3/4_树.md at master · yrcDream/Java-3 · GitHub

树二叉树二叉树具有唯一根节点二叉树每个节点最多有两个孩子&#xff0c;最多有一个父亲二叉树具有天然递归结构二叉树不一定是 “满” 的&#xff1a;一个节点也是二叉树、空节点也是二叉树二叉搜索树(BST)BST 的基本功能public class BST> {private Node root;private int…