昇思25天学习打卡营第12天|基于MindSpore的GPT2文本摘要

基于MindSpore的GPT2文本摘要

数据集加载

使用nlpcc2017摘要数据,共包含5万个样本,内容是新闻正文及其摘要。

from mindnlp.utils import http_get
from mindspore.dataset import TextFileDataset# 下载数据集
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')# 加载数据集
dataset = TextFileDataset(str(path), shuffle=False)

数据预处理


train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)

按9:1划分测试集与训练集,randomize表示不对数据进行随机排序,按照原顺序直接拆分。

原始数据格式是:
article: [CLS] article_context [SEP]
summary: [CLS] summary_context [SEP]

期望的预处理后的数据格式是:
[CLS] article_context [SEP] summary_context [SEP]

import json
import numpy as npdef process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):# 加载json格式的数据并转成numpy数组def read_map(text):data = json.loads(text.tobytes())return np.array(data['article']), np.array(data['summarization'])# 使用分词器处理artical和summary。# text=article表示主文本输入,text_pair=summary表示辅助文本# padding指将输入序列填充或截断的最大长度# truncation指定截断策略,only_first表示指仅截断主文本(article)。# 通常是主文本较长需要截断,而辅助文本较短并且需要完整保留。def merge_and_pad(article, summary):tokenized = tokenizer(text=article, text_pair=summary,padding='max_length', truncation='only_first', max_length=max_seq_len)return tokenized['input_ids'], tokenized['input_ids']# 提取article和summarydataset = dataset.map(read_map, 'text', ['article', 'summary'])#  将列名修改为input_ids和labelsdataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])# 将数据进行分批dataset = dataset.batch(batch_size)# 如果shuffle是true,则打乱数据if shuffle:dataset = dataset.shuffle(batch_size)return dataset

这里的tokenizer使用BertTokenizer,因为GPT2没有中文的分词器。

from mindnlp.transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
# 使用刚刚定义的函数分成四批进行处理。
train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)

模型构建

from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModel
​
class GPT2ForSummarization(GPT2LMHeadModel):def construct(self,input_ids = None,attention_mask = None,labels = None,):outputs=super().construct(input_ids=input_ids, attention_mask=attention_mask)shift_logits=outputs.logits[..., :-1, :]shift_labels=labels[..., 1:]loss=ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)return loss

具体解释一下以上代码:

shift 的作用使为了对齐预测和标签,使模型输出和标签对应,从而得到每个位置的预测误差。
举例 I love program
在自回归模型中,模型会根据前文逐步预测(不熟悉的盆友可以看一下上一篇文章:自回归模型与文本生成方法)

  • 输入 “I”,输出 “love”
  • 输入 “I love”,输出 “programming”

也就是位置1的输出对应的是位置2(的标签),位置2的输出对应位置3(的标签)

对应到实际的数据就是

  • outputs.logits[…, :-1, :]
    去除 logits 的最后一个时间步,因为没有标签与之对应。继续以上面的为例,就是去掉programming,因为programing没有后面输出了。
  • labels[…, 1:]
    去除 labels 的第一个时间步,因为没有预测值与之对应。去掉I,因为I前面没有输入,因此也不是输出的一部分。

由此完成了shift错位操作。

搞好数据结构之后再计算损失

  1. 将shift_logits形状调整成二维张量,使每一行对应一个token的预测分布。
  2. 将shift_labels形状调整为一维张量,使每个元素对应一个标签。
  3. 使用cross_entropy()计算交叉熵损失,忽略填充token的损失。

定义学习率warmup

from mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateScheduleclass LinearWithWarmUp(LearningRateSchedule):"""Warmup-decay learning rate."""def __init__(self, learning_rate, num_warmup_steps, num_training_steps):super().__init__()self.learning_rate = learning_rateself.num_warmup_steps = num_warmup_stepsself.num_training_steps = num_training_stepsdef construct(self, global_step):if global_step < self.num_warmup_steps:return global_step / float(max(1, self.num_warmup_steps)) * self.learning_ratereturn ops.maximum(0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))) * self.learning_rate

这里定义了LinearWithWarmUp作为自定义学习率调度类。它可以在训练的初始阶段进行学习率的线性预热,再在剩余的训练步骤中线性衰减

初始化 预热步数、训练步数、学习率
构建时如果步数小于预热步数则开始进行线性增长,从0增长到learning rate。
如果步数大于等于时,则进行线性衰减,由learning rate变回0。maximum保证不会降低到0以下。

模型训练

内容详见注释内容

# 初始化参数
num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4#训练的步数=数据集的大小 乘以 准备完整遍历数据集的次数(epoch)
num_training_steps = num_epochs * train_dataset.get_dataset_size()from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModelconfig = GPT2Config(vocab_size=len(tokenizer))
# 初始化一个用于文本摘要的GPT2模型
model = GPT2ForSummarization(config)# 初始化学习率调度器和adam优化器
lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallback# 设置检查点回调,用于在训练过程中保存模型检查点
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',epochs=1, keep_checkpoint_max=2)# 初始化训练器,设置模型,训练集,训练轮次,优化器和回调函数
trainer = Trainer(network=model, train_dataset=train_dataset,epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
# 开启混合精度训练,以提高训练速度和节省显存
trainer.set_amp(level='O1')  
# 开始训练,并指定目标列(labels)作为标签。
trainer.run(tgt_columns="labels")

模型推理

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):
# 处理测试集的过程和训练集差不多
# 依然是提取出article和sumarization
# 再进行分词,制定最大长度和截断def read_map(text):data = json.loads(text.tobytes())return np.array(data['article']), np.array(data['summarization'])def pad(article):tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)return tokenized['input_ids']dataset = dataset.map(read_map, 'text', ['article', 'summary'])dataset = dataset.map(pad, 'article', ['input_ids'])dataset = dataset.batch(batch_size)return datasettest_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
# 加载已经训练好的模型
model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)
# 设为非训练模式,可以禁用掉一些训练相关的操作如dropout
model.set_train(False)
# 模型的结束标记设置成分隔符标记,这样生成的文本遇到分隔符就会终止
model.config.eos_token_id = model.config.sep_token_id
i = 0for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)output_text = tokenizer.decode(output_ids[0].tolist())i += 1

总结

本章介绍了使用GPT2进行文本总结任务的基本流程,包括数据导入、数据预处理、模型训练、和模型推理。

打卡凭证

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/869063.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MVC 可以把通用命名空间放在配置文件

这种方式的引入,是将命名空间引入到所有视图中了,不需要在使用using单独引用了。

MATLAB实现-基于CNN-LSTM卷积神经网络结合长短期记忆神经网络数据分类预测(多输入多分类)

MATLAB实现-基于CNN-LSTM卷积神经网络结合长短期记忆神经网络数据分类预测&#xff08;多输入多分类&#xff09; 基于CNN-LSTM卷积神经网络结合长短期记忆神经网络数据分类预测&#xff08;多输入多分类&#xff09; 1.数据均为Excel数据&#xff0c;直接替换数据就可以运行…

【ASSEHR出版】第四届现代教育技术与社会科学国际学术会议(ICMETSS 2024)

第四届现代教育技术与社会科学国际学术会议&#xff08;ICMETSS 2024&#xff09;将于2024年8月23-25日在马来西亚 吉隆坡举行。 会议旨在为从事教育相关领域的专家学者、工程技术人员、技术研发人员提供一个共享科研成果和前沿技术&#xff0c;了解学术发展趋势&#xff0c;拓…

非升即走保命刊:审稿速度堪比“水刊”的1区Top,国人优势大,无爆雷风险,2个月可录!

本周投稿推荐 SCI • 地质遥感类&#xff0c;1.0-2.0&#xff08;34天沾边可录&#xff09; • CCF推荐&#xff0c;4.5-5.0&#xff08;2天见刊&#xff09; • 生物医学制药类&#xff08;2天逢投必中&#xff09; EI • 各领域沾边均可&#xff08;2天录用&#xff09…

人工智能+病理组学的交叉课题,患者的临床特征如何收集与整理|顶刊专题汇总·24-07-09

小罗碎碎念 本期文献主题&#xff1a;人工智能病理组学的交叉课题&#xff0c;患者的临床特征如何收集与整理 我们在阅读文献的时候会发现&#xff0c;有的文章会详细给出自己的数据集分析表&#xff0c;分别列出训练集、验证集的数量&#xff0c;以及每个特征对应的患者人数。…

解码技术债:AI代码助手与智能体的革新之道

技术债 技术债可能来源于多种原因&#xff0c;比如时间压力、资源限制、技术选型不当等。它可以表现为代码中的临时性修补、未能彻底解决的设计问题、缺乏文档或测试覆盖等。虽然技术债可以帮助快速推进项目进度&#xff0c;但长期来看&#xff0c;它会增加软件维护的成本和风险…

无线充电宝哪个牌子好?绿联、西圣、小米充电宝测评对比!

随着科技的不断进步和智能设备的普及&#xff0c;无线充电宝逐渐成为了现代人生活中的必需品。它们不仅方便了我们的日常充电需求&#xff0c;更减少了线缆的束缚&#xff0c;提高了使用的便捷性。在众多品牌中&#xff0c;绿联、西圣和小米作为市场上广受好评的无线充电宝品牌…

【FreeRTOS】freeRTOS的版本号在哪个源文件定义

在task.h中定义 可以通过宏 tskKERNEL_VERSION_NUMBER 找到&#xff0c; 具体如下图&#xff1a;记录一下

【系统架构设计】计算机组成与体系结构(一)

计算机组成与体系结构 计算机系统组成计算机硬件组成控制器运算器主存储器辅助存储器输入设备输出设备 计算机系统结构的分类存储程序的概念Flynn分类 复杂指令集系统与精简指令集系统总线 存储器系统流水线 兜兜转转&#xff0c;最后还是回到了4大件&#xff0c;补基础&#x…

通证经济促进企业数字化转型

在数字化时代的大潮中&#xff0c;通证经济犹如一股新兴力量&#xff0c;以其前所未有的创新模式和深远潜力&#xff0c;正悄然重塑着全球经济格局。通证经济生态体系&#xff0c;作为这一变革的核心驱动力&#xff0c;正逐步构建起一个高效、透明且充满创新活力的新经济生态系…

转型之路:从G端项目到梦想领域的跨越

在职业生涯的十字路口&#xff0c;面对公司G端项目减少与潜在的降薪危机&#xff0c;我毅然决定踏上转型之旅&#xff0c;不再让环境的不确定性左右我的未来。毕业两年间&#xff0c;我深耕于建筑行业的G端项目招标投标解决方案&#xff0c;但内心的声音告诉我&#xff0c;是时…

汇川伺服 (2)DDR、MSI电机、SV510、SV520、SV660软件简单调试

一、DDR DDR 简介 应用场合 二、MSI电机系列 综合概述 三、SV510压合伺服 四、SV520 相序辨识 角度辨识 五、SV660 六、简单调试 两种不同的显示状态 状态显示参数 调试案例 设置账户密码 面板JOG功能 DO强制输出 惯量辨识 计算驱动器电阻 负载惯量比 计算案例&#…

免费试用Aicbo AI绘图软件,你的艺术梦想触手可及

最近AI绘图技术风靡全球&#xff0c;今天要给大家推荐一款集成了免费试用AI绘图软件的神器&#xff0c;即便你是从零开始&#xff0c;也能迅速掌握&#xff0c;创作出令人惊叹的艺术作品。平台是叫&#xff1a;Aicbo 这款神器设计人性化&#xff0c;操作极其简便&#xff0c;只…

python-课程满意度计算(赛氪OJ)

[题目描述] 某个班主任对学生们学习的的课程做了一个满意度调查&#xff0c;一共在班级内抽取了 N 个同学&#xff0c;对本学期的 M 种课程进行满意度调查。他想知道&#xff0c;有多少门课是被所有调查到的同学都喜欢的。输入格式&#xff1a; 第一行输入两个整数 N , M 。 接…

【BUG】RestTemplate发送Post请求后,响应中编码为gzip而导致的报错

BUG描述 20240613-09:59:59.062|INFO|null|810184|xxx|xxx||8|http-nio-xxx-exec-1|com.xxx.jim.xxx.XXXController.?.?|MSG接收到来自xxx的文件请求 headers:[host:"xxx", accept:"text/html,application/json,application/xhtmlxml,application/xml;q0.9,*…

智启未来,共筑工业软件新梦 ——清华大学博士生天洑软件实习启航

2024年6月30日&#xff0c;清华大学工程物理系、深圳国际研究生院、航天航空学院、机械工程系、能源与动力工程系的10名博士研究生抵达南京天洑软件有限公司&#xff0c;正式开启为期6周的博士生必修环节社会实践。 “天洑软件清华基地”成立于2021年&#xff0c;旨在为清华理工…

C编程使用clock函数实现计算一段代码的执行时间:毫秒单位

一、函数原型 在Linux系统中&#xff0c;clock()函数是一个非常重要且常用的函数&#xff0c;它主要用于测量程序运行的CPU时间。这个函数是C/C语言中的一个标准函数&#xff0c;其原型定义在<time.h>头文件中。以下是对clock()函数的详细解析&#xff1a; #include <…

uniapp安卓端实现语音合成播报

最初尝试使用讯飞语音合成方式,能获取到语音数据,但是数据是base64格式的,在安卓端无法播放,网上有说通过转成blob格式的url可以播放,但是uniapp不支持转换的api;于是后面又想其他办法,使用安卓插件播报原生安卓语音播报插件 - DCloud 插件市场 方案一(讯飞语音合成) 1.在讯飞…

C++ 可调用对象

文章目录 概述1.函数以及函数指针函数函数指针 2.成员函数指针3.lamda表达式4.函数对象&#xff08;Func&#xff09;5.通过 std::function 包装的可调用对象 小结 概述 在C中&#xff0c;“可调用对象”&#xff08;Callable&#xff09;是一个可以被调用的对象&#xff0c;它…

Codeforces Round 956 F. array-value 【01Trie查询异或最小值】

题意 给定一个非负整数数组 a a a 对每个长度至少为 2 2 2 的子数组&#xff0c;定义其权值为&#xff1a;子数组内两两异或值最小值 即 b ⊂ a [ l , r ] , w ( b ) min ⁡ l ≤ i < j ≤ r { a i ⨁ a j } b \subset a[l, r], \quad w(b) \min_{l \leq i < j \le…