昇思训练营打卡第二十四天(LSTM+CRF序列标注)

LSTM(Long Short-Term Memory,长短时记忆网络)是一种特殊的循环神经网络(RNN),由Hochreiter和Schmidhuber在1997年提出。它旨在解决传统RNN在处理长距离依赖问题时出现的梯度消失和梯度爆炸问题。以下是LSTM的一些主要特点:

  1. 细胞状态(Cell State):LSTM的核心是细胞状态,它贯穿整个LSTM网络,使得信息可以在网络中长时间传递。

  2. 门结构(Gates):LSTM通过门结构来控制信息的流入、流出和遗忘。主要包括以下三个门:

    • 遗忘门(Forget Gate):决定从细胞状态中丢弃什么信息。
    • 输入门(Input Gate):决定哪些新的信息被存储在细胞状态中。
    • 输出门(Output Gate):决定从细胞状态中输出什么信息。
  3. 梯度消失和梯度爆炸问题的缓解:由于LSTM的门结构,它能够在长序列中保持稳定的梯度,从而有效地缓解梯度消失和梯度爆炸问题。

CRF(Conditional Random Field,条件随机场)是一种概率图模型,常用于自然语言处理(NLP)中的序列标注任务,如词性标注、命名实体识别(NER)等。CRF模型能够考虑上下文信息,对序列中的每个元素(如词语)进行标注,使得整个序列的标注结果尽可能合理。

以下是CRF的一些关键特点:

  1. 无向图:CRF是一种无向图模型,它假设序列中的每个元素与其相邻元素之间存在依赖关系。

  2. 条件概率:CRF模型定义了一个条件概率分布,它表示在给定输入序列的情况下,输出标签序列的概率。

  3. 特征函数:CRF模型使用特征函数来描述输入与输出之间的关系。特征函数可以是基于输入序列和输出标签的任意函数,例如当前词、前后词、词性等。

  4. 全局最优标注:CRF在预测时会考虑整个序列的信息,寻找全局最优的标签序列,而不是单独对每个元素进行最优标注。

  5. 训练与解码:CRF的训练通常使用最大似然估计,而解码(即预测)时通常使用维特比算法(Viterbi algorithm)来找到最有可能的标签序列。

CRF模型的结构通常包括以下两个部分:

  • 发射特征(Emission Features):这些特征与输入序列中的每个元素相关,描述了输入元素与其对应标签的关系。
  • 转移特征(Transition Features):这些特征描述了标签序列中相邻标签之间的关系。

CRF模型的优势在于它能够有效地利用上下文信息,并且能够通过特征函数灵活地定义输入与输出之间的关系。这使得CRF在处理序列标注问题时通常比其他模型(如基于规则的模型或简单的概率模型)表现得更好。

Score计算

def compute_score(emissions, tags, seq_ends, mask, trans, start_trans, end_trans):# emissions: (seq_length, batch_size, num_tags)# tags: (seq_length, batch_size)# mask: (seq_length, batch_size)seq_length, batch_size = tags.shapemask = mask.astype(emissions.dtype)# 将score设置为初始转移概率# shape: (batch_size,)score = start_trans[tags[0]]# score += 第一次发射概率# shape: (batch_size,)score += emissions[0, mnp.arange(batch_size), tags[0]]for i in range(1, seq_length):# 标签由i-1转移至i的转移概率(当mask == 1时有效)# shape: (batch_size,)score += trans[tags[i - 1], tags[i]] * mask[i]# 预测tags[i]的发射概率(当mask == 1时有效)# shape: (batch_size,)score += emissions[i, mnp.arange(batch_size), tags[i]] * mask[i]# 结束转移# shape: (batch_size,)last_tags = tags[seq_ends, mnp.arange(batch_size)]# score += 结束转移概率# shape: (batch_size,)score += end_trans[last_tags]return score

Normalizer计算

def compute_normalizer(emissions, mask, trans, start_trans, end_trans):# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)seq_length = emissions.shape[0]# 将score设置为初始转移概率,并加上第一次发射概率# shape: (batch_size, num_tags)score = start_trans + emissions[0]for i in range(1, seq_length):# 扩展score的维度用于总score的计算# shape: (batch_size, num_tags, 1)broadcast_score = score.expand_dims(2)# 扩展emission的维度用于总score的计算# shape: (batch_size, 1, num_tags)broadcast_emissions = emissions[i].expand_dims(1)# 根据公式(7),计算score_i# 此时broadcast_score是由第0个到当前Token所有可能路径# 对应score的log_sum_exp# shape: (batch_size, num_tags, num_tags)next_score = broadcast_score + trans + broadcast_emissions# 对score_i做log_sum_exp运算,用于下一个Token的score计算# shape: (batch_size, num_tags)next_score = ops.logsumexp(next_score, axis=1)# 当mask == 1时,score才会变化# shape: (batch_size, num_tags)score = mnp.where(mask[i].expand_dims(1), next_score, score)# 最后加结束转移概率# shape: (batch_size, num_tags)score += end_trans# 对所有可能的路径得分求log_sum_exp# shape: (batch_size,)return ops.logsumexp(score, axis=1)

Viterbi算法

def viterbi_decode(emissions, mask, trans, start_trans, end_trans):# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)seq_length = mask.shape[0]score = start_trans + emissions[0]history = ()for i in range(1, seq_length):broadcast_score = score.expand_dims(2)broadcast_emission = emissions[i].expand_dims(1)next_score = broadcast_score + trans + broadcast_emission# 求当前Token对应score取值最大的标签,并保存indices = next_score.argmax(axis=1)history += (indices,)next_score = next_score.max(axis=1)score = mnp.where(mask[i].expand_dims(1), next_score, score)score += end_transreturn score, historydef post_decode(score, history, seq_length):# 使用Score和History计算最佳预测序列batch_size = seq_length.shape[0]seq_ends = seq_length - 1# shape: (batch_size,)best_tags_list = []# 依次对一个Batch中每个样例进行解码for idx in range(batch_size):# 查找使最后一个Token对应的预测概率最大的标签,# 并将其添加至最佳预测序列存储的列表中best_last_tag = score[idx].argmax(axis=0)best_tags = [int(best_last_tag.asnumpy())]# 重复查找每个Token对应的预测概率最大的标签,加入列表for hist in reversed(history[:seq_ends[idx]]):best_last_tag = hist[idx][best_tags[-1]]best_tags.append(int(best_last_tag.asnumpy()))# 将逆序求解的序列标签重置为正序best_tags.reverse()best_tags_list.append(best_tags)return best_tags_list

CRF层

import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore.common.initializer import initializer, Uniformdef sequence_mask(seq_length, max_length, batch_first=False):"""根据序列实际长度和最大长度生成mask矩阵"""range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)result = range_vector < seq_length.view(seq_length.shape + (1,))if batch_first:return result.astype(ms.int64)return result.astype(ms.int64).swapaxes(0, 1)class CRF(nn.Cell):def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:if num_tags <= 0:raise ValueError(f'invalid number of tags: {num_tags}')super().__init__()if reduction not in ('none', 'sum', 'mean', 'token_mean'):raise ValueError(f'invalid reduction: {reduction}')self.num_tags = num_tagsself.batch_first = batch_firstself.reduction = reductionself.start_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions')self.end_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions')self.transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions')def construct(self, emissions, tags=None, seq_length=None):if tags is None:return self._decode(emissions, seq_length)return self._forward(emissions, tags, seq_length)def _forward(self, emissions, tags=None, seq_length=None):if self.batch_first:batch_size, max_length = tags.shapeemissions = emissions.swapaxes(0, 1)tags = tags.swapaxes(0, 1)else:max_length, batch_size = tags.shapeif seq_length is None:seq_length = mnp.full((batch_size,), max_length, ms.int64)mask = sequence_mask(seq_length, max_length)# shape: (batch_size,)numerator = compute_score(emissions, tags, seq_length-1, mask, self.transitions, self.start_transitions, self.end_transitions)# shape: (batch_size,)denominator = compute_normalizer(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)# shape: (batch_size,)llh = denominator - numeratorif self.reduction == 'none':return llhif self.reduction == 'sum':return llh.sum()if self.reduction == 'mean':return llh.mean()return llh.sum() / mask.astype(emissions.dtype).sum()def _decode(self, emissions, seq_length=None):if self.batch_first:batch_size, max_length = emissions.shape[:2]emissions = emissions.swapaxes(0, 1)else:batch_size, max_length = emissions.shape[:2]if seq_length is None:seq_length = mnp.full((batch_size,), max_length, ms.int64)mask = sequence_mask(seq_length, max_length)return viterbi_decode(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)

BiLSTM+CRF模型

class BiLSTM_CRF(nn.Cell):def __init__(self, vocab_size, embedding_dim, hidden_dim, num_tags, padding_idx=0):super().__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, batch_first=True)self.hidden2tag = nn.Dense(hidden_dim, num_tags, 'he_uniform')self.crf = CRF(num_tags, batch_first=True)def construct(self, inputs, seq_length, tags=None):embeds = self.embedding(inputs)outputs, _ = self.lstm(embeds, seq_length=seq_length)feats = self.hidden2tag(outputs)crf_outs = self.crf(feats, tags, seq_length)return crf_outs
embedding_dim = 16
hidden_dim = 32training_data = [("清 华 大 学 ".split(),"B I I I ".split()
), ("重 庆 ".split(),"B I ".split()
)]word_to_idx = {}
word_to_idx['<pad>'] = 0
for sentence, tags in training_data:for word in sentence:if word not in word_to_idx:word_to_idx[word] = len(word_to_idx)tag_to_idx = {"B": 0, "I": 1, "O": 2}len(word_to_idx)
model = BiLSTM_CRF(len(word_to_idx), embedding_dim, hidden_dim, len(tag_to_idx))
optimizer = nn.SGD(model.trainable_params(), learning_rate=0.01, weight_decay=1e-4)
grad_fn = ms.value_and_grad(model, None, optimizer.parameters)def train_step(data, seq_length, label):loss, grads = grad_fn(data, seq_length, label)optimizer(grads)return loss
def prepare_sequence(seqs, word_to_idx, tag_to_idx):seq_outputs, label_outputs, seq_length = [], [], []max_len = max([len(i[0]) for i in seqs])for seq, tag in seqs:seq_length.append(len(seq))idxs = [word_to_idx[w] for w in seq]labels = [tag_to_idx[t] for t in tag]idxs.extend([word_to_idx['<pad>'] for i in range(max_len - len(seq))])labels.extend([tag_to_idx['O'] for i in range(max_len - len(seq))])seq_outputs.append(idxs)label_outputs.append(labels)return ms.Tensor(seq_outputs, ms.int64), \ms.Tensor(label_outputs, ms.int64), \ms.Tensor(seq_length, ms.int64)
data, label, seq_length = prepare_sequence(training_data, word_to_idx, tag_to_idx)
data.shape, label.shape, seq_length.shape
from tqdm import tqdmsteps = 500
with tqdm(total=steps) as t:for i in range(steps):loss = train_step(data, seq_length, label)t.set_postfix(loss=loss)t.update(1)
score, history = model(data, seq_length)
score
predict = post_decode(score, history, seq_length)
predict
idx_to_tag = {idx: tag for tag, idx in tag_to_idx.items()}def sequence_to_tag(sequences, idx_to_tag):outputs = []for seq in sequences:outputs.append([idx_to_tag[i] for i in seq])return outputs
sequence_to_tag(predict, idx_to_tag)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/871136.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

羧基聚乙二醇生物素的制备方法;COOH-PEG-Biotin

羧基聚乙二醇生物素&#xff08;COOH-PEG-Biotin&#xff09;是一种常见的生物分子聚合物&#xff0c;具有多种应用&#xff0c;特别是在生物实验、药物研发和生物技术等领域。以下是对该化合物的详细解析&#xff1a; 一、基本信息 名称&#xff1a;羧基聚乙二醇生物素&#x…

数据结构:链表详解 (c++实现)

前言 对于数据结构的线性表&#xff0c;其元素在逻辑结构上都是序列关系&#xff0c;即数据元素之间有前驱和后继关系。 但在物理结构上有两种存储方式&#xff1a; 顺序存储结构&#xff1a; 使用此结构的线性表也叫 顺序表物理存储上是连续的&#xff0c;因此可以随机访问…

电压反馈型运算放大器的增益和带宽

简介 本教程旨在考察标定运算放大器的增益和带宽的常用方法。需要指出的是&#xff0c;本讨论适用于电压反馈(VFB)型运算放大器。 开环增益 与理想的运算放大器不同&#xff0c;实际的运算放大器增益是有限的。开环直流增益(通常表示为AVOL)指放大器在反馈环路未闭合时的增益…

借人工智能之手,编织美妙歌词篇章

在音乐的领域中&#xff0c;歌词宛如璀璨的明珠&#xff0c;为旋律增添了无尽的魅力和情感深度。然而&#xff0c;对于许多创作者来说&#xff0c;编织出美妙动人的歌词并非易事。但如今&#xff0c;随着科技的飞速发展&#xff0c;人工智能为我们带来了全新的创作可能。 “妙…

Cornerstone3D导致浏览器崩溃的踩坑记录

WebGL: CONTEXT_LOST_WEBGL: loseContext: context lost ⛳️ 问题描述 在使用vue3vite重构Cornerstone相关项目后&#xff0c;在Mac本地运行良好&#xff0c;但是部署测试环境后&#xff0c;在window系统的Chrome浏览器中切换页面会导致页面崩溃。查看Chrome的任务管理器&am…

浅析Kafka Streams消息流式处理流程及原理

以下结合案例&#xff1a;统计消息中单词出现次数&#xff0c;来测试并说明kafka消息流式处理的执行流程 Maven依赖 <dependencies><dependency><groupId>org.apache.kafka</groupId><artifactId>kafka-streams</artifactId><exclusio…

【三维AIGC】扩散模型LDM辅助3D Gaussian重建三维场景

标题&#xff1a;《Sampling 3D Gaussian Scenes in Seconds with Latent Diffusion Models》 来源&#xff1a;Glasgow大学&#xff1b;爱丁堡大学 连接&#xff1a;https://arxiv.org/abs/2406.13099 提示&#xff1a;写完文章后&#xff0c;目录可以自动生成&#xff0c;如何…

Spring Security学习笔记(一)Spring Security架构原理

前言&#xff1a;本系列博客基于Spring Boot 2.6.x依赖的Spring Security5.6.x版本 Spring Security中文文档&#xff1a;https://springdoc.cn/spring-security/index.html 一、什么是Spring Security Spring Security是一个安全控制相关的java框架&#xff0c;它提供了一套全…

海外ASO:iOS与谷歌优化的相同点和区别

海外ASO是针对iOS的App Store和谷歌的Google Play这两个主要海外应用商店进行的优化过程&#xff0c;两个不同的平台需要采取不同的优化策略&#xff0c;以下是对iOS优化和谷歌优化的详细解析&#xff1a; 一、iOS优化&#xff08;App Store&#xff09; 1、关键词覆盖 选择关…

用node.js写一个简单的图书管理界面——功能:添加,删除,修改数据

涉及到的模块&#xff1a; var fs require(‘fs’)——内置模块 var ejs require(‘ejs’)——第三方模块 var mysql require(‘mysql’)——第三方模块 var express require(‘express’)——第三方模块 var bodyParser require(‘body-parser’)——第三方中间件 需要…

打造你的智能家居指挥中心:基于STM32的多协议(zigbee、http)网关(附代码示例)

1. 项目概述 随着物联网技术的蓬勃发展&#xff0c;智能家居正逐步融入人们的日常生活。然而&#xff0c;市面上琳琅满目的智能家居设备通常采用不同的通信协议&#xff0c;导致不同品牌设备之间难以实现互联互通。为了解决这一难题&#xff0c;本文设计了一种基于STM32的多协…

ant design form动态增减表单项Form.List如何进行动态校验规则

项目需求&#xff1a; 在使用ant design form动态增减表单项Form.List时&#xff0c;Form.List中有多组表单项&#xff0c;一组中的最后一个表单项的校验规则是动态的&#xff0c;该组为最后一组时&#xff0c;最后一个表单项是非必填项&#xff0c;其他时候为必填项。假设动态…

docker inspect 如何提取容器的ip和端口 网络信息?

目录 通过原生Linux命令过滤找到IP 通过jq工具找到IP 使用docker -f 的过滤&#xff08;模板&#xff09; 查找端口映射信息 查看容器内部细节 docker inspect 容器ID或容器名 通过原生Linux命令过滤找到IP 通过jq工具找到IP jq 是一个轻量级且灵活的命令行工具&#xf…

(视频演示)基于OpenCV的实时视频跟踪火焰识别软件V1.0源码及exe下载

本文介绍了基于OpenCV的实时视频跟踪火焰识别软件&#xff0c;该软件通过先进的图像处理技术实现对实时视频中火焰的检测与跟踪&#xff0c;同时支持导入图片进行火焰识别。主要功能包括相机选择、实时跟踪和图片模式。软件适用于多种场合&#xff0c;用于保障人民生命财产安全…

OpenGL笔记二之glad加载opengl函数以及opengl-API(函数)初体验

OpenGL笔记二之glad加载opengl函数以及opengl-API(函数)初体验 总结自bilibili赵新政老师的教程 code review! 文章目录 OpenGL笔记二之glad加载opengl函数以及opengl-API(函数)初体验1.运行2.重点3.目录结构4.main.cpp5.CMakeList.txt 1.运行 2.重点 3.目录结构 01_GLFW_WI…

Python-PLAXIS自动化建模技术与典型岩土工程

有限单元法在岩土工程问题中应用非常广泛&#xff0c;很多软件都采用有限单元解法。在使用各大软件进行数值模拟建模的过程中&#xff0c;您是否发现GUI界面中重复性的点击输入工作太繁琐&#xff1f;从而拖慢了设计或方案必选进程&#xff1f; 搭建自己的Plaxis模型&#xff…

设计模式的七大原则

1.单一职责原则 单一职责原则(Single responsibility principle)&#xff0c;即一个类应该只负责一项职责。如类A负责两个不同职责&#xff1a;职责1&#xff0c;职责2。当职责1需求变更而改变A时&#xff0c;可能造成职责2执行错误&#xff0c;所以需要将类A的粒度分解为A1、…

安卓14中Zygote初始化流程及源码分析

文章目录 日志抓取结合日志与源码分析systemServer zygote创建时序图一般应用 zygote 创建时序图向 zygote socket 发送数据时序图 本文首发地址 https://h89.cn/archives/298.html 最新更新地址 https://gitee.com/chenjim/chenjimblog 本文主要结合日志和代码看安卓 14 中 Zy…

C/C++ 进阶(7)模拟实现map/set

个人主页&#xff1a;仍有未知等待探索-CSDN博客 专题分栏&#xff1a;C 一、简介 map和set都是关联性容器&#xff0c;底层都是用红黑树写的。 特点&#xff1a;存的Key值都是唯一的&#xff0c;不重复。 map存的是键值对&#xff08;Key—Value&#xff09;。 set存的是键…

Git的命令使用与IDEA内置git图形化的使用

Git 简介 Git 是分布式版本控制系统&#xff0c;它可以帮助开发人员跟踪和管理代码的更改。Git 可以记录代码的历史记录&#xff0c;并允许您在不同版本之间切换。 通过历史记录可以查看&#xff1a; 进行了哪些更改&#xff1f;谁进行了更改&#xff1f;何时进行了更改&#…