TextCNN的复现

TextCNN的复现–pytorch的实现

对于TextCNN的讲解,可以参考这篇文章

Convolutional Neural Networks for Sentence Classification - 知乎 (zhihu.com)

接下来主要是对代码内容的详解,完整代码将在文章末尾给出。

使用的数据集为电影评论数据集,其中正面数据集5000条左右,负面的数据集也为5000条。

pyroch的基本训练过程:

加载训练集–构建模型–模型训练–模型评价

首先,是要对数据集进行加载,在对数据集加载时候需要继承一下Dataset类,代码如下

class Data_loader(Dataset):def __init__(self, file_pos, file_neg, model_path, word2_vec=False):self.file_pos = file_posself.file_neg = file_negif word2_vec:self.x_train, self.y_train = self.get_word2vec(model_path)else:self.x_train, self.y_train, self.dictionary = self.pre_process()def __getitem__(self, idx):data = self.x_train[idx]label = self.y_train[idx]data = torch.tensor(data)label = torch.tensor(label)return data, labeldef __len__(self):return len(self.x_train)def clean_sentences(self, string):string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)string = re.sub(r"\'s", " \'s", string)string = re.sub(r"\'ve", " \'ve", string)string = re.sub(r"n\'t", " n\'t", string)string = re.sub(r"\'re", " \'re", string)string = re.sub(r"\'d", " \'d", string)string = re.sub(r"\'ll", " \'ll", string)string = re.sub(r",", " , ", string)string = re.sub(r"!", " ! ", string)string = re.sub(r"\(", " \( ", string)string = re.sub(r"\)", " \) ", string)string = re.sub(r"\?", " \? ", string)string = re.sub(r"\s{2,}", " ", string)return string.strip().lower()def load_data_and_labels(self):positive_examples = list(open(self.file_pos, "r", encoding="utf-8").readlines())positive_examples = [s.strip() for s in positive_examples]  # 对评论数据删除每一行数据的\t,\nnegative_examples = list(open(self.file_neg, "r", encoding="utf-8").readlines())negative_examples = [s.strip() for s in negative_examples]  # 对评论数据删除每一行数据的\t,\nx_text = positive_examples + negative_examplesx_text = [self.clean_sentences(_) for _ in x_text]positive_labels = [[1, 0] for _ in positive_examples]  # 正样本数据为1negative_labels = [[0, 1] for _ in negative_examples]  # 负样本数据为0y = np.concatenate([positive_labels, negative_labels], 0)return x_text, y  # 返回的是dataframe对象,[0]data[0]为文本数据,data[1]为标签def pre_process(self):'''加载数据,并对之前使用的数据进行打乱返回,同时根据训练集和测试集的比列进行划分,默认百分80和百分20:return:测试数据、训练数据、以及生成的词汇表'''x_data, y_label = self.load_data_and_labels()max_document_length = max(len(x.split(' ')) for x in x_data)voc = []word_split = [][voc.extend(x.split()) for x in x_data]  # 生成词典[word_split.append(x.split()) for x in x_data]if len(voc) != 0:ordere_dict = OrderedDict(sorted(Counter(_flatten(voc)).items(), key=lambda x: x[1], reverse=True))# 把文档映射成词汇的索引序列dictionary = vocab(ordere_dict)x_data = []for words in word_split:x = list(dictionary.lookup_indices(words))temp_pos = max_document_length - len(x)if temp_pos != 0:for i in range(1, temp_pos + 1):x.extend([0])x_data.append(x)x_data = np.array(x_data)np.random.seed(10)# 将标签打乱顺序,返回索引shuffle_indices = np.random.permutation(np.arange(len(y_label)))x_shuffled = x_data[shuffle_indices]y_shuffled = y_label[shuffle_indices]return x_shuffled, y_shuffled, dictionarydef get_word2vec(self, model_path):model = gensim.models.Word2Vec.load(model_path)x_data, y_label = self.load_data_and_labels()word_split = [][word_split.append(x.split()) for x in x_data]sentence_vectors = []for sentence in word_split:sentence_vector = []for word in sentence:try:v = model.wv.get_vector(word)except Exception as e:v = np.zeros(shape=(model.vector_size,), dtype=np.float32)sentence_vector.append(v)sentence_vectors.append(sentence_vector)max_document_length = max(len(x) for x in sentence_vectors)for vector in sentence_vectors:for i in range(1, max_document_length - len(vector) + 1):v = np.zeros(shape=(model.vector_size,), dtype=np.float32)vector.append(v)vector_data = np.asarray(sentence_vectors, dtype=np.float32)np.random.seed(10)# 将标签打乱顺序,返回索引shuffle_indices = np.random.permutation(np.arange(len(y_label)))x_shuffled = vector_data[shuffle_indices]y_shuffled = y_label[shuffle_indices]return x_shuffled, y_shuffled

上述代码中的__init__ 、getitem 、len是必须要继承实现的方法,clean_sentence是对读取的数据进行清洗,load_data_and_label是加载数据且返回清洗过后的数据以及数据标签。pre_process是对数据进行编码,原始的数据是英文数据,因此需要对其进行分词、编码,最后返回的数据将是数字,一行数据就是一句评论。

例如:

I like this movie

在对其进行编码返回后将是 0 1 2 3,0对应的为I,1对应的为like以此类推。

get_word2vec则是使用word2vec预训练模型来对每个单词对应的数据内容进行映射。1个单词对应的将会是一个100维的矩阵,该维度可以根据自己训练word2vec模型时候自己进行调整。

接下来是word2vec模型的训练及保存,出于简便性,训练word2vec模型时候直接使用了该数据集对word2vec模型进行训练。

代码如下所示:

def get_model(p_file, n_file):x_data, y_label = load_data_and_labels(p_file, n_file)x_data = [x.split() for x in x_data]max_document_length = max(len(x) for x in x_data)model = Word2Vec(x_data, vector_size=256)return model

在这儿设置的每个词的维度是256维。

接下来就是TextCNN模型的构建

class GlobalMaxPool1d(nn.Module):def __init__(self):super(GlobalMaxPool1d, self).__init__()def forward(self, x):return F.max_pool1d(x, kernel_size=x.shape[2])class TextCNN(nn.Module):def __init__(self, num_classes, num_embeddings=-1, embedding_dim=512, kernel_size=[3, 4, 5, 6],num_channels=[32, 32, 32, 32], embeddings_pretrained=None):super(TextCNN, self).__init__()self.num_classes = num_classesself.num_embeddings = num_embeddingsif self.num_embeddings > 0:self.embedding = nn.Embedding(num_embeddings, embedding_dim)if embeddings_pretrained is not None:self.embedding = self.embedding.from_pretrained(embeddings_pretrained, freeze=False)self.cnn_layers = nn.ModuleList()  # 创建多个一维卷积层for c, k in zip(num_channels, kernel_size):cnn = nn.Sequential(nn.Conv1d(in_channels=embedding_dim, out_channels=c, kernel_size=k),nn.BatchNorm1d(c),nn.ReLU(inplace=True))# cnn = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=c, kernel_size=k),#                     nn.BatchNorm1d(c),#                     nn.ReLU(inplace=True)#                     )self.cnn_layers.append(cnn)self.pool = GlobalMaxPool1d()self.classify = nn.Sequential(nn.Dropout(p=0.2),nn.Linear(sum(num_channels), self.num_classes))def forward(self, x):if self.num_embeddings > 0:x = self.embedding(x)# input = torch.unsqueeze(x, dim=1)# print(input.size())input = x.permute(0, 2, 1)# print(input.size())# print(len(input[0]))y = []for layer in self.cnn_layers:x = layer(input)x = self.pool(x).squeeze(-1)y.append(x)# print(y)y = torch.cat(y, dim=1)out = self.classify(y)# out = torch.sigmoid(out)return out

在构建模型时候需要继承nn.moudule,同时要实现__init__、以及forward方法,可以看作init在定义各个层,forward在对各个层之间来进行连接。

接下来就是对模型进行训练,代码如下所示:

batch_size = 832
num_classes = 2
file_pos = 'E:\\PostGraduate\\Paper_review\\pytorch_TextCnn/data/rt-polarity.pos'
file_neg = 'E:\\PostGraduate\\Paper_review\\pytorch_TextCnn/data/rt-polarity.neg'
word2vec_path = 'E:\\PostGraduate\\Paper_review\\pytorch_TextCnn/word2vec1.model'
train_data = Data_loader(file_pos, file_neg, word2vec_path)
train_size = int(len(train_data) * 0.8)
test_size = len(train_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(train_data, [train_size, test_size])
train_iter = DataLoader(train_dataset, batch_size=830, shuffle=True)
test_iter = DataLoader(test_dataset, batch_size=2133, shuffle=True)
model = TextCNN(num_classes, embeddings_pretrained=True)
# model = TextCNN(num_classes, num_embeddings=18764)
# 开始训练
epoch = 100  # 训练轮次
optmizer = torch.optim.Adam(model.parameters(), lr=0.01)
# optmizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.4)
train_losses = []
train_counter = []
test_losses = []
log_interval = 5
test_counter = [i * len(train_iter.dataset) for i in range(epoch + 1)]
device = 'cpu'def train_loop(n_epochs, optimizer, model, train_loader, device, test_iter):for epoch in range(1, n_epochs + 1):print("开始第{}轮训练".format(epoch))model.train()correct = 0for i, data in enumerate(train_loader):optimizer.zero_grad()(text_data, label) = datatext_data = text_data.to(device)label = label.to(device)label = label.long()output = model(text_data)loss_func = nn.BCEWithLogitsLoss()# output = output.long()loss = loss_func(output, label.float())loss.backward()optimizer.step()pred = output.data.max(1, keepdim=True)[1]label = label.data.max(1, keepdim=True)[1]correct += pred.eq(label.data.view_as(pred)).sum()if i % log_interval == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, i * len(text_data), len(train_loader.dataset),100. * i / len(train_loader), loss.item()))train_losses.append(loss.item())train_counter.append((i * 64) + ((epoch - 1) * len(train_loader.dataset)))torch.save(model.state_dict(), './model.pth')torch.save(optimizer.state_dict(), './optimizer.pth')print("Accuracy: {}/{} ({:.0f}%)\n".format(correct, len(train_loader.dataset),100. * correct / len(train_loader.dataset)))print("开始第{}轮评价".format(epoch))model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_iter:# for data, target in train_iter:data = data.to(device)target = target.to(device)output = model(data)loss_func = nn.BCEWithLogitsLoss()# output = output.long()loss = loss_func(output, target.float())test_loss += losspred = output.data.max(1, keepdim=True)[1]label = target.data.max(1, keepdim=True)[1]correct += pred.eq(label.data.view_as(pred)).sum()test_loss /= len(test_iter.dataset)# test_loss /= len(train_iter.dataset)test_losses.append(test_loss)print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_iter.dataset),100. * correct / len(test_iter.dataset)))train_loop(epoch, optmizer, model, train_iter, device, test_iter)

在上述中,首先会对数据集加载进来,然后分为80%的训练集和20%的测试集,定义使用的优化器为adam。同时在训练的过程中会对优化器、损失函数等信息进行保存。

训练结果如下所示:

完整代码链接

t, len(test_iter.dataset),
100. * correct / len(test_iter.dataset)))

train_loop(epoch, optmizer, model, train_iter, device, test_iter)


在上述中,首先会对数据集加载进来,然后分为80%的训练集和20%的测试集,定义使用的优化器为adam。同时在训练的过程中会对优化器、损失函数等信息进行保存。训练结果为75%左右。完整代码链接[木南/TextCNN (gitee.com)](https://gitee.com/nanwang-crea/text-cnn)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/656183.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(五)MySQL的备份及恢复

1、MySQL日志管理 在数据库保存数据时,有时候不可避免会出现数据丢失或者被破坏,这样情况下,我们必须保证数据的安全性和完整性,就需要使用日志来查看或者恢复数据了 数据库中数据丢失或被破坏可能原因: 误删除数据…

idea破解方法

idea破解:IDEA 2023.2.5 最新激活码,注册码(亲测好用)

GitHub 开启 2FA 双重身份验证的方法

为什么要开启 2FA 自2023年3月13日起,我们登录 GitHub 都会看到一个要求 Enable 2FA 的重要提示,具体如下: GitHub users are now required to enable two-factor authentication as an additional security measure. Your activity on GitHub includes you in this requi…

vivado 硬块规划器

硬块规划器 Versal自适应SoC的硬块规划GT组件从通用/通道更新为AMD的GT_QUAD粒度Versal™ 自适应SoC。为了启用某些GT共享用例,对GT向导流进行了修改使用Vivado IP集成商。使用Vivado IP集成商构建使用单个或多个GT_ QUAD。连接到GT_QUAD的自定义IP的设计条目为通过…

认知篇:什么是逆转诅咒?一个提问GPT的错误姿势

本系列文章主要是分享一些关于大模型的一些学术研究或者实验性质的探索,为大家更新一些针对大模型的认知。所有的结论我都会附上对应的参考文献,有理有据,也希望这些内容可以对大家使用大模型的过程有一些启发。 注:本系列研究关注…

养猫家庭如何挑选宠物空气净化器?猫用空气净化器品牌推荐!

家里的猫咪真的太可爱了,但它们的毛发总是无处不在。而且猫砂盆一天不清理,整个屋子都会弥漫着臭味。每天打扫也很费时费力,虽然享受着猫咪带来的快乐,但也不得不面对这些困扰。 一直以来,我都想购买一台空气净化器&a…

宠物处方单子怎么开,宠物门诊处方管理软件教程

宠物处方单子怎么开,宠物门诊处方管理软件教程 一、前言 宠物店电子处方软件操作教程以 佳易王宠物店电子处方管理系统V16.0为例说明。 如图,在开处方的时候,点击导航栏菜单,兽医处方按钮 点击 增加新单,填写宠物及…

Security ❀ HTTP/HTTPS逐包解析交互过程细节

文章目录 1. TCP三次握手机制2. HTTP Request 请求报文3. HTTP Response 响应报文4. SSL/TLS协议4.1. ClientHello 客户端Hello报文4.2 ServerHello 服务器Hello报文4.3. *ServerKeyExchange 服务公钥交换4.4. ClientKeyExchange 客户端公钥交换4.5. *CertificateVerify 证书验…

graphviz下载与使用-----决策树可视化

下载graphviz 官网:https://www.graphviz.org/download/ 安装graphviz 双击安装程序

《葡萄与葡萄酒鉴赏》期末考核

题目一:随着时代的发展,人们的生活水平逐渐提高,人们更加注重生活的质量和品位,葡萄酒已经成为人们生活中不可缺少的一部分。请论述葡萄酒的营养价值和经济价值。 葡萄酒的营养价值: 抗氧化物质:葡萄酒中富含抗氧化物…

vue中父组件直接调用子组件方法(通过ref)

目录 1、vue2 中,父组件调用子组件的方法 2、vue3 中,父组件调用子组件的方法 1、vue2 中,父组件调用子组件的方法 在Vue 2中,父组件可以通过使用ref属性来引用子组件的实例,然后通过该实例调用子组件的方法。 首先…

报错“MySql配置文件已损坏,请联系技术支持”的解决方法

目录 第一步 打开控制面板,选择管理工具,再选择事件查看器 第二步 在【应用程序】里找到这条报错,记下来文件内容。我自己的来源是“MsiInstaller” 第三步 winR组合键,输入regedit打开注册表 第四步 根据前面报错的文件名定位…

Linux ip命令

IP命令 从centos7以前我们一直使用ifconfig命令来执行网络相关的任务,比如检查和配置网卡信息,但是ifconfig已经不再被维护,并且在最近版本的Linux中被废除了!ifconfig命令已经被ip命令所代替了。 ip 命令跟 ifconfig 命令有些类似&#xff…

靠着这篇笔记,我拿下了16k车载测试offer!

🔥 交流讨论:欢迎加入我们一起学习! 🔥 资源分享:耗时200小时精选的「软件测试」资料包 🔥 教程推荐:火遍全网的《软件测试》教程 📢欢迎点赞 👍 收藏 ⭐留言 &#x1…

Android 熄屏录音一分钟后没有声音

在使用录音功能的时候发现熄屏的时候过了一分钟之后就没有声音了,虽然录音还在录制但是没有声音,推测是熄屏后手机声音的什么服务关闭了。 可以用前台服务使录音这个动作保活,Android官方文档 服务概览 | Background work | Android De…

构建基于Flask的跑腿外卖小程序

跑腿外卖小程序作为现代生活中的重要组成部分,其技术实现涉及诸多方面,其中Web开发框架是至关重要的一环。在这篇文章中,我们将使用Python的Flask框架构建一个简单的跑腿外卖小程序的原型,展示其基本功能和实现原理。 首先&…

NVIDIA Isaac Sim 入门教程(二)

系列文章目录 前言 一、简介 1.1. Isaac Sim Interface 1.1.1. 学习目标 本教程介绍了Omniverse Isaac Sim中最常用的用户界面按钮、菜单和控件。学完本教程后,您应该能够更自信地在 Isaac Sim 界面中浏览和查找内容。 1.1.2. 入门 首先在场景中添加一个立方体。…

JavaScript学习大纲

1.基本概念和语法 JavaScript简介和历史JavaScript的用途和应用领域JavaScript的基本语法(变量、数据类型、运算符等)控制流程(条件语句、循环语句等)函数和作用域 2.DOM操作 了解DOM(文档对象模型)的基…

银行数据仓库体系实践(14)--数据应用之内部报表及数据分析

在银行日常经营中,每个部门、分支行随时随地都需要进行数据统计和分析,才能对银行当前业务状况及时了解,以进行后续经营策略、营销活动、风险策略的调整和决策。那在平时进行数据分析时除了各数据应用系统(如各类监管报表系统、财…

DAY35:贪心算法part4、860\406\452

Leetcode: 860 柠檬水找零 有如下三种情况: 情况一:账单是5,直接收下。 情况二:账单是10,消耗一个5,增加一个10 情况三:账单是20,优先消耗一个10和一个5,如果不够&am…