nlp系列(6)文本实体识别(Bi-LSTM+CRF)pytorch

模型介绍

LSTM:长短期记忆网络(Long-short-term-memory),能够记住长句子的前后信息,解决了RNN的问题(时间间隔较大时,网络对前面的信息会遗忘,从而出现梯度消失问题,会形成长期依赖问题),避免长期依赖问题。
Bi-LSTM:由前向LSTM与后向LSTM组合而成。

模型结构

Bi-LSTM

同LSTM,区别在于模型的输出和结构上不同,如下图:

图1 Bi-LSTM的数据输入形式

一共有两个LSTM网络,一个网络从一句话的首段进行学习,另一个网络从一句话的末端进行学习。

相关详情请看nlp系列(5)文本实体识别(LSTM)pytorch 中模型详解

CRF

CRF(条件随机场):是一个判别模型,用于解决标注偏差问题,使用P(Y|X)建模,为全局归一化
适用领域:词性标注、分词、命名实体识别等
以命名实体为例:
在这里插入图片描述
损失计算:
lg ⁡ P ( Y ∣ X ) = − l g e s ( X , Y ) ∑ y ‾ ϵ Y x e s ( X , y ‾ ) = − S ( X , y ) + lg ⁡ ∑ y ‾ ϵ Y x e s ( X , y ‾ ) \lg P(Y|X) = -lg \frac{e^s(X,Y)}{\sum_{\overline{y}\epsilon Y_x}{e^s(X,\overline y)}} = - S(X, y) + \lg\sum_{\overline{y}\epsilon Y_x}{e^s(X,\overline y)} lgP(YX)=lgyϵYxes(X,y)es(X,Y)=S(X,y)+lgyϵYxes(X,y)
推荐一个视频讲解,全程手写推导,讲得很细
机器学习-白板推导系列(十七)-条件随机场CRF(Conditional Random Field)

数据介绍

数据集用的是论文【ACL 2018Chinese NER using Lattice LSTM】中从新浪财经收集的简历数据。每一句话用换行进行隔开。

图2 数据样式

模型准备

方法一:使用ptorch库自带的CRF库,其CRF库关键函数介绍链接

    def forward(self, sentence, tags=None, mask=None):# sentence=(batch, seq_len)   tags=(batch, seq_len)  masks=(batch, seq_len)# 1. 从 sentence 到 Embedding 层embeds = self.word_embeds(sentence).permute(1, 0, 2)  # shape [seq_len, batch_size, embedding_size]# 2. 从 Embedding 层到 Bi-LSTM 层# Bi-lstm 层的隐藏节点设置# 隐藏层就是(h_0, c_0)    num_directions = 2 if self.bidirectional else 1# h_0 的结构:(num_layers*num_directions, batch_size, hidden_size)self.hidden = (torch.randn(2, sentence.shape[0], self.hidden_dim // 2, device=self.device),torch.randn(2, sentence.shape[0], self.hidden_dim // 2, device=self.device))# input=(seq_length, batch_size, embedding_num)# output(lstm_out)=(seq_length, batch_size, num_directions * hidden_size)# h_0 = (num_layers*num_directions, batch_size, hidden_size)lstm_out, self.hidden = self.lstm(embeds, self.hidden)# 3. 从 Bi-LSTM 层到全连接层# 从 Bi-lstm 的输出转为 target_size 长度的向量组(即输出了每个 tag 的可能性)# 输出 shape=(seq_length, batch_size, len(tag_to_ix))lstm_feats = self.linear(lstm_out)# 4. 全连接层到 CRF 层if tags is not None:# 训练用if mask is not None:loss = -1. * self.crf(emissions=lstm_feats.permute(1, 0, 2), tags=tags, mask=mask, reduction='mean')# outputs=(batch_size,)   输出 log 形式的 likelihoodelse:loss = -1. * self.crf(emissions=lstm_feats.permute(1, 0, 2), tags=tags, reduction='mean')return losselse:# 测试if mask is not None:prediction = self.crf.decode(emissions=lstm_feats.permute(1, 0, 2), mask=mask)else:prediction = self.crf.decode(emissions=lstm_feats.permute(1, 0, 2))return prediction

方法2:编写CRF实现代码

def argmax(vec):"""返回 vec 中每一行最大的那个元素的下标"""# return the argmax as a python int_, idx = torch.max(vec, 1)# 获取该元素:tensor只有一个元素才能调用item方法return idx.item()def log_sum_exp(vec, device):"""vec 维度为 1*5Compute log sum exp in a numerically stable way for the forward algorithm前向算法是不断累积之前的结果,这样就会有个缺点指数和累积到一定程度后,会超过计算机浮点值的最大值,变成inf,这样取log后也是inf为了避免这种情况,用一个合适的值clip去提指数和的公因子,这样就不会使某项变得过大而无法计算计算一维向量 vec 与其最大值的 log_sum_exp"""max_score = vec[0, argmax(vec)]  # max_score的维度为1max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])  # 维度为 1*5return max_score.to(device) + torch.log(torch.sum(torch.exp(vec - max_score_broadcast))).to(device)class BiLSTM_CRF(nn.Module):def __init__(self, vocab_size, tag_to_index, embedding_dim, hidden_dim):# 调用父类的initsuper(BiLSTM_CRF, self).__init__()self.embedding_dim = embedding_dim  # word embedding dim  嵌入维度: 词向量维度self.hidden_dim = hidden_dim  # Bi-LSTM hidden dim  隐藏层维度self.vocab_size = vocab_size  # 词汇量大小self.tag_to_index = tag_to_index  # 标签转下标的词典self.target_size = len(tag_to_index)  # 输出维度:目标取值范围大小,标签预测类别数self.device = "cuda:0" if torch.cuda.is_available() else "cpu"''' Embedding 的用法A simple lookup table that stores embeddings of a fixed dictionary and size.This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings.一个简单的查找表,用于存储固定字典和大小的嵌入。该模块通常用于存储词嵌入并使用索引检索它们。模块的输入是索引列表,输出是相应的词嵌入。requires_grad: 用于说明当前量是否需要在计算中保留对应的梯度信息'''self.word_embeds = nn.Embedding(vocab_size, embedding_dim)'''embedding_dim:特征维度hidden_dim:隐藏层层数num_layers:循环层数bidirectional:是否采用 Bi-LSTM(前向LSTM+反向LSTM)'''self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=1, bidirectional=True)# 将 Bi-LSTM 提取的特征向量映射到特征空间,即经过全连接得到发射分数self.hidden2tag = nn.Linear(hidden_dim, self.target_size)# 转移矩阵的参数初始化,transitions[i,j]代表的是从第j个tag转移到第i个tag的转移分数# 转移矩阵是随机的,在网络中会随着训练不断更新self.transitions = nn.Parameter(torch.randn(self.target_size, self.target_size))# 初始化所有其他 tag 转移到 START_TAG 的分数非常小,即不可能由其他 tag 转移到 START_TAG# 初始化 STOP_TAG 转移到所有其他 tag 的分数非常小,即不可能由 STOP_TAG 转移到其他 tag# 转移矩阵: 列标 转 行标# 规定:其他 tag 不能转向 start,stop 也不能转向其他 tagself.transitions.data[self.tag_to_index[START_TAG], :] = -10000  # 从任何标签转移到 START_TAG 不可能self.transitions.data[:, self.tag_to_index[STOP_TAG]] = -10000  # 从 STOP_TAG 转移到任何标签不可能# 初始化 hidden layerself.hidden = self.init_hidden()def init_hidden(self):# 初始化 Bi-LSTM 的参数 h_0, c_0return (torch.randn(2, 1, self.hidden_dim // 2).to(self.device),torch.randn(2, 1, self.hidden_dim // 2).to(self.device))def _get_lstm_features(self, sentence):# 通过 Bi-LSTM 提取特征self.hidden = self.init_hidden()embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)'''默认参数意义:input_size,hidden_size,num_layershidden_size : LSTM在运行时里面的维度。隐藏层状态的维数,即隐藏层节点的个数torch里的LSTM单元接受的输入都必须是3维的张量(Tensors):第一维体现的每个句子的长度,即提供给LSTM神经元的每个句子的长度,如果是其他的带有带有序列形式的数据,则表示一个明确分割单位长度,第二维度体现的是batch_size,即每一次给网络句子条数第三维体现的是输入的元素,即每个具体的单词用多少维向量来表示'''lstm_out, self.hidden = self.lstm(embeds, self.hidden)lstm_out = lstm_out.view(len(sentence), self.hidden_dim)lstm_feats = self.hidden2tag(lstm_out)return lstm_featsdef _score_sentence(self, feats, tags):"""CRF 的输出,即 emit + transition scores"""# 计算给定 tag 序列的分数,即一条路径的分数score = torch.zeros(1).to(self.device)tags = torch.cat([torch.tensor([self.tag_to_index[START_TAG]], dtype=torch.long).to(self.device), tags])# 转移 + 前向for i, feat in enumerate(feats):# 递推计算路径分数:转移分数 + 发射分数score = score + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]score = score + self.transitions[self.tag_to_index[STOP_TAG], tags[-1]]return scoredef _forward_alg(self, feats):  # 预测序列的得分,就是 Loss 的右边第一项"""前向算法:feats 表示发射矩阵(emit score),是 Bi-LSTM 所有时间步的输出 意思是经过 Bi-LSTM 的 sentence 的每个 word 对应于每个 label 的得分"""# 通过前向算法递推计算 alpha 初始为 -10000init_alphas = torch.full((1, self.target_size), -10000.).to(self.device)  # 用-10000.来填充一个形状为[1,target_size]的tensor# 初始化 step 0 即 START 位置的发射分数,START_TAG 取 0 其他位置取 -10000  start 位置的 alpha 为 0# 因为 start tag 是4,所以tensor([[-10000., -10000., -10000., 0., -10000.]]),# 将 start 的值为零,表示开始进行网络的传播,init_alphas[0][self.tag_to_index[START_TAG]] = 0.# 将初始化 START 位置为 0 的发射分数赋值给 previous  包装进变量,实现自动反向传播previous = init_alphas# 迭代整个句子for obs in feats:# The forward tensors at this timestep# 当前时间步的前向 tensoralphas_t = []for next_tag in range(self.target_size):# 取出当前tag的发射分数,与之前时间步的tag无关'''Bi-LSTM 生成的矩阵是 emit score[观测/发射概率], 即公式中的H()函数的输出CRF 是判别式模型emit score: Bi-LSTM 对序列中每个位置的对应标签打分的和transition score: 是该序列状态转移矩阵中对应的和Score = EmissionScore + TransitionScore'''# Bi-LSTM的生成矩阵是 emit_score,维度为 1*5emit_score = obs[next_tag].view(1, -1).expand(1, self.target_size).to(self.device)# 取出当前 tag 由之前 tag 转移过来的转移分数trans_score = self.transitions[next_tag].view(1, -1)# 当前路径的分数:之前时间步分数 + 转移分数 + 发射分数next_tag_var = previous.to(self.device) + trans_score.to(self.device) + emit_score.to(self.device)# 对当前分数取 log-sum-expalphas_t.append(log_sum_exp(next_tag_var, self.device).view(1))# 更新 previous 递推计算下一个时间步previous = torch.cat(alphas_t).view(1, -1)# 考虑最终转移到 STOP_TAGterminal_var = previous + self.transitions[self.tag_to_index[STOP_TAG]]# 计算最终的分数scores = log_sum_exp(terminal_var, self.device)return scores.to(self.device)def _viterbi_decode(self, feats):"""Decoding的意义:给定一个已知的观测序列,求其最有可能对应的状态序列"""# 预测序列的得分,维特比解码,输出得分与路径值backpointers = []# 初始化 viterbi 的 previous 变量init_vvars = torch.full((1, self.target_size), -10000.).cpu()  # 这就保证了一定是从START到其他标签init_vvars[0][self.tag_to_index[START_TAG]] = 0# 第 i 步的 forward_var 保存第 i-1 步的维特比变量previous = init_vvarsfor obs in feats:# 保存当前时间步的回溯指针bptrs_t = []# 保存当前时间步的 viterbi 变量viterbivars_t = []for next_tag in range(self.target_size):# 其他标签(B,I,E,Start,End)到标签next_tag的概率# 维特比算法记录最优路径时只考虑上一步的分数以及上一步 tag 转移到当前 tag 的转移分数# 并不取决与当前 tag 的发射分数next_tag_var = previous.cpu() + self.transitions[next_tag].cpu()  # previous 保存的是之前的最优路径的值# 找到此刻最好的状态转入点best_tag_id = argmax(next_tag_var)  # 返回最大值对应的那个tag# 记录点bptrs_t.append(best_tag_id)viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))# 更新 previous,加上当前 tag 的发射分数 obs# 从 step0 到 step(i-1) 时 5 个序列中每个序列的最大 scoreprevious = (torch.cat(viterbivars_t).cpu() + obs.cpu()).view(1, -1)# 回溯指针记录当前时间步各个 tag 来源前一步的 tagbackpointers.append(bptrs_t)# 考虑转移到 STOP_TAG 的转移分数# 其他标签到STOP_TAG的转移概率terminal_var = previous.cpu() + self.transitions[self.tag_to_index[STOP_TAG]].cpu()best_tag_id = argmax(terminal_var)path_score = terminal_var[0][best_tag_id]# 通过回溯指针解码出最优路径best_path = [best_tag_id]# best_tag_id 作为线头,反向遍历 backpointers 找到最优路径for bptrs_t in reversed(backpointers):best_tag_id = bptrs_t[best_tag_id]best_path.append(best_tag_id)# 去除 START_TAGstart = best_path.pop()assert start == self.tag_to_index[START_TAG]  # Sanity checkbest_path.reverse()  # 把从后向前的路径正过来return path_score, best_pathdef neg_log_likelihood(self, sentence, tags):# CRF 损失函数由两部分组成,真实路径的分数和所有路径的总分数。# 真实路径的分数应该是所有路径中分数最高的。# log 真实路径的分数/log所有可能路径的分数,越大越好,构造 crf loss 函数取反,loss 越小越好feats = self._get_lstm_features(sentence)  # 经过LSTM+Linear后的输出作为CRF的输入# 前向算法分数forward_score = self._forward_alg(feats)  # loss的log部分的结果# 真实分数gold_score = self._score_sentence(feats, tags)  # loss的后半部分S(X,y)的结果# log P(y|x) = forward_score - gold_scorereturn forward_score - gold_score# 这里 Bi-LSTM 和 CRF 共同前向输出def forward(self, sentence):"""重写原 module 里的 forward"""sentence = sentence.reshape(-1)# 通过 Bi-LSTM 提取发射分数lstm_feats = self._get_lstm_features(sentence)# 根据发射分数以及转移分数,通过 viterbi 解码找到一条最优路径score, tag_seq = self._viterbi_decode(lstm_feats)return score, tag_seq

模型预测

注:模型只训练了一轮,预测结果与实际会有差异。
方法一:
在这里插入图片描述

图3 方法1预测结果
方法二:

在这里插入图片描述

图4 方法2预测结果

源码获取

Bi-LSTM-CRF 实体识别

硬性的标准其实限制不了无限可能的我们,所以啊!少年们加油吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/7920.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开源项目注意事项

fork项目后,记得另外开启一个分支然后在新分支上进行开发,push到仓库后从分支往原项目提交。 否则会出现Partially verified(导致提交pr后auto-merge失败) 注意git提交操作 https://blog.csdn.net/sonichenn/article/details/13…

flask中的werkzeug介绍

flask中的werkzeug Werkzeug是一个Python库,用于开发Web应用程序。它是一个WSGI(Web Server Gateway Interface)工具包,提供了一系列实用功能来帮助开发者处理HTTP请求、响应、URLs等等。Werkzeug的设计非常灵活,可以…

请问学JavaScript 前要学html 和css 吗?

前言 html和css可以理解为是一个网站的骨架和皮肤,这两部分做好后整个网站的外观展示的完成度基本就有了个90%左右,所以在学习js前是需要学习html和css 的,这两部分不用花特别多的时间(虽然css如果想做一些非常炫酷的效果个人认为…

vue中重新获取数据导致页面加长,要求在页面更新之后浏览器滚动条滚动到之前浏览记录的位置。以及获取当前页面中是哪个元素产生滚动条的方法。

目前的页面样式为&#xff1a; 代码是&#xff1a; <section id"detailSection"><el-tableref"multipleTable":data"logDetailList"style"width: 650px;margin:20px auto;"id"dialogDetail":show-header"fals…

App测试流程及测试点

1 APP测试基本流程 1.1流程图 1.2测试周期 测试周期可按项目的开发周期来确定测试时间&#xff0c;一般测试时间为两三周&#xff08;即15个工作日&#xff09;&#xff0c;根据项目情况以及版本质量可适当缩短或延长测试时间。正式测试前先向主管确认项目排期。 1.3测试资源…

启动es容器错误

说明&#xff1a;启动es容器&#xff0c;刚启动就停止&#xff0c;查看日志&#xff0c;出现以下错误信息&#xff08;java.lang.IllegalArgumentException: Plugin [analysis-ik] was built for Elasticsearch version 8.8.2 but version 7.12.1 is running&#xff09; 解决&…

【2023】HashMap详细源码分析解读

前言 在弄清楚HashMap之前先介绍一下使用到的数据结构&#xff0c;在jdk1.8之后HashMap中为了优化效率加入了红黑树这种数据结构。 树 在计算机科学中&#xff0c;树&#xff08;英语&#xff1a;tree&#xff09;是一种抽象数据类型&#xff08;ADT&#xff09;或是实作这种…

数据结构【栈和队列】

第三章 栈与队列 一、栈 1.定义&#xff1a;只允许一端进行插入和删除的线性表&#xff0c;结构与手枪的弹夹差不多&#xff0c;可以作为实现递归函数&#xff08;调用和返回都是后进先出&#xff09;调用的一种数据结构&#xff1b; 栈顶&#xff1a;允许插入删除的那端&…

网络知识点之-BGP协议

边界网关协议&#xff08;BGP&#xff09;是运行于 TCP 上的一种自治系统的路由协议。 BGP 是唯一一个用来处理像因特网大小的网络的协议&#xff0c;也是唯一能够妥善处理好不相关路由域间的多路连接的协议。 BGP 构建在 EGP 的经验之上。 BGP 系统的主要功能是和其他的 BGP 系…

特征选择策略:为检测乳腺癌生物标志物寻找新出口

内容一览&#xff1a;microRNA&#xff08;小分子核糖核酸&#xff09;是一类短小的单链非编码 RNA 转录体。这些分子在多种恶性肿瘤中呈现失控性生长&#xff0c;因此近年来被诸多研究确定为确诊癌症的可靠的生物标志物 (biomarker)。在多种病理分析中&#xff0c;差异表达分析…

vue3下的uniapp跨域踩坑

uniapp vue3 H5跨域踩坑 开发移动端H5的时候由于后端接口没有做跨域处理&#xff0c;因此需要做下服务器代理&#xff0c;于是百度搜索了uniapp下h5的跨域配置 在manifest下的h5配置proxy&#xff0c;大概是这样: "h5": {"devServer": {"https"…

安全—01day

文章目录 1. 编码1.1 ASCLL编码1.2 URL编码1.3 Unicode编码1.4 HTML编码1.5 Base64编码 2. form表单2.1 php接收form表单2.2 python接收form表单 1. 编码 1.1 ASCLL编码 ASCII 是基于拉丁字母的一套电脑编码系统&#xff0c;主要用于显示现代英语和其他西欧语言。它是最通用的…

ajax/axios访问后端测试方法

文章目录 1、浏览器执行javascript方法GET请求POST请求 2、Postman测试工具GET请求POST请求 3、idea IDE提供的httpclient4、Apache JMeter 1、浏览器执行javascript方法 GET请求 http://localhost:6060/admin/get/123 POST请求 技巧&#xff1a;打开谷歌浏览器&#xff0c…

C数据结构与算法——队列 应用(C语言纯享版 迷宫)

实验任务 (1) 掌握顺序循环队列及其C语言的表示&#xff1b; (2) 掌握入队、出队等基本算法的实现&#xff1b; (3) 掌握顺序循环队列的基本应用&#xff08;求解迷宫通路&#xff09;。 实验内容 使用C语言实现顺序循环队列的类型定义与算法函数&#xff1b;编写main()函数…

算法与数据结构(三)--栈

一.栈的基本概念 栈是一种特殊的表&#xff0c;这种表只在表首进行插入和删除操作。 因此&#xff0c;表首对于栈来说具有特殊的意义&#xff0c;称为栈顶。相应的&#xff0c;表尾称为栈底。不含任何元素的栈称为空栈。 栈的修改遵循后进先出的原则&#xff0c;Last In First…

Zabbix邮件报警(163网易邮箱)

目录 一、电脑登录网易邮箱配置 二、Server端安装配置邮件服务器 邮箱查看 三、编辑zabbix_server.conf 引用邮件脚本 查看邮件 五、配置zabbix web监控项邮件报警 操作思路 Server.zabbix.com web操作 确认报警媒介信息 配置zabbix中的用户所使用的报警媒介类型以及接收邮…

【网络】HTTPS协议

目录 一、概念 1、HTTPS 2、加密解密 3、加密的必要性 4、常见的加密方式 4.1、对称加密 4.2、非对称加密 5、数据摘要 && 数据指纹 6、数字签名 二、HTTPS的工作过程 1、只使用对称加密 2、只使用非对称加密 3、双方都使用非对称加密 4、非对称加密 对…

rust gtk 桌面应用 demo

《精通Rust》里介绍了 GTK框架的开发&#xff0c;这篇博客记录并扩展一下。rust 可以用于桌面应用开发&#xff0c;我还挺惊讶的&#xff0c;大学的时候也有学习过 VC&#xff0c;对桌面编程一直都很感兴趣&#xff0c;而且一直有一种妄念&#xff0c;总觉得自己能开发一款很好…

深入学习 Redis - 深挖经典数据类型之 set

目录 前言 一、Set 类型 1.1、操作命令 sadd / smembers&#xff08;添加&#xff09; sismember&#xff08;判断存在&#xff09; scard&#xff08;获取元素个数&#xff09; spop&#xff08;删除元素&#xff09; smove&#xff08;移动&#xff09; srem&#x…

Golang time 包以及日期函数

time 包 在 golang 中 time 包提供了时间的显示和测量用的函数。 time.Now()获取当前时间 可以通过 time.Now()函数获取当前的时间对象&#xff0c;然后获取时间对象的年月日时分秒等信息。 示例代码如下&#xff1a; package mainimport ("fmt""time" )…