接上一篇https://blog.csdn.net/m0_60688978/article/details/139046644
# def getAQ():
# ask=[]
# answer=[]
# with open("./data/flink.txt","r",encoding="utf-8") as f:
# lines=f.readlines()
# for line in lines:
# ask.append(line.split("----")[0])
# answer.append(line.split("----")[1].replace("\n",""))
# return answer,ask# seq_answer,seq_example=getAQ()import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import os
from tqdm import tqdmseq_example = ["你认识我吗", "你住在哪里", "你知道我的名字吗", "你是谁", "你会唱歌吗", "谁是张学友"]
seq_answer = ["当然认识", "我住在成都", "我不知道", "我是机器人", "我不会", "她旁边那个就是"]# 所有词
example_cut = []
answer_cut = []
word_all = []
# 分词
for i in seq_example:example_cut.append(list(jieba.cut(i)))
for i in seq_answer:answer_cut.append(list(jieba.cut(i)))
# 所有词
for i in example_cut + answer_cut:for word in i:if word not in word_all:word_all.append(word)
# 词语索引表
word2index = {w: i+3 for i, w in enumerate(word_all)}
# 补全
word2index['PAD'] = 0
# 句子开始
word2index['SOS'] = 1
# 句子结束
word2index['EOS'] = 2
index2word = {value: key for key, value in word2index.items()}
# 一些参数
vocab_size = len(word2index)
seq_length = max([len(i) for i in example_cut + answer_cut]) + 1
print("vocab_size is",vocab_size,", seq_length is ",seq_length)
embedding_size = 128
num_classes = vocab_size
hidden_size = 256
batch_size=6
seq_len=7# 将句子用索引表示
def make_data(seq_list):result = []for word in seq_list:seq_index = [word2index[i] for i in word]if len(seq_index) < seq_length:seq_index += [0] * (seq_length - len(seq_index))result.append(seq_index)return result
encoder_input = make_data(example_cut)
decoder_input = make_data([['SOS'] + i for i in answer_cut])
decoder_target = make_data([i + ['EOS'] for i in answer_cut])# 训练数据
encoder_input, decoder_input, decoder_target = torch.tensor(encoder_input,dtype=torch.long),torch.tensor(decoder_input,dtype=torch.long), torch.tensor(decoder_target,dtype=torch.long)class encoder(nn.Module):def __init__(self):super(encoder, self).__init__()self.embedding=nn.Embedding(vocab_size,embedding_size)self.lstm=nn.LSTM(embedding_size,hidden_size)def forward(self, inputx):embeded=self.embedding(inputx)output,(encoder_h_n, encoder_c_n)=self.lstm(embeded.permute(1,0,2))return output,(encoder_h_n,encoder_c_n)cc1=encoder()
out,(hn,cn)=cc1(encoder_input)
out.size(),hn.size(),out,hn#out [seq_len_size,batch_size,hidden_size] hn [layer_num,batch_size,hidden_size]# concat
# class Attention(nn.Module):
# def __init__(self):
# super(Attention, self).__init__()
# self.wa=nn.Linear(hidden_size*2,hidden_size*2,bias=False)# self.wa1=nn.Linear(hidden_size*2,hidden_size,bias=False)# def forward(self, hidden, encoder_outputs):
# """
# hidden:[layer_num,batch_size,hidden_size]
# encoder_outputs:[seq_len,batch_size,hidden_size]
# """# hiddenchange=hidden.repeat(encoder_outputs.size(0),1,1)#[seq_len,batch_size,hidden_size]# concated=torch.cat([hiddenchange.permute(1,0,2),encoder_outputs.permute(1,0,2)],dim=-1)# [batch_size,seq_len,hidden_size*2]
# waed=self.wa(concated)# [batch_size,seq_len,hidden_size*2]
# tanhed=torch.tanh(waed)# [batch_size,seq_len,hidden_size*2]# self.va=nn.Parameter(torch.FloatTensor(encoder_outputs.size(1),hidden_size*2))#[batch_size,hidden_size*2]
# # print("tanhed size",tanhed.size(),self.va.unsqueeze(2).size())# attr=tanhed.bmm(self.va.unsqueeze(2))# [batch_size,seq_len,1]# context=attr.permute(0,2,1).bmm(encoder_outputs.permute(1,0,2))# [batch_size,1,seq_len]# return context,attr# general
class Attention(nn.Module):def __init__(self):super(Attention, self).__init__()self.va=nn.Linear(hidden_size,hidden_size,bias=False)def forward(self, hidden, encoder_outputs):"""hidden:[layer_num,batch_size,hidden_size]encoder_outputs:[seq_len,batch_size,hidden_size]"""score=encoder_outputs.permute(1,0,2).bmm(self.va(hidden).permute(1,2,0))# [batch_size,seq_len,layer_num]attr=nn.functional.softmax(score,dim=1)# [batch_size,seq_len,layer_num]context=attr.permute(0,2,1).bmm(encoder_outputs.permute(1,0,2)) #[batch_size,layer_num,hidden_size]return context,attr# # dot
# class Attention(nn.Module):
# def __init__(self):
# super(Attention, self).__init__()# def forward(self, hidden, encoder_outputs):
# """
# hidden:[layer_num,batch_size,hidden_size]
# encoder_outputs:[seq_len,batch_size,hidden_size]
# """
# score=encoder_outputs.permute(1,0,2).bmm(hidden.permute(1,2,0))# [batch_size,seq_len,layer_num]
# attr=nn.functional.softmax(score,dim=-1)# [batch_size,seq_len,layer_num]# context=attr.permute(0,2,1).bmm(encoder_outputs.permute(1,0,2)) #[batch_size,layer_num,hidden_size]# return context,attrclass decoder(nn.Module):def __init__(self):super(decoder,self).__init__()self.embeded=nn.Embedding(vocab_size,embedding_size)self.lstm=nn.LSTM(embedding_size,hidden_size)self.attr=Attention()self.fc=nn.Linear(hidden_size*2,vocab_size,bias=False)self.tan=nn.Linear(hidden_size*2,hidden_size*2,bias=False)def forward(self,input,encoder_outputs,hn,cn):decoder_input_embedded=self.embeded(input)#[batch_size,seq_len,embedding_size]lstm_output,(lstm_hn,lstm_cn)=self.lstm(decoder_input_embedded.permute(1,0,2),(hn,cn)) #lstm_output [seq_len,batch_size,hidden_size] lstm_hn [layer_num,batch_size,hidden_size]context,attr=self.attr(hn,encoder_outputs) #[batch_size,layer_num,hidden_size] [batch_size,seq_len,layer_num]# 这里是解决预测时长度为1,而训练是seq长度为7的关键context=context.repeat(1,lstm_output.size(0),1) #[batch_size,seq_len,hidden_size]# print("lstm_output size is",lstm_output.size(),"context size is",context.size())concated=torch.cat([lstm_output.permute(1,0,2),context],dim=-1) #[batch_size,seq_len,hidden_size*2]concat_output=torch.tanh(self.tan(concated))# [batch_size,seq_len,hidden_size*2]
# concat_output=nn.functional.log_softmax(self.fc(concat_output),dim=-1) # [batch_size,seq_len,vocab_size]concat_output=self.fc(concat_output) # [batch_size,seq_len,vocab_size]# 这里从softmax变为log_softmax,开始损失从3.xx变为了2.xreturn concat_output,lstm_hn,lstm_cnclass seq2seq(nn.Module):def __init__(self):super(seq2seq,self).__init__()self.word_vec=nn.Linear(vocab_size,embedding_size)self.encoder=encoder()self.decoder=decoder()def forward(self,encoder_input, decoder_input,istrain=1):encoder_outputs,(hn,cn)=self.encoder(encoder_input)if istrain:#outputs [batch_size,seq_len,vocab_size] hn [layer_num,batch_size,hidden_size]outputs,hn,cn=self.decoder(decoder_input,encoder_outputs,hn,cn)return outputs,hn,cnelse:finaloutputs=[]for i in range(seq_length):#outputs [batch_size,seq_len,vocab_size] hn [layer_num,batch_size,hidden_size]outputs,hn,cn=self.decoder(decoder_input,encoder_outputs,hn,cn)gailv=nn.functional.softmax(outputs[0],dim=1)out=torch.topk(gailv,1)[1].item()if out in [0,2]:return finaloutputsfinaloutputs.append(out)decoder_input=torch.tensor([[out]])return finaloutputsmodel = seq2seq()
print(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.05)model.train()
for epoch in tqdm(range(3000)):pred,lstm_hn,lstm_cn = model(encoder_input, decoder_input,1)
# print("pred",pred.size())
# print("decoder_target",decoder_target.view(-1).size())loss = criterion(pred.reshape(-1, vocab_size), decoder_target.view(-1))optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print("Epoch: %d, loss: %.5f " % (epoch + 1, loss))
# 保存模型
torch.save(model.state_dict(), './seq2seqModel.pkl')model.eval()
question_text = '谁是张学友'
question_cut = list(jieba.cut(question_text))
encoder_x = make_data([question_cut])
decoder_x = [[word2index['SOS']]]
encoder_x, decoder_x = torch.LongTensor(encoder_x), torch.LongTensor(decoder_x)
out= model(encoder_x,decoder_x,0)
answer = ''
for i in out:answer += index2word[i]
print('问题:', question_text)
print('回答:', answer)