昇思25天学习打卡营第21天|基于MindSpore实现BERT对话情绪识别

模型简介

BERT全称是来自变换器的双向编码器表征量。与BERT模型相似的预训练语言模型例如问答、命名实体识别、自然语言推理、文本分类等在许多自然语言处理任务中发挥着重要作用。模型是基于Transformer中的Encoder并加上双向的结构。

BERT模型的主要创新点都在pre-train方法上,即用了Masked Language Model和Next Sentence Prediction两种方法分别捕捉词语和句子级别的representation。

在用Masked Language Model方法训练BERT的时候,随机把语料库中15%的单词做Mask操作。对于这15%的单词做Mask操作分为三种情况:80%的单词直接用[Mask]替换、10%的单词直接替换成另一个新的单词、10%的单词保持不变。

以文本情感分类任务作为例子说明BERT模型整个应用过程。

import osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn, contextfrom mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy# prepare dataset
class SentimentDataset:"""Sentiment Dataset"""def __init__(self, path):self.path = pathself._labels, self._text_a = [], []self._load()def _load(self):with open(self.path, "r", encoding="utf-8") as f:dataset = f.read()lines = dataset.split("\n")for line in lines[1:-1]:label, text_a = line.split("\t")self._labels.append(int(label))self._text_a.append(text_a)def __getitem__(self, index):return self._labels[index], self._text_a[index]def __len__(self):return len(self._labels)

数据集

使用已标注的、经过分词预处理的机器人聊天数据集,来自于百度飞桨团队。数据由两列组成,以制表符('\t')分隔,第一列是情绪分类的类别(0表示消极;1表示中性;2表示积极),第二列是以空格分词的中文文本,如下示例,文件为 utf8 编码。

# download dataset
!wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
!tar xvf emotion_detection.tar.gz

数据加载和数据预处理

import numpy as npdef process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):is_ascend = mindspore.get_context('device_target') == 'Ascend'column_names = ["label", "text_a"]dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)# transformstype_cast_op = transforms.TypeCast(mindspore.int32)def tokenize_and_pad(text):if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text)return tokenized['input_ids'], tokenized['attention_mask']# map datasetdataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset

数据预处理采用静态Shape处理

from mindnlp.transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

模型构建

通过 BertForSequenceClassification 构建用于情感分类的 BERT 模型,加载预训练权重,设置情感三分类的超参数自动构建模型。后面对模型采用自动混合精度操作,提高训练的速度,然后实例化优化器,紧接着实例化评价指标,设置模型训练的权重保存策略,最后就是构建训练器,模型开始训练。

from mindnlp.transformers import BertForSequenceClassification, BertModel
from mindnlp._legacy.amp import auto_mixed_precision# set bert config and define parameters for training
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
model = auto_mixed_precision(model, 'O1')optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
metric = Accuracy()
# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_val, metrics=metric,epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])
%%time
# start training
trainer.run(tgt_columns="labels")

模型验证

将验证数据集加再进训练好的模型,对数据集进行验证,查看模型在验证数据上面的效果,此处的评价指标为准确率。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

模型推理

遍历推理数据集,将结果与标签进行统一展示。

dataset_infer = SentimentDataset("data/infer.tsv")def predict(text, label=None):label_map = {0: "消极", 1: "中性", 2: "积极"}text_tokenized = Tensor([tokenizer(text).input_ids])logits = model(text_tokenized)predict_label = logits[0].asnumpy().argmax()info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"if label is not None:info += f" , label: '{label_map[label]}'"print(info)from mindspore import Tensorfor label, text in dataset_infer:predict(text, label)

输出结果:


inputs: '我 要 客观', predict: '中性' , label: '中性'
inputs: '靠 你 真是 说 废话 吗', predict: '消极' , label: '消极'
inputs: '口嗅 会', predict: '中性' , label: '中性'
inputs: '每次 是 表妹 带 窝 飞 因为 窝路痴', predict: '中性' , label: '中性'
inputs: '别说 废话 我 问 你 个 问题', predict: '消极' , label: '消极'
inputs: '4967 是 新加坡 那 家 银行', predict: '中性' , label: '中性'
inputs: '是 我 喜欢 兔子', predict: '积极' , label: '积极'
inputs: '你 写 过 黄山 奇石 吗', predict: '中性' , label: '中性'
inputs: '一个一个 慢慢来', predict: '中性' , label: '中性'
inputs: '我 玩 过 这个 一点 都 不 好玩', predict: '消极' , label: '消极'
inputs: '网上 开发 女孩 的 QQ', predict: '中性' , label: '中性'
inputs: '背 你 猜 对 了', predict: '中性' , label: '中性'
inputs: '我 讨厌 你 , 哼哼 哼 。 。', predict: '消极' , label: '消极'

自定义推理数据集

predict("家人们咱就是说一整个无语住了 绝绝子叠buff")

输出结果:

inputs: '家人们咱就是说一整个无语住了 绝绝子叠buff', predict: '中性'

总结

BERT模型采用了Masked Language Model和Next Sentence Prediction两种方法分别捕捉词语和句子级别的representation,从而获取对话情绪。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/43285.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SCI丨返修一作+通讯

中科四区,JCR2 返修转让一作通讯,5个月左右录用 题目:通过机器学习算法XXXXXXXxxx混凝土力学性能的可靠方法

Nginx 配置ssl证书

1. 准备 SSL 证书文件 确保您有以下文件: SSL 证书文件(通常是 .crt 或 .pem 文件) 私钥文件(通常是 .key 文件) 中间证书文件(如果适用,通常是 .crt 或 .pem 文件) 将这些文件上传…

苍穹外卖--完善登录功能:进行MD5加密

目标 TODO:使用MD5加密方式对明文密码。 实现 password DigestUtils.md5DigestAsHex(password.getBytes());

Face_recognition实现人脸识别

这里写自定义目录标题 欢迎使用Markdown编辑器一、安装人脸识别库face_recognition1.1 安装cmake1.2 安装dlib库1.3 安装face_recognition 二、3个常用的人脸识别案例2.1 识别并绘制人脸框2.2 提取并绘制人脸关键点2.3 人脸匹配及标注 欢迎使用Markdown编辑器 本文基于face_re…

双向链表+Map实现LRU

LRU: LRU是Least Recently Used的缩写,即最近最少使用,是一种常用的页面置换算法,选择最近最久未使用的页面予以淘汰。 核心思想: 基于Map实现k-v存储,双向链表中使用一个虚拟头部和虚拟尾部,虚拟头部的…

BioXcell—InVivoMAb anti-West Nile/dengue virus E protein

研发背景: 西尼罗河病毒(WNV)是一种由蚊虫类介导传播的黄病毒,与引起人类感染性流行病的登革热病毒、黄热病病毒和日本脑炎病毒密切相关。 WNV和登革热病毒(DENV)同属黄病毒科(Flaviviridae)黄热病毒属,是具有小包膜单…

实现llava的【单轮对话】调整成【多轮对话】 (输入图片/多模态/多轮对话/llava)

使用llava时,将llava的单轮对话调整成多轮对话 先说好方法一,直接对官方网站的quick start代码进行修改方法二:使用bash结合基础的单轮对话代码进行修改2.1 首先是最基础的llava官网代码回顾2.2 伪多轮对话方法(每次新的prompt输入…

AEC10 SA计算整理 --- 基础SA

LuxSA: LuxSALumaAvgLumaBE16x16 LuxSATarget[setparam/tr:luxlux] LuxSAAdjRatioLuxSATarget/LuxSALumaLuxSALuma: 计算16x16区域的平均亮度(Luma值)。 LuxSATarget: 通过参数设置获取目标亮度值(通常与当前光线条件相关)。 Lux…

【多模态】41、VILA | 打破常规多模态模型训练策略,在预训练阶段就微调 LLM 被证明能取得更好的效果!

论文:VILA: On Pre-training for Visual Language Models 代码:https://github.com/NVlabs/VILA 出处:NVLabs 时间:2024.05 贡献: 证明在预训练阶段对 LLM 进行微调能够提升模型对上下文任务的效果在 SFT 阶段混合…

同三维T80006解码器视频使用操作说明书:高清HDMI解码器,高清SDI解码器,4K超清HDMI解码器,双路4K超高清解码器

同三维T80006解码器视频使用操作说明书:高清HDMI解码器,高清SDI解码器,4K超清HDMI解码器,双路4K超高清解码器 解码器T80006 同三维,十多年老品牌,我们一直专注:视频采集卡、视频编码器、转码器、…

Centos7离线安装ElasticSearch7.4.2

一、官网下载相关的安装包 ElasticSearch7.4.2: elasticsearch-7.4.2-linux-x86_64.tar.gz 下载中文分词器: elasticsearch-analysis-ik-7.4.2.zip 二、上传解压文件到服务器 上传到目录:/home/data/elasticsearch 解压文件&#xff1…

免费无限白嫖阿里云服务器

今天,我来分享一个免费且无限使用阿里云服务器的方法,零成本!这适用于日常测试学习,比如测试 Shell 脚本、学习 Docker 安装、MySQL 等等。跟着我的步骤,你将轻松拥有一个稳定可靠的服务器,为你的学习和实践…

数据库的优点和缺点分别是什么

数据库作为数据存储和管理的核心组件,具有一系列显著的优点,同时也存在一些潜在的缺点。以下是对数据库优点和缺点的详细分析: 优点 数据一致性:数据库通过事务处理和锁机制等手段,确保数据的一致性和完整性。这意味着…

错误记录-SpringCloud-OpenFeign测试远程调用

文章目录 1,org.springframework.beans.factory.UnsatisfiedDependencyException: Error creating bean with name memberController: Unsatisfied dependency expressed through field couponFeign2, Receiver class org.springframework.cloud.netflix…

几种不同的方式禁止IP访问网站(PHP、Nginx、Apache设置方法)

1、PHP禁止IP和IP段访问 <?//禁止某个IP$banned_ip array ("127.0.0.1",//"119.6.20.66","192.168.1.4");if ( in_array( getenv("REMOTE_ADDR"), $banned_ip ) ){die ("您的IP禁止访问&#xff01;");}//禁止某个IP段…

SAP S4 环境下,KSU1 Ob52 转为前台操作,不产生传输请求号

参考 OB52/KSV1/KSU1等与Client状态相关的前台操作-y_q_yang se16n 编辑表 T811FLAGS 新增行 CYCLES MAINTENANCE X 即可

WAIC 2024 AI盛宴大会亮点回顾

&#x1f389; 刚刚落幕的WAIC 2024大会&#xff0c;简直是科技迷们的狂欢节&#xff01;这场全球瞩目的人工智能盛会&#xff0c;不仅汇聚了全球顶尖的智慧&#xff0c;还带来了无数令人惊叹的创新成果。让我们一起回顾那些抓人眼球的亮点吧&#xff01; &#x1f525; 超燃开…

Java面试八股之MySQL的redo log和undo log

MySQL的redo log和undo log 在MySQL的InnoDB存储引擎中&#xff0c;redo log和undo log是两种重要的日志&#xff0c;它们各自服务于不同的目的&#xff0c;对数据库的事务处理和恢复机制至关重要。 Redo Log&#xff08;重做日志&#xff09; 功能 redo log的主要作用是确…

是不是看错了,C++ 构造函数也可以是虚函数?(一)

以下内容为本人的学习笔记&#xff0c;如需要转载&#xff0c;请声明原文链接 微信公众号「ENG八戒」https://mp.weixin.qq.com/s/2HXYlggENXcSwTdydtof0g 首先&#xff0c;C 构造函数可以是虚函数吗&#xff1f; 语法上来说&#xff0c;答案是不行的。类对象的创建依赖于类构…

世界人工智能大会 | 江行智能大模型解决方案入选“AI赋能新型工业化创新应用优秀案例”

日前&#xff0c;2024世界人工智能大会暨人工智能全球治理高级别会议在上海启幕。本次大会主题为“以共商促共享&#xff0c;以善治促善智”&#xff0c;汇聚了上千位全球科技、产业界领军人物&#xff0c;共同探讨大模型、数据、新型工业化等人工智能深度发展时代下的热点话题…