240712_昇思学习打卡-Day24-LSTM+CRF序列标注(3)
今天做LSTM+CRF序列标注第三部分,同样,仅作简单记录及注释,最近确实太忙了。
Viterbi算法
在完成前向训练部分后,需要实现解码部分。这里我们选择适合求解序列最优路径的Viterbi算法。与计算Normalizer类似,使用动态规划求解所有可能的预测序列得分。不同的是在解码时同时需要将第𝑖个Token对应的score取值最大的标签保存,供后续使用Viterbi算法求解最优预测序列使用。
取得最大概率得分ScoreScore,以及每个Token对应的标签历史HistoryHistory后,根据Viterbi算法可以得到公式:
从第0个至第𝑖个Token对应概率最大的序列,只需要考虑从第0个至第𝑖−1个Token对应概率最大的序列,以及从第𝑖个至第𝑖−1个概率最大的标签即可。因此我们逆序求解每一个概率最大的标签,构成最佳的预测序列。
由于静态图语法限制,我们将Viterbi算法求解最佳预测序列的部分作为后处理函数,不纳入后续CRF层的实现。
# 定义维特比解码算法,用于找出具有最大概率的标签序列
def viterbi_decode(emissions, mask, trans, start_trans, end_trans):# emissions: (seq_length, batch_size, num_tags) 发射概率矩阵# mask: (seq_length, batch_size) 序列掩码,用于标记有效序列长度# trans: 转移概率矩阵# start_trans: 初始状态转移概率向量# end_trans: 终止状态转移概率向量seq_length = mask.shape[0] # 获取序列长度# 初始化分数矩阵,等于初始状态转移概率加上第一个发射概率score = start_trans + emissions[0]history = () # 初始化历史路径记录# 遍历序列中的每个时间步for i in range(1, seq_length):# 扩展维度以便广播运算broadcast_score = score.expand_dims(2)broadcast_emission = emissions[i].expand_dims(1)# 计算所有可能的转移分数next_score = broadcast_score + trans + broadcast_emission# 找出当前Token对应的最大分数标签,并保存indices = next_score.argmax(axis=1)history += (indices,) # 保存历史路径信息# 取出最大分数next_score = next_score.max(axis=1)# 更新分数矩阵,只更新mask为True的部分score = mnp.where(mask[i].expand_dims(1), next_score, score)# 加上终止状态转移概率score += end_trans# 返回最终的分数矩阵和历史路径信息return score, history# 根据解码过程中的得分和历史路径信息,重构最优标签序列
def post_decode(score, history, seq_length):# score: 最终得分矩阵# history: 历史路径信息# seq_length: 每个样本的实际序列长度batch_size = seq_length.shape[0] # 获取批次大小seq_ends = seq_length - 1 # 计算每个样本的最后一个Token位置# 初始化最佳标签序列列表best_tags_list = []# 对批次中的每个样本进行解码for idx in range(batch_size):# 找出使最后一个Token对应的预测概率最大的标签best_last_tag = score[idx].argmax(axis=0)best_tags = [int(best_last_tag.asnumpy())] # 添加最佳标签到序列# 从历史路径信息中反向追踪,找到每个Token的最佳标签for hist in reversed(history[:seq_ends[idx]]):best_last_tag = hist[idx][best_tags[-1]]best_tags.append(int(best_last_tag.asnumpy()))# 将逆序的标签序列反转,得到正序的最优标签序列best_tags.reverse()best_tags_list.append(best_tags) # 添加到结果列表# 返回最优标签序列列表return best_tags_list
CRF层
完成上述前向训练和解码部分的代码后,将其组装完整的CRF层。考虑到输入序列可能存在Padding的情况,CRF的输入需要考虑输入序列的真实长度,因此除发射矩阵和标签外,加入seq_length
参数传入序列Padding前的长度,并实现生成mask矩阵的sequence_mask
方法。
综合上述代码,使用nn.Cell
进行封装,最后实现完整的CRF层如下:
# 导入MindSpore相关模块
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore.common.initializer import initializer, Uniform# 定义序列掩码生成函数
def sequence_mask(seq_length, max_length, batch_first=False):"""根据序列的实际长度和最大长度生成mask矩阵。参数:seq_length: 实际序列长度张量。max_length: 序列的最大长度。batch_first: 是否将批次放在第一维度。返回:mask矩阵,形状为(batch_size, max_length),其中True表示有效位置,False表示填充位置。"""# 生成从0到max_length的范围向量range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)# 创建mask矩阵,shape为(seq_length.shape + (1,))result = range_vector < seq_length.view(seq_length.shape + (1,))# 转换数据类型并根据batch_first参数调整维度顺序if batch_first:return result.astype(ms.int64)return result.astype(ms.int64).swapaxes(0, 1)# 定义条件随机场(CRF)模型类
class CRF(nn.Cell):def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:"""初始化CRF模型。参数:num_tags: 标签数量。batch_first: 是否将批次放在第一维度。reduction: 损失函数的缩减方式。"""# 检查标签数量是否有效if num_tags <= 0:raise ValueError(f'无效的标签数量: {num_tags}')super().__init__()# 检查reduction参数是否有效if reduction not in ('none', 'sum', 'mean', 'token_mean'):raise ValueError(f'无效的缩减方式: {reduction}')self.num_tags = num_tags # 标签数量self.batch_first = batch_first # 批次是否在第一维度self.reduction = reduction # 损失函数缩减方式# 初始化起始和结束状态转移权重self.start_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions')self.end_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions')# 初始化状态间转移权重self.transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions')def construct(self, emissions, tags=None, seq_length=None):"""CRF模型的前向传播方法。参数:emissions: 发射概率张量。tags: 真实标签张量。seq_length: 序列长度张量。返回:如果tags为None,则返回解码结果;否则返回损失值。"""if tags is None:return self._decode(emissions, seq_length)return self._forward(emissions, tags, seq_length)def _forward(self, emissions, tags=None, seq_length=None):"""计算损失值。参数:emissions: 发射概率张量。tags: 真实标签张量。seq_length: 序列长度张量。返回:损失值。"""# 根据batch_first参数调整emissions和tags的维度顺序if self.batch_first:batch_size, max_length = tags.shapeemissions = emissions.swapaxes(0, 1)tags = tags.swapaxes(0, 1)else:max_length, batch_size = tags.shape# 如果seq_length未给出,则假设所有序列都是最大长度if seq_length is None:seq_length = mnp.full((batch_size,), max_length, ms.int64)# 生成mask矩阵mask = sequence_mask(seq_length, max_length)# 计算分子部分(真实路径的得分)numerator = compute_score(emissions, tags, seq_length-1, mask, self.transitions, self.start_transitions, self.end_transitions)# 计算分母部分(所有可能路径的得分总和)denominator = compute_normalizer(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)# 计算对数似然比llh = denominator - numerator# 根据reduction参数选择损失值的缩减方式if self.reduction == 'none':return llhelif self.reduction == 'sum':return llh.sum()elif self.reduction == 'mean':return llh.mean()return llh.sum() / mask.astype(emissions.dtype).sum()def _decode(self, emissions, seq_length=None):"""解码方法,用于预测最优标签序列。参数:emissions: 发射概率张量。seq_length: 序列长度张量。返回:最优标签序列。"""# 根据batch_first参数调整emissions的维度顺序if self.batch_first:batch_size, max_length = emissions.shape[:2]emissions = emissions.swapaxes(0, 1)else:batch_size, max_length = emissions.shape[:2]# 如果seq_length未给出,则假设所有序列都是最大长度if seq_length is None:seq_length = mnp.full((batch_size,), max_length, ms.int64)# 生成mask矩阵mask = sequence_mask(seq_length, max_length)# 使用维特比算法解码最优路径return viterbi_decode(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)
打卡图片: