昇思25天学习打卡营第24天 | LSTM+CRF序列标注

内容介绍:

序列标注指给定输入序列,给序列中每个Token进行标注标签的过程。序列标注问题通常用于从文本中进行信息抽取,包括分词(Word Segmentation)、词性标注(Position Tagging)、命名实体识别(Named Entity Recognition, NER)等。以命名实体识别为例:

输入序列
输出标注BIIIOOOOOBI

如上表所示,清华大学 和 北京是地名,需要将其识别,我们对每个输入的单词预测其标签,最后根据标签来识别实体。

这里使用了一种常见的命名实体识别的标注方法——“BIOE”标注,将一个实体(Entity)的开头标注为B,其他部分标注为I,非实体标注为O。

条件随机场(Conditional Random Field, CRF)

从上文的举例可以看到,对序列进行标注,实际上是对序列中每个Token进行标签预测,可以直接视作简单的多分类问题。但是序列标注不仅仅需要对单个Token进行分类预测,同时相邻Token直接有关联关系。以清华大学一词为例:

输入序列
输出标注BIII
输出标注OIII×

如上表所示,正确的实体中包含的4个Token有依赖关系,I前必须是B或I,而错误输出结果将字标注为O,违背了这一依赖。将命名实体识别视为多分类问题,则每个词的预测概率都是独立的,易产生类似的问题,因此需要引入一种能够学习到此种关联关系的算法来保证预测结果的正确性。而条件随机场是适合此类场景的一种概率图模型概率图模型概率图模型。

具体内容:

1. 导包

import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore.common.initializer import initializer, Uniform
from tqdm import tqdm

2. score计算

def compute_score(emissions, tags, seq_ends, mask, trans, start_trans, end_trans):# emissions: (seq_length, batch_size, num_tags)# tags: (seq_length, batch_size)# mask: (seq_length, batch_size)seq_length, batch_size = tags.shapemask = mask.astype(emissions.dtype)# 将score设置为初始转移概率# shape: (batch_size,)score = start_trans[tags[0]]# score += 第一次发射概率# shape: (batch_size,)score += emissions[0, mnp.arange(batch_size), tags[0]]for i in range(1, seq_length):# 标签由i-1转移至i的转移概率(当mask == 1时有效)# shape: (batch_size,)score += trans[tags[i - 1], tags[i]] * mask[i]# 预测tags[i]的发射概率(当mask == 1时有效)# shape: (batch_size,)score += emissions[i, mnp.arange(batch_size), tags[i]] * mask[i]# 结束转移# shape: (batch_size,)last_tags = tags[seq_ends, mnp.arange(batch_size)]# score += 结束转移概率# shape: (batch_size,)score += end_trans[last_tags]return score

3. Normalizer计算

def compute_normalizer(emissions, mask, trans, start_trans, end_trans):# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)seq_length = emissions.shape[0]# 将score设置为初始转移概率,并加上第一次发射概率# shape: (batch_size, num_tags)score = start_trans + emissions[0]for i in range(1, seq_length):# 扩展score的维度用于总score的计算# shape: (batch_size, num_tags, 1)broadcast_score = score.expand_dims(2)# 扩展emission的维度用于总score的计算# shape: (batch_size, 1, num_tags)broadcast_emissions = emissions[i].expand_dims(1)# 根据公式(7),计算score_i# 此时broadcast_score是由第0个到当前Token所有可能路径# 对应score的log_sum_exp# shape: (batch_size, num_tags, num_tags)next_score = broadcast_score + trans + broadcast_emissions# 对score_i做log_sum_exp运算,用于下一个Token的score计算# shape: (batch_size, num_tags)next_score = ops.logsumexp(next_score, axis=1)# 当mask == 1时,score才会变化# shape: (batch_size, num_tags)score = mnp.where(mask[i].expand_dims(1), next_score, score)# 最后加结束转移概率# shape: (batch_size, num_tags)score += end_trans# 对所有可能的路径得分求log_sum_exp# shape: (batch_size,)return ops.logsumexp(score, axis=1)

4. Viterbi算法

在完成前向训练部分后,需要实现解码部分。这里我们选择适合求解序列最优路径的Viterbi算法。与计算Normalizer类似,使用动态规划求解所有可能的预测序列得分。不同的是在解码时同时需要将第𝑖个Token对应的score取值最大的标签保存,供后续使用Viterbi算法求解最优预测序列使用。

def viterbi_decode(emissions, mask, trans, start_trans, end_trans):# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)seq_length = mask.shape[0]score = start_trans + emissions[0]history = ()for i in range(1, seq_length):broadcast_score = score.expand_dims(2)broadcast_emission = emissions[i].expand_dims(1)next_score = broadcast_score + trans + broadcast_emission# 求当前Token对应score取值最大的标签,并保存indices = next_score.argmax(axis=1)history += (indices,)next_score = next_score.max(axis=1)score = mnp.where(mask[i].expand_dims(1), next_score, score)score += end_transreturn score, historydef post_decode(score, history, seq_length):# 使用Score和History计算最佳预测序列batch_size = seq_length.shape[0]seq_ends = seq_length - 1# shape: (batch_size,)best_tags_list = []# 依次对一个Batch中每个样例进行解码for idx in range(batch_size):# 查找使最后一个Token对应的预测概率最大的标签,# 并将其添加至最佳预测序列存储的列表中best_last_tag = score[idx].argmax(axis=0)best_tags = [int(best_last_tag.asnumpy())]# 重复查找每个Token对应的预测概率最大的标签,加入列表for hist in reversed(history[:seq_ends[idx]]):best_last_tag = hist[idx][best_tags[-1]]best_tags.append(int(best_last_tag.asnumpy()))# 将逆序求解的序列标签重置为正序best_tags.reverse()best_tags_list.append(best_tags)return best_tags_list

5. CRF层

完成上述前向训练和解码部分的代码后,将其组装完整的CRF层。考虑到输入序列可能存在Padding的情况,CRF的输入需要考虑输入序列的真实长度,因此除发射矩阵和标签外,加入seq_length参数传入序列Padding前的长度,并实现生成mask矩阵的sequence_mask方法。

综合上述代码,使用nn.Cell进行封装,最后实现完整的CRF层如下:

def sequence_mask(seq_length, max_length, batch_first=False):"""根据序列实际长度和最大长度生成mask矩阵"""range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)result = range_vector < seq_length.view(seq_length.shape + (1,))if batch_first:return result.astype(ms.int64)return result.astype(ms.int64).swapaxes(0, 1)class CRF(nn.Cell):def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:if num_tags <= 0:raise ValueError(f'invalid number of tags: {num_tags}')super().__init__()if reduction not in ('none', 'sum', 'mean', 'token_mean'):raise ValueError(f'invalid reduction: {reduction}')self.num_tags = num_tagsself.batch_first = batch_firstself.reduction = reductionself.start_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions')self.end_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions')self.transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions')def construct(self, emissions, tags=None, seq_length=None):if tags is None:return self._decode(emissions, seq_length)return self._forward(emissions, tags, seq_length)def _forward(self, emissions, tags=None, seq_length=None):if self.batch_first:batch_size, max_length = tags.shapeemissions = emissions.swapaxes(0, 1)tags = tags.swapaxes(0, 1)else:max_length, batch_size = tags.shapeif seq_length is None:seq_length = mnp.full((batch_size,), max_length, ms.int64)mask = sequence_mask(seq_length, max_length)# shape: (batch_size,)numerator = compute_score(emissions, tags, seq_length-1, mask, self.transitions, self.start_transitions, self.end_transitions)# shape: (batch_size,)denominator = compute_normalizer(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)# shape: (batch_size,)llh = denominator - numeratorif self.reduction == 'none':return llhif self.reduction == 'sum':return llh.sum()if self.reduction == 'mean':return llh.mean()return llh.sum() / mask.astype(emissions.dtype).sum()def _decode(self, emissions, seq_length=None):if self.batch_first:batch_size, max_length = emissions.shape[:2]emissions = emissions.swapaxes(0, 1)else:batch_size, max_length = emissions.shape[:2]if seq_length is None:seq_length = mnp.full((batch_size,), max_length, ms.int64)mask = sequence_mask(seq_length, max_length)return viterbi_decode(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)

6. BiLSTM+CRF模型

class BiLSTM_CRF(nn.Cell):def __init__(self, vocab_size, embedding_dim, hidden_dim, num_tags, padding_idx=0):super().__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, batch_first=True)self.hidden2tag = nn.Dense(hidden_dim, num_tags, 'he_uniform')self.crf = CRF(num_tags, batch_first=True)def construct(self, inputs, seq_length, tags=None):embeds = self.embedding(inputs)outputs, _ = self.lstm(embeds, seq_length=seq_length)feats = self.hidden2tag(outputs)crf_outs = self.crf(feats, tags, seq_length)return crf_outs

7. 词表

embedding_dim = 16
hidden_dim = 32training_data = [("清 华 大 学 坐 落 于 首 都 北 京".split(),"B I I I O O O O O B I".split()
), ("重 庆 是 一 个 魔 幻 城 市".split(),"B I O O O O O O O".split()
)]word_to_idx = {}
word_to_idx['<pad>'] = 0
for sentence, tags in training_data:for word in sentence:if word not in word_to_idx:word_to_idx[word] = len(word_to_idx)tag_to_idx = {"B": 0, "I": 1, "O": 2}

8. 初始化模型

model = BiLSTM_CRF(len(word_to_idx), embedding_dim, hidden_dim, len(tag_to_idx))
optimizer = nn.SGD(model.trainable_params(), learning_rate=0.01, weight_decay=1e-4)

9. 每步计算

grad_fn = ms.value_and_grad(model, None, optimizer.parameters)def train_step(data, seq_length, label):loss, grads = grad_fn(data, seq_length, label)optimizer(grads)return loss

10. 打包Batch

def prepare_sequence(seqs, word_to_idx, tag_to_idx):seq_outputs, label_outputs, seq_length = [], [], []max_len = max([len(i[0]) for i in seqs])for seq, tag in seqs:seq_length.append(len(seq))idxs = [word_to_idx[w] for w in seq]labels = [tag_to_idx[t] for t in tag]idxs.extend([word_to_idx['<pad>'] for i in range(max_len - len(seq))])labels.extend([tag_to_idx['O'] for i in range(max_len - len(seq))])seq_outputs.append(idxs)label_outputs.append(labels)return ms.Tensor(seq_outputs, ms.int64), \ms.Tensor(label_outputs, ms.int64), \ms.Tensor(seq_length, ms.int64)

11. 训练

steps = 500
with tqdm(total=steps) as t:for i in range(steps):loss = train_step(data, seq_length, label)t.set_postfix(loss=loss)t.update(1)

在深入学习了LSTM(长短期记忆网络)结合CRF(条件随机场)这一强大的序列标注模型之后,我深感这一组合在解决自然语言处理中的序列标注任务时展现出了非凡的魅力和实用性。这段学习旅程不仅拓宽了我的技术视野,也让我对自然语言处理领域的复杂性和精妙性有了更深一层的理解。

LSTM作为循环神经网络(RNN)的一种变体,通过引入“门”机制(遗忘门、输入门、输出门)有效解决了传统RNN在长序列处理中容易出现的梯度消失或梯度爆炸问题。它能够捕捉数据中的长期依赖关系,这对于理解自然语言这种高度上下文依赖的序列数据至关重要。而CRF作为一种统计建模方法,在给定输入序列的条件下,能够计算整个输出序列的联合概率分布,特别适合于序列标注这类需要全局最优解的任务。将LSTM与CRF结合,使得模型既能够捕捉到序列中的长期依赖信息,又能在全局范围内优化标注序列,从而显著提升了标注的准确性和鲁棒性。

在将LSTM+CRF模型应用于如命名实体识别(NER)、词性标注(POS Tagging)等具体任务时,我深刻体会到了理论与实践相结合的重要性。通过调整模型参数、优化网络结构、引入预训练词向量等技术手段,可以显著提升模型的性能。同时,面对不同领域、不同规模的数据集,模型的泛化能力和适应性也成为了考验模型优劣的关键指标。这让我意识到,在实际应用中,需要根据具体任务的特点和数据情况,灵活调整模型策略,以达到最佳效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/45401.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JavaScript 算法】二分查找:快速定位目标元素

&#x1f525; 个人主页&#xff1a;空白诗 文章目录 一、算法原理二、算法实现三、应用场景四、优化与扩展五、总结 二分查找&#xff08;Binary Search&#xff09;是一种高效的查找算法&#xff0c;适用于在有序数组中快速定位目标元素。相比于线性查找&#xff0c;二分查找…

护(H)网(W)行动正当时:你对HW知多少,一文带你全面了解护网行动

引言&#xff1a;2016年我国发布了《网络安全法》&#xff08;于2017年6月1日正式生效&#xff09;&#xff0c;明确规定了关键信息基础设施的运营者必须制定网络安全事件应急预案&#xff0c;并定期进行演练&#xff0c;为HW行动的开展提供了法律依据&#xff0c;通过红蓝对抗…

Unity 中使用状态机模式来管理UI

1. 清晰的状态管理 状态机模式允许你以结构化的方式管理不同的UI状态。每个状态&#xff08;比如主菜单、设置菜单、游戏中界面等&#xff09;都有其独立的行为和属性&#xff0c;这使得管理复杂UI逻辑变得更加清晰和可维护。 2. 简化的状态切换 状态机模式可以简化不同UI状…

报表控件DevExpress Reporting中文教程 - 如何创建穿透钻取报表?

DevExpress Reporting是.NET Framework下功能完善的报表平台&#xff0c;它附带了易于使用的Visual Studio报表设计器和丰富的报表控件集&#xff0c;包括数据透视表、图表&#xff0c;因此您可以构建无与伦比、信息清晰的报表。 钻取报表允许用户通过单击主/活动报表文档中的…

Android的dtbo文件介绍

文章目录 设备树&#xff08;Device Tree&#xff09;设备树覆盖&#xff08;Device Tree Overlay, DTO&#xff09;dtbo文件的作用使用流程示例 dtbo 文件是 Android 设备中的设备树覆盖文件&#xff08;Device Tree Blob Overlay&#xff09;。它用于动态地修改设备树配置&am…

智能酒精壁炉与会所会客厅的氛围搭配

智能酒精壁炉与会所会客厅的氛围搭配可以创造出现代、高雅且舒适的环境&#xff0c;提升客人的整体体验。以下是如何将智能酒精壁炉与会所会客厅氛围相协调的几点建议&#xff1a; 现代化与高品位感&#xff1a; 智能酒精壁炉展现出现代化的设计和高科技特点&#xff0c;与会所…

应急响应-战后溯源反制社会工程学

&#x1f3bc;个人主页&#xff1a;金灰 &#x1f60e;作者简介:一名简单的大一学生;易编橙终身成长社群的嘉宾.✨ 专注网络空间安全服务,期待与您的交流分享~ 感谢您的点赞、关注、评论、收藏、是对我最大的认可和支持&#xff01;❤️ &#x1f34a;易编橙终身成长社群&#…

开源的混合AI搜索引擎;定制 Claude 3 Haiku 模型; 和gpt-4o同样Transformer架构的开源视觉语言模型;离线自动转录工具

✨ 1: MemFree MemFree是一款开源的混合AI搜索引擎&#xff0c;可搜索个人知识库和互联网。 MemFree 是一个开源的混合AI搜索引擎&#xff0c;可以同时在你的个人知识库&#xff08;如书签、笔记、文档等&#xff09;和互联网中进行搜索。这款搜索引擎的主要特点包括&#xf…

嵌入式智能手表项目实现分享

简介 这是一个基于STM32F411CUE6和FreeRTOS和LVGL的低成本的超多功能的STM32智能手表~ 推荐 如果觉得这个手表的硬件难做,又想学习相关的东西,可以试下这个新出的开发板,功能和例程demo更多!FriPi炸鸡派STM32F411开发板: 【STM32开发板】 FryPi炸鸡派 - 嘉立创EDA开源硬件平…

GD32MCU最小系统构成条件

大家是否有这个疑惑&#xff1a;大学课程学习51的时候&#xff0c;老师告诉我们51的最小系统构成&#xff1f;那么进入32位单片机时代&#xff0c;gd32最小系统构成又是怎么样的呢&#xff1f; 1.供电电路 需要确保供电的电压电流稳定&#xff0c;以东方红开发版为例&#xff…

ABAQUS广东正版代理商:亿达四方——达索官方授权

在粤港澳大湾区建设的浪潮中&#xff0c;广东作为中国改革开放的前沿阵地&#xff0c;始终走在科技创新的最前线。亿达四方&#xff0c;作为国际领先的仿真软件ABAQUS在广东地区的官方授权代理商&#xff0c;正以先进的技术和服务&#xff0c;推动着广东地区制造业向智能化、高…

【Tomcat目录详解】关于Tomcat你还需要了解的详细内容

希望文章能给到你启发和灵感&#xff5e; 如果觉得文章对你有帮助的话&#xff0c;点赞 关注 收藏 支持一下博主吧&#xff5e; 阅读指南 开篇说明一、基础环境说明1.1 硬件环境1.2 软件环境 二、Tomcat的文件结构2.1 bin目录2.1.1 startup和shutdown2.1.2 Catalina2.1.3 serv…

【43页PPT】企业数据架构数据治理设计规划咨询项目建议

本项目聚焦于企业数据资产的深度挖掘与价值最大化&#xff0c;旨在通过一系列定制化策略与架构设计&#xff0c;重塑企业的数据生态体系。我们的核心任务包括&#xff1a; 企业现状深度剖析&#xff1a;全面审视企业当前的数据环境、业务流程及战略方向&#xff0c;精准把握数…

Docker 基本管理及部署

目录 1.Docker概述 1.1 Docker是什么&#xff1f; 1.2 Docker的宗旨 1.3 容器的优点 1.4 Docker与虚拟机的区别 1.5 容器在内核中支持的两种技术 1.6 namespace的六大类型 2.Docker核心概念 2.1 镜像 2.2 容器 2.3 仓库 3.安装Docker 3.1 查看 docker 版本信息 4.…

FPGA上板项目(一)——点灯熟悉完整开发流程、ILA在线调试

目录 创建工程创建 HDL 代码仿真添加管脚约束添加时序约束生成 bit 文件下载ILA 在线调试 创建工程 型号选择&#xff1a;以 AXU9EG 开发板为例&#xff0c;芯片选择 xczu9eg-ffvb1156-2-i 创建 HDL 代码 注意&#xff1a;由于输入时钟为 200MHz 的差分时钟&#xff0c;因此…

2024年高职云计算实验室建设及云计算实训平台整体解决方案

随着云计算技术的飞速发展&#xff0c;高职院校亟需构建一个与行业需求紧密结合的云计算实验室和实训平台。以下是针对2024年高职院校云计算实验室建设的全面解决方案。 1、在高职云计算实验室的建设与规划中&#xff0c;首要任务是立足于云计算学科的精准定位&#xff0c;紧密…

4.SpringBoot自定义封装Starter实践

目录 概述旧版2.7之后自定义Starter 概述 SpringBoot自定义封装Starter实践 旧版 在SpringBoot2.7之前&#xff0c;META-INF 下 spring.factories 加 org.springframework.boot.autoconfigure.EnableAutoConfigurationXXAutoConfiguration 2.7之后 SpringBoot2.7推出新的自动配…

爬虫-浏览器自动化

什么是selenium selenium是浏览器自动化测试框架&#xff0c;原本用于网页测试。但到了爬虫领域&#xff0c;它又成为了爬虫的好帮手。有了 selenium&#xff0c;我们便不再需要判断网页数据加载的方式&#xff0c;只要让 selenium 自动控制浏览器&#xff0c;就像有双无形的手…

【以史为镜、以史明志,知史爱党、知史爱国】中华上下五千年之-元朝

元朝是中国历史上第一个由蒙古族族建立的大统一封建王朝。完整的元王朝历史进程分为四个阶段&#xff1a; 元朝的历史让我们一笔带过&#xff0c;相信大家也不怎么喜欢看。同意的点赞&#xff01; 元朝的前身——蒙古汗国&#xff08;1206年—1271年&#xff09; 建立王朝统治—…

快速体验 Llama3 的 4 种方式,本地部署,800 tokens/s 的推理速度真的太快了!

北京时间4月19日凌晨&#xff0c;Meta在官网上官宣了Llama-3&#xff0c;作为继Llama1、Llama2和CodeLlama之后的第三代模型&#xff0c;Llama3在多个基准测试中实现了全面领先&#xff0c;性能优于业界同类最先进的模型&#xff0c;你有没有第一时间体验上呢&#xff0c;这篇文…