BERT-CRF 微调中文 NER 模型

文章目录

  • 数据集
  • 模型定义
  • 数据集预处理
    • BIO 标签转换
    • 自定义Dataset
    • 拆分训练、测试集
  • 训练
  • 验证、测试
  • 指标计算
  • 推理
  • 其它
    • 相关参数
    • CRF 模块

数据集

  • CLUE-NER数据集:https://github.com/CLUEbenchmark/CLUENER2020/blob/master/pytorch_version/README.md
    在这里插入图片描述

模型定义

import torch
import torch.nn as nn
from pytorch_crf import CRF
from transformers import BertPreTrainedModel, BertModelclass BertCrfForNer(BertPreTrainedModel):def __init__(self, config):super(BertCrfForNer, self).__init__(config)self.bert = BertModel(config)self.dropout = nn.Dropout(config.hidden_dropout_prob)self.classifier = nn.Linear(config.hidden_size, config.num_labels)self.crf = CRF(num_tags=config.num_labels, batch_first=True)self.num_labels = config.num_labelsself.init_weights()def forward(self, input_ids, token_type_ids=None, attention_mask=None,labels=None,input_lens=None):outputs =self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)sequence_output = outputs[0]sequence_output = self.dropout(sequence_output)logits = self.classifier(sequence_output)outputs = (logits,)if labels is not None:loss = self.crf(emissions = logits, tags=labels, mask=attention_mask)outputs =(-1*loss,)+outputsreturn outputs # (loss), scores

其中 CRF 模块 pytorch_crf.py 见后文。

数据集预处理

BIO 标签转换

ALLOW_LABEL = ["name", "organization", "address","company","government"]def generate_bio_tags(tokenizer, text_json, allowed_type = ALLOW_LABEL):def tokenize_with_location(tokenizer, input_data):encoded_input = tokenizer.encode_plus(input_data, return_offsets_mapping=True)return list(zip([tokenizer.decode(i) for i in  encoded_input.input_ids],encoded_input.offset_mapping))def get_bio_tag(labels, token_start, token_end):if token_start >= token_end:return "O"for entity_type, entities in labels.items():if entity_type in allowed_type:for entity_name, positions in entities.items():for position in positions:start, end = positionif token_start >= start and token_end <= end+1:if token_start == start:return f"B-{entity_type}"else:return f"I-{entity_type}"return "O"text = text_json["text"]labels = text_json["label"]# 使用BERT分词器进行分词tokenized_text = tokenize_with_location(tokenizer, text)tokens, bio_tags = [], []for token, loc in tokenized_text:loc_s, loc_e = locbio_tag = get_bio_tag(labels, loc_s, loc_e)bio_tags.append(bio_tag)tokens.append(token)return tokens, bio_tags# 输入JSON数据
input_json = {"text": "你们是最棒的!#英雄联盟d学sanchez创作的原声王", "label": {"game": {"英雄联盟": [[8, 11]]}}}
generate_bio_tags(tokenizer, input_json)
"""
(['[CLS]','你','们','是','最','棒','的','!','#','英','雄','联','盟','d','学','san','##che','##z','创','作','的','原','声','王','[SEP]'],['O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O','O'])"""

自定义Dataset

from tqdm.notebook import tqdm
import json
import pickle
import oscached_dataset = 'train.dataset.pkl'
train_file = 'train.json'
if not os.path.exists(cached_dataset):dataset = []with open(train_file, 'r') as file:for line in tqdm(file.readlines()):data = json.loads(line.strip())tokens, bio_tags = generate_bio_tags(tokenizer, data)if len(set(bio_tags)) > 1:dataset.append({"text": data["text"], "tokens": tokens, "tags": bio_tags})with open(cached_dataset, 'wb') as f:pickle.dump(dataset, f)else:with open(cached_dataset, 'rb') as f:dataset = pickle.load(f)

先把原始数据 {“text”: …, “label”: … } 转换成 {“text”: … , “tokens”: …, “tags”: …}

from itertools import product
from torch.utils.data import Dataset, DataLoaderlabels = ["O"] + [f"{i}-{j}" for i,j in product(['B','I'], ALLOW_LABEL)]
label2id = {k: v for v, k in enumerate(labels)}
id2label = {v: k for v, k in enumerate(labels)}class BertDataset(Dataset):def __init__(self, dataset, tokenizer, max_len):self.len = len(dataset)self.data = datasetself.tokenizer = tokenizerself.max_len = max_lendef __getitem__(self, index):# step 1: tokenize (and adapt corresponding labels)item = self.data[index]# step 2: add special tokens (and corresponding labels)tokenized_sentence = item["tokens"]labels = item["tags"] # add outside label for [CLS] token# step 3: truncating/paddingmaxlen = self.max_lenif (len(tokenized_sentence) > maxlen):# truncatetokenized_sentence = tokenized_sentence[:maxlen]labels = labels[:maxlen]else:# padtokenized_sentence = tokenized_sentence + ['[PAD]'for _ in range(maxlen - len(tokenized_sentence))]labels = labels + ["O" for _ in range(maxlen - len(labels))]# step 4: obtain the attention maskattn_mask = [1 if tok != '[PAD]' else 0 for tok in tokenized_sentence]# step 5: convert tokens to input idsids = self.tokenizer.convert_tokens_to_ids(tokenized_sentence)label_ids = [label2id[label] for label in labels]# the following line is deprecated#label_ids = [label if label != 0 else -100 for label in label_ids]return {'ids': torch.tensor(ids, dtype=torch.long),'mask': torch.tensor(attn_mask, dtype=torch.long),#'token_type_ids': torch.tensor(token_ids, dtype=torch.long),'targets': torch.tensor(label_ids, dtype=torch.long)} def __len__(self):return self.len

拆分训练、测试集

import numpy as np
import random
def split_train_test_valid(dataset, train_size=0.9, test_size=0.1):dataset = np.array(dataset)total_size = len(dataset)# define the ratiostrain_len = int(total_size * train_size)test_len = int(total_size * test_size)# split the dataframeidx = list(range(total_size))random.shuffle(idx)  # 将index列表打乱data_train = dataset[idx[:train_len]]data_test = dataset[idx[train_len:train_len+test_len]]data_valid = dataset[idx[train_len+test_len:]]  # 剩下的就是validreturn data_train, data_test, data_validdata_train, data_test, data_valid = split_train_test_valid(dataset)
print("FULL Dataset: {}".format(len(dataset)))
print("TRAIN Dataset: {}".format(data_train.shape))
print("TEST Dataset: {}".format(data_test.shape))training_set = BertDataset(data_train, tokenizer, MAX_LEN)
testing_set = BertDataset(data_test, tokenizer, MAX_LEN)
train_params = {'batch_size': TRAIN_BATCH_SIZE,'shuffle': True,'num_workers': 0}test_params = {'batch_size': VALID_BATCH_SIZE,'shuffle': True,'num_workers': 0}
training_loader = DataLoader(training_set, **train_params)
testing_loader = DataLoader(testing_set, **test_params)

训练

model = BertCrfForNer.from_pretrained('models/bert-base-chinese',
# model = AutoModelForTokenClassification.from_pretrained('save_model',num_labels=len(id2label),id2label=id2label,label2id=label2id)
if MULTI_GPU:model = torch.nn.DataParallel(model, )
model.to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
from sklearn.metrics import accuracy_score
import warnings
warnings.filterwarnings('ignore')def train(epoch):tr_loss, tr_accuracy = 0, 0nb_tr_examples, nb_tr_steps = 0, 0tr_preds, tr_labels = [], []# put model in training modemodel.train()for idx, batch in enumerate(training_loader):ids = batch['ids'].to(device, dtype = torch.long)mask = batch['mask'].to(device, dtype = torch.long)targets = batch['targets'].to(device, dtype = torch.long)outputs = model(input_ids=ids, attention_mask=mask, labels=targets)
#         loss, tr_logits = outputs.loss, outputs.logitsloss, tr_logits = outputs[0], outputs[1]if MULTI_GPU:loss = loss.mean()tr_loss += loss.item()nb_tr_steps += 1nb_tr_examples += targets.size(0)if idx % 100==0:loss_step = tr_loss/nb_tr_stepsprint(f"Training loss per 100 training steps: {loss_step}")# compute training accuracyflattened_targets = targets.view(-1) # shape (batch_size * seq_len,)num_labels = model.module.num_labels if MULTI_GPU else model.num_labelsactive_logits = tr_logits.view(-1, num_labels) # shape (batch_size * seq_len, num_labels)flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size * seq_len,)# now, use mask to determine where we should compare predictions with targets (includes [CLS] and [SEP] token predictions)active_accuracy = mask.view(-1) == 1 # active accuracy is also of shape (batch_size * seq_len,)targets = torch.masked_select(flattened_targets, active_accuracy)predictions = torch.masked_select(flattened_predictions, active_accuracy)tr_preds.extend(predictions)tr_labels.extend(targets)tmp_tr_accuracy = accuracy_score(targets.cpu().numpy(), predictions.cpu().numpy())tr_accuracy += tmp_tr_accuracy# gradient clippingtorch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=MAX_GRAD_NORM)# backward passoptimizer.zero_grad()loss.backward()optimizer.step()epoch_loss = tr_loss / nb_tr_stepstr_accuracy = tr_accuracy / nb_tr_stepsprint(f"Training loss epoch: {epoch_loss}")print(f"Training accuracy epoch: {tr_accuracy}")for epoch in range(EPOCHS):print(f"Training epoch: {epoch + 1}")train(epoch)
"""
Training epoch: 1
Training loss per 100 training steps: 76.82186126708984
Training loss per 100 training steps: 26.512494955912675
Training loss per 100 training steps: 18.23713019356799
Training loss per 100 training steps: 14.71561597431221
Training loss per 100 training steps: 12.793566083075698
Training loss epoch: 12.138352865534845
Training accuracy epoch: 0.9093487211512798
"""

验证、测试

def valid(model, testing_loader):# put model in evaluation modemodel.eval()eval_loss, eval_accuracy = 0, 0nb_eval_examples, nb_eval_steps = 0, 0eval_preds, eval_labels = [], []with torch.no_grad():for idx, batch in enumerate(testing_loader):ids = batch['ids'].to(device, dtype = torch.long)mask = batch['mask'].to(device, dtype = torch.long)targets = batch['targets'].to(device, dtype = torch.long)outputs = model(input_ids=ids, attention_mask=mask, labels=targets)loss, eval_logits = outputs[0], outputs[1]if MULTI_GPU:loss = loss.mean()eval_loss += loss.item()nb_eval_steps += 1nb_eval_examples += targets.size(0)if idx % 100==0:loss_step = eval_loss/nb_eval_stepsprint(f"Validation loss per 100 evaluation steps: {loss_step}")# compute evaluation accuracyflattened_targets = targets.view(-1) # shape (batch_size * seq_len,)num_labels = model.module.num_labels if MULTI_GPU else model.num_labelsactive_logits = eval_logits.view(-1, num_labels) # shape (batch_size * seq_len, num_labels)flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size * seq_len,)# now, use mask to determine where we should compare predictions with targets (includes [CLS] and [SEP] token predictions)active_accuracy = mask.view(-1) == 1 # active accuracy is also of shape (batch_size * seq_len,)targets = torch.masked_select(flattened_targets, active_accuracy)predictions = torch.masked_select(flattened_predictions, active_accuracy)eval_labels.extend(targets)eval_preds.extend(predictions)tmp_eval_accuracy = accuracy_score(targets.cpu().numpy(), predictions.cpu().numpy())eval_accuracy += tmp_eval_accuracy#print(eval_labels)#print(eval_preds)labels = [id2label[id.item()] for id in eval_labels]predictions = [id2label[id.item()] for id in eval_preds]#print(labels)#print(predictions)eval_loss = eval_loss / nb_eval_stepseval_accuracy = eval_accuracy / nb_eval_stepsprint(f"Validation Loss: {eval_loss}")print(f"Validation Accuracy: {eval_accuracy}")return labels, predictionslabels, predictions = valid(model, testing_loader)
"""
Validation loss per 100 evaluation steps: 5.371463775634766
Validation Loss: 5.623965330123902
Validation Accuracy: 0.9547014622783095
"""

指标计算

from seqeval.metrics import classification_reportprint(classification_report([labels], [predictions]))
"""precision    recall  f1-score   supportaddress       0.50      0.62      0.55       316company       0.65      0.77      0.70       270government       0.69      0.85      0.76       208name       0.87      0.87      0.87       374
organization       0.76      0.82      0.79       343micro avg       0.69      0.79      0.74      1511macro avg       0.69      0.79      0.73      1511
weighted avg       0.70      0.79      0.74      1511
"""

推理

from transformers import pipelinemodel_to_test = (model.module if hasattr(model, "module") else model
)
pipe = pipeline(task="token-classification", model=model_to_test.to("cpu"), tokenizer=tokenizer, aggregation_strategy="simple")pipe("我的名字是michal johnson,我的手机号是13425456344,我家住在东北松花江上8幢7单元6楼5号房")
"""
[{'entity_group': 'name','score': 0.83746755,'word': 'michal johnson','start': 5,'end': 19},{'entity_group': 'address','score': 0.924768,'word': '东 北 松 花 江 上 8 幢 7 单 元 6 楼 5 号 房','start': 42,'end': 58}]
"""

其它

相关参数

import torch
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1,3'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')MAX_LEN = 128
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 32
EPOCHS = 1
LEARNING_RATE = 1e-05
MAX_GRAD_NORM = 10
MULTI_GPU = False
ALLOW_LABEL = ["name", "organization", "address","company","government"]

CRF 模块

参考:https://github.com/CLUEbenchmark/CLUENER2020/blob/master/pytorch_version/models/crf.py

import torch
import torch.nn as nn
from typing import List, Optionalclass CRF(nn.Module):"""Conditional random field.This module implements a conditional random field [LMP01]_. The forward computationof this class computes the log likelihood of the given sequence of tags andemission score tensor. This class also has `~CRF.decode` method which findsthe best tag sequence given an emission score tensor using `Viterbi algorithm`_.Args:num_tags: Number of tags.batch_first: Whether the first dimension corresponds to the size of a minibatch.Attributes:start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size``(num_tags,)``.end_transitions (`~torch.nn.Parameter`): End transition score tensor of size``(num_tags,)``.transitions (`~torch.nn.Parameter`): Transition score tensor of size``(num_tags, num_tags)``... [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001)."Conditional random fields: Probabilistic models for segmenting andlabeling sequence data". *Proc. 18th International Conf. on MachineLearning*. Morgan Kaufmann. pp. 282–289... _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm"""def __init__(self, num_tags: int, batch_first: bool = False) -> None:if num_tags <= 0:raise ValueError(f'invalid number of tags: {num_tags}')super().__init__()self.num_tags = num_tagsself.batch_first = batch_firstself.start_transitions = nn.Parameter(torch.empty(num_tags))self.end_transitions = nn.Parameter(torch.empty(num_tags))self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))self.reset_parameters()def reset_parameters(self) -> None:"""Initialize the transition parameters.The parameters will be initialized randomly from a uniform distributionbetween -0.1 and 0.1."""nn.init.uniform_(self.start_transitions, -0.1, 0.1)nn.init.uniform_(self.end_transitions, -0.1, 0.1)nn.init.uniform_(self.transitions, -0.1, 0.1)def __repr__(self) -> str:return f'{self.__class__.__name__}(num_tags={self.num_tags})'def forward(self, emissions: torch.Tensor,tags: torch.LongTensor,mask: Optional[torch.ByteTensor] = None,reduction: str = 'mean') -> torch.Tensor:"""Compute the conditional log likelihood of a sequence of tags given emission scores.Args:emissions (`~torch.Tensor`): Emission score tensor of size``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,``(batch_size, seq_length, num_tags)`` otherwise.tags (`~torch.LongTensor`): Sequence of tags tensor of size``(seq_length, batch_size)`` if ``batch_first`` is ``False``,``(batch_size, seq_length)`` otherwise.mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.reduction: Specifies  the reduction to apply to the output:``none|sum|mean|token_mean``. ``none``: no reduction will be applied.``sum``: the output will be summed over batches. ``mean``: the output will beaveraged over batches. ``token_mean``: the output will be averaged over tokens.Returns:`~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` ifreduction is ``none``, ``()`` otherwise."""if reduction not in ('none', 'sum', 'mean', 'token_mean'):raise ValueError(f'invalid reduction: {reduction}')if mask is None:mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device)if mask.dtype != torch.uint8:mask = mask.byte()self._validate(emissions, tags=tags, mask=mask)if self.batch_first:emissions = emissions.transpose(0, 1)tags = tags.transpose(0, 1)mask = mask.transpose(0, 1)# shape: (batch_size,)numerator = self._compute_score(emissions, tags, mask)# shape: (batch_size,)denominator = self._compute_normalizer(emissions, mask)# shape: (batch_size,)llh = numerator - denominatorif reduction == 'none':return llhif reduction == 'sum':return llh.sum()if reduction == 'mean':return llh.mean()return llh.sum() / mask.float().sum()def decode(self, emissions: torch.Tensor,mask: Optional[torch.ByteTensor] = None,nbest: Optional[int] = None,pad_tag: Optional[int] = None) -> List[List[List[int]]]:"""Find the most likely tag sequence using Viterbi algorithm.Args:emissions (`~torch.Tensor`): Emission score tensor of size``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,``(batch_size, seq_length, num_tags)`` otherwise.mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.nbest (`int`): Number of most probable paths for each sequencepad_tag (`int`): Tag at padded positions. Often input varies in length andthe length will be padded to the maximum length in the batch. Tags atthe padded positions will be assigned with a padding tag, i.e. `pad_tag`Returns:A PyTorch tensor of the best tag sequence for each batch of shape(nbest, batch_size, seq_length)"""if nbest is None:nbest = 1if mask is None:mask = torch.ones(emissions.shape[:2], dtype=torch.uint8,device=emissions.device)if mask.dtype != torch.uint8:mask = mask.byte()self._validate(emissions, mask=mask)if self.batch_first:emissions = emissions.transpose(0, 1)mask = mask.transpose(0, 1)if nbest == 1:return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0)return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag)def _validate(self, emissions: torch.Tensor,tags: Optional[torch.LongTensor] = None,mask: Optional[torch.ByteTensor] = None) -> None:if emissions.dim() != 3:raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')if emissions.size(2) != self.num_tags:raise ValueError(f'expected last dimension of emissions is {self.num_tags}, 'f'got {emissions.size(2)}')if tags is not None:if emissions.shape[:2] != tags.shape:raise ValueError('the first two dimensions of emissions and tags must match, 'f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')if mask is not None:if emissions.shape[:2] != mask.shape:raise ValueError('the first two dimensions of emissions and mask must match, 'f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')no_empty_seq = not self.batch_first and mask[0].all()no_empty_seq_bf = self.batch_first and mask[:, 0].all()if not no_empty_seq and not no_empty_seq_bf:raise ValueError('mask of the first timestep must all be on')def _compute_score(self, emissions: torch.Tensor,tags: torch.LongTensor,mask: torch.ByteTensor) -> torch.Tensor:# emissions: (seq_length, batch_size, num_tags)# tags: (seq_length, batch_size)# mask: (seq_length, batch_size)seq_length, batch_size = tags.shapemask = mask.float()# Start transition score and first emission# shape: (batch_size,)score = self.start_transitions[tags[0]]score += emissions[0, torch.arange(batch_size), tags[0]]for i in range(1, seq_length):# Transition score to next tag, only added if next timestep is valid (mask == 1)# shape: (batch_size,)score += self.transitions[tags[i - 1], tags[i]] * mask[i]# Emission score for next tag, only added if next timestep is valid (mask == 1)# shape: (batch_size,)score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]# End transition score# shape: (batch_size,)seq_ends = mask.long().sum(dim=0) - 1# shape: (batch_size,)last_tags = tags[seq_ends, torch.arange(batch_size)]# shape: (batch_size,)score += self.end_transitions[last_tags]return scoredef _compute_normalizer(self, emissions: torch.Tensor,mask: torch.ByteTensor) -> torch.Tensor:# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)seq_length = emissions.size(0)# Start transition score and first emission; score has size of# (batch_size, num_tags) where for each batch, the j-th column stores# the score that the first timestep has tag j# shape: (batch_size, num_tags)score = self.start_transitions + emissions[0]for i in range(1, seq_length):# Broadcast score for every possible next tag# shape: (batch_size, num_tags, 1)broadcast_score = score.unsqueeze(2)# Broadcast emission score for every possible current tag# shape: (batch_size, 1, num_tags)broadcast_emissions = emissions[i].unsqueeze(1)# Compute the score tensor of size (batch_size, num_tags, num_tags) where# for each sample, entry at row i and column j stores the sum of scores of all# possible tag sequences so far that end with transitioning from tag i to tag j# and emitting# shape: (batch_size, num_tags, num_tags)next_score = broadcast_score + self.transitions + broadcast_emissions# Sum over all possible current tags, but we're in score space, so a sum# becomes a log-sum-exp: for each sample, entry i stores the sum of scores of# all possible tag sequences so far, that end in tag i# shape: (batch_size, num_tags)next_score = torch.logsumexp(next_score, dim=1)# Set score to the next score if this timestep is valid (mask == 1)# shape: (batch_size, num_tags)score = torch.where(mask[i].unsqueeze(1), next_score, score)# End transition score# shape: (batch_size, num_tags)score += self.end_transitions# Sum (log-sum-exp) over all possible tags# shape: (batch_size,)return torch.logsumexp(score, dim=1)def _viterbi_decode(self, emissions: torch.FloatTensor,mask: torch.ByteTensor,pad_tag: Optional[int] = None) -> List[List[int]]:# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)# return: (batch_size, seq_length)if pad_tag is None:pad_tag = 0device = emissions.deviceseq_length, batch_size = mask.shape# Start transition and first emission# shape: (batch_size, num_tags)score = self.start_transitions + emissions[0]history_idx = torch.zeros((seq_length, batch_size, self.num_tags),dtype=torch.long, device=device)oor_idx = torch.zeros((batch_size, self.num_tags),dtype=torch.long, device=device)oor_tag = torch.full((seq_length, batch_size), pad_tag,dtype=torch.long, device=device)# - score is a tensor of size (batch_size, num_tags) where for every batch,#   value at column j stores the score of the best tag sequence so far that ends#   with tag j# - history_idx saves where the best tags candidate transitioned from; this is used#   when we trace back the best tag sequence# - oor_idx saves the best tags candidate transitioned from at the positions#   where mask is 0, i.e. out of range (oor)# Viterbi algorithm recursive case: we compute the score of the best tag sequence# for every possible next tagfor i in range(1, seq_length):# Broadcast viterbi score for every possible next tag# shape: (batch_size, num_tags, 1)broadcast_score = score.unsqueeze(2)# Broadcast emission score for every possible current tag# shape: (batch_size, 1, num_tags)broadcast_emission = emissions[i].unsqueeze(1)# Compute the score tensor of size (batch_size, num_tags, num_tags) where# for each sample, entry at row i and column j stores the score of the best# tag sequence so far that ends with transitioning from tag i to tag j and emitting# shape: (batch_size, num_tags, num_tags)next_score = broadcast_score + self.transitions + broadcast_emission# Find the maximum score over all possible current tag# shape: (batch_size, num_tags)next_score, indices = next_score.max(dim=1)# Set score to the next score if this timestep is valid (mask == 1)# and save the index that produces the next score# shape: (batch_size, num_tags)score = torch.where(mask[i].unsqueeze(-1), next_score, score)indices = torch.where(mask[i].unsqueeze(-1), indices, oor_idx)history_idx[i - 1] = indices# End transition score# shape: (batch_size, num_tags)end_score = score + self.end_transitions_, end_tag = end_score.max(dim=1)# shape: (batch_size,)seq_ends = mask.long().sum(dim=0) - 1# insert the best tag at each sequence end (last position with mask == 1)history_idx = history_idx.transpose(1, 0).contiguous()history_idx.scatter_(1, seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags),end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags))history_idx = history_idx.transpose(1, 0).contiguous()# The most probable path for each sequencebest_tags_arr = torch.zeros((seq_length, batch_size),dtype=torch.long, device=device)best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device)for idx in range(seq_length - 1, -1, -1):best_tags = torch.gather(history_idx[idx], 1, best_tags)best_tags_arr[idx] = best_tags.data.view(batch_size)return torch.where(mask, best_tags_arr, oor_tag).transpose(0, 1)def _viterbi_decode_nbest(self, emissions: torch.FloatTensor,mask: torch.ByteTensor,nbest: int,pad_tag: Optional[int] = None) -> List[List[List[int]]]:# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)# return: (nbest, batch_size, seq_length)if pad_tag is None:pad_tag = 0device = emissions.deviceseq_length, batch_size = mask.shape# Start transition and first emission# shape: (batch_size, num_tags)score = self.start_transitions + emissions[0]history_idx = torch.zeros((seq_length, batch_size, self.num_tags, nbest),dtype=torch.long, device=device)oor_idx = torch.zeros((batch_size, self.num_tags, nbest),dtype=torch.long, device=device)oor_tag = torch.full((seq_length, batch_size, nbest), pad_tag,dtype=torch.long, device=device)# + score is a tensor of size (batch_size, num_tags) where for every batch,#   value at column j stores the score of the best tag sequence so far that ends#   with tag j# + history_idx saves where the best tags candidate transitioned from; this is used#   when we trace back the best tag sequence# - oor_idx saves the best tags candidate transitioned from at the positions#   where mask is 0, i.e. out of range (oor)# Viterbi algorithm recursive case: we compute the score of the best tag sequence# for every possible next tagfor i in range(1, seq_length):if i == 1:broadcast_score = score.unsqueeze(-1)broadcast_emission = emissions[i].unsqueeze(1)# shape: (batch_size, num_tags, num_tags)next_score = broadcast_score + self.transitions + broadcast_emissionelse:broadcast_score = score.unsqueeze(-1)broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2)# shape: (batch_size, num_tags, nbest, num_tags)next_score = broadcast_score + self.transitions.unsqueeze(1) + broadcast_emission# Find the top `nbest` maximum score over all possible current tag# shape: (batch_size, nbest, num_tags)next_score, indices = next_score.view(batch_size, -1, self.num_tags).topk(nbest, dim=1)if i == 1:score = score.unsqueeze(-1).expand(-1, -1, nbest)indices = indices * nbest# convert to shape: (batch_size, num_tags, nbest)next_score = next_score.transpose(2, 1)indices = indices.transpose(2, 1)# Set score to the next score if this timestep is valid (mask == 1)# and save the index that produces the next score# shape: (batch_size, num_tags, nbest)score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), next_score, score)indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), indices, oor_idx)history_idx[i - 1] = indices# End transition score shape: (batch_size, num_tags, nbest)end_score = score + self.end_transitions.unsqueeze(-1)_, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1)# shape: (batch_size,)seq_ends = mask.long().sum(dim=0) - 1# insert the best tag at each sequence end (last position with mask == 1)history_idx = history_idx.transpose(1, 0).contiguous()history_idx.scatter_(1, seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest),end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest))history_idx = history_idx.transpose(1, 0).contiguous()# The most probable path for each sequencebest_tags_arr = torch.zeros((seq_length, batch_size, nbest),dtype=torch.long, device=device)best_tags = torch.arange(nbest, dtype=torch.long, device=device) \.view(1, -1).expand(batch_size, -1)for idx in range(seq_length - 1, -1, -1):best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1, best_tags)best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbestreturn torch.where(mask.unsqueeze(-1), best_tags_arr, oor_tag).permute(2, 1, 0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/2515.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【iOS开发】(五)react Native路由和导航20240421-22

【iOS开发】(五)react Native 路由和导航Navigation 20240421 在&#xff08;一&#xff09;&#xff08;二&#xff09;中我们 Reactnative搭建了开发环境、学习了 基础语法、状态管理&#xff0c;JSX、组件、状态和生命周期以及样式布局等。 在&#xff08;三&#xff09;&a…

MATLAB 数据类型

MATLAB 数据类型 MATLAB 不需要任何类型声明或维度语句。每当 MATLAB 遇到一个新的变量名&#xff0c;它就创建变量并分配适当的内存空间。 如果变量已经存在&#xff0c;那么MATLAB将用新内容替换原始内容&#xff0c;并在必要时分配新的存储空间。 例如&#xff0c; Tota…

Vue3中使用无缝滚动插件vue3-seamless-scroll

官网&#xff1a;https://www.npmjs.com/package/vue-seamless-scroll 1、实现效果文字描述&#xff1a; 表格中的列数据进行横向无缝滚动&#xff0c;某一列进行筛选的时候&#xff0c;重新请求后端的数据&#xff0c;进行刷新 2、安装&#xff1a;npm i vue3-seamless-scrol…

小程序 rich-text 解析富文本 图片过大时如何自适应?

在微信小程序中&#xff0c;用rich-text 解析后端返回的数据&#xff0c;当图片尺寸太大时&#xff0c;会溢出屏幕&#xff0c;导致横向出现滚动 查看富文本代码 图片是用 <img 标签&#xff0c;所以写个正则匹配一下图片标签&#xff0c;手动加上样式即可 // content 为后…

Python 面向对象——5.多态

本章学习链接如下&#xff1a; Python 面向对象——1.基本概念 Python 面向对象——2.类与对象实例属性补充解释&#xff0c;self的作用等 Python 面向对象——3.实例方法&#xff0c;类方法与静态方法 Python 面向对象——4.继承 1.基本概念 多态是面向对象编程&#x…

贪吃蛇(C语言版)

在我们学习完C语言 和单链表知识点后 我们开始写个贪吃蛇的代码 目标&#xff1a;使用C语言在Windows环境的控制台模拟实现经典小游戏贪吃蛇 贪吃蛇代码实现的基本功能&#xff1a; 地图的绘制 蛇、食物的创建 蛇的状态&#xff08;正常 撞墙 撞到自己 正常退出&#xf…

Python蜘蛛侠

目录 写在前面 蜘蛛侠 编写代码 代码分析 更多精彩 写在后面 写在前面 本期小编给大家推荐一个酷酷的Python蜘蛛侠&#xff0c;一起来看看叭~ 蜘蛛侠 蜘蛛侠&#xff08;Spider-Man&#xff09;是美国漫威漫画宇宙中的一位标志性人物&#xff0c;由传奇创作者斯坦李与艺…

探索ChatGPT在提高人脸识别与软性生物识准确性的表现与可解释性

概述 从GPT-1到GPT-3&#xff0c;OpenAI的模型不断进步&#xff0c;推动了自然语言处理技术的发展。这些模型在处理语言任务方面展现出了强大的能力&#xff0c;包括文本生成、翻译、问答等。 然而&#xff0c;当涉及到面部识别和生物特征估计等任务时&#xff0c;这些基于文…

设计模式-00 设计模式简介之几大原则

设计模式-00 设计模式简介之几大原则 本专栏主要分析自己学习设计模式相关的浅解&#xff0c;并运用modern cpp 来是实现&#xff0c;描述相关设计模式。 通过编写代码&#xff0c;深入理解设计模式精髓&#xff0c;并且很好的帮助自己掌握设计模式&#xff0c;顺便巩固自己的c…

用于车载T-BOX汽车级的RA8900CE

用于车载T-BOX等高精度计时的汽车级时钟模块RTC:RA8900CE.车载实时时钟芯片RA8900CE内置32.768Khz的晶体&#xff0c;实现年、月、日、星期、小时、分钟和秒精准计时。RA8900CE满足AEC-Q200认证&#xff0c;内置温补功能&#xff0c;保证实时时钟的稳定可靠&#xff0c;功耗低至…

【Linux】解决ubuntu20.04版本插入无线网卡没有wifi显示【无线网卡Realtek 8811cu】

ubuntu为Realtek 8811cu安装驱动&#xff0c;解决wifi连接问题 1、确认无线网卡的型号-Realtek 8810cu2、下载并配置驱动 一句话总结&#xff1a;先确定网卡的型号&#xff0c;然后根据网卡的型号区寻找对应的驱动下载&#xff0c;下载完成之后在ubuntu系统中进行编译&#xff…

HTTP慢连接攻击的原理和防范措施

随着互联网的快速发展&#xff0c;网络安全问题日益凸显&#xff0c;网络攻击事件频繁发生。其中&#xff0c;HTTP慢速攻击作为一种隐蔽且高效的攻击方式&#xff0c;近年来逐渐出现的越来越多。 为了防范这些网络攻击&#xff0c;我们需要先了解这些攻击情况&#xff0c;这样…

【笔试】03

FLOPS FLOPS 是 Floating Point Operations Per Second 的缩写&#xff0c;意为每秒浮点运算次数。它是衡量计算机性能的指标&#xff0c;特别是用于衡量计算机每秒能够执行多少浮点运算。在高性能计算领域&#xff0c;FLOPS 被广泛用来评估超级计算机、CPU、GPU 和其他处理器…

2024年区块链链游即将迎来大爆发

随着区块链技术的不断发展和成熟&#xff0c;其应用领域也在不断扩展。其中&#xff0c;区块链链游&#xff08;Blockchain Games&#xff09;作为区块链技术在游戏行业中的应用&#xff0c;备受关注。2024年&#xff0c;区块链链游行业即将迎来爆发&#xff0c;这一趋势不容忽…

Windows10如何关闭Edge浏览器的Copilot

在Windows10更新后&#xff0c;打开Edge浏览器&#xff0c;无论复制什么内容&#xff0c;都会弹出Copilot人工智能插件&#xff0c;非常令人反感&#xff0c;网上搜索的关闭方法都非常麻烦&#xff0c;比如&#xff1a;组策略和注册表。自己摸索得出最简便有效的关闭方法。 1、…

【java毕业设计】 基于Spring Boot+mysql的高校心理教育辅导系统设计与实现(程序源码)-高校心理教育辅导系统

基于Spring Bootmysql的高校心理教育辅导系统设计与实现&#xff08;程序源码毕业论文&#xff09; 大家好&#xff0c;今天给大家介绍基于Spring Bootmysql的高校心理教育辅导系统设计与实现&#xff0c;本论文只截取部分文章重点&#xff0c;文章末尾附有本毕业设计完整源码及…

一致性hash

一、什么是一致性hash 普通的hash算法 (hashcode % size )&#xff0c;如果size发生变化&#xff0c;几乎所有的历史数据都需要重hash、移动&#xff0c;代价非常大&#xff0c;常见的java中的hashmap就是如此。 那如果在hash表扩容或者收缩的时候size能够保持不变&#xff0…

gitee / github 配置git, 实现免密码登录

文章目录 怎么配置公钥和私钥验证配置成功问题 怎么配置公钥和私钥 以下内容参考自 github ssh 配置&#xff0c;gitee的配置也是一样的&#xff1b; 粘贴以下文本&#xff0c;将示例中使用的电子邮件替换为 GitHub 电子邮件地址。 ssh-keygen -t ed25519 -C "your_emai…

线性代数 --- 矩阵的对角化以及矩阵的n次幂

矩阵的对角化以及矩阵的n次幂 &#xff08;特征向量与特征值的应用&#xff09; 前言&#xff1a; 在上一篇文章中&#xff0c;我记录了学习矩阵的特征向量和特征值的学习笔记&#xff0c;所关注的是那些矩阵A作用于向量x后&#xff0c;方向不发生改变的x(仅有尺度的缩放)。线…

VMware 15 安装centos7虚拟机

1. 安装前准备 1.1 下载centos 阿里巴巴开源镜像站-OPSX镜像站-阿里云开发者社区 下载需要版本的centos版本 直达链接 centos7.9 &#xff1a; centos-7.9.2009-isos-x86_64安装包下载_开源镜像站-阿里云 .基础使用的话安装选择这个就行了&#xff0c;大概下载几分钟 2. …