昇思25天学习打卡营第7天 | 基于MindSpore的GPT2文本摘要

本次打卡基于gpt2的文本摘要

数据加载及预处理

from mindnlp.utils import http_get# download dataset
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')from mindspore.dataset import TextFileDataset# load dataset
dataset = TextFileDataset(str(path), shuffle=False)
dataset.get_dataset_size()# split into training and testing dataset
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)import json
import numpy as np# preprocess dataset
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):def read_map(text):data = json.loads(text.tobytes())return np.array(data['article']), np.array(data['summarization'])def merge_and_pad(article, summary):# tokenization# pad to max_seq_length, only truncate the articletokenized = tokenizer(text=article, text_pair=summary,padding='max_length', truncation='only_first', max_length=max_seq_len)return tokenized['input_ids'], tokenized['input_ids']dataset = dataset.map(read_map, 'text', ['article', 'summary'])# change column names to input_ids and labels for the following trainingdataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])dataset = dataset.batch(batch_size)if shuffle:dataset = dataset.shuffle(batch_size)return datasetfrom mindnlp.transformers import BertTokenizer# We use BertTokenizer for tokenizing chinese context.
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
len(tokenizer)

 模型构建¶

from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModelclass GPT2ForSummarization(GPT2LMHeadModel):def construct(self,input_ids = None,attention_mask = None,labels = None,):outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)shift_logits = outputs.logits[..., :-1, :]shift_labels = labels[..., 1:]# Flatten the tokensloss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)return lossfrom mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateScheduleclass LinearWithWarmUp(LearningRateSchedule):"""Warmup-decay learning rate."""def __init__(self, learning_rate, num_warmup_steps, num_training_steps):super().__init__()self.learning_rate = learning_rateself.num_warmup_steps = num_warmup_stepsself.num_training_steps = num_training_stepsdef construct(self, global_step):if global_step < self.num_warmup_steps:return global_step / float(max(1, self.num_warmup_steps)) * self.learning_ratereturn ops.maximum(0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))) * self.learning_ratenum_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4num_training_steps = num_epochs * train_dataset.get_dataset_size()from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModelconfig = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallbackckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',epochs=1, keep_checkpoint_max=2)trainer = Trainer(network=model, train_dataset=train_dataset,epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
trainer.set_amp(level='O1')  # 开启混合精度trainer.run(tgt_columns="labels")

 

 结论

gpt2相较bert等模型,在文本识别、文本摘要、命名体识别中有着优秀的表现,但其模型规模相对较大,训练时间较长,打卡中展示的没有完成训练,这里需要更好的gpu来辅助训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/45449.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

以太坊(以太坊solidity合约)

以太坊&#xff08;以太坊solidity合约&#xff09; 1&#xff0c;以太坊2&#xff0c;开发名词解释&#xff08;1&#xff09;钱包&#xff08;2&#xff09;Solidity&#xff08;3&#xff09;Ether&#xff08;以太币&#xff09;&#xff08;4&#xff09;Truffle&#xff…

Redis 7.x 系列【23】哨兵模式

有道无术&#xff0c;术尚可求&#xff0c;有术无道&#xff0c;止于术。 本系列Redis 版本 7.2.5 源码地址&#xff1a;https://gitee.com/pearl-organization/study-redis-demo 文章目录 1. 概述2. 工作原理2.1 监控2.2 标记下线2.3 哨兵领袖2.4 新的主节点2.5 通知更新 3. …

请求响应(后端必备)

一、请求 1.简单参数 原始方式&#xff1a; 在原始的web程序中&#xff0c;获取请求参数&#xff0c;需要通过HttpServletRequest对象手动获取 RequestMapping("/simpleParam")public String simpleParam(HttpServletRequest request){String name request.getP…

什么叫价内期权?直接带你了解期权价内期权怎么使用?!

今天带你了解什么叫价内期权&#xff1f;直接带你了解期权价内期权怎么使用&#xff1f;&#xff01;价内期权是具有内在价值的期权。期权持有人行权时&#xff0c;对看涨期权而言&#xff0c;行权价格低于标的证券结算价格&#xff1b;对看跌期权而言&#xff0c;标的证券结算…

js 请求blob:https:// 图片

方式1 def get_file_content_chrome(driver, uri):result driver.execute_async_script("""var uri arguments[0];var callback arguments[1];var toBase64 function(buffer){for(var r,nnew Uint8Array(buffer),tn.length,anew Uint8Array(4*Math.ceil(t/…

前端Vue组件化实践:自定义加载组件的探索与应用

在前端开发领域&#xff0c;随着业务逻辑复杂度的提升和系统规模的不断扩大&#xff0c;传统的开发方式逐渐暴露出效率低下、维护困难等问题。为了解决这些挑战&#xff0c;组件化开发作为一种高效、灵活的开发模式&#xff0c;受到了越来越多开发者的青睐。本文将结合实践&…

Java基础及进阶

JAVA特性 基础语法 一、Java程序的命令行工具 二、final、finally、finalize 三、继承 class 父类 { //代码 }class 子类 extends 父类 { //代码 }四、Vector、ArrayList、LinkedList 五、原始数据类型和包装类 六、接口和抽象类 JAVA进阶 Java引用队列 Object counter ne…

PostgreSQL行级安全策略探究

前言 最近和朋友讨论oracle行级安全策略(VPD)时&#xff0c;查看了下官方文档&#xff0c;看起来VPD的原理是针对应用了Oracle行级安全策略的表、视图或同义词发出的 SQL 语句动态添加where子句。通俗理解就是将行级安全策略动态添加为where 条件。那么PG中的行级安全策略是怎…

使用UDP通信接收与发送Mavlink2.0协议心跳包完整示例

1.克隆mavlink源码 https://github.com/mavlink/mavlink.git 2.进入mavlink目录,安装依赖 python3 -m pip install -r pymavlink/requirements.txt 3.生成Mavlink的C头文件 mavlink % python3 -m pymavlink.tools.mavgen --lang=C --wire-protocol=2.0 --output=generated…

1-5岁幼儿胼胝体的表面形态测量

摘要 胼胝体(CC)是大脑中的一个大型白质纤维束&#xff0c;它参与各种认知、感觉和运动过程。尽管CC与多种发育和精神疾病有关&#xff0c;但关于这一结构的正常发育(特别是在幼儿阶段)还有很多待解开的谜团。虽然早期文献中报道了性别二态性&#xff0c;但这些研究的观察结果…

【Linux网络】select{理解认识select/select与多线程多进程/认识select函数/使用select开发并发echo服务器}

文章目录 0.理解/认识回顾回调函数select/pollread与直接使用 read 的效率差异 1.认识selectselect/多线程&#xff08;Multi-threading&#xff09;/多进程&#xff08;Multi-processing&#xff09;select函数socket就绪条件select的特点总结 2.select下echo服务器封装套接字…

C++ 类和对象 赋值运算符重载

前言&#xff1a; 在上文我们知道数据类型分为自定义类型和内置类型&#xff0c;当我想用内置类型比较大小是非常容易的但是在C中成员变量都是在类(自定义类型)里面的&#xff0c;那我想给类比较大小那该怎么办呢&#xff1f;这时候运算符重载就出现了 一 运算符重载概念&…

安全防御:防火墙基本模块

目录 一、接口 1.1 物理接口 1.2 虚拟接口 二、区域 三、模式 3.1 路由模式 3.2 透明模式 3.3 旁路检测模式 3.4 混合模式 四、安全策略 五、防火墙的状态检测和会话表技术 一、接口 1.1 物理接口 三层口 --- 可以配置IP地址的接口 二层口&#xff1a; 普通二层…

车载终端_RTK定位|4路摄像头|驾驶辅助系统ADAS定制方案

现代车辆管理行业的发展趋势逐渐向智能化和高效化方向发展&#xff0c;车载终端成为关键的工具之一。在这个背景下&#xff0c;一款特别为车队管理行业设计的车载终端应运而生。该车载终端采用8寸多点触控电容屏&#xff0c;搭载联发科四核处理器&#xff0c;主频2.0GHz&#x…

如何安装node.js

Node.js Node.js 是一个基于 Chrome V8 引擎的 JavaScript 运行时环境。 主要特点和优势&#xff1a; 非阻塞 I/O 和事件驱动&#xff1a;能够高效处理大量并发连接&#xff0c;非常适合构建高并发的网络应用&#xff0c;如 Web 服务器、实时聊天应用等。 例如&#xff0c;在…

网络安全——防御(防火墙)带宽以及双机热备实验

12&#xff0c;对现有网络进行改造升级&#xff0c;将当个防火墙组网改成双机热备的组网形式&#xff0c;做负载分担模式&#xff0c;游客区和DMZ区走FW3&#xff0c;生产区和办公区的流量走FW1 13&#xff0c;办公区上网用户限制流量不超过100M&#xff0c;其中销售部人员在其…

排序相关算法--3.选择排序

之前涉及的堆排序就是选择排序的一种&#xff0c;先进行选择。 基本选择排序&#xff1a; 最简单&#xff0c;也是最没用的排序算法&#xff0c;时间复杂度高并且还是不稳定的排序方法&#xff0c;项目中很少会用。 过程&#xff1a; 在一个长度为 N 的无序数组中&#xff0c;…

智慧公厕系统助力城市卫生管理

在当今快速发展的城市环境中&#xff0c;城市卫生管理面临着诸多挑战。其中&#xff0c;公共厕所的管理一直是一个重要但又常被忽视的环节。然而&#xff0c;随着科技的不断进步&#xff0c;智慧公厕系统的出现为城市卫生管理带来了全新的解决方案&#xff0c;成为提升城市品质…

OrangePi AIpro 浅上手

OrangePi AIpro 浅上手 OrangePi AIpro 介绍开发版介绍硬件规格顶层视图和底层视图接口详情图 玩转 OrangePi AIPro烧录镜像串口调试连接 WiFissh 连接配置下载源 使用感受优点&#xff1a;缺点或需注意的点&#xff1a; OrangePi AIpro 介绍 开发版介绍 OrangePi AIpro是香橙…