- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
本周任务:
- 结合Word2Vec文本内容预测文本标签
加载数据
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings
import pandas as pdwarnings.filterwarnings('ignore')device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# 从本地CSV文件中读取文本内容和标签
train_data = pd.read_csv("D:/桌面/365/train.csv", sep='\t', header=None)
train_data.head()
def coustom_data_iter(texts, labels):for x, y in zip(texts, labels):yield x, yx = train_data[0].values[:]
y = train_data[1].values[:]
构建词典
from gensim.models.word2vec import Word2Vec
import numpy as npw2v = Word2Vec(vector_size=100, min_count=3)w2v.build_vocab(x)
w2v.train(x, total_examples=w2v.corpus_count, epochs=20)
def average_vec(text):vec = np.zeros(100).reshape((1,100))for word in text:try:vec += w2v.wv[word].reshape((1,100))except KeyError:continuereturn vecx_vec = np.concatenate([average_vec(z) for z in x])w2v.save('./w2c_model.pkl')
生成数据批次和迭代器
text_pipeline = lambda x: average_vec(x)
label_pipeline = lambda x: label_name.index(x)
text_pipeline('你在干嘛')
array([[ 0.78121352, 1.93111382, 0.96291968, 0.39362412, -1.67714586,-0.55152619, 1.7284598 , 0.69204517, 1.1396839 , -0.9755076 ,-0.55864345, -3.68676656, 1.41707338, -0.44626126, 0.2580443 ,1.09325009, 2.28043211, -2.26334408, 3.32311766, -1.24760717,2.2325974 , -0.48408172, -0.55063696, 0.36853465, -1.32127168,-0.53377433, -1.48909409, -0.5050023 , 1.42371842, -0.4252875 ,2.52355766, 0.60818394, -1.68924798, -0.16912293, 1.26915893,-0.4575564 , 0.02507078, 3.33139969, -2.1995108 , 0.44307417,-0.41596803, 1.39861814, -0.58643346, 0.91654699, -0.08089826,0.08773175, 1.51611513, -0.22212304, -3.55333737, 1.93851076,0.42497785, -1.47862379, -0.96684674, 1.20408788, -0.86870126,-1.12228102, 1.67186388, -1.11024326, -0.18936946, 1.0811481 ,1.82965288, -0.78202841, 2.17574303, -1.03871018, -0.51042572,0.40746585, -1.70572275, 1.3409467 , 1.38298857, 1.11757374,-0.8333215 , 0.04856796, 1.43110101, -0.02333559, 0.82732772,-0.9469737 , -4.43783602, -0.20290428, 1.04759257, -1.21757071,-1.30356295, 0.50049417, -1.87846385, 2.47995635, -2.41918275,-1.72291106, 2.65663178, -0.96948189, -1.30033612, -0.37353188,0.53420451, -1.99955091, 0.12223354, 1.74861516, 0.99491888,-1.43117569, 0.063243 , 0.84598846, -2.79536995, 0.02697589]])
from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list = [],[]for (_text, _label) in batch:label_list.append(label_pipeline(_label))processed_text = torch.tensor(text_pipeline(_text), dtype=torch.float32)text_list.append(processed_text)label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)return text_list.to(device), label_list.to(device)datalodaer = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)
模型构建
from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, num_class):super(TextClassificationModel, self).__init__()self.fc = nn.Linear(100, num_class)def forward(self, text):text = text.float()return self.fc(text)
在这组词汇中不匹配的词汇:书
初始化模型
num_class = len(label_name)
vocab_size = 100000
em_size = 12
model = TextClassificationModel(num_class).to(device)
定义训练及评估函数
import timedef train(dataloader):model.train()total_acc, train_loss, total_count = 0,0,0log_interval = 50start_time = time.time()for idx, (text,label) in enumerate(dataloader): # text, label的顺序不能反,否则会报错predicted_label = model(text)optimizer.zero_grad()loss = criterion(predicted_label, label)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)optimizer.step()total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc/total_count, train_loss/total_count))total_acc, train_loss, total_count = 0,0,0start_time = time.time()def evaluate(dataloader):model.eval()total_acc,train_loss, total_count = 0,0,0with torch.no_grad():for idx, (text,label) in enumerate(dataloader):predicted_label = model(text)loss = criterion(predicted_label, label)total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count
训练模型
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_datasetEPOCHS = 10
LR = 5
BATCH_SIZE = 64criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = Nonetrain_iter = coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_, split_valid_ = random_split(train_dataset, [int(len(train_dataset) * 0.8), int(len(train_dataset) * 0.2)])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)for epoch in range(1,EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('|epoch {:1d} | time: {:4.2f}s |''valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch, time.time() - epoch_start_time, val_acc, val_loss,lr))print('-' * 69)
| epoch 1 | 50/ 152 batches| train_acc 0.752 train_loss 0.02433
| epoch 1 | 100/ 152 batches| train_acc 0.836 train_loss 0.01740
| epoch 1 | 150/ 152 batches| train_acc 0.831 train_loss 0.01821
---------------------------------------------------------------------
|epoch 1 | time: 2.83s |valid_acc 0.847 valid_loss 0.016 | lr 5.000000
---------------------------------------------------------------------
| epoch 2 | 50/ 152 batches| train_acc 0.843 train_loss 0.01709
| epoch 2 | 100/ 152 batches| train_acc 0.835 train_loss 0.01863
| epoch 2 | 150/ 152 batches| train_acc 0.854 train_loss 0.01577
---------------------------------------------------------------------
|epoch 2 | time: 1.28s |valid_acc 0.852 valid_loss 0.017 | lr 5.000000
---------------------------------------------------------------------
| epoch 3 | 50/ 152 batches| train_acc 0.854 train_loss 0.01663
| epoch 3 | 100/ 152 batches| train_acc 0.855 train_loss 0.01743
| epoch 3 | 150/ 152 batches| train_acc 0.846 train_loss 0.01738
---------------------------------------------------------------------
|epoch 3 | time: 1.34s |valid_acc 0.862 valid_loss 0.017 | lr 5.000000
---------------------------------------------------------------------
| epoch 4 | 50/ 152 batches| train_acc 0.862 train_loss 0.01514
| epoch 4 | 100/ 152 batches| train_acc 0.854 train_loss 0.01638
| epoch 4 | 150/ 152 batches| train_acc 0.854 train_loss 0.01920
---------------------------------------------------------------------
|epoch 4 | time: 1.18s |valid_acc 0.847 valid_loss 0.018 | lr 5.000000
---------------------------------------------------------------------
| epoch 5 | 50/ 152 batches| train_acc 0.898 train_loss 0.00902
| epoch 5 | 100/ 152 batches| train_acc 0.897 train_loss 0.00885
| epoch 5 | 150/ 152 batches| train_acc 0.900 train_loss 0.00893
---------------------------------------------------------------------
|epoch 5 | time: 1.37s |valid_acc 0.879 valid_loss 0.011 | lr 0.500000
---------------------------------------------------------------------
| epoch 6 | 50/ 152 batches| train_acc 0.900 train_loss 0.00788
| epoch 6 | 100/ 152 batches| train_acc 0.904 train_loss 0.00703
| epoch 6 | 150/ 152 batches| train_acc 0.901 train_loss 0.00681
---------------------------------------------------------------------
|epoch 6 | time: 1.33s |valid_acc 0.883 valid_loss 0.010 | lr 0.500000
---------------------------------------------------------------------
| epoch 7 | 50/ 152 batches| train_acc 0.922 train_loss 0.00573
| epoch 7 | 100/ 152 batches| train_acc 0.901 train_loss 0.00728
| epoch 7 | 150/ 152 batches| train_acc 0.894 train_loss 0.00702
---------------------------------------------------------------------
|epoch 7 | time: 1.12s |valid_acc 0.879 valid_loss 0.009 | lr 0.500000
---------------------------------------------------------------------
| epoch 8 | 50/ 152 batches| train_acc 0.908 train_loss 0.00630
| epoch 8 | 100/ 152 batches| train_acc 0.905 train_loss 0.00593
| epoch 8 | 150/ 152 batches| train_acc 0.911 train_loss 0.00526
---------------------------------------------------------------------
|epoch 8 | time: 1.11s |valid_acc 0.881 valid_loss 0.009 | lr 0.050000
---------------------------------------------------------------------
| epoch 9 | 50/ 152 batches| train_acc 0.911 train_loss 0.00580
| epoch 9 | 100/ 152 batches| train_acc 0.905 train_loss 0.00611
| epoch 9 | 150/ 152 batches| train_acc 0.917 train_loss 0.00516
---------------------------------------------------------------------
|epoch 9 | time: 1.12s |valid_acc 0.881 valid_loss 0.009 | lr 0.005000
---------------------------------------------------------------------
| epoch 10 | 50/ 152 batches| train_acc 0.912 train_loss 0.00564
| epoch 10 | 100/ 152 batches| train_acc 0.905 train_loss 0.00575
| epoch 10 | 150/ 152 batches| train_acc 0.916 train_loss 0.00565
---------------------------------------------------------------------
|epoch 10 | time: 1.12s |valid_acc 0.881 valid_loss 0.009 | lr 0.000500
---------------------------------------------------------------------
测试指定数据
def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text), dtype=torch.float32)print(text.shape)output = model(text)return output.argmax(1).item()ex_text_str = '还有双鸭山到淮阴的汽车票吗13号的'model = model.to('cpu')print('该文本的类别是: %s' %label_name[predict(ex_text_str, text_pipeline)])
torch.Size([1, 100])
该文本的类别是: Travel-Query
总结
- 本周是结合前几周的内容,使用Word2Vec进行词嵌入之后,再实现中文文本分类
- 本次自己的错误:将for idx, (text,label) in enumerate(dataloader): 中的text、label搞反了,导致输入和模型的输出无法匹配,因此花费了很多时间