240712_昇思学习打卡-Day24-LSTM+CRF序列标注(3)

240712_昇思学习打卡-Day24-LSTM+CRF序列标注(3)

今天做LSTM+CRF序列标注第三部分,同样,仅作简单记录及注释,最近确实太忙了。

Viterbi算法

在完成前向训练部分后,需要实现解码部分。这里我们选择适合求解序列最优路径的Viterbi算法。与计算Normalizer类似,使用动态规划求解所有可能的预测序列得分。不同的是在解码时同时需要将第𝑖个Token对应的score取值最大的标签保存,供后续使用Viterbi算法求解最优预测序列使用。

取得最大概率得分ScoreScore,以及每个Token对应的标签历史HistoryHistory后,根据Viterbi算法可以得到公式:

请添加图片描述

从第0个至第𝑖个Token对应概率最大的序列,只需要考虑从第0个至第𝑖−1个Token对应概率最大的序列,以及从第𝑖个至第𝑖−1个概率最大的标签即可。因此我们逆序求解每一个概率最大的标签,构成最佳的预测序列。

由于静态图语法限制,我们将Viterbi算法求解最佳预测序列的部分作为后处理函数,不纳入后续CRF层的实现。

# 定义维特比解码算法,用于找出具有最大概率的标签序列
def viterbi_decode(emissions, mask, trans, start_trans, end_trans):# emissions: (seq_length, batch_size, num_tags) 发射概率矩阵# mask: (seq_length, batch_size) 序列掩码,用于标记有效序列长度# trans: 转移概率矩阵# start_trans: 初始状态转移概率向量# end_trans: 终止状态转移概率向量seq_length = mask.shape[0]  # 获取序列长度# 初始化分数矩阵,等于初始状态转移概率加上第一个发射概率score = start_trans + emissions[0]history = ()  # 初始化历史路径记录# 遍历序列中的每个时间步for i in range(1, seq_length):# 扩展维度以便广播运算broadcast_score = score.expand_dims(2)broadcast_emission = emissions[i].expand_dims(1)# 计算所有可能的转移分数next_score = broadcast_score + trans + broadcast_emission# 找出当前Token对应的最大分数标签,并保存indices = next_score.argmax(axis=1)history += (indices,)  # 保存历史路径信息# 取出最大分数next_score = next_score.max(axis=1)# 更新分数矩阵,只更新mask为True的部分score = mnp.where(mask[i].expand_dims(1), next_score, score)# 加上终止状态转移概率score += end_trans# 返回最终的分数矩阵和历史路径信息return score, history# 根据解码过程中的得分和历史路径信息,重构最优标签序列
def post_decode(score, history, seq_length):# score: 最终得分矩阵# history: 历史路径信息# seq_length: 每个样本的实际序列长度batch_size = seq_length.shape[0]  # 获取批次大小seq_ends = seq_length - 1  # 计算每个样本的最后一个Token位置# 初始化最佳标签序列列表best_tags_list = []# 对批次中的每个样本进行解码for idx in range(batch_size):# 找出使最后一个Token对应的预测概率最大的标签best_last_tag = score[idx].argmax(axis=0)best_tags = [int(best_last_tag.asnumpy())]  # 添加最佳标签到序列# 从历史路径信息中反向追踪,找到每个Token的最佳标签for hist in reversed(history[:seq_ends[idx]]):best_last_tag = hist[idx][best_tags[-1]]best_tags.append(int(best_last_tag.asnumpy()))# 将逆序的标签序列反转,得到正序的最优标签序列best_tags.reverse()best_tags_list.append(best_tags)  # 添加到结果列表# 返回最优标签序列列表return best_tags_list

CRF层

完成上述前向训练和解码部分的代码后,将其组装完整的CRF层。考虑到输入序列可能存在Padding的情况,CRF的输入需要考虑输入序列的真实长度,因此除发射矩阵和标签外,加入seq_length参数传入序列Padding前的长度,并实现生成mask矩阵的sequence_mask方法。

综合上述代码,使用nn.Cell进行封装,最后实现完整的CRF层如下:

# 导入MindSpore相关模块
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore.common.initializer import initializer, Uniform# 定义序列掩码生成函数
def sequence_mask(seq_length, max_length, batch_first=False):"""根据序列的实际长度和最大长度生成mask矩阵。参数:seq_length: 实际序列长度张量。max_length: 序列的最大长度。batch_first: 是否将批次放在第一维度。返回:mask矩阵,形状为(batch_size, max_length),其中True表示有效位置,False表示填充位置。"""# 生成从0到max_length的范围向量range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)# 创建mask矩阵,shape为(seq_length.shape + (1,))result = range_vector < seq_length.view(seq_length.shape + (1,))# 转换数据类型并根据batch_first参数调整维度顺序if batch_first:return result.astype(ms.int64)return result.astype(ms.int64).swapaxes(0, 1)# 定义条件随机场(CRF)模型类
class CRF(nn.Cell):def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:"""初始化CRF模型。参数:num_tags: 标签数量。batch_first: 是否将批次放在第一维度。reduction: 损失函数的缩减方式。"""# 检查标签数量是否有效if num_tags <= 0:raise ValueError(f'无效的标签数量: {num_tags}')super().__init__()# 检查reduction参数是否有效if reduction not in ('none', 'sum', 'mean', 'token_mean'):raise ValueError(f'无效的缩减方式: {reduction}')self.num_tags = num_tags  # 标签数量self.batch_first = batch_first  # 批次是否在第一维度self.reduction = reduction  # 损失函数缩减方式# 初始化起始和结束状态转移权重self.start_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions')self.end_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions')# 初始化状态间转移权重self.transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions')def construct(self, emissions, tags=None, seq_length=None):"""CRF模型的前向传播方法。参数:emissions: 发射概率张量。tags: 真实标签张量。seq_length: 序列长度张量。返回:如果tags为None,则返回解码结果;否则返回损失值。"""if tags is None:return self._decode(emissions, seq_length)return self._forward(emissions, tags, seq_length)def _forward(self, emissions, tags=None, seq_length=None):"""计算损失值。参数:emissions: 发射概率张量。tags: 真实标签张量。seq_length: 序列长度张量。返回:损失值。"""# 根据batch_first参数调整emissions和tags的维度顺序if self.batch_first:batch_size, max_length = tags.shapeemissions = emissions.swapaxes(0, 1)tags = tags.swapaxes(0, 1)else:max_length, batch_size = tags.shape# 如果seq_length未给出,则假设所有序列都是最大长度if seq_length is None:seq_length = mnp.full((batch_size,), max_length, ms.int64)# 生成mask矩阵mask = sequence_mask(seq_length, max_length)# 计算分子部分(真实路径的得分)numerator = compute_score(emissions, tags, seq_length-1, mask, self.transitions, self.start_transitions, self.end_transitions)# 计算分母部分(所有可能路径的得分总和)denominator = compute_normalizer(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)# 计算对数似然比llh = denominator - numerator# 根据reduction参数选择损失值的缩减方式if self.reduction == 'none':return llhelif self.reduction == 'sum':return llh.sum()elif self.reduction == 'mean':return llh.mean()return llh.sum() / mask.astype(emissions.dtype).sum()def _decode(self, emissions, seq_length=None):"""解码方法,用于预测最优标签序列。参数:emissions: 发射概率张量。seq_length: 序列长度张量。返回:最优标签序列。"""# 根据batch_first参数调整emissions的维度顺序if self.batch_first:batch_size, max_length = emissions.shape[:2]emissions = emissions.swapaxes(0, 1)else:batch_size, max_length = emissions.shape[:2]# 如果seq_length未给出,则假设所有序列都是最大长度if seq_length is None:seq_length = mnp.full((batch_size,), max_length, ms.int64)# 生成mask矩阵mask = sequence_mask(seq_length, max_length)# 使用维特比算法解码最优路径return viterbi_decode(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)

打卡图片:

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/45877.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从“Hello,World”谈起(C++入门)

前言 c的发展史及c能干什么不能干什么不是我们今天的重点&#xff0c;不在这里展开&#xff0c;有兴趣的朋友可以自行查阅相关资料。今天我们主要是围绕c的入门程序&#xff0c;写一个“hello&#xff0c;world”&#xff0c;并且围绕这个入门程序简单介绍一下c和c的一些语法&…

C++ Qt 自制开源科学计算器

C Qt 自制开源科学计算器 项目地址 软件下载地址 目录 0. 效果预览1. 数据库准备2. 按键&快捷键说明3. 颜色切换功能(初版)4. 未来开发展望5. 联系邮箱 0. 效果预览 普通计算模式效果如下&#xff1a; 科学计算模式效果如下&#xff1a; 更具体的功能演示视频见如下链接…

stm32入门-----初识stm32

目录 前言 ARM stm32 1.stm32家族 2.stm32的外设资源 3.命名规则 4.系统结构 ​编辑 5.引脚定义 6.启动配置 7.STM32F103C8T6芯片 8.STM32F103C8T6芯片原理图与最小系统电路 前言 已经很久没跟新了&#xff0c;上次发文的时候是好几个月之前了&#xff0c;现在我是想去…

论文分享|NeurIPS2022‘华盛顿大学|俄罗斯套娃表示学习(OpenAI使用的文本表示学习技术)

论文题目&#xff1a;Matryoshka Representation Learning 来源&#xff1a;NeurIPS2022/华盛顿大学谷歌 方向&#xff1a;表示学习 开源地址&#xff1a;https://github.com/RAIVNLab/MRL 摘要 学习表征对于现代机器学习很重要&#xff0c;广泛用于很多下游任务。大多数情…

java配置nginx网络安全,防止国外ip访问,自动添加黑名单,需手动重新加载nginx

通过访问日志自动添加国外ip黑名单 创建一个类&#xff0c;自己添加一个main启动类即可测试 import lombok.AccessLevel; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.json.JSONArray; import org.json.JSONObject; import org.sp…

面试经验之谈

优质博文&#xff1a;IT-BLOG-CN ​通常面试官会把每一轮面试分为三个环节&#xff1a;① 行为面试 ② 技术面试 ③ 应聘者提问 行为面试环节 面试开始的5~10分钟通常是行为面试的时间&#xff0c;面试官会参照简历和你的自我介绍了解应聘者的过往经验和项目经历。由于面试官…

nodejs模板引擎(一)

在 Node.js 中使用模板引擎可以让您更轻松地生成动态 HTML 页面&#xff0c;通过将静态模板与动态数据结合&#xff0c;您可以创建可维护且易于扩展的 Web 应用程序。以下是一个使用 Express 框架和 EJS 模板引擎的基本示例&#xff1a; 安装必要的依赖&#xff1a; 首先&#…

分享浏览器被hao123网页劫持,去除劫持的方式

昨天看python相关的自动化工作代码时&#xff0c;发现谷歌浏览器被hao123劫持了&#xff0c;把那些程序删了也不管用 方法1&#xff1a;删除hao123注册表&#xff0c;这个方式不太好用&#xff0c;会找不到注册表 方法2&#xff1a;看浏览器快捷方式的属性页面&#xff0c;一…

【C++】入门基础(命名空间、缺省参数、函数重载)

目录 一.命名空间&#xff1a;namespace 1.namespace的价值 2.namespace的定义 3.namespace的使用方法 3.1 域解析运算符:: 3.2 using展开 3.3 using域解析运算符 二.输入输出 三.缺省参数 四.函数重载 1.参数类型不同 2.参数个数不同 3.参数顺序不同 一.命名空间&…

APP专项测试之网络测试

背景 当前app网络环境比较复杂&#xff0c;越来越多的公共wifi&#xff0c;网络制式有2G、3G、4G网络&#xff0c;会对用户使用app造成一定影响&#xff1b;当前app使用场景多变&#xff0c;如进地铁、上公交、进电梯等&#xff0c;使得弱网测试显得尤为重要&#xff1b; 网络正…

链路追踪系列-02.演示zipkin

当本机启动docker es zipkinServer之后&#xff1a; 启动3个项目&#xff1a;先eureka-server&#xff0c;再 PaymentMain8001,… 浏览器打开&#xff1a;http://localhost:9001/consumer/payment/zipkin consumer代码 &#xff1a; provider: 此时查询es:

3-2 多层感知机的从零开始实现

import torch from torch import nn from d2l import torch as d2lbatch_size 256 # 批量大小为256 train_iter, test_iter d2l.load_data_fashion_mnist(batch_size) # load进来训练集和测试集初始化模型参数 回想一下&#xff0c;Fashion-MNIST中的每个图像由 28 28 784…

学习C++,应该循序渐进的看哪些书?

学习C是一个循序渐进的过程&#xff0c;需要根据自己的基础和目标来选择合适的书籍。以下是一个推荐的学习路径&#xff0c;包含了从入门到进阶的书籍&#xff1a; 1. 入门阶段 《C Primer Plus 第6版 中文版》 推荐理由&#xff1a;这本书同样适合C零基础的学习者&#xff0…

[CISCN2018]2ex

啊!好恶心的mips寄存器 好多IDA都查不到,这寄存器~! fuck! 但是这种寄存器一般的题都不难 这道题就是 我用平常的方法,没找到 左边函数一个一个点 看见这里0X3F base64 密文呢? 我giao 外面的txt文件里面 脚本 import base64 import string# 定义你的自定义字符集 st…

聊点基础---Java和.NET开发技术异同全方位分析

1. C#语言基础 1.1 C#语法概览 欢迎来到C#的世界&#xff01;对于刚从Java转过来的开发者来说&#xff0c;你会发现C#和Java有很多相似之处&#xff0c;但C#也有其独特的魅力和强大之处。让我们一起来探索C#的基本语法&#xff0c;并比较一下与Java的异同。 程序结构 C#程序…

美团收银Android一面凉经(2024)

美团收银Android一面凉经(2024) 笔者作为一名双非二本毕业7年老Android, 最近面试了不少公司, 目前已告一段落, 整理一下各家的面试问题, 打算陆续发布出来, 供有缘人参考。今天给大家带来的是《美团收银Android一面凉经(2024)》。 应聘岗位: 美团餐饮PaaS平台Android开发工程师…

【2-1:RPC设计】

RPC 1. 基础1.1 定义&特点1.2 具体实现框架1.3 应用场景2. RPC的关键技术点&一次调用rpc流程2.1 RPC流程流程两个网络模块如何连接的呢?其它特性RPC优势2.2 序列化技术序列化方式PRC如何选择序列化框架考虑因素2.3 应用层的通信协议-http2.3.1 基础概念大多数RPC大多自…

【C++ | 虚函数】虚函数详解 及 例子代码演示(包含虚函数使用、动态绑定、虚函数表、虚表指针)

&#x1f601;博客主页&#x1f601;&#xff1a;&#x1f680;https://blog.csdn.net/wkd_007&#x1f680; &#x1f911;博客内容&#x1f911;&#xff1a;&#x1f36d;嵌入式开发、Linux、C语言、C、数据结构、音视频&#x1f36d; &#x1f923;本文内容&#x1f923;&a…

Matlab-Simulink模型保存为图片的方法

有好多种办法将模型保存为图片&#xff0c;这里直接说经常用的 而且贴到Word文档中清晰、操作简单。 simulink自带有截图功能&#xff0c;这两种方法都可以保存模型图片。选择后直接就复制到截切板上了。直接去文档中粘贴就完事了。 这两个格式效果不太一样&#xff0c;第一种清…

JS登录页源码 —— 可一键复制抱走

前期回顾 https://blog.csdn.net/m0_57904695/article/details/139838176?spm1001.2014.3001.5501https://blog.csdn.net/m0_57904695/article/details/139838176?spm1001.2014.3001.5501 登录页预览效果 <!DOCTYPE html> <html lang"en"><head…