一下是一个Transformer代码实例:
def sample(self, batch_size, max_length=140, con_token_list= ['is_JNK3', 'is_GSK3', 'high_QED', 'good_SA']):"""Sample a batch of sequencesArgs:batch_size : Number of sequences to samplemax_length: Maximum length of the sequencesOutputs:seqs: (batch_size, seq_length) The sampled sequences.log_probs : (batch_size) Log likelihood for each sequence.entropy: (batch_size) The entropies for the sequences. Notcurrently used."""# conditional tokencon_token_list = Variable(self.voc.encode(con_token_list))con_tokens = Variable(torch.zeros(batch_size, len(con_token_list)).long()) #形状为 (batch_size, len(con_token_list)),表示条件标记的张量。for ind, token in enumerate(con_token_list):con_tokens[:, ind] = tokenstart_token = Variable(torch.zeros(batch_size, 1).long()) #形状为 (batch_size, 1),表示序列开始标记的张量。start_token[:] = self.voc.vocab['GO']input_vector = start_token # 在循环中更新的张量,它的形状与 sequences 相同。# print(batch_size)sequences = start_tokenlog_probs = Variable(torch.zeros(batch_size))# log_probs1 = Variable(torch.zeros(batch_size))finished = torch.zeros(batch_size).byte()finished = finished.to(self.device)for step in range(max_length):logits = sample_forward_model(self.decodertf, input_vector, con_tokens) #形状为 (batch_size, max_length, vocab_size)。logits_step = logits[:, step, :] #是从 logits 中选择当前时间步的张量,形状为 (batch_size, vocab_size)。prob = F.softmax(logits_step, dim=1)log_prob = F.log_softmax(logits_step, dim=1)input_vector = torch.multinomial(prob, 1)# need to concat prior words as the sequences and input 记录下每一步采样sequences = torch.cat((sequences, input_vector), 1) #形状为 (batch_size, seq_length),表示生成的序列。log_probs += self._nll_loss(log_prob, input_vector.view(-1)) #形状为 (batch_size),表示每个生成序列的对数似然。# log_probs1 += NLLLoss(log_prob, input_vector.view(-1))# print(log_probs1==-log_probs)EOS_sampled = (input_vector.view(-1) == self.voc.vocab['EOS']).datafinished = torch.ge(finished + EOS_sampled, 1) #形状为 (batch_size),是一个二进制张量,表示每个序列是否已经结束。if torch.prod(finished) == 1:# print('End')break# because there are no hidden layer in transformer, so we need to append generated word in every step as the input_vectorinput_vector = sequencesreturn sequences[:, 1:].data, log_probs