《昇思25天学习打卡营第22天|基于MindSpore的GPT2文本摘要》

#学习打卡第22天#

1. 数据集

1.1 数据下载

        使用nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。

from mindnlp.utils import http_get
from mindspore.dataset import TextFileDataset# download dataset
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')# load dataset
dataset = TextFileDataset(str(path), shuffle=False, num_samples=2500)
dataset.get_dataset_size()# split into training and testing dataset
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)

1.2 数据预处理

原始数据格式:

        article: [CLS] article_context [SEP]
        summary: [CLS] summary_context [SEP]
预处理后的数据格式:

        [CLS] article_context [SEP] summary_context [SEP]

import json
import numpy as np
from mindnlp.transformers import BertTokenizer# preprocess dataset
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):def read_map(text):data = json.loads(text.tobytes())return np.array(data['article']), np.array(data['summarization'])def merge_and_pad(article, summary):# tokenization# pad to max_seq_length, only truncate the articletokenized = tokenizer(text=article, text_pair=summary,padding='max_length', truncation='only_first', max_length=max_seq_len)return tokenized['input_ids'], tokenized['input_ids']dataset = dataset.map(read_map, 'text', ['article', 'summary'])# change column names to input_ids and labels for the following trainingdataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])dataset = dataset.batch(batch_size)if shuffle:dataset = dataset.shuffle(batch_size)return dataset# We use BertTokenizer for tokenizing chinese context.
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
len(tokenizer)train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)

2. 模型构建

构建GPT2ForSummarization模型

from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModel
from mindspore.nn.learning_rate_schedule import LearningRateScheduleclass LinearWithWarmUp(LearningRateSchedule):"""Warmup-decay learning rate."""def __init__(self, learning_rate, num_warmup_steps, num_training_steps):super().__init__()self.learning_rate = learning_rateself.num_warmup_steps = num_warmup_stepsself.num_training_steps = num_training_stepsdef construct(self, global_step):if global_step < self.num_warmup_steps:return global_step / float(max(1, self.num_warmup_steps)) * self.learning_ratereturn ops.maximum(0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))) * self.learning_rateclass GPT2ForSummarization(GPT2LMHeadModel):def construct(self,input_ids = None,attention_mask = None,labels = None,):outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)shift_logits = outputs.logits[..., :-1, :]shift_labels = labels[..., 1:]# Flatten the tokensloss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)return loss

3. 模型训练

from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModelfrom mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallbacknum_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4
num_training_steps = num_epochs * train_dataset.get_dataset_size()config = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)
# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',epochs=1, keep_checkpoint_max=2)trainer = Trainer(network=model, train_dataset=train_dataset,epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
trainer.set_amp(level='O1')  # 开启混合精度
trainer.run(tgt_columns="labels")

4. 模型推理

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):def read_map(text):data = json.loads(text.tobytes())return np.array(data['article']), np.array(data['summarization'])def pad(article):tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)return tokenized['input_ids']dataset = dataset.map(read_map, 'text', ['article', 'summary'])dataset = dataset.map(pad, 'article', ['input_ids'])dataset = dataset.batch(batch_size)return datasettest_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)
model.set_train(False)
model.config.eos_token_id = model.config.sep_token_idi = 0
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)output_text = tokenizer.decode(output_ids[0].tolist())print(output_text)i += 1if i == 1:break

5. 心得总结

        GPT-2是OpenAI推出的基于Transformer的生成式预训练模型,擅长文本生成。在文本摘要任务中,GPT-2通过预训练学习语言模式,再通过微调适应摘要任务,能有效提取文章要点并生成简洁摘要。

        尽管GPT-2在文本摘要任务中表现出色,但它也面临一些挑战和限制。例如生成的摘要可能包含不准确或冗余的信息,且模型的可解释性较低。此外,GPT-2的计算资源需求较高,限制了其在资源受限环境中的应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/46468.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

活用 localStorage

我维护的这款工具 https://editor.yunwow.cn/ 已经帮我写了 7 篇文章了&#xff0c; 用起来很顺手&#xff0c;因此我打算再给它升级下让它更方便&#xff0c;我决定要给它加个本地缓存功能。我给它提的要求是&#xff1a; 1. 至少能缓存 5 篇文章 2. 能有选择的加载模板 3…

MySQL-对数据库和表的DDL命令

文章目录 一、什么是DDL操作二、数据库编码集和数据库校验集三、使用步骤对数据库的增删查改1.创建数据库2.进入数据库3.显示数据库4.修改数据库mysqldump 5.删除数据库 对表的增删查改1.添加/创建表2.插入表内容3.查看表查看所有表查看表结构查看表内容 4.修改表修改表的名字修…

融云:换头像=换人设?社交应用中隐秘而重要的「用户信息管理」

当代年轻人失眠三大原因&#xff0c;最近新上的《喜人奇妙夜》帮你找到了—— 基金绿了、吵架输了、前任头像换了。 当你半夜翻看前任的社交账号&#xff0c;一场盛大的失眠就开始了&#xff0c;就算古希腊掌柜睡眠的神躺你旁边也不好使。即便 Ta 没有更新内容&#xff0c;昵…

Redis 中String类型操作命令(命令演示,时间复杂度,返回值,注意事项)

String 类型 文章目录 String 类型set 命令get 命令mset 命令mget 命令get 和 mget 的区别incr 命令incrby 命令decr 命令decrby 命令incrbyfloat 命令append 命令getrange 命令setrange 命令 字符串类型是 Redis 中最基础的数据类型&#xff0c;在讲解命令之前&#xff0c;我们…

Linux的load(负载)

负载(load)是Linux机器的一个重要指标&#xff0c;直观了反应了机器当前的状态。 在Linux系统中&#xff0c;系统负载是对当前CPU工作量的度量&#xff0c;被定义为特定时间间隔内运行队列中的平均线程数。 Linux的负载高&#xff0c;主要是由于CPU使用、内存使用、10消…

新款S32K3 MCU可解决汽车软件开发的成本和复杂性问题(器件编号包含S32K322E、S32K322N、S32K328)

全新的S32K3系列专门用于车身电子系统、电池管理和新兴的域控制器&#xff0c;利用涵盖网络安全、功能安全和底层驱动程序的增强型封装持续简化软件开发。 相关产品&#xff1a;S32K328NHT1VPCSR S32K328GHT1MPCSR S32K322NHT0VPASR S32K322EHT0VPBSR S32K322NHT0VPBSR S32K32…

Doris数据库---建表、调整表结构操作

一、简介 本文章主讲创建 Doris 自维护的表的语法&#xff0c;以下为本人最近为数据中台接入doris所踩的坑及其解决方案&#xff0c;欢迎点评。 二、doris建表语法&#xff1a; 官网建表语法网址链接&#xff1a;CREATE-TABLE - Apache Doris 官网建表语法如图所示&#xf…

【C++】构造函数详解

&#x1f4e2;博客主页&#xff1a;https://blog.csdn.net/2301_779549673 &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; 如有错误敬请指正&#xff01; &#x1f4e2;本文由 JohnKi 原创&#xff0c;首发于 CSDN&#x1f649; &#x1f4e2;未来很长&#…

windows服务器搭建区块链环境(node.js+truffle+ganache)

windows服务器搭建区块链环境&#xff08;node.jstruffleganache&#xff09; 1&#xff0c;安装node.js中文版的2&#xff0c;更改下载源3&#xff0c;安装truffle4&#xff0c;安装ganache&#xff08;可以跳过使用ganache-cli&#xff09;5&#xff0c;安装ganache-cli&…

starRocks搭建

公司要使用新的大数据架构&#xff0c;打算用国产代替国外的大数据平台。所以这里我就纠结用doris还是starrocks&#xff0c;如果用doris&#xff0c;因为是开源的&#xff0c;以后就可以直接用云厂商的。如果用starrocks就得自己搭建&#xff0c;但是以后肯定会商业化&#xf…

医院护士站卫星电子钟,时间精准,为众人提供精确的时间引导

在医院这个充满紧张与关怀的环境中&#xff0c;每一刻的时间都承载着生命的重量。医院护士站卫星电子钟以其精准的时间显示&#xff0c;成为了为众人提供精确时间引导的重要存在。 一、医院卫星电子钟应用原因 首先&#xff0c;护士站是医院内信息交流和医疗服务协调的核心区域…

Springboot自定义banner启动动画

一、banner文件自定义编写 1、创建banner文件 banner文件的文件名称默认为“banner.txt”&#xff0c;这个在SpringApplication.java中定义的 一般自定义就新建一个banner.txt文件,放在项目resources中。这时在banner.txt中编写启动动画展示内容。例如&#xff1a; banner.t…

【排序算法】—— 归并排序

归并排序时间复杂度O(NlongN)&#xff0c;空间复杂度O(N)&#xff0c;是一种稳定的排序&#xff0c;其次可以用来做外排序算法&#xff0c;即对磁盘(文件)上的数据进行排序。 目录 一、有序数组排序 二、排序思路 三、递归实现 四、非递归实现 一、有序数组排序 要理解归…

mysql(5.5)启动服务和环境配置

正常启动 参考&#xff1a;Javaweb基础之mysql回溯笔记(一) 总的来说就是在mysql的安装目录下&#xff0c;找到bin下面的msyqld.exe&#xff0c;双击即启动了mysql服务&#xff1b; 启动方式二 也可以直接找到windows的服务项进行启动&#xff0c;操作如下&#xff1a; 打开…

Mac电脑下运行java命令行出现:错误: 找不到或无法加载主类

mac 电脑 问题复现 随手写了一个main方法&#xff0c;想用命令行操作 进入 BlockDemo.java 所在目录&#xff1a; wnwangnandeMBP wn % cd /Users/wn/IdeaProjects/test/JianZhiOffer/src/main/java/com/io/wn wnwangnandeMBP wn % ls -l total 16 -rw-r--r-- 1 wangnan …

换手机了怎么恢复微信聊天记录?教你3招实用技巧

随着科技的飞速发展&#xff0c;手机更新换代的速度也越来越快。当我们换上一部新手机时&#xff0c;最头疼的问题之一往往是如何将旧手机中的重要数据&#xff0c;尤其是微信聊天记录&#xff0c;迁移到新手机上。微信聊天记录不仅记录了我们的日常沟通&#xff0c;还承载了许…

踩坑日记 | 记一次流程图问题排查

踩坑日记&#xff1a;记一次流程图问题排查 标签&#xff1a; activiti | 流程 引言 今天排查了一个流程图问题&#xff0c;耗时2个小时终于解决&#xff0c;记录下来 现象 流程审批驳回报错&#xff1a;Unknown property used in expression: ${xxxx} 使用的是 activiti …

[C/C++入门][循环]12、等差数列和等差数列末项计算

等差数列是什么&#xff1f; 想象一下&#xff0c;你获得了一个神奇的糖果盒&#xff0c;他有一个神奇的功能&#xff0c;每次你打开盒子时&#xff0c;里面都会多出同样数量的糖。你只要给里面放上1颗糖&#xff0c;然后想着可以多几颗&#xff0c;比如我希望打开的时候多两颗…

【C++练级之路】【Lv.26】类型转换

快乐的流畅&#xff1a;个人主页 个人专栏&#xff1a;《算法神殿》《数据结构世界》《进击的C》 远方有一堆篝火&#xff0c;在为久候之人燃烧&#xff01; 文章目录 一、C风格类型转换1.1 隐式类型转换1.2 显式类型转换 二、C风格类型转换2.1 static_cast2.2 dynamic_cast2.3…

配置Redis时yml的格式导致报错

报错如下 java.lang.IllegalStateException: Failed to load ApplicationContext at org.springframework.test.context.cache.DefaultCacheAwareContextLoaderDelegate.loadContext(DefaultCacheAwareContextLoaderDelegate.java:98) at org.springframework.test.context.su…