昇思MindSpore学习总结十六 —— 基于MindSpore的GPT2文本摘要

1、mindnlp 版本要求

!pip install tokenizers==0.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp

2、数据集加载与处理

2.1 数据集加载

 本次实验使用的是nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。

from mindspore.dataset import TextFileDataset  # 从mindspore.dataset模块中导入TextFileDataset类# load dataset  # 加载数据集
dataset = TextFileDataset(str(path), shuffle=False)  # 创建一个TextFileDataset实例,参数是文件路径(path)转换成字符串格式,shuffle=False表示不打乱数据顺序
dataset.get_dataset_size()  # 获取数据集的大小,即数据集中样本的数量

# split into training and testing dataset  # 将数据集分割为训练集和测试集
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)  # 将数据集按比例[0.9, 0.1]分割为训练集和测试集,randomize=False表示不随机打乱数据

 2.2 数据预处理

import json  # 导入json模块,用于处理JSON数据
import numpy as np  # 导入numpy模块,并简写为np,用于处理数组和矩阵# preprocess dataset  # 预处理数据集
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):# 定义一个嵌套函数read_map,用于读取并解析JSON文本数据def read_map(text):data = json.loads(text.tobytes())  # 将文本数据转换为字节后用json.loads解析为Python字典return np.array(data['article']), np.array(data['summarization'])  # 返回文章和摘要的numpy数组# 定义一个嵌套函数merge_and_pad,用于合并并填充数据def merge_and_pad(article, summary):# tokenization  # 进行分词操作# pad to max_seq_length, only truncate the article  # 填充到最大序列长度,仅截断文章部分tokenized = tokenizer(text=article, text_pair=summary,padding='max_length', truncation='only_first', max_length=max_seq_len)  # 使用tokenizer对文章和摘要进行分词,填充到最大长度,仅截断文章部分return tokenized['input_ids'], tokenized['input_ids']  # 返回分词后的输入ID(注意:这里的input_ids和labels是相同的)dataset = dataset.map(read_map, 'text', ['article', 'summary'])  # 使用read_map函数对数据集进行映射,提取文章和摘要# change column names to input_ids and labels for the following training  # 更改列名为input_ids和labels,以便后续训练dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])  # 使用merge_and_pad函数对数据进行映射,生成input_ids和labelsdataset = dataset.batch(batch_size)  # 将数据集按批次大小进行分批处理if shuffle:dataset = dataset.shuffle(batch_size)  # 如果shuffle为True,则对批次进行随机打乱return dataset  # 返回预处理后的数据集

 因GPT2无中文的tokenizer,我们使用BertTokenizer替代。

from mindnlp.transformers import BertTokenizer  # 从mindnlp.transformers模块中导入BertTokenizer类# We use BertTokenizer for tokenizing Chinese context.  # 我们使用BertTokenizer对中文内容进行分词
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')  # 使用预训练的'bert-base-chinese'模型初始化BertTokenizer
len(tokenizer)  # 获取tokenizer的词汇表大小

train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)  # 使用process_dataset函数对训练数据集进行预处理,传入参数包括训练数据集、分词器和批次大小为4
next(train_dataset.create_tuple_iterator())  # 创建一个tuple迭代器并获取其第一个元素

 3、模型构建

3.1 构建GPT2ForSummarization模型,注意shift right的操作。

from mindspore import ops  # 从mindspore模块导入ops操作
from mindnlp.transformers import GPT2LMHeadModel  # 从mindnlp.transformers模块中导入GPT2LMHeadModel类# 定义一个用于摘要生成的GPT2模型类,继承自GPT2LMHeadModel
class GPT2ForSummarization(GPT2LMHeadModel):# 定义模型的构造函数def construct(self,input_ids=None,  # 输入IDattention_mask=None,  # 注意力掩码labels=None,  # 标签):# 调用父类的construct方法,获取模型输出outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)shift_logits = outputs.logits[..., :-1, :]  # 移动logits,使其与shift_labels对齐shift_labels = labels[..., 1:]  # 移动标签,使其与shift_logits对齐# Flatten the tokens  # 将tokens展平loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)  # 计算交叉熵损失,忽略填充的tokenreturn loss  # 返回计算的损失

 3.2 动态学习率

from mindspore import ops  # 从mindspore模块导入ops操作
from mindspore.nn.learning_rate_schedule import LearningRateSchedule  # 从mindspore.nn.learning_rate_schedule模块导入LearningRateSchedule类# 定义一个线性学习率衰减与热身相结合的学习率调度类,继承自LearningRateSchedule
class LinearWithWarmUp(LearningRateSchedule):"""Warmup-decay learning rate.  # 热身-衰减学习率。"""def __init__(self, learning_rate, num_warmup_steps, num_training_steps):super().__init__()  # 调用父类的构造函数self.learning_rate = learning_rate  # 初始化学习率self.num_warmup_steps = num_warmup_steps  # 初始化热身步数self.num_training_steps = num_training_steps  # 初始化训练步数# 定义构造函数def construct(self, global_step):# 如果当前步数小于热身步数if global_step < self.num_warmup_steps:return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate  # 线性增加学习率# 否则,学习率进行线性衰减return ops.maximum(0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))) * self.learning_rate  # 计算并返回衰减后的学习率

 4、模型训练

num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4num_training_steps = num_epochs * train_dataset.get_dataset_size()
from mindspore import nn  # 从mindspore模块导入nn(神经网络)模块
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel  # 从mindnlp.transformers模块导入GPT2Config和GPT2LMHeadModel类# 配置GPT2模型的配置
config = GPT2Config(vocab_size=len(tokenizer))  # 创建GPT2配置实例,设置词汇表大小为tokenizer的长度
model = GPT2ForSummarization(config)  # 使用配置实例创建一个GPT2ForSummarization模型# 创建学习率调度器
lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)  # 创建线性热身-衰减学习率调度器# 创建优化器
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)  # 使用AdamWeightDecay优化器,并传入模型的可训练参数和学习率调度器
# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))

from mindnlp._legacy.engine import Trainer  # 从mindnlp._legacy.engine模块导入Trainer类
from mindnlp._legacy.engine.callbacks import CheckpointCallback  # 从mindnlp._legacy.engine.callbacks模块导入CheckpointCallback类# 创建一个CheckpointCallback实例,用于保存检查点
ckpoint_cb = CheckpointCallback(save_path='checkpoint',  # 检查点保存路径ckpt_name='gpt2_summarization',  # 检查点文件名epochs=1,  # 每个epoch保存一次检查点keep_checkpoint_max=2  # 最多保留两个检查点
)# 创建一个Trainer实例,用于训练模型
trainer = Trainer(network=model,  # 要训练的模型train_dataset=train_dataset,  # 训练数据集epochs=1,  # 训练的epoch数optimizer=optimizer,  # 优化器callbacks=ckpoint_cb  # 回调函数,包括检查点回调
)trainer.set_amp(level='O1')  # 开启混合精度训练,级别设置为'O1'

下面这段代码,运行时间较长,最好选择较高算力。 

trainer.run(tgt_columns="labels")  # 运行训练器,指定目标列为“labels”

配置不够,训练时间太长。 

5、模型推理

数据处理,将向量数据变为中文数据

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):# 定义一个嵌套函数read_map,用于读取并解析JSON文本数据def read_map(text):data = json.loads(text.tobytes())  # 将文本数据转换为字节后用json.loads解析为Python字典return np.array(data['article']), np.array(data['summarization'])  # 返回文章和摘要的numpy数组# 定义一个嵌套函数pad,用于对文章进行分词和填充def pad(article):tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)  # 对文章进行分词,截断至最大长度减去摘要长度return tokenized['input_ids']  # 返回分词后的输入IDdataset = dataset.map(read_map, 'text', ['article', 'summary'])  # 使用read_map函数对数据集进行映射,提取文章和摘要dataset = dataset.map(pad, 'article', ['input_ids'])  # 使用pad函数对文章进行分词和填充,生成input_idsdataset = dataset.batch(batch_size)  # 将数据集按批次大小进行分批处理return dataset  # 返回预处理后的数据集
test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
# 创建一个tuple迭代器并获取其第一个元素,以NumPy数组的形式输出,并打印出来
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))
model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)  # 从预训练的检查点加载模型
model.set_train(False)  # 设置模型为评估模式(非训练模式)
model.config.eos_token_id = model.config.sep_token_id  # 设置模型的eos_token_id为sep_token_id
i = 0  # 初始化计数器为0# 遍历测试数据集的迭代器,获取输入ID和原始摘要
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():# 使用模型生成新的摘要,参数包括最大新生成的token数量、束搜索的束数、不重复的ngram大小output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)# 将生成的ID转换为文本output_text = tokenizer.decode(output_ids[0].tolist())print(output_text)  # 打印生成的摘要文本i += 1  # 计数器加1if i == 1:  # 如果计数器达到1break  # 跳出循环,仅生成并打印一个摘要

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/48501.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用Amazon Web Services Lambda把天气预报推送到微信

最近北京开始下雨&#xff0c;开始和同事打赌几点能够雨停&#xff0c;虽然Iphone已经提供了实时天气&#xff0c;但是还是想用国内的API试试看看是不是更加准确些。 以下是我使用的服务&#xff1a; 地图SDK/APP获取 经纬度彩云天气API 通过地理位置获取天气信息Lambda 作为…

关于Mysql的面试题(实时更新中~)

一、主键约束与“not null unique”区别 1、作为Primary Key的域/域组不能为null&#xff0c;而Unique Key可以。 2、在一个表中只能有一个Primary Key&#xff0c;而多个Unique Key可以同时存在。unique not null 可以 将表的一列或多列定义为唯一性属性&#xff0c;而prima…

buu做题(6)

目录 [GWCTF 2019]我有一个数据库 [WUSTCTF2020]朴实无华 [GWCTF 2019]我有一个数据库 什么都没有, 尝试用dirsearch扫一下目录 可以扫到一个 /phpmyadmin 可以直接进入到数据库里面 但里面没什么东西 可以看到它的版本不是最新的, 搜一下相关的漏洞 phpMyAdmin 4.8.1后台文…

go关于string与[]byte再学深一点

目标&#xff1a;充分理解string与[]bytes零拷贝转换的实现 先回顾下string与[]byte的基本知识 1. string与[]byte的数据结构 reflect包中关于字符串的数据结构 // StringHeader is the runtime representation of a string.type StringHeader struct {Data uintptrLen int} …

ClickHouse 入门(一)【基本特点、数据类型与表引擎】

前言 今天开始学习 ClickHouse &#xff0c;一种 OLAP 数据库&#xff0c;实时数仓中用到的比较多&#xff1b; 1、ClickHouse 入门 ClickHouse 是俄罗斯的 Yandex&#xff08;搜索引擎公司&#xff09;在 2016 年开源的列式存储数据库&#xff08;HBase 也是列式存储&#xf…

某宝同款度盘不限速后台系统源码

简介&#xff1a; 某宝同款度盘不限速后台系统源码&#xff0c;验证已被我去除&#xff0c;两个后端系统&#xff0c;账号和卡密系统 第一步安装宝塔&#xff0c;部署卡密系统&#xff0c;需要环境php7.4 把源码丢进去&#xff0c;设置php7.4&#xff0c;和伪静态为thinkphp…

山东济南十大杰出人物起名大师颜廷利:影响世界的思想家哲学家教育家

在宇宙的广袤舞台上&#xff0c;各类智者以他们独特的方式揭示着世界的奥秘。数学家们在无尽的符号与公式中穿梭&#xff0c;像探索者般解锁着自然界的深层逻辑。考古学家们则跋涉于古老的土地&#xff0c;用他们的双手拂去岁月的尘埃&#xff0c;让沉睡的历史重见天日。 二十一…

spss是什么软件?spss有什么用

spss是什么软件&#xff1f; SPSS是一款数据统计、分析软件&#xff0c;它由IBM公司出品&#xff0c;这款软件平台提供了文本分析、大量的机器学习算法、数据分析模型、高级统计分析功能等&#xff0c;软件易学且功能非常强大&#xff0c;可以使用SPSS制作图表&#xff0c;例如…

汽车免拆诊断案例 | 2017 款林肯大陆车发动机偶尔无法起动

故障现象 一辆2017款林肯大陆车&#xff0c;搭载2.0T发动机&#xff0c;累计行驶里程约为7.5万km。车主进厂反映&#xff0c;有时按下起动按钮&#xff0c;起动机不工作&#xff0c;发动机无法起动&#xff0c;组合仪表点亮正常&#xff1b;多次按下起动按钮&#xff0c;发动机…

(21)起落架/可伸缩相机支架

文章目录 前言 1 连接到自动驾驶仪 2 通过任务规划器设置 3 其他参数 4 参数说明 前言 Copter 和 Plane 支持可伸缩的起落架/相机支架&#xff0c;由伺服机制激活&#xff08;如 Hobby King 出售的用于copters 的这些&#xff09;。齿轮/支架可以手动缩回或用一个辅助开关…

【 DHT11 温湿度传感器】使用STC89C51读取发送到串口、通过时序图编写C语言

文章目录 DHT11 温湿度传感器概述接线数据传送通讯过程时序图检测模块是否存在 代码实现总结对tmp tmp << 1;的理解对sendByte(datas[0]/10 0x30);的理解 DHT11 温湿度传感器 使用80C51单片机通过读取HDT11温湿度传感的数据&#xff0c;发送到串口。 通过时序图编写相应…

微信小程序数组绑定使用案例(一)

微信小程序数组绑定案例&#xff0c;修改数组中的值 1.Wxml 代码 <view class"list"><view class"item {{item.ischeck?active:}}" wx:for"{{list}}"><view class"title">{{item.name}} <text>({{item.id}…

Redis7(二)Redis持久化双雄

持久化之RDB RDB的持久化方式是在指定时间间隔&#xff0c;执行数据集的时间点快照。也就是在指定的时间间隔将内存中的数据集快照写入磁盘&#xff0c;也就是Snapshot内存快照&#xff0c;它恢复时再将硬盘快照文件直接读回到内存里面。 RDB保存的是dump.rdb文件。 自动触发…

昇思25天学习打卡营第25天|MindNLP ChatGLM-6B StreamChat

配置环节 %%capture captured_output !pip uninstall mindspore -y !pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore2.2.14 !pip install mindnlp !pip install mdtex2html配置国内镜像 !export HF_ENDPOINThttps://hf-mirror.com下载与加载模型 from m…

【计算机视觉】siamfc论文复现实现目标追踪

什么是目标跟踪 使用视频序列第一帧的图像(包括bounding box的位置)&#xff0c;来找出目标出现在后序帧位置的一种方法。 什么是孪生网络结构 孪生网络结构其思想是将一个训练样本(已知类别)和一个测试样本(未知类别)输入到两个CNN(这两个CNN往往是权值共享的)中&#xff0…

代码解读:Diffusion Models中的长宽桶技术(Aspect Ratio Bucketing)

Diffusion Models专栏文章汇总&#xff1a;入门与实战 前言&#xff1a;自从SDXL提出了长宽桶技术之后&#xff0c;彻底解决了不同长宽比的图像输入问题&#xff0c;现在已经成为训练扩散模型必选的方案。这篇博客从代码详细解读如何在模型训练的时候运用长宽桶技术(Aspect Rat…

【机器学习】-- SVM核函数(超详细解读)

支持向量机&#xff08;SVM&#xff09;中的核函数是支持向量机能够处理非线性问题并在高维空间中学习复杂决策边界的关键。核函数在SVM中扮演着将输入特征映射到更高维空间的角色&#xff0c;使得原始特征空间中的非线性问题在高维空间中变得线性可分。 一、SVM是什么&#x…

时间卷积网络(TCN):序列建模的强大工具(附Pytorch网络模型代码)

这里写目录标题 1. 引言2. TCN的核心特性2.1 序列建模任务描述2.2 因果卷积2.3 扩张卷积2.4 残差连接 3. TCN的网络结构4. TCN vs RNN5. TCN的应用TCN的实现 1. 引言 引用自&#xff1a;Bai S, Kolter J Z, Koltun V. An empirical evaluation of generic convolutional and re…

Linux系统之部署扫雷小游戏(三)

Linux系统之部署扫雷小游戏(三) 一、小游戏介绍1.1 小游戏简介1.2 项目预览二、本次实践介绍2.1 本地环境规划2.2 本次实践介绍三、检查本地环境3.1 检查系统版本3.2 检查系统内核版本3.3 检查软件源四、安装Apache24.1 安装Apache2软件4.2 启动apache2服务4.3 查看apache2服…

大厂生产解决方案:泳道隔离机制

更多大厂面试内容可见 -> http://11come.cn 大厂生产解决方案&#xff1a;泳道隔离机制 背景 在公司中&#xff0c;由于项目多、开发人员多&#xff0c;一般会有多套测试环境&#xff08;可以理解为多个服务器&#xff09;&#xff0c;同一套服务会在多套测试环境中都部署…