CNN/DailyMail训练文本摘要模型

要使用 TensorFlow Datasets (TFDS) 来训练一个文本摘要模型,可以选择一个包含文章和摘要的数据集,例如 CNN/DailyMail 数据集。

这个数据集通常用于训练和评估文本摘要模型。

以下是使用 TFDS 加载数据集并训练一个简单的序列到序列 (seq2seq) 模型的过程。

首先,确保安装了 TensorFlow Datasets:

pip install tensorflow tensorflow-datasets

然后,以下是训练文本摘要模型的完整代码:

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.models import Model
from tensorflow.keras.layers import TextVectorization, Embedding, LSTM, Dense# 加载 CNN/DailyMail 数据集
data, info = tfds.load('cnn_dailymail', with_info=True, as_supervised=True)
train_data, val_data = data['train'], data['validation']# 为了加快演示,我们将只使用一小部分数据
train_data = train_data.take(5000)
val_data = val_data.take(1000)# 定义文本向量化和序列长度
sequence_length = 512
vocab_size = 20000
vectorize_layer = TextVectorization(max_tokens=vocab_size, output_mode='int', output_sequence_length=sequence_length)# 准备数据集
def prepare_dataset(data):articles = data.map(lambda article, summary: article)summaries = data.map(lambda article, summary: summary)vectorize_layer.adapt(articles)vectorized_articles = articles.map(lambda x: vectorize_layer(x))vectorized_summaries = summaries.map(lambda x: vectorize_layer(x))dataset = tf.data.Dataset.zip((vectorized_articles, vectorized_summaries)).batch(32).prefetch(tf.data.AUTOTUNE)return datasettrain_dataset = prepare_dataset(train_data)
val_dataset = prepare_dataset(val_data)# 构建一个简单的 seq2seq 模型
embedding_dim = 128
lstm_units = 256# 编码器
encoder_inputs = tf.keras.Input(shape=(None,), dtype='int64')
encoder_embedding = Embedding(vocab_size, embedding_dim)(encoder_inputs)
_, state_h, state_c = LSTM(lstm_units, return_state=True)(encoder_embedding)
encoder_states = [state_h, state_c]# 解码器
decoder_inputs = tf.keras.Input(shape=(None,), dtype='int64')
decoder_embedding = Embedding(vocab_size, embedding_dim)(decoder_inputs)
decoder_lstm = LSTM(lstm_units, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_embedding, initial_state=encoder_states)
decoder_dense = Dense(vocab_size, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)# 定义模型
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)# 编译模型
model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy', metrics=['accuracy'])# 训练模型
model.fit(train_dataset, epochs=10, validation_data=val_dataset)# 使用模型进行文本摘要
def summarize_text(text, model, vectorize_layer):vectorized_text = vectorize_layer(tf.convert_to_tensor([text]))summary = tf.constant([vectorize_layer.vocab_size - 1], dtype=tf.int64)  # 使用序列结束标记开始for _ in range(sequence_length):predictions = model.predict([vectorized_text, tf.expand_dims(summary, 0)])predicted_id = tf.argmax(predictions[0, -1, :])if predicted_id == 0:breaksummary = tf.concat([summary, [predicted_id]], axis=0)return vectorize_layer.get_vocabulary()[summary.numpy()]# 测试摘要
for article, summary in val_data.take(1):print('原始文章:', article.numpy().decode('utf-8'))print('真实摘要:', summary.numpy().decode('utf-8'))predicted_summary = summarize_text(article.numpy().decode('utf-8'), model, vectorize_layer)print('预测摘要:', ' '.join(predicted_summary))

这段代码做了如下几件事情:

  1. 加载 CNN/DailyMail 数据集,并选择了一小部分数据以便快速演示。
  2. 定义了文本向量化层,用于将文本转换成整数序列。
  3. 准备了训练和验证数据集,应用了文本向量化,并进行了批处理。
  4. 构建了一个简单的 seq2seq 模型,包含了一个编码器和一个解码器,两者都使用了 LSTM 层。
  5. 编译并训练了模型。
  6. 定义了一个函数 summarize_text,用于生成文本摘要。

针对 CNN/DailyMail 数据集的预处理部分的代码,以及每一步的解释:

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.layers import TextVectorization# 加载 CNN/DailyMail 数据集
data = tfds.load('cnn_dailymail', as_supervised=True)
train_data, val_data = data['train'], data['validation']# 定义文本向量化的参数
sequence_length = 512
vocab_size = 20000# 创建一个文本向量化层
vectorize_layer = TextVectorization(max_tokens=vocab_size,  # 设置最大的词汇量output_mode='int',  # 设置输出模式为整数索引output_sequence_length=sequence_length  # 设置输出的序列长度
)# 准备数据集的函数
def prepare_dataset(data):# 将数据集分为文章和摘要articles = data.map(lambda article, summary: article)summaries = data.map(lambda article, summary: summary)# 适应文本向量化层,只对文章进行适应以构建词汇表vectorize_layer.adapt(articles)# 将文章和摘要映射到整数序列vectorized_articles = articles.map(lambda x: vectorize_layer(x))vectorized_summaries = summaries.map(lambda x: vectorize_layer(x))# 将文章和摘要的整数序列打包成一个新的数据集,并进行批处理和预取dataset = tf.data.Dataset.zip((vectorized_articles, vectorized_summaries))dataset = dataset.batch(32).prefetch(tf.data.AUTOTUNE)return dataset# 应用预处理函数到训练和验证数据集
train_dataset = prepare_dataset(train_data)
val_dataset = prepare_dataset(val_data)

解释:

  1. 加载数据集
    使用 tfds.load 函数加载 CNN/DailyMail 数据集。参数 as_supervised=True 表示我们希望以监督学习的格式加载数据集,即每个数据点都包含输入数据(文章)和标签数据(摘要)。

  2. 定义文本向量化参数
    设置序列长度和词汇量的大小。这些参数对于模型处理文本数据非常重要。sequence_length 确定了模型可以处理的最大文章和摘要长度。vocab_size 决定了词汇表的大小,即模型可以识别的不同单词的最大数量。

  3. 创建文本向量化层
    TextVectorization 层用于将文本转换为整数序列。每个整数都对应词汇表中的一个单词。这一步是将自然语言转换为机器学习模型可以处理的格式的关键步骤。

  4. 预处理数据集的函数
    prepare_dataset 函数负责将原始文本数据集转换为模型可以使用的格式。它首先将数据集分为文章和摘要,然后使用 vectorize_layer.adapt 方法来适应(即构建)词汇表。随后,它将文章和摘要映射到整数序列。

  5. 批处理和预取
    batch(32) 方法将数据集划分为大小为 32 的批次,这意味着模型将一次处理 32 篇文章及其相应的摘要。prefetch(tf.data.AUTOTUNE) 方法用于提前准备好接下来的数据批次,这样在模型训练时可以减少 I/O 阻塞,提高训练效率。

  6. 应用预处理
    prepare_dataset 函数被应用到训练和验证数据集上,这样我们就得到了可以直接用于模型训练和评估的数据集。

这个预处理过程是为了简化示例而设定的,并且假设模型是一个基础的 seq2seq 模型。

在实际应用中,您可能需要更复杂的预处理步骤,例如对文本进行清洗、使用子词分词(subword tokenization)等。

为了使模型训练过程中自动保存最佳模型,我们可以使用 ModelCheckpoint 回调。这个回调会在每个训练周期(epoch)结束时运行,并根据我们指定的条件(如验证集上的损失或准确率)保存模型。下面是如何设置 ModelCheckpoint 回调并将其添加到模型训练中的示例:

首先,导入所需的库并设置 ModelCheckpoint 回调:

from tensorflow.keras.callbacks import ModelCheckpoint# 设置模型检查点回调,保存最佳模型
checkpoint_path = "seq2seq_checkpoint.ckpt"
checkpoint_callback = ModelCheckpoint(filepath=checkpoint_path,save_weights_only=True,save_best_only=True,monitor='val_loss',  # 也可以是 'val_accuracy',取决于你想监控的指标mode='min',  # 如果监控的是 'val_loss',则模式是 'min',即越小越好verbose=1  # 打印保存模型的信息
)

接着,将 checkpoint_callback 添加到 fit 方法的 callbacks 参数中:

# 训练模型
model.fit(train_dataset,epochs=10,validation_data=val_dataset,callbacks=[checkpoint_callback]  # 添加回调函数
)

现在,模型会在每个训练周期结束时自动检查验证损失,并在出现更低的验证损失时保存权重。save_weights_only=True 表示只保存模型的权重,而不是整个模型。这样可以节省存储空间,但需要在加载权重时重建模型结构。

如果你想在训练后恢复模型的权重,你可以使用以下代码:

# 假设模型结构已经定义并编译,然后加载权重
model.load_weights(checkpoint_path)

完成这些之后,你可以使用 summarize_text 函数对新文章进行摘要,或者评估模型在验证集上的表现。这里是完整的代码,包含了自动保存回调的设置:

# 定义模型检查点回调
checkpoint_path = "seq2seq_checkpoint.ckpt"
checkpoint_callback = ModelCheckpoint(filepath=checkpoint_path,save_weights_only=True,save_best_only=True,monitor='val_loss',mode='min',verbose=1
)# 训练模型
model.fit(train_dataset,epochs=10,validation_data=val_dataset,callbacks=[checkpoint_callback]  # 添加回调函数
)# 在需要时加载最佳模型权重
model.load_weights(checkpoint_path)# 使用模型进行文本摘要
# ...

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/626179.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

边缘计算的挑战和机遇(结合RDH-EI)

边缘计算的挑战和机遇 边缘计算面临着数据安全与隐私保护、网络稳定性等挑战,但同时也带来了更强的实时性和本地处理能力,为企业降低了成本和压力,提高了数据处理效率。因此,边缘计算既带来了挑战也带来了机遇,需要我…

虚拟机安装openjdk 输入 javac 报错 javac -version bash: javac: 未找到命令... 相似命令是: ‘java‘

问题 [root#localhost java-1.8.0-openjdk-1.8.0.262.b10-1.el7.x86_64]# javac -version bash: javac: 未找到命令... 相似命令是: java [root#localhost java-1.8.0-openjdk-1.8.0.262.b10-1.el7.x86_64]# java -version openjdk version "1.8.0_392" …

JSP-简化

一、引入 在介绍完JSP的概念之后,我们深感JSP页面开发的困难,但是JSP并非没有做丝毫的努力。本篇将介绍JSP的两个快捷开发手段:EL表达式与JSTL。我们已知道JSP的复杂是来自于将HTML代码与Java代码揉杂在一起,于是我们通过上述的两…

Unity之物理系统

专栏的上一篇角色控制器控制角色移动跳崖,这一篇来说说Unity的物理系统。 本篇小编还要带大家做一个碰撞检测效果实例,先放效果图:流星撞击地面产生爆炸效果 一、Rigidbody 我们给胶囊添加了 Rigidbody 组件它才有的重力,我们来…

逸学Docker【java工程师基础】3.3Docker安装nacos

docker pull nacos/nacos-server docker network create nacos_network #创建容器网络 docker run -d \ --name nacos \ --privileged \ --cgroupns host \ --env JVM_XMX256m \ --env MODEstandalone \ --env JVM_XMS256m \ -p 8848:8848/tcp \ -p 9848:9848…

图解拒付平台:如何应对用户的拒付

这是《百图解码支付系统设计与实现》专栏系列文章中的第(5)篇。 本章主要讲清楚支付系统中拒付涉及的基本概念,产品架构、系统架构,以及一些核心的流程和相关领域模型、状态机设计等。 1. 前言 拒付在中国比较少见,但…

教你用五步让千年的兵马俑跳上现代的科目三?

以下是一张我上月去西安拍的兵马俑照片: 使用通义千问,5步就能它舞动起来,跳上现在流行的“科目三”舞蹈。 千年兵马俑跳上科目三 全民舞王 第1步 打开通义千问App,我使用的是华为手机,苹果版的没试; 在…

【算法题】53. 最大子数组和

题目 给你一个整数数组 nums ,请你找出一个具有最大和的连续子数组(子数组最少包含一个元素),返回其最大和。 子数组 是数组中的一个连续部分 示例 1: 输入:nums [-2,1,-3,4,-1,2,1,-5,4] 输出&#x…

把握现货黄金的基本操作技巧

在投资市场这个大舞台上,有各种各样的投资产品供投资者选择,其中黄金作为一种重要的投资资产,一直受到广大投资者的青睐。然而,黄金交易并非易事,需要掌握一定的操作技巧。那么,如何才能把握住现货黄金的操…

vue3-条件渲染

条件渲染 v-if v-if 指令用于条件性地渲染一块内容。这块内容只会在指令的表达式返回真值时才被渲染。 <h1 v-if"awesome">Vue is awesome!</h1>v-else 你也可以使用 v-else 为 v-if 添加一个“else 区块”。 <button click"awesome !awesom…

YOLOv5改进 | 主干篇 | 12月份最新成果TransNeXt特征提取网络(全网首发)

一、本文介绍 本文给大家带来的改进机制是TransNeXt特征提取网络,其发表于2023年的12月份是一个最新最前沿的网络模型&#xff0c;将其应用在我们的特征提取网络来提取特征&#xff0c;同时本文给大家解决其自带的一个报错&#xff0c;通过结合聚合的像素聚焦注意力和卷积GLU&…

CSS-Flex布局

文章目录 一、Flex布局总结 一、Flex布局 Flex布局是一种弹性盒子布局&#xff0c;适用于构建响应式的页面布局 Flex布局是一种弹性盒子布局&#xff0c;适用于构建响应式的页面布局。以下是一些Flex布局的技巧&#xff1a; 使用flex属性设置弹性容器的布局方式&#xff0c;常…

《绝地求生》职业选手画面设置推荐 绝地求生画面怎么设置最好

《绝地求生》画面怎么设置最好是很多玩家心中的疑问&#xff0c;如果性能不是问题无疑高特效显示效果更好&#xff0c;但并不是所有画面参数都利于战斗&#xff0c;今天闲游盒带来分享的《绝地求生》职业选手画面设置推荐&#xff0c;赶紧来看看吧。 当前PUBG的图像设置的重要性…

彝族民居一大特色——土掌房

彝族民居一大特色——土掌房在彝区&#xff0c;各地、各支系传承的居室建筑形式是多种多样的&#xff0c;并与当地的居住习俗有密切关联&#xff0c;从村寨的聚落到住宅的地址&#xff1b;从房间的分置到什物的堆放&#xff1b;从建筑结构到民居信仰和禁忌&#xff0c;都表现出…

Day 24 回溯算法 1

77. 组合 代码随想录 1. 思路 典型的回溯算法&#xff0c;分为以下几步&#xff1a; &#xff08;1&#xff09;确定递归函数 这里递归函数就是每一层的遍历&#xff0c;起名为backtrace。这里遍历需要用for循环的起始终止位置&#xff0c;因此参数为n和k &#xff08;2&am…

Linux知识(未完成)

一、Linux 1.1 Linux 的应用领域 1.1.1 个人桌面领域的应用 此领域是 Linux 比较薄弱的环节但是随着发展&#xff0c;近几年 linux 在个人桌面领域的占有率在逐渐提高 1.1.2 服务器领域 linux 在服务器领域的应用是最高的 linux 免费、稳定、高效等特点在这里得到了很好的…

2024年1月15日

1、桌面应用用到系统本身api 1. 文件系统&#xff08;File System&#xff09;&#xff1a; 使用 Node.js 的 fs 模块来进行文件系统操作&#xff0c;读写文件&#xff0c;创建文件夹等。 2. 操作系统信息&#xff08;Operating System Information&#xff09;&#xff1a; 使…

探寻爬虫世界01:HTML页面结构

文章目录 一、引言&#xff08;一&#xff09;背景介绍&#xff1a;选择爬取51job网站数据的原因&#xff08;二&#xff09;目标与需求明确&#xff1a;爬取51job网站数据的目的与用户需求 二、网页结构探索&#xff08;一&#xff09;51job网页结构分析1、页面组成&#xff1…

【Nuxt3】nuxt3目录文件详情描述:.nuxt、.output、assets、public、utils(一)

简言 nuxt3的中文网站 上次简单介绍了nuxt3创建项目的方法和目录文件大概用处。 这次详细说下.nuxt、.output、assets、public、utils五个文件夹的用处。 正文 .nuxt Nuxt在开发中使用.nuxt/目录来生成你的Vue应用程序。 为了避免将开发构建的输出推送到你的代码仓库中&…

如何在原型中实现继承和多态

在JavaScript中&#xff0c;我们可以通过原型链来实现继承。以下是如何在原型中实现继承的例子&#xff1a; // 定义一个动物原型 var Animal function() {}; Animal.prototype.move function() { console.log(‘This animal can move.’); }; // 定义一个狗的原型&#xf…