基于多种CNN模型在清华新闻语料分类效果上的对比

该实验项目目录如图: 

1、 模型

1.1. TextCNN

# coding: UTF-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Config(object):"""配置参数"""def __init__(self, dataset, embedding):self.model_name = 'TextCNN'self.train_path = dataset + '/data/train.txt'                                # 训练集self.dev_path = dataset + '/data/dev.txt'                                    # 验证集self.test_path = dataset + '/data/test.txt'                                  # 测试集self.predict_path = dataset + '/data/predict.txt'# self.class_list = [x.strip() for x in open(dataset + '/data/class.txt', encoding='utf-8').readlines()]              # 类别名单self.class_list = ['财经', '房产', '股票', '教育', '科技', '社会', '时政', '体育', '游戏','娱乐']self.vocab_path = dataset + '/data/vocab.pkl'                                # 词表self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'        # 模型训练结果self.log_path = dataset + '/log/' + self.model_nameself.embedding_pretrained = torch.tensor(np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\if embedding != 'random' else None                                       # 预训练词向量self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # 设备self.dropout = 0.5                                              # 随机失活self.require_improvement = 1000                                 # 若超过1000batch效果还没提升,则提前结束训练self.num_classes = len(self.class_list)                         # 类别数self.n_vocab = 0                                                # 词表大小,在运行时赋值self.num_epochs = 20                                             # epoch数self.batch_size = 128                                           # mini-batch大小self.pad_size = 32                                              # 每句话处理成的长度(短填长切)self.learning_rate = 1e-3                                       # 学习率self.embed = self.embedding_pretrained.size(1)\if self.embedding_pretrained is not None else 300           # 字向量维度self.filter_sizes = (2, 3, 4)                                   # 卷积核尺寸self.num_filters = 256                                          # 卷积核数量(channels数)'''Convolutional Neural Networks for Sentence Classification'''class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()if config.embedding_pretrained is not None:self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)else:self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)self.convs = nn.ModuleList([nn.Conv2d(1, config.num_filters, (k, config.embed)) for k in config.filter_sizes])self.dropout = nn.Dropout(config.dropout)self.fc = nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes)def conv_and_pool(self, x, conv):x = F.relu(conv(x)).squeeze(3)x = F.max_pool1d(x, x.size(2)).squeeze(2)return xdef forward(self, x):out = self.embedding(x[0])out = out.unsqueeze(1)out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)out = self.dropout(out)out = self.fc(out)return out

1.2. TextRCNN

# coding: UTF-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Config(object):"""配置参数"""def __init__(self, dataset, embedding):self.model_name = 'TextRCNN'self.train_path = dataset + '/data/train.txt'                                # 训练集self.dev_path = dataset + '/data/dev.txt'                                    # 验证集self.test_path = dataset + '/data/test.txt'                                  # 测试集self.class_list = [x.strip() for x in open(dataset + '/data/class.txt', encoding='utf-8').readlines()]              # 类别名单self.vocab_path = dataset + '/data/vocab.pkl'                                # 词表self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'        # 模型训练结果self.log_path = dataset + '/log/' + self.model_nameself.embedding_pretrained = torch.tensor(np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\if embedding != 'random' else None                                       # 预训练词向量self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # 设备self.dropout = 1.02                                              # 随机失活self.require_improvement = 1000                                 # 若超过1000batch效果还没提升,则提前结束训练self.num_classes = len(self.class_list)                         # 类别数self.n_vocab = 0                                                # 词表大小,在运行时赋值self.num_epochs = 20                                            # epoch数self.batch_size = 128                                           # mini-batch大小self.pad_size = 32                                              # 每句话处理成的长度(短填长切)self.learning_rate = 1e-3                                       # 学习率self.embed = self.embedding_pretrained.size(1)\if self.embedding_pretrained is not None else 300           # 字向量维度, 若使用了预训练词向量,则维度统一self.hidden_size = 256                                          # lstm隐藏层self.num_layers = 1                                             # lstm层数'''Recurrent Convolutional Neural Networks for Text Classification'''class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()if config.embedding_pretrained is not None:self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)else:self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,bidirectional=True, batch_first=True, dropout=config.dropout)self.maxpool = nn.MaxPool1d(config.pad_size)self.fc = nn.Linear(config.hidden_size * 2 + config.embed, config.num_classes)def forward(self, x):x, _ = xembed = self.embedding(x)  # [batch_size, seq_len, embeding]=[64, 32, 64]out, _ = self.lstm(embed)out = torch.cat((embed, out), 2)out = F.relu(out)out = out.permute(0, 2, 1)out = self.maxpool(out).squeeze()out = self.fc(out)return out

1.3. TextRNN

# coding: UTF-8
import torch
import torch.nn as nn
import numpy as npclass Config(object):"""配置参数"""def __init__(self, dataset, embedding):self.model_name = 'TextRNN'self.train_path = dataset + '/data/train.txt'                                # 训练集self.dev_path = dataset + '/data/dev.txt'                                    # 验证集self.test_path = dataset + '/data/test.txt'                                  # 测试集# self.class_list = [x.strip() for x in open(dataset + '/data/class.txt', encoding='utf-8').readlines()]self.class_list = ['体育', '军事', '娱乐', '政治', '教育', '灾难', '社会', '科技', '财经', '违法']self.vocab_path = dataset + '/data/vocab.pkl'                                # 词表self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'        # 模型训练结果self.log_path = dataset + '/log/' + self.model_nameself.embedding_pretrained = torch.tensor(np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\if embedding != 'random' else None                                       # 预训练词向量self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # 设备self.dropout = 0.5                                              # 随机失活self.require_improvement = 1000                                 # 若超过1000batch效果还没提升,则提前结束训练self.num_classes = len(self.class_list)                         # 类别数self.n_vocab = 0                                                # 词表大小,在运行时赋值self.num_epochs = 20                                            # epoch数self.batch_size = 128                                           # mini-batch大小self.pad_size = 32                                              # 每句话处理成的长度(短填长切)self.learning_rate = 1e-3                                       # 学习率self.embed = self.embedding_pretrained.size(1)\if self.embedding_pretrained is not None else 300           # 字向量维度, 若使用了预训练词向量,则维度统一self.hidden_size = 128                                          # lstm隐藏层self.num_layers = 2                                             # lstm层数'''Recurrent Neural Network for Text Classification with Multi-Task Learning'''class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()if config.embedding_pretrained is not None:self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)else:self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,bidirectional=True, batch_first=True, dropout=config.dropout)self.fc = nn.Linear(config.hidden_size * 2, config.num_classes)def forward(self, x):x, _ = xout = self.embedding(x)  # [batch_size, seq_len, embeding]=[128, 32, 300]out, _ = self.lstm(out)out = self.fc(out[:, -1, :])  # 句子最后时刻的 hidden statereturn out'''变长RNN,效果差不多,甚至还低了点...'''# def forward(self, x):#     x, seq_len = x#     out = self.embedding(x)#     _, idx_sort = torch.sort(seq_len, dim=0, descending=True)  # 长度从长到短排序(index)#     _, idx_unsort = torch.sort(idx_sort)  # 排序后,原序列的 index#     out = torch.index_select(out, 0, idx_sort)#     seq_len = list(seq_len[idx_sort])#     out = nn.utils.rnn.pack_padded_sequence(out, seq_len, batch_first=True)#     # [batche_size, seq_len, num_directions * hidden_size]#     out, (hn, _) = self.lstm(out)#     out = torch.cat((hn[2], hn[3]), -1)#     # out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)#     out = out.index_select(0, idx_unsort)#     out = self.fc(out)#     return out

2、训练

 2.1. utils.py代码编写

# coding: UTF-8
import os
import torch
import numpy as np
import pickle as pkl
from tqdm import tqdm
import time
from datetime import timedeltaMAX_VOCAB_SIZE = 10000  # 词表长度限制
UNK, PAD = '<UNK>', '<PAD>'  # 未知字,padding符号def build_vocab(file_path, tokenizer, max_size, min_freq):vocab_dic = {}with open(file_path, 'r', encoding='UTF-8') as f:for line in tqdm(f):lin = line.strip()if not lin:continuecontent = lin.split('\t')[0]for word in tokenizer(content):vocab_dic[word] = vocab_dic.get(word, 0) + 1vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1],reverse=True)[:max_size]vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})return vocab_dicdef build_dataset(config, ues_word):if ues_word:tokenizer = lambda x: x.split(' ')  # 以空格隔开,word-levelelse:tokenizer = lambda x: [y for y in x]  # char-levelif os.path.exists(config.vocab_path):vocab = pkl.load(open(config.vocab_path, 'rb'))else:vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)pkl.dump(vocab, open(config.vocab_path, 'wb'))print(f"Vocab size: {len(vocab)}")def load_dataset(path, pad_size=32):contents = []with open(path, 'r', encoding='UTF-8') as f:for line in tqdm(f):lin = line.strip()if not lin:continuecontent, label = lin.split('\t')words_line = []token = tokenizer(content)seq_len = len(token)if pad_size:if len(token) < pad_size:token.extend([PAD] * (pad_size - len(token)))else:token = token[:pad_size]seq_len = pad_size# word to idfor word in token:words_line.append(vocab.get(word, vocab.get(UNK)))contents.append((words_line, int(label), seq_len))return contents  # [([...], 0), ([...], 1), ...]train = load_dataset(config.train_path, config.pad_size)dev = load_dataset(config.dev_path, config.pad_size)test = load_dataset(config.test_path, config.pad_size)return vocab, train, dev, test, # predictclass DatasetIterater(object):def __init__(self, batches, batch_size, device):self.batch_size = batch_sizeself.batches = batchesself.n_batches = len(batches) // batch_sizeself.residue = False  # 记录batch数量是否为整数if len(batches) % self.n_batches != 0:self.residue = Trueself.index = 0self.device = devicedef _to_tensor(self, datas):x = torch.LongTensor([_[0] for _ in datas]).to(self.device)y = torch.LongTensor([_[1] for _ in datas]).to(self.device)# pad前的长度(超过pad_size的设为pad_size)seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)return (x, seq_len), ydef __next__(self):if self.residue and self.index == self.n_batches:batches = self.batches[self.index * self.batch_size: len(self.batches)]self.index += 1batches = self._to_tensor(batches)return batcheselif self.index >= self.n_batches:self.index = 0raise StopIterationelse:batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]self.index += 1batches = self._to_tensor(batches)return batchesdef __iter__(self):return selfdef __len__(self):if self.residue:return self.n_batches + 1else:return self.n_batchesdef build_iterator(dataset, config, predict):if predict==True:config.batch_size = 1iter = DatasetIterater(dataset, config.batch_size, config.device)return iterdef get_time_dif(start_time):"""获取已使用时间"""end_time = time.time()time_dif = end_time - start_timereturn timedelta(seconds=int(round(time_dif)))if __name__ == "__main__":'''提取预训练词向量'''# 下面的目录、文件名按需更改。train_dir = "./THUCNews/data/train.txt"vocab_dir = "./THUCNews/data/vocab.pkl"pretrain_dir = "./THUCNews/data/sgns.sogou.char"emb_dim = 300filename_trimmed_dir = "./THUCNews/data/embedding_SougouNews"if os.path.exists(vocab_dir):word_to_id = pkl.load(open(vocab_dir, 'rb'))else:# tokenizer = lambda x: x.split(' ')  # 以词为单位构建词表(数据集中词之间以空格隔开)tokenizer = lambda x: [y for y in x]  # 以字为单位构建词表word_to_id = build_vocab(train_dir, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)pkl.dump(word_to_id, open(vocab_dir, 'wb'))embeddings = np.random.rand(len(word_to_id), emb_dim)f = open(pretrain_dir, "r", encoding='UTF-8')for i, line in enumerate(f.readlines()):# if i == 0:  # 若第一行是标题,则跳过#     continuelin = line.strip().split(" ")if lin[0] in word_to_id:idx = word_to_id[lin[0]]emb = [float(x) for x in lin[1:301]]embeddings[idx] = np.asarray(emb, dtype='float32')f.close()np.savez_compressed(filename_trimmed_dir, embeddings=embeddings)

 2.2. train_eval.py代码编写

# coding: UTF-8
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
import time
from torch.utils.tensorboard import SummaryWriter
from utils import get_time_dif# 权重初始化,默认xavier
def init_network(model, method='xavier', exclude='embedding', seed=123):for name, w in model.named_parameters():if exclude not in name:if 'weight' in name:if method == 'xavier':nn.init.xavier_normal_(w)elif method == 'kaiming':nn.init.kaiming_normal_(w)else:nn.init.normal_(w)elif 'bias' in name:nn.init.constant_(w, 0)else:passdef train(config, model, train_iter, dev_iter, test_iter):start_time = time.time()model.train()optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)# 学习率指数衰减,每次epoch:学习率 = gamma * 学习率# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)total_batch = 0  # 记录进行到多少batchdev_best_loss = float('inf')last_improve = 0  # 记录上次验证集loss下降的batch数flag = False  # 记录是否很久没有效果提升writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))for epoch in range(config.num_epochs):print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))# scheduler.step() # 学习率衰减for i, (trains, labels) in enumerate(train_iter):outputs = model(trains)model.zero_grad()loss = F.cross_entropy(outputs, labels)loss.backward()optimizer.step()if total_batch % 100 == 0:# 每多少轮输出在训练集和验证集上的效果true = labels.data.cpu()predic = torch.max(outputs.data, 1)[1].cpu()train_acc = metrics.accuracy_score(true, predic)dev_acc, dev_loss = evaluate(config, model, dev_iter)if dev_loss < dev_best_loss:dev_best_loss = dev_losstorch.save(model.state_dict(), config.save_path)improve = '*'last_improve = total_batchelse:improve = ''time_dif = get_time_dif(start_time)msg = 'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  ' \'Val Acc: {4:>6.2%},  Time: {5} {6}'print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))writer.add_scalar("loss/train", loss.item(), total_batch)writer.add_scalar("loss/dev", dev_loss, total_batch)writer.add_scalar("acc/train", train_acc, total_batch)writer.add_scalar("acc/dev", dev_acc, total_batch)model.train()total_batch += 1if total_batch - last_improve > config.require_improvement:# 验证集loss超过1000batch没下降,结束训练print("No optimization for a long time, auto-stopping...")flag = Truebreakif flag:breakwriter.close()test(config, model, test_iter)def test(config, model, test_iter):# testmodel.load_state_dict(torch.load(config.save_path))model.eval()start_time = time.time()test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'print(msg.format(test_loss, test_acc))print("Precision, Recall and F1-Score...")print(test_report)print("Confusion Matrix...")print(test_confusion)time_dif = get_time_dif(start_time)print("Time usage:", time_dif)def evaluate(config, model, data_iter, test=False):model.eval()loss_total = 0predict_all = np.array([], dtype=int)labels_all = np.array([], dtype=int)with torch.no_grad():for texts, labels in data_iter:outputs = model(texts)loss = F.cross_entropy(outputs, labels)loss_total += losslabels = labels.data.cpu().numpy()predic = torch.max(outputs.data, 1)[1].cpu().numpy()labels_all = np.append(labels_all, labels)predict_all = np.append(predict_all, predic)acc = metrics.accuracy_score(labels_all, predict_all)if test:report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)confusion = metrics.confusion_matrix(labels_all, predict_all)return acc, loss_total / len(data_iter), report, confusionreturn acc, loss_total / len(data_iter)

2.3. text_mixture_predict.py代码编写

# coding:utf-8import torch
import numpy as np
import pickle as pkl
from importlib import import_module
from utils import build_iterator
import argparseparser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, TextRCNN')
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
args = parser.parse_args()MAX_VOCAB_SIZE = 10000  # 词表长度限制
tokenizer = lambda x: [y for y in x]  # char-level
UNK, PAD = '<UNK>', '<PAD>'  # 未知字,padding符号def load_dataset(content, vocab, pad_size=32):contents = []for line in content:lin = line.strip()if not lin:continue# content, label = lin.split('\t')words_line = []token = tokenizer(line)seq_len = len(token)if pad_size:if len(token) < pad_size:token.extend([PAD] * (pad_size - len(token)))else:token = token[:pad_size]seq_len = pad_size# word to idfor word in token:words_line.append(vocab.get(word, vocab.get(UNK)))contents.append((words_line, int(0), seq_len))return contents  # [([...], 0), ([...], 1), ...]def match_label(pred, config):label_list = config.class_listreturn label_list[pred]def final_predict(config, model, data_iter):map_location = lambda storage, loc: storagemodel.load_state_dict(torch.load(config.save_path, map_location=map_location))model.eval()predict_all = np.array([])with torch.no_grad():for texts, _ in data_iter:outputs = model(texts)pred = torch.max(outputs.data, 1)[1].cpu().numpy()pred_label = [match_label(i, config) for i in pred]predict_all = np.append(predict_all, pred_label)return predict_alldef main(text):dataset = 'THUCNews'  # 数据集# 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:randomembedding = 'embedding_SougouNews.npz'if args.embedding == 'random':embedding = 'random'model_name = args.model  # 'TextRCNN'  # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformerx = import_module('models.' + model_name)config = x.Config(dataset, embedding)vocab = pkl.load(open(config.vocab_path, 'rb'))content = load_dataset(text, vocab, 64)predict = Truepredict_iter = build_iterator(content, config, predict)config.n_vocab = len(vocab)model = x.Model(config).to(config.device)result = final_predict(config, model, predict_iter)for i, j in enumerate(result):print('text:{}'.format(text[i]),'\t','label:{}'.format(j))if __name__ == '__main__':test = ['国考28日网上查报名序号查询后务必牢记报名参加2011年国家公务员的考生,如果您已通过资格审查,那么请于10月28日8:00后,登录考录专题网站查询自己的“关键数字”——报名序号。''国家公务员局等部门提醒:报名序号是报考人员报名确认和下载打印准考证等事项的重要依据和关键字,请务必牢记。此外,由于年龄在35周岁以上、40周岁以下的应届毕业硕士研究生和''博士研究生(非在职),不通过网络进行报名,所以,这类人报名须直接与要报考的招录机关联系,通过电话传真或发送电子邮件等方式报名。','高品质低价格东芝L315双核本3999元作者:徐彬【北京行情】2月20日东芝SatelliteL300(参数图片文章评论)采用14.1英寸WXGA宽屏幕设计,配备了IntelPentiumDual-CoreT2390''双核处理器(1.86GHz主频/1MB二级缓存/533MHz前端总线)、IntelGM965芯片组、1GBDDR2内存、120GB硬盘、DVD刻录光驱和IntelGMAX3100集成显卡。目前,它的经销商报价为3999元。','国安少帅曾两度出山救危局他已托起京师一代才俊新浪体育讯随着联赛中的连续不胜,卫冕冠军北京国安的队员心里到了崩溃的边缘,俱乐部董事会连夜开会做出了更换主教练洪元硕的决定。''而接替洪元硕的,正是上赛季在李章洙下课风波中同样下课的国安俱乐部副总魏克兴。生于1963年的魏克兴球员时代并没有特别辉煌的履历,但也绝对称得上特别:15岁在北京青年队获青年''联赛最佳射手,22岁进入国家队,著名的5-19一战中,他是国家队的替补队员。','汤盈盈撞人心情未平复眼泛泪光拒谈悔意(附图)新浪娱乐讯汤盈盈日前醉驾撞车伤人被捕,原本要彩排《欢乐满东华2008》的她因而缺席,直至昨日(12月2日),盈盈继续要与王君馨、马''赛、胡定欣等彩排,大批记者在电视城守候,她足足迟了约1小时才到场。全身黑衣打扮的盈盈,神情落寞、木无表情,回答记者问题时更眼泛泪光。盈盈因为迟到,向记者说声“不好意思”后''便急步入场,其助手坦言盈盈没什么可以讲。后来在《欢乐满东华2008》监制何小慧陪同下,盈盈接受简短访问,她小声地说:“多谢大家关心,交给警方处理了,不方便讲,','甲醇期货今日挂牌上市继上半年焦炭、铅期货上市后,酝酿已久的甲醇期货将在今日正式挂牌交易。基准价均为3050元/吨继上半年焦炭、铅期货上市后,酝酿已久的甲醇期货将在今日正式''挂牌交易。郑州商品交易所(郑商所)昨日公布首批甲醇期货8合约的上市挂牌基准价,均为3050元/吨。据此推算,买卖一手甲醇合约至少需要12200元。业内人士认为,作为国际市场上的''首个甲醇期货品种,其今日挂牌后可能会因炒新资金追捧而出现冲高走势,脉冲式行情过后可能有所回落,不过,投资者在上市初期应关注期现价差异常带来的无风险套利交易机会。']main(test)

2.4. utils_fasttext.py代码编写 

# coding: UTF-8
import os
import torch
import numpy as np
import pickle as pkl
from tqdm import tqdm
import time
from datetime import timedeltaMAX_VOCAB_SIZE = 10000
UNK, PAD = '<UNK>', '<PAD>'def build_vocab(file_path, tokenizer, max_size, min_freq):vocab_dic = {}with open(file_path, 'r', encoding='UTF-8') as f:for line in tqdm(f):lin = line.strip()if not lin:continuecontent = lin.split('\t')[0]for word in tokenizer(content):vocab_dic[word] = vocab_dic.get(word, 0) + 1vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1],reverse=True)[:max_size]vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})return vocab_dicdef build_dataset(config, ues_word):if ues_word:tokenizer = lambda x: x.split(' ')  # 以空格隔开,word-levelelse:tokenizer = lambda x: [y for y in x]  # char-levelif os.path.exists(config.vocab_path):vocab = pkl.load(open(config.vocab_path, 'rb'))else:vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)pkl.dump(vocab, open(config.vocab_path, 'wb'))print(f"Vocab size: {len(vocab)}")def biGramHash(sequence, t, buckets):t1 = sequence[t - 1] if t - 1 >= 0 else 0return (t1 * 14918087) % bucketsdef triGramHash(sequence, t, buckets):t1 = sequence[t - 1] if t - 1 >= 0 else 0t2 = sequence[t - 2] if t - 2 >= 0 else 0return (t2 * 14918087 * 18408749 + t1 * 14918087) % bucketsdef load_dataset(path, pad_size=32):contents = []with open(path, 'r', encoding='UTF-8') as f:for line in tqdm(f):lin = line.strip()if not lin:continuecontent, label = lin.split('\t')words_line = []token = tokenizer(content)seq_len = len(token)if pad_size:if len(token) < pad_size:token.extend([PAD] * (pad_size - len(token)))else:token = token[:pad_size]seq_len = pad_size# word to idfor word in token:words_line.append(vocab.get(word, vocab.get(UNK)))# fasttext ngrambuckets = config.n_gram_vocabbigram = []trigram = []# ------ngram------for i in range(pad_size):bigram.append(biGramHash(words_line, i, buckets))trigram.append(triGramHash(words_line, i, buckets))# -----------------contents.append((words_line, int(label), seq_len, bigram, trigram))return contents  # [([...], 0), ([...], 1), ...]train = load_dataset(config.train_path, config.pad_size)dev = load_dataset(config.dev_path, config.pad_size)test = load_dataset(config.test_path, config.pad_size)return vocab, train, dev, testclass DatasetIterater(object):def __init__(self, batches, batch_size, device):self.batch_size = batch_sizeself.batches = batchesself.n_batches = len(batches) // batch_sizeself.residue = False  # 记录batch数量是否为整数 if len(batches) % self.n_batches != 0:self.residue = Trueself.index = 0self.device = devicedef _to_tensor(self, datas):# xx = [xxx[2] for xxx in datas]# indexx = np.argsort(xx)[::-1]# datas = np.array(datas)[indexx]x = torch.LongTensor([_[0] for _ in datas]).to(self.device)y = torch.LongTensor([_[1] for _ in datas]).to(self.device)bigram = torch.LongTensor([_[3] for _ in datas]).to(self.device)trigram = torch.LongTensor([_[4] for _ in datas]).to(self.device)# pad前的长度(超过pad_size的设为pad_size)seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)return (x, seq_len, bigram, trigram), ydef __next__(self):if self.residue and self.index == self.n_batches:batches = self.batches[self.index * self.batch_size: len(self.batches)]self.index += 1batches = self._to_tensor(batches)return batcheselif self.index >= self.n_batches:self.index = 0raise StopIterationelse:batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]self.index += 1batches = self._to_tensor(batches)return batchesdef __iter__(self):return selfdef __len__(self):if self.residue:return self.n_batches + 1else:return self.n_batchesdef build_iterator(dataset, config, predict):if predict == True:config.batch_size = 1iter = DatasetIterater(dataset, config.batch_size, config.device)return iterdef get_time_dif(start_time):"""获取已使用时间"""end_time = time.time()time_dif = end_time - start_timereturn timedelta(seconds=int(round(time_dif)))if __name__ == "__main__":'''提取预训练词向量'''vocab_dir = "./THUCNews/data/vocab.pkl"pretrain_dir = "./THUCNews/data/sgns.sogou.char"emb_dim = 300filename_trimmed_dir = "./THUCNews/data/vocab.embedding.sougou"word_to_id = pkl.load(open(vocab_dir, 'rb'))embeddings = np.random.rand(len(word_to_id), emb_dim)f = open(pretrain_dir, "r", encoding='UTF-8')for i, line in enumerate(f.readlines()):# if i == 0:  # 若第一行是标题,则跳过#     continuelin = line.strip().split(" ")if lin[0] in word_to_id:idx = word_to_id[lin[0]]emb = [float(x) for x in lin[1:301]]embeddings[idx] = np.asarray(emb, dtype='float32')f.close()np.savez_compressed(filename_trimmed_dir, embeddings=embeddings)

 2.5. run.py代码编写

# coding: UTF-8
import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module
import argparse# 当碰到argparse的参数时,运行方式需要在终端执行我们的程序
parser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', type=str, help='choose a model: TextCNN, TextRNN, TextRCNN', default="TextRNN") #调用模型
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
args = parser.parse_args()if __name__ == '__main__':dataset = 'THUCNews'  # 数据集# 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:randomembedding = 'embedding_SougouNews.npz'if args.embedding == 'random':embedding = 'random'model_name = args.model  # 'TextRCNN'  # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformerif model_name == 'FastText':from utils_fasttext import build_dataset, build_iterator, get_time_difembedding = 'random'else:from utils import build_dataset, build_iterator, get_time_difx = import_module('models.' + model_name)config = x.Config(dataset, embedding)np.random.seed(1)torch.manual_seed(1)torch.cuda.manual_seed_all(1)torch.backends.cudnn.deterministic = True  # 保证每次结果一样start_time = time.time()print("Loading data...")vocab, train_data, dev_data, test_data = build_dataset(config, args.word)train_iter = build_iterator(train_data, config, False)dev_iter = build_iterator(dev_data, config, False)test_iter = build_iterator(test_data, config, False)# predict_iter = build_iterator(predict_data, config)time_dif = get_time_dif(start_time)print("Time usage:", time_dif)# trainconfig.n_vocab = len(vocab)model = x.Model(config).to(config.device)if model_name != 'Transformer':init_network(model)print(model.parameters)train(config, model, train_iter, dev_iter, test_iter)

 3、训练效果

3.1. TextCNN训练效果

TextCNN训练效果

Test Time

0:04:23

Test Loaa:

0.3

Test Acc

90.58%

类别

Precision

recall

F1-score

support

财经

0.9123

0.8950

0.9036

1000

房产

0.9043

0.9360

0.9199

1000

股票

0.8812

0.8230

0.8511

1000

教育

0.9540

0.9530

0.9535

1000

科技

0.8354

0.8880

0.9609

1000

社会

0.8743

0.9110

0.8923

1000

时政

0.8816

0.8860

0.8838

1000

体育

0.9682

0.9430

0.9554

1000

游戏

0.9249

0.9110

0.9179

1000

娱乐

0.9287

0.9120

0.9203

1000

Accuracy

0.9058

1000

Macro avg

0.9065

0.9058

0.9059

1000

Weighted avg

0.9065

0.9058

0.9059

1000

表1 TextCNN网络训练结果

3.2. TextRNN训练效果

TextRNN训练效果

Test Time

0:05:07

Test Loaa:

0.29

Test Acc

91.03%

类别

Precision

recall

F1-score

support

财经

0.9195

0.8790

0.8988

1000

房产

0.9181

0.9190

0.9185

1000

股票

0.8591

0.8290

0.8438

1000

教育

0.9349

0.9480

0.9414

1000

科技

0.8642

0.8720

0.8681

1000

社会

0.9190

0.9080

0.9135

1000

时政

0.8578

0.8990

0.8779

1000

体育

0.9690

0.9690

0.9690

1000

游戏

0.9454

0.9350

0.9402

1000

娱乐

0.9175

0.9450

0.9310

1000

Accuracy

0.9103

1000

Macro avg

0.9104

0.9103

0.9102

1000

Weighted avg

0.9104

0.9103

0.9102

1000

表2 TextRNN网络训练结果

TextRNN网络的训练效果最好,准确率达到了91.03%,明显高于TextCNN网络的效果。

 3.3. TextRCNN训练效果

TextRCNN训练效果

Test Time

0:03:20

Test Loaa:

0.29

Test Acc

90.96%

类别

Precision

recall

F1-score

support

财经

0.9134

0.8970

0.9051

1000

房产

0.9051

0.9350

0.9198

1000

股票

0.8658

0.8320

0.8485

1000

教育

0.9295

0.9500

0.9397

1000

科技

0.8352

0.8770

0.8556

1000

社会

0.8993

0.9290

0.9139

1000

时政

0.8921

0.9680

0.8799

1000

体育

0.9851

0.9670

0.9743

1000

游戏

0.9551

0.9140

0.9341

1000

娱乐

0.9233

0.9270

0.9251

1000

Accuracy

0.9096

1000

Macro avg

0.9101

0.9096

0.9096

1000

Weighted avg

0.9101

0.9096

0.9096

1000

表3 TextRCNN网络训练结果

TextRCNN网络的效果为90.96%,与TextCNN网络模型效果相近。

mode

time

Cpu

TextCNN

0:04:23

CORE i5

TextRNN

0:05:07

TextRCNN

0:03:20

表4 模型对比

最终发现TextRNN的训练结果最好,但所需的时间也是最久的。

4、操作异常问题与解决方案

(1)ModuleNotFoundError: No module named 'tensorboard'

使用以下命令在命令行中安装Tensorboard: pip install tensorboard

(2)module 'distutils' has no attribute 'version'

发生此错误是因为 setuptools 版本59.6.0中的更改以某种方式中断了对 version 属性的调用,而且,在最新版本中,version 属性似乎已经从 distutils 中删除。使用tensorboard包时,setuptools版本过高导致的问题。

卸载当前setuptools:

pip uninstall setuptools

使用pip,不能使用 conda uninstall setuptools ;切记不能使用conda的命令,原因是,conda在卸载的时候,会自动分析与其相关的库,然后全部删除。

pip install setuptools==58.0.4

5、总结

通过分别对模型TextCNN、TextRNN、TextRCNN和不同的硬件环境进行实验,分别对实验结果中的训练时间、准确率、召回率和F1值进行比较,进一步确定哪个模型在给定数据集和硬件环境下表现最佳。最终发现TextRNN的训练结果最好,但所需的时间也是最久的,而TextRCNN模型的训练结果与TextCNN几乎相同,但TextRCNN所需的时间最少。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/652124.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C++类与对象(上)】

C类与对象(上&#xff09; 1.面向过程和面向对象初步认识2.类的引入3.类的定义4.类的访问限定符及封装4.1 访问限定符4.2 封装 5.类的作用域6.类的实例化7.类的对象大小的计算7.1如何计算类对象的大小7.2 类对象的存储方式猜测7.3结构体内存对齐规则 8.类成员函数的this指针8.1…

Java多线程基础-18:线程安全的集合类与ConcurrentHashMap

Java标准库提供了很多集合类&#xff0c;但有一些集合类是线程不安全的&#xff0c;也就是说&#xff0c;在多线程环境下可能会出问题的。常用的ArrayList&#xff0c;LinkedList&#xff0c;HashMap&#xff0c;PriorityQueue等都是线程不安全的&#xff08;Vector, Stack, Ha…

Android创建工程

语言选择Java&#xff0c;我用的Java 最小SDK&#xff1a;就是开发的APP支持的最小安卓版本 Gradle 是一款Google 推出的基于 JVM、通用灵活的项目构建工具&#xff0c;支持 Maven&#xff0c;JCenter 多种第三方仓库;支持传递性依赖管理、废弃了繁杂的xml 文件&#xff0c;转而…

关于ArcGIS的Update更新工具的疑问

Update更新工具官方帮助文件解释如下&#xff1a; 但是根据这个插图很让人疑惑&#xff0c;输入要素是蓝色&#xff0c;更新要素是黄色&#xff0c;输出要素为绿色&#xff0c;而且全部是绿色。我一直以为是与更新要素相交&#xff08;被包含切割&#xff09;的哪些输入要素都被…

【常用工具】7-Zip 解/压缩软件——基本使用方法

在实际日常工作或项目中&#xff0c;经常会遇到需要在window操作系统上压缩文件&#xff0c;在Linux操作系统上解压缩的场景&#xff0c;一款实用的压缩软件迫在眉睫&#xff0c;经过实际使用总结&#xff0c;7-Zip可以很好的解决很多压缩和解压缩问题&#xff0c;其基本使用方…

WordPress如何自定义日期和时间格式?附PHP日期和时间格式字符串

WordPress网站在很多地方都需要用到日期和时间&#xff0c;那么我们应该在哪里设置日期和时间呢&#xff1f;又如何自定义日期和时间格式呢&#xff1f;下面boke112百科就跟大家一起来学习一下PHP标准化的日期和时间格式字符串。 特别说明&#xff1a;格式字符是标准化的&#…

canvas绘制旋转的大风车

查看专栏目录 canvas实例应用100专栏&#xff0c;提供canvas的基础知识&#xff0c;高级动画&#xff0c;相关应用扩展等信息。canvas作为html的一部分&#xff0c;是图像图标地图可视化的一个重要的基础&#xff0c;学好了canvas&#xff0c;在其他的一些应用上将会起到非常重…

LCweekly-game

ExScorecomplete situation1220717/719(解答错误)30523/537(超时,弱智题已AC)40 有用的是Ex2和Ex4 Ex2 my solution class Solution { public://calculate xs l-time 幂乘int jiecheng(int x,int l){int zx;for(int i0;i<l;i){if(z>pow(10,4.5))return 0;zz*z;}return…

C#算法(11)—求三个点构成圆的圆心坐标和半径

前言 我们在上位机开发领域也经常会碰到根据三个点求出圆的圆形、半径等信息的场景,本文就是详细的介绍如何根据三个点使用C#代码求出三点构成的圆的圆心坐标、圆半径、三点构成的圆弧的角度。 1、3点求圆分析 A、B、C三个点都是圆上的坐标点,过向量AB做中垂线,过向量AC做…

What is `@Scheduled` does?

Scheduled 是Spring框架中用于定时任务调度的注解&#xff0c;它允许我们在类的方法上声明一个方法作为定时任务&#xff0c;由Spring容器统一管理和执行。使用此注解后&#xff0c;Spring会根据注解中的属性配置&#xff0c;按照指定的时间规则自动调用该方法。 public class…

文心一言 VS ChatGPT :谁是更好的选择?

前言 目前各种大模型、人工智能相关内容覆盖了朋友圈已经各种媒体平台&#xff0c;对于Ai目前来看只能说各有千秋。GPT的算法迭代是最先进的&#xff0c;但是它毕竟属于国外产品&#xff0c;有着网络限制、注册限制、会员费高昂等弊端&#xff0c;难以让国内用户享受。文心一言…

2023年度AI盘点 AIGC|AGI|ChatGPT|人工智能大模型

前言 「作者主页」&#xff1a;雪碧有白泡泡 「个人网站」&#xff1a;雪碧的个人网站 2023年是人工智能大语言模型大爆发的一年&#xff0c;一些概念和英文缩写也在这一年里集中出现&#xff0c;很容易混淆&#xff0c;甚至把人搞懵。 文章目录 前言01 《ChatGPT 驱动软件开…

使用一个定时器(timer_fd)管理多个定时事件

使用一个定时器(timer_fd)管理多个定时事件 使用 timerfd_xxx 系列函数可以很方便的与 select、poll、epoll 等IO复用函数相结合&#xff0c;实现基于事件的定时器功能。大体上有两种实现思路&#xff1a; 为每个定时事件创建一个 timer_fd&#xff0c;绑定对应的定时回调函数…

QEMU源码全解析41 —— Machine(11)

接前一篇文章&#xff1a;QEMU源码全解析40 —— Machine&#xff08;10&#xff09; 本文内容参考&#xff1a; 《趣谈Linux操作系统》 —— 刘超&#xff0c;极客时间 《QEMU/KVM》源码解析与应用 —— 李强&#xff0c;机械工业出版社 特此致谢&#xff01; 时间过去了几…

go语言(二十一)---- channel的关闭

channel不像文件一样需要经常去关闭&#xff0c;只有当你确实没有任何发送数据了&#xff0c;或者你想显示的结束range循环之类的&#xff0c;才去关闭channel。关闭channel后&#xff0c;无法向channel再发送数据&#xff0c;&#xff08;引发pannic错误后&#xff0c;导致接收…

Linux编译实时内核和打补丁

目录 1.Linux内核2.实时内核3.编译实时内核3.1 准备3.2 获取内核源码3.3 编译3.4 设置GRUB确保启动到实时内核 4.给内核打补丁5.安装新的内核 1.Linux内核 https://github.com/torvalds/linux Linux内核是Linux操作系统的核心部分&#xff0c;它是操作系统的基本组成部分&…

spring整合mybatis的底层原理

spring整合mybatis的底层原理 原理&#xff1a; FactoryBean的自定义对象jdk动态代理Mapper接口对象 一、手写一个spring集成mybatis 目录结构&#xff1a; 1.1 入口类 public class Test {public static void main(String[] args) {AnnotationConfigApplicationContext co…

Linux中的软链接与硬链接

Linux链接概念 Linux链接分两种&#xff0c;一种被称为硬链接&#xff08;Hard Link&#xff09;&#xff0c;另一种被称为符号链接&#xff08;Symbolic Link&#xff09;。默认情况下&#xff0c;使用 ln 命令不加参数创建硬链接&#xff0c;加 -s 参数则创建软链接 硬链接…

【mongoDB】创建用户账号和权限

使用use database_name 命令创建或切换到一个数据库 查看用户 show users 输入该命令后&#xff0c;无数据表示该数据库没有用户 创建用户 user:" freedom " 表示用户名为freedom pwd:" 123456 ” 表示密码为123456 roles:[" root "] …

一键去除图片背景——background-removal-js

一些JavaScript库和工具可以帮助实现背景去除&#xff1a; OpenCV.js&#xff1a;OpenCV的JavaScript版本&#xff0c;提供了许多计算机视觉功能&#xff0c;包括背景去除。Jimp&#xff1a;一个用于处理图像的JavaScript库&#xff0c;提供了许多图像处理功能&#xff0c;包括…