昇思25天学习打卡营第11天|基于MindSpore通过GPT实现情感分类

学AI还能赢奖品?每天30分钟,25天打通AI任督二脉 (qq.com)

基于MindSpore通过GPT实现情感分类

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
!pip install jieba
%env HF_ENDPOINT=https://hf-mirror.com
import osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nnfrom mindnlp.dataset import load_datasetfrom mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']
imdb_train.get_dataset_size()

加载IMDB数据集。将IMDB数据集分为训练集和测试集。IMDB (Internet Movie Database) 数据集包含来自著名在线电影数据库 IMDB 的电影评论。每条评论都被标注为正面(positive)或负面(negative),因此该数据集是一个二分类问题,也就是情感分类问题。

import numpy as npdef process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):is_ascend = mindspore.get_context('device_target') == 'Ascend'def tokenize(text):if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)return tokenized['input_ids'], tokenized['attention_mask']if shuffle:dataset = dataset.shuffle(batch_size)# map datasetdataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset

定义数据预处理函数。这个函数输入参数为数据集、分词器(GPT Tokenizer)以及一些可选参数,如最大序列长度、批量大小和是否打乱数据。预处理包括将文本转换为模型可以理解的输入格式(如input_ids和attention_mask),并将标签转换为整数类型。

from mindnlp.transformers import GPTTokenizer
# tokenizer
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')# add sepcial token: <PAD>
special_tokens_dict = {"bos_token": "<bos>","eos_token": "<eos>","pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)

加载GPT分词器并增加特殊标记。

# split train dataset into train and valid datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])

将训练集划分为训练集和验证集。

dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)

用 process_dataset 函数对训练集、验证集和测试集进行处理,得到相应的数据集对象。

next(dataset_train.create_tuple_iterator())
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam# set bert config and define parameters for training
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
model.config.pad_token_id = gpt_tokenizer.pad_token_id
model.resize_token_embeddings(model.config.vocab_size + 3)optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)metric = Accuracy()# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_train, metrics=metric,epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],jit=False)

导入 GPTForSequenceClassification 模型和 Adam 优化器。设置GPT模型的配置信息,包括pad_token_id和词汇表大小。使用Adam优化器对模型的可训练参数进行优化(从这里没有看出是更新部分参数,还是全部参数,有可能是部分参数。通常会改变最后一层分类器的权重和偏置,其他层的权重被冻结不变或者只微小更新些许参数。)。

Accuracy作为评价指标。

定义回调函数用于保存检查点:

   - CheckpointCallback:用于定期保存模型权重,save_path 指定了保存路径,ckpt_name保存文件的前缀,epochs=1 每个epoch保存一次,keep_checkpoint_max=2 表示最多保留2个检查点文件。
   - BestModelCallback:用于保存验证集上表现最好的模型,auto_load=True表示在训练结束后自动加载最优模型的权重。

创建 Trainer 对象,传入以下参数:
      - network:要训练的模型。
      - train_dataset:训练数据集。
      - eval_dataset:验证数据集。
      - metrics:评估指标。
      - epochs:训练轮数。
      - optimizer:优化器。
      - callbacks:回调函数列表,包括检查点保存和最佳模型保存。
      - jit:是否启用JIT编译,这里设置为False。

trainer.run(tgt_columns="labels")

通过 Trainer 的 run 方法启动训练,指定了训练过程中的目标标签列为 "labels"。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

创建 Evaluator 对象,传入以下参数:
      - network:要评估的模型。
      - eval_dataset:测试数据集。
      - metrics:评估指标。

用MindSpore通过GPT实现情感分类(Sentiment Classification)的示例。首先加载了IMDB影评数据集,并将其划分为训练集、验证集和测试集。然后使用GPTTokenizer对文本进行了标记化和转换。接下来,使用GPTForSequenceClassification构建了情感分类模型,并定义了优化器和评估指标。使用Trainer进行模型的训练,并设置了保存检查点的回调函数。训练完成后,通过Evaluator对测试集进行评估,输出分类准确率。通过对IMDB影评数据集进行训练和评估,模型可以自动进行情感分类,识别出正面或负面情感。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/862948.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Kubernetes面试整理-RBAC(基于角色的访问控制) 的理解和配置方法

在 Kubernetes 中,RBAC(基于角色的访问控制,Role-Based Access Control)是一种控制访问权限的机制,用于管理用户和服务账户对集群资源的访问。RBAC 通过定义角色和角色绑定来控制谁可以对哪些资源执行哪些操作。 核心概念 1. Role 和 ClusterRole: ● Role:定义在特定命…

java基于微信小程序+mysql+RocketMQ开发的医院智能问诊系统源码 智能导诊系统 智能导诊小程序源码

java基于微信小程序mysqlRocketMQ开发的医院智能问诊系统源码 智能导诊系统 智能导诊小程序源码 医院导诊系统是一种基于互联网和定位技术的智能化服务系统&#xff0c;旨在为患者提供精准、便捷的医院内部导航和医疗就诊咨询服务。该系统整合了医院的各种医疗服务资源&#x…

【软件实施】软件实施概论

目录 软件实施概述定义主要工作软件项目的实施工作区别于一般的项目&#xff08;如&#xff1a;房地产工程项目&#xff09;软件实施的重要性挑战与对策软件项目实施的流程软件项目实施的周期 软件企业软件企业分类产品型软件企业业务特点产品型软件企业的分类产品型软件企业的…

PortSip测试

安装PBX 下载 免费下载 PortSIP PBX 安装PBX&#xff0c;安装后&#xff0c;运行 &#xff0c;默认用户是admin 密码是admin&#xff0c;然后配置IP 为192.168.0.189 设置域名为192.168.0.189 配置分机 添加分机&#xff0c;添加了10001、10002、9999 三个分机&#xff0c…

Android Studio中使用命令行gradle查看签名信息

Android Studio中使用命令行gradle查看签名信息&#xff1a; 使用 Gradle 插件生成签名报告 打开 Android Studio 的 Terminal。 运行以下命令&#xff1a;./gradlew signingReport 将生成一个签名报告&#xff0c;其中包含 MD5、SHA1 和 SHA-256 的信息。 如果失败&#xf…

10分钟完成微信JSAPI支付对接过程-JAVA后端接口

引入架包 <dependency><groupId>com.github.javen205</groupId><artifactId>IJPay-WxPay</artifactId><version>${ijapy.version}</version></dependency>配置类 package com.joolun.web.config;import org.springframework.b…

【递归、搜索与回溯】记忆化搜索

记忆化搜索 1.记忆化搜索2.不同路径3.最长递增子序列4. 猜数字大小 II5.矩阵中的最长递增路径 点赞&#x1f44d;&#x1f44d;收藏&#x1f31f;&#x1f31f;关注&#x1f496;&#x1f496; 你的支持是对我最大的鼓励&#xff0c;我们一起努力吧!&#x1f603;&#x1f603;…

5000字深入讲解:企业数字化转型优先从哪个板块开始?

很多企业都知道数字化转型重要&#xff0c;但不知道应该怎样入手&#xff0c;分哪些阶段。以下引用国内领先数字化服务商 织信Informat 的数字化转型方法论材料&#xff0c;且看看他们是如何看待数字化转型的&#xff1f;数字化转型应该从哪先开始&#xff1f;如何做&#xff1…

Java中的类与接口:抽象与实现的艺术

Java中的类与接口&#xff1a;抽象与实现的艺术 大家好&#xff0c;我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编&#xff0c;也是冬天不穿秋裤&#xff0c;天冷也要风度的程序猿&#xff01;今天我们将深入探讨Java中类与接口的设计原则&#xff0c;以及如…

P1107 [BJWC2008] 雷涛的小猫

[BJWC2008] 雷涛的小猫 题目背景 原最大整数参见 P1012 题目描述 雷涛同学非常的有爱心&#xff0c;在他的宿舍里&#xff0c;养着一只因为受伤被救助的小猫&#xff08;当然&#xff0c;这样的行为是违反学生宿舍管理条例的&#xff09;。在他的照顾下&#xff0c;小猫很快…

阿里云服务器数据库迁云: 数据从传统到云端的安全之旅(WordPress个人博客实战教学)

&#x1f3ac; 鸽芷咕&#xff1a;个人主页 &#x1f525; 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想&#xff0c;就是为了理想的生活! 文章目录 一、 开始实战1.2创建实验资源1.3重置云服务器ECS的登录密码&#xff08;请记住密码&#xff09;1.4 设置安全组端口1…

C++ STL: std::vector与std::array的深入对比

什么是 std::vector 和 std::array 首先&#xff0c;让我们简要介绍一下这两种容器&#xff1a; • std::vector&#xff1a;一个动态数组&#xff0c;可以根据需要动态调整其大小。 • std::array&#xff1a;一个固定大小的数组&#xff0c;其大小在编译时确定。 虽然…

Adobe Acrobat Pro或者Adobe Acrobat Reader取消多标签页显示,设置打开一个pdf文件对应一个窗口。

Windows系统&#xff1a;Adobe Acrobat Pro或者Adobe Acrobat Reader首选项-一般-取消在同一窗口的新标签中打开文档&#xff08;需要重启&#xff09;的对勾&#xff0c;点击确定&#xff0c;彻底关闭后重启&#xff0c;这样打开的每一个PDF文件对应的是一个窗口&#xff0c;并…

力扣第214题“最短回文串”

在本篇文章中&#xff0c;我们将详细解读力扣第214题“最短回文串”。通过学习本篇文章&#xff0c;读者将掌握如何使用 KMP 算法来解决这一问题&#xff0c;并了解相关的复杂度分析和模拟面试问答。每种方法都将配以详细的解释&#xff0c;以便于理解。 问题描述 力扣第214题…

拓展方法知识点

拓展方法的基本概念 概念 为现有非静态变量类型添加新方法。 作用 1.提升程序拓展性。 2.不需要再对象中重新写方法。 3.不需要继承来添加方法。 4.为别人封装的类型写额外的方法。 特点 1.一定是写在静态类中。 2.一定是个静态函数。 3.第一个参数为拓展目标。 4…

Bridging nonnull in Objective-C to Swift: Is It Safe?

Bridging nonnull in Objective-C to Swift: Is It Safe? In the world of iOS development, bridging between Objective-C and Swift is a common practice, especially for legacy codebases (遗留代码库) or when integrating (集成) third-party libraries. One importa…

重磅更新-UniApp自定义字体可视化设计

重磅更新-UniApp自定义字体可视化设计。 DIY可视化为了适配不同APP需要&#xff0c;支持用户自定义字体&#xff0c;自定义字体后&#xff0c;设计出来的界面更多样化&#xff0c;不再是单一字体效果。用户可以使用第三方字体加入设计&#xff0c;在设计的时候选择上自己的字体…

简过网:三支一扶有编制吗?考上三支一扶就是入编了吗?

小编看到很多朋友在咨询三支一扶的相关问题&#xff0c;比如三支一扶有没有编制&#xff1f;今天&#xff0c;针对这个问题咱们一块来了解一下吧。 三支一扶是没有编制的&#xff0c;既没有事业编制&#xff0c;也不是公务员编制 什么是三支一扶&#xff1f; 三支一扶是指大…

AI副业赚钱攻略:掌握数字时代的机会

前言 最近国产大模型纷纷上线&#xff0c;飞入寻常百姓家。AI副业正成为许多人寻找额外收入的途径。无论您是想提高家庭收入还是寻求职业发展&#xff0c;这里有一个变现&#xff0c;帮助您掌握AI兼职副业的机会。 1. 了解AI的基础知识 在开始之前&#xff0c;了解AI的基础…

一个开源的、独立的、可自托管的评论系统,专为现代Web平台设计

大家好&#xff0c;今天给大家分享的是一个开源的、独立的、可自托管的评论系统&#xff0c;专为现代Web平台设计。 Remark42是一个自托管的、轻量级的、简单的&#xff08;但功能强大的&#xff09;评论引擎&#xff0c;它不会监视用户。它可以嵌入到博客、文章或任何其他读者…