NLP(20)--知识图谱+实体抽取

前言

仅记录学习过程,有问题欢迎讨论

基于LLM的垂直领域问答方案:

  • 特点:不是通用语料;准确度要求高,召回率可以低(转人工);拓展性和可控性(改变特定内容的回答);确切的评价指标

实现:
传统方法:

  • 知识库+文本匹配 (问题转向量找知识库的相似度最高的问题来回答)

基于LLM:

1.直接生成:

  • 直接使用LLM获答案,然后使用SFT微调
  • 缺点:fine-tune的成本高;模型的泛用性下降;不易调整数据

2.RAG思路(推荐):

  • 段落召回+阅读理解(基于获取的信息给模型提示,期望获取一个更好的答案)
  • 召回的就是你垂直领域的内容,去给llm提示
  • 缺点:对LLM要求高,受召回算法限制(如果正确答案被丢弃,LLM无法挽回);生成结果不可控

3.基于知识体系(图谱)

  • 树形结构设置知识体系结构,然后给LLM去匹配最可能的知识点选择项,一级一级往下走
  • 缺点:需要大量标注数据,标注数据需要人工标注,标注成本高

知识图谱:

  • 知识图谱是图数据库的一种,用于存储和表示知识。
  • 如:姚明-身高-226cm(三元组)
  • 知识图谱的构建取决于你想要完成的任务,如果想要获取实体之间的关系,就要构建实体关系的图谱
以下为知识图谱的具体内容:

实体抽取:ner任务获取实体属性
关系抽取:

  • 限定领域:
    • 文本+实体送入模型 预测关系(本质还是分类任务)
    • 可以同时训练实体抽取和关系抽取,loss为二者相加
  • 开放领域:
    基于序列标注 (NER)

属性抽取:同关系抽取

知识融合:

  • 实体对齐:通过判断不同来源的属性的相似度
  • 实体消歧:根据上下文和语义关系进行实体消歧
  • 属性对齐:属性和属性值的相似度计算

知识推理:通过模型来推断两个实体的关系
知识表示:实体关系属性都转化为向量,都可以用id表示某个信息

图数据库的使用:

noe4j
使用NL2SQL把输入的文字变为sql查询

  • 方法1:基于模版+文本匹配,输入的文本去匹配对应的问题模版–再去匹配SQL(依赖模版,易于拓展,可以写复杂sql)
  • 方法2:semantic parsing(语义解析)–通过训练多个模型来获取sql(不易于拓展)
  • 方法3:用LLM写sql

代码展示

构建基于neo4j的知识图谱问答:
这里采用的是方法1,依赖问题模版。

import re
import json
import pandas
import itertools
from py2neo import Graph"""
使用neo4j 构建基于知识图谱的问答
需要自定义问题模板
"""class GraphQA:def __init__(self):# 启动neo4j neo4j consoleself.graph = Graph("http://localhost:7474", auth=("neo4j", "password"))schema_path = "kg_schema.json"templet_path = "question_templet.xlsx"self.load(schema_path, templet_path)print("知识图谱问答系统加载完毕!\n===============")# 对外提供问答接口def query(self, sentence):print("============")print(sentence)# 对输入的句子找和模板中最匹配的问题info = self.parse_sentence(sentence)  # 信息抽取print("info:", info)# 匹配模板templet_cypher_score = self.cypher_match(sentence, info)  # cypher匹配for templet, cypher, score, answer in templet_cypher_score:graph_search_result = self.graph.run(cypher).data()# 最高分命中的模板不一定在图上能找到答案, 当不能找到答案时,运行下一个搜索语句, 找到答案时停止查找后面的模板if graph_search_result:answer = self.parse_result(graph_search_result, answer)return answerreturn None# 加载模板def load(self, schema_path, templet_path):self.load_kg_schema(schema_path)self.load_question_templet(templet_path)return# 加载模板信息def load_question_templet(self, templet_path):dataframe = pandas.read_excel(templet_path)self.question_templet = []for index in range(len(dataframe)):question = dataframe["question"][index]cypher = dataframe["cypher"][index]cypher_check = dataframe["check"][index]answer = dataframe["answer"][index]self.question_templet.append([question, cypher, json.loads(cypher_check), answer])return# 返回输入中的实体,关系,属性def parse_sentence(self, sentence):# 先提取实体,关系,属性entitys = self.get_mention_entitys(sentence)relations = self.get_mention_relations(sentence)labels = self.get_mention_labels(sentence)attributes = self.get_mention_attributes(sentence)# 然后根据模板进行匹配return {"%ENT%": entitys,"%REL%": relations,"%LAB%": labels,"%ATT%": attributes}# 获取问题中谈到的实体,可以使用基于词表的方式,也可以使用NER模型def get_mention_entitys(self, sentence):return re.findall("|".join(self.entity_set), sentence)# 获取问题中谈到的关系,也可以使用各种文本分类模型def get_mention_relations(self, sentence):return re.findall("|".join(self.relation_set), sentence)# 获取问题中谈到的属性def get_mention_attributes(self, sentence):return re.findall("|".join(self.attribute_set), sentence)# 获取问题中谈到的标签def get_mention_labels(self, sentence):return re.findall("|".join(self.label_set), sentence)# 加载图谱信息def load_kg_schema(self, path):with open(path, encoding="utf8") as f:schema = json.load(f)self.relation_set = set(schema["relations"])self.entity_set = set(schema["entitys"])self.label_set = set(schema["labels"])self.attribute_set = set(schema["attributes"])return# 匹配模板的问题def cypher_match(self, sentence, info):# 根据提取到的实体,关系等信息,将模板展开成待匹配的问题文本templet_cypher_pair = self.expand_question_and_cypher(info)result = []for templet, cypher, answer in templet_cypher_pair:# 求相似度 距离函数score = self.sentence_similarity_function(sentence, templet)# print(sentence, templet, score)result.append([templet, cypher, score, answer])# 取最相似的result = sorted(result, reverse=True, key=lambda x: x[2])return result# 根据提取到的实体,关系等信息,将模板展开成待匹配的问题文本def expand_question_and_cypher(self, info):templet_cypher_pair = []# 模板的数据for templet, cypher, cypher_check, answer in self.question_templet:# 匹配模板cypher_check_result = self.match_cypher_check(cypher_check, info)if cypher_check_result:templet_cypher_pair += self.expand_templet(templet, cypher, cypher_check, info, answer)return templet_cypher_pair# 校验 减少比较次数def match_cypher_check(self, cypher_check, info):for key, required_count in cypher_check.items():if len(info.get(key, [])) < required_count:return Falsereturn True# 对于单条模板,根据抽取到的实体属性信息扩展,形成一个列表# info:{"%ENT%":["周杰伦", "方文山"], “%REL%”:[“作曲”]}def expand_templet(self, templet, cypher, cypher_check, info, answer):# 获取所有组合combinations = self.get_combinations(cypher_check, info)templet_cpyher_pair = []for combination in combinations:# 替换模板中的实体,属性,关系replaced_templet = self.replace_token_in_string(templet, combination)replaced_cypher = self.replace_token_in_string(cypher, combination)replaced_answer = self.replace_token_in_string(answer, combination)templet_cpyher_pair.append([replaced_templet, replaced_cypher, replaced_answer])return templet_cpyher_pair# 对于找到了超过模板中需求的实体数量的情况,需要进行排列组合# info:{"%ENT%":["周杰伦", "方文山"], “%REL%”:[“作曲”]}def get_combinations(self, cypher_check, info):slot_values = []for key, required_count in cypher_check.items():# 生成所有组合slot_values.append(itertools.combinations(info[key], required_count))value_combinations = itertools.product(*slot_values)combinations = []for value_combination in value_combinations:combinations.append(self.decode_value_combination(value_combination, cypher_check))return combinations# 将提取到的值分配到键上def decode_value_combination(self, value_combination, cypher_check):res = {}for index, (key, required_count) in enumerate(cypher_check.items()):if required_count == 1:res[key] = value_combination[index][0]else:for i in range(required_count):key_num = key[:-1] + str(i) + "%"res[key_num] = value_combination[index][i]return res# 将带有token的模板替换成真实词# string:%ENT1%和%ENT2%是%REL%关系吗# combination: {"%ENT1%":"word1", "%ENT2%":"word2", "%REL%":"word"}def replace_token_in_string(self, string, combination):for key, value in combination.items():string = string.replace(key, value)return string# 求相似度 距离函数 Jaccard相似度def sentence_similarity_function(self, sentence1, sentence2):# print("计算  %s %s"%(string1, string2))jaccard_distance = len(set(sentence1) & set(sentence2)) / len(set(sentence1) | set(sentence2))return jaccard_distance# 解析结果def parse_result(self, graph_search_result, answer):graph_search_result = graph_search_result[0]# 关系查找返回的结果形式较为特殊,单独处理if "REL" in graph_search_result:graph_search_result["REL"] = list(graph_search_result["REL"].types())[0]answer = self.replace_token_in_string(answer, graph_search_result)return answerif __name__ == '__main__':graph = GraphQA()res = graph.query("谁导演的不能说的秘密")print(res)res = graph.query("发如雪的谱曲是谁")print(res)

代码2

通过联合训练同时实现 属性和实体(关系)的抽取
config

"""
配置参数信息
"""
Config = {"model_path": "./output/","model_name": "model.pt","schema_path": r"schema.json","train_data_path": r"D:\NLP\video\第十四周\week13 知识问答\triplet_data\train_triplet_data.json","valid_data_path": r"D:\NLP\video\第十四周\week13 知识问答\triplet_data\valid_triplet_data.json","vocab_path": r"D:\NLP\video\第九周 序列标注任务\课件\week9 序列标注问题\ner\chars.txt","model_type": "lstm",# 文本向量大小"char_dim": 20,# 文本长度"max_len": 200,# 词向量大小"hidden_size": 64,# 训练 轮数"epoch_size": 15,# 批量大小"batch_size": 32,# 训练集大小"simple_size": 500,# 学习率"lr": 1e-3,# dropout"dropout": 0.5,# 优化器"optimizer": "adam",# 卷积核"kernel_size": 3,# 最大池 or 平均池"pooling_style": "avg",# 模型层数"num_layers": 3,"bert_model_path": r"D:\NLP\video\第六周\bert-base-chinese",# 输出层大小"output_size": 9,# 随机数种子"seed": 987,"attribute_loss_ratio": 0.01
}

loader

"""
数据加载
"""
import os
import numpy as np
import json
import re
import os
import torch
import torch.utils.data as Data
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer# 获取字表集
def load_vocab(path):vocab = {}with open(path, 'r', encoding='utf-8') as f:for index, line in enumerate(f):word = line.strip()# 0留给padding位置,所以从1开始vocab[word] = index + 1vocab['unk'] = len(vocab) + 1return vocabclass DataGenerator:def __init__(self, data_path, config):self.data_path = data_pathself.config = configself.schema = self.load_schema(config["schema_path"])if self.config["model_type"] == "bert":self.tokenizer = BertTokenizer.from_pretrained(config["bert_model_path"])self.vocab = load_vocab(config["vocab_path"])self.config["vocab_size"] = len(self.vocab)self.max_len = config["max_len"]# 实体标签self.bio_schema = {"B_object": 0,"I_object": 1,"B_value": 2,"I_value": 3,"O": 4}self.config["label_count"] = len(self.bio_schema)self.config["attribute_count"] = len(self.schema)# 中文的语句list# self.sentence_list = []self.load_data()def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]def load_schema(self, path):with open(path, encoding="utf8") as f:return json.load(f)def load_data(self):# 储存实体数据 测试使用self.text_data = []# 处理完的数据self.data = []# 是否超长self.exceed_max_length = 0with open(self.data_path, 'r', encoding='utf-8') as f:for line in f:data = json.loads(line.strip())# 文本sentence = data["context"]# 获取实体object = data["object"]attribute = data["attribute"]value = data["value"]if attribute not in self.schema:attribute = "UNRELATED"self.text_data.append([sentence, object, attribute, value])input_id, attribute_label, sentence_label = self.process_sentence(sentence, object, attribute, value)self.data.append([torch.LongTensor(input_id),torch.LongTensor([attribute_label]),torch.LongTensor(sentence_label)])return data# 文本预处理# 转化为向量def sentence_to_index(self, text):input_ids = []vocab = self.vocabif self.config["model_type"] == "bert":# 中文的文本转化为tokenizer的idinput_ids = self.tokenizer.encode(text, padding="max_length", truncation=True,max_length=self.config["max_len"], add_special_tokens=False)else:for char in text:input_ids.append(vocab.get(char, vocab['unk']))# 填充or裁剪input_ids = self.padding(input_ids)return input_ids# 数据预处理 裁剪or填充def padding(self, input_ids, padding_dot=0):length = self.config["max_len"]padded_input_ids = input_idsif len(input_ids) >= length:return input_ids[:length]else:padded_input_ids += [padding_dot] * (length - len(input_ids))return padded_input_ids# 处理数据def process_sentence(self, sentence, object, attribute, value):if len(sentence) > self.max_len:self.exceed_max_length += 1object_start = sentence.index(object)value_start = sentence.index(value)# 文本转向量input_id = self.sentence_to_index(sentence)# 属性转向量attribute_index = self.schema[attribute]# assert len(sentence) == len(input_id)# 构建实体和属性标签 在一句话中 一个变量就可以 初始化要用 Olabel = [self.bio_schema["O"]] * len(input_id)# print(sentence,"======",object,"=====",object_start)if object_start < self.max_len:label[object_start] = self.bio_schema["B_object"]# 标记实体for index in range(object_start + 1, object_start + len(object)):if index >= self.max_len:breaklabel[index] = self.bio_schema["I_object"]# 标记属性if value_start < self.max_len:label[value_start] = self.bio_schema["B_value"]for index in range(value_start + 1, value_start + len(value)):if index >= self.max_len:breaklabel[index] = self.bio_schema["I_value"]if len(input_id) != len(label):label = self.padding(label, -1)return input_id, attribute_index, label# 用torch自带的DataLoader类封装数据
def load_data_batch(data_path, config, shuffle=True):dg = DataGenerator(data_path, config)# DataLoader 类封装数据 dg除了data 还包含其他信息(后面需要使用)dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)return dlif __name__ == '__main__':from config import Configdg = DataGenerator(Config["train_data_path"], Config)print(len(dg))print(dg[0])

model

import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from transformers import BertModel
from torchcrf import CRF"""
建立网络模型结构
"""class TorchModel(nn.Module):def __init__(self, config):super(TorchModel, self).__init__()hidden_size = config["hidden_size"]vocab_size = config["vocab_size"] + 1output_size = config["output_size"]self.model_type = config["model_type"]num_layers = config["num_layers"]self.pooling_style = config["pooling_style"]# self.use_bert = config["use_bert"]# self.use_crf = config["use_crf"]self.emb = nn.Embedding(vocab_size + 1, hidden_size, padding_idx=0)if self.model_type == 'rnn':self.encoder = nn.RNN(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers,batch_first=True)elif self.model_type == 'lstm':# 双向lstm,输出的是 hidden_size * 2(num_layers 要写2)self.encoder = nn.LSTM(hidden_size, hidden_size, num_layers=num_layers, bidirectional=True,batch_first=True)hidden_size = hidden_size * 2elif self.model_type == 'bert':self.encoder = BertModel.from_pretrained(config["bert_model_path"])# 需要使用预训练模型的hidden_sizehidden_size = self.encoder.config.hidden_sizeelif self.model_type == 'cnn':self.encoder = CNN(config)elif self.model_type == "gated_cnn":self.encoder = GatedCNN(config)elif self.model_type == "bert_lstm":self.encoder = BertLSTM(config)# 需要使用预训练模型的hidden_sizehidden_size = self.encoder.config.hidden_size# self.classify = nn.Linear(hidden_size, output_size)# 实体的分类输出self.bio_classifier = nn.Linear(hidden_size, config["label_count"])# 属性的分类输出self.attribute_classifier = nn.Linear(hidden_size, config["attribute_count"])self.pooling_style = config["pooling_style"]# self.crf_layer = CRF(output_size, batch_first=True)# 一起计算loss的权重self.attribute_loss_ratio = config["attribute_loss_ratio"]self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1)  # loss采用交叉熵损失def forward(self, x, attribute_target=None, bio_target=None):if self.model_type == 'bert':# 输入x为[batch_size, seq_len]# bert返回的结果是 (sequence_output, pooler_output)# sequence_output:batch_size, max_len, hidden_size# pooler_output:batch_size, hidden_sizex = self.encoder(x)[0]else:x = self.emb(x)x = self.encoder(x)# 判断x是否是tupleif isinstance(x, tuple):x = x[0]# 序列标注(单个字)bio_predict = self.bio_classifier(x)  # (batch_size, max_length, 5) (Head-B, Head-I, Tail-B, Tail-I, O)# 文本分类# 池化层if self.pooling_style == "max":# shape[1]代表列数,shape是行和列数构成的元组self.pooling_style = nn.MaxPool1d(x.shape[1])elif self.pooling_style == "avg":self.pooling_style = nn.AvgPool1d(x.shape[1])x = self.pooling_style(x.transpose(1, 2)).squeeze()# 属性分类attribute_predict = self.attribute_classifier(x)if bio_target is not None:bio_loss = self.loss(bio_predict.view(-1, bio_predict.shape[-1]), bio_target.view(-1))# attribute_loss = self.loss(attribute_predict,  attribute_target.view(-1))attribute_loss = self.loss(attribute_predict.view(x.shape[0], -1), attribute_target.view(-1))return bio_loss + self.attribute_loss_ratio * attribute_losselse:return attribute_predict, bio_predict# 优化器的选择
def choose_optimizer(config, model):optimizer = config["optimizer"]learning_rate = config["lr"]if optimizer == "adam":return Adam(model.parameters(), lr=learning_rate)elif optimizer == "sgd":return SGD(model.parameters(), lr=learning_rate)# 定义CNN模型
class CNN(nn.Module):def __init__(self, config):super(CNN, self).__init__()hidden_size = config["hidden_size"]kernel_size = config["kernel_size"]pad = int((kernel_size - 1) / 2)self.cnn = nn.Conv1d(hidden_size, hidden_size, kernel_size, bias=False, padding=pad)def forward(self, x):  # x : (batch_size, max_len, embeding_size)return self.cnn(x.transpose(1, 2)).transpose(1, 2)# 定义GatedCNN模型
class GatedCNN(nn.Module):def __init__(self, config):super(GatedCNN, self).__init__()self.cnn = CNN(config)self.gate = CNN(config)# 定义前向传播函数 比普通cnn多了一次sigmoid 然后互相卷积def forward(self, x):a = self.cnn(x)b = self.gate(x)b = torch.sigmoid(b)return torch.mul(a, b)# 定义BERT-LSTM模型
class BertLSTM(nn.Module):def __init__(self, config):super(BertLSTM, self).__init__()self.bert = BertModel.from_pretrained(config["bert_model_path"], return_dict=False)self.rnn = nn.LSTM(self.bert.config.hidden_size, self.bert.config.hidden_size, batch_first=True)def forward(self, x):x = self.bert(x)[0]x, _ = self.rnn(x)return x# if __name__ == "__main__":
#     from config import Config
#
#     Config["output_size"] = 2
#     Config["vocab_size"] = 20
#     Config["max_length"] = 5
#     Config["model_type"] = "bert"
#     Config["use_bert"] = True
#     # model = BertModel.from_pretrained(Config["bert_model_path"], return_dict=False)
#     x = torch.LongTensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
#     # sequence_output, pooler_output = model(x)
#     # print(x[1], type(x[2]), len(x[2]))
#
#     model = TorchModel(Config)
#     label = torch.LongTensor([0,1])
#     print(model(x, label))

main

# -*- coding: utf-8 -*-
import reimport torch
import os
import random
import os
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from loader import load_data_batch
from evaluate import Evaluator# [DEBUG, INFO, WARNING, ERROR, CRITICAL]logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)"""
模型训练主程序
"""
# 通过设置随机种子来复现上一次的结果(避免随机性)
seed = Config["seed"]
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)def main(config):# 保存模型的目录if not os.path.isdir(config["model_path"]):os.mkdir(config["model_path"])# 加载数据dataset = load_data_batch(config["train_data_path"], config)# 加载模型model = TorchModel(config)# 是否使用gpuif torch.cuda.is_available():logger.info("gpu可以使用,迁移模型至gpu")model.cuda()# 选择优化器optim = choose_optimizer(config, model)# 加载效果测试类evaluator = Evaluator(config, model, logger)for epoch in range(config["epoch_size"]):epoch += 1logger.info("epoch %d begin" % epoch)epoch_loss = []# 训练模型model.train()for index, batch_data in enumerate(dataset):if torch.cuda.is_available():batch_data = [d.cuda() for d in batch_data]# x, y = dataiter# 反向传播optim.zero_grad()input_id, attribute_index, label = batch_data     # 输入变化时这里需要修改,比如多输入,多输出的情况# 计算梯度loss = model(input_id, attribute_index,label)# 梯度更新loss.backward()# 优化器更新模型optim.step()# 记录损失epoch_loss.append(loss.item())logger.info("epoch average loss: %f" % np.mean(epoch_loss))# 测试模型效果evaluator.eval(epoch)model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)torch.save(model.state_dict(), model_path)return model, datasetif __name__ == "__main__":main(Config)

evaluate

"""
模型效果测试
"""
import re
from collections import defaultdictimport numpy as np
import torch
from loader import load_data_batchclass Evaluator:def __init__(self, config, model, logger):self.config = configself.model = modelself.logger = logger# 选择验证集合self.dataset = load_data_batch(config["valid_data_path"], config, shuffle=False)# 获取loader时的字符集self.bio_schema = self.dataset.dataset.bio_schemaself.attribute_schema = self.dataset.dataset.schemaself.text_data = self.dataset.dataset.text_data# 将index转换为标签self.index_to_label = dict((y, x) for x, y in self.attribute_schema.items())def eval(self, epoch):self.logger.info("开始测试第%d轮模型效果:" % epoch)# 测试模式self.model.eval()self.stats_dict = {"object_acc": 0, "attribute_acc": 0, "value_acc": 0, "full_match_acc": 0}# 遍历验证集for index, batch_data in enumerate(self.dataset):# 分割为一个batch 因为text_data是没分割batch的 所以评估数据集shuffle = Falsetext_data = self.text_data[index * self.config["batch_size"]: (index + 1) * self.config["batch_size"]]if torch.cuda.is_available():batch_data = [d.cuda() for d in batch_data]# 这里只取input_id (32*200)input_id = batch_data[0]  # 输入变化时这里需要修改,比如多输入,多输出的情况with torch.no_grad():attribute_pred, bio_pred = self.model(input_id)  # 不输入labels,使用模型当前参数进行预测self.write_stats(attribute_pred, bio_pred, text_data)self.show_stats()returndef write_stats(self, attribute_pred, bio_pred, text_data):attribute_pred = torch.argmax(attribute_pred, dim=-1)bio_pred = torch.argmax(bio_pred, dim=-1)for attribute_p, bio_p, info in zip(attribute_pred, bio_pred, text_data):context, object, attribute, value = infobio_p = bio_p.cpu().detach().tolist()# 获取实体的预测值pred_object, pred_value = self.decode1(bio_p, context)# 获取属性的预测值pred_attribute = self.index_to_label[int(attribute_p)]self.stats_dict["object_acc"] += int(pred_object == object)self.stats_dict["attribute_acc"] += int(pred_attribute == attribute)self.stats_dict["value_acc"] += int(pred_value == value)if pred_value == value and pred_attribute == attribute and pred_object == object:self.stats_dict["full_match_acc"] += 1return# 打印结果def show_stats(self):for key, value in self.stats_dict.items():self.logger.info("%s : %s " % (key, value / len(self.text_data)))self.logger.info("--------------------")return# 相当于截取对应的句子def decode1(self, labels, sentence):labels = "".join([str(i) for i in labels])pred_obj, pred_value = "", ""# * 表示0 or d多个for location in re.finditer("(01*)", labels):s, e = location.span()pred_obj = sentence[s:e]breakfor location in re.finditer("(23*)", labels):s, e = location.span()pred_value = sentence[s:e]breakreturn pred_obj, pred_valuedef decode(self, pred_label, context):pred_label = "".join([str(i) for i in pred_label])pred_obj = self.seek_pattern("01*", pred_label, context)pred_value = self.seek_pattern("23*", pred_label, context)return pred_obj, pred_valuedef seek_pattern(self, pattern, pred_label, context):pred_obj = re.search(pattern, pred_label)if pred_obj:s, e = pred_obj.span()pred_obj = context[s:e]else:pred_obj = ""return pred_obj

predict

# -*- coding: utf-8 -*-
import torch
import re
import json
import numpy as np
from collections import defaultdict
from config import Config
from model import TorchModel"""
模型效果测试
"""class SentenceLabel:def __init__(self, config, model_path):self.config = configself.bio_schema = {"B_object": 0,"I_object": 1,"B_value": 2,"I_value": 3,"O": 4}self.attribute_schema = json.load(open(config["schema_path"], encoding="utf8"))self.index_to_label = dict((y, x) for x, y in self.attribute_schema.items())self.config["label_count"] = len(self.bio_schema)self.config["attribute_count"] = len(self.attribute_schema)self.vocab = self.load_vocab(config["vocab_path"])self.config["vocab_size"] = len(self.vocab)self.model = TorchModel(config)self.model.load_state_dict(torch.load(model_path))self.model.eval()print("模型加载完毕!")# 加载字表或词表def load_vocab(self, path):vocab = {}with open(path, 'r', encoding='utf-8') as f:for index, line in enumerate(f):word = line.strip()# 0留给padding位置,所以从1开始vocab[word] = index + 1vocab['unk'] = len(vocab) + 1return vocabdef decode(self, attribute_label, bio_label, context):pred_attribute = self.index_to_label[int(attribute_label)]bio_label = "".join([str(i) for i in bio_label.detach().tolist()])pred_obj = self.seek_pattern("01*", bio_label, context)pred_value = self.seek_pattern("23*", bio_label, context)return pred_obj, pred_attribute, pred_valuedef seek_pattern(self, pattern, pred_label, context):pred_obj = re.search(pattern, pred_label)if pred_obj:s, e = pred_obj.span()pred_obj = context[s:e]else:pred_obj = ""return pred_objdef predict(self, sentence):input_id = []for char in sentence:input_id.append(self.vocab.get(char, self.vocab["[UNK]"]))with torch.no_grad():attribute_pred, bio_pred = self.model(torch.LongTensor([input_id]))attribute_pred = torch.argmax(attribute_pred)bio_pred = torch.argmax(bio_pred[0], dim=-1)object, attribute, value = self.decode(attribute_pred, bio_pred, sentence)return object, attribute, valueif __name__ == "__main__":sl = SentenceLabel(Config, "output/epoch_15.pth")sentence = "浙江理工大学是一所以工为主,特色鲜明,优势突出,理、工、文、经、管、法、艺术、教育等多学科协调发展的省属重点建设大学。"res = sl.predict(sentence)print(res)sentence = "出露地层的岩石以沉积岩为主(其中最多为碳酸盐岩),在受到乌江的切割下,由内外力共同作用,形成沟壑密布、崎岖复杂的地貌景观。"res = sl.predict(sentence)print(res)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/20237.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

字符串 | 字符串匹配之 KMP 算法以及 C++ 代码实现

目录 1 为什么使用 KMP&#xff1f;2 什么是 next 数组&#xff1f;2.1 什么是字符串的前后缀&#xff1f;2.2 如何计算 next 数组&#xff1f; 3 KMP 部分的算法4 完整代码 &#x1f608;前言&#xff1a;这篇文章比较长&#xff0c;但我感觉自己是讲明白了的 1 为什么…

让低代码平台插上AI的翅膀 - 记开源驰骋AI平台升级

让低代码系统插上AI的翅膀——驰骋低代码开发平台引领新时代 在当今日新月异的科技世界中&#xff0c;人工智能&#xff08;AI&#xff09;已经成为各个行业不可或缺的一部分。从制造业的自动化生产到金融行业的智能风控&#xff0c;再到医疗领域的精准诊断&#xff0c;AI技术…

Kafka自定义分区器编写教程

1.创建java类MyPartitioner并实现Partitioner接口 点击灯泡选择实现方法&#xff0c;导入需要实现的抽象方法 2.实现方法 3.自定义分区器的使用 在自定义生产者消息发送时&#xff0c;属性配置上加入自定义分区器 properties.put(ProducerConfig.PARTITIONER_CLASS_CONFIG,&q…

基于STM32的轻量级Web服务器设计

文章目录 一、前言1.1 开发背景1.2 实现的功能1.3 硬件模块组成1.4 ENC28J60网卡介绍1.5 UIP协议栈【1】目标与特点【2】核心组件【3】应用与优势 1.6 添加UIP协议栈实现创建WEB服务器步骤1.7 ENC28J60添加UIP协议栈实现创建WEB客户端1.8 ENC28J60移植UIP协议并编写服务器测试示…

SD-WAN:企业网络转型的必然趋势

随着SD-WAN技术的不断进步和完善&#xff0c;越来越多的企业选择利用SD-WAN进行网络转型。根据IDC的研究&#xff0c;47%的企业已经成功迁移到SD-WAN&#xff0c;另有48%的公司计划在未来两个月内部署这一技术。 据Channel Futures报道&#xff0c;一位合作伙伴透露&#xff0c…

NetMizer 日志管理系统前台RCE漏洞

声明 本文仅用于技术交流&#xff0c;请勿用于非法用途 由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失&#xff0c;均由使用者本人负责&#xff0c;文章作者不为此承担任何责任。 一、产品介绍 NetMizer日志管理系统是一个与NetMizer流量管理设备配合…

矩阵1-范数与二重求和的求和可交换

矩阵1-范数与二重求和的求和可交换 1、矩阵1-范数 A [ a 11 a 12 ⋯ a 1 n a 21 a 22 ⋯ a 2 n ⋮ ⋮ ⋱ ⋮ a n 1 a n 2 ⋯ a n n ] A \begin{bmatrix} a_{11} &a_{12} &\cdots &a_{1n} \\ a_{21} &a_{22} &\cdots &a_{2n} \\ \vdots &\vdots …

【开源】Wordpress自定义鼠标样式插件

插件简介 使用此插件可一键自定义Wordpress前端鼠标指针样式。利用该插件&#xff0c;站长可以快速实现替换多种鼠标指针样式于网站前端。 鼠标图案均来自于网络&#xff0c;插件仅作收集整理&#xff0c;插件完全开源无任何商业性质。 插件截图 使用教程 下载插件文件 下载…

文件系统小册(FusePosixK8s csi)【2 Posix标准】

文件系统小册&#xff08;Fuse&Posix&K8s csi&#xff09;【2 Posix】 往期文章&#xff1a;文件系统小册&#xff08;Fuse&Posix&K8s csi&#xff09;【1 Fuse】 POSIX&#xff1a;可移植操作系统接口&#xff08;标准&#xff09; 1 概念 POSIX&#xff1a;…

前端Vue自定义支付密码输入框键盘与设置弹框组件的设计与实现

摘要 随着信息技术的不断发展&#xff0c;前端开发的复杂性日益加剧。传统的开发方式&#xff0c;即将整个系统构建为一个庞大的整体应用&#xff0c;往往会导致开发效率低下和维护成本高昂。任何微小的改动或新功能的增加都可能引发对整个应用逻辑的广泛影响&#xff0c;这种…

【原创】springboot+mysql医院预约挂号管理系统设计与实现

个人主页&#xff1a;程序猿小小杨 个人简介&#xff1a;从事开发多年&#xff0c;Java、Php、Python、前端开发均有涉猎 博客内容&#xff1a;Java项目实战、项目演示、技术分享 文末有作者名片&#xff0c;希望和大家一起共同进步&#xff0c;你只管努力&#xff0c;剩下的交…

【Redis】List源码剖析

大家好&#xff0c;我是白晨&#xff0c;一个不是很能熬夜&#xff0c;但是也想日更的人。如果喜欢这篇文章&#xff0c;点个赞&#x1f44d;&#xff0c;关注一下&#x1f440;白晨吧&#xff01;你的支持就是我最大的动力&#xff01;&#x1f4aa;&#x1f4aa;&#x1f4aa…

linux:命令别名,文件描述符及重定向

命令别名 命令别名是Shell提供的一种快捷方式&#xff0c;允许为命令创建简短的替代名称。&#xff0c;可以通过输入较短的别名来执行较长的命令&#xff0c;从而提高效率。 1.查看所有别名: [rootlocalhost ~]# alias 2.创建临时别名,当前会话关闭即清除 alias 别名完整命令…

游戏交易平台源码游戏帐号交易平台系统源码

功能介绍 1&#xff1a;后台可以添加删除游戏分类 2&#xff1a;会员中心可以出售游戏币,账号&#xff0c;装备 3&#xff1a;后台可以对会员和商品进行管理 4&#xff1a;多商家入驻,商家发布信息 5&#xff1a;手机版功能可以生成APP 6&#xff1a;在线支付可支持微信和支…

VQAScore开启文本到视觉生成评估新篇章

随着生成式人工智能技术的飞速发展&#xff0c;如何全面评估生成内容的质量和与输入提示的一致性成为了一个挑战。在图像-文本对齐领域&#xff0c;传统的评估方法如CLIPScore存在局限性&#xff0c;尤其是在处理涉及多个对象、属性和关系的复杂提示时。它们通常基于简单的词袋…

MES系统的功能、架构及应用价值

MES系统生产过程控制的主要方面涵盖了生产计划与控制、生产调度与排程、数据采集与监控、质量控制与管理、物料管理与控制以及设备管理与维护等多个方面。这些功能共同构成了MES系统的核心价值&#xff0c;帮助企业实现生产过程的数字化、智能化和精细化管理。 一、工厂使用MES…

【Oracle】修改已经存在的序列的当前值

前情提要 在oracle中一般使用序列来实现ID自增。但是oracle中序列维护的没有mysql那么好。只是单存的递增。 比如新建了一个序列&#xff0c;从1开始&#xff0c;每次递增1。此时我向数据库里插入一条id10的数据。那么在序列查询到10的时候&#xff0c;插入就会报错。 所以比较…

vue父组件如何向子组件传递数据?

Vue.js 中,父组件向子组件传递数据的主要方式是通过 props。具体步骤如下: 在父组件中定义要传递的数据:<!-- 父组件模板 --> <template><child-component :message"parentMessage"></child-component> </template><script> ex…

2024-05-29_二进制文件和文本文件作业

1.关于文本文件和二进制文件描述错误的是&#xff1f;&#xff08; &#xff09; A.文本文件是可以读懂的&#xff0c;二进制文件没办法直接读懂 B.数据在内存中以二进制的形式存储&#xff0c;如果不加转换的输出到外存&#xff0c;就是二进制文件 C.将内存中的数据转化成ASC…

Vue3-Setup-“集大成者”

何为Setup&#xff1a; setup是Vue3中一个新的配置项&#xff0c;值是一个函数&#xff0c;它是 Composition API 行为的根基&#xff0c;组件中所用到的&#xff1a;数据、方法、计算属性、钩子、自定义方法、自定义插槽、自定义Ref、监视......等等&#xff0c;均配置在setup…