BiLSTM-CRF的中文命名实体识别

项目地址:NLP-Application-and-Practice/11_BiLSTM-ner-bilstm-crf/11.3-BiLSTM-CRF的中文命名实体识别/ner_bilstm_crf at master · zz-zik/NLP-Application-and-Practice (github.com)

 读取renmindata.pkl文件

read_file_pkl.py

# encoding:utf-8import pickle# 读取数据
def load_data():pickle_path = './data_target_pkl/renmindata.pkl'with open(pickle_path, 'rb') as inp:word2id = pickle.load(inp)id2word = pickle.load(inp)tag2id = pickle.load(inp)id2tag = pickle.load(inp)x_train = pickle.load(inp)y_train = pickle.load(inp)x_test = pickle.load(inp)y_test = pickle.load(inp)x_valid = pickle.load(inp)y_valid = pickle.load(inp)print("train len:", len(x_train))print("test len:", len(x_test))print("valid len:", len(x_valid))return word2id, tag2id, x_train, x_test, x_valid, y_train, y_test, y_valid, id2tagdef main():word = load_data()print(len(word))if __name__ == '__main__':main()

这段代码定义了一个函数load_data(),用于读取存储在文件'../data_target_pkl/renminddata.pkl'中的数据。函数首先使用pickle模块打开文件,然后逐个加载文件中的数据并赋值给相应的变量。最后,打印出训练集、测试集和验证集的长度,并返回这些变量。在main()函数中,调用load_data()函数并打印其返回值。这段代码的目的是读取并加载pickle文件中的数据,并在main()函数中测试load_data()函数的正确性。 

构建BiLSTM-CRF

bilstm_crf_model.py

# encoding:utf-8import torch
import torch.nn as nn
from TorchCRF import CRF
from torch.utils.data import Dataset# 命名体识别数据
class NERDataset(Dataset):def __init__(self, X, Y, *args, **kwargs):self.data = [{'x': X[i], 'y': Y[i]} for i in range(X.shape[0])]def __getitem__(self, index):return self.data[index]def __len__(self):return len(self.data)# LSTM_CRF模型
class NERLSTM_CRF(nn.Module):def __init__(self, config):super(NERLSTM_CRF, self).__init__()self.embedding_dim = config.embedding_dimself.hidden_dim = config.hidden_dimself.vocab_size = config.vocab_sizeself.num_tags = config.num_tagsself.embeds = nn.Embedding(self.vocab_size, self.embedding_dim)self.dropout = nn.Dropout(config.dropout)self.lstm = nn.LSTM(self.embedding_dim,self.hidden_dim // 2,num_layers=1,bidirectional=True,batch_first=True,  # 该属性设置后,需要特别注意数据的形状)self.linear = nn.Linear(self.hidden_dim, self.num_tags)# CRF 层self.crf = CRF(self.num_tags)def forward(self, x, mask):embeddings = self.embeds(x)feats, hidden = self.lstm(embeddings)emissions = self.linear(self.dropout(feats))outputs = self.crf.viterbi_decode(emissions, mask)return outputsdef log_likelihood(self, x, labels, mask):embeddings = self.embeds(x)feats, hidden = self.lstm(embeddings)emissions = self.linear(self.dropout(feats))loss = -self.crf.forward(emissions, labels, mask)return torch.sum(loss)# ner chinese

这段代码定义了一个用于命名体识别的LSTM_CRF模型。NERDataset类是一个自定义的用于存储命名体识别数据的类,继承自torch.utils.data.Dataset。NERLSTM_CRF类是一个自定义的继承自torch.nn.Module的类,用于实现LSTM_CRF模型的前向传播和训练过程。该模型包含嵌入层、LSTM层、线性层和CRF层。通过调用log_likelihood方法可以计算给定输入序列的对数似然。 

模型信息

utils.py

# encoding:utf-8
import torch
from utils import load_data
from utils import parse_tags
from utils import utils_to_train
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import classification_reportword2id = load_data()[0]
max_epoch, device, train_data_loader, valid_data_loader, test_data_loader, optimizer, model = utils_to_train()# 中文命名体识别
class ChineseNER(object):def train(self):for epoch in range(max_epoch):# 训练模式model.train()for index, batch in enumerate(train_data_loader):# 梯度归零optimizer.zero_grad()# 训练数据-->gpux = batch['x'].to(device)mask = (x > 0).to(device)y = batch['y'].to(device)# 前向计算计算损失loss = model.log_likelihood(x, y, mask)# 反向传播loss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(parameters=model.parameters(),max_norm=10)# 更新参数optimizer.step()if index % 200 == 0:print('epoch:%5d,------------loss:%f' %(epoch, loss.item()))# 验证损失和精度aver_loss = 0preds, labels = [], []for index, batch in enumerate(valid_data_loader):# 验证模式model.eval()# 验证数据-->gpuval_x, val_y = batch['x'].to(device), batch['y'].to(device)val_mask = (val_x > 0).to(device)predict = model(val_x, val_mask)# 前向计算损失loss = model.log_likelihood(val_x, val_y, val_mask)aver_loss += loss.item()# 统计非0的,也就是真实标签的长度leng = []res = val_y.cpu()for i in val_y.cpu():tmp = []for j in i:if j.item() > 0:tmp.append(j.item())leng.append(tmp)for index, i in enumerate(predict):preds += i[:len(leng[index])]for index, i in enumerate(val_y.tolist()):labels += i[:len(leng[index])]# 损失值与评测指标aver_loss /= (len(valid_data_loader) * 64)precision = precision_score(labels, preds, average='macro')recall = recall_score(labels, preds, average='macro')f1 = f1_score(labels, preds, average='macro')report = classification_report(labels, preds)print(report)torch.save(model.state_dict(), 'params1.data_target_pkl')# 预测,输入为单句,输出为对应的单词和标签def predict(self, input_str=""):model.load_state_dict(torch.load("../models/ner/params1.data_target_pkl"))model.eval()if not input_str:input_str = input("请输入文本: ")input_vec = []for char in input_str:if char not in word2id:input_vec.append(word2id['[unknown]'])else:input_vec.append(word2id[char])# convert to tensorsentences = torch.tensor(input_vec).view(1, -1).to(device)mask = sentences > 0paths = model(sentences, mask)res = parse_tags(input_str, paths[0])return res# 在测试集上评判性能def test(self, test_dataloader):model.load_state_dict(torch.load("../models/ner/params1.data_target_pkl"))aver_loss = 0preds, labels = [], []for index, batch in enumerate(test_dataloader):# 验证模式model.eval()# 验证数据-->gpuval_x, val_y = batch['x'].to(device), batch['y'].to(device)val_mask = (val_x > 0).to(device)predict = model(val_x, val_mask)# 前向计算损失loss = model.log_likelihood(val_x, val_y, val_mask)aver_loss += loss.item()# 统计非0的,也就是真实标签的长度leng = []for i in val_y.cpu():tmp = []for j in i:if j.item() > 0:tmp.append(j.item())leng.append(tmp)for index, i in enumerate(predict):preds += i[:len(leng[index])]for index, i in enumerate(val_y.tolist()):labels += i[:len(leng[index])]# 损失值与评测指标aver_loss /= len(test_dataloader)precision = precision_score(labels, preds, average='macro')recall = recall_score(labels, preds, average='macro')f1 = f1_score(labels, preds, average='macro')report = classification_report(labels, preds)print(report)if __name__ == '__main__':cn = ChineseNER()cn.train()

这段代码定义了一个用于命名实体识别的模型和训练函数。其中,parse_tags函数用于将模型的预测结果解码成可读的实体类别;Config类定义了一些超参数;utils_to_train函数返回训练过程中需要用到的各种对象和参数。 

BiLSTM-CRF的训练

train.py

# encoding:utf-8
import torch
from utils import load_data
from utils import parse_tags
from utils import utils_to_train
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import classification_reportword2id = load_data()[0]
max_epoch, device, train_data_loader, valid_data_loader, test_data_loader, optimizer, model = utils_to_train()# 中文命名体识别
class ChineseNER(object):def train(self):for epoch in range(max_epoch):# 训练模式model.train()for index, batch in enumerate(train_data_loader):# 梯度归零optimizer.zero_grad()# 训练数据-->gpux = batch['x'].to(device)mask = (x > 0).to(device)y = batch['y'].to(device)# 前向计算计算损失loss = model.log_likelihood(x, y, mask)# 反向传播loss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(parameters=model.parameters(),max_norm=10)# 更新参数optimizer.step()if index % 200 == 0:print('epoch:%5d,------------loss:%f' %(epoch, loss.item()))# 验证损失和精度aver_loss = 0preds, labels = [], []for index, batch in enumerate(valid_data_loader):# 验证模式model.eval()# 验证数据-->gpuval_x, val_y = batch['x'].to(device), batch['y'].to(device)val_mask = (val_x > 0).to(device)predict = model(val_x, val_mask)# 前向计算损失loss = model.log_likelihood(val_x, val_y, val_mask)aver_loss += loss.item()# 统计非0的,也就是真实标签的长度leng = []res = val_y.cpu()for i in val_y.cpu():tmp = []for j in i:if j.item() > 0:tmp.append(j.item())leng.append(tmp)for index, i in enumerate(predict):preds += i[:len(leng[index])]for index, i in enumerate(val_y.tolist()):labels += i[:len(leng[index])]# 损失值与评测指标aver_loss /= (len(valid_data_loader) * 64)precision = precision_score(labels, preds, average='macro')recall = recall_score(labels, preds, average='macro')f1 = f1_score(labels, preds, average='macro')report = classification_report(labels, preds)print(report)torch.save(model.state_dict(), 'params1.data_target_pkl')# 预测,输入为单句,输出为对应的单词和标签def predict(self, input_str=""):model.load_state_dict(torch.load("../models/ner/params1.data_target_pkl"))model.eval()if not input_str:input_str = input("请输入文本: ")input_vec = []for char in input_str:if char not in word2id:input_vec.append(word2id['[unknown]'])else:input_vec.append(word2id[char])# convert to tensorsentences = torch.tensor(input_vec).view(1, -1).to(device)mask = sentences > 0paths = model(sentences, mask)res = parse_tags(input_str, paths[0])return res# 在测试集上评判性能def test(self, test_dataloader):model.load_state_dict(torch.load("../models/ner/params1.data_target_pkl"))aver_loss = 0preds, labels = [], []for index, batch in enumerate(test_dataloader):# 验证模式model.eval()# 验证数据-->gpuval_x, val_y = batch['x'].to(device), batch['y'].to(device)val_mask = (val_x > 0).to(device)predict = model(val_x, val_mask)# 前向计算损失loss = model.log_likelihood(val_x, val_y, val_mask)aver_loss += loss.item()# 统计非0的,也就是真实标签的长度leng = []for i in val_y.cpu():tmp = []for j in i:if j.item() > 0:tmp.append(j.item())leng.append(tmp)for index, i in enumerate(predict):preds += i[:len(leng[index])]for index, i in enumerate(val_y.tolist()):labels += i[:len(leng[index])]# 损失值与评测指标aver_loss /= len(test_dataloader)precision = precision_score(labels, preds, average='macro')recall = recall_score(labels, preds, average='macro')f1 = f1_score(labels, preds, average='macro')report = classification_report(labels, preds)print(report)if __name__ == '__main__':cn = ChineseNER()cn.train()

这段代码实现了一个中文命名体识别的训练和预测功能。通过加载数据和训练参数,使用循环神经网络模型进行训练和验证,计算损失和评估指标,然后在测试集上进行性能评估。最后,提供一个函数用于对输入文本进行预测,并返回预测结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/183315.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分享一些基于php商城案例

案例1: ​​​​​​http://www.9520.xin/ 案例2: http://ptll.hasbuy.com/ 案例3: http://likeshop.9520.xin/mobile 案例4: http://www.hasbuy.com/

Ubuntu Linux玩童年小霸王插卡游戏

1.下载安装模拟器 在Windows平台模拟器非常多,而且效果也很优秀,Linux平台的用户常常很羡慕,却因为系统的缘故,无法使用这样的模拟器,但是随着时代的发展,Linux平台也出现了许多优秀的模拟器,现…

CTF ssrf+pin

什么是pin码 pin码是flask在开启debug模式下,进行代码调试模式所需的进入密码,需要正确的PIN码才能进入调试模式,可以理解为自带的webshell pin码如何生成 pin码生成要六要素 1.username 在可以任意文件读的条件下读 /etc/passwd进行猜测 2.modname 默…

navigator.clipboard is undefined in JavaScript issue [Fixed]

navigator.clipboard 在不安全的网站是无法访问的。 在本地开发使用localhost或127.0.0.1没有这个问题。因为它不是不安全网站。 在现实开发中,可能遇到测试环境为不安全网站。 遇到这个问题,就需要将不安全网站标记为非不安全网站即可。 外网提供了3…

【HTML】VScode不打开浏览器实时预览html

1. 问题描述 预览HTML时,不想打开浏览器,想在VScode中直接实时预览 2. 解决方案 下载Microsoft官方的Live Preview 点击预览按钮即可预览

Unity中Shader优化通用规则

文章目录 前言一、精度优化1、三种精度 fixed / half / float2、位置坐标、物理坐标类使用float3、HDR颜色、方向向量类使用half4、普通纹理、颜色类使用 fixed5、实际上,使用的精度取决于 平台 和 GPU6、现在桌面级GPU都是直接采用 float , Shader中的 fixed / hal…

J2EE征程——第一个纯servletCURD

第一个纯servletCURD 前言在此之前 一,概述二、CURD1介绍2查询并列表显示准备实体类country编写 CountryListServlet配置web.xml为web应用导入mysql-jdbc的jar包 3增加准备增加的页面addc.html编写 CAddServlet配置web.xml测试 4删除修改CountryListServlet&#xf…

RabbitMQ消息模型之Routing-Topic

Routing Topic Topic类型的Exchange与Direct相比,都是可以根据RoutingKey把消息路由到不同的队列。只不过Topic类型Exchange可以让队列在绑定Routing key的时候使用通配符!这种模型Routingkey一般都是由一个或多个单词组成,多个单词之间以”…

ESP32-Web-Server编程- WebSocket 编程

ESP32-Web-Server编程- WebSocket 编程 概述 在前述 ESP32-Web-Server 实战编程-通过网页控制设备的 GPIO 中,我们创建了一个基于 HTTP 协议的 ESP32 Web 服务器,每当浏览器向 Web 服务器发送请求,我们将 HTML/CSS 文件提供给浏览器。 使用…

智能手表上的音频(四):语音通话

上篇讲了智能手表上音频文件播放。本篇开始讲语音通话。同音频播放一样有两种case:内置codec和BT。先看这两种case下audio data path,分别如下图: 内置codec下的语音通话audio data path 蓝牙下的语音通话audio data path 从上面两张图可以看…

享元模式 (Flyweight Pattern)

定义: 享元模式(Flyweight Pattern)是一种结构型设计模式,用于优化性能和内存使用。它通过共享尽可能多的相似对象来减少内存占用,特别是在有大量对象时。这种模式通常用于减少应用程序中对象的数量,并在多…

Redis 实战缓存

本篇概要: 1. 设置、查询、获取过期时间;2. 缓存穿透:设置空键;3. 封杀单ip;4. 封杀ip段;5. 缓存预热;6. 使用 hash 数据类型保存新闻的缓存,增加点击量;7. Sorted set&a…

纯js实现录屏并保存视频到本地的尝试

前言:先了解下:navigator.mediaDevices,mediaDevices 是 Navigator 只读属性,返回一个 MediaDevices 对象,该对象可提供对相机和麦克风等媒体输入设备的连接访问,也包括屏幕共享。 const media navigator…

【刷题】DFS

DFS 递归: 1.判断是否失败终止 2.判断是否成功终止,如果成功的,记录一个成果 3.遍历各种选择,在这部分可以进行剪枝 4.在每种情况下进行DFS,并进行回退。 199. 二叉树的右视图 给定一个二叉树的 根节点 root&#x…

深度学习之十二(图像翻译AI算法--UNIT(Unified Neural Translation))

概念 UNIT(Unified Neural Translation)是一种用于图像翻译的 AI 模型。它是一种基于生成对抗网络(GAN)的框架,用于将图像从一个域转换到另一个域。在图像翻译中,这意味着将一个风格或内容的图像转换为另一个风格或内容的图像,而不改变图像的内容或语义。 UNIT 的核心…

IDEA2022 Git 回滚及回滚内容恢复

IDEA2022 Git 回滚 ①选择要回滚的地方,右键选择 ②选择要回滚的模式 ③开始回滚 ④soft模式回滚的内容会保留在暂存区 ⑤输入git push -f ,然后重新提交,即可同步远程 注意观察IDEA右下角分支的标记,蓝色代表远程内容未同步到本…

数据结构 / day06 作业

1.下面的代码打印在屏幕上的值是多少? /下面的代码打印在屏幕上的值是多少?#include "stdio.h"int compute_data(int arr[], unsigned int len) {long long int result 0;if(result len)return arr[0];resultcompute_data(arr,--len);printf("len%d, res…

elementui中table进行表单验证

<el-form :model"ruleForm" ref"ruleForm" class"demo-ruleForm"><el-table :data"ruleForm.tableDataShou" border style"width: 100%;"><el-table-column type"index" label"序号" wi…

Android12源码分析

Android 12的源码结构与之前的版本类似&#xff0c;但也有一些新的变化和特性。以下是对Android 12源码结构的简要解析&#xff1a; 1. 系统源代码&#xff1a;这部分包含了整个Android操作系统的核心代码&#xff0c;包括Linux内核、系统库、运行时环境&#xff08;ART&#…

Flink源码解析零之重要名词的理解

名词解释 1)StreamGraph 根据用户通过 Stream API 编写的代码生成的最初的图。 (1)StreamNode 用来代表 operator 的类,并具有所有相关的属性,如并发度、入边和出边等。 (2)StreamEdge 表示连接两个StreamNode的边。 2)JobGraph StreamGraph经过优化后生成了 J…