昇思MindSpore学习总结十七 —— 基于MindSpore通过GPT实现情感分类

 1、要求

2、导入了一些必要的库和模块

        以便在使用MindSpore和MindNLP进行深度学习任务时能使用各种功能,比如数据集处理、模型训练、评估和回调功能。

import os  # 导入操作系统相关功能的模块,如文件和目录操作import mindspore  # 导入MindSpore库,这是一个深度学习框架
from mindspore.dataset import text, GeneratorDataset, transforms  # 从MindSpore的数据集模块导入处理文本、生成数据集和变换功能
from mindspore import nn  # 从MindSpore库导入神经网络模块from mindnlp.dataset import load_dataset  # 从MindNLP库导入加载数据集的功能from mindnlp._legacy.engine import Trainer, Evaluator  # 从MindNLP库的旧版本引擎模块导入训练器和评估器
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback  # 导入旧版本引擎模块的回调功能,用于检查点保存和最佳模型保存
from mindnlp._legacy.metrics import Accuracy  # 从MindNLP库的旧版本指标模块导入准确率指标

 3、加载IMDB数据集

        并将其分为训练集和测试集。load_dataset函数会返回一个包含数据集各个部分的字典,然后你可以通过键 'train''test' 来访问相应的数据。

imdb_ds = load_dataset('imdb', split=['train', 'test'])  # 加载IMDB数据集,并将数据集分为训练集和测试集,返回一个包含两个部分的字典imdb_train = imdb_ds['train']  # 从字典中提取训练集数据
imdb_test = imdb_ds['test']  # 从字典中提取测试集数据

4、获取训练集数据集大小

get_dataset_size() 用于返回数据集中包含的样本数量。这个方法的返回值通常是一个整数,表示训练集中有多少个样本。 

imdb_train.get_dataset_size()  # 获取训练集数据集中样本的数量

 5、定义一个用于处理数据集的函数 process_dataset

        将输入文本数据进行tokenization,并根据设备类型选择不同的批处理方式。如果需要,还可以对数据集进行打乱和批处理。

import numpy as np  # 导入NumPy库,用于数值计算和数组操作def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):# 定义处理数据集的函数,接受数据集、tokenizer、最大序列长度、批量大小和是否打乱数据集作为参数is_ascend = mindspore.get_context('device_target') == 'Ascend'# 检查当前设备是否为Ascend(华为的深度学习处理器),根据设备类型选择不同的tokenizer处理方式def tokenize(text):# 定义tokenize函数,用于对文本进行tokenizationif is_ascend:# 如果在Ascend设备上,使用'padding'和'truncation'进行tokenizationtokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:# 否则只进行'truncation'和设置最大长度tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)return tokenized['input_ids'], tokenized['attention_mask']# 返回tokenized的'input_ids'和'attention_mask'字段if shuffle:dataset = dataset.shuffle(batch_size)# 如果设置了shuffle参数为True,则对数据集进行打乱# 对数据集应用map操作dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])# 使用tokenize函数处理数据集中的"text"列,并生成新的列'input_ids'和'attention_mask'dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")# 将数据集中的"label"列转换为mindspore.int32类型,并生成新的列"labels"# 根据设备类型选择批处理方式if is_ascend:dataset = dataset.batch(batch_size)# 如果在Ascend设备上,直接进行批处理else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})# 否则使用padded_batch进行批处理,设置填充信息(input_ids和attention_mask的填充值)return dataset  # 返回处理后的数据集

6、初始化一个GPT模型的tokenizer

加载预训练的GPT模型,然后添加了一些特殊标记,如句子开始标记、句子结束标记和填充标记。

from mindnlp.transformers import GPTTokenizer
# 从MindNLP库中导入GPTTokenizer,用于加载和处理GPT模型的tokenizergpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')
# 使用预训练的GPT tokenizer进行初始化,指定预训练模型名称为'openai-gpt'# add special token: <PAD>
special_tokens_dict = {"bos_token": "<bos>",  # 句子开始标记"eos_token": "<eos>",  # 句子结束标记"pad_token": "<pad>",  # 填充标记
}
# 定义一个字典,用于添加特定的特殊标记到tokenizer中num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)
# 将定义的特殊标记添加到tokenizer中,并返回添加的标记数量

 7、划分训练集、验证集

将一个训练数据集 imdb_train 按照 70% 和 30% 的比例分割成两个数据集:一个用于训练 (imdb_train),另一个用于验证 (imdb_val)。

# split train dataset into train and valid datasets
# 将训练数据集拆分成训练集和验证集imdb_train, imdb_val = imdb_train.split([0.7, 0.3])
# 将 imdb_train 数据集按 70% 和 30% 的比例分割成两个数据集
# imdb_train 包含 70% 的数据,用于继续训练
# imdb_val 包含 30% 的数据,用于验证模型

8、处理三个数据集:训练集、验证集和测试集。

process_dataset 函数的作用是对数据集进行预处理,包括标记化、清洗或其他转换操作。gpt_tokenizer 是用于将文本数据转换为模型可以理解的格式的标记化工具。数据集的打乱(shuffle=True)有助于防止模型训练中的过拟合和提升泛化能力。

dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
# 使用 process_dataset 函数处理 imdb_train 数据集
# gpt_tokenizer 用于对数据进行标记化处理
# shuffle=True 表示对数据进行随机打乱,以提高训练效果
# 处理后的数据集存储在 dataset_train 中dataset_val = process_dataset(imdb_val, gpt_tokenizer)
# 使用 process_dataset 函数处理 imdb_val 数据集
# gpt_tokenizer 用于对数据进行标记化处理
# 处理后的数据集存储在 dataset_val 中
# 这里没有指定 shuffle 参数,默认情况下数据不会被打乱dataset_test = process_dataset(imdb_test, gpt_tokenizer)
# 使用 process_dataset 函数处理 imdb_test 数据集
# gpt_tokenizer 用于对数据进行标记化处理
# 处理后的数据集存储在 dataset_test 中
# 这里也没有指定 shuffle 参数,默认情况下数据不会被打乱

9、从 dataset_train 中获取下一个数据样本

  • dataset_train.create_tuple_iterator():将 dataset_train 数据集转换为一个迭代器,这个迭代器返回的每一项都是一个元组(通常包含输入数据和对应的标签)。
  • next():从迭代器中获取下一个数据项。如果迭代器中没有更多的数据项,它将引发 StopIteration 异常。
next(dataset_train.create_tuple_iterator())
# 创建一个迭代器,用于遍历 dataset_train 数据集中的数据
# 使用 create_tuple_iterator() 方法将 dataset_train 转换为一个元组迭代器
# 然后调用 next() 函数从迭代器中获取下一个数据样本
# 这将返回数据集中的第一个数据项(通常是一个包含特征和标签的元组)

10、配置模型

配置一个 GPT 模型进行序列分类任务,并设置了训练过程中的优化器、评估指标、回调函数等。

  1. 模型定义和配置

    • GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2):从预训练的 GPT 模型加载,并配置为二分类任务。
    • model.config.pad_token_id = gpt_tokenizer.pad_token_id:设置模型的填充标记 ID。
    • model.resize_token_embeddings(model.config.vocab_size + 3):扩展模型的词汇表以适应新增的词汇。
  2. 优化器和学习率

    • optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5):使用 Adam 优化器,学习率设置为 0.00002。
  3. 评估指标

    • metric = Accuracy():使用准确率作为模型性能评估指标。
  4. 回调函数

    • CheckpointCallback:用于保存训练过程中的模型检查点。
    • BestModelCallback:用于保存和自动加载最佳模型检查点。
  5. 训练配置

    • Trainer:用于模型的训练,指定了模型、数据集、优化器、回调函数等参数。
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam# set bert config and define parameters for training
# 设置模型配置和训练参数# 创建一个 GPT 模型用于序列分类任务,num_labels=2 表示有两个分类标签
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)# 配置模型的填充标记 ID 为 tokenzier 中的 pad_token_id
model.config.pad_token_id = gpt_tokenizer.pad_token_id# 调整模型的词汇表大小,为模型词汇表增加 3 个新的词汇
model.resize_token_embeddings(model.config.vocab_size + 3)# 使用 Adam 优化器,并设置学习率为 2e-5
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)# 定义准确率作为评估指标
metric = Accuracy()# 定义回调函数以保存检查点
# ckpoint_cb 用于保存模型检查点
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)# best_model_cb 用于保存最佳模型检查点,并自动加载最佳模型
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)# 创建一个 Trainer 对象用于训练模型
trainer = Trainer(network=model,                       # 训练的模型train_dataset=dataset_train,         # 训练数据集eval_dataset=dataset_train,          # 验证数据集(这里使用了训练集进行验证)metrics=metric,                      # 评估指标epochs=1,                            # 训练轮次optimizer=optimizer,                 # 优化器callbacks=[ckpoint_cb, best_model_cb], # 回调函数jit=False                            # 是否使用 JIT 编译
)

 11、trainer.run(tgt_columns="labels") 启动了模型的训练过程,并指定了目标列。

  • trainer.run(): 这个方法启动了模型的训练过程。它会根据之前配置的训练参数(如模型、优化器、数据集、回调函数等)开始训练。

  • tgt_columns="labels": 这是一个参数,指定了数据集中哪个列作为模型的目标列(即标签列)。在这里,"labels" 表示数据集中用于训练和验证的目标列是 labels。这个列的值用于计算损失函数并进行模型的优化。

这种设置通常用于数据集中包含多个列的情况,其中一个列是模型训练的目标输出。在这个例子中,labels 列包含了分类任务中的标签。

trainer.run(tgt_columns="labels")
# 启动训练过程
# tgt_columns="labels" 指定了训练数据集中包含的目标列(标签列)
# 训练过程将使用这些目标列来计算损失和进行梯度更新

12、 设置并运行了模型的评估过程

  1. 创建 Evaluator 对象:

    • network=model:指定要评估的模型。
    • eval_dataset=dataset_test:指定用于评估的数据集。这里使用了 dataset_test 数据集进行评估。
    • metrics=metric:指定评估时使用的指标。在这里是准确率(Accuracy())。
  2. 运行评估:

    • evaluator.run(tgt_columns="labels"):启动模型的评估过程。
    • tgt_columns="labels":指定目标列(即数据集中用于计算评估指标的列)。在这里,"labels" 表示模型将使用这个列中的标签来计算准确率。

通过这种设置,Evaluator 对象会遍历 dataset_test 数据集,计算模型在测试集上的表现,并根据指定的评估指标(准确率)输出结果。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
# 创建一个 Evaluator 对象,用于评估模型
# 参数解释:
# network=model: 要评估的模型
# eval_dataset=dataset_test: 用于评估的数据集
# metrics=metric: 评估过程中使用的指标(这里是准确率)evaluator.run(tgt_columns="labels")
# 启动评估过程
# tgt_columns="labels" 指定了数据集中包含的目标列(标签列)
# 评估过程中使用这些目标列来计算评估指标(如准确率)

打卡

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/49103.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据开发/数仓工程师上手指南(二)数仓构建分层概念

前言 回顾上篇文章我们可以用思维导图一遍概览&#xff1a; 在了解了数仓的基本架构之后&#xff0c;我们还需要掌握数仓构建方法&#xff0c;也就是了解数仓是如何建模的&#xff0c;有什么规则和通用方法。我们应该如何去构建一个性能良好、稳定高效、契合业务的数据仓库。…

图形/视图架构的坐标系

图形/视图架构有 3 个有效的坐标系&#xff1a;场景坐标系、视图坐标系、图形项坐标系。 视图坐标系 视图坐标系就是视图组件的物理坐标系&#xff0c;单位是像素。QGraphicsView 视口的左上角坐标总是(0,0)。 场景坐标系 场景坐标系定义了所有图形项的基础坐标&#xff0c;场…

如何排查GD32 MCU复位是由哪个复位源导致的?

上期为大家讲解了GD32 MCU复位包括电源复位和系统复位&#xff0c;其中系统复位还包括独立看门狗复位、内核软复位、窗口看门狗复位等&#xff0c;在一个GD32系统中&#xff0c;如果莫名其妙产生了MCU复位&#xff0c;如何排查具体是由哪个复位源导致的呢&#xff1f; GD32 MC…

Idea如何查看Maven依赖树

1、使用idea自带的功能查看依赖树 2、使用Maven Helper插件 https://zhuanlan.zhihu.com/p/699663369

《Milvus Cloud向量数据库指南》——监管机构和社区:开源许可证标准的守护者与推动者

在开源软件的浩瀚宇宙中,监管机构和社区构成了其稳定运行与持续发展的双轮驱动。这些组织不仅定义了开源的本质,还通过制定、维护和执行许可证标准,确保了开源生态的开放性、透明性和协作精神得以传承。其中,开源倡议组织(OSI)、自由软件基金会(FSF)以及Apache软件基金…

【STM32】IIC学习笔记

学习IIC 前言一、基础知识GPIO_WriteBit 写入高低电平 二、放代码三、逐行细读总结 前言 最近沉迷手写笔记~ 尝试解读江科大的IIC程序&#xff0c;结合笔记更理解IIC 一、基础知识 GPIO_WriteBit 写入高低电平 二、放代码 这个是江科大的软件IIC的设置部分 #include "s…

正点原子 通用外设配置模型 GPIO配置步骤 NVIC配置

1. 这个是通用外设驱动模式配置 除了初始化是必须的 其他不是必须的 2. gpio配置步骤 1.使能时钟是相当于开电 2.设置工作模式是配置是输出还是输入 是上拉输入还是下拉输入还是浮空 是高速度还是低速度这些 3 和 4小点就是读写io口的状态了 3. 这个图是正点原子 将GPIO 的时…

Axure设计之轮播图(动态面板+中继器)

轮播图&#xff08;Carousel&#xff09;是一种网页或应用界面中常见的组件&#xff0c;用于展示一系列的图片或内容&#xff0c;通常通过自动播放或用户交互&#xff08;如点击箭头按钮&#xff09;来切换展示不同的内容。轮播图能够吸引用户的注意力&#xff0c;有效展示重要…

全能数据分析工具:Tableau Desktop 2019 for Mac 中文激活版

Tableau Desktop 2019 一款专业的全能数据分析工具&#xff0c;可以让用户将海量数据导入并记性汇总&#xff0c;并且支持多种数据类型&#xff0c;比如像是编程常用的键值对、哈希MAP、JSON类型数据等&#xff0c;因此用户可以将很多常用数据库文件直接导入Tableau Desktop&am…

Django Web开发:构建强大RBAC权限管理系统的实战指南

文章目录 前言一、rbac 基于角色的权限管理1.acl 基于用户的权限管理2.rbac 基于角色的权限管理 二、应用示例1.配置角色资源a.分析表b.核心逻辑c.使用transfer在前端实现资源配置d.页面效果 2.登录时获取对应权限a.员工登录b.中间件c.前端请求d.效果图 3.前端-路由守卫-页面权…

GAT知识总结

《GRAPH ATTENTION NETWORKS》 解决GNN聚合邻居节点的时候没有考虑到不同的邻居节点重要性不同的问题&#xff0c;GAT借鉴了Transformer的idea&#xff0c;引入masked self-attention机制&#xff0c; 在计算图中的每个节点的表示的时候&#xff0c;会根据邻居节点特征的不同来…

解开基于大模型的Text2SQL的神秘面纱

你好&#xff0c;我是 shengjk1&#xff0c;多年大厂经验&#xff0c;努力构建 通俗易懂的、好玩的编程语言教程。 欢迎关注&#xff01;你会有如下收益&#xff1a; 了解大厂经验拥有和大厂相匹配的技术等 希望看什么&#xff0c;评论或者私信告诉我&#xff01; 文章目录 一…

JAVA基础 - 对象

目录 一. 简介 二. 空对象 三. 构造方法 四. 析构方法 五. this关键字 六. 对象销毁 一. 简介 在 Java 中&#xff0c;对象&#xff08;Object&#xff09;是面向对象编程的核心概念。 对象是类的实例化&#xff0c;它将数据&#xff08;属性&#xff09;和操作这些数据…

【运算放大器】输入失调电压和输入偏置电流(2)实例计算

概述 根据上一篇文章的理论&#xff0c;分别计算没有输入电阻和有输入电阻两种情况下的运放总输出误差。例题来自于TI高精度实验室系列课程。 目录 概述实例计算 1&#xff1a;没有输入电阻实例计算 2&#xff1a;有输入电阻总结 实例计算 1&#xff1a;没有输入电阻 要求&am…

通过IEC104转MQTT网关对接阿里云、华为云、亚马逊AWS、ThingsBoard、Ignition、Zabbix

随着工业互联网的快速发展&#xff0c;传统电力系统中的IEC 104协议设备正逐步向更加开放、灵活的物联网架构转型。MQTT&#xff08;Message Queuing Telemetry Transport&#xff09;作为一种轻量级的消息传输协议&#xff0c;因其低带宽消耗、高可靠性和广泛的支持性&#xf…

vue3前端开发-小兔鲜项目-路由拦截器增加token的携带

vue3前端开发-小兔鲜项目-路由拦截器增加token的携带&#xff01;实际开发中&#xff0c;很多业务接口的请求&#xff0c;都要求必须是登录状态&#xff01;为此&#xff0c;这个token信息就会频繁的被加入到了请求头部信息中。request请求头内既然需要频繁的携带这个token.我们…

集团ERP信息化项目实施方案(82页PPT)

集团ERP信息化项目实施方案的82页PPT详尽阐述了企业资源规划&#xff08;ERP&#xff09;系统实施的全过程&#xff0c;旨在帮助集团整合多个业务流程于一个统一的平台。方案从当前市场环境分析入手&#xff0c;解释了ERP系统对于提升集团运营效率、降低成本和优化资源配置的必…

【OpenCV C++20 学习笔记】图片融合

图片融合 原理实现结果展示完整代码 原理 关于OpenCV的配置和基础用法&#xff0c;请参阅本专栏的其他文章&#xff1a;垚武田的OpenCV合集 这里采用的图片熔合的算法来自Richard Szeliski的书《Computer Vision: Algorithms and Applications》&#xff08;《计算机视觉&#…

STM32是使用的内部时钟还是外部时钟

STM32是使用的内部时钟还是外部时钟&#xff0c;经常会有人问这个问题。 1、先了解时钟树&#xff0c;见下图&#xff1a; 2、在MDK中&#xff0c;使用的是HSEPLL作为SYSCLK&#xff0c;因此需要对时钟配置寄存器&#xff08;RCC_CFGR&#xff09;进行配置&#xff0c;寄存器内…

Eaton伊顿触摸屏维修XV-303-15-C00-A00-1C

伊顿触摸屏维修,工业触摸屏维修,主板维修,坏高故障,损坏显示,不损坏,运行稳定,不花屏,无反应慢等故障维修,维修有保障,资费低.,触摸屏主板坏,高压板故障,按键损坏等均可修理。 伊顿触摸屏维修 EATON触摸屏维修 伊顿工控机维修 EATON工控机维修 伊顿人机界面维修 EATON触摸屏维…