240707_昇思学习打卡-Day19-基于MindSpore通过GPT实现情感分类

240707_昇思学习打卡-Day19-基于MindSpore通过GPT实现情感分类

今天基于GPT实现一个情感分类的功能,假设已经安装好了MindSpore环境。

# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
!pip install jieba
%env HF_ENDPOINT=https://hf-mirror.com

导包导包

import osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nnfrom mindnlp.dataset import load_datasetfrom mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
# 加载IMDb数据集
imdb_ds = load_dataset('imdb', split=['train', 'test'])
# 获取训练集
imdb_train = imdb_ds['train']
# 获取测试集
imdb_test = imdb_ds['test']
# 调用get_dataset_size方法来获取训练集的大小
imdb_train.get_dataset_size()
import numpy as npdef process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):"""处理数据集,使用tokenizer对文本进行编码,并根据指定的batch大小和序列长度组织数据。参数:- dataset: 需要处理的数据集,包含文本和标签。- tokenizer: 用于将文本转换为token序列的tokenizer。- max_seq_len: 最大序列长度,超过该长度的序列将被截断。- batch_size: 打包数据的批次大小。- shuffle: 是否在处理数据集前对其进行洗牌。返回:- 经过tokenization和batch处理后的数据集。"""# 判断是否在Ascend设备上运行is_ascend = mindspore.get_context('device_target') == 'Ascend'def tokenize(text):"""对文本进行tokenization,并返回input_ids和attention_mask。参数:- text: 需要被tokenize的文本。返回:- tokenize后的input_ids和attention_mask。"""# 根据设备类型选择合适的tokenization方法if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)return tokenized['input_ids'], tokenized['attention_mask']# 如果需要洗牌,对数据集进行洗牌操作if shuffle:dataset = dataset.shuffle(batch_size)# 对数据集进行tokenization操作# map datasetdataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])# 将标签转换为int32类型dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")# 根据设备类型选择合适的批次处理方法# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset
import numpy as npdef process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):"""处理数据集,使用tokenizer对文本进行编码,并根据指定的batch大小和序列长度组织数据。参数:- dataset: 需要处理的数据集,包含文本和标签。- tokenizer: 用于将文本转换为token序列的tokenizer。- max_seq_len: 最大序列长度,超过该长度的序列将被截断。- batch_size: 打包数据的批次大小。- shuffle: 是否在处理数据集前对其进行洗牌。返回:- 经过tokenization和batch处理后的数据集。"""# 判断是否在Ascend设备上运行is_ascend = mindspore.get_context('device_target') == 'Ascend'def tokenize(text):"""对文本进行tokenization,并返回input_ids和attention_mask。参数:- text: 需要被tokenize的文本。返回:- tokenize后的input_ids和attention_mask。"""# 根据设备类型选择合适的tokenization方法if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)return tokenized['input_ids'], tokenized['attention_mask']# 如果需要洗牌,对数据集进行洗牌操作if shuffle:dataset = dataset.shuffle(batch_size)# 对数据集进行tokenization操作# map datasetdataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])# 将标签转换为int32类型dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")# 根据设备类型选择合适的批次处理方法# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset
# 导入来自mindnlp库transformers模块中的GPTTokenizer类
from mindnlp.transformers import GPTTokenizer# 初始化GPT分词器,使用预训练的'openai-gpt'模型
# 分词器
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')# 定义一个特殊token字典,包括开始、结束和填充token
special_tokens_dict = {"bos_token": "<bos>",  # 开始符号"eos_token": "<eos>",  # 结束符号"pad_token": "<pad>",  # 填充符号
}# 向分词器中添加特殊token,并返回添加的token数量
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)
# 将训练数据集imdb_train分割成训练集和验证集
# 按照70%训练集和30%验证集的比例进行划分
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])
dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)
# 调用create_tuple_iterator方法创建一个迭代器,并通过next函数获取迭代器的第一个元素
# 这里的目的是为了展示或测试迭代器是否能正常生成数据
# 对于参数和返回值的详细说明,需要查看create_tuple_iterator方法的文档或实现
next(dataset_train.create_tuple_iterator())
# 导入GPT序列分类模型与Adam优化器
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam# 初始化GPT模型用于序列分类任务,设置标签数量为2(二分类任务)
# 设置模型配置并定义训练参数
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
# 配置模型的填充标记ID以匹配分词器设置
model.config.pad_token_id = gpt_tokenizer.pad_token_id
# 调整令牌嵌入层大小以适应新增词汇量
model.resize_token_embeddings(model.config.vocab_size + 3)# 使用2e-5的学习率初始化Adam优化器
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)# 初始化准确度指标来评估模型性能
metric = Accuracy()# 定义回调函数以在训练过程中保存检查点
# 定义保存检查点的回调函数
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
# 初始化最佳模型回调函数以保存表现最优的模型
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)# 初始化训练器,包括模型、训练数据集、评估数据集、性能指标、优化器以及回调函数
trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_train, metrics=metric,epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],jit=False)
# 导入GPT序列分类模型与Adam优化器
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam# 初始化GPT模型用于序列分类任务,设置标签数量为2(二分类任务)
# 设置模型配置并定义训练参数
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
# 配置模型的填充标记ID以匹配分词器设置
model.config.pad_token_id = gpt_tokenizer.pad_token_id
# 调整令牌嵌入层大小以适应新增词汇量
model.resize_token_embeddings(model.config.vocab_size + 3)# 使用2e-5的学习率初始化Adam优化器
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)# 初始化准确度指标来评估模型性能
metric = Accuracy()# 定义回调函数以在训练过程中保存检查点
# 定义保存检查点的回调函数
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
# 初始化最佳模型回调函数以保存表现最优的模型
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)# 初始化训练器,包括模型、训练数据集、评估数据集、性能指标、优化器以及回调函数
trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_train, metrics=metric,epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],jit=False)
# 执行模型训练
trainer.run(tgt_columns="labels")
# 初始化Evaluator对象,用于评估模型性能
# 参数说明:
# network: 待评估的模型
# eval_dataset: 用于评估的测试数据集
# metrics: 评估指标
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)# 执行模型评估,指定目标列作为评估标签
# 该步骤将计算模型在测试数据集上的指定评估指标
evaluator.run(tgt_columns="labels")

打卡图片:

image-20240707192022631

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/42433.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mysql数据库索引、事务相关知识

索引 索引是一种特殊的文件&#xff0c;包含着对数据表里所有记录的引用指针。可以对表中的一列或多列创建索引&#xff0c; 并指定索引的类型&#xff0c;各类索引有各自的数据结构实现 查看索引 show index from 表名;创建索引对于非主键、非唯一约束、非外键的字段&#…

基于贝叶斯优化的卷积神经网络-循环神经网络混合模型的的模拟股票时间序列预测(MATLAB R2021B)

将机器学习和深度学习方法运用到股市分析中, 不仅具有一定的理论价值, 也具有一定的实践价值。从理论价值上讲, 中国的量化投资技术&#xff08;投资观念、方法与决策等&#xff09;还不够成熟, 尚处在起步阶段, 能够将量化投资技术运用到投资决策中的公司寥寥无几。目前, 国内…

Spring框架:核心概念与Spring Boot微服务开发指南

引言 Spring框架是一个开源的Java平台,它提供了全面的基础设施支持,用于开发Java应用程序。Spring的核心概念包括依赖注入(DI)、面向切面编程(AOP)和事务管理。随着微服务架构的兴起,Spring Boot作为Spring框架的扩展,提供了一种快速开发独立微服务的方式。本文将详细介…

端口被占用,使用小黑框查杀

netstat -ano &#xff08;查看目前所有被占的端口&#xff09; netstat -ano|findstr " 8080" 查一下目前被占用的端口号 &#xff0c;目前我要查的端口号是&#xff1a;8080&#xff0c;注意 后面打8080的时候&#xff0c;要有空格&#xff0c;要不然报错 **task…

Zabbix 的部署和自定义监控内容

前言 一个完整的项目的业务架构包括 客户端 -> 防火墙 -> 负载均衡层&#xff08;四层、七层 LVS/HAProxy/nginx&#xff09; -> Web缓存/应用层&#xff08;nginx、tomcat&#xff09; -> 业务逻辑层(php/java动态应用服务) -> 数据缓存/持久层&#xff08;r…

C 预处理器

C 预处理器 概述 C预处理器是C语言编译过程中的一个重要环节&#xff0c;它对源代码进行预处理&#xff0c;以扩展和修改代码内容。预处理器的主要功能包括宏定义、文件包含、条件编译等。本文将详细介绍C预处理器的工作原理、功能及其在C编程中的应用。 C预处理器的工作原理…

C# List、LinkedList、Dictionary性能对比

数据结构性能对比 List、LinkedList、Dictionary 1. ArrayList &#xff08;List&#xff1a;前传&#xff09; ArrayList 是一个特殊数组&#xff0c; 通过添加和删除元素就可以动态改变数组的长度。 ArrayList集合相对于数组的优点&#xff1a; 支持…

C 语言总复习

总体上必须清楚的: 1)程序结构是三种: 顺序结构 , 循环结构 (三个循环结构), 选择结构 (if 和 switch) 2)读程序都要从main()入口, 然后从最上面顺序往下读(碰到循环做循环,碰到选择做选择)。 3)计算机的数据在电脑中保存是以二进制的形式. 数据存放的位置就是他的地址。 4…

适合selenium的防自动化检测的方法

Selenium 是一个强大的自动化测试工具&#xff0c;能够模拟真实用户与网页的交互。针对您询问的适合在 Selenium 中实施的策略&#xff0c;以下是一些直接适用于或可以通过 Selenium 配置实现的方法&#xff1a; 修改User-Agent: 通过 Chrome 或 Firefox 的选项在启动时设置自…

操作系统智能助手OS Copilot评测报告

背景 如果不是朋友告知&#xff0c;我还不知道阿里云推出了【操作系统智能助手OS Copilot】这样一款产品。 我做系统运维的工作还是挺多的&#xff0c;知道系统运维工作的一些痛点&#xff1b;例如&#xff1a; Linux命令繁杂&#xff0c;想全部记住不太可能&#xff0c;多数…

软件测试《用例篇》

测试用例 测试用例的概念 测试用例是被测试人员向被测试系统发起的一组集合&#xff0c;包括测试环境&#xff0c;操作步骤&#xff0c;预期结果&#xff0c;测试数据等 使用测试用例的好处 使用测试用例进行测试的好处主要有&#xff1a;提高测试效率&#xff0c;降低测试的重…

YOLOV8改进DSConv分布移位卷积

基础干货&#xff1a;高效卷积&#xff0c;降内存提速度保精度 (eepw.com.cn) 各种卷积性能对比(Conv,DwConv,GhostConv,PConv,DCNV)-CSDN博客

WAWA鱼曲折的大学四年回忆录

声明&#xff1a;本文内容纯属个人主观臆断&#xff0c;如与事实不符&#xff0c;请参考事实 前言&#xff1a; 早想写一下大学四年的总结了&#xff0c;但总是感觉无从下手&#xff0c;不知道从哪里开始写&#xff0c;通过这篇文章主要想做一个记录&#xff0c;并从现在的认…

中国智能制造装备产业发展机遇

导语 大家好&#xff0c;我是社长&#xff0c;老K。专注分享智能制造和智能仓储物流等内容。 新书《智能物流系统构成与技术实践》 更多的海量【智能制造】相关资料&#xff0c;请到智能制造online知识星球自行下载。 随着全球第四次工业革命的浪潮&#xff0c;智能制造装备产业…

刷leetcode中常用且有效的方法总结

刷题的时候经常会因为不知道一个方法多写很多行代码&#xff0c;既然有trick为何不用&#xff01;你问我眼中为何常含泪水&#xff0c;因为我忘记方法忘的深沉。那么我决定出一期&#xff01;刷题中常用且有效的方法们&#xff01;将会陆续补充,有补充欢迎评论区留言 目录 py…

C++ 函数高级——函数的默认参数

函数默认参数 在C中&#xff0c;函数的形参列表中的形参是可以有默认值的 语法&#xff1a;返回值类型 函数名 &#xff08;参数 默认值&#xff09;{ } 示例&#xff1a; 正确代码&#xff1a; 运行结果&#xff1a;

昇思25天学习打卡营第13天|sea_fish

打开第13天。本次学习的内容为LLM原理和实践中基于MindSpore通过GPT实现情感分类的内容。记录学习的过程。 根据实验系统中的内容一步一步学习基于MindSpore通过GPT实现情感分类的整个过程。整个过程分为以下三个过程&#xff1a; 数据集加载与处理&#xff1a;数据集加载和数…

开源六轴协作机械臂myCobot 280接入GPT4大模型!实现更复杂和智能化的任务

本文已经或者同济子豪兄作者授权对文章进行编辑和转载 引言 随着人工智能和机器人技术的快速发展&#xff0c;机械臂在工业、医疗和服务业等领域的应用越来越广泛。通过结合大模型和多模态AI&#xff0c;机械臂能够实现更加复杂和智能化的任务&#xff0c;提升了人机协作的效率…

Laravel批量插入数据:提升数据库操作效率的秘诀

Laravel批量插入数据&#xff1a;提升数据库操作效率的秘诀 Laravel作为PHP的现代Web应用框架&#xff0c;提供了优雅而简洁的方法来处理数据库操作。批量插入数据是数据库操作中常见的需求&#xff0c;尤其是在处理大量数据时&#xff0c;批量插入可以显著提高性能。本文将详…

LDAP技术解析:打造安全、高效的企业数据架构

1.LDAP简介 LDAP&#xff08;Lightweight Directory Access Portocol&#xff0c;轻量目录访问协议&#xff09;是一种用于访问与管理分布式目录服务的开放协议。目录服务是一种特殊的数据库&#xff0c;优化用于读取和查询操作&#xff0c;而不是写入操作。LDAP广泛用于身份验…