关系抽取–TPLinker: https://blog.csdn.net/weixin_42223207/article/details/116425447
Tagging
TPLinker模型需要对关系三元组(subject, relation, object)进行手动Tagging,过程分为三部分:
(1)entity head to entity tail (EH-TO-ET)
(2)subject head to object head (SH-to-OH)
(3)subject tail to object tail (ST-to-OT)
标记示例见下图,EH-TO-ET用紫色表示,SH-to-OH用红色表示,ST-to-OT用蓝色表示。
论文笔记:路由器 TPLinker 也来做关系抽取:https://zhuanlan.zhihu.com/p/304104571
关系抽取之TPLinker解读加源码分析:https://zhuanlan.zhihu.com/p/342300800
关于解码过程
通过实体抽取得到字典D后,遍历关系,通过关系得到有关系的两个实体的尾部E, 再通过关系得到有关系的两个实体的头部,结合字典D得到这两个实体尾部set(s),set(o)(此为真实的标签),再通过set(s),set(o)在不在E里面,来判断是否成功抽取了一条三元组。
其中相关的公式:
文章的大概内容进行了解读,现在到了代码解读环节:
输入数据部分:
输入的seq的长度为seq_len ,获取句子的最大长度
train and valid max token num
max_tok_num = 0
all_data = train_data + valid_data for sample in all_data:tokens = tokenize(sample["text"])max_tok_num = max(max_tok_num, len(tokens))
max_tok_num # 获取句子的最大长度
接下来对文本超过长度的进行划窗处理:
if max_tok_num > hyper_parameters["max_seq_len"]: # 截断长度train_data = preprocessor.split_into_short_samples(train_data, hyper_parameters["max_seq_len"], sliding_len = hyper_parameters["sliding_len"], encoder = config["encoder"] #超过长度则滑动窗口得到新的样本) valid_data = preprocessor.split_into_short_samples(valid_data, hyper_parameters["max_seq_len"], sliding_len = hyper_parameters["sliding_len"], encoder = config["encoder"])
接下来看划窗的具体操作
def split_into_short_samples(self, sample_list, max_seq_len, sliding_len = 50, encoder = "BERT", data_type = "train"):new_sample_list = []for sample in tqdm(sample_list, desc = "Splitting into subtexts"):text_id = sample["id"]text = sample["text"]tokens = self._tokenize(text)tok2char_span = self._get_tok2char_span_map(text) #返回句子中单词的偏移量# sliding at token levelsplit_sample_list = []for start_ind in range(0, len(tokens), sliding_len): #sliding_len 滑动窗口的大小,if encoder == "BERT": # if use bert, do not split a word into two sampleswhile "##" in tokens[start_ind]:start_ind -= 1end_ind = start_ind + max_seq_len # 结束的长度char_span_list = tok2char_span[start_ind:end_ind] #截断char_level_span = [char_span_list[0][0], char_span_list[-1][1]] #第一个词到最后一个词的长度sub_text = text[char_level_span[0]:char_level_span[1]]#原始文本截断new_sample = {"id": text_id,"text": sub_text,"tok_offset": start_ind, #token的偏移量"char_offset": char_level_span[0], #每个字符的偏移量}if data_type == "test": # test setif len(sub_text) > 0:split_sample_list.append(new_sample)else: # train or valid dataset, only save spo and entities in the subtext# sposub_rel_list = []for rel in sample["relation_list"]:subj_tok_span = rel["subj_tok_span"]obj_tok_span = rel["obj_tok_span"]# if subject and object are both in this subtext, add this spo to new sampleif subj_tok_span[0] >= start_ind and subj_tok_span[1] <= end_ind \and obj_tok_span[0] >= start_ind and obj_tok_span[1] <= end_ind: new_rel = copy.deepcopy(rel)new_rel["subj_tok_span"] = [subj_tok_span[0] - start_ind, subj_tok_span[1] - start_ind] # start_ind: 单词级别的偏移量new_rel["obj_tok_span"] = [obj_tok_span[0] - start_ind, obj_tok_span[1] - start_ind]new_rel["subj_char_span"][0] -= char_level_span[0] # 字符级别的偏移量new_rel["subj_char_span"][1] -= char_level_span[0]new_rel["obj_char_span"][0] -= char_level_span[0]new_rel["obj_char_span"][1] -= char_level_span[0]sub_rel_list.append(new_rel)# entitysub_ent_list = []for ent in sample["entity_list"]:tok_span = ent["tok_span"]# if entity in this subtext, add the entity to new sampleif tok_span[0] >= start_ind and tok_span[1] <= end_ind: new_ent = copy.deepcopy(ent)new_ent["tok_span"] = [tok_span[0] - start_ind, tok_span[1] - start_ind]new_ent["char_span"][0] -= char_level_span[0]new_ent["char_span"][1] -= char_level_span[0]sub_ent_list.append(new_ent)# eventif "event_list" in sample:sub_event_list = []for event in sample["event_list"]:trigger_tok_span = event["trigger_tok_span"]if trigger_tok_span[1] > end_ind or trigger_tok_span[0] < start_ind:continuenew_event = copy.deepcopy(event)new_arg_list = []for arg in new_event["argument_list"]:if arg["tok_span"][0] >= start_ind and arg["tok_span"][1] <= end_ind:new_arg_list.append(arg)new_event["argument_list"] = new_arg_listsub_event_list.append(new_event)new_sample["event_list"] = sub_event_list # maybe emptynew_sample["entity_list"] = sub_ent_list # maybe emptynew_sample["relation_list"] = sub_rel_list # maybe emptysplit_sample_list.append(new_sample)# all segments covered, no need to continueif end_ind > len(tokens):breaknew_sample_list.extend(split_sample_list)return new_sample_list
输入数据,DataMaker4Bert中定义:
class DataMaker4Bert():def __init__(self, tokenizer, handshaking_tagger):self.tokenizer = tokenizerself.handshaking_tagger = handshaking_taggerdef get_indexed_data(self, data, max_seq_len, data_type = "train"): #index转换为dataindexed_samples = []for ind, sample in tqdm(enumerate(data), desc = "Generate indexed train or valid data"):text = sample["text"]# codes for bert inputcodes = self.tokenizer.encode_plus(text, return_offsets_mapping = True, add_special_tokens = False,max_length = max_seq_len, truncation = True,pad_to_max_length = True)# taggingspots_tuple = Noneif data_type != "test":spots_tuple = self.handshaking_tagger.get_spots(sample) #获取实体,头,尾标签# get codesinput_ids = torch.tensor(codes["input_ids"]).long()attention_mask = torch.tensor(codes["attention_mask"]).long()token_type_ids = torch.tensor(codes["token_type_ids"]).long()tok2char_span = codes["offset_mapping"]sample_tp = (sample,input_ids,attention_mask,token_type_ids,tok2char_span,spots_tuple,)indexed_samples.append(sample_tp) return indexed_samples
输入的是tokenizer和handshakingtagger,tokenizer为bert等一系列模型的标准输入,而get_spots函数获取了实体,头,尾的标签,具体看下代码
def get_spots(self, sample):'''entity spot and tail_rel spot: (span_pos1, span_pos2, tag_id)head_rel spot: (rel_id, span_pos1, span_pos2, tag_id)'''ent_matrix_spots, head_rel_matrix_spots, tail_rel_matrix_spots = [], [], [] for rel in sample["relation_list"]:subj_tok_span = rel["subj_tok_span"]obj_tok_span = rel["obj_tok_span"]ent_matrix_spots.append((subj_tok_span[0], subj_tok_span[1] - 1, self.tag2id_ent["ENT-H2T"])) #sub token的[起始位置,尾部位置,实体标签(1)]ent_matrix_spots.append((obj_tok_span[0], obj_tok_span[1] - 1, self.tag2id_ent["ENT-H2T"]))# obj token的[起始位置,尾部位置,实体标签(1)]if subj_tok_span[0] <= obj_tok_span[0]:head_rel_matrix_spots.append((self.rel2id[rel["predicate"]], subj_tok_span[0], obj_tok_span[0], self.tag2id_head_rel["REL-SH2OH"]))#【关系类别,实体_1 头部,实体_2头部,关系标签(1)】else:head_rel_matrix_spots.append((self.rel2id[rel["predicate"]], obj_tok_span[0], subj_tok_span[0], self.tag2id_head_rel["REL-OH2SH"]))#【关系类别,实体_1 头部,实体_2头部,关系标签(2)】if subj_tok_span[1] <= obj_tok_span[1]:tail_rel_matrix_spots.append((self.rel2id[rel["predicate"]], subj_tok_span[1] - 1, obj_tok_span[1] - 1, self.tag2id_tail_rel["REL-ST2OT"]))#【关系类别,实体_1 尾部,实体_2尾部,关系标签(1)】else:tail_rel_matrix_spots.append((self.rel2id[rel["predicate"]], obj_tok_span[1] - 1, subj_tok_span[1] - 1, self.tag2id_tail_rel["REL-OT2ST"]))#【关系类别,实体_1 尾部,实体_2尾部,关系标签(2)】return ent_matrix_spots, head_rel_matrix_spots, tail_rel_matrix_spots
获取输入的数据indexed_train_data = data_maker.get_indexed_data(train_data, max_seq_len) #获取输入
# index_train_data = data_maker.get_indexed_data(train_test_data,max_seq_len)
indexed_valid_data = data_maker.get_indexed_data(valid_data, max_seq_len)
tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"], add_special_tokens = False, do_lower_case = False)
data_maker = DataMaker4Bert(tokenizer, handshaking_tagger) #(sample,input_ids,attention_mask,token_type_ids,tok2char_span,spots_tuple,)
接下来则是定义HandshakingTaggingScheme
max_seq_len = min(max_tok_num, hyper_parameters["max_seq_len"]) #max_len 长度
rel2id = json.load(open(rel2id_path, "r", encoding = "utf-8"))
handshaking_tagger = HandshakingTaggingScheme(rel2id = rel2id, max_seq_len = max_seq_len) #初始化
查看具体的定义
class HandshakingTaggingScheme(object):"""docstring for HandshakingTaggingScheme"""def __init__(self, rel2id, max_seq_len):super(HandshakingTaggingScheme, self).__init__()self.rel2id = rel2idself.id2rel = {ind:rel for rel, ind in rel2id.items()}self.tag2id_ent = { #实体头尾"O": 0,"ENT-H2T": 1, # entity head to entity tail}self.id2tag_ent = {id_:tag for tag, id_ in self.tag2id_ent.items()}self.tag2id_head_rel = { #sub,obj头对头标识1,obj头对sub头标识2"O": 0,"REL-SH2OH": 1, # subject head to object head"REL-OH2SH": 2, # object head to subject head}self.id2tag_head_rel = {id_:tag for tag, id_ in self.tag2id_head_rel.items()}self.tag2id_tail_rel = {"O": 0, "REL-ST2OT": 1, # subject tail to object tail"REL-OT2ST": 2, # object tail to subject tail}self.id2tag_tail_rel = {id_:tag for tag, id_ in self.tag2id_tail_rel.items()}# mapping shaking sequence and matrixself.matrix_size = max_seq_len# e.g. [(0, 0), (0, 1), (0, 2), (1, 1), (1, 2), (2, 2)] #转换成矩阵上三角矩阵平铺self.shaking_ind2matrix_ind = [(ind, end_ind) for ind in range(self.matrix_size) for end_ind in list(range(self.matrix_size))[ind:]]self.matrix_ind2shaking_ind = [[0 for i in range(self.matrix_size)] for j in range(self.matrix_size)]for shaking_ind, matrix_ind in enumerate(self.shaking_ind2matrix_ind): #上三角矩阵,上三角每个元素储存着上三角铺平序列的相对应的位置序号self.matrix_ind2shaking_ind[matrix_ind[0]][matrix_ind[1]] = shaking_ind
这里比较关键的是shaking_ind2matrix_ind,与matrix_ind2shaking_ind,其中shaking_ind2matrix_ind如下所示,是一个上三角铺平序列
[(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), (0, 8), (0, 9), (0, 10), (0, 11), (0, 12), (0, 13), …]
而matrix_ind2shaking_ind为优化前的二维矩阵,其中上三角每个元素储存着上三角铺平序列的相对应的位置序号
[[0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], …]
最后组成上三角矩阵(打印的结果没有填满)
[[0, 1, 2, 3, 4, 5, 6, 7, 8, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], [0, 0, 0, 0, 0, 0, 0, 0, 0, …], …]
载入数据
train_dataloader = DataLoader(MyDataset(indexed_train_data), batch_size = hyper_parameters["batch_size"], shuffle = False, num_workers = 5,drop_last = False,collate_fn = data_maker.generate_batch,
)
查看DataLoader的返回值:
def generate_batch(self, batch_data, data_type = "train"):sample_list = []input_ids_list = []attention_mask_list = []token_type_ids_list = [] tok2char_span_list = []ent_spots_list = []head_rel_spots_list = []tail_rel_spots_list = []for tp in batch_data:sample_list.append(tp[0])input_ids_list.append(tp[1])attention_mask_list.append(tp[2]) token_type_ids_list.append(tp[3]) tok2char_span_list.append(tp[4])if data_type != "test":ent_matrix_spots, head_rel_matrix_spots, tail_rel_matrix_spots = tp[5]ent_spots_list.append(ent_matrix_spots)head_rel_spots_list.append(head_rel_matrix_spots)tail_rel_spots_list.append(tail_rel_matrix_spots)# @specific: indexed by bert tokenizerbatch_input_ids = torch.stack(input_ids_list, dim = 0)batch_attention_mask = torch.stack(attention_mask_list, dim = 0)batch_token_type_ids = torch.stack(token_type_ids_list, dim = 0)batch_ent_shaking_tag, batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag = None, None, Noneif data_type != "test":batch_ent_shaking_tag = self.handshaking_tagger.sharing_spots2shaking_tag4batch(ent_spots_list)batch_head_rel_shaking_tag = self.handshaking_tagger.spots2shaking_tag4batch(head_rel_spots_list)batch_tail_rel_shaking_tag = self.handshaking_tagger.spots2shaking_tag4batch(tail_rel_spots_list)return sample_list, \batch_input_ids, batch_attention_mask, batch_token_type_ids, tok2char_span_list, \batch_ent_shaking_tag, batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag
其中比较重要的是batch_ent_shaking_tag, batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag这三个函数,在实体batch_ent_shaking_tag中用到了handshaking_tagger的sharing_spots2shaking_tag4batch的函数
def sharing_spots2shaking_tag4batch(self, batch_spots):'''convert spots to batch shaking seq tag因长序列的stack是费时操作,所以写这个函数用作生成批量shaking tag如果每个样本生成一条shaking tag再stack,一个32的batch耗时1s,太昂贵spots: [(start_ind, end_ind, tag_id), ], for entiyreturn: batch_shake_seq_tag: (batch_size, shaking_seq_len)'''shaking_seq_len = self.matrix_size * (self.matrix_size + 1) // 2batch_shaking_seq_tag = torch.zeros(len(batch_spots), shaking_seq_len).long()for batch_id, spots in enumerate(batch_spots):for sp in spots:shaking_ind = self.matrix_ind2shaking_ind[sp[0]][sp[1]] #在矩阵中找到实体的的start_int,跟end_int的位置值tag_id = sp[2]batch_shaking_seq_tag[batch_id][shaking_ind] = tag_id #铺平上三角矩阵,标注序列中的实体标识1的位置return batch_shaking_seq_tag
在关系batch_head_rel_shaking_tag与batch_tail_rel_shaking_tag中用到了spots2shaking_tag4batch
def spots2shaking_tag4batch(self, batch_spots):
‘’’
convert spots to batch shaking seq tag
spots: [(rel_id, start_ind, end_ind, tag_id), ], for head relation and tail_relation
return:
batch_shake_seq_tag: (batch_size, rel_size, shaking_seq_len)
‘’’
shaking_seq_len = self.matrix_size * (self.matrix_size + 1) // 2
batch_shaking_seq_tag = torch.zeros(len(batch_spots), len(self.rel2id), shaking_seq_len).long()
for batch_id, spots in enumerate(batch_spots):
for sp in spots:
shaking_ind = self.matrix_ind2shaking_ind[sp[1]][sp[2]]
tag_id = sp[3]
rel_id = sp[0]
batch_shaking_seq_tag[batch_id][rel_id][shaking_ind] = tag_id
return batch_shaking_seq_tag
跟实体类似,只是多了关系总数,然后整个函数的返回为
return sample_list,
batch_input_ids, tok2char_span_list,
batch_ent_shaking_tag, batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag
初始化模型
rel_extractor = TPLinkerBert(encoder,
len(rel2id),
hyper_parameters[“shaking_type”],
hyper_parameters[“inner_enc_type”],
hyper_parameters[“dist_emb_size”],
hyper_parameters[“ent_add_dist”],
hyper_parameters[“rel_add_dist”],
)
模型的具体定义
class TPLinkerBert(nn.Module):def __init__(self, encoder, rel_size, shaking_type,inner_enc_type,dist_emb_size,ent_add_dist,rel_add_dist):super().__init__()self.encoder = encoderhidden_size = encoder.config.hidden_sizeself.ent_fc = nn.Linear(hidden_size, 2) #实体预测,0,1self.head_rel_fc_list = [nn.Linear(hidden_size, 3) for _ in range(rel_size)] #rel_size多少种关系self.tail_rel_fc_list = [nn.Linear(hidden_size, 3) for _ in range(rel_size)]# 对每个关系进行个linear层的3分类【0,1,2】for ind, fc in enumerate(self.head_rel_fc_list):self.register_parameter("weight_4_head_rel{}".format(ind), fc.weight) #过3层全连接层self.register_parameter("bias_4_head_rel{}".format(ind), fc.bias) #偏差for ind, fc in enumerate(self.tail_rel_fc_list):self.register_parameter("weight_4_tail_rel{}".format(ind), fc.weight)self.register_parameter("bias_4_tail_rel{}".format(ind), fc.bias)# handshaking kernelself.handshaking_kernel = HandshakingKernel(hidden_size, shaking_type, inner_enc_type)# distance embeddingself.dist_emb_size = dist_emb_sizeself.dist_embbedings = None # it will be set in the first forwardingself.ent_add_dist = ent_add_distself.rel_add_dist = rel_add_dist
self.head_rel_fc_list与self.tail_rel_fc_list相同,是关系的一个全连接(标签有三个0,1,2),每一种关系有一个独立的MLP层,self.head_rel_fc_list是列表的形式。
上面为关系和实体关系预测,实体和各个关系都经过了mlp层,我们假设有5种关系,则会有11层MLP即为:1个实体预测层+(1个头部层+1个尾部层)*5
def forward(self, input_ids, attention_mask, token_type_ids):# input_ids, attention_mask, token_type_ids: (batch_size, seq_len)context_outputs = self.encoder(input_ids, attention_mask, token_type_ids) # 0 last_hidden 1 pooled# last_hidden_state: (batch_size, seq_len, hidden_size)last_hidden_state = context_outputs[0]# shaking_hiddens: (batch_size, 1 + ... + seq_len, hidden_size)shaking_hiddens = self.handshaking_kernel(last_hidden_state) #铺平上三角矩阵shaking_hiddens4ent = shaking_hiddensshaking_hiddens4rel = shaking_hiddens# add distance embeddings if it is setif self.dist_emb_size != -1:# set self.dist_embbedingshidden_size = shaking_hiddens.size()[-1]if self.dist_embbedings is None:dist_emb = torch.zeros([self.dist_emb_size, hidden_size]).to(shaking_hiddens.device)for d in range(self.dist_emb_size):for i in range(hidden_size):if i % 2 == 0:dist_emb[d][i] = math.sin(d / 10000**(i / hidden_size))else:dist_emb[d][i] = math.cos(d / 10000**((i - 1) / hidden_size))seq_len = input_ids.size()[1]dist_embbeding_segs = []for after_num in range(seq_len, 0, -1): #铺平dist_embbeding_segs.append(dist_emb[:after_num, :])self.dist_embbedings = torch.cat(dist_embbeding_segs, dim = 0)if self.ent_add_dist:shaking_hiddens4ent = shaking_hiddens + self.dist_embbedings[None,:,:].repeat(shaking_hiddens.size()[0], 1, 1)if self.rel_add_dist:shaking_hiddens4rel = shaking_hiddens + self.dist_embbedings[None,:,:].repeat(shaking_hiddens.size()[0], 1, 1)# if self.dist_emb_size != -1 and self.ent_add_dist:
# shaking_hiddens4ent = shaking_hiddens + self.dist_embbedings[None,:,:].repeat(shaking_hiddens.size()[0], 1, 1)
# else:
# shaking_hiddens4ent = shaking_hiddens
# if self.dist_emb_size != -1 and self.rel_add_dist:
# shaking_hiddens4rel = shaking_hiddens + self.dist_embbedings[None,:,:].repeat(shaking_hiddens.size()[0], 1, 1)
# else:
# shaking_hiddens4rel = shaking_hiddensent_shaking_outputs = self.ent_fc(shaking_hiddens4ent) #实体预测,(0,1)head_rel_shaking_outputs_list = []nn.ModuleList()for fc in self.head_rel_fc_list:head_rel_shaking_outputs_list.append(fc(shaking_hiddens4rel)) #对每一种关系头进行分类tail_rel_shaking_outputs_list = []for fc in self.tail_rel_fc_list: #对每一种关系尾进行分类tail_rel_shaking_outputs_list.append(fc(shaking_hiddens4rel))head_rel_shaking_outputs = torch.stack(head_rel_shaking_outputs_list, dim = 1) #n种关系拼接在一起tail_rel_shaking_outputs = torch.stack(tail_rel_shaking_outputs_list, dim = 1) #n种关系拼接在一起return ent_shaking_outputs, head_rel_shaking_outputs, tail_rel_shaking_outputs
ent_shaking_outputs为实体预测,head_rel_shaking_outputs_list对关系头进行分类,tail_rel_shaking_outputs_list对关系尾进行分类,其中关键函数shaking_hiddens4ent中的HandshakingKernel函数定义如下
class HandshakingKernel(nn.Module):def __init__(self, hidden_size, shaking_type, inner_enc_type):super().__init__()self.shaking_type = shaking_typeif shaking_type == "cat":self.combine_fc = nn.Linear(hidden_size * 2, hidden_size) #fc层elif shaking_type == "cat_plus":self.combine_fc = nn.Linear(hidden_size * 3, hidden_size)elif shaking_type == "cln":self.tp_cln = LayerNorm(hidden_size, hidden_size, conditional = True)elif shaking_type == "cln_plus":self.tp_cln = LayerNorm(hidden_size, hidden_size, conditional = True)self.inner_context_cln = LayerNorm(hidden_size, hidden_size, conditional = True)self.inner_enc_type = inner_enc_type #一层单向lstmif inner_enc_type == "mix_pooling":self.lamtha = Parameter(torch.rand(hidden_size))elif inner_enc_type == "lstm":self.inner_context_lstm = nn.LSTM(hidden_size, hidden_size, num_layers = 1, bidirectional = False, batch_first = True)def enc_inner_hiddens(self, seq_hiddens, inner_enc_type = "lstm"):# seq_hiddens: (batch_size, seq_len, hidden_size)def pool(seqence, pooling_type):if pooling_type == "mean_pooling":pooling = torch.mean(seqence, dim = -2)elif pooling_type == "max_pooling":pooling, _ = torch.max(seqence, dim = -2)elif pooling_type == "mix_pooling":pooling = self.lamtha * torch.mean(seqence, dim = -2) + (1 - self.lamtha) * torch.max(seqence, dim = -2)[0]return poolingif "pooling" in inner_enc_type:inner_context = torch.stack([pool(seq_hiddens[:, :i+1, :], inner_enc_type) for i in range(seq_hiddens.size()[1])], dim = 1)elif inner_enc_type == "lstm":inner_context, _ = self.inner_context_lstm(seq_hiddens)return inner_contextdef forward(self, seq_hiddens):'''seq_hiddens: (batch_size, seq_len, hidden_size)return:shaking_hiddenss: (batch_size, (1 + seq_len) * seq_len / 2, hidden_size) (32, 5+4+3+2+1, 5)'''#一句话中每个字与剩下的字构成上三角矩阵如:长度为5则的到的为[[batch,5,hidden_size],[batch,4,hidden_size]...]seq_len = seq_hiddens.size()[-2] #句子的长度shaking_hiddens_list = []for ind in range(seq_len):hidden_each_step = seq_hiddens[:, ind, :] #取每个batch的每个字的维度visible_hiddens = seq_hiddens[:, ind:, :] # 从当前取到最后repeat_hiddens = hidden_each_step[:, None, :].repeat(1, seq_len - ind, 1) #复制dim=1的维度跟visible维度保持一致if self.shaking_type == "cat":#选择的是cat模式,可以在配置文件中设置shaking_hiddens = torch.cat([repeat_hiddens, visible_hiddens], dim = -1) #将当前每个字的维度与其后的每个字的维度拼接在一起shaking_hiddens = torch.tanh(self.combine_fc(shaking_hiddens))#过一个线性层elif self.shaking_type == "cat_plus":inner_context = self.enc_inner_hiddens(visible_hiddens, self.inner_enc_type)shaking_hiddens = torch.cat([repeat_hiddens, visible_hiddens, inner_context], dim = -1)shaking_hiddens = torch.tanh(self.combine_fc(shaking_hiddens))elif self.shaking_type == "cln":shaking_hiddens = self.tp_cln(visible_hiddens, repeat_hiddens)elif self.shaking_type == "cln_plus":inner_context = self.enc_inner_hiddens(visible_hiddens, self.inner_enc_type)shaking_hiddens = self.tp_cln(visible_hiddens, repeat_hiddens)shaking_hiddens = self.inner_context_cln(shaking_hiddens, inner_context)shaking_hiddens_list.append(shaking_hiddens) #添加到列表中long_shaking_hiddens = torch.cat(shaking_hiddens_list, dim = 1)#铺平上三角矩阵return long_shaking_hiddens
输入的seq_hiddens维度是[batch,seq_len, hiddensize],是一句话经过bert编码过后的值,而HandshakingKernel函数的作用是将矩阵变为上三角矩阵,即本身矩阵为[seq_len * seq_len],在经过函数过后为每一行都减1,最后通过long_shakinghiddens把函数把结果铺平,得到[seq_len+(seq_len -1) + (seq_len -2)…+1],对应了图片部分。
整个函数先是循环每句话中的词,当ind是0时,hidden_each_step代表了循环的每个词的编码[batch,1,hidden_size],visiblehiddens是循环到的这个单词以及之后的单词的编码,维度就是[batch,seq_len,hidden_size],repeat_hiddens对hidden_each_step的第二个维度进行了复制,维度为[batch,seq_len,hidden_size],将当前单词和其后的各个单词的编码进行拼接维度是[batch,seq_len,hidden_size*2]组成上三角矩阵的一行,在经过MLP层后shakinghiddens的维度是[batch,seq_len,hidden_size],之后每一行依次类推。
关于loss部分
total_loss, total_ent_sample_acc, total_head_rel_sample_acc, total_tail_rel_sample_acc = 0., 0., 0., 0.for batch_ind, batch_train_data in enumerate(dataloader):t_batch = time.time()z = (2 * len(rel2id) + 1) # 2倍的关系steps_per_ep = len(dataloader) #有多少数据total_steps = hyper_parameters["loss_weight_recover_steps"] + 1 # + 1 avoid division by zero error #加速loss在一定的步数回归current_step = steps_per_ep * ep + batch_ind # ?w_ent = max(1 / z + 1 - current_step / total_steps, 1 / z)w_rel = min((len(rel2id) / z) * current_step / total_steps, (len(rel2id) / z))loss_weights = {"ent": w_ent, "rel": w_rel} #给予不同任务的权重loss, ent_sample_acc, head_rel_sample_acc, tail_rel_sample_acc = train_step(batch_train_data, optimizer, loss_weights)scheduler.step()total_loss += losstotal_ent_sample_acc += ent_sample_acctotal_head_rel_sample_acc += head_rel_sample_acctotal_tail_rel_sample_acc += tail_rel_sample_accavg_loss = total_loss / (batch_ind + 1)avg_ent_sample_acc = total_ent_sample_acc / (batch_ind + 1)avg_head_rel_sample_acc = total_head_rel_sample_acc / (batch_ind + 1)avg_tail_rel_sample_acc = total_tail_rel_sample_acc / (batch_ind + 1)
随着step加大,w_ent的权重递减,w_rel权重递增。先关注实体,保证实体抽准确,后面关注关系的抽取,由于目前工作原因,更多细节待闲时在进行解读。
百度信息抽取Lic2020关系抽取:https://zhuanlan.zhihu.com/p/138858558