Day10【基于encoder- decoder架构实现新闻文本摘要的提取】

实现新闻文本摘要的提取

      • 1. 概述与背景
      • 2.参数配置
      • 3.数据准备
      • 4.数据加载
      • 5.主程序
      • 6.预测评估
      • 7.生成效果
      • 8.总结

1. 概述与背景

新闻摘要生成是自然语言处理(NLP)中的一个重要任务,其目标是自动从长篇的新闻文章中提取出简洁、准确的摘要。近年来,基于深度学习的摘要生成方法已成为主流,尤其是采用 Encoder-Decoder 架构的模型。这个架构在机器翻译、文本摘要、文档标注、多模态交互等领域取得了显著的效果。

本文基于现有数据集,先将输入的新闻文本数据和对应的标题摘要在已知词表上序列化,然后将序列化后的输入索引数据(作为输入文本数据)和标签索引数据(作为生成式文本摘要标签)共同输入到Encoder-Decoder模型架构中得到输出预测的文本摘要数据,之后将输出的预测文本摘要数据以及另一份标签索引数据(作为真实的文本标签)两者使用交叉熵损失函数计算loss,最后反向传播更新梯度。

2.参数配置

config.py

# -*- coding: utf-8 -*-"""
配置参数信息
"""
import os
import torchConfig = {"model_path": "output","input_max_length": 120,"output_max_length": 30,"epoch": 200,"batch_size": 32,"optimizer": "adam","learning_rate":1e-3,"seed":42,"vocab_size":6219,"vocab_path":"vocab.txt","train_data_path": r"sample_data.json","valid_data_path": r"sample_data.json","beam_size":5}

3.数据准备

词表文件vocab.txt词表文件
新闻文本数据训练和验证数据

4.数据加载

loader.py

# -*- coding: utf-8 -*-import json
import torch
from torch.utils.data import DataLoader
"""
数据加载
"""class DataGenerator:def __init__(self, data_path, config, logger):self.config = configself.logger = loggerself.path = data_pathself.vocab = load_vocab(config["vocab_path"])self.config["vocab_size"] = len(self.vocab)self.config["pad_idx"] = self.vocab["[PAD]"]self.config["start_idx"] = self.vocab["[CLS]"]self.config["end_idx"] = self.vocab["[SEP]"]self.load()def load(self):self.data = []with open(self.path, encoding="utf8") as f:for i, line in enumerate(f):line = json.loads(line)title = line["title"]content = line["content"]self.prepare_data(title, content)return#文本到对应的index#头尾分别加入[cls]和[sep]def encode_sentence(self, text, max_length, with_cls_token=True, with_sep_token=True):input_id = []if with_cls_token:input_id.append(self.vocab["[CLS]"])for char in text:input_id.append(self.vocab.get(char, self.vocab["[UNK]"]))if with_sep_token:input_id.append(self.vocab["[SEP]"])input_id = self.padding(input_id, max_length)return input_id#补齐或截断输入的序列,使其可以在一个batch内运算def padding(self, input_id, length):input_id = input_id[:length]input_id += [self.vocab["[PAD]"]] * (length - len(input_id))return input_id#输入输出转化成序列def prepare_data(self, title, content):input_seq = self.encode_sentence(content, self.config["input_max_length"], False, False) #输入序列output_seq = self.encode_sentence(title, self.config["output_max_length"], True, False) #输出序列gold = self.encode_sentence(title, self.config["output_max_length"], False, True) #不进入模型,用于计算lossself.data.append([torch.LongTensor(input_seq),torch.LongTensor(output_seq),torch.LongTensor(gold)])returndef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]def load_vocab(vocab_path):token_dict = {}with open(vocab_path, encoding="utf8") as f:for index, line in enumerate(f):token = line.strip()token_dict[token] = indexreturn token_dict#用torch自带的DataLoader类封装数据
def load_data(data_path, config, logger, shuffle=True):dg = DataGenerator(data_path, config, logger)dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)return dl

输入数据和标签的编码主要通过 encode_sentence 方法实现。具体来说,输入数据(如新闻内容)和标签(如新闻标题)都需要转化为对应的索引序列,以便供模型进行训练。编码过程如下:

  1. 输入数据(content)编码encode_sentence 方法将新闻内容转换为词汇表中的索引序列。首先,如果需要,添加 [CLS] 标记作为序列的开始,然后遍历文本中的每个字符,将其映射为词汇表中的索引,如果词汇表中没有该字符,则使用 [UNK](未知词)表示。最后,如果需要,添加 [SEP] 标记作为序列的结束。生成的索引序列会通过 padding 方法填充或截断至预设的最大长度。

  2. 标签数据(title)编码:标签(即标题)也会通过 encode_sentence 方法进行编码,步骤与输入数据类似,因为标题是需要预测生成表示要输出的序列,因此会包含 [CLS] 标记作为开头,不包含 [SEP],以区分输入和输出。

  3. 计算损失的 gold 序列:在训练中,为了计算损失,gold 序列会与输出序列相似,作为真实的标签,在它后面包含 [SEP] 标记和输出序列对齐,作为模型训练时的目标序列。

  4. 生成解码过程:模型训练完毕后,Decoder会根据输入的Encoder编码向量及输出序列的第一个标记CLS输出第一个预测的token,根据输入的Encoder编码向量及输出序列(第一个标记CLS+生成的前一个token)输出第二个预测token,之后再根据输入的Encoder编码向量及输出序列(第一个标记CLS+生成的前2个token)输出第三个预测token,以此类推。直到输出最后一个预测的tokenSEP时,生成解码过程结束。

  5. 在这里插入图片描述
    通过这样的编码方式,输入数据和标签数据被转化为整数索引序列,并进行填充或截断,以确保它们具有相同的长度,从而可以批量处理并输入到模型进行训练。

5.主程序

# -*- coding: utf-8 -*-
import sys
import torch
import random
import os
import numpy as np
import time
import logging
import json
from config import Config
from evaluate import Evaluator
from loader import load_data#这个transformer是本文件夹下的代码,和我们之前用来调用bert的transformers第三方库是两回事
from transformer.Models import Transformerlogging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)"""
模型训练主程序
"""# seed = Config["seed"]
# random.seed(seed)
# np.random.seed(seed)
# torch.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)def choose_optimizer(config, model):optimizer = config["optimizer"]learning_rate = config["learning_rate"]if optimizer == "adam":return torch.optim.Adam(model.parameters(), lr=learning_rate)elif optimizer == "sgd":return torch.optim.SGD(model.parameters(), lr=learning_rate)def main(config):#创建保存模型的目录if not os.path.isdir(config["model_path"]):os.mkdir(config["model_path"])#加载模型logger.info(json.dumps(config, ensure_ascii=False, indent=2))model = Transformer(config["vocab_size"], config["vocab_size"], 0, 0,d_word_vec=128, d_model=128, d_inner=256,n_layers=1, n_head=2, d_k=64, d_v=64,)# 标识是否使用gpucuda_flag = torch.cuda.is_available()if cuda_flag:logger.info("gpu可以使用,迁移模型至gpu")model = model.cuda()#加载优化器optimizer = choose_optimizer(config, model)# 加载训练数据train_data = load_data(config["train_data_path"], config, logger)#加载效果测试类evaluator = Evaluator(config, model, logger)#加载lossloss_func = torch.nn.CrossEntropyLoss(ignore_index=0)#训练for epoch in range(config["epoch"]):epoch += 1model.train()if cuda_flag:model.cuda()logger.info("epoch %d begin" % epoch)train_loss = []for index, batch_data in enumerate(train_data):if cuda_flag:batch_data = [d.cuda() for d in batch_data]input_seq, target_seq, gold = batch_datapred = model(input_seq, target_seq)loss = loss_func(pred, gold.view(-1))train_loss.append(float(loss))loss.backward()optimizer.step()optimizer.zero_grad()logger.info("epoch average loss: %f" % np.mean(train_loss))evaluator.eval(epoch)model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)torch.save(model.state_dict(), model_path)returnif __name__ == "__main__":main(Config)

主程序主要实现了基于Transformer架构的模型训练过程。在训练过程中,首先通过配置文件Config获取相关参数,并根据配置创建一个Transformer模型。训练过程在指定的轮次(epoch)内进行,每一轮开始时,首先设定模型为训练模式。接着,对于每个训练批次,输入数据(input_seq)、目标序列(target_seq)和真实标签(gold)被送入模型中进行前向传播,计算出模型预测值(pred)。通过交叉熵损失函数(CrossEntropyLoss)与真实标签进行对比,得到当前批次的损失。损失值会被累积并进行反向传播(loss.backward()),优化器更新参数(optimizer.step()),并清空梯度缓存(optimizer.zero_grad())。每一轮训练结束后,打印出平均损失值并进行模型效果评估。

6.预测评估

evaluate.py

# -*- coding: utf-8 -*-
from loader import load_data
from collections import defaultdict
from transformer.Translator import Translator"""
模型效果测试
"""class Evaluator:def __init__(self, config, model, logger):self.config = configself.model = modelself.logger = loggerself.valid_data = load_data(config["valid_data_path"], config, logger, shuffle=False)self.reverse_vocab = dict([(y, x) for x, y in self.valid_data.dataset.vocab.items()])self.translator = Translator(self.model,config["beam_size"],config["output_max_length"],config["pad_idx"],config["pad_idx"],config["start_idx"],config["end_idx"])def eval(self, epoch):self.logger.info("开始测试第%d轮模型效果:" % epoch)self.model.eval()self.model.cpu()self.stats_dict = defaultdict(int)  # 用于存储测试结果for index, batch_data in enumerate(self.valid_data):input_seqs, target_seqs, gold = batch_datafor input_seq in input_seqs:generate = self.translator.translate_sentence(input_seq.unsqueeze(0))print("输入:", self.decode_seq(input_seq))print("输出:", self.decode_seq(generate))breakreturndef decode_seq(self, seq):pre_seq = []for idx in seq:if idx < 6 :continuechar = self.reverse_vocab[int(idx)]pre_seq.append(char)return "".join(pre_seq)

在模型的评估过程中,验证集数据被加载并逐批传入模型进行推理。每一批数据中的输入序列通过 Translator 进行翻译,生成相应的预测输出。预测过程通常涉及使用模型的前向传播,将输入序列转化为目标语言的输出。为了评估模型效果,生成的输出是通过索引序列的方式进行表示,而这些索引随后会被映射回具体的词汇,通过反向词汇表解码为可读的文本。每次翻译后,模型的输入和生成的输出都会被打印出来,以便进行直观的对比。通过反复的测试与评估,能够逐步提高模型的准确性和生成质量。

7.生成效果

训练200轮效果:

2025-04-19 12:44:56,206 - __main__ - INFO - epoch 200 begin
2025-04-19 12:44:57,086 - __main__ - INFO - epoch average loss: 0.416101
2025-04-19 12:44:57,086 - __main__ - INFO - 开始测试第200轮模型效果:
输入: 阿根廷布宜诺斯艾利斯省奇尔梅斯市一服装店,8个月内被抢了三次。最后被抢劫的经历,更是直接让老板心理崩溃:歹徒在抢完不久后发现衣服“抢错了尺码”,理直气壮地拿着衣服到店里换,老板又不敢声张,只好忍气吞声。(中国新闻网)
输出: 阿根廷歹徒抢服装尺码不对拿回店里换
输入: 就俄罗斯免费医疗话题,国家卫生计生委国际司司长任明辉表示,真正的免费医疗制度不存在。或由税收支持,或个人和企业支付的医疗保险社会保险解决。免费医疗国家的患者看病不花钱,费用在各种税收或缴纳的保险中体现了。(网图)
输出: 卫生计生委国际司司长:真正的免费医疗不存在
输入: 6月合格境外机构投资者(QFII)加快入市步伐。据中登公司发布的20136月份统计月报显示,QFII基金6月份在沪深两市分别新增开户1415个A股股票账户,这29个账户让QFII在沪深两市的总账户数达到465个。
输出: 6月QFII积极入市新增开户户9户
输入: 路透社消息,一艘从利比亚横渡地中海开往意大利的偷渡船倾覆,约400人身亡。船上载有550多名偷渡客,许多是年轻人和儿童,大部分来自撒哈拉以南非洲地区。事发后意大利海防部队展开搜救,获救的150人被送往意大利南部港口。
输出: 从利比亚开往意大利:400偷渡客沉船身亡

8.总结

本文实现了一个基于 Transformer Encoder-Decoder 架构的新闻摘要生成系统。通过使用词汇表将输入数据和目标输出数据转化为索引序列,并通过交叉熵损失函数训练模型,模型通过 Beam Search 解码生成摘要。训练过程中使用了多轮的模型评估和优化,使得最终模型能够生成简洁、准确的新闻摘要。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/77416.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【大疆dji】ESDK开发环境搭建(软件准备篇)

接上一篇【大疆dji】ESDK开发环境搭建&#xff08;硬件准备篇&#xff09; 1. 编译环境 ESDK 提供 x86_64/aarch64 基于 Linux 平台 Ubuntu 发行版操作系统构建的静态库&#xff0c;运行 demo 先正确安装所需的依赖包。arm32位就不支持了。建议使用编译安装的方式&#xff0c;…

Java数据结构——ArrayList

Java中ArrayList 一 ArrayList的简介二 ArrayList的构造方法三 ArrayList常用方法1.add()方法2.remove()方法3.get()和set()方法4.index()方法5.subList截取方法 四 ArrayList的遍历for循环遍历增强for循环(for each)迭代器遍历 ArrayList问题及其思考 前言 ArrayList是一种 顺…

【信息获取能力】

第一层&#xff1a;表象观察 现象&#xff1a;AI系统&#xff08;如GPT-4&#xff09;可以瞬间调用并整合全球互联网上的公开信息&#xff0c;而人类即使穷尽一生也无法完成同等规模的知识储备。 底层逻辑&#xff1a; 存储与检索效率&#xff1a;人类大脑的记忆容量有限&…

03、GPIO外设(三):标准库代码示例

标准库代码示例 1、点亮LED2、LED闪烁3、LED流水灯4、按键控制LED5、蜂鸣器 本章源代码链接&#xff1a; 链接: link 1、点亮LED 实验要求&#xff1a;点亮LED ①LED.c文件的代码如下: #include "LED.h"/*** LED引脚初始化*//* 定义数组&#xff0c;想要添加引脚…

卷积神经网络(CNN)与VGG16在图像识别中的实验设计与思路

卷积神经网络&#xff08;CNN&#xff09;与VGG16在图像识别中的实验设计与思路 以下从基础原理、VGG16架构解析、实验设计步骤三个层面展开说明&#xff0c;结合代码示例与关键参数设置&#xff0c;帮助理解其应用逻辑。 一、CNN与VGG16的核心差异 基础CNN结构 通常包含33~55个…

java导出word含表格并且带图片

背景 我们需要通过 Java 动态导出 Word 文档&#xff0c;基于预定义的 模板文件&#xff08;如 .docx 格式&#xff09;。模板中包含 表格&#xff0c;程序需要完成以下操作&#xff1a; 替换模板中的文本&#xff08;如占位符 ${设备类型} 等&#xff09;。 替换模板中的图…

Oracle19C低版本一天遭遇两BUG(ORA-04031/ORA-600)

昨天帮朋友看一个系统异常卡顿的案例&#xff0c;在这里分享给大家 环境&#xff1a;Exadata X8M 数据库版本19.11 1.系统报错信息 表象为系统卡顿&#xff0c;页面无法刷出&#xff0c;登陆到主机上看到节点1 系统等待存在大量的 cursor: pin S wait on X等待 查看两个节…

2025年Q1数据安全政策、规范、标准以及报告汇总共92份(附下载)

一、政策演进趋势分析 &#xff08;一&#xff09;国家级政策新动向 数据要素市场建设 数据流通安全治理方案&#xff08;重点解析数据确权与交易规则&#xff09; 公共数据授权运营规范&#xff08;创新性提出分级授权机制&#xff09; 新兴技术安全规范 人工智能安全标准…

ERR_PNPM_DLX_NO_BIN No binaries found in tailwindcss

场景复现&#xff1a; 最近在vue3项目中安装了tailwindcss&#xff0c;但是它默认帮我安装的版本是4XX的&#xff0c;导致我执行 npx tailwindcss init -p报错了。 解决方案&#xff1a; 更改tailwindcss的版本为3 pnpm add -D tailwindcss3再次执行生成tailwindcss的初始…

第 4 篇:Motion 拖拽与手势动画(交互篇)—— 打造直觉化交互体验

Framer Motion 的拖拽与手势系统让实现复杂交互变得异常简单。本文将深入解析核心 API&#xff0c;并通过实战案例演示如何创造自然流畅的交互体验。 &#x1f9f2; 拖拽动画基础 1. 启用拖拽 使用 drag 属性即可开启拖拽能力。支持的值有&#xff1a;true&#xff08;全方向…

CF148D Bag of mice

题目传送门 思路 状态设计 设 d p i , j dp_{i, j} dpi,j​ 表示袋中有 i i i 个白鼠和 j j j 个黑鼠时&#xff0c; A A A 能赢的概率。 状态转移 现在考虑抓鼠情况&#xff1a; A A A 抓到白鼠&#xff1a;直接判 A A A 赢&#xff0c;概率是 i i j \frac{i}{i j}…

BT1120 BT656驱动相关代码示例

前些年做视频输出项目的时候用过bt1120 tx与rx模块&#xff0c;现将部分代码进行记录整理。代码功能正常&#xff0c;可正常应用。 1. rx部分&#xff1a; /****************************************************************************** Copyright (C) 2021,All rights …

服务器简介(含硬件外观接口介绍)

服务器&#xff08;Server&#xff09;是指提供资源、服务、数据或应用程序的计算机系统或设备。它通常比普通的个人计算机更强大、更可靠&#xff0c;能够长时间无间断运行&#xff0c;支持多个用户或客户端的请求。简单来说&#xff0c;服务器就是专门用来存储、管理和提供数…

SQL-exists和in核心区别​、 性能对比​、适用场景​

EXISTS和IN的基本区别。IN用于检查某个值是否在子查询返回的结果集中,而EXISTS用于检查子 查询是否至少返回了一行数据。通常来说,EXISTS在子查询结果集较大时表现更好,因为一旦找 到匹配项就会停止搜索,而IN则需要遍历整个结果集。 在 SQL 中,EXISTS 和 IN 都可以用于…

焕活身心,解锁健康养生新方式

健康养生是一门科学&#xff0c;更是一种生活智慧。从日常点滴做起&#xff0c;才能筑牢健康根基。​ 饮食上&#xff0c;应遵循 “食物多样&#xff0c;谷类为主” 原则。多摄入新鲜蔬果&#xff0c;它们富含维生素与膳食纤维&#xff0c;有助于增强免疫力&#xff1b;选择全…

QT+Cmake+mingw32-make编译64位的zlib-1.3.1源码成功过程

由于开源的软件zlib库是很多相关库libpng等基础库&#xff0c;因此掌握使用mingw编译器来编译zlib源码的步骤十分重要。本文主要是通过图文模式讲解完整的qtcmakezlib源码搭建和测试过程&#xff0c;为后续的其他源码编译环境搭建做基础准备。 详细步骤如下&#xff1a; 1、下…

健身会员管理系统(ssh+jsp+mysql8.x)含运行文档

健身会员管理系统(sshjspmysql8.x) 对健身房的健身器材、会员、教练、办卡、会员健身情况进行管理&#xff0c;可根据会员号或器材进行搜索&#xff0c;查看会员健身情况或器材使用情况。

【langchain4j】Springboot如何接入大模型以及实战开发-AI问答助手(一)

langchain4j介绍 官网地址&#xff1a;https://docs.langchain4j.dev/get-started langchain4j可以说是java和spring的关系&#xff0c;spring让我们开发java应用非常简单&#xff0c;那么langchain4j对应的就是java开发ai的 “Spring” 他集成了AI应用的多种场景&#xff0c…

平均池化(Average Pooling)

1. 定义与作用​​ ​​平均池化​​是一种下采样操作&#xff0c;通过对输入区域的数值取​​平均值​​来压缩数据空间维度。其核心作用包括&#xff1a; ​​降低计算量​​&#xff1a;减少特征图尺寸&#xff0c;提升模型效率。​​保留整体特征​​&#xff1a;平滑局部…

【dify实战】chatflow结合deepseek实现基于自然语言的数据库问答、Echarts可视化展示、Excel报表下载

dify结合deepseek实现基于自然语言的数据库问答、Echarts可视化展示、Excel报表下载 观看视频&#xff0c;您将学会 在dify下如何快速的构建一个chatflow&#xff0c;来完成数据分析工作&#xff1b;如何在AI的回复中展示可视化的图表&#xff1b;如何在AI 的回复中加入Excel报…