三、自然语言分类

文章目录

  • 1 数据准备
    • 1.1 数据集拆分
    • 1.2 创建词库vocabulary
    • 1.3 batch数据,创建Iterator
  • 2 Word Averaging模型
  • 3 RNN模型
  • 4 CNN

三种分类方式:Word Averaging模型、RNN、CNN。

1 数据准备

第一步是准备数据。代码中用到的类库有spacy、torchtext。
torchtext中有data.Field用来处理每个文本字段。dataLabelFiled处理标签字段。

1.1 数据集拆分

将数据集分为训练集、验证集和测试集。了解每种数据集的数据量,查看每一条数据的样子。
每一条句子是一个样本。

1.2 创建词库vocabulary

TEXT.build_vocab(train_data, max_size=25000, vectors="glove.6B.100d", unk_init=torch.Tensor.normal_)
LABEL.build_vocab(train_data)

设置词表大小,用词向量初始化词表。

1.3 batch数据,创建Iterator

训练的时候是一个batch,一个batch训练的。torchtext会将短句子pad到和最长的句子长度相同。

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits((train_data, valid_data, test_data), batch_size=BATCH_SIZE,device=device)

数据准备好了,就开始用模型训练吧。

2 Word Averaging模型

我们首先介绍一个简单的Word Averaging模型。这个模型非常简单,我们把每个单词都通过Embedding层投射成word embedding vector,然后把一句话中的所有word vector做个平均,就是整个句子的vector表示了。
接下来把这个sentence vector传入一个Linear层,做分类即可。
怎么做平均呢?我们使用avg_pool2d来做average pooling。我们的目标是把sentence length那个维度平均成1,然后保留embedding这个维度。
avg_pool2d的kernel size是 (embedded.shape[1], 1),所以句子长度的那个维度会被压扁。
在这里插入图片描述

import torch.nn as nn
import torch.nn.functional as Fclass WordAVGModel(nn.Module):def __init__(self, vocab_size, embedding_dim, output_dim, pad_idx):super().__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)self.fc = nn.Linear(embedding_dim, output_dim)def forward(self, text):embedded = self.embedding(text) # [sent len, batch size, emb dim]embedded = embedded.permute(1, 0, 2) # [batch size, sent len, emb dim]pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1) # [batch size, embedding_dim]return self.fc(pooled)

接下来就是模型训练、评价。

优化函数:nn.BCEWithLogitsLoss()
优化方法:Adm
评价:损失函数值,准确率。

3 RNN模型

在这里插入图片描述

我们使用最后一个hidden state hTh_ThT来表示整个句子。
然后我们把hTh_ThT通过一个线性变换f,然后用来预测句子的情感。

class RNN(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout, pad_idx):super().__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, dropout=dropout)self.fc = nn.Linear(hidden_dim*2, output_dim)self.dropout = nn.Dropout(dropout)def forward(self, text):embedded = self.dropout(self.embedding(text)) #[sent len, batch size, emb dim]output, (hidden, cell) = self.rnn(embedded)#output = [sent len, batch size, hid dim * num directions]#hidden = [num layers * num directions, batch size, hid dim]#cell = [num layers * num directions, batch size, hid dim]#concat the final forward (hidden[-2,:,:]) and backward (hidden[-1,:,:]) hidden layers#and apply dropouthidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)) # [batch size, hid dim * num directions]return self.fc(hidden.squeeze(0))

4 CNN

CNN可以把每个局域的特征提出来。在用于句子情感分类的时候,其实是做了一个ngram的特征抽取。
首先做一个embedding,得到每个词的词向量,是一个k维的向量。每个词的词向量结合到一起,得到一个nxk的矩阵。n是单词个数。
在这里插入图片描述

其次卷积层filter。现在相当于一个hgram。最后转化为h个单词。
在这里插入图片描述
w是hxk的。
每一个h单词的窗口都会被这个filter转化。c=[c1,c2,...cn−h−1]c=[c_1,c_2,...c_{n-h-1}]c=[c1,c2,...cnh1]
在这里插入图片描述
上图中是一个卷积核为(3,embedding_size)的卷积。每次作用于3个单词,形成3-gram结果。

一般来说会选用多个卷积核:下图是作用了3个(3,embedding_size)卷积核。也可以使用不同尺寸的卷积核。
在这里插入图片描述

第三步,做一个max-over-time pooling。C^=max{c}\widehat{C}=max\{c\}C=max{c}

如果有m个filter,我们会得到z=[C1^,C2^,...Cm^]z=[\widehat{C_1},\widehat{C_2},...\widehat{C_m}]z=[C1,C2,...Cm]

第四步,对z做一个线性变换,得到分类。

class CNN(nn.Module):def __init__(self, vocab_size, embedding_size, output_size, pad_idx, n_filters, filter_sizes, dropout):super(CNN, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=pad_idx)self.convs = nn.ModuleList([nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(fs, embedding_size)) for fs in filter_sizes])self.linear = nn.Linear(len(filter_sizes) * n_filters, output_size)self.dropout = nn.Dropout(dropout)def forward(self, text):# text : [seq_len, batch_size]text = text.permute(1, 0) # [batch_size,seq_len]embed = self.embedding(text)  # [batch_size, seq_len, embedding_size]embed = embed.unsqueeze(1) # [batch_size, 1, seq_len, embedding_size]# conved中每一个的形状是:[batch_size, n_filters, seq_len-filter_sizes[i]+1,embedding_size-embedding_size+1]# 因为kernel_size的第二个维度是embedding_sizeconved = [conv(embed) for conv in self.convs]# conved中每一个的形状是:[batch_size, n_filters, seq_len-filter_sizes[i]+1]conved = [F.relu(conv).squeeze(3) for conv in conved]# pooled中的每一个形状是:[batch_size, n_filters, 1]pooled = [F.max_pool1d(conv, conv.shape[2]) for conv in conved]# pooled中的每一个形状是:[batch_size, n_filters]pooled = [conv.squeeze(2) for conv in pooled]cat = self.dropout(torch.cat(pooled, dim=1)) # [batch_size, len(filter_sizes) * n_filters]return self.linear(cat)  # [batch_size,output_size]N_FILTERS = 100
FILTER_SIZES = [3, 4, 5]
DROPOUT = 0.5
cnn_model = CNN(VOCAB_SIZE, EMBEDDING_SIZE, OUTPUT_SIZE, PAD_IDX, N_FILTERS, FILTER_SIZES, DROPOUT)

代码中的卷积核分别为:(3,embedding_size),(4,embedding_size),(5,embedding_size)
每次卷积完成之后,[batch_size, n_filters, seq_len-filter_sizes[i]+1,embedding_size-embedding_size+1],最后一个维度为1。

因为多个卷积之后的形状不一样,通过max_pooling层后统一变为[batch_size, n_filters]。这样就可以将多个卷积的结果拼接在一起,送入全连接层。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/424033.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

spring mvc学习(38):Unknow tag(c:forEach)错误解决办法,jstl.jar包以及standard.jar包下载与导入

解决问题步骤&#xff1a; ①&#xff1a;下载jstl.jar和standard.jar 点击下载jstl.jar 点击下载standard.jar ②&#xff1a;将两个包剪切到项目中的WEB-INF/lib文件夹内 右键加到eclipse环境中---bulidpath--add to path 第一行代码<% pagelanguage"java" im…

nlp中的经典深度学习模型(一)

文章目录1 DNN与词向量1.1 DNN1.2 skip-gram1.3 简单句子分类模型DAN2 RNNLSTMGRU2.1 RNN2.2 LSTM2.3 LSTM变种2.4 递归神经网络2.5 双向RNN2.6 堆叠RNN1 DNN与词向量 1.1 DNN 神经网络中每一个神经单元是一个线性变化加一个激活函数 sUTasU^TasUTa af(z)af(z)af(z) zWxbzWxb…

spring mvc学习(39):restful的crud实现删除方式

上图是目录结构&#xff0c;本节是有问同学的&#xff0c;当好好总结 pom.xml <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation"http://maven.apache.org/POM/4.0…

像程序员一样地思考

在成为程序员的道路上&#xff0c;要经历四个坎坷&#xff0c;让我们用四个境界来标明他们。 第一境界&#xff0c;就是前面所说的&#xff0c;掌握一门或则几门编程语言&#xff0c;会模仿例子来实现程序代码&#xff0c;并且让代码在计算机系统中运行起来。达到这个境界的人…

nlp中的经典深度学习模型(二)

attention和transformer都是面试重点。 文章目录3 seq2seqAttention3.1 Sequence to Sequence Model3.1.2 模型介绍3.1.2 模型训练3.2注意力机制3.2.1介绍3.2.1“Bahdanau” style attention3.2.2“Luong” style attention4 Transformer4.1 Multi-head Attention4.1.1 自注意力…

spring mvc学习(40):restful的crud实现增加方式

上图是目录结构&#xff0c;本节是有问同学的&#xff0c;当好好总结 pom.xml <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation"http://maven.apache.org/POM/4.0…

spring mvc学习(41):restful的crud的项目原型介绍

上图是目录结构&#xff0c;本节是有问同学的&#xff0c;当好好总结 pom.xml <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation"http://maven.apache.org/POM/4.0.…

nlp中的经典模型(三)

文章目录5 NLP中的卷积神经网络5.1 卷积5.2 多通道5.2 max pooling5 NLP中的卷积神经网络 RNN的问题&#xff1a; 1 时间复杂度高 2 最后一个向量包含所有信息。有点不可靠 CNN可以通过卷积核捕捉局部特征&#xff0c;那是不是可以用于句子&#xff0c;表示特定长度的词序列呢…

第一百二十九期:阿里内部员工,排查Java问题常用的工具单

平时的工作中经常碰到很多疑难问题的处理&#xff0c;在解决问题的同时&#xff0c;有一些工具起到了相当大的作用&#xff0c;在此书写下来&#xff0c;一是作为笔记&#xff0c;可以让自己后续忘记了可快速翻阅&#xff0c;二是分享&#xff0c;希望看到此文的同学们可以拿出…

第一百三十期:14种常见编程语言的优缺点及应用范围

C语言是一门通用计算机编程语言&#xff0c;应用广泛。面向过程的&#xff0c;数据与算法分开。它的重点在于算法和数据结构。1972年由美国贝尔实验室在B语言的基础上设计出。 作者&#xff1a;编程小新 C 概述:C语言是一门通用计算机编程语言&#xff0c;应用广泛。面向过程…

第一百三十一期:2019年容器使用报告:Docker 和 Kubernetes 王者地位不倒!

近日&#xff0c;容器创业公司 Sysdig 发布了 2019 年容器使用报告。这是 Sysdig 第三年发布容器年度使用报告&#xff0c;与之前不同的是&#xff0c;今年的调查结合了更多的数据源&#xff0c;并深入挖掘了 Kubernetes 的使用模式。 作者&#xff1a;高效开发运维 近日&…

Flask 路由映射对于双斜线的处理 //a//b

例子 from flask import Flask import time from tornado.wsgi import WSGIContainer from tornado.httpserver import HTTPServer from tornado.ioloop import IOLoopapp Flask(__name__)app.route(//abc//a) def index():# time.sleep(5)return OKapp.route(/abc//a) def in…

⼤规模⽆监督预训练语⾔模型与应⽤(上)

文章目录1 单词作为语言模型的基本单位的缺点2 character level modeling3预训练句子向量3.1 skip-thought3.2 InferSent3.3 句子向量评价数据集4 预训练文档向量5 ELMO1 单词作为语言模型的基本单位的缺点 单词量有限&#xff0c;遇到没有见过的单词只能以UNK表示。 模型参数…

第一百三十二期:MySQL系列:一句SQL,MySQL是怎么工作的?

当我们在mysql窗口或者数据库连接工具中输入一句sql后&#xff0c;我们就可以获取到想要的数据&#xff0c;这中间MySQL到底是怎么工作的呢&#xff1f; 作者&#xff1a;Java架构学习交流 对于MySQL而言&#xff0c;其实分为客户端与服务端。 服务端&#xff0c;就是MySQL应…

Visual Studio 2005 Tip:编辑项目文件

原文参考自&#xff1a;http://blogs.msdn.com/shawnfa/archive/2006/04/26/582326.aspx很多时候我们需要手动修改VS的项目文件&#xff08;.csproj/.vbproj&#xff09;&#xff0c;这时大多数人会简单的使用记事本&#xff08;notepad&#xff09;打开并编辑。虽然这没什么不…

transformer bert GPT(未完)

原文标题&#xff1a;⼤规模⽆监督预训练语⾔模型与应⽤&#xff08;中&#xff09; 文章目录1 transformer1.1 encoder部分1.1.1 Attention定义1.1.2 Multi-head Attention1.1.3 position-wise feed-forward networks1.1.4 positional encoding1.1.5 残差链接1.1.6 layer norm…

spring mvc学习(42):restful的编辑功能实现

上图是目录结构&#xff0c;本节是有问同学的&#xff0c;当好好总结 pom.xml <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation"http://maven.apache.org/POM/4.0.…

那些年用过的Redis集群架构(含面试解析)

引言 今天&#xff0c;我接到了高中同学刘有码面试失利的消息。 他面试的时候&#xff0c;身份是某知名公司的小码农一枚&#xff0c;却因为不懂自己生产上Redis是如何部署的&#xff0c;导致面试失败&#xff01; 人间惨剧&#xff0c;莫过于此。 接到他面试失利的消息&#x…

再谈BERT

三次讲到了BERT。第一次是nlp中的经典深度学习模型(二)&#xff0c;第二次是transformer & bert &GPT&#xff0c;这是第三次。 文章目录1 关于预训练模型1.1预训练概念1.2 再谈语言模型1.3 ELMo1.4 GPT2 BERT2.1 BERT特点2.2架构2.3 预训练任务2.3.1 masked language …

第一百三十三期:MySQL锁会不会,你就差看一看咯

本文章向大家介绍MySQL锁详细讲解&#xff0c;包括数据库锁基本知识、表锁、表读锁、表写锁、行锁、MVCC、事务的隔离级别、悲观锁、乐观锁、间隙锁GAP、死锁等等&#xff0c;需要的朋友可以参考一下。 作者&#xff1a;php自学中心 本文章向大家介绍MySQL锁详细讲解&#xff…