【MindSpore学习打卡】应用实践-自然语言处理-深入理解LSTM+CRF在序列标注中的应用

在自然语言处理(NLP)领域,序列标注是一项重要的任务。其目标是为给定的输入序列中的每个Token分配一个标签。序列标注的应用范围广泛,包括分词、词性标注、命名实体识别(NER)等。在本文中,我们将探讨如何利用LSTM和CRF模型进行序列标注,并使用MindSpore框架实现这一过程。通过深入了解LSTM和CRF的原理和实现方法,读者将能够更好地理解和应用这些技术来解决实际问题。

条件随机场(Conditional Random Field, CRF)

在序列标注任务中,简单地将每个Token的标签预测视为多分类问题是不够的,因为相邻Token之间存在依赖关系。以命名实体识别为例:

输入序列
输出标注BIIIOOOOOBI

在上述例子中,清华大学北京是地名,需要将其识别出来。我们对每个输入的单词预测其标签,最后根据标签来识别实体。为了捕获这种依赖关系,我们引入条件随机场(CRF)。

为什么需要CRF

在序列标注任务中,简单地将每个Token的标签预测视为多分类问题是不够的,因为相邻Token之间存在依赖关系。比如在命名实体识别任务中,一个实体的开始标签通常是"B",后续的标签是"I",而非实体的标签是"O"。如果我们不考虑这种依赖关系,模型可能会产生不合理的标签序列。条件随机场(CRF)通过引入发射概率和转移概率,能够捕获这种标签间的依赖关系,从而提高预测的准确性。

CRF的定义与参数化形式

CRF是一种概率图模型,适用于捕获序列中相邻Token之间的依赖关系。设 x = { x 0 , . . . , x n } x=\{x_0, ..., x_n\} x={x0,...,xn}为输入序列, y = { y 0 , . . . , y n } y=\{y_0, ..., y_n\} y={y0,...,yn}为输出的标注序列,其中 n n n为序列的最大长度。则输出序列 y y y的概率为:

P ( y ∣ x ) = exp ⁡ ( Score ( x , y ) ) ∑ y ′ ∈ Y exp ⁡ ( Score ( x , y ′ ) ) P(y|x) = \frac{\exp{(\text{Score}(x, y)})}{\sum_{y' \in Y} \exp{(\text{Score}(x, y')})} P(yx)=yYexp(Score(x,y))exp(Score(x,y))

其中, Score ( x , y ) \text{Score}(x, y) Score(x,y)用于衡量序列 x x x和标签 y y y的匹配程度。我们定义两个概率函数来计算 Score \text{Score} Score

  1. 发射概率函数 ψ EMIT \psi_\text{EMIT} ψEMIT:表示 x i → y i x_i \rightarrow y_i xiyi的概率。
  2. 转移概率函数 ψ TRANS \psi_\text{TRANS} ψTRANS:表示 y i − 1 → y i y_{i-1} \rightarrow y_i yi1yi的概率。

基于这两个函数,我们可以得到 Score \text{Score} Score的计算公式:

Score ( x , y ) = ∑ i log ⁡ ψ EMIT ( x i → y i ) + log ⁡ ψ TRANS ( y i − 1 → y i ) \text{Score}(x,y) = \sum_i \log \psi_\text{EMIT}(x_i \rightarrow y_i) + \log \psi_\text{TRANS}(y_{i-1} \rightarrow y_i) Score(x,y)=ilogψEMIT(xiyi)+logψTRANS(yi1yi)

CRF的实现

在实现CRF时,我们需要计算正确标签序列的得分(Score)和所有可能标签序列的对数指数和(Normalizer)。然后通过求解负对数似然损失(NLL)来进行模型训练。

Score计算

为什么需要序列填充和掩码

在实际应用中,输入序列的长度可能不一致。为了将这些序列打包成一个Batch,我们需要对长度不足的序列进行填充。然而,填充的部分不应参与模型的训练和预测。因此,我们引入了掩码矩阵(mask),用于忽略填充部分的计算。这样可以确保模型只关注有效的Token,提高训练和预测的准确性。

首先根据公式计算正确标签序列的得分:

def compute_score(emissions, tags, seq_ends, mask, trans, start_trans, end_trans):seq_length, batch_size = tags.shapemask = mask.astype(emissions.dtype)score = start_trans[tags[0]]score += emissions[0, mnp.arange(batch_size), tags[0]]for i in range(1, seq_length):score += trans[tags[i - 1], tags[i]] * mask[i]score += emissions[i, mnp.arange(batch_size), tags[i]] * mask[i]last_tags = tags[seq_ends, mnp.arange(batch_size)]score += end_trans[last_tags]return score

Normalizer计算

接下来,我们使用动态规划算法计算Normalizer:

def compute_normalizer(emissions, mask, trans, start_trans, end_trans):seq_length = emissions.shape[0]score = start_trans + emissions[0]for i in range(1, seq_length):broadcast_score = score.expand_dims(2)broadcast_emissions = emissions[i].expand_dims(1)next_score = broadcast_score + trans + broadcast_emissionsnext_score = ops.logsumexp(next_score, axis=1)score = mnp.where(mask[i].expand_dims(1), next_score, score)score += end_transreturn ops.logsumexp(score, axis=1)

Viterbi算法

为什么使用Viterbi算法

在解码阶段,我们需要找到使得序列得分最高的标签序列。穷举所有可能的标签序列并计算其得分是不可行的,因为可能的标签序列数量是指数级的。Viterbi算法是一种动态规划算法,能够高效地找到最优标签序列。它通过逐步计算每个Token对应的最优标签,并保存中间结果,避免了重复计算,从而大大提高了解码的效率。

在解码阶段,我们使用Viterbi算法求解最优标签序列:

def viterbi_decode(emissions, mask, trans, start_trans, end_trans):seq_length = mask.shape[0]score = start_trans + emissions[0]history = ()for i in range(1, seq_length):broadcast_score = score.expand_dims(2)broadcast_emission = emissions[i].expand_dims(1)next_score = broadcast_score + trans + broadcast_emissionindices = next_score.argmax(axis=1)history += (indices,)next_score = next_score.max(axis=1)score = mnp.where(mask[i].expand_dims(1), next_score, score)score += end_transreturn score, historydef post_decode(score, history, seq_length):batch_size = seq_length.shape[0]seq_ends = seq_length - 1best_tags_list = []for idx in range(batch_size):best_last_tag = score[idx].argmax(axis=0)best_tags = [int(best_last_tag.asnumpy())]for hist in reversed(history[:seq_ends[idx]]):best_last_tag = hist[idx][best_tags[-1]]best_tags.append(int(best_last_tag.asnumpy()))best_tags.reverse()best_tags_list.append(best_tags)return best_tags_list

CRF层的封装

我们将上述代码封装成一个CRF层:

import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore.common.initializer import initializer, Uniformdef sequence_mask(seq_length, max_length, batch_first=False):range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)result = range_vector < seq_length.view(seq_length.shape + (1,))if batch_first:return result.astype(ms.int64)return result.astype(ms.int64).swapaxes(0, 1)class CRF(nn.Cell):def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:if num_tags <= 0:raise ValueError(f'invalid number of tags: {num_tags}')super().__init__()if reduction not in ('none', 'sum', 'mean', 'token_mean'):raise ValueError(f'invalid reduction: {reduction}')self.num_tags = num_tagsself.batch_first = batch_firstself.reduction = reductionself.start_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions')self.end_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions')self.transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions')def construct(self, emissions, tags=None, seq_length=None):if tags is None:return self._decode(emissions, seq_length)return self._forward(emissions, tags, seq_length)def _forward(self, emissions, tags=None, seq_length=None):if self.batch_first:batch_size, max_length = tags.shapeemissions = emissions.swapaxes(0, 1)tags = tags.swapaxes(0, 1)else:max_length, batch_size = tags.shapeif seq_length is None:seq_length = mnp.full((batch_size,), max_length, ms.int64)mask = sequence_mask(seq_length, max_length)numerator = compute_score(emissions, tags, seq_length-1, mask, self.transitions, self.start_transitions, self.end_transitions)denominator = compute_normalizer(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)llh = denominator - numeratorif self.reduction == 'none':return llhif self.reduction == 'sum':return llh.sum()if self.reduction == 'mean':return llh.mean()return llh.sum() / mask.astype(emissions.dtype).sum()def _decode(self, emissions, seq_length=None):if self.batch_first:batch_size, max_length = emissions.shape[:2]emissions = emissions.swapaxes(0, 1)else:batch_size, max_length = emissions.shape[:2]if seq_length is None:seq_length = mnp.full((batch_size,), max_length, ms.int64)mask = sequence_mask(seq_length, max_length)return viterbi_decode(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)

BiLSTM+CRF模型

为什么使用双向LSTM

双向LSTM能够同时捕获序列中前后两个方向的依赖关系。在序列标注任务中,当前Token的标签不仅依赖于前面的Token,还可能依赖于后面的Token。通过使用双向LSTM,我们可以更全面地提取序列特征,从而提高模型的表现。

在实现了CRF层之后,我们设计一个双向LSTM+CRF的模型来进行命名实体识别任务的训练。模型结构如下:

nn.Embedding -> nn.LSTM -> nn.Dense -> CRF

具体实现如下:

class BiLSTM_CRF(nn.Cell):def __init__(self, vocab_size, embedding_dim, hidden_dim, num_tags, padding_idx=0):super().__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, batch_first=True)self.hidden2tag = nn.Dense(hidden_dim, num_tags, 'he_uniform')self.crf = CRF(num_tags, batch_first=True)def construct(self, inputs, seq_length, tags=None):embeds = self.embedding(inputs)outputs, _ = self.lstm(embeds, seq_length=seq_length)feats = self.hidden2tag(outputs)crf_outs = self.crf(feats, tags, seq_length)return crf_outs

数据准备

我们生成两句例子和对应的标签,并构造词表和标签表:

embedding_dim = 16
hidden_dim = 32training_data = [("清 华 大 学 坐 落 于 首 都 北 京".split(),"B I I I O O O O O B I".split()
), ("重 庆 是 一 个 魔 幻 城 市".split(),"B I O O O O O O O".split()
)]word_to_idx = {}
word_to_idx['<pad>'] = 0
for sentence, tags in training_data:for word in sentence:if word not in word_to_idx:word_to_idx[word] = len(word_to_idx)tag_to_idx = {"B": 0, "I": 1, "O": 2}
len(word_to_idx)

模型训练

实例化模型,选择优化器并将模型和优化器送入Wrapper:

model = BiLSTM_CRF(len(word_to_idx), embedding_dim, hidden_dim, len(tag_to_idx))
optimizer = nn.SGD(model.trainable_params(), learning_rate=0.01, weight_decay=1e-4)
grad_fn = ms.value_and_grad(model, None, optimizer.parameters)def train_step(data, seq_length, label):loss, grads = grad_fn(data, seq_length, label)optimizer(grads)return loss

将生成的数据打包成Batch,并进行填充:

def prepare_sequence(seqs, word_to_idx, tag_to_idx):seq_outputs, label_outputs, seq_length = [], [], []max_len = max([len(i[0]) for i in seqs])for seq, tag in seqs:seq_length.append(len(seq))idxs = [word_to_idx[w] for w in seq]labels = [tag_to_idx[t] for t in tag]idxs.extend([word_to_idx['<pad>'] for i in range(max_len - len(seq))])labels.extend([tag_to_idx['O'] for i in range(max_len - len(seq))])seq_outputs.append(idxs)label_outputs.append(labels)return ms.Tensor(seq_outputs, ms.int64), \ms.Tensor(label_outputs, ms.int64), \ms.Tensor(seq_length, ms.int64)
data, label, seq_length = prepare_sequence(training_data, word_to_idx, tag_to_idx)
data.shape, label.shape, seq_length.shape

预编译模型并训练500个step:

from tqdm import tqdmsteps = 500
with tqdm(total=steps) as t:for i in range(steps):loss = train_step(data, seq_length, label)t.set_postfix(loss=loss)t.update(1)

模型评估

训练完成后,我们使用模型进行预测:

score, history = model(data, seq_length)
score

使用后处理函数进行预测得分的处理:

predict = post_decode(score, history, seq_length)
predict

将预测的index序列转换为标签序列并打印输出结果:

idx_to_tag = {idx: tag for tag, idx in tag_to_idx.items()}def sequence_to_tag(sequences, idx_to_tag):outputs = []for seq in sequences:outputs.append([idx_to_tag[i] for i in seq])return outputs
sequence_to_tag(predict, idx_to_tag)

通过上述步骤,我们成功实现了一个基于LSTM和CRF的序列标注模型,并在命名实体识别任务中进行了应用。希望这篇博客能帮助你更好地理解LSTM和CRF在序列标注中的应用。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/43172.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

无人机之穿越机知识篇

一、定义和类型 穿越机&#xff0c;即FPV Drone或Racing Drone&#xff0c;是一种主要通过第一人称视角&#xff08;FPV&#xff09;进行操作的无人机。这种无人机通常配备有四个电机和相应的飞控系统&#xff0c;使其具有极高的飞行自由度和速度。穿越机主要分为竞速型和花飞…

electron在VSCode和IDEA及webStrom等编辑器控制台打印日志乱码

window10环境下设置 1.打开Windows设置 2.打开时间和语言&#xff0c;选择语言菜单、如何点击管理语言设置 3.打开之后选择管理&#xff0c;选择更改系统区域设置&#xff0c;把Beta版&#xff1a;使用Unicode UTF-8提供全球语言支持 勾上&#xff0c;点击确定&#xff0c;…

甄选范文“论区块链技术及应用”,软考高级论文,系统架构设计师论文

论文真题 区块链作为一种分布式记账技术,目前已经被应用到了资产管理、物联网、医疗管理、政务监管等多个领域。从网络层面来讲,区块链是一个对等网络(Peer to Peer, P2P),网络中的节点地位对等,每个节点都保存完整的账本数据,系统的运行不依赖中心化节点,因此避免了中…

MySQL 9.0 新功能概览

官方文档 https://dev.mysql.com/doc/refman/9.0/en/mysql-nutshell.html 时隔 6 年多&#xff0c;上周 Oracle 发布了 MySQL 最新的大版本 9.0。我们一起来看看新版本有哪些东西。 用 JavaScript 写存储过程 半年前已经单独介绍过 「虽迟但到&#xff01;MySQL 可以用 Java…

方圆资源网,方圆资源官网

在当今这个信息化高速发展的时代&#xff0c;方圆资源网络已成为推动社会进步、促进经济发展的重要力量。方圆资源网不仅汇聚了海量的信息资源&#xff0c;更为我们提供了一个高效、便捷的信息交流平台。本文旨在详细介绍资源网的概念、特点、功能以及其在现代社会中的重要意义…

DeepMind的JEST技术:AI训练速度提升13倍,能效增强10倍,引领绿色AI革命

谷歌旗下的人工智能研究实验室DeepMind发布了一项关于人工智能模型训练的新研究成果&#xff0c;声称其新提出的“联合示例选择”&#xff08;Joint Example Selection&#xff0c;简称JEST&#xff09;技术能够极大地提高训练速度和能源效率&#xff0c;相比其他方法&#xff…

华为乾崑智驾加持:深蓝S07首次亮相

最近&#xff0c;特斯拉FSD即将入华的消息&#xff0c;让智能驾驶成为了汽车行业热议的焦点&#xff0c;而当新能源汽车的代表企业深蓝汽车&#xff0c;与全球领先的华为乾崑智驾强强联手&#xff0c;一场颠覆性的智能出行变革也已蓄势待发。 7月8日&#xff0c;深蓝汽车携其最…

[数仓]四、离线数仓(Hive数仓系统-续)

第8章 数仓搭建-DWT层 8.1 访客主题 1)建表语句 DROP TABLE IF EXISTS dwt_visitor_topic; CREATE EXTERNAL TABLE dwt_visitor_topic (`mid_id` STRING COMMENT 设备id,`brand` STRING COMMENT 手机品牌,`model` STRING COMMENT 手机型号,`channel` ARRAY<STRING> C…

MinIO - 服务端签名直传(前端 + 后端 + 效果演示)

目录 开始 服务端签名直传概述 代码实现 后端实现 前端实现 效果演示 开始 服务端签名直传概述 传统的&#xff0c;我们有两种方式将图片上传到 OSS&#xff1a; a&#xff09;前端请求 -> 后端服务器 -> OSS 好处&#xff1a;在服务端上传&#xff0c;更加安全…

Android - 云游戏本地悬浮输入框实现

一、简述 云游戏输入法分两种情况,以云化原神为例,分为 云端输入法 和 本地输入法,运行效果如下: 云端输入法本地输入法云端输入法 就是运行在云端设备上的输入法,对于不同客户端来说(Android、iPhone),运行效果一致。 本地输入法 则是运行在用户侧设备上的输入法,对…

Fastapi在docekr中进行部署之后,uvicorn占用的CPU非常高

前一段接点小活&#xff0c;做点开发&#xff0c;顺便学了学FASTAPI框架&#xff0c;对比flask据说能好那么一些&#xff0c;至少并发什么的不用研究其他的asgi什么的&#xff0c;毕竟不是专业开发&#xff0c;能少研究一个东西就省了很多的事。 但是部署的过程中突然之间在do…

Zabbix分布式监控

目录 分布式监控架构 实现分布式监控的步骤 优点和应用场景 安装Zabbix_Proxy Server端Web页面配置 测试 Zabbix 的分布式监控架构允许在大规模和地理上分散的环境中进行高效的监控。通过分布式监控&#xff0c;Zabbix 可以扩展其监控能力&#xff0c;支持大量主机和设备…

基于全志Orangepi Zero 2 的语音刷抖音项目

目录 一、硬件连接 二、项目实现功能 三、SU-03T语音模块的配置和烧录 3.1 创建产品&#xff1a; 3.2 设置PIN引脚为串口模式&#xff1a; 3.3 设置唤醒词&#xff1a; 3.4 设置指令语句&#xff1a; 3.5 设置控制详情&#xff1a; 3.6 发音人配置&#xff1a; 3.7 其…

atcoder 357 F Two Sequence Queries (线段树板子)

题目&#xff1a; 分析&#xff1a; 线段树 代码&#xff1a; // Problem: F - Two Sequence Queries // Contest: AtCoder - SuntoryProgrammingContest2024&#xff08;AtCoder Beginner Contest 357&#xff09; // URL: https://atcoder.jp/contests/abc357/tasks/abc357_…

华为HCIP Datacom H12-821 卷29

1.多选题 下面关于LSA age字段&#xff0c;描述正确的是∶ A、LSA age的单位为秒&#xff0c;在LSDB中的LSA的LS age随时间增长而增长 B、LSA age的单位为秒&#xff0c;在LSDB中的LSA的LS age随时间增长而减少 C、如果一条LSA的LS age达到了LS RefreshTime&#xff08…

[spring] Spring MVC - security(下)

[spring] Spring MVC - security&#xff08;下&#xff09; callback 一下&#xff0c;当前项目结构如下&#xff1a; 这里实现的功能是连接数据库&#xff0c;大范围和 [spring] rest api security 重合 数据库连接 - 明文密码 第一部分使用明文密码 设置数据库 主要就是…

内存与硬盘(笔记)

文章目录 1. 内存① 笔记② 相关软件 2. 硬盘① 笔记② 相关软件 3. 区别&#xff1a;4. 推荐 1. 内存 ① 笔记 ① 笔记本的内存条和台式机的内存条是不互通的 ② 我们常说的内存其实指的是运行内存(前台后台同时能运行多少APP) ③ 下图来自京东&#xff1a; 解析&#xff1…

ubuntu 分区情况

ubuntu系统安装与分区指南 - Philbert - 博客园 (cnblogs.com)https://www.cnblogs.com/liangxuran/p/14872811.html 详解安装Ubuntu Linux系统时硬盘分区最合理的方法-腾讯云开发者社区-腾讯云 (tencent.com)https://cloud.tencent.com/developer/article/1711884

终于找到了免费的C盘清理软件(极智C盘清理)

搜了很久&#xff0c;终于让我找到了一款 完全免费的C盘清理软件&#xff08;极智C盘清理&#xff09;。 点击前往官网免费使用极智C盘清理软件&#xff1a; C盘清理 用户好评 完全免费的极智C盘清理 用极智C盘清理清理了下系统的临时文件、缓存等无用数据文件&#xff0c;C盘终…

设计资料:520-基于ZU15EG 适配AWR2243的雷达验证底板 高速信号处理板 AWR2243毫米波板

基于ZU15EG 适配AWR2243的雷达验证底板 一、板卡概述 本板卡系北京太速科技自主研发&#xff0c;基于MPSOC系列SOC XCZU15EG-FFVB1156架构&#xff0c;搭载两组64-bit DDR4&#xff0c;每组容量32Gb&#xff0c;最高可稳定运行在2400MT/s。另有1路10G SFP光纤接口、1路40G…