【从零开始实现意图识别】中文对话意图识别详解

前言

意图识别(Intent Recognition)是自然语言处理(NLP)中的一个重要任务,它旨在确定用户输入的语句中所表达的意图或目的。简单来说,意图识别就是对用户的话语进行语义理解,以便更好地回答用户的问题或提供相关的服务。

在NLP中,意图识别通常被视为一个分类问题,即通过将输入语句分类到预定义的意图类别中来识别其意图。这些类别可以是各种不同的任务、查询、请求等,例如搜索、购买、咨询、命令等。

下面是一个简单的例子来说明意图识别的概念:

用户输入: "我想订一张从北京到上海的机票。

意图识别:预订机票。

在这个例子中,通过将用户输入的语句分类到“预订机票”这个意图类别中,系统可以理解用户的意图并为其提供相关的服务。

意图识别是NLP中的一项重要任务,它可以帮助我们更好地理解用户的需求和意图,从而为用户提供更加智能和高效的服务。

在智能对话任务中,意图识别是一种非常重要的技术,它可以帮助系统理解用户的输入,从而提供更加准确和个性化的回答和服务。

模型

意图识别和槽位填充是对话系统中的基础任务。本仓库实现了一个基于BERT的意图(intent)和槽位(slots)联合预测模块。想法上实际与JoinBERT类似(GitHub:BERT for Joint Intent Classification and Slot Filling),利用 [CLS] token对应的last hidden state去预测整句话的intent,并利用句子tokens的last hidden states做序列标注,找出包含slot values的tokens。你可以自定义自己的意图和槽位标签,并提供自己的数据,通过下述流程训练自己的模型,并在JointIntentSlotDetector类中加载训练好的模型直接进行意图和槽值预测。

源GitHub:https://github.com/Linear95/bert-intent-slot-detector

在本文使用的模型中对数据进行了扩充、对代码进行注释、对部分代码进行了修改

Bert模型下载

Bert模型下载地址:https://huggingface.co/bert-base-chinese/tree/main

下载下方红框内的模型即可。

数据集介绍

训练数据以json格式给出,每条数据包括三个关键词:text表示待检测的文本,intent代表文本的类别标签,slots是文本中包括的所有槽位以及对应的槽值,以字典形式给出。

{

"text": "搜索西红柿的做法。",

"domain": "cookbook",

"intent": "QUERY",

"slots": {"ingredient": "西红柿"}

}

原始数据集:https://conference.cipsc.org.cn/smp2019/

本项目中在原始数据集中新增了部分数据,用来平衡数据。

模型训练

python train.py

# -----------training-------------
max_acc = 0
for epoch in range(args.train_epochs):total_loss = 0model.train()for step, batch in enumerate(train_dataloader):input_ids, intent_labels, slot_labels = batchoutputs = model(input_ids=torch.tensor(input_ids).long().to(device),intent_labels=torch.tensor(intent_labels).long().to(device),slot_labels=torch.tensor(slot_labels).long().to(device))loss = outputs['loss']total_loss += loss.item()if args.gradient_accumulation_steps > 1:loss = loss / args.gradient_accumulation_stepsloss.backward()if step % args.gradient_accumulation_steps == 0:# 用于对梯度进行裁剪,以防止在神经网络训练过程中出现梯度爆炸的问题。torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)optimizer.step()scheduler.step()model.zero_grad()train_loss = total_loss / len(train_dataloader)dev_acc, intent_avg, slot_avg = dev(model, val_dataloader, device, slot_dict)flag = Falseif max_acc < dev_acc:max_acc = dev_accflag = Truesave_module(model, model_save_dir)print(f"[{epoch}/{args.train_epochs}] train loss: {train_loss}  dev intent_avg: {intent_avg} "f"def slot_avg: {slot_avg} save best model: {'*' if flag else ''}")dev_acc, intent_avg, slot_avg = dev(model, val_dataloader, device, slot_dict)
print("last model dev intent_avg: {} def slot_avg: {}".format(intent_avg, slot_avg))

运行过程:

模型推理

python predict.py
 
def detect(self, text, str_lower_case=True):"""text : list of string, each string is a utterance from user"""list_input = Trueif isinstance(text, str):text = [text]list_input = Falseif str_lower_case:text = [t.lower() for t in text]batch_size = len(text)inputs = self.tokenizer(text, padding=True)with torch.no_grad():outputs = self.model(input_ids=torch.tensor(inputs['input_ids']).long().to(self.device))intent_logits = outputs['intent_logits']slot_logits = outputs['slot_logits']intent_probs = torch.softmax(intent_logits, dim=-1).detach().cpu().numpy()slot_probs = torch.softmax(slot_logits, dim=-1).detach().cpu().numpy()slot_labels = self._predict_slot_labels(slot_probs)intent_labels = self._predict_intent_labels(intent_probs)slot_values = self._extract_slots_from_labels(inputs['input_ids'], slot_labels, inputs['attention_mask'])outputs = [{'text': text[i], 'intent': intent_labels[i], 'slots': slot_values[i]}for i in range(batch_size)]if not list_input:return outputs[0]return outputs

推理结果:

模型检测相关代码

将概率值转换为实际标注值

def _predict_slot_labels(self, slot_probs):"""slot_probs : probability of a batch of tokens into slot labels, [batch, seq_len, slot_label_num], numpy array"""slot_ids = np.argmax(slot_probs, axis=-1)return self.slot_dict[slot_ids.tolist()]def _predict_intent_labels(self, intent_probs):"""intent_labels : probability of a batch of intent ids into intent labels, [batch, intent_label_num], numpy array"""intent_ids = np.argmax(intent_probs, axis=-1)return self.intent_dict[intent_ids.tolist()]

槽位验证(确保检测结果的正确性)

def _extract_slots_from_labels_for_one_seq(self, input_ids, slot_labels, mask=None):results = {}unfinished_slots = {}  # dict of {slot_name: slot_value} pairsif mask is None:mask = [1 for _ in range(len(input_ids))]def add_new_slot_value(results, slot_name, slot_value):if slot_name == "" or slot_value == "":return resultsif slot_name in results:results[slot_name].append(slot_value)else:results[slot_name] = [slot_value]return resultsfor i, slot_label in enumerate(slot_labels):if mask[i] == 0:continue# 检测槽位的第一字符(B_)开头if slot_label[:2] == 'B_':slot_name = slot_label[2:]  # 槽位名称 (B_ 后面)if slot_name in unfinished_slots:results = add_new_slot_value(results, slot_name, unfinished_slots[slot_name])unfinished_slots[slot_name] = self.tokenizer.decode(input_ids[i])# 检测槽位的后面字符(I_)开头elif slot_label[:2] == 'I_':slot_name = slot_label[2:]if slot_name in unfinished_slots and len(unfinished_slots[slot_name]) > 0:unfinished_slots[slot_name] += self.tokenizer.decode(input_ids[i])for slot_name, slot_value in unfinished_slots.items():if len(slot_value) > 0:results = add_new_slot_value(results, slot_name, slot_value)return results

源码获取

NLP/bert-intent-slot at main · mzc421/NLP (github.com)icon-default.png?t=N7T8https://github.com/mzc421/NLP/tree/main/bert-intent-slot

链接作者

欢迎关注我的公众号:@AI算法与电子竞赛

硬性的标准其实限制不了无限可能的我们,所以啊!少年们加油吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/164862.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

XUbuntu22.04之解决gpg keyserver receive failed no data(一百九十三)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 人生格言&#xff1a; 人生…

DevExpress WinForms TreeMap组件,用嵌套矩形可视化复杂分层数据

DevExpress WinForms TreeMap控件允许用户使用嵌套的矩形来可视化复杂的平面或分层数据结构。 DevExpress WinForms有180组件和UI库&#xff0c;能为Windows Forms平台创建具有影响力的业务解决方案。同时能完美构建流畅、美观且易于使用的应用程序&#xff0c;无论是Office风…

中文rlhf数据集50w条数据解析

中文rlhf数据集50w条数据解析 解析代码数据名代码解析 解析代码 import jieba from tqdm import tqdm import re import pandas as pd import numpy as npdef find_non_english_text(text):pattern re.compile(r[^a-zA-Z])return pattern.sub(, text)def find_chinese_text(t…

教育数字化转型:塑造未来学习新范式

在国家教育数字化战略行动指引下&#xff0c;我国正积极推动数字化赋能教育高质量发展&#xff0c;以塑造教育发展的新优势。如今&#xff0c;随着科技新基建的普及和数字化赋能教育的深入推进&#xff0c;未来的教育模型正在逐渐形成。 在新的教育模型中&#xff0c;数字化学…

算法基础(python版本)

第二章 算法设计思想 一、搜索排序 1.排序算法 https://visualgo.net/zh/sorting (1)冒泡排序 # 思路&#xff1a; # (1)比较相邻元素&#xff0c;如果第一个比第二个大&#xff0c;则交换他们 # (2)第一轮下来&#xff0c;可以保证最后一个数一定是最大的&#xff1b;第二…

2023最全的Web自动化测试介绍

做测试的同学们都了解&#xff0c;做Web自动化&#xff0c;我们主要用Selenium或者是QTP。 有的人可能就会说&#xff0c;我没这个Java基础&#xff0c;没有Selenium基础&#xff0c;能行吗&#xff1f;测试虽然属于计算机行业&#xff0c;但其实并不需要太深入的编程知识&…

介绍一个功能强大的shopify app——TINYIMG

各位观众老爷&#xff0c;南来的北往的&#xff0c;东去的西走的&#xff0c;今天给大家推荐一个功能很强大的shopify app 当当当 那就是 tinyimg 这个app有多牛逼呢&#xff0c;且听我慢慢道来 首先这个app可以用来优化图片大小&#xff0c;给你的网站提提速 然后这个app还可…

Android使用AIDL+MemoryFile传递大数据

Android进程间通信经常会使用AIDL&#xff0c;简单方便&#xff0c;但是数据量有限制&#xff0c;超过一定值会报错&#xff1a; E !!! FAILED BINDER TRANSACTION !!! (parcel size 2073744) 可以通过使用AIDLMemoryFile传递大数据 新建AIDL接口&#xff1a; interface On…

CCFCSP试题编号:201803-2试题名称:碰撞的小球

一、题目描述 二、思路 1.首先妾身分析这个题目&#xff0c;想要解题&#xff0c;得得解决2个问题。 1&#xff09;判断小球到达端点或碰撞然后改变方向&#xff1b; 2&#xff09;每时刻都要改变位置 两个问题都比较好解决&#xff0c;1&#xff09;只要简单判断坐标&…

形态学操作—膨胀

在 OpenCV 中&#xff0c;图像形态学操作是一组基于图像形状的处理技术&#xff0c;其中膨胀&#xff08;Dilation&#xff09;是其中之一。膨胀操作可用于图像处理中的特征增强、去噪、分割和边缘检测等。其基本原理是利用结构元素&#xff08;Kernel 或 Structuring Element&…

Tomcat实现WebSocket即时通讯 Java实现WebSocket的两种方式

HTTP协议是“请求-响应”模式&#xff0c;浏览器必须先发请求给服务器&#xff0c;服务器才会响应该请求。即服务器不会主动发送数据给浏览器。 实时性要求高的应用&#xff0c;如在线游戏、股票实时报价和在线协同编辑等&#xff0c;浏览器需实时显示服务器的最新数据&#x…

UML建模图文详解教程06——顺序图

版权声明 本文原创作者&#xff1a;谷哥的小弟作者博客地址&#xff1a;http://blog.csdn.net/lfdfhl本文参考资料&#xff1a;《UML面向对象分析、建模与设计&#xff08;第2版&#xff09;》吕云翔&#xff0c;赵天宇 著 顺序图概述 顺序图(sequence diagram&#xff0c;也…

(三)C语言之for语句概述

&#xff08;三&#xff09;C语言之for语句概述 一、使用for语句实现打印华氏温度与摄氏温度转换二、for语句概述三、练习 一、使用for语句实现打印华氏温度与摄氏温度转换 #include <stdio.h> /*当华氏温度为 0,20,40,...300时&#xff0c;打印出华氏温度与摄氏温度对照…

一个简单的QT应用示例

一个简单的QT应用示例&#xff1a;创建一个窗口程序。 首先&#xff0c;确保已经安装了Qt开发环境。接下来&#xff0c;按照以下步骤创建一个简单的窗口程序&#xff1a; 1. 打开Qt Creator&#xff0c;点击“新建文件或项目”。 2. 选择“应用程序”&#xff0c;然后点击“下…

【MATLAB】根轨迹的绘制及rltool工具的使用

目录 一、MATLAB中传递函数的表示二、rlocus函数绘制根轨迹1.常规根轨迹仿真示例2.参数根轨迹仿真示例3.零度根轨迹仿真示例 三、图形化工具rltool介绍 一、MATLAB中传递函数的表示 在绘制系统的根轨迹之前&#xff0c;需要知道传递函数在matlab中如何表示。 在matlab中&#…

VOC数据集和COCO数据集直接的相互转换

VOC数据集格式 get_list.py import os import random import shutil# 设置随机种子 random.seed(1000)# 判断Annotations和JpegImages是否对应 train_precent=0.8 label_path= "../../Annotations" print(os.path.abspath(label_path)) save="../Main" pr…

repo init报error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed

repo init报error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed 1 repo init出错的信息2 解决方法 在ubuntu执行repo init的时候报了repo init报error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed这种错误&#xff0c;解决方法是需要更新本地…

PS给图片增加一个白色边框。

问题描述&#xff1a;PS如何给图片增加一个白色边框&#xff1f; 解决办法&#xff1a; 第一步&#xff1a;使用shiftAltA快捷键&#xff0c;在图片四周拉出一个灰白色的边框。如下图所示&#xff1a; 第二步&#xff0c;使用快捷键Ctrlshiftn新建一个图层。 并把新建的图层…

创建maven的web项目

&#xff08;一&#xff09;创建maven的web项目 Step1、创建一个普通的maven项目 &#xff08;1&#xff09;新建一个empty project&#xff0c;命名为SSM2。 点击项目名&#xff0c;右键new&#xff0c;选择Module&#xff0c;左侧选择“Maven archetype”&#xff0c;可以给…

我叫:快速排序【JAVA】

1.自我介绍 1.快速排序是由东尼霍尔所发展的一种排序算法。 2.快速排序又是一种分而治之思想在排序算法上的典型应用。 3.本质上来看&#xff0c;快速排序应该算是在冒泡排序基础上的递归分治法。 2.思想共享 快速排序(Quicksort)是对冒泡排序的一种改进。基本思想是:通过一趟…