昇思MindSpore学习总结十六 —— 基于MindSpore的GPT2文本摘要

1、mindnlp 版本要求

!pip install tokenizers==0.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp

2、数据集加载与处理

2.1 数据集加载

 本次实验使用的是nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。

from mindspore.dataset import TextFileDataset  # 从mindspore.dataset模块中导入TextFileDataset类# load dataset  # 加载数据集
dataset = TextFileDataset(str(path), shuffle=False)  # 创建一个TextFileDataset实例,参数是文件路径(path)转换成字符串格式,shuffle=False表示不打乱数据顺序
dataset.get_dataset_size()  # 获取数据集的大小,即数据集中样本的数量

# split into training and testing dataset  # 将数据集分割为训练集和测试集
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)  # 将数据集按比例[0.9, 0.1]分割为训练集和测试集,randomize=False表示不随机打乱数据

 2.2 数据预处理

import json  # 导入json模块,用于处理JSON数据
import numpy as np  # 导入numpy模块,并简写为np,用于处理数组和矩阵# preprocess dataset  # 预处理数据集
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):# 定义一个嵌套函数read_map,用于读取并解析JSON文本数据def read_map(text):data = json.loads(text.tobytes())  # 将文本数据转换为字节后用json.loads解析为Python字典return np.array(data['article']), np.array(data['summarization'])  # 返回文章和摘要的numpy数组# 定义一个嵌套函数merge_and_pad,用于合并并填充数据def merge_and_pad(article, summary):# tokenization  # 进行分词操作# pad to max_seq_length, only truncate the article  # 填充到最大序列长度,仅截断文章部分tokenized = tokenizer(text=article, text_pair=summary,padding='max_length', truncation='only_first', max_length=max_seq_len)  # 使用tokenizer对文章和摘要进行分词,填充到最大长度,仅截断文章部分return tokenized['input_ids'], tokenized['input_ids']  # 返回分词后的输入ID(注意:这里的input_ids和labels是相同的)dataset = dataset.map(read_map, 'text', ['article', 'summary'])  # 使用read_map函数对数据集进行映射,提取文章和摘要# change column names to input_ids and labels for the following training  # 更改列名为input_ids和labels,以便后续训练dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])  # 使用merge_and_pad函数对数据进行映射,生成input_ids和labelsdataset = dataset.batch(batch_size)  # 将数据集按批次大小进行分批处理if shuffle:dataset = dataset.shuffle(batch_size)  # 如果shuffle为True,则对批次进行随机打乱return dataset  # 返回预处理后的数据集

 因GPT2无中文的tokenizer,我们使用BertTokenizer替代。

from mindnlp.transformers import BertTokenizer  # 从mindnlp.transformers模块中导入BertTokenizer类# We use BertTokenizer for tokenizing Chinese context.  # 我们使用BertTokenizer对中文内容进行分词
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')  # 使用预训练的'bert-base-chinese'模型初始化BertTokenizer
len(tokenizer)  # 获取tokenizer的词汇表大小

train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)  # 使用process_dataset函数对训练数据集进行预处理,传入参数包括训练数据集、分词器和批次大小为4
next(train_dataset.create_tuple_iterator())  # 创建一个tuple迭代器并获取其第一个元素

 3、模型构建

3.1 构建GPT2ForSummarization模型,注意shift right的操作。

from mindspore import ops  # 从mindspore模块导入ops操作
from mindnlp.transformers import GPT2LMHeadModel  # 从mindnlp.transformers模块中导入GPT2LMHeadModel类# 定义一个用于摘要生成的GPT2模型类,继承自GPT2LMHeadModel
class GPT2ForSummarization(GPT2LMHeadModel):# 定义模型的构造函数def construct(self,input_ids=None,  # 输入IDattention_mask=None,  # 注意力掩码labels=None,  # 标签):# 调用父类的construct方法,获取模型输出outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)shift_logits = outputs.logits[..., :-1, :]  # 移动logits,使其与shift_labels对齐shift_labels = labels[..., 1:]  # 移动标签,使其与shift_logits对齐# Flatten the tokens  # 将tokens展平loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)  # 计算交叉熵损失,忽略填充的tokenreturn loss  # 返回计算的损失

 3.2 动态学习率

from mindspore import ops  # 从mindspore模块导入ops操作
from mindspore.nn.learning_rate_schedule import LearningRateSchedule  # 从mindspore.nn.learning_rate_schedule模块导入LearningRateSchedule类# 定义一个线性学习率衰减与热身相结合的学习率调度类,继承自LearningRateSchedule
class LinearWithWarmUp(LearningRateSchedule):"""Warmup-decay learning rate.  # 热身-衰减学习率。"""def __init__(self, learning_rate, num_warmup_steps, num_training_steps):super().__init__()  # 调用父类的构造函数self.learning_rate = learning_rate  # 初始化学习率self.num_warmup_steps = num_warmup_steps  # 初始化热身步数self.num_training_steps = num_training_steps  # 初始化训练步数# 定义构造函数def construct(self, global_step):# 如果当前步数小于热身步数if global_step < self.num_warmup_steps:return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate  # 线性增加学习率# 否则,学习率进行线性衰减return ops.maximum(0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))) * self.learning_rate  # 计算并返回衰减后的学习率

 4、模型训练

num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4num_training_steps = num_epochs * train_dataset.get_dataset_size()
from mindspore import nn  # 从mindspore模块导入nn(神经网络)模块
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel  # 从mindnlp.transformers模块导入GPT2Config和GPT2LMHeadModel类# 配置GPT2模型的配置
config = GPT2Config(vocab_size=len(tokenizer))  # 创建GPT2配置实例,设置词汇表大小为tokenizer的长度
model = GPT2ForSummarization(config)  # 使用配置实例创建一个GPT2ForSummarization模型# 创建学习率调度器
lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)  # 创建线性热身-衰减学习率调度器# 创建优化器
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)  # 使用AdamWeightDecay优化器,并传入模型的可训练参数和学习率调度器
# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))

from mindnlp._legacy.engine import Trainer  # 从mindnlp._legacy.engine模块导入Trainer类
from mindnlp._legacy.engine.callbacks import CheckpointCallback  # 从mindnlp._legacy.engine.callbacks模块导入CheckpointCallback类# 创建一个CheckpointCallback实例,用于保存检查点
ckpoint_cb = CheckpointCallback(save_path='checkpoint',  # 检查点保存路径ckpt_name='gpt2_summarization',  # 检查点文件名epochs=1,  # 每个epoch保存一次检查点keep_checkpoint_max=2  # 最多保留两个检查点
)# 创建一个Trainer实例,用于训练模型
trainer = Trainer(network=model,  # 要训练的模型train_dataset=train_dataset,  # 训练数据集epochs=1,  # 训练的epoch数optimizer=optimizer,  # 优化器callbacks=ckpoint_cb  # 回调函数,包括检查点回调
)trainer.set_amp(level='O1')  # 开启混合精度训练,级别设置为'O1'

下面这段代码,运行时间较长,最好选择较高算力。 

trainer.run(tgt_columns="labels")  # 运行训练器,指定目标列为“labels”

配置不够,训练时间太长。 

5、模型推理

数据处理,将向量数据变为中文数据

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):# 定义一个嵌套函数read_map,用于读取并解析JSON文本数据def read_map(text):data = json.loads(text.tobytes())  # 将文本数据转换为字节后用json.loads解析为Python字典return np.array(data['article']), np.array(data['summarization'])  # 返回文章和摘要的numpy数组# 定义一个嵌套函数pad,用于对文章进行分词和填充def pad(article):tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)  # 对文章进行分词,截断至最大长度减去摘要长度return tokenized['input_ids']  # 返回分词后的输入IDdataset = dataset.map(read_map, 'text', ['article', 'summary'])  # 使用read_map函数对数据集进行映射,提取文章和摘要dataset = dataset.map(pad, 'article', ['input_ids'])  # 使用pad函数对文章进行分词和填充,生成input_idsdataset = dataset.batch(batch_size)  # 将数据集按批次大小进行分批处理return dataset  # 返回预处理后的数据集
test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
# 创建一个tuple迭代器并获取其第一个元素,以NumPy数组的形式输出,并打印出来
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))
model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)  # 从预训练的检查点加载模型
model.set_train(False)  # 设置模型为评估模式(非训练模式)
model.config.eos_token_id = model.config.sep_token_id  # 设置模型的eos_token_id为sep_token_id
i = 0  # 初始化计数器为0# 遍历测试数据集的迭代器,获取输入ID和原始摘要
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():# 使用模型生成新的摘要,参数包括最大新生成的token数量、束搜索的束数、不重复的ngram大小output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)# 将生成的ID转换为文本output_text = tokenizer.decode(output_ids[0].tolist())print(output_text)  # 打印生成的摘要文本i += 1  # 计数器加1if i == 1:  # 如果计数器达到1break  # 跳出循环,仅生成并打印一个摘要

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/48501.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机视觉篇5 图像的位置--边框

计算机视觉篇4 图像的位置--边框 在训练集中&#xff0c;我们将每个锚框视为一个训练样本。 为了训练目标检测模型&#xff0c;我们需要每个锚框的类别&#xff08;class&#xff09;和偏移量&#xff08;offset&#xff09;标签&#xff0c;其中前者是与锚框相关的对象的类别…

Python网络编程:socket模块的入门与实践

Socket模块的基本概念 创建Socket 在Python中&#xff0c;可以使用socket模块创建Socket对象&#xff1a; import socket# 创建一个TCP/IP socket s socket.socket(socket.AF_INET, socket.SOCK_STREAM) 地址族与Socket类型 socket.AF_INET&#xff1a;IPv4地址族 socket…

#systemverilog# 之 event region 和 timeslot 仿真调度(十)高层次视角看仿真调度事件的发生

仿真调度系列文章,已经编写10篇,写到这里,相比大家都已经对VCS仿真工具的运行机制,有了大体了解。学无止境,而且真正的仿真调度行为控制,是每个EDA厂商自己产品的高度机密。言外之意,我们要抱着谦虚的态度说:我们只是懂了一点点。 一 RTL仿真中的竞争现象 在实际仿真…

【信号分解】基于极点对称模态分解ESMD实现信号分解附Matlab代码

% 导入信号数据 load(‘signal_data.mat’); % 假设信号数据保存在signal_data.mat文件中 % 构建ESMD函数 esmd (x) esmd_decomposition(x); % 对信号进行ESMD分解 components esmd(signal_data); % 显示分解结果 figure; subplot(length(components)1,1,1); plot(signal_…

lua 写一个 不同时区之间转换日期和时间 函数

这个函数用于调整时间戳以适应不同的时区。它接受五个参数&#xff1a;format、timeStamp、dontFixForTimeOffset、currentServerTimeZone和showLog。返回 os.date&#xff0c;可以转化成指定格式的年月日时间 ### 功能 该函数的主要功能是根据给定的时区偏移量调整时间戳&am…

springSecurity学习之springSecurity过滤web请求

过滤web请求 在spring中存在一个DelegatingFilterProxy&#xff0c;是一种特殊的Filter&#xff0c;主要任务就是将工作委托给Filter实现类 使用EnableWebSecurity注解时引入FilterChainProxy Bean(name AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME) pub…

使用Amazon Web Services Lambda把天气预报推送到微信

最近北京开始下雨&#xff0c;开始和同事打赌几点能够雨停&#xff0c;虽然Iphone已经提供了实时天气&#xff0c;但是还是想用国内的API试试看看是不是更加准确些。 以下是我使用的服务&#xff1a; 地图SDK/APP获取 经纬度彩云天气API 通过地理位置获取天气信息Lambda 作为…

关于Mysql的面试题(实时更新中~)

一、主键约束与“not null unique”区别 1、作为Primary Key的域/域组不能为null&#xff0c;而Unique Key可以。 2、在一个表中只能有一个Primary Key&#xff0c;而多个Unique Key可以同时存在。unique not null 可以 将表的一列或多列定义为唯一性属性&#xff0c;而prima…

buu做题(6)

目录 [GWCTF 2019]我有一个数据库 [WUSTCTF2020]朴实无华 [GWCTF 2019]我有一个数据库 什么都没有, 尝试用dirsearch扫一下目录 可以扫到一个 /phpmyadmin 可以直接进入到数据库里面 但里面没什么东西 可以看到它的版本不是最新的, 搜一下相关的漏洞 phpMyAdmin 4.8.1后台文…

go关于string与[]byte再学深一点

目标&#xff1a;充分理解string与[]bytes零拷贝转换的实现 先回顾下string与[]byte的基本知识 1. string与[]byte的数据结构 reflect包中关于字符串的数据结构 // StringHeader is the runtime representation of a string.type StringHeader struct {Data uintptrLen int} …

ClickHouse 入门(一)【基本特点、数据类型与表引擎】

前言 今天开始学习 ClickHouse &#xff0c;一种 OLAP 数据库&#xff0c;实时数仓中用到的比较多&#xff1b; 1、ClickHouse 入门 ClickHouse 是俄罗斯的 Yandex&#xff08;搜索引擎公司&#xff09;在 2016 年开源的列式存储数据库&#xff08;HBase 也是列式存储&#xf…

Linux C服务需要在A服务和B服务都启动成功后才能启动

需求 C服务需要在A服务和B服务都启动成功后才能启动 服务编号服务名服务Anginx.service服务Bmashang.service服务Credis.service 实验 如果您想要 redis.service 在 nginx.service 和 mashang.service 都成功启动后才能启动&#xff0c;那么需要在 redis.service 的服务单元…

67| 上海市互联网行业招聘数据集的构建与可视化分析

一、数据集介绍 数据集概述 数据集文件可见 上海市互联网行业招聘数据集(

赞扬的10条原则

来自刘澜《领导力就是说对十句话》 1&#xff09; 赞扬的第一个原则&#xff0c;是赞扬要具体。不要只是说“你做得真好”&#xff0c;而 是要具体说你到底做了什么&#xff0c;怎么做得好。 2&#xff09;赞扬的第二个原则&#xff0c;是赞扬行为&#xff0c;而非品质。你赞扬…

某宝同款度盘不限速后台系统源码

简介&#xff1a; 某宝同款度盘不限速后台系统源码&#xff0c;验证已被我去除&#xff0c;两个后端系统&#xff0c;账号和卡密系统 第一步安装宝塔&#xff0c;部署卡密系统&#xff0c;需要环境php7.4 把源码丢进去&#xff0c;设置php7.4&#xff0c;和伪静态为thinkphp…

山东济南十大杰出人物起名大师颜廷利:影响世界的思想家哲学家教育家

在宇宙的广袤舞台上&#xff0c;各类智者以他们独特的方式揭示着世界的奥秘。数学家们在无尽的符号与公式中穿梭&#xff0c;像探索者般解锁着自然界的深层逻辑。考古学家们则跋涉于古老的土地&#xff0c;用他们的双手拂去岁月的尘埃&#xff0c;让沉睡的历史重见天日。 二十一…

spss是什么软件?spss有什么用

spss是什么软件&#xff1f; SPSS是一款数据统计、分析软件&#xff0c;它由IBM公司出品&#xff0c;这款软件平台提供了文本分析、大量的机器学习算法、数据分析模型、高级统计分析功能等&#xff0c;软件易学且功能非常强大&#xff0c;可以使用SPSS制作图表&#xff0c;例如…

LeetCode 第407场周赛个人题解

目录 100372. 使两个整数相等的位更改次数 原题链接 思路分析 AC代码 100335. 字符串元音游戏 原题链接 思路分析 AC代码 100360. 将 1 移动到末尾的最大操作次数 原题链接 思路分析 AC代码 100329. 使数组等于目标数组所需的最少操作次数 原题链接 思路分析 A…

汽车免拆诊断案例 | 2017 款林肯大陆车发动机偶尔无法起动

故障现象 一辆2017款林肯大陆车&#xff0c;搭载2.0T发动机&#xff0c;累计行驶里程约为7.5万km。车主进厂反映&#xff0c;有时按下起动按钮&#xff0c;起动机不工作&#xff0c;发动机无法起动&#xff0c;组合仪表点亮正常&#xff1b;多次按下起动按钮&#xff0c;发动机…

ubuntu 挂载硬盘,raspberry pi 树莓派,jetson

在Ubuntu中挂载硬盘&#xff0c;首先需要确认硬盘是否被系统识别&#xff0c;可以通过lsblk或fdisk -l命令查看。假设新硬盘被系统识别为/dev/sdb&#xff0c;你可以按照以下步骤进行挂载&#xff1a; 创建一个挂载点&#xff08;例如在/mnt下&#xff09;&#xff1a; sudo …