240707_昇思学习打卡-Day19-基于MindSpore通过GPT实现情感分类

240707_昇思学习打卡-Day19-基于MindSpore通过GPT实现情感分类

今天基于GPT实现一个情感分类的功能,假设已经安装好了MindSpore环境。

# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
!pip install jieba
%env HF_ENDPOINT=https://hf-mirror.com

导包导包

import osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nnfrom mindnlp.dataset import load_datasetfrom mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
# 加载IMDb数据集
imdb_ds = load_dataset('imdb', split=['train', 'test'])
# 获取训练集
imdb_train = imdb_ds['train']
# 获取测试集
imdb_test = imdb_ds['test']
# 调用get_dataset_size方法来获取训练集的大小
imdb_train.get_dataset_size()
import numpy as npdef process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):"""处理数据集,使用tokenizer对文本进行编码,并根据指定的batch大小和序列长度组织数据。参数:- dataset: 需要处理的数据集,包含文本和标签。- tokenizer: 用于将文本转换为token序列的tokenizer。- max_seq_len: 最大序列长度,超过该长度的序列将被截断。- batch_size: 打包数据的批次大小。- shuffle: 是否在处理数据集前对其进行洗牌。返回:- 经过tokenization和batch处理后的数据集。"""# 判断是否在Ascend设备上运行is_ascend = mindspore.get_context('device_target') == 'Ascend'def tokenize(text):"""对文本进行tokenization,并返回input_ids和attention_mask。参数:- text: 需要被tokenize的文本。返回:- tokenize后的input_ids和attention_mask。"""# 根据设备类型选择合适的tokenization方法if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)return tokenized['input_ids'], tokenized['attention_mask']# 如果需要洗牌,对数据集进行洗牌操作if shuffle:dataset = dataset.shuffle(batch_size)# 对数据集进行tokenization操作# map datasetdataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])# 将标签转换为int32类型dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")# 根据设备类型选择合适的批次处理方法# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset
import numpy as npdef process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):"""处理数据集,使用tokenizer对文本进行编码,并根据指定的batch大小和序列长度组织数据。参数:- dataset: 需要处理的数据集,包含文本和标签。- tokenizer: 用于将文本转换为token序列的tokenizer。- max_seq_len: 最大序列长度,超过该长度的序列将被截断。- batch_size: 打包数据的批次大小。- shuffle: 是否在处理数据集前对其进行洗牌。返回:- 经过tokenization和batch处理后的数据集。"""# 判断是否在Ascend设备上运行is_ascend = mindspore.get_context('device_target') == 'Ascend'def tokenize(text):"""对文本进行tokenization,并返回input_ids和attention_mask。参数:- text: 需要被tokenize的文本。返回:- tokenize后的input_ids和attention_mask。"""# 根据设备类型选择合适的tokenization方法if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)return tokenized['input_ids'], tokenized['attention_mask']# 如果需要洗牌,对数据集进行洗牌操作if shuffle:dataset = dataset.shuffle(batch_size)# 对数据集进行tokenization操作# map datasetdataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])# 将标签转换为int32类型dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")# 根据设备类型选择合适的批次处理方法# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset
# 导入来自mindnlp库transformers模块中的GPTTokenizer类
from mindnlp.transformers import GPTTokenizer# 初始化GPT分词器,使用预训练的'openai-gpt'模型
# 分词器
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')# 定义一个特殊token字典,包括开始、结束和填充token
special_tokens_dict = {"bos_token": "<bos>",  # 开始符号"eos_token": "<eos>",  # 结束符号"pad_token": "<pad>",  # 填充符号
}# 向分词器中添加特殊token,并返回添加的token数量
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)
# 将训练数据集imdb_train分割成训练集和验证集
# 按照70%训练集和30%验证集的比例进行划分
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])
dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)
# 调用create_tuple_iterator方法创建一个迭代器,并通过next函数获取迭代器的第一个元素
# 这里的目的是为了展示或测试迭代器是否能正常生成数据
# 对于参数和返回值的详细说明,需要查看create_tuple_iterator方法的文档或实现
next(dataset_train.create_tuple_iterator())
# 导入GPT序列分类模型与Adam优化器
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam# 初始化GPT模型用于序列分类任务,设置标签数量为2(二分类任务)
# 设置模型配置并定义训练参数
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
# 配置模型的填充标记ID以匹配分词器设置
model.config.pad_token_id = gpt_tokenizer.pad_token_id
# 调整令牌嵌入层大小以适应新增词汇量
model.resize_token_embeddings(model.config.vocab_size + 3)# 使用2e-5的学习率初始化Adam优化器
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)# 初始化准确度指标来评估模型性能
metric = Accuracy()# 定义回调函数以在训练过程中保存检查点
# 定义保存检查点的回调函数
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
# 初始化最佳模型回调函数以保存表现最优的模型
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)# 初始化训练器,包括模型、训练数据集、评估数据集、性能指标、优化器以及回调函数
trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_train, metrics=metric,epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],jit=False)
# 导入GPT序列分类模型与Adam优化器
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam# 初始化GPT模型用于序列分类任务,设置标签数量为2(二分类任务)
# 设置模型配置并定义训练参数
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
# 配置模型的填充标记ID以匹配分词器设置
model.config.pad_token_id = gpt_tokenizer.pad_token_id
# 调整令牌嵌入层大小以适应新增词汇量
model.resize_token_embeddings(model.config.vocab_size + 3)# 使用2e-5的学习率初始化Adam优化器
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)# 初始化准确度指标来评估模型性能
metric = Accuracy()# 定义回调函数以在训练过程中保存检查点
# 定义保存检查点的回调函数
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
# 初始化最佳模型回调函数以保存表现最优的模型
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)# 初始化训练器,包括模型、训练数据集、评估数据集、性能指标、优化器以及回调函数
trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_train, metrics=metric,epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],jit=False)
# 执行模型训练
trainer.run(tgt_columns="labels")
# 初始化Evaluator对象,用于评估模型性能
# 参数说明:
# network: 待评估的模型
# eval_dataset: 用于评估的测试数据集
# metrics: 评估指标
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)# 执行模型评估,指定目标列作为评估标签
# 该步骤将计算模型在测试数据集上的指定评估指标
evaluator.run(tgt_columns="labels")

打卡图片:

image-20240707192022631

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/42433.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mysql数据库索引、事务相关知识

索引 索引是一种特殊的文件&#xff0c;包含着对数据表里所有记录的引用指针。可以对表中的一列或多列创建索引&#xff0c; 并指定索引的类型&#xff0c;各类索引有各自的数据结构实现 查看索引 show index from 表名;创建索引对于非主键、非唯一约束、非外键的字段&#…

基于贝叶斯优化的卷积神经网络-循环神经网络混合模型的的模拟股票时间序列预测(MATLAB R2021B)

将机器学习和深度学习方法运用到股市分析中, 不仅具有一定的理论价值, 也具有一定的实践价值。从理论价值上讲, 中国的量化投资技术&#xff08;投资观念、方法与决策等&#xff09;还不够成熟, 尚处在起步阶段, 能够将量化投资技术运用到投资决策中的公司寥寥无几。目前, 国内…

端口被占用,使用小黑框查杀

netstat -ano &#xff08;查看目前所有被占的端口&#xff09; netstat -ano|findstr " 8080" 查一下目前被占用的端口号 &#xff0c;目前我要查的端口号是&#xff1a;8080&#xff0c;注意 后面打8080的时候&#xff0c;要有空格&#xff0c;要不然报错 **task…

Zabbix 的部署和自定义监控内容

前言 一个完整的项目的业务架构包括 客户端 -> 防火墙 -> 负载均衡层&#xff08;四层、七层 LVS/HAProxy/nginx&#xff09; -> Web缓存/应用层&#xff08;nginx、tomcat&#xff09; -> 业务逻辑层(php/java动态应用服务) -> 数据缓存/持久层&#xff08;r…

操作系统智能助手OS Copilot评测报告

背景 如果不是朋友告知&#xff0c;我还不知道阿里云推出了【操作系统智能助手OS Copilot】这样一款产品。 我做系统运维的工作还是挺多的&#xff0c;知道系统运维工作的一些痛点&#xff1b;例如&#xff1a; Linux命令繁杂&#xff0c;想全部记住不太可能&#xff0c;多数…

软件测试《用例篇》

测试用例 测试用例的概念 测试用例是被测试人员向被测试系统发起的一组集合&#xff0c;包括测试环境&#xff0c;操作步骤&#xff0c;预期结果&#xff0c;测试数据等 使用测试用例的好处 使用测试用例进行测试的好处主要有&#xff1a;提高测试效率&#xff0c;降低测试的重…

WAWA鱼曲折的大学四年回忆录

声明&#xff1a;本文内容纯属个人主观臆断&#xff0c;如与事实不符&#xff0c;请参考事实 前言&#xff1a; 早想写一下大学四年的总结了&#xff0c;但总是感觉无从下手&#xff0c;不知道从哪里开始写&#xff0c;通过这篇文章主要想做一个记录&#xff0c;并从现在的认…

中国智能制造装备产业发展机遇

导语 大家好&#xff0c;我是社长&#xff0c;老K。专注分享智能制造和智能仓储物流等内容。 新书《智能物流系统构成与技术实践》 更多的海量【智能制造】相关资料&#xff0c;请到智能制造online知识星球自行下载。 随着全球第四次工业革命的浪潮&#xff0c;智能制造装备产业…

C++ 函数高级——函数的默认参数

函数默认参数 在C中&#xff0c;函数的形参列表中的形参是可以有默认值的 语法&#xff1a;返回值类型 函数名 &#xff08;参数 默认值&#xff09;{ } 示例&#xff1a; 正确代码&#xff1a; 运行结果&#xff1a;

开源六轴协作机械臂myCobot 280接入GPT4大模型!实现更复杂和智能化的任务

本文已经或者同济子豪兄作者授权对文章进行编辑和转载 引言 随着人工智能和机器人技术的快速发展&#xff0c;机械臂在工业、医疗和服务业等领域的应用越来越广泛。通过结合大模型和多模态AI&#xff0c;机械臂能够实现更加复杂和智能化的任务&#xff0c;提升了人机协作的效率…

盘点当下智能体应用开发的几种形态

现在多智能体系统开发的关注度越来越高了&#xff0c;不光在开发者的圈子热度很高&#xff0c;很多职场人士&#xff0c;甚至是小白也参与其中&#xff0c;因为现在的门槛越来越低了&#xff0c;尤其是&#xff0c;最近特别火的扣子&#xff08;coze&#xff09;和百度的appbui…

【TB作品】51单片机 Proteus仿真00016 乒乓球游戏机

课题任务 本课题任务 (联机乒乓球游戏)如下图所示: 同步显示 oo 8个LED ooooo oo ooooo 8个LED 单片机 单片机 按键 主机 从机 按键 设计题目:两机联机乒乓球游戏 图1课题任务示意图 具体说明: 共有两个单片机,每个单片机接8个LED和1 个按键,两个单片机使用串口连接。 (2)单片机…

数据结构学生信息顺序表

主程序 #include "fun.h" int main(int argc, const char *argv[]) { seq_p Screate_seq(); stu data; printf("请问要输入几个学生的数据&#xff1a;"); int n; scanf("%d",&n); while(n--) { prin…

MySQL Binlog详解:提升数据库可靠性的核心技术

文章目录 1. 引言1.1 什么是MySQL Bin Log&#xff1f;1.2 Bin Log的作用和应用场景 2. Bin Log的基本概念2.1 Bin Log的工作原理2.2 Bin Log的三种格式 3. 配置与管理Bin Log3.1 启用Bin Log3.2 配置Bin Log参数3.3 管理Bin Log文件3.4 查看Bin Log内容3.5 使用mysqlbinlog工具…

STM32崩溃问题排查

文章目录 前言1. 问题说明2. STM32&#xff08;Cortex M4内核&#xff09;的寄存器3. 崩溃问题分析3.1 崩溃信息的来源是哪里&#xff1f;3.2 崩溃信息中的每个关键字代表的含义3.3 利用崩溃信息去查找造成崩溃的点3.4 keil5中怎么根据地址找到问题点3.5 keil5上编译时怎么输出…

【NTN 卫星通信】Starlink基于终端用户的测量以及测试概述

1 概述 收集了一些starlink的资料&#xff0c;是基于终端侧部署在野外的一些测试以及测量结果。 2 低地球轨道卫星网络概述 低地球轨道卫星网络(lsn)被认为是即将到来的6G中真正实现全球覆盖的关键基础设施。本文介绍了我们对Starlink端到端网络特征的初步测量结果和观测结果&…

STM32-ADC+DMA

本内容基于江协科技STM32视频学习之后整理而得。 文章目录 1. ADC模拟-数字转换器1.1 ADC模拟-数字转换器1.2 逐次逼近型ADC1.3 ADC框图1.4 ADC基本结构1.5 输入通道1.6 规则组的转换模式1.6.1 单次转换&#xff0c;非扫描模式1.6.2 连续转换&#xff0c;非扫描模式1.6.3 单次…

Tabu Search — 温和介绍

Tabu Search — 温和介绍 目录 Tabu Search — 温和介绍 一、说明 二、什么是禁忌搜索以及我可以在哪里使用它&#xff1f; 三、禁忌搜索原则 四、短期记忆和积极搜索&#xff1a; 五、举例时间 六、结论&#xff1a; 七、参考&#xff1a; 一、说明 最近&#xff0c;我参加了…

在DevEco运行typeScript代码,全网详细解决执行Set-ExecutionPolicy RemoteSigned报出的错

目录 基本思路 网络推荐 本人实践 如下操作,报错: 基本思路 //在DevEco运行typeScript代码 /** * 1.保证node -v出现版本,若没有,配置环境变量(此电脑-属性-高级系统变量配置-path-粘贴路径);DevEco在local.properties中可看到当前nodejs的路径 * 2.npm install …

海外仓一件代发功能自动化:海外仓WMS系统配置方法

根据数据显示&#xff0c;2014-2019年短短几年之间&#xff0c;跨境电商销售总额增长了160%以上。这为跨境电商商家和海外仓&#xff0c;国际物流等服务端企业都提供了巨大的发展机遇。 然而&#xff0c;作为海外仓&#xff0c;要想服务好跨境电商&#xff0c;仓库作业的每一个…