昇思25天学习打卡营第24天 | LSTM+CRF序列标注

内容介绍:

序列标注指给定输入序列,给序列中每个Token进行标注标签的过程。序列标注问题通常用于从文本中进行信息抽取,包括分词(Word Segmentation)、词性标注(Position Tagging)、命名实体识别(Named Entity Recognition, NER)等。以命名实体识别为例:

输入序列
输出标注BIIIOOOOOBI

如上表所示,清华大学 和 北京是地名,需要将其识别,我们对每个输入的单词预测其标签,最后根据标签来识别实体。

这里使用了一种常见的命名实体识别的标注方法——“BIOE”标注,将一个实体(Entity)的开头标注为B,其他部分标注为I,非实体标注为O。

条件随机场(Conditional Random Field, CRF)

从上文的举例可以看到,对序列进行标注,实际上是对序列中每个Token进行标签预测,可以直接视作简单的多分类问题。但是序列标注不仅仅需要对单个Token进行分类预测,同时相邻Token直接有关联关系。以清华大学一词为例:

输入序列
输出标注BIII
输出标注OIII×

如上表所示,正确的实体中包含的4个Token有依赖关系,I前必须是B或I,而错误输出结果将字标注为O,违背了这一依赖。将命名实体识别视为多分类问题,则每个词的预测概率都是独立的,易产生类似的问题,因此需要引入一种能够学习到此种关联关系的算法来保证预测结果的正确性。而条件随机场是适合此类场景的一种概率图模型概率图模型概率图模型。

具体内容:

1. 导包

import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore.common.initializer import initializer, Uniform
from tqdm import tqdm

2. score计算

def compute_score(emissions, tags, seq_ends, mask, trans, start_trans, end_trans):# emissions: (seq_length, batch_size, num_tags)# tags: (seq_length, batch_size)# mask: (seq_length, batch_size)seq_length, batch_size = tags.shapemask = mask.astype(emissions.dtype)# 将score设置为初始转移概率# shape: (batch_size,)score = start_trans[tags[0]]# score += 第一次发射概率# shape: (batch_size,)score += emissions[0, mnp.arange(batch_size), tags[0]]for i in range(1, seq_length):# 标签由i-1转移至i的转移概率(当mask == 1时有效)# shape: (batch_size,)score += trans[tags[i - 1], tags[i]] * mask[i]# 预测tags[i]的发射概率(当mask == 1时有效)# shape: (batch_size,)score += emissions[i, mnp.arange(batch_size), tags[i]] * mask[i]# 结束转移# shape: (batch_size,)last_tags = tags[seq_ends, mnp.arange(batch_size)]# score += 结束转移概率# shape: (batch_size,)score += end_trans[last_tags]return score

3. Normalizer计算

def compute_normalizer(emissions, mask, trans, start_trans, end_trans):# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)seq_length = emissions.shape[0]# 将score设置为初始转移概率,并加上第一次发射概率# shape: (batch_size, num_tags)score = start_trans + emissions[0]for i in range(1, seq_length):# 扩展score的维度用于总score的计算# shape: (batch_size, num_tags, 1)broadcast_score = score.expand_dims(2)# 扩展emission的维度用于总score的计算# shape: (batch_size, 1, num_tags)broadcast_emissions = emissions[i].expand_dims(1)# 根据公式(7),计算score_i# 此时broadcast_score是由第0个到当前Token所有可能路径# 对应score的log_sum_exp# shape: (batch_size, num_tags, num_tags)next_score = broadcast_score + trans + broadcast_emissions# 对score_i做log_sum_exp运算,用于下一个Token的score计算# shape: (batch_size, num_tags)next_score = ops.logsumexp(next_score, axis=1)# 当mask == 1时,score才会变化# shape: (batch_size, num_tags)score = mnp.where(mask[i].expand_dims(1), next_score, score)# 最后加结束转移概率# shape: (batch_size, num_tags)score += end_trans# 对所有可能的路径得分求log_sum_exp# shape: (batch_size,)return ops.logsumexp(score, axis=1)

4. Viterbi算法

在完成前向训练部分后,需要实现解码部分。这里我们选择适合求解序列最优路径的Viterbi算法。与计算Normalizer类似,使用动态规划求解所有可能的预测序列得分。不同的是在解码时同时需要将第𝑖个Token对应的score取值最大的标签保存,供后续使用Viterbi算法求解最优预测序列使用。

def viterbi_decode(emissions, mask, trans, start_trans, end_trans):# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)seq_length = mask.shape[0]score = start_trans + emissions[0]history = ()for i in range(1, seq_length):broadcast_score = score.expand_dims(2)broadcast_emission = emissions[i].expand_dims(1)next_score = broadcast_score + trans + broadcast_emission# 求当前Token对应score取值最大的标签,并保存indices = next_score.argmax(axis=1)history += (indices,)next_score = next_score.max(axis=1)score = mnp.where(mask[i].expand_dims(1), next_score, score)score += end_transreturn score, historydef post_decode(score, history, seq_length):# 使用Score和History计算最佳预测序列batch_size = seq_length.shape[0]seq_ends = seq_length - 1# shape: (batch_size,)best_tags_list = []# 依次对一个Batch中每个样例进行解码for idx in range(batch_size):# 查找使最后一个Token对应的预测概率最大的标签,# 并将其添加至最佳预测序列存储的列表中best_last_tag = score[idx].argmax(axis=0)best_tags = [int(best_last_tag.asnumpy())]# 重复查找每个Token对应的预测概率最大的标签,加入列表for hist in reversed(history[:seq_ends[idx]]):best_last_tag = hist[idx][best_tags[-1]]best_tags.append(int(best_last_tag.asnumpy()))# 将逆序求解的序列标签重置为正序best_tags.reverse()best_tags_list.append(best_tags)return best_tags_list

5. CRF层

完成上述前向训练和解码部分的代码后,将其组装完整的CRF层。考虑到输入序列可能存在Padding的情况,CRF的输入需要考虑输入序列的真实长度,因此除发射矩阵和标签外,加入seq_length参数传入序列Padding前的长度,并实现生成mask矩阵的sequence_mask方法。

综合上述代码,使用nn.Cell进行封装,最后实现完整的CRF层如下:

def sequence_mask(seq_length, max_length, batch_first=False):"""根据序列实际长度和最大长度生成mask矩阵"""range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)result = range_vector < seq_length.view(seq_length.shape + (1,))if batch_first:return result.astype(ms.int64)return result.astype(ms.int64).swapaxes(0, 1)class CRF(nn.Cell):def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:if num_tags <= 0:raise ValueError(f'invalid number of tags: {num_tags}')super().__init__()if reduction not in ('none', 'sum', 'mean', 'token_mean'):raise ValueError(f'invalid reduction: {reduction}')self.num_tags = num_tagsself.batch_first = batch_firstself.reduction = reductionself.start_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions')self.end_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions')self.transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions')def construct(self, emissions, tags=None, seq_length=None):if tags is None:return self._decode(emissions, seq_length)return self._forward(emissions, tags, seq_length)def _forward(self, emissions, tags=None, seq_length=None):if self.batch_first:batch_size, max_length = tags.shapeemissions = emissions.swapaxes(0, 1)tags = tags.swapaxes(0, 1)else:max_length, batch_size = tags.shapeif seq_length is None:seq_length = mnp.full((batch_size,), max_length, ms.int64)mask = sequence_mask(seq_length, max_length)# shape: (batch_size,)numerator = compute_score(emissions, tags, seq_length-1, mask, self.transitions, self.start_transitions, self.end_transitions)# shape: (batch_size,)denominator = compute_normalizer(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)# shape: (batch_size,)llh = denominator - numeratorif self.reduction == 'none':return llhif self.reduction == 'sum':return llh.sum()if self.reduction == 'mean':return llh.mean()return llh.sum() / mask.astype(emissions.dtype).sum()def _decode(self, emissions, seq_length=None):if self.batch_first:batch_size, max_length = emissions.shape[:2]emissions = emissions.swapaxes(0, 1)else:batch_size, max_length = emissions.shape[:2]if seq_length is None:seq_length = mnp.full((batch_size,), max_length, ms.int64)mask = sequence_mask(seq_length, max_length)return viterbi_decode(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)

6. BiLSTM+CRF模型

class BiLSTM_CRF(nn.Cell):def __init__(self, vocab_size, embedding_dim, hidden_dim, num_tags, padding_idx=0):super().__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, batch_first=True)self.hidden2tag = nn.Dense(hidden_dim, num_tags, 'he_uniform')self.crf = CRF(num_tags, batch_first=True)def construct(self, inputs, seq_length, tags=None):embeds = self.embedding(inputs)outputs, _ = self.lstm(embeds, seq_length=seq_length)feats = self.hidden2tag(outputs)crf_outs = self.crf(feats, tags, seq_length)return crf_outs

7. 词表

embedding_dim = 16
hidden_dim = 32training_data = [("清 华 大 学 坐 落 于 首 都 北 京".split(),"B I I I O O O O O B I".split()
), ("重 庆 是 一 个 魔 幻 城 市".split(),"B I O O O O O O O".split()
)]word_to_idx = {}
word_to_idx['<pad>'] = 0
for sentence, tags in training_data:for word in sentence:if word not in word_to_idx:word_to_idx[word] = len(word_to_idx)tag_to_idx = {"B": 0, "I": 1, "O": 2}

8. 初始化模型

model = BiLSTM_CRF(len(word_to_idx), embedding_dim, hidden_dim, len(tag_to_idx))
optimizer = nn.SGD(model.trainable_params(), learning_rate=0.01, weight_decay=1e-4)

9. 每步计算

grad_fn = ms.value_and_grad(model, None, optimizer.parameters)def train_step(data, seq_length, label):loss, grads = grad_fn(data, seq_length, label)optimizer(grads)return loss

10. 打包Batch

def prepare_sequence(seqs, word_to_idx, tag_to_idx):seq_outputs, label_outputs, seq_length = [], [], []max_len = max([len(i[0]) for i in seqs])for seq, tag in seqs:seq_length.append(len(seq))idxs = [word_to_idx[w] for w in seq]labels = [tag_to_idx[t] for t in tag]idxs.extend([word_to_idx['<pad>'] for i in range(max_len - len(seq))])labels.extend([tag_to_idx['O'] for i in range(max_len - len(seq))])seq_outputs.append(idxs)label_outputs.append(labels)return ms.Tensor(seq_outputs, ms.int64), \ms.Tensor(label_outputs, ms.int64), \ms.Tensor(seq_length, ms.int64)

11. 训练

steps = 500
with tqdm(total=steps) as t:for i in range(steps):loss = train_step(data, seq_length, label)t.set_postfix(loss=loss)t.update(1)

在深入学习了LSTM(长短期记忆网络)结合CRF(条件随机场)这一强大的序列标注模型之后,我深感这一组合在解决自然语言处理中的序列标注任务时展现出了非凡的魅力和实用性。这段学习旅程不仅拓宽了我的技术视野,也让我对自然语言处理领域的复杂性和精妙性有了更深一层的理解。

LSTM作为循环神经网络(RNN)的一种变体,通过引入“门”机制(遗忘门、输入门、输出门)有效解决了传统RNN在长序列处理中容易出现的梯度消失或梯度爆炸问题。它能够捕捉数据中的长期依赖关系,这对于理解自然语言这种高度上下文依赖的序列数据至关重要。而CRF作为一种统计建模方法,在给定输入序列的条件下,能够计算整个输出序列的联合概率分布,特别适合于序列标注这类需要全局最优解的任务。将LSTM与CRF结合,使得模型既能够捕捉到序列中的长期依赖信息,又能在全局范围内优化标注序列,从而显著提升了标注的准确性和鲁棒性。

在将LSTM+CRF模型应用于如命名实体识别(NER)、词性标注(POS Tagging)等具体任务时,我深刻体会到了理论与实践相结合的重要性。通过调整模型参数、优化网络结构、引入预训练词向量等技术手段,可以显著提升模型的性能。同时,面对不同领域、不同规模的数据集,模型的泛化能力和适应性也成为了考验模型优劣的关键指标。这让我意识到,在实际应用中,需要根据具体任务的特点和数据情况,灵活调整模型策略,以达到最佳效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/45401.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JavaScript 算法】二分查找:快速定位目标元素

&#x1f525; 个人主页&#xff1a;空白诗 文章目录 一、算法原理二、算法实现三、应用场景四、优化与扩展五、总结 二分查找&#xff08;Binary Search&#xff09;是一种高效的查找算法&#xff0c;适用于在有序数组中快速定位目标元素。相比于线性查找&#xff0c;二分查找…

护(H)网(W)行动正当时:你对HW知多少,一文带你全面了解护网行动

引言&#xff1a;2016年我国发布了《网络安全法》&#xff08;于2017年6月1日正式生效&#xff09;&#xff0c;明确规定了关键信息基础设施的运营者必须制定网络安全事件应急预案&#xff0c;并定期进行演练&#xff0c;为HW行动的开展提供了法律依据&#xff0c;通过红蓝对抗…

嵌入式裸机开发与 Linux 开发

引言 嵌入式系统在现代电子设备中占有重要地位&#xff0c;其开发模式主要分为裸机开发和基于操作系统&#xff08;如 Linux&#xff09;的开发。本文将详细介绍嵌入式裸机开发和 Linux 开发的特点、优缺点&#xff0c;并进行对比分析&#xff0c;以帮助读者更好地理解和选择合…

js 移动数组元素的几个方法

位置交换 /*** param {any[]} arr - 原始数组。* param {number} fromIndex - 当前元素所在位置索引。* param {number} toIndex - 移动到交换的位置索引。* returns {any[]} 返回修改后的数组。*/ const swapItem function(arr, fromIndex, toIndex) {arr[toIndex] arr.spl…

35、php 实现构建乘积数组、正则表达式匹配

题目&#xff1a; uniapp-v3是基于vue3语法的&#xff0c;在hbuilderx中运行即可 Project setup npm install Compiles and hot-reloads for development npm run serve Compiles and minifies for production npm run build 在HBuilderX中导入src文件打包;打包H5手机版可以…

Unity 中使用状态机模式来管理UI

1. 清晰的状态管理 状态机模式允许你以结构化的方式管理不同的UI状态。每个状态&#xff08;比如主菜单、设置菜单、游戏中界面等&#xff09;都有其独立的行为和属性&#xff0c;这使得管理复杂UI逻辑变得更加清晰和可维护。 2. 简化的状态切换 状态机模式可以简化不同UI状…

报表控件DevExpress Reporting中文教程 - 如何创建穿透钻取报表?

DevExpress Reporting是.NET Framework下功能完善的报表平台&#xff0c;它附带了易于使用的Visual Studio报表设计器和丰富的报表控件集&#xff0c;包括数据透视表、图表&#xff0c;因此您可以构建无与伦比、信息清晰的报表。 钻取报表允许用户通过单击主/活动报表文档中的…

Android的dtbo文件介绍

文章目录 设备树&#xff08;Device Tree&#xff09;设备树覆盖&#xff08;Device Tree Overlay, DTO&#xff09;dtbo文件的作用使用流程示例 dtbo 文件是 Android 设备中的设备树覆盖文件&#xff08;Device Tree Blob Overlay&#xff09;。它用于动态地修改设备树配置&am…

智能酒精壁炉与会所会客厅的氛围搭配

智能酒精壁炉与会所会客厅的氛围搭配可以创造出现代、高雅且舒适的环境&#xff0c;提升客人的整体体验。以下是如何将智能酒精壁炉与会所会客厅氛围相协调的几点建议&#xff1a; 现代化与高品位感&#xff1a; 智能酒精壁炉展现出现代化的设计和高科技特点&#xff0c;与会所…

应急响应-战后溯源反制社会工程学

&#x1f3bc;个人主页&#xff1a;金灰 &#x1f60e;作者简介:一名简单的大一学生;易编橙终身成长社群的嘉宾.✨ 专注网络空间安全服务,期待与您的交流分享~ 感谢您的点赞、关注、评论、收藏、是对我最大的认可和支持&#xff01;❤️ &#x1f34a;易编橙终身成长社群&#…

开源的混合AI搜索引擎;定制 Claude 3 Haiku 模型; 和gpt-4o同样Transformer架构的开源视觉语言模型;离线自动转录工具

✨ 1: MemFree MemFree是一款开源的混合AI搜索引擎&#xff0c;可搜索个人知识库和互联网。 MemFree 是一个开源的混合AI搜索引擎&#xff0c;可以同时在你的个人知识库&#xff08;如书签、笔记、文档等&#xff09;和互联网中进行搜索。这款搜索引擎的主要特点包括&#xf…

嵌入式智能手表项目实现分享

简介 这是一个基于STM32F411CUE6和FreeRTOS和LVGL的低成本的超多功能的STM32智能手表~ 推荐 如果觉得这个手表的硬件难做,又想学习相关的东西,可以试下这个新出的开发板,功能和例程demo更多!FriPi炸鸡派STM32F411开发板: 【STM32开发板】 FryPi炸鸡派 - 嘉立创EDA开源硬件平…

使用mediapip 检测pose 并作为一个服务

代码 import uvicorn from fastapi import FastAPI, HTTPException import cv2 import mediapipe as mp from pydantic import BaseModelapp FastAPI()# 创建一个模型来序列化姿态数据 class PoseData(BaseModel):landmarks: list# 初始化MediaPipe的姿态估计模型 mp_pose m…

GD32MCU最小系统构成条件

大家是否有这个疑惑&#xff1a;大学课程学习51的时候&#xff0c;老师告诉我们51的最小系统构成&#xff1f;那么进入32位单片机时代&#xff0c;gd32最小系统构成又是怎么样的呢&#xff1f; 1.供电电路 需要确保供电的电压电流稳定&#xff0c;以东方红开发版为例&#xff…

Qt WARNING: Failure to find: xxxxxx.h

重新规划了自定义文件夹后&#xff0c;编译出现错误&#xff0c;如 Qmake WARNING: Failure to find: xxxxxx.h 或者 error: XXXX.h: No such file or directory 如果文件是在windows下直接重新放置新的目录&#xff0c;那么需要修改.pro文件 老文件的可能没有注释或删除&am…

ABAQUS广东正版代理商:亿达四方——达索官方授权

在粤港澳大湾区建设的浪潮中&#xff0c;广东作为中国改革开放的前沿阵地&#xff0c;始终走在科技创新的最前线。亿达四方&#xff0c;作为国际领先的仿真软件ABAQUS在广东地区的官方授权代理商&#xff0c;正以先进的技术和服务&#xff0c;推动着广东地区制造业向智能化、高…

【Tomcat目录详解】关于Tomcat你还需要了解的详细内容

希望文章能给到你启发和灵感&#xff5e; 如果觉得文章对你有帮助的话&#xff0c;点赞 关注 收藏 支持一下博主吧&#xff5e; 阅读指南 开篇说明一、基础环境说明1.1 硬件环境1.2 软件环境 二、Tomcat的文件结构2.1 bin目录2.1.1 startup和shutdown2.1.2 Catalina2.1.3 serv…

深入解析EtherCAT `CheckProductCode` 属性:确保系统一致性与安全性

在工业自动化领域&#xff0c;EtherCAT&#xff08;Ethernet for Control Automation Technology&#xff09;已成为一种广泛应用的实时以太网协议。它的高性能、灵活性和可靠性使其成为复杂自动化任务的理想选择。然而&#xff0c;确保系统的正确配置和安全运行是使用EtherCAT…

【43页PPT】企业数据架构数据治理设计规划咨询项目建议

本项目聚焦于企业数据资产的深度挖掘与价值最大化&#xff0c;旨在通过一系列定制化策略与架构设计&#xff0c;重塑企业的数据生态体系。我们的核心任务包括&#xff1a; 企业现状深度剖析&#xff1a;全面审视企业当前的数据环境、业务流程及战略方向&#xff0c;精准把握数…

Opencv中的直方图

cv2.calcHist() 直方图是图像中像素强度分布的图形表达方式&#xff0c;统计了每一个强度值所具有的像素个数。并可以计算图像中的一个或多个通道的直方图。 dst cv2.calcHist(images, channels, mask, histSize, ranges[, hist[, accumulate]])images&#xff1a;源图像&am…