昇思25天学习打卡营第11天|基于MindSpore的GPT2文本摘要

数据集

准备nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。

数据需要预处理,如下

原始数据格式:
article: [CLS] article_context [SEP]
summary: [CLS] summary_context [SEP]

预处理后的数据格式:
[CLS] article_context [SEP] summary_context [SEP]

代码示例:

# 下载依赖
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
!pip install mindnlpfrom mindnlp.utils import http_get# 下载数据集
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')from mindspore.dataset import TextFileDataset# 加载数据集
dataset = TextFileDataset(str(path), shuffle=False)
dataset.get_dataset_size()# 按9:1比例拆分训练集、测试集
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)import json
import numpy as np# 数据集数据预处理
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):def read_map(text):data = json.loads(text.tobytes())return np.array(data['article']), np.array(data['summarization'])def merge_and_pad(article, summary):# tokenization# pad to max_seq_length, only truncate the articletokenized = tokenizer(text=article, text_pair=summary,padding='max_length', truncation='only_first', max_length=max_seq_len)return tokenized['input_ids'], tokenized['input_ids']dataset = dataset.map(read_map, 'text', ['article', 'summary'])# change column names to input_ids and labels for the following trainingdataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])dataset = dataset.batch(batch_size)if shuffle:dataset = dataset.shuffle(batch_size)return datasetfrom mindnlp.transformers import BertTokenizertokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
len(tokenizer)train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)
next(train_dataset.create_tuple_iterator())

模型构建

代码示例:

# 构建GPT2ForSummarization模型
from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModelclass GPT2ForSummarization(GPT2LMHeadModel):def construct(self,input_ids = None,attention_mask = None,labels = None,):outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)shift_logits = outputs.logits[..., :-1, :]shift_labels = labels[..., 1:]# Flatten the tokensloss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)return loss# 动态学习率
from mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateScheduleclass LinearWithWarmUp(LearningRateSchedule):"""Warmup-decay learning rate."""def __init__(self, learning_rate, num_warmup_steps, num_training_steps):super().__init__()self.learning_rate = learning_rateself.num_warmup_steps = num_warmup_stepsself.num_training_steps = num_training_stepsdef construct(self, global_step):if global_step < self.num_warmup_steps:return global_step / float(max(1, self.num_warmup_steps)) * self.learning_ratereturn ops.maximum(0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))) * self.learning_rate

模型训练

代码示例:

num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4num_training_steps = num_epochs * train_dataset.get_dataset_size()from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModelconfig = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallbackckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',epochs=1, keep_checkpoint_max=2)trainer = Trainer(network=model, train_dataset=train_dataset,epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
trainer.set_amp(level='O1')  # 开启混合精度trainer.run(tgt_columns="labels")

运行结果:
模型训练结果

模型推理

将向量数据变为中文数据。
代码示例:

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):def read_map(text):data = json.loads(text.tobytes())return np.array(data['article']), np.array(data['summarization'])def pad(article):tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)return tokenized['input_ids']dataset = dataset.map(read_map, 'text', ['article', 'summary'])dataset = dataset.map(pad, 'article', ['input_ids'])dataset = dataset.batch(batch_size)return datasettest_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)
model.set_train(False)
model.config.eos_token_id = model.config.sep_token_id
i = 0
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)output_text = tokenizer.decode(output_ids[0].tolist())print(output_text)i += 1if i == 1:break

截图时间
截图时间

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/43476.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

docker部署ES遇到的问题

在使用docker部署ES 集群时遇到了一些问题&#xff0c;自己记录一下 集群启动成功后有一个节点加入不了集群 failed to validate incoming join request from node Caused by: org.elasticsearch.cluster.coordination.CoordinationStateRejectedException: This node previo…

Java socket 获取gps定位

1.Java socket 获取gps定位的方法 在Java中使用Socket来直接获取GPS定位信息并不直接可行&#xff0c;因为GPS数据通常不是通过Socket通信来获取的。GPS数据通常由设备&#xff08;如智能手机、GPS接收器&#xff09;上的GPS硬件模块生成&#xff0c;并通过操作系统或专门的GP…

Redis安装部署与使用,多实例

一、redis基础 1.1 关系型数据库和NoSQL数据库 数据库主要分为两大类&#xff1a;关系型数据库与 NoSQL 数据库。 关系型数据库&#xff0c;是建立在关系模型基础上的数据库&#xff0c;其借助于集合代数等数学概念和方法来处理数据库中的数据。主流的 MySQL、Oracle、MS SQ…

Python爬虫教程第2篇-reqeusts是最好用的网络请求工具

简介 爬虫第一步就是网络请求&#xff0c;一个好用的网络请求库会非常重要。而requests库就是非常好用的一个http库&#xff0c;pyhon中虽然也有内置的urllib库用于网络请求&#xff0c;但是urllib使用起来比较的麻烦&#xff0c;而且缺少很多实用的高级功能&#xff0c;所以这…

Syncthing一款开源去中心化和点对点文件同步工具

Syncthing&#xff1a;一款开源的文件同步工具&#xff0c;去中心化和点对点加密传输&#xff0c;支持多平台&#xff0c;允许用户在多个设备之间安全、灵活地同步和共享文件&#xff0c;无需依赖第三方云服务&#xff0c;特别适合高安全性和自主控制的文件同步场景。 &#x…

牛客周赛 Round 50

A题&#xff1a;小红的最小最大 思路&#xff1a; 大水题 code&#xff1a; inline void solve() {int a, b, c; cin >> a >> b >> c;if (min(a, b) c > max(a, b)) cout << "YES\n";else cout << "NO\n";return; }…

传感器标定(二)摄像头外参标定(camera2lidar)

一、数据采集 1、ros包数据采集 rosbag record -a -O output_filename --duration6 //设置bag包名字为 my_rosbag.bag rosbag record -a -O my_rosbag.bag --duration62、参数解释 -a&#xff1a;订阅所有话题。-O output_filename&#xff1a;指定输出文件名称。–duration…

使用MySQLInstaller配置MySQL

操作步骤 1.配置High Availability 默认选项Standalone MySQL Server classic MySQL Replication 2.配置Type and Networking ◆端口默认启用TCP/P网络 ◆端口默认为3306 3.配置Account and Roles 设置root账户的密码、添加其他管理员 4.配置Windows Service ◆配置MySQL Serv…

Java线程池及面试题

1.线程池介绍 顾名思义&#xff0c;线程池就是管理一系列线程的资源池&#xff0c;其提供了一种限制和管理线程资源的方式。每个线程池还维护一些基本统计信息&#xff0c;例如已完成任务的数量。 总结一下使用线程池的好处&#xff1a; 降低资源消耗。通过重复利用已创建的…

xcode项目添加README.md文件并进行编辑

想要给xcode项目添加README.md文件其实还是比较简单的&#xff0c;但是对于不熟悉xcode这个工具的人来讲&#xff0c;还是有些陌生&#xff0c;下面简单给大家讲一下流程。 选择“文件”>“新建”>“文件”&#xff0c;在其他&#xff08;滚动到工作表底部&#xff09;下…

【云原生】AWS云平台,ECR推送Helm chart包

文章目录 1、背景信息2、AWS ECR推送OCI1、背景信息 背景一:OCI 是一个围绕容器格式和运行时的开放治理结构,旨在创建开放的行业标准。OCI 由 Docker、CoreOS 和其他容器技术相关的公司于 2015 年创立,现在由 Linux 基金会托管。OCI 的目标是提供一个中立的论坛,以解决容器…

Vue 3<script setup>使用v-for渲染数组中的元素,根据传入id删除数组元素(filter方法根据元素id过滤数组中的元素)

首先&#xff0c;需要在<script setup>中定义组件的数据和方法。然后&#xff0c;在模板中使用v-for来遍历数组并渲染元素&#xff0c;每个元素旁边添加一个删除按钮&#xff0c;并通过点击事件调用删除方法。 <template> <div> <div v-for"item …

Java基础-组件及事件处理(中)

(创作不易&#xff0c;感谢有你&#xff0c;你的支持&#xff0c;就是我前行的最大动力&#xff0c;如果看完对你有帮助&#xff0c;请留下您的足迹&#xff09; 目录 BorderLayout布局管理器 说明&#xff1a; 示例&#xff1a; FlowLayout布局管理器 说明&#xff1a; …

vue extend的作用和使用方法

Vue.extend 是 Vue.js 提供的一个全局 API&#xff0c;用于扩展 Vue 组件。它的作用是创建一个可以被多次使用的组件构造器&#xff0c;可以像普通组件一样使用&#xff0c;并且可以在多个地方可以实例化该组件。 Vue.extend 的原理是通过 Vue.extend 方法创建一个新的构造器&…

【Qt5】入门Qt开发教程,一篇文章就够了(详解含qt源码)

目录 一、Qt概述 1.1 什么是Qt 1.2 Qt的发展史 1.3 Qt的优势 1.4 Qt版本 1.5 成功案例 二、创建Qt项目 2.1 使用向导创建 2.2 一个最简单的Qt应用程序 2.2.1 main函数中 2.2.2 类头文件 2.3 .pro文件 2.4 命名规范 2.5 QtCreator常用快捷键 三、Qt按钮小程序 …

使用Godot4组件制作竖版太空射击游戏_2D卷轴飞机射击-激光组件(二)

文章目录 开发思路发射点添加子弹组件构建子弹处理缩放效果闪光效果 使用Godot4组件制作竖版太空射击游戏_2D卷轴飞机射击&#xff08;一&#xff09; 开发思路 整体开发还是基于组件的思维。相比于工厂模式或者状态机&#xff0c;可能有些老套&#xff0c;但是更容易理解和编…

STM32/GD32驱动步进电机芯片TM2160

文章目录 官方概要简单介绍整体架构流程 官方概要 TMC2160是一款带SPI接口的大功率步进电机驱动IC。它具有业界最先进的步进电机驱动器&#xff0c;具有简单的步进/方向接口。采用外部晶体管&#xff0c;可实现高动态、高转矩驱动。基于TRINAMICs先进的spreadCycle和stealthCh…

STM32 低功耗模式 睡眠、停止和待机 详解

目录 1.睡眠模式&#xff08;Sleep Mode&#xff09; 2.停止模式&#xff08;stop mode&#xff09; 3.待机模式&#xff08;Standby Mode&#xff09; STM32提供了三种低功耗模式&#xff0c;分别是睡眠模式&#xff08;Sleep Mode&#xff09;、停止模式&#xff08;Stop …

Electron 简单搭建项目

准备工作 全局安装 node npm创建文件夹&#xff0c;并执行 npm init安装 electron npm i electron --save-dev在 package.json 配置文件中的scripts字段下增加一条start命令&#xff1a; {"scripts": {"start": "electron ."} }由于配置中的入…

MYSQL八股文汇总

目录 1、三大范式 2、DML 语句和 DDL 语句区别 3、主键和外键的区别 4、drop、delete、truncate 区别 5、基础架构 6、MyISAM 和 InnoDB 有什么区别&#xff1f; 7、推荐自增id作为主键问题 8、为什么 MySQL 的自增主键不连续 9、redo log 是做什么的? 10、redo log…