本文学习纪录
PyTorch深度学习项目实战100例
使用LSTM来实现文本匹配任务
使用LSTM(Long Short-Term Memory)网络来实现文本匹配任务是自然语言处理(NLP)中的一个常见应用。文本匹配任务的目标是确定两个文本段落是否在某种程度上相似或相关,例如在问答系统、文档检索、相似问题匹配等场景中非常有用。
句⼦1:我不爱吃剁椒⻥头,但是我爱吃⻥头
句⼦2:我爱吃⼟⾖,但是不爱吃地⽠
模型构建
输入层
:两个独立的输入,分别对应两个文本序列。
LSTM层
:为每个输入文本设计一个LSTM层来捕获序列信息。可以使用双向LSTM(BiLSTM)来获取前后文信息。
相似度计算
:使用余弦相似度、曼哈顿距离、欧式距离等方法计算两个LSTM层的输出向量之间的相似度。
输出层
:根据相似度分数输出匹配程度,可以是二分类(匹配或不匹配)或者回归(相似度得分)。
定义网络
# 定义网络结构
class LSTM(nn.Module):def __init__(self, vocab_size, hidden_dim, num_layers, embedding_dim, output_dim):super(LSTM, self).__init__()self.hidden_dim = hidden_dim # 隐层大小self.num_layers = num_layers # LSTM层数# 嵌入层,会对所有词形成一个连续型嵌入向量,该向量的维度为embedding_dim# 然后利用这个向量来表示该字,而不是用索引继续表示self.embeddings_x = nn.Embedding(vocab_size + 1, embedding_dim)self.embeddings_y = nn.Embedding(vocab_size + 1, embedding_dim)# 定义LSTM层,第一个参数为每个时间步的特征大小,这里就是每个字的维度# 第二个参数为隐层大小# 第三个参数为lstm的层数self.lstm_x = nn.LSTM(embedding_dim, hidden_dim, num_layers)self.lstm_y = nn.LSTM(embedding_dim, hidden_dim, num_layers)self.cos_sim = nn.CosineSimilarity(dim=1, eps=1e-6)# 利用全连接层将其映射为2维,即0和1的概率self.fc = nn.Linear(1, output_dim)def forward(self, x_input, y_input):# 1.首先形成嵌入向量embeds_x = self.embeddings_x(x_input)embeds_y = self.embeddings_y(x_input)# 2.将嵌入向量导入到lstm层output_x, _ = self.lstm_x(embeds_x)output_y, _ = self.lstm_x(embeds_y)timestep, batch_size, hidden_dim = output_x.shapeoutput_x = output_x.reshape(timestep, batch_size, -1)output_y = output_y.reshape(timestep, batch_size, -1)# 3.获取lstm最后一个隐层表示向量output_x = output_x[-1]output_y = output_y[-1]# 4.计算两个向量的余弦相似度sim = self.cos_sim(output_x, output_y)sim = sim.view(-1, 1)# 5.形成最终输出结果output = self.fc(sim)return output
模型训练
# 6.模型训练
model = LSTM(vocab_size=vocab_size, hidden_dim=hidden_dim, num_layers=num_layers,embedding_dim=embedding_dim, output_dim=output_dim)Configimizer = optim.Adam(model.parameters(), lr=lr) # 优化器
criterion = nn.CrossEntropyLoss() # 多分类损失函数model.to(device)
loss_meter = meter.AverageValueMeter()best_acc = 0 # 保存最好准确率
best_model = None # 保存对应最好准确率的模型参数for epoch in range(epochs):model.train() # 开启训练模式epoch_acc = 0 # 每个epoch的准确率epoch_acc_count = 0 # 每个epoch训练的样本数train_count = 0 # 用于计算总的样本数,方便求准确率loss_meter.reset()train_bar = tqdm(train_loader) # 形成进度条for data in train_bar:x_input, y_input, label = data # 解包迭代器中的X和Yx_input = x_input.long().transpose(1, 0).contiguous()x_input = x_input.to(device)y_input = y_input.long().transpose(1, 0).contiguous()y_input = y_input.to(device)Configimizer.zero_grad()# 形成预测结果output_ = model(x_input, y_input)# 计算损失loss = criterion(output_, label.long().view(-1))loss.backward()Configimizer.step()loss_meter.add(loss.item())# 计算每个epoch正确的个数epoch_acc_count += (output_.argmax(axis=1) == label.view(-1)).sum()train_count += len(x_input)# 每个epoch对应的准确率epoch_acc = epoch_acc_count / train_count# 打印信息print("【EPOCH: 】%s" % str(epoch + 1))print("训练损失为%s" % (str(loss_meter.mean)))print("训练精度为%s" % (str(epoch_acc.item() * 100)[:5]) + '%')# 保存模型及相关信息if epoch_acc > best_acc:best_acc = epoch_accbest_model = model.state_dict()# 在训练结束保存最优的模型参数if epoch == epochs - 1:# 保存模型torch.save(best_model, './best_model.pkl')
测试语句
try:# 数据预处理input_shape = 20 # 序列长度,就是时间步大小,也就是这里的每句话中的词的个数# 用于测试的话sentence1 = "我不爱吃剁椒鱼头,但是我爱吃鱼头"sentence2 = "我爱吃土豆,但是不爱吃地瓜"# 将对应的字转化为相应的序号x_input = [[word2idx[word] for word in sentence1]]x_input = pad_sequences(maxlen=input_shape, sequences=x_input, padding='post', value=0)x_input = torch.from_numpy(x_input)y_input = [[word2idx[word] for word in sentence2]]y_input = pad_sequences(maxlen=input_shape, sequences=y_input, padding='post', value=0)y_input = torch.from_numpy(y_input)# 加载模型model_path = './best_model.pkl'model = LSTM(vocab_size=vocab_size, hidden_dim=hidden_dim, num_layers=num_layers,embedding_dim=embedding_dim, output_dim=output_dim)model.load_state_dict(torch.load(model_path, 'cpu'))# 模型预测,注意输入的数据第一个input_shapey_pred = model(x_input.long().transpose(1, 0), y_input.long().transpose(1, 0))idx2label = {0:"匹配失败!", 1:"匹配成功!"}print('输入语句: %s \t %s' % (sentence1, sentence2))print('文本匹配结果: %s' % idx2label[y_pred.argmax().item()])except KeyError as err:print("您输入的句子有汉字不在词汇表中,请重新输入!")print("不在词汇表中的单词为:%s." % err)
数据集为QA_corpus,训练数据10w条,验证集和测试集均为1w条
其中对应模型文件夹下的args.py文件是超参数
QA_corpus
数据集展示