昇思25天学习打卡营第17天 | 基于MindSpore实现BERT对话情绪识别

昇思25天学习打卡营第17天 | 基于MindSpore实现BERT对话情绪识别

文章目录

  • 昇思25天学习打卡营第17天 | 基于MindSpore实现BERT对话情绪识别
    • BERT模型
    • 对话情绪识别
    • BERT模型的文本情绪分类任务
      • 数据集
        • 数据下载
        • 数据加载与预处理
      • 模型构建
      • 模型验证
      • 模型推理
    • 总结
    • 打卡

BERT模型

BERT(Bidirectional Encoder Representations from Transformers)是Google开发的一种新型语言模型,模型主要基于Transformer中的Encoder并加上双向的结构。

BERT的主要创新点在pre-train方法上,使用了

  • Masked Language Model :随机把语料库中 15 % 15\% 15%的单词做Mask操作,对这些单词的Mask操作分为:
    • 80 % 80\% 80%的单词直接用[Mask]替换;
    • 10 % 10\% 10%的单词直接替换成另一个新的单词;
    • 10 % 10\% 10%的单词保持不变。
  • Next Sentence Prediction:目的是让模型理解两个句子之间的联系。训练的输入是句子A和B,B有一半的几率是A的下一句,BERT模型预测B是否为A的下一句。

对话情绪识别

对话情绪识别(Emotion Detection, EmoTect),专注于识别智能对话场景中用户的情绪。针对用户文本,自动判断该文本的情绪类别并给出相应的置信度。情绪类型分为积极、消极、中性。

BERT模型的文本情绪分类任务

数据集

import osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn, contextfrom mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy# prepare dataset
class SentimentDataset:"""Sentiment Dataset"""def __init__(self, path):self.path = pathself._labels, self._text_a = [], []self._load()def _load(self):with open(self.path, "r", encoding="utf-8") as f:dataset = f.read()lines = dataset.split("\n")for line in lines[1:-1]:label, text_a = line.split("\t")self._labels.append(int(label))self._text_a.append(text_a)def __getitem__(self, index):return self._labels[index], self._text_a[index]def __len__(self):return len(self._labels)

数据来自于百度飞桨团队,由两列组成,以制表符(\t)分隔:

  • 第一列是情绪分类的类别(0表示消极;1表示中性;2表示积极)
  • 第二列是以空格分词的中文文本。

label–text_a
0–谁骂人了?我从来不骂人,我骂的都不是人,你是人吗 ?
1–我有事等会儿就回来和你聊
2–我见到你很高兴谢谢你帮我

数据下载
# download dataset
!wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
!tar xvf emotion_detection.tar.gz
数据加载与预处理
import numpy as npdef process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):is_ascend = mindspore.get_context('device_target') == 'Ascend'column_names = ["label", "text_a"]dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)# transformstype_cast_op = transforms.TypeCast(mindspore.int32)def tokenize_and_pad(text):if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text)return tokenized['input_ids'], tokenized['attention_mask']# map datasetdataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return datasetfrom mindnlp.transformers import BertTokenizertokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)

模型构建

通过BertForSequenceClassification构建用于情绪分类的BERT模型,加载预训练权重,设置情绪三分类的超参数自动构建模型。

from mindnlp.transformers import BertForSequenceClassification, BertModel
from mindnlp._legacy.amp import auto_mixed_precision# set bert config and define parameters for training
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
model = auto_mixed_precision(model, 'O1')optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)metric = Accuracy()
# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_val, metrics=metric,epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])%%time
# start training
trainer.run(tgt_columns="labels")

模型验证

在验证集上对模型验证,查看模型在验证数据上的指标。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

模型推理

dataset_infer = SentimentDataset("data/infer.tsv")def predict(text, label=None):label_map = {0: "消极", 1: "中性", 2: "积极"}text_tokenized = Tensor([tokenizer(text).input_ids])logits = model(text_tokenized)predict_label = logits[0].asnumpy().argmax()info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"if label is not None:info += f" , label: '{label_map[label]}'"print(info)from mindspore import Tensorfor label, text in dataset_infer:predict(text, label)

总结

这一节对BERT模型进行了介绍,其主要创新点为Masked Language Model和Next Sentence Prediction,这两种方法可以捕获词语和句子级别的特征表示。使用预训练的BERT模型,可以很方便的对下游任务进行Fine-tuning。这一节介绍了BERT在情绪分类任务上的应用。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/48302.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Espressif-ESP32S3】【VScode】安装【ESP-IDF】插件及相关工具链

一、ESP-IDF简介 二、VScode安装ESP-IDF插件 三、安装ESP-IDF、ESP-IDF-Tools以及相关工具链 四、测试例程&编译烧录 五、IDF常用指令 资料下载: 链接:https://pan.baidu.com/s/15Q2rl2jpIaKfj5rATkYE6g?pwdGLNG 提取码:GLNG 一、ESP-…

opencv—常用函数学习_“干货“_7

目录 十九、模板匹配 从图像中提取矩形区域的子像素精度补偿 (getRectSubPix) 在图像中搜索和匹配模板 (matchTemplate) 比较两个形状(轮廓)的相似度 (matchShapes) 解释 二十、图像矩 计算图像或轮廓的矩 (moments) 计算图像或轮廓的Hu不变矩 (H…

IntelliJ IDEA 2024.1 最新变化 附问卷调查 AI

IntelliJ IDEA 2024.1 最新变化 问卷调查项目在线AI IntelliJ IDEA 2024.1 最新变化关键亮点全行代码补全 Ultimate对 Java 22 功能的支持新终端 Beta编辑器中的粘性行 AI AssistantAI Assistant 改进 UltimateAI Assistant 中针对 Java 和 Kotlin 的改进代码高亮显示 Ultimate…

Android14之调试广播实例(二百二十五)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒…

shell脚本检查OGG同步进程状态

服务器环境中在root用户下部署了ogg同步进程,在oracle用户下也部署了同步进程。在不用脚本检查的情况下,进程需要在root用户和oracle用户下来回切换,比较麻烦,所以考虑用脚本实现,在root用户下一键检查root用户和oracl…

Grid Search:解锁模型优化新境界

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 非常期待和您一起在这个小…

一键复制页面

<script src"./html2canvas.js"></script> <div id"target"><ol style"margin-top: 8px;margin-bottom: 8px;padding-left: 25px;" class"list-paddingleft-1"><li><p>递归遍历 DOM 树:</p&…

【Android性能优化】Android CPU占用率检测原理和优化方向

【Android性能优化】Android CPU占用率检测原理和优化方向 CPU相关知识 CPU占用的基本计算公式 (1 - 空闲态运行时间/总运行时间) * 100% Hz、Tick、Jiffies&#xff1a; Hz&#xff1a;Linux核心每隔固定周期会发出timer interrupt (IRQ 0)&#xff0c;HZ是用来定义每一秒有…

python 66 个冷知识 0720

66个有趣的Python冷知识 一行反转列表 使用切片一行反转列表&#xff1a;reversed_list my_list[::-1] 统计文件单词数量 使用 collections.Counter 统计文件中每个单词的数量&#xff1a;from collections import Counter; with open(file.txt) as f: word_count Counter(f…

【数据结构初阶】复杂度

目录 一、时间复杂度 1、时间复杂度的概念 2、大O的渐进表示法 3、常见的时间复杂度计算举例 二、空间复杂度 1、空间复杂度的概念 2、常见的空间复杂度计算举例 三、常见复杂度对比 正文开始—— 前言 一个算法&#xff0c;并非越简洁越好&#xff0c;那该如何衡量一个算法…

FLINK-checkpoint失败原因及处理方式

在 Flink 或其他分布式数据处理系统中&#xff0c;Checkpoint 失败可能由多种原因引起。以下是一些常见的原因&#xff1a; 资源不足&#xff1a; 如果 TaskManager 的内存或磁盘空间不足&#xff0c;可能无法完成状态的快照&#xff0c;导致 Checkpoint 失败。 网络问题&am…

微信小程序开发入门指南

文章目录 一、微信小程序简介二、微信小程序开发准备三、微信小程序开发框架四、微信小程序开发实例六、微信小程序开发进阶6.1 组件化开发6.2 API调用6.3 云开发 七、微信小程序开发注意事项7.1 遵守规范7.2 注意性能7.3 保护用户隐私 八、总结 大家好&#xff0c;今天将为大家…

源码安装 AMD GPGPU 生态 ROCm 备忘

0, 前言 如果初步接触 AMD这套&#xff0c;可以先在ubuntu上使用apt工具安装&#xff0c;并针对特定感兴趣的模块从源码编译安装替换&#xff0c;并开展研究。对整体感兴趣时可以考虑从源码编译安装整个ROCm生态。 1, 预制二进制通过apt 安装 待补。。。 2, 从源码安装 sudo …

C:一些题目

1.分数求和 计算1/1-1/21/3-1/41/5 …… 1/99 - 1/100 的值 #include <stdio.h>int main(){double sum 0.0; // 使用 double 类型来存储结果&#xff0c;以处理可能的小数部分int sign 1; // 符号标志&#xff0c;初始为 1 表示正数for (int i 1; i < 100; i)…

Vue3 内置组件Teleport以及Susponse

1、Teleport 1.1 概念 将组件模版中的指定的dom挂载&#xff08;传送&#xff09;到指定的dom元素上&#xff0c;如挂载到body中&#xff0c;挂载到#app选择器上面。 1.2 应用场景 经典案例如&#xff1a;模态框。 <template><teleport to"body">&l…

处理AI模型中的“Type Mismatch”报错:数据类型转换技巧

处理AI模型中的“Type Mismatch”报错&#xff1a;数据类型转换技巧 &#x1f504; 处理AI模型中的“Type Mismatch”报错&#xff1a;数据类型转换技巧 &#x1f504;摘要引言正文内容1. 错误解析&#xff1a;什么是“Type Mismatch”&#xff1f;2. 数据类型转换技巧2.1 检查…

Redis之Zset

目录 一.介绍 二.命令 三.编码方式 四.应用场景 Redis的学习专栏&#xff1a;http://t.csdnimg.cn/a8cvV 一.介绍 ZSET&#xff08;有序集合&#xff09;是 Redis 提供的一种数据结构&#xff0c;它与普通集合&#xff08;SET&#xff09;类似&#xff0c;不同之处在于每个…

【带你了解软件系统架构的演变】

🌈个人主页: 程序员不想敲代码啊 🏆CSDN优质创作者,CSDN实力新星,CSDN博客专家 👍点赞⭐评论⭐收藏 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进步! 1. 介绍 🍋‍🟩软件系统架构的演变是一个响应技术变革、业务需求…

Tailwind CSS常见组合用法

1、一般布局组合 <main className"flex min-h-screen flex-col items-center justify-between p-24"></main>flex将元素的显示类型设置为 flexbox。这意味着子元素将以 flex 项的方式排列。min-h-screen将元素的最小高度设置为全屏高度&#xff08;视口高…

【Powershell】超越限制:获取Azure AD登录日志

你是否正在寻找一种方法来追踪 Azure Active Directory&#xff08;Azure AD&#xff09;中用户的登录活动&#xff1f; 如果是的话&#xff0c;查看Azure AD用户登录日志最简单的方法是使用Microsoft Entra管理中心。打开 https://entra.microsoft.com/&#xff0c;然后进入 监…