昇思25天学习打卡营第15天|基于 MindSpore 实现 BERT 对话情绪识别

文章目录

      • 昇思MindSpore应用实践
          • 1、基于 MindSpore 实现 BERT 对话情绪识别
            • BERT 模型简介
            • 数据集
            • 数据加载和数据预处理
          • 2、模型训练
            • 模型验证
          • 3、模型推理
      • Reference

昇思MindSpore应用实践

本系列文章主要用于记录昇思25天学习打卡营的学习心得。

1、基于 MindSpore 实现 BERT 对话情绪识别
BERT 模型简介

自2017年Google发表"Attention is ALL You Need"之后,基于自注意力机制结构的网络模型开始在各领域发展,特别是Transformer网络,针对序列的强大长距离关联处理能力在NLP领域取得了成效,BERT 的创新点在于它将双向 Transformer 用于语言模型,曾一度刷新各项NLP任务的SOTA记录,包括问答 Question Answering (SQuAD v1.1),推理 Natural Language Inference (MNLI) 等,也由此拉开了LLM的序幕。

在此之前,以往的NLP模型,如RNN网络,通常是从左向右输入一个文本序列,或者将 left-to-right 和 right-to-left 的训练结合起来(如Bi-LSTM)。实验结果表明,双向训练的语言模型对语境的理解会比单向的语言模型更深刻,

BERT 利用了 Transformer 的 Encoder 部分。Transformer 基于自注意力机制可以学习文本中单词之间的上下文关系。Transformer 的原型包括两个独立的机制,一个 Encoder 负责接收文本作为输入,一个 Decoder 负责预测任务的结果。BERT 的目标是生成语言模型,所以只需要 Encoder 机制。

Transformer 的 encoder 是一次性读取整个文本序列,而不是从左到右或从右到左地按顺序读取,这个特征使得模型能够基于单词的两侧学习,相当于是一个双向的功能

下图是 Transformer 的 encoder 部分,输入是一个 token 序列,先对其进行 embedding 称为向量,然后输入给神经网络,输出是大小为 H 的向量序列,每个向量对应着具有相同索引的 token。
在这里插入图片描述

BERT模型的主要创新点都在pre-train方法上,即用了Masked Language ModelNext Sentence Prediction两种方法分别捕捉词语和句子级别的representation。

在用Masked Language Model方法训练BERT的时候,随机把语料库中15%的单词做Mask操作。对于这15%的单词做Mask操作分为三种情况:80%的单词直接用[Mask]替换、10%的单词直接替换成另一个新的单词、10%的单词保持不变。

因为涉及到Question Answering (QA) 和 Natural Language Inference (NLI)之类的任务,增加了Next Sentence Prediction预训练任务,目的是让模型理解两个句子之间的联系。与Masked Language Model任务相比,Next Sentence Prediction更简单些,训练的输入是句子A和B,B有一半的几率是A的下一句,输入这两个句子,BERT模型预测B是不是A的下一句

BERT预训练之后,会保存它的Embedding table和12层Transformer权重(BERT-BASE)或24层Transformer权重(BERT-LARGE)。使用预训练好的BERT模型可以对下游任务进行Fine-tuning,比如:文本分类、相似度判断、阅读理解等。

对话情绪识别(Emotion Detection,简称EmoTect),专注于识别智能对话场景中用户的情绪,针对智能对话场景中的用户文本,自动判断该文本的情绪类别并给出相应的置信度,情绪类型分为积极、消极、中性。 对话情绪识别适用于聊天、客服等多个场景,能够帮助企业更好地把握对话质量、改善产品的用户交互体验,也能分析客服服务质量、降低人工质检成本。

下面以MindSpore给出的一个文本情感分类任务为例子来说明BERT模型的整个应用过程:

import osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn, contextfrom mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy# prepare dataset
class SentimentDataset:"""Sentiment Dataset"""def __init__(self, path):self.path = pathself._labels, self._text_a = [], []self._load()def _load(self):with open(self.path, "r", encoding="utf-8") as f:dataset = f.read()lines = dataset.split("\n")for line in lines[1:-1]:label, text_a = line.split("\t")self._labels.append(int(label))self._text_a.append(text_a)def __getitem__(self, index):return self._labels[index], self._text_a[index]def __len__(self):return len(self._labels)
数据集

这里提供一份已标注的、经过分词预处理的机器人聊天数据集,来自于百度飞桨团队。数据由两列组成,以制表符(‘\t’)分隔,第一列是情绪分类的类别(0表示消极;1表示中性;2表示积极),第二列是以空格分词的中文文本,如下示例,文件为 utf8 编码。

label–text_a

0–谁骂人了?我从来不骂人,我骂的都不是人,你是人吗 ?

1–我有事等会儿就回来和你聊

2–我见到你很高兴谢谢你帮我

这部分主要包括数据集读取,数据格式转换,数据 Tokenize 处理和 pad 操作。

数据加载和数据预处理

新建 process_dataset 函数用于数据加载和数据预处理:

import numpy as npdef process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):is_ascend = mindspore.get_context('device_target') == 'Ascend'column_names = ["label", "text_a"]dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)# transformstype_cast_op = transforms.TypeCast(mindspore.int32)def tokenize_and_pad(text):if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text)return tokenized['input_ids'], tokenized['attention_mask']# map datasetdataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset

昇腾NPU环境下暂不支持动态Shape,数据预处理部分采用静态Shape处理:

from mindnlp.transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
tokenizer.pad_token_iddataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)
dataset_train.get_col_names()

[‘input_ids’, ‘attention_mask’, ‘labels’]

print(next(dataset_train.create_tuple_iterator()))

[Tensor(shape=[32, 64], dtype=Int64, value=
[[ 101, 6443, 3221 … 0, 0, 0],
[ 101, 872, 1963 … 0, 0, 0],
[ 101, 6929, 872 … 0, 0, 0],

[ 101, 872, 4268 … 0, 0, 0],
[ 101, 671, 4991 … 0, 0, 0],
[ 101, 1376, 1480 … 0, 0, 0]]), Tensor(shape=[32, 64], dtype=Int64, value=
[[1, 1, 1 … 0, 0, 0],
[1, 1, 1 … 0, 0, 0],
[1, 1, 1 … 0, 0, 0],

[1, 1, 1 … 0, 0, 0],
[1, 1, 1 … 0, 0, 0],
[1, 1, 1 … 0, 0, 0]]), Tensor(shape=[32], dtype=Int32, value= [1, 1, 1, 1, 1, 1, 1, 2, 2, 0, 2, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1,
1, 1, 2, 1, 1, 1, 1, 1])]

2、模型训练

通过 BertForSequenceClassification 构建用于情感分类的 BERT 模型,加载预训练权重,设置情感三分类的超参数自动构建模型。后面对模型采用自动混合精度操作,提高训练的速度,然后实例化优化器,紧接着实例化评价指标,设置模型训练的权重保存策略,最后就是构建训练器,模型开始训练。

from mindnlp.transformers import BertForSequenceClassification, BertModel
from mindnlp._legacy.amp import auto_mixed_precision# set bert config and define parameters for training
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
model = auto_mixed_precision(model, 'O1')optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)metric = Accuracy()
# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_val, metrics=metric,epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])
# start training
trainer.run(tgt_columns="labels")

在这里插入图片描述

模型验证

将验证数据集加再进训练好的模型,对数据集进行验证,查看模型在验证数据上面的效果,此处的评价指标为准确率。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")% print_log:
Evaluate Score: {'Accuracy': 0.9083011583011583}
3、模型推理

遍历推理数据集,将结果与标签进行统一展示。

dataset_infer = SentimentDataset("data/infer.tsv")
def predict(text, label=None):label_map = {0: "消极", 1: "中性", 2: "积极"}text_tokenized = Tensor([tokenizer(text).input_ids])logits = model(text_tokenized)predict_label = logits[0].asnumpy().argmax()info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"if label is not None:info += f" , label: '{label_map[label]}'"print(info)
from mindspore import Tensorfor label, text in dataset_infer:predict(text, label)# print_log:
inputs: '我 要 客观', predict: '中性' , label: '中性'
inputs: '靠 你 真是 说 废话 吗', predict: '消极' , label: '消极'
inputs: '口嗅 会', predict: '中性' , label: '中性'
inputs: '每次 是 表妹 带 窝 飞 因为 窝路痴', predict: '中性' , label: '中性'
inputs: '别说 废话 我 问 你 个 问题', predict: '消极' , label: '消极'
inputs: '4967 是 新加坡 那 家 银行', predict: '中性' , label: '中性'
inputs: '是 我 喜欢 兔子', predict: '积极' , label: '积极'
inputs: '你 写 过 黄山 奇石 吗', predict: '中性' , label: '中性'
inputs: '一个一个 慢慢来', predict: '中性' , label: '中性'
inputs: '我 玩 过 这个 一点 都 不 好玩', predict: '消极' , label: '消极'
inputs: '网上 开发 女孩 的 QQ', predict: '中性' , label: '中性'
inputs: '背 你 猜 对 了', predict: '中性' , label: '中性'
inputs: '我 讨厌 你 , 哼哼 哼 。 。', predict: '消极' , label: '消极'

在这里插入图片描述

Reference

昇思大模型平台
不会停的蜗牛-5 分钟入门 Google 最强NLP模型:BERT

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/44405.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决IDEA每次新建项目都需要重新配置maven的问题

每次打开IDEA都要重新配置maven,这是因为在DEA中分为项目设置和全局设置,这个时候我们就需要去到全局中设置maven了。我用的是IntelliJ IDEA 2023.3.4 (Ultimate Edition),以此为例。 第一步:打开一个空的IDEA,选择左…

数据结构day6链式队列

主程序 #include "fun.h" int main(int argc, const char *argv[]) { que_p Qcreate(); enqueue(Q,10); enqueue(Q,20); enqueue(Q,30); enqueue(Q,40); enqueue(Q,50); show_que(Q); dequeue(Q); show_que(Q); printf(&qu…

stm32按键设置闹钟数进退位不正常?如何解决

🏆本文收录于《CSDN问答解惑-专业版》专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收…

李彦宏: 开源模型是智商税|马斯克: OpenAI 闭源不如叫 CloseAI

在 2024 年世界人工智能大会(WAIC 2024)上,百度创始人、董事长兼首席执行官李彦宏发表对开源模型的评价。 李彦宏认为:开源模型实际上是一种智商税,而闭源模型才是人工智能(AI)行业的未来。 马…

如何批量更改很多个文件夹里的文件名中包含文件夹名?

🏆本文收录于《CSDN问答解惑-专业版》专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收…

【学术会议征稿】第五届大数据、人工智能与物联网工程国际会议

第五届大数据、人工智能与物联网工程国际会议 2024 5th International Conference on Big Data, Artificial Intelligence and Internet of Things 第五届大数据、人工智能与物联网工程国际会议(ICBAIE 2024)定于2024年10月25-27号在中国深圳隆重举行。…

JRT打印药敏报告

最近没写jrt系列博客,不是中途而废了。而是在写微生物系统。今天终于把微生物大体完成了,伴随着业务的实现,框架趋于完善和稳定。构建一套完美而强大的打印体系一直是我的理想,从最开始C#的winform打印控件到刚接触bs时候用js打印…

通过Arcgis从逐月平均气温数据中提取并计算年平均气温

通过Arcgis快速将逐月平均气温数据生成年平均气温数据。本次用2020年逐月平均气温数据操作说明。 一、准备工作 (1)准备Arcmap桌面软件; (2)准备2020年逐月平均气温数据(NC格式)、范围图层数据&…

JAVA分布式事务详情分布式事务的解决方案Java中的分布式事务实现

本人详解 作者:王文峰,参加过 CSDN 2020年度博客之星,《Java王大师王天师》 公众号:JAVA开发王大师,专注于天道酬勤的 Java 开发问题中国国学、传统文化和代码爱好者的程序人生,期待你的关注和支持!本人外号:神秘小峯 山峯 转载说明:务必注明来源(注明:作者:王文峰…

数一140+上岸|七月强化一定要避开这3个雷区!

当然可以,强化阶段的主要任务就是做题! 但是不用一刀切,强化阶段听课和做题可以二八原则,就是听课占20%,做题占80%。 因为自己去自学讲义的话,比如张宇18讲,会漏掉一些重点,有的技…

进程间的通信--管道

文章目录 一、进程通信的介绍1.1进程间为什么需要通信1.2进程如何通信 二、管道2.1匿名管道2.1.1文件描述符理解管道2.1.2接口使用2.1.3管道的4种情况2.1.4管道的五种特征 2.2管道的使用场景2.2.1命令行中的管道2.2.2进程池 2.命名管道2.1.1原理2.2.2接口2.2.3代码实例 一、进程…

scipy库中,不同应用滤波函数的区别,以及FIR滤波器和IIR滤波器的区别

一、在 Python 中,有多种函数可以用于应用 FIR/IIR 滤波器,每个函数的使用场景和特点各不相同。以下是一些常用的 FIR /IIR滤波器应用函数及其区别: from scipy.signal import lfiltery lfilter(fir_coeff, 1.0, x)from scipy.signal impo…

【Docker-compose】搭建php 环境

文章目录 Docker-compose容器编排1. 是什么2. 能干嘛3. 去哪下4. Compose 核心概念5. 实战 :linux 配置dns 服务器,搭建lemp环境(Nginx MySQL (MariaDB) PHP )要求6. 配置dns解析配置 lemp Docker-compose容器编排 1. 是什么 …

【智能算法改进】一种混合多策略改进的麻雀搜索算法

目录 1.算法原理2.改进点3.结果展示4.参考文献5.代码获取 1.算法原理 【智能算法】麻雀搜索算法(SSA)原理及实现 2.改进点 精英反向学习策略 将精英反向学习策略应用到初始化阶段, 通过反向解的生成与精英个体的选择, 不仅使算法搜索范围得到扩大, 提…

从零开学C++:类和对象(上)

引言:在学习了C的入门级知识之后,现在就让我们一起进入类和对象的学习吧,该知识点我将分为上,中,下三个部分对其进行讲解。 更多有关C语言和数据结构的知识详解可前往个人主页:计信猫 目录 一,类…

Android liveData 监听异常,fragment可见时才收到回调记录

背景&#xff1a;在app的fragment不可见的情况下使用&#xff0c;发现注册了&#xff0c;但是没有回调导致数据一直未更新&#xff0c;只有在fragment可见的时候才收到回调 // 观察通用信息mLightNaviTopViewModel.getUpdateCommonInfo().observe(this, new Observer<Common…

[嵌入式 C 语言] 按位与、或、取反、异或

若协议中如下图所示&#xff1a; 注意&#xff1a; 长度为1&#xff0c;表示1个字节&#xff0c;也就是0xFF&#xff0c;也就是 1111 1111 &#xff08;这里0xFF只是单纯表示一个数&#xff0c;也可以是其他数&#xff0c;这里需要注意的是1个字节的意思&#xff09; 一、按位…

第三课网关作用

实验拓扑图&#xff1a; 基础配置&#xff1a; PC1的基础配置 PC2的基础配置&#xff1a; PC4的基础配置 AR1添加PC4网段: 并且添加pc1,pc2的网段。 并且添加pc1,pc2的网段。 原理&#xff1a;PC4先把数据交给100.100.100.1&#xff0c;交给了路由器&#xff0c;路由器再把数…

瑞萨RH850 RTC计时进位异常

RH850 MCU的RTC&#xff08;实时时钟&#xff09;采用BCD&#xff08;二进制编码的十进制&#xff09;编码格式&#xff0c;支持闰年自动识别&#xff0c;并具有秒、分、时、日、周、月、年的进位功能。其中&#xff0c;秒和分为60进位&#xff0c;时为12或24进位&#xff0c;周…