Mindspore 公开课 - GPT

GPT Task

在模型 finetune 中,需要根据不同的下游任务来处理输入,主要的下游任务可分为以下四类:

  • 分类(Classification):给定一个输入文本,将其分为若干类别中的一类,如情感分类、新闻分类等;
  • 蕴含(Entailment):给定两个输入文本,判断它们之间是否存在蕴含关系(即一个文本是否可以从另一个文本中推断出来);
  • 相似度(Similarity):给定两个输入文本,计算它们之间的相似度得分;
  • 多项选择题(Multiple choice):给定一个问题和若干个答案选项,选择最佳的答案。

在这里插入图片描述
我们使用IMDb数据集,通过finetune GPT进行情感分类任务。

IMDb数据集是一个常用的情感分类数据集,其中包含50,000条影评文本,其中25,000条用作训练数据,另外25,000条用作测试数据。每个样本都有一个二元标签,表示影评的情感是正面还是负面。

import osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nnfrom mindnlp import load_dataset
from mindnlp.transforms import PadTransform, GPTTokenizerfrom mindnlp.engine import Trainer, Evaluator
from mindnlp.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp.metrics import Accuracy

数据预处理

# 加载数据集
imdb_train, imdb_test = load_dataset('imdb', shuffle=False)

通过load_dataset加载IMDb数据集后,我们需要对数据进行如下处理:

  • 将文本内容进行分词,并映射为对应的数字索引;
  • 统一序列长度:超过进行截断,不足通过占位符进行补全;
  • 按照分类任务的输入要求,在句首和句末分别添加Start与Extract占位符(此处用 <bos>与 <eos>表示);
  • 批处理。
import numpy as npdef process_dataset(dataset, tokenizer, max_seq_len=256, batch_size=32, shuffle=False):"""数据集预处理"""def pad_sample(text):if len(text) + 2 >= max_seq_len:return np.concatenate([np.array([tokenizer.bos_token_id]), text[: max_seq_len-2], np.array([tokenizer.eos_token_id])])else:pad_len = max_seq_len - len(text) - 2return np.concatenate( [np.array([tokenizer.bos_token_id]), text,np.array([tokenizer.eos_token_id]),np.array([tokenizer.pad_token_id] * pad_len)])column_names = ["text", "label"]rename_columns = ["input_ids", "label"]if shuffle:dataset = dataset.shuffle(batch_size)dataset = dataset.map(operations=[tokenizer, pad_sample], input_columns="text")dataset = dataset.rename(input_columns=column_names, output_columns=rename_columns)dataset = dataset.batch(batch_size)return dataset

加载 GPT tokenizer,并添加上述使用到的 <bos>, <eos>, <pad>占位符。

gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')special_tokens_dict = {"bos_token": "<bos>","eos_token": "<eos>","pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)

由于IMDb数据集本身不包含验证集,我们手动将其分割为训练和验证两部分,比例取0.7, 0.3。

imdb_train, imdb_val = imdb_train.split([0.7, 0.3])dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)

模型训练

同BERT课件中的情感分类任务实现,这里我们依旧使用了混合精度。另外需要注意的一点是,由于在前序数据处理中,我们添加了3个特殊占位符,所以在token embedding中需要调整词典的大小(vocab_size + 3)。

from mindnlp.models import GPTForSequenceClassification
from mindnlp._legacy.amp import auto_mixed_precisionmodel = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
model.pad_token_id = gpt_tokenizer.pad_token_id
model.resize_token_embeddings(model.config.vocab_size + 3)
model = auto_mixed_precision(model, 'O1')loss = nn.CrossEntropyLoss()
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)metric = Accuracy()ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='sentiment_model', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', auto_load=True)trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_val, metrics=metric,epochs=3, loss_fn=loss, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb], jit=True)trainer.run(tgt_columns="label")

模型评估

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="label")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/625125.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

报名活动怎么做_小程序创建线上报名活动最详细攻略

报名活动怎么做&#xff1a;一篇让你掌握活动策划与营销的秘籍 在当今社会&#xff0c;无论是线上还是线下&#xff0c;活动已经成为企业营销和品牌推广的重要手段。但是&#xff0c;如何策划一场成功的活动呢&#xff1f;这篇文章将为你揭示活动策划与营销的秘籍&#xff0c;…

政采网调试要求及常见问题解决方法

登录平台软件环境要求&#xff1a; 操作系统&#xff1a;建议Win10及以上&#xff08;Win10-64位专业版 版本号17134纯净安装版本&#xff09; 浏 览 器&#xff1a;IE11浏览器、谷歌120.0.6099.217&#xff08;64位正式版&#xff09;浏览器 必要软件&#xff1a;CA互联互通…

Mindspore 公开课 - BERT

BERT BERT模型本质上是结合了 ELMo 模型与 GPT 模型的优势。 相比于ELMo&#xff0c;BERT仅需改动最后的输出层&#xff0c;而非模型架构&#xff0c;便可以在下游任务中达到很好的效果&#xff1b;相比于GPT&#xff0c;BERT在处理词元表示时考虑到了双向上下文的信息&#…

微服务架构设计核心理论:掌握微服务设计精髓

文章目录 一、微服务与服务治理1、概述2、Two Pizza原则和微服务团队3、主链路规划4、服务治理和微服务生命周期5、微服务架构的网络层搭建6、微服务架构的部署结构7、面试题 二、配置中心1、为什么要配置中心2、配置中心高可用思考 三、服务监控1、业务埋点的技术选型2、用户行…

2023年总结:雄关漫道真如铁,而今迈步从头越,今朝得失

2023年悄然离去&#xff0c;感谢大家的帮助、鼓励和陪伴&#xff0c;感谢家人的理解和支持&#xff0c;祝大家新年快乐&#xff0c;阖家幸福&#xff0c;身体健康。像往常一样&#xff0c;今年也会写一篇年终总结&#xff0c;也是自己的第11篇年终总结&#xff0c;题目就叫《雄…

32 二叉树的定义

之前的通用树结构 采用双亲孩子表示法模型 孩子兄弟表示法模型 引出二叉树 二叉树的定义&#xff1a; 满二叉树和完全二叉树 对此图要有印象 满二叉树一定是完全二叉树&#xff0c;但是完全二叉树不一定是满二叉树 小结

Javaweb之SpringBootWeb案例员工管理分页查询的详细解析

3. 员工管理 完成了部门管理的功能开发之后&#xff0c;我们进入到下一环节员工管理功能的开发。 基于以上原型&#xff0c;我们可以把员工管理功能分为&#xff1a; 分页查询&#xff08;今天完成&#xff09; 带条件的分页查询&#xff08;今天完成&#xff09; 删除员工&…

HNU-算法设计与分析-实验4

算法设计与分析实验4 计科210X 甘晴void 202108010XXX 目录 文章目录 算法设计与分析<br>实验41 回溯算法求解0-1背包问题问题重述想法代码验证算法分析 2 回溯算法实现题5-4运动员最佳配对问题问题重述想法代码验证算法分析 3 分支限界法求解0-1背包问题问题重述想法…

gogs git创建仓库步骤

目录 引言创建仓库clone 仓库推送代码 引言 Gogs 是一款类似GitHub的开源文件/代码管理系统&#xff08;基于Git&#xff09;&#xff0c;Gogs 的目标是打造一个最简单、最快速和最轻松的方式搭建自助 Git 服务。 创建仓库 git中的组织可以把它看成是相关仓库的集合&#xff0c…

DNS主从服务器配置

主从服务器配置&#xff1a; &#xff08;1&#xff09;完全区域传送&#xff1a;复制整个区域文件 #主DNS服务器的配置【主dns服务器的ip地址为192.168.168.129】 #编辑DNS系统配置信息&#xff08;我这里写的增加的信息&#xff0c;源文件里面有很多内容&#xff09; [root…

python中小数据池和编码

嗨喽&#xff0c;大家好呀~这里是爱看美女的茜茜呐 ⼀. 小数据池 在说小数据池之前. 我们先看⼀个概念. 什么是代码块: 根据提示我们从官⽅⽂档找到了这样的说法&#xff1a; A Python program is constructed from code blocks. A block is a piece of Python program text…

大电流直流恒温控制电路

一个电子制冷器控制芯片 实物照片 驱动芯片 使用环境12V直流&#xff0c;电流10A 特此记录 anlog 2024年1月15日

2.1.2 一个关于y=ax+b的故事

跳转到根目录&#xff1a;知行合一&#xff1a;投资篇 已完成&#xff1a; 1、投资&技术   1.1.1 投资-编程基础-numpy   1.1.2 投资-编程基础-pandas   1.2 金融数据处理   1.3 金融数据可视化 2、投资方法论   2.1.1 预期年化收益率   2.1.2 一个关于yaxb的…

【C初阶——内存函数】鹏哥C语言系列文章,基本语法知识全面讲解

本文由睡觉待开机原创&#xff0c;转载请注明出处。 本内容在csdn网站首发 欢迎各位点赞—评论—收藏 如果存在不足之处请评论留言&#xff0c;共同进步&#xff01; 这里写目录标题 1.memcpy使用和模拟实现2.memmove的使用和模拟实现3.memset函数的使用4.memcpy函数的使用 1.m…

linux安装MySQL5.7(安装、开机自启、定时备份)

一、安装步骤 我喜欢安装在/usr/local/mysql目录下 #切换目录 cd /usr/local/ #下载文件 wget https://dev.mysql.com/get/Downloads/MySQL-5.7/mysql-5.7.38-linux-glibc2.12-x86_64.tar.gz #解压文件 tar -zxvf mysql-5.7.38-linux-glibc2.12-x86_64.tar.gz -C /usr/local …

ERP和MES对接的几种接口方式

在数字化工厂的规划建设中&#xff0c;信息化系统的集成&#xff0c;既是重点&#xff0c;但同时也是难点。ERP和MES对接时&#xff0c;ERP主要负责下达生产计划&#xff0c;MES是执行生产计划&#xff0c;二套系统在数据交互时&#xff0c;需要确保基础数据的一致性&#xff0…

SpringBoot源码分析

一&#xff1a;简介 由Pivotal团队提供的全新框架其设计目的是用来简化新Spring应用的初始搭建以及开发过程使用了特定的方式来进行配置快速应用开发领域 二&#xff1a;运行原理以及特点 运行原理&#xff1a; SpringBoot为我们做的自动配置&#xff0c;确实方便快捷&#…

STC8H8K蓝牙智能巡线小车——2. 点亮左右转弯灯与危险报警灯

任务调用示例 RTX 51 TNY 可做多任务调度&#xff0c;API较为简单。 /* 接口API */// 创建任务 extern unsigned char os_create_task (unsigned char task_id); // 结束任务 extern unsigned char os_delete_task (unsigned char task_id);// 等待 extern unsig…

RTKlib操作手册--使用样例数据演示

简介 RTKLIB&#xff08;Real-Time Kinematic Library&#xff09;是一款开源的实时差分全球导航卫星系统&#xff08;GNSS&#xff09;软件库。它旨在提供高精度的位置解算&#xff0c;特别是在实时应用中&#xff0c;如精密农业、测绘、无人机导航等领域。 RTKLIB支持多种G…

目标检测数据集 - 人脸检测数据集下载「包含VOC、COCO、YOLO三种格式」

数据集介绍&#xff1a;行人检测数据集&#xff0c;真实场景高质量图片数据&#xff0c;涉及场景丰富&#xff0c;比如校园行人、街景行人、道路行人、遮挡行人、严重遮挡行人数据&#xff1b;适用实际项目应用&#xff1a;公共场所监控场景下行人检测项目&#xff0c;以及作为…