昇思25天学习打卡营第11天|基于MindSpore的GPT2文本摘要

数据集

准备nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。

数据需要预处理,如下

原始数据格式:
article: [CLS] article_context [SEP]
summary: [CLS] summary_context [SEP]

预处理后的数据格式:
[CLS] article_context [SEP] summary_context [SEP]

代码示例:

# 下载依赖
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
!pip install mindnlpfrom mindnlp.utils import http_get# 下载数据集
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')from mindspore.dataset import TextFileDataset# 加载数据集
dataset = TextFileDataset(str(path), shuffle=False)
dataset.get_dataset_size()# 按9:1比例拆分训练集、测试集
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)import json
import numpy as np# 数据集数据预处理
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):def read_map(text):data = json.loads(text.tobytes())return np.array(data['article']), np.array(data['summarization'])def merge_and_pad(article, summary):# tokenization# pad to max_seq_length, only truncate the articletokenized = tokenizer(text=article, text_pair=summary,padding='max_length', truncation='only_first', max_length=max_seq_len)return tokenized['input_ids'], tokenized['input_ids']dataset = dataset.map(read_map, 'text', ['article', 'summary'])# change column names to input_ids and labels for the following trainingdataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])dataset = dataset.batch(batch_size)if shuffle:dataset = dataset.shuffle(batch_size)return datasetfrom mindnlp.transformers import BertTokenizertokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
len(tokenizer)train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)
next(train_dataset.create_tuple_iterator())

模型构建

代码示例:

# 构建GPT2ForSummarization模型
from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModelclass GPT2ForSummarization(GPT2LMHeadModel):def construct(self,input_ids = None,attention_mask = None,labels = None,):outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)shift_logits = outputs.logits[..., :-1, :]shift_labels = labels[..., 1:]# Flatten the tokensloss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)return loss# 动态学习率
from mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateScheduleclass LinearWithWarmUp(LearningRateSchedule):"""Warmup-decay learning rate."""def __init__(self, learning_rate, num_warmup_steps, num_training_steps):super().__init__()self.learning_rate = learning_rateself.num_warmup_steps = num_warmup_stepsself.num_training_steps = num_training_stepsdef construct(self, global_step):if global_step < self.num_warmup_steps:return global_step / float(max(1, self.num_warmup_steps)) * self.learning_ratereturn ops.maximum(0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))) * self.learning_rate

模型训练

代码示例:

num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4num_training_steps = num_epochs * train_dataset.get_dataset_size()from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModelconfig = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallbackckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',epochs=1, keep_checkpoint_max=2)trainer = Trainer(network=model, train_dataset=train_dataset,epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
trainer.set_amp(level='O1')  # 开启混合精度trainer.run(tgt_columns="labels")

运行结果:
模型训练结果

模型推理

将向量数据变为中文数据。
代码示例:

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):def read_map(text):data = json.loads(text.tobytes())return np.array(data['article']), np.array(data['summarization'])def pad(article):tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)return tokenized['input_ids']dataset = dataset.map(read_map, 'text', ['article', 'summary'])dataset = dataset.map(pad, 'article', ['input_ids'])dataset = dataset.batch(batch_size)return datasettest_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)
model.set_train(False)
model.config.eos_token_id = model.config.sep_token_id
i = 0
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)output_text = tokenizer.decode(output_ids[0].tolist())print(output_text)i += 1if i == 1:break

截图时间
截图时间

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/43476.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis安装部署与使用,多实例

一、redis基础 1.1 关系型数据库和NoSQL数据库 数据库主要分为两大类&#xff1a;关系型数据库与 NoSQL 数据库。 关系型数据库&#xff0c;是建立在关系模型基础上的数据库&#xff0c;其借助于集合代数等数学概念和方法来处理数据库中的数据。主流的 MySQL、Oracle、MS SQ…

Python爬虫教程第2篇-reqeusts是最好用的网络请求工具

简介 爬虫第一步就是网络请求&#xff0c;一个好用的网络请求库会非常重要。而requests库就是非常好用的一个http库&#xff0c;pyhon中虽然也有内置的urllib库用于网络请求&#xff0c;但是urllib使用起来比较的麻烦&#xff0c;而且缺少很多实用的高级功能&#xff0c;所以这…

Syncthing一款开源去中心化和点对点文件同步工具

Syncthing&#xff1a;一款开源的文件同步工具&#xff0c;去中心化和点对点加密传输&#xff0c;支持多平台&#xff0c;允许用户在多个设备之间安全、灵活地同步和共享文件&#xff0c;无需依赖第三方云服务&#xff0c;特别适合高安全性和自主控制的文件同步场景。 &#x…

使用MySQLInstaller配置MySQL

操作步骤 1.配置High Availability 默认选项Standalone MySQL Server classic MySQL Replication 2.配置Type and Networking ◆端口默认启用TCP/P网络 ◆端口默认为3306 3.配置Account and Roles 设置root账户的密码、添加其他管理员 4.配置Windows Service ◆配置MySQL Serv…

Java线程池及面试题

1.线程池介绍 顾名思义&#xff0c;线程池就是管理一系列线程的资源池&#xff0c;其提供了一种限制和管理线程资源的方式。每个线程池还维护一些基本统计信息&#xff0c;例如已完成任务的数量。 总结一下使用线程池的好处&#xff1a; 降低资源消耗。通过重复利用已创建的…

xcode项目添加README.md文件并进行编辑

想要给xcode项目添加README.md文件其实还是比较简单的&#xff0c;但是对于不熟悉xcode这个工具的人来讲&#xff0c;还是有些陌生&#xff0c;下面简单给大家讲一下流程。 选择“文件”>“新建”>“文件”&#xff0c;在其他&#xff08;滚动到工作表底部&#xff09;下…

Java基础-组件及事件处理(中)

(创作不易&#xff0c;感谢有你&#xff0c;你的支持&#xff0c;就是我前行的最大动力&#xff0c;如果看完对你有帮助&#xff0c;请留下您的足迹&#xff09; 目录 BorderLayout布局管理器 说明&#xff1a; 示例&#xff1a; FlowLayout布局管理器 说明&#xff1a; …

【Qt5】入门Qt开发教程,一篇文章就够了(详解含qt源码)

目录 一、Qt概述 1.1 什么是Qt 1.2 Qt的发展史 1.3 Qt的优势 1.4 Qt版本 1.5 成功案例 二、创建Qt项目 2.1 使用向导创建 2.2 一个最简单的Qt应用程序 2.2.1 main函数中 2.2.2 类头文件 2.3 .pro文件 2.4 命名规范 2.5 QtCreator常用快捷键 三、Qt按钮小程序 …

使用Godot4组件制作竖版太空射击游戏_2D卷轴飞机射击-激光组件(二)

文章目录 开发思路发射点添加子弹组件构建子弹处理缩放效果闪光效果 使用Godot4组件制作竖版太空射击游戏_2D卷轴飞机射击&#xff08;一&#xff09; 开发思路 整体开发还是基于组件的思维。相比于工厂模式或者状态机&#xff0c;可能有些老套&#xff0c;但是更容易理解和编…

STM32/GD32驱动步进电机芯片TM2160

文章目录 官方概要简单介绍整体架构流程 官方概要 TMC2160是一款带SPI接口的大功率步进电机驱动IC。它具有业界最先进的步进电机驱动器&#xff0c;具有简单的步进/方向接口。采用外部晶体管&#xff0c;可实现高动态、高转矩驱动。基于TRINAMICs先进的spreadCycle和stealthCh…

STM32 低功耗模式 睡眠、停止和待机 详解

目录 1.睡眠模式&#xff08;Sleep Mode&#xff09; 2.停止模式&#xff08;stop mode&#xff09; 3.待机模式&#xff08;Standby Mode&#xff09; STM32提供了三种低功耗模式&#xff0c;分别是睡眠模式&#xff08;Sleep Mode&#xff09;、停止模式&#xff08;Stop …

MYSQL八股文汇总

目录 1、三大范式 2、DML 语句和 DDL 语句区别 3、主键和外键的区别 4、drop、delete、truncate 区别 5、基础架构 6、MyISAM 和 InnoDB 有什么区别&#xff1f; 7、推荐自增id作为主键问题 8、为什么 MySQL 的自增主键不连续 9、redo log 是做什么的? 10、redo log…

App H5+ 实现下载、查看功能 前后端实现(SpringBoot)

<!doctype html><html><head><meta charset"utf-8"><title>维修指南</title><meta name"viewport" content"widthdevice-width, initial-scale1.0, minimum-scale0, maximum-scale0.85, user-scalableyes&quo…

下半年交火点:智驾全国都能开,智舱多模态大模型

“你猜一猜我现在参加什么样的活动呢&#xff1f;” “你参加的是WAIC&#xff0c;就是那个人工智能的大Party&#xff0c;超多科技高手都在这……” “你帮我介绍一下这本书吧。” “这书叫《反脆弱&#xff0c;从不确定性中获益》&#xff0c;讲的是怎么在混乱里找机会&am…

搞不清啊?伦敦金与上海金区别是?

进入黄金市场的朋友&#xff0c;有可能会被各式各样的黄金交易品种带得眼花缭乱&#xff0c;其实各品种虽然都以黄金作为投资标的物&#xff0c;但是也是各有不同的&#xff0c;下面我们就来比较一下相似的投资品种——伦敦金和上海金。 首先在比较之前&#xff0c;我们要搞清楚…

基于泰坦尼克号生还数据进行 Spark 分析

基于泰坦尼克号生还数据进行 Spark 分析 在这篇博客中&#xff0c;我们将展示如何使用 Apache Spark 分析著名的泰坦尼克号数据集。通过这篇教程&#xff0c;您将学习如何处理数据、分析乘客的生还情况&#xff0c;并生成有价值的统计信息。 数据解析 • PassengerId &#…

快速排序[原理,C++实现,注意事项,时间复杂度分析]

模板&#xff1a; //本模板来自ACwing void quick_sort(int q[],int l,int r) {if(l>r) return;int xq[lr>>1],il-1,jr1;while(i<j){do i;while(q[i]<x);do j--;while(q[j]>x); if(i<j) swap(q[i],q[j]);}quick_sort(q,l,j);quick_sort(q,j1,r); };原理&…

注册中心组成结构和基本原理解析

假如你正在设计和开发一个分布式服务系统&#xff0c;系统中存在一批能够独立运行的服务&#xff0c;而在部署上也采用了集群模式以防止出现单点故障。显然&#xff0c;对于一个完整的业务系统而言&#xff0c;这些服务之间需要相互调用并形成复杂的访问链路&#xff0c;一种可…

codesys多段直线电机跨电机控制

1. 电机描述 在X轴上有多段直线电机&#xff0c;如下图有9个&#xff0c;从X1到X9. 2.codesys程序结构 程序名称&#xff1a;Pou_two_motors 动作名称&#xff1a;ACT_move 把这个程序搞到任务配置里面 通过ethercat总线命名一下这些电机&#xff0c;方便调用。 3.程序内容 P…

油烟监测仪:守护厨房,让蓝天白云成为常态

夏日炎炎&#xff0c;白天的酷暑让人们更加向往夜晚的凉爽与惬意。在这样的季节里&#xff0c;品尝各式烧烤、小龙虾&#xff0c;再搭配一杯冰镇啤酒&#xff0c;成为了许多市民夜晚消遣的不二选择。然而&#xff0c;随之而来的餐饮油烟问题也进入了高发阶段&#xff0c;对周边…