BERT-pytorch源码实现,解决内存溢出问题
相信大家很多人都在做BERT这个模型,但是,有些人可能就是直接从transfermer这个模型里直接导入数据,但是这种方法不方便我们修改模型,于是有些人就通过pytorch详细实现了BERT,但是博主发现,这些详细实现BERT的代码出现了内存溢出问题,博主就做了改进,下面代码,我们可以解决掉内存溢出问题,主要还是因为中间结果并没有完全释放,代码如下:
注:大家如果要解决内存溢出问题,关注del语句就可以了。
'''code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathorReference : https://github.com/jadore801120/attention-is-all-you-need-pytorchhttps://github.com/JayParks/transformer, https://github.com/dhlee347/pytorchic-bert
'''
import re
import math
import torch
import numpy as np
from random import *
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Dataimport matplotlib.pyplot as plt
from data_process import get_datasetences,label,setences_test,label_test=get_data()
device = torch.device('cpu')sentences=setences
#text = (
# 'Hello, how are you? I am Romeo.\n' # R
# 'Hello, Romeo My name is Juliet. Nice to meet you.\n' # J
# 'Nice meet you too. How are you today?\n' # R
# 'Great. My baseball team won the competition.\n' # J
# 'Oh Congratulations, Juliet\n' # R
# 'Thank you Romeo\n' # J
# 'Where are you going today?\n' # R
# 'I am going shopping. What about you?\n' # J
# 'I am going to visit my grandmother. she is not very well' # R
#)
#sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n') # filt
#print(sentences)word_list = list(set(" ".join(setences).split())|set(" ".join(setences_test).split())) # ['hello', 'how', 'are', 'you',...]
word2idx = {'[PAD]' : 0, '[CLS]' : 1, '[SEP]' : 2, '[MASK]' : 3}
for i, w in enumerate(word_list):word2idx[w] = i + 4
idx2word = {i: w for i, w in enumerate(word2idx)}
vocab_size = len(word2idx)token_list = list()
for sentence in setences:arr = [word2idx[s] for s in sentence.split()]token_list.append(arr)#print(token_list)
'''
[[12, 7, 22, 5, 39, 21, 15],[12, 15, 13, 35, 10, 27, 34, 14, 19, 5],[34, 19, 5, 17, 7, 22, 5, 8],[33, 13, 37, 32, 28, 11, 16],[30, 23, 27],[6, 5, 15],[36, 22, 5, 31, 8],[39, 21, 31, 18, 9, 20, 5],[39, 21, 31, 14, 29, 13, 4, 25, 10, 26, 38, 24]]
'''
# BERT Parameters
maxlen = 30
batch_size = 6
max_pred = 5 # max tokens of prediction
n_layers = 6
n_heads = 12
d_model = 768
d_ff = 768*4 # 4*d_model, FeedForward dimension
d_k = d_v = 64 # dimension of K(=Q), V
n_segments = 3
# sample IsNext and NotNext to be same in small batch size
def make_data():batch = []for i in range(len(setences)):tokens_a_index = itokens_a = token_list[tokens_a_index]input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']]segment_ids = [0] * (1 + len(tokens_a) + 1) # MASK LMn_pred = min(max_pred, max(1, int(len(input_ids) * 0.15))) # 15 % of tokens in one sentencecand_maked_pos = [i for i, token in enumerate(input_ids)if token != word2idx['[CLS]'] and token != word2idx['[SEP]']] # candidate masked positionshuffle(cand_maked_pos)masked_tokens, masked_pos = [], []for pos in cand_maked_pos[:n_pred]:masked_pos.append(pos)masked_tokens.append(input_ids[pos])if random() < 0.8: # 80%input_ids[pos] = word2idx['[MASK]'] # make maskelif random() > 0.9: # 10%index = randint(0, vocab_size - 1) # random index in vocabularywhile index < 4: # can't involve 'CLS', 'SEP', 'PAD'index = randint(0, vocab_size - 1)input_ids[pos] = index # replace# Zero Paddingsn_pad = maxlen - len(input_ids)input_ids.extend([0] * n_pad)segment_ids.extend([0] * n_pad)# Zero Padding (100% - 15%) tokensif max_pred > n_pred:n_pad = max_pred - n_predmasked_tokens.extend([0] * n_pad)masked_pos.extend([0] * n_pad)batch.append([input_ids, segment_ids, masked_tokens, masked_pos, label[tokens_a_index]]) # IsNextreturn batch
batch = make_data()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)
'''>>> a = [1,2,3]
b = [4,5,6]
zipped = zip(a,b) # 打包为元组的列表
[(1, 4), (2, 5), (3, 6)]
zip(*zipped) # 与 zip 相反,可理解为解压,为zip的逆过程,可用于矩阵的转置
[(1, 2, 3), (4, 5, 6)]
'''
input_ids, segment_ids, masked_tokens, masked_pos, isNext = \torch.LongTensor(input_ids), torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens), \torch.LongTensor(masked_pos), torch.LongTensor(isNext)class MyDataSet(Data.Dataset):def __init__(self, input_ids, segment_ids, masked_tokens, masked_pos, isNext):self.input_ids = input_idsself.segment_ids = segment_idsself.masked_tokens = masked_tokensself.masked_pos = masked_posself.isNext = isNextdef __len__(self):return len(self.input_ids)def __getitem__(self, idx):return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], self.masked_pos[idx], self.isNext[idx]loader = Data.DataLoader(MyDataSet(input_ids, segment_ids, masked_tokens, masked_pos, isNext), batch_size, True)
def get_attn_pad_mask(seq_q, seq_k):batch_size, seq_len = seq_q.size() #[batch_size,maxlen]# eq(zero) is PAD tokenpad_attn_mask = seq_q.data.eq(0).unsqueeze(1) # [batch_size, 1, seq_len]return pad_attn_mask.expand(batch_size, seq_len, seq_len) # [batch_size, seq_len, seq_len]def gelu(x):return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))class Embedding(nn.Module):def __init__(self):super(Embedding, self).__init__()self.tok_embed = nn.Embedding(vocab_size, d_model) # token embeddingself.pos_embed = nn.Embedding(maxlen, d_model) # position embeddingself.seg_embed = nn.Embedding(n_segments, d_model) # segment(token type) embeddingself.norm = nn.LayerNorm(d_model)def forward(self, x, seg):seq_len = x.size(1)pos = torch.arange(seq_len, dtype=torch.long)# print("pos:",pos)'''pos: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])'''pos = pos.unsqueeze(0).expand_as(x).to(device) # [seq_len] -> [batch_size, seq_len]# print("pos_batch:", pos)embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)del pos,x, segreturn self.norm(embedding)class ScaledDotProductAttention(nn.Module):def __init__(self):super(ScaledDotProductAttention, self).__init__()def forward(self, Q, K, V, attn_mask):scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, seq_len, seq_len]scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.attn = nn.Softmax(dim=-1)(scores)context = torch.matmul(attn, V)del attn,scores,Q, K, V,attn_maskreturn contextclass MultiHeadAttention(nn.Module):def __init__(self):super(MultiHeadAttention, self).__init__()self.W_Q = nn.Linear(d_model, d_k * n_heads)self.W_K = nn.Linear(d_model, d_k * n_heads)self.W_V = nn.Linear(d_model, d_v * n_heads)def forward(self, Q, K, V, attn_mask):# q: [batch_size, seq_len, d_model], k: [batch_size, seq_len, d_model], v: [batch_size, seq_len, d_model]residual, batch_size = Q, Q.size(0)residual=residual.to(device)# (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) # q_s: [batch_size, n_heads, seq_len, d_k]k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2) # k_s: [batch_size, n_heads, seq_len, d_k]v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2) # v_s: [batch_size, n_heads, seq_len, d_v]attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]# context: [batch_size, n_heads, seq_len, d_v], attn: [batch_size, n_heads, seq_len, seq_len]context = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)# context: [batch_size, seq_len, n_heads, d_v]output = nn.Linear(n_heads * d_v, d_model).to(device)(context)del context,attn_mask,q_s,k_s,v_sreturn nn.LayerNorm(d_model).to(device)(output + residual) # output: [batch_size, seq_len, d_model]class PoswiseFeedForwardNet(nn.Module):def __init__(self):super(PoswiseFeedForwardNet, self).__init__()self.fc1 = nn.Linear(d_model, d_ff)self.fc2 = nn.Linear(d_ff, d_model)def forward(self, x):# (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff) -> (batch_size, seq_len, d_model)return self.fc2(gelu(self.fc1(x)))class EncoderLayer(nn.Module):def __init__(self):super(EncoderLayer, self).__init__()self.enc_self_attn = MultiHeadAttention()self.pos_ffn = PoswiseFeedForwardNet()def forward(self, enc_inputs, enc_self_attn_mask):enc_outputs = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,Venc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, seq_len, d_model]del enc_self_attn_mask,enc_inputsreturn enc_outputsclass BERT(nn.Module):def __init__(self):super(BERT, self).__init__()self.embedding = Embedding()self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])self.fc = nn.Sequential(nn.Linear(d_model, d_model),nn.Dropout(0.5),nn.Tanh(),)self.classifier = nn.Linear(d_model, 3)self.linear = nn.Linear(d_model, d_model)self.activ2 = gelu# fc2 is shared with embedding layerembed_weight = self.embedding.tok_embed.weightself.fc2 = nn.Linear(d_model, vocab_size, bias=False)self.fc2.weight = embed_weightdef forward(self, input_ids, segment_ids, masked_pos):output = self.embedding(input_ids, segment_ids) # [bach_size, seq_len, d_model]enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids) # [batch_size, maxlen, maxlen]for layer in self.layers:# output: [batch_size, max_len, d_model]output = layer(output, enc_self_attn_mask)# it will be decided by first token(CLS)'''(fc): Sequential((0): Linear(in_features=768, out_features=768, bias=True)(1): Dropout(p=0.5, inplace=False)(2): Tanh())(classifier): Linear(in_features=768, out_features=2, bias=True)(linear): Linear(in_features=768, out_features=768, bias=True)(fc2): Linear(in_features=768, out_features=40, bias=False)'''# logits_clsf :根据[CLS]预测是否是连续的句子,[CLS]在第一维h_pooled = self.fc(output[:, 0]) # [batch_size, d_model]logits_clsf = self.classifier(h_pooled) # [batch_size, 2] predict isNextmasked_pos = masked_pos[:, :, None].expand(-1, -1, d_model) # [batch_size, max_pred, d_model]h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]h_masked = self.activ2(self.linear(h_masked)) # [batch_size, max_pred, d_model]#logits_lm:预测mask的tokenlogits_lm = self.fc2(h_masked) # [batch_size, max_pred, vocab_size]del h_masked,h_pooled,output,enc_self_attn_mask,masked_pos,input_ids,segment_idsreturn logits_lm, logits_clsf
model = BERT().to(device)
# print(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.000001)
#out = torch.gather(input, dim, index)
index = torch.from_numpy(np.array([[1, 2, 0], [2, 0, 1]])).type(torch.LongTensor)
index = index[:, :, None].expand(-1, -1, 10)
loss_list=[]
for epoch in range(10):loss_sum=0for input_ids, segment_ids, masked_tokens, masked_pos, isNext in loader:logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)#logits_lm:[batch_size,max_pred,vocab_size] -> [batch_size*max_pred,vocab_size],batch_size*max_pred个词。每个词都有vocab_size种可能。loss_lm = criterion(logits_lm.view(-1, vocab_size), masked_tokens.view(-1)) # for masked LMloss_lm = (loss_lm.float()).mean()# isNext=isNext.to(device)loss_clsf = criterion(logits_clsf, isNext) # for sentence classificationloss = loss_lm + loss_clsfloss_sum=loss_sum+lossloss_list.append(float(loss))print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))optimizer.zero_grad()loss.backward()optimizer.step()del loss, logits_clsf, input_ids,segment_ids,masked_tokens,masked_pos,logits_lm,isNext,loss_clsf,loss_lm# Predict mask tokens ans isNextprint('test')token_list=[]for sentence in setences_test:arr = [word2idx[s] for s in sentence.split()]token_list.append(arr)def make_data_test():batch = []for i in range(len(setences_test)):tokens_a_index = itokens_a = token_list[tokens_a_index]input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']]segment_ids = [0] * (1 + len(tokens_a) + 1) # MASK LMn_pred = min(max_pred, max(1, int(len(input_ids) * 0.15))) # 15 % of tokens in one sentencecand_maked_pos = [i for i, token in enumerate(input_ids)if token != word2idx['[CLS]'] and token != word2idx['[SEP]']] # candidate masked positionshuffle(cand_maked_pos)masked_tokens, masked_pos = [], []for pos in cand_maked_pos[:n_pred]:masked_pos.append(pos)masked_tokens.append(input_ids[pos])if random() < 0.8: # 80%input_ids[pos] = word2idx['[MASK]'] # make maskelif random() > 0.9: # 10%index = randint(0, vocab_size - 1) # random index in vocabularywhile index < 4: # can't involve 'CLS', 'SEP', 'PAD'index = randint(0, vocab_size - 1)input_ids[pos] = index # replace# Zero Paddingsn_pad = maxlen - len(input_ids)input_ids.extend([0] * n_pad)segment_ids.extend([0] * n_pad)# Zero Padding (100% - 15%) tokensif max_pred > n_pred:n_pad = max_pred - n_predmasked_tokens.extend([0] * n_pad)masked_pos.extend([0] * n_pad)batch.append([input_ids, segment_ids, masked_tokens, masked_pos, label_test[tokens_a_index]]) # IsNextreturn batch
# Proprecessing Finishedbatch = make_data_test()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)
'''>>> a = [1,2,3]
b = [4,5,6]
zipped = zip(a,b) # 打包为元组的列表
[(1, 4), (2, 5), (3, 6)]
zip(*zipped) # 与 zip 相反,可理解为解压,为zip的逆过程,可用于矩阵的转置
[(1, 2, 3), (4, 5, 6)]
'''
input_ids, segment_ids, masked_tokens, masked_pos, isNext = \torch.LongTensor(input_ids), torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens), \torch.LongTensor(masked_pos), torch.LongTensor(isNext)predict_list=[]for i in range(len(batch)):input_ids, segment_ids, masked_tokens, masked_pos, isNext = batch[0]print([idx2word[w] for w in input_ids if idx2word[w] != '[PAD]'])logits_lm, logits_clsf = model(torch.LongTensor([input_ids]), \torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))logits_lm = logits_lm.data.max(2)[1][0].data.numpy()print('masked tokens list : ',[pos for pos in masked_tokens if pos != 0])print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]print('isNext : ', isNext )print('predict isNext : ', logits_clsf)predict_list.append(logits_clsf)test_loss = 0
correct = 0
total = 0
target_num =[0,0,0]
predict_num = [0,0,0]
p=0
acc_num =[0,0,0]for i in label_test:target_num[i]+=1for i in predict_list:# print(i.argmax())index=int(i)if index in [0,1,2]:predict_num[index]+=1# print(id2word[index],id2word[p])if index==label_test[p]:acc_num[index]+=1p=p+1#print(target_num)
#print(predict_num)
#print(acc_num)
recallz=0
precisionz=0
accuracyz=0
F1z=0ps=0
rs=0
for i in range(3):if target_num[i]!=0:recallz=acc_num[i]/target_num[i]else:recallz=0if predict_num[i]!=0:precisionz=acc_num[i]/predict_num[i]else:precisionz=0ps=ps+precisionzrs=rs+recallzif recallz+precisionz!=0:F1z=2*recallz*precisionz/(recallz+precisionz)+F1z#recall = [acc_num[i]/target_num[i] for i in range(3)]#precision = [acc_num[i]/predict_num[i] for i in range(3)]#F1 = [2*recall[i]*precision[i]/(recall[i]+precision[i]) for i in range(3)]print()
accuracy = sum(acc_num)/sum(target_num) # 打印格式方便复制
print('recall:',rs/3)
print('precision:',ps/3)
print('F1:',F1z/3)
print('accuracy',accuracy)plt.plot(loss_list,label='BERT')
plt.legend()
plt.title('loss-epoch')
plt.show()