昇思25天学习打卡营第15天|基于 MindSpore 实现 BERT 对话情绪识别

文章目录

      • 昇思MindSpore应用实践
          • 1、基于 MindSpore 实现 BERT 对话情绪识别
            • BERT 模型简介
            • 数据集
            • 数据加载和数据预处理
          • 2、模型训练
            • 模型验证
          • 3、模型推理
      • Reference

昇思MindSpore应用实践

本系列文章主要用于记录昇思25天学习打卡营的学习心得。

1、基于 MindSpore 实现 BERT 对话情绪识别
BERT 模型简介

自2017年Google发表"Attention is ALL You Need"之后,基于自注意力机制结构的网络模型开始在各领域发展,特别是Transformer网络,针对序列的强大长距离关联处理能力在NLP领域取得了成效,BERT 的创新点在于它将双向 Transformer 用于语言模型,曾一度刷新各项NLP任务的SOTA记录,包括问答 Question Answering (SQuAD v1.1),推理 Natural Language Inference (MNLI) 等,也由此拉开了LLM的序幕。

在此之前,以往的NLP模型,如RNN网络,通常是从左向右输入一个文本序列,或者将 left-to-right 和 right-to-left 的训练结合起来(如Bi-LSTM)。实验结果表明,双向训练的语言模型对语境的理解会比单向的语言模型更深刻,

BERT 利用了 Transformer 的 Encoder 部分。Transformer 基于自注意力机制可以学习文本中单词之间的上下文关系。Transformer 的原型包括两个独立的机制,一个 Encoder 负责接收文本作为输入,一个 Decoder 负责预测任务的结果。BERT 的目标是生成语言模型,所以只需要 Encoder 机制。

Transformer 的 encoder 是一次性读取整个文本序列,而不是从左到右或从右到左地按顺序读取,这个特征使得模型能够基于单词的两侧学习,相当于是一个双向的功能

下图是 Transformer 的 encoder 部分,输入是一个 token 序列,先对其进行 embedding 称为向量,然后输入给神经网络,输出是大小为 H 的向量序列,每个向量对应着具有相同索引的 token。
在这里插入图片描述

BERT模型的主要创新点都在pre-train方法上,即用了Masked Language ModelNext Sentence Prediction两种方法分别捕捉词语和句子级别的representation。

在用Masked Language Model方法训练BERT的时候,随机把语料库中15%的单词做Mask操作。对于这15%的单词做Mask操作分为三种情况:80%的单词直接用[Mask]替换、10%的单词直接替换成另一个新的单词、10%的单词保持不变。

因为涉及到Question Answering (QA) 和 Natural Language Inference (NLI)之类的任务,增加了Next Sentence Prediction预训练任务,目的是让模型理解两个句子之间的联系。与Masked Language Model任务相比,Next Sentence Prediction更简单些,训练的输入是句子A和B,B有一半的几率是A的下一句,输入这两个句子,BERT模型预测B是不是A的下一句

BERT预训练之后,会保存它的Embedding table和12层Transformer权重(BERT-BASE)或24层Transformer权重(BERT-LARGE)。使用预训练好的BERT模型可以对下游任务进行Fine-tuning,比如:文本分类、相似度判断、阅读理解等。

对话情绪识别(Emotion Detection,简称EmoTect),专注于识别智能对话场景中用户的情绪,针对智能对话场景中的用户文本,自动判断该文本的情绪类别并给出相应的置信度,情绪类型分为积极、消极、中性。 对话情绪识别适用于聊天、客服等多个场景,能够帮助企业更好地把握对话质量、改善产品的用户交互体验,也能分析客服服务质量、降低人工质检成本。

下面以MindSpore给出的一个文本情感分类任务为例子来说明BERT模型的整个应用过程:

import osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn, contextfrom mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy# prepare dataset
class SentimentDataset:"""Sentiment Dataset"""def __init__(self, path):self.path = pathself._labels, self._text_a = [], []self._load()def _load(self):with open(self.path, "r", encoding="utf-8") as f:dataset = f.read()lines = dataset.split("\n")for line in lines[1:-1]:label, text_a = line.split("\t")self._labels.append(int(label))self._text_a.append(text_a)def __getitem__(self, index):return self._labels[index], self._text_a[index]def __len__(self):return len(self._labels)
数据集

这里提供一份已标注的、经过分词预处理的机器人聊天数据集,来自于百度飞桨团队。数据由两列组成,以制表符(‘\t’)分隔,第一列是情绪分类的类别(0表示消极;1表示中性;2表示积极),第二列是以空格分词的中文文本,如下示例,文件为 utf8 编码。

label–text_a

0–谁骂人了?我从来不骂人,我骂的都不是人,你是人吗 ?

1–我有事等会儿就回来和你聊

2–我见到你很高兴谢谢你帮我

这部分主要包括数据集读取,数据格式转换,数据 Tokenize 处理和 pad 操作。

数据加载和数据预处理

新建 process_dataset 函数用于数据加载和数据预处理:

import numpy as npdef process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):is_ascend = mindspore.get_context('device_target') == 'Ascend'column_names = ["label", "text_a"]dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)# transformstype_cast_op = transforms.TypeCast(mindspore.int32)def tokenize_and_pad(text):if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text)return tokenized['input_ids'], tokenized['attention_mask']# map datasetdataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset

昇腾NPU环境下暂不支持动态Shape,数据预处理部分采用静态Shape处理:

from mindnlp.transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
tokenizer.pad_token_iddataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)
dataset_train.get_col_names()

[‘input_ids’, ‘attention_mask’, ‘labels’]

print(next(dataset_train.create_tuple_iterator()))

[Tensor(shape=[32, 64], dtype=Int64, value=
[[ 101, 6443, 3221 … 0, 0, 0],
[ 101, 872, 1963 … 0, 0, 0],
[ 101, 6929, 872 … 0, 0, 0],

[ 101, 872, 4268 … 0, 0, 0],
[ 101, 671, 4991 … 0, 0, 0],
[ 101, 1376, 1480 … 0, 0, 0]]), Tensor(shape=[32, 64], dtype=Int64, value=
[[1, 1, 1 … 0, 0, 0],
[1, 1, 1 … 0, 0, 0],
[1, 1, 1 … 0, 0, 0],

[1, 1, 1 … 0, 0, 0],
[1, 1, 1 … 0, 0, 0],
[1, 1, 1 … 0, 0, 0]]), Tensor(shape=[32], dtype=Int32, value= [1, 1, 1, 1, 1, 1, 1, 2, 2, 0, 2, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1,
1, 1, 2, 1, 1, 1, 1, 1])]

2、模型训练

通过 BertForSequenceClassification 构建用于情感分类的 BERT 模型,加载预训练权重,设置情感三分类的超参数自动构建模型。后面对模型采用自动混合精度操作,提高训练的速度,然后实例化优化器,紧接着实例化评价指标,设置模型训练的权重保存策略,最后就是构建训练器,模型开始训练。

from mindnlp.transformers import BertForSequenceClassification, BertModel
from mindnlp._legacy.amp import auto_mixed_precision# set bert config and define parameters for training
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
model = auto_mixed_precision(model, 'O1')optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)metric = Accuracy()
# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_val, metrics=metric,epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])
# start training
trainer.run(tgt_columns="labels")

在这里插入图片描述

模型验证

将验证数据集加再进训练好的模型,对数据集进行验证,查看模型在验证数据上面的效果,此处的评价指标为准确率。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")% print_log:
Evaluate Score: {'Accuracy': 0.9083011583011583}
3、模型推理

遍历推理数据集,将结果与标签进行统一展示。

dataset_infer = SentimentDataset("data/infer.tsv")
def predict(text, label=None):label_map = {0: "消极", 1: "中性", 2: "积极"}text_tokenized = Tensor([tokenizer(text).input_ids])logits = model(text_tokenized)predict_label = logits[0].asnumpy().argmax()info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"if label is not None:info += f" , label: '{label_map[label]}'"print(info)
from mindspore import Tensorfor label, text in dataset_infer:predict(text, label)# print_log:
inputs: '我 要 客观', predict: '中性' , label: '中性'
inputs: '靠 你 真是 说 废话 吗', predict: '消极' , label: '消极'
inputs: '口嗅 会', predict: '中性' , label: '中性'
inputs: '每次 是 表妹 带 窝 飞 因为 窝路痴', predict: '中性' , label: '中性'
inputs: '别说 废话 我 问 你 个 问题', predict: '消极' , label: '消极'
inputs: '4967 是 新加坡 那 家 银行', predict: '中性' , label: '中性'
inputs: '是 我 喜欢 兔子', predict: '积极' , label: '积极'
inputs: '你 写 过 黄山 奇石 吗', predict: '中性' , label: '中性'
inputs: '一个一个 慢慢来', predict: '中性' , label: '中性'
inputs: '我 玩 过 这个 一点 都 不 好玩', predict: '消极' , label: '消极'
inputs: '网上 开发 女孩 的 QQ', predict: '中性' , label: '中性'
inputs: '背 你 猜 对 了', predict: '中性' , label: '中性'
inputs: '我 讨厌 你 , 哼哼 哼 。 。', predict: '消极' , label: '消极'

在这里插入图片描述

Reference

昇思大模型平台
不会停的蜗牛-5 分钟入门 Google 最强NLP模型:BERT

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/44405.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ArduPilot开源代码之AP_OpticalFlow_CXOF

ArduPilot开源代码之AP_OpticalFlow_CXOF 1. 源由2. Library设计3. 重要例程3.1 AP_OpticalFlow_CXOF::init3.2 AP_OpticalFlow_CXOF::update3.3 AP_OpticalFlow_CXOF::detect 4. 总结5. 参考资料 1. 源由 AP_OpticalFlow_CXOF是就是一个光流计,与前面传感模块&…

【网络】为什么SCTP四次握手可以抵御SYN攻击

深入理解SCTP的安全性:从四次握手到抵御SYN攻击 引言 在网络通信的世界中,安全性和可靠性是至关重要的。传统的TCP(传输控制协议)在建立连接时使用三次握手,但这种机制存在一些安全漏洞,比如SYN攻击。而S…

解决IDEA每次新建项目都需要重新配置maven的问题

每次打开IDEA都要重新配置maven,这是因为在DEA中分为项目设置和全局设置,这个时候我们就需要去到全局中设置maven了。我用的是IntelliJ IDEA 2023.3.4 (Ultimate Edition),以此为例。 第一步:打开一个空的IDEA,选择左…

数据结构day6链式队列

主程序 #include "fun.h" int main(int argc, const char *argv[]) { que_p Qcreate(); enqueue(Q,10); enqueue(Q,20); enqueue(Q,30); enqueue(Q,40); enqueue(Q,50); show_que(Q); dequeue(Q); show_que(Q); printf(&qu…

stm32按键设置闹钟数进退位不正常?如何解决

🏆本文收录于《CSDN问答解惑-专业版》专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收…

A63 STM32_HAL库函数 之 Uart通用驱动 -- B -- 所有函数的介绍及使用

A63 STM32_HAL库函数 之 Uart通用驱动 -- B -- 所有函数的介绍及使用 1 该驱动函数预览1.15 HAL_UART_DMAResume1.16 HAL_UART_DMAStop1.17 HAL_UART_Abort1.18 HAL_UART_AbortTransmit1.19 HAL_UART_AbortReceive1.20 HAL_UART_Abort_IT1.21 HAL_UART_AbortTransmit_IT1.22 HA…

【Zoom安全解析】深入Zoom的端到端加密机制

标题:【Zoom安全解析】深入Zoom的端到端加密机制 在远程工作和在线会议变得越来越普及的今天,视频会议平台的安全性成为了用户关注的焦点。Zoom作为全球领先的视频会议软件,其端到端加密(E2EE)功能保证了通话的安全性…

李彦宏: 开源模型是智商税|马斯克: OpenAI 闭源不如叫 CloseAI

在 2024 年世界人工智能大会(WAIC 2024)上,百度创始人、董事长兼首席执行官李彦宏发表对开源模型的评价。 李彦宏认为:开源模型实际上是一种智商税,而闭源模型才是人工智能(AI)行业的未来。 马…

python库 - sentencepiece

SentencePiece 是一个开源的文本处理库,由 Google 开发,专门用于处理和生成无监督的文本符号化(tokenization)模型。它支持字节对编码(BPE)和 Unigram 语言模型两种主要的符号化算法,广泛应用于…

关于es中的kibana中的一些常用操作命令语句

1、查询某个索引的文档数据 get /index_name/_search2、根据某个索引下的具体某个字段的值来查询数据,这里比如说是:根据就modelId 1787774412315766785来查询数据 GET /index_name/_search {"query": {"term": {"modelId&q…

如何批量更改很多个文件夹里的文件名中包含文件夹名?

🏆本文收录于《CSDN问答解惑-专业版》专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收…

【学术会议征稿】第五届大数据、人工智能与物联网工程国际会议

第五届大数据、人工智能与物联网工程国际会议 2024 5th International Conference on Big Data, Artificial Intelligence and Internet of Things 第五届大数据、人工智能与物联网工程国际会议(ICBAIE 2024)定于2024年10月25-27号在中国深圳隆重举行。…

SIFT代码,MATLAB

SIFT特征详解 - Brook_icv - 博客园 (cnblogs.com) 这一篇文章写的挺好的,个人觉得,大家感兴趣的话可以看一下,但是有一些地方个人觉得还是没有讲的很清楚。 代码下载路径: https://pan.baidu.com/s/1c_NULsDrF3nH6xTmC1cY2Q 提…

【系统架构设计师】计算机组成与体系结构 ⑪ ( 数据传输控制方式 | 程序直接控制方式 | 中断控制方式 | 直接内存访问方式 )

文章目录 一、数据传输控制方式1、IO 设备数据传输2、数据传输控制方式 二、程序直接控制方式 ( 重点考点 )1、无条件传送 和 程序查询方式2、程序查询方式3、程序直接控制方式 的 优缺点 三、程序中断方式1、程序中断方式 流程2、程序中断方式 优缺点 四、DMA 方式1、DMA 简介…

好用的声音分析的软件和网站

有许多软件和网站可以帮助进行声音分析,从专业级的音频处理软件到在线工具,以下是一些推荐: 专业音频分析软件 Audacity 开源且免费的音频编辑和分析工具。提供基本的音频录制、编辑和分析功能。支持多种插件,扩展其功能。 Adob…

项目实战--Spring Boot 3整合Flink实现大数据文件处理

一、应用背景 公司大数据项目中,需要构建和开发高效、可靠的数据处理子系统,实现大数据文件处理、整库迁移、延迟与乱序处理、数据清洗与过滤、实时数据聚合、增量同步(CDC)、状态管理与恢复、反压问题处理、数据分库分表、跨数据…

把 .py 文件编译成 .pyd 文件

将 .py 文件编译成 .pyd 文件(在Windows上)或 .so 文件(在Linux或macOS上),实际上是将Python代码编译成一种可以被Python解释器直接加载的二进制模块。这种编译过程通常使用cython、pyinstaller的钩子(hook…

JRT打印药敏报告

最近没写jrt系列博客,不是中途而废了。而是在写微生物系统。今天终于把微生物大体完成了,伴随着业务的实现,框架趋于完善和稳定。构建一套完美而强大的打印体系一直是我的理想,从最开始C#的winform打印控件到刚接触bs时候用js打印…

day11:文件处理

一、文件与文件模式介绍 1、什么是文件 文件是操作系统提供给用户/应用程序操作硬盘的一种虚拟的概念/接口 用户/应用程序(open()) 操作系统(文件) 计算机硬件(硬盘)2、为何要用文件 ①用户/应用程序可以通过文件将数据永久保存…

【最强八股文 -- 计算机网络】【快速版】DNS 解析过程

步骤一:查询浏览器及计算机本地 HOST 文件中是否有对应的 缓存 步骤二:请求本地 DNS 服务器 (无则请求 根级 获取能提供信息的权威 DNS 服务器 ) 步骤三:逐级返回对应的 IP 至浏览器 步骤四:浏览器发起连接并缓存 参考资料&#x…