文档分类DPCNN简介(pytorch实现)

文档分类DPCNN简介

        • DPCNN简介
      • 模型结构
          • 区域嵌入
          • 等长卷积
          • 1/2池化
          • DPCNN模型代码实现

DPCNN简介

论文中提出了一种基于 word-level 级别的网络-DPCNN,由于 TextCNN 不能通过卷积获得文本的长距离依赖关系,而论文中 DPCNN 通过不断加深网络,可以抽取长距离的文本依赖关系。

实验证明在不增加太多计算成本的情况下,增加网络深度就可以获得最佳的准确率。‍

前面我们提到过TextRCNN就是将CNN中的池化加入到RNN中,来解决RNN是一个有的偏倚,现在DPCNN通过不断加深网络,来弥补自身短缺的长距离依赖问题,可见每一种模型都不是十全十美的,只有不断探索,不断创新,相互借鉴,才能够使性能进一步提升。

模型结构

在这里插入图片描述

区域嵌入

这里是将TextCNN的包含多尺寸卷积滤波器的卷积层的卷积结果称之为区域嵌入,即对一个文本区域文本片段(比如3-gram)进行一组卷积操作后生成的embedding。这里不同于textCNN的二维卷积,DPCNN采用的是一维卷积。以3-gram为例子回顾textCNN,设置了一个大小为3xD的二维卷积核进行卷积操作(其中D是词嵌入的维度),其实这是一种保留词序的做法。那么对于DPCNN,采用的是不保留词序的做法,即:首先对3-grm中的3个词的词向量取均值得到一个大小为1xD的向量,然后设置一组大小为1*D的一维卷积核对该3-grm进行卷积操作。

等长卷积

经过区域嵌入后,是两层卷积层,这里采用的是等长卷积,以此来提高词位embedding的表示的丰富性。首先先介绍一下三种卷积的概念:

假设输入的序列长度为 n,卷积核大小为 m,步长(stride)为 s,输入序列两端各填补 p 个零(zero padding),那么该卷积层的输出序列为 (n-m+2p)/s+1。

  1. 窄卷积:步长 s=1,两端不补零,即 p=0,卷积后输出长度为 n-m+1。
  2. 宽卷积:步长 s=1,两端补零 p=m-1,卷积后输出长度 n+m-1。
  3. 等长卷积:步长 s=1,两端补零 p=(m-1)/2,卷积后输出长度为 n。

输入输出序列的位置数一样多,即为等长卷积,该卷积的意义是:输出的词是由该位置输入的词以及其左右词的上下文信息提取得到的,也就是说,这个词包含被上下文信息修饰过的更高级别的语义。

1/2池化

本文使用一个 size=3,stride=2(大小为3,步长为2)的池化层进行最大池化,在此称为1/2池化层。每经过一个1/2池化层,序列的长度就被压缩成了原来的一半。因此,经过1/2池化后,同样一个size为3的卷积核,其能够感知到的文本片段就比之前长了一倍。
在堆叠多层卷积池化层之后,就得到了加深的可以抽取长距离的文本依赖关系的网络。最后的池化层把每段文本聚合为一个向量。

主要区别在于输入层由无监督词嵌入层作为输入,把文档的每个词的词向量作出二维数组作为输入;卷积层有两个卷积层组成,卷积层输入通过跳跃连接,恒等映射和卷积层输出相加作为卷积层输出;采样层以尺度大小为2进行下采样,达到尺度缩放的目的。堆叠几层卷积层和采样层,形成尺度缩放金字塔,达到维度缩放的目的。最终将卷积层输出拼接成向量通过隐藏层和softmax层作为输出分类。

DPCNN模型代码实现
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Config(object):"""配置参数"""def __init__(self):self.dropout = 0.5  # 随机失活self.require_improvement = 1000  # 若超过1000batch效果还没提升,则提前结束训练self.num_classes =10 # 类别数self.n_vocab = 10000  # 词表大小,在运行时赋值self.num_epochs = 20  # epoch数self.batch_size = 128  # mini-batch大小self.pad_size = 32  # 每句话处理成的长度(短填长切)self.learning_rate = 1e-3  # 学习率self.embed = 300  # 字向量维度self.num_filters = 250  # 卷积核数量(channels数)'''Deep Pyramid Convolutional Neural Networks for Text Categorization'''class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)self.conv_region = nn.Conv2d(1, config.num_filters, (3, config.embed), stride=1)self.conv = nn.Conv2d(config.num_filters, config.num_filters, (3, 1), stride=1)self.max_pool = nn.MaxPool2d(kernel_size=(3, 1), stride=2)# (pad_left, pad_right, pad_top, pad_bottom)填充self.padding1 = nn.ZeroPad2d((0, 0, 1, 1))  # top bottomself.padding2 = nn.ZeroPad2d((0, 0, 0, 1))  # bottomself.relu = nn.ReLU()self.fc = nn.Linear(config.num_filters, config.num_classes)def forward(self, x):x = x[0]  # torch.Size([128, 32])x = self.embedding(x)  # torch.Size([128, 32,300])x = x.unsqueeze(1)  # torch.Size([128, 1, 32, 300])x = self.conv_region(x)  # torch.Size([128, 250, 30, 300])x = self.padding1(x)  # [128, 250, 32, 1]x = self.relu(x)  # [128, 250, 32, 1]x = self.conv(x)  # [125, 250, 30, 1]x = self.padding1(x)  # [128, 250, 32, 1]x = self.relu(x)  # [128, 250, 32, 1]x = self.conv(x)  # [128, 250, 30, 1]while x.size()[2] > 2:x = self._block(x)# print("x10", x)#torch.Size([128, 250, 1, 1])x = x.squeeze()  # [128, 250]x = self.fc(x)  # [128, 10]return xdef _block(self, x):x = self.padding2(x)px = self.max_pool(x)x = self.padding1(px)x = F.relu(x)x = self.conv(x)x = self.padding1(x)x = F.relu(x)x = self.conv(x)x = x + pxreturn xconfig=Config()
model=Model(config)
print(model)

输出:

Model((embedding): Embedding(10000, 300, padding_idx=9999)(conv_region): Conv2d(1, 250, kernel_size=(3, 300), stride=(1, 1))(conv): Conv2d(250, 250, kernel_size=(3, 1), stride=(1, 1))(max_pool): MaxPool2d(kernel_size=(3, 1), stride=2, padding=0, dilation=1, ceil_mode=False)(padding1): ZeroPad2d((0, 0, 1, 1))(padding2): ZeroPad2d((0, 0, 0, 1))(relu): ReLU()(fc): Linear(in_features=250, out_features=10, bias=True)
)

参考:
https://blog.csdn.net/sikh_0529/article/details/126912490
https://blog.csdn.net/qq_43592352/article/details/122764889

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/13409.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ATA-2021B高压放大器在光纤超声传感器中的应用

实验名称:超声传感性能研究 测试目的: 光纤马赫-曾德尔干涉仪是一种灵敏度高、结构灵活的传感结构。当在MZI上施加超声波信号时,会影响所涉及的干涉光之间的光程差,并导致干涉光谱的漂移。由于模式耦合是基于MZI的光纤传感器的关键…

脑中风也会出现眩晕?快速识别中风,一定要牢记这些!

眩晕是许多人都会经历的不适感,发作时仿佛整个世界都在旋转,可能还伴随着站立不稳、脚步虚浮、恶心等症状。然而,你可能不知道的是,这些症状在某些情况下可能是脑中风的前兆。如果不及时关注并采取相应措施,一旦发展为…

Failed to register @ServerEndpoint class

springboot集成websocket启动异常信息 java.lang.IllegalStateException: Failed to register ServerEndpoint class: springboot集成websocket引用依赖 <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter…

比特币能否跨过量子时代的这道槛?(2/2)

在上一篇文章《比特币能否跨过量子时代的这道槛&#xff1f;&#xff08;1/2&#xff09;》里说到&#xff0c;比特币要进入量子时代&#xff0c;在技术上必须对加密算法进行升级&#xff0c;而改变比特币的加密基础需要进行硬分叉&#xff0c;这一过程不仅技术复杂&#xff0c…

【算法】二分查找——在排序数组中查找元素的第一个和最后一个位置

本节博客主要是通过“在排序数组中查找元素的第一个和最后一个位置”总结关于二分算法的左右界代码模板&#xff0c;有需要借鉴即可。 目录 1.题目2.二分边界算法2.1查找区间左端点2.1.1循环条件2.1.2求中点的操作2.1.3总结 2.2查找区间右端点2.1.1循环条件2.1.2求中点的操作2.…

O2OA平台流程催办怎么做

O2OA平台设计了灵活的消息提醒数据交互方式&#xff0c;开发者可以根据自己的需要&#xff0c;来消费消息提醒数据&#xff0c;也可以将消息提醒数据接入到Kafka消息中间件来实现消息的准实时提醒。本篇主要介绍如何在O2OA服务器中设置流程的催办提醒消息。 催办提醒服务&#…

centos无法联网解决方案(9步完成

1.打开终端&#xff0c;输入 su - root 进入到管理员模式&#xff08;-的前后都有空格哈&#xff09; 切换后&#xff0c;显示的就是root... 2.. &#xff0c;输入命令ip addr 2. 切换当前目录 cd /etc/sysconfig/network-scripts/ 3.输入命令&#xff0c;打开文件 vi /etc…

一.常见算法--动态规划

&#xff08;1&#xff09;0-1背包问题 问题描述&#xff1a; 0-1背包问题的描述&#xff1a;在n种物品中选择1个或0个第i种物品&#xff0c;装入背包容量为m的背包&#xff0c;使得背包价值达到最大。 思路与关键点&#xff1a; 用到了max函数&#xff0c;用于返回两个数之中…

为何Linux成为你不可或缺的技能

在数字化飞速发展的今天&#xff0c;无论你是IT行业的精英&#xff0c;还是其他领域的专业人士&#xff0c;掌握Linux都已经成为一项至关重要的技能。那么&#xff0c;为什么一定要学会Linux呢&#xff1f;以下文章仅供参考 1. 开源的力量&#xff1a;无限的可能性 Linux是一…

工厂自动化升级改造(3)-Modbus与MQTT的转换

什么是MQTT,Modbus,见下面文章 工厂自动化升级改造参考(01)--设备通信协议详解及选型-CSDN博客文章浏览阅读608次,点赞9次,收藏6次。>>特点:基于标准的以太网技术,使用TCP/IP协议栈,支持高速数据传输和局域网内的设备通信。>>>特点:跨平台的通信协议,…

java版数字藏品深色UI仿鲸探数藏盲盒合成短视频卡牌模式支持高并发功能介绍

根据您提供的艺术品发售系统的需求&#xff0c;以下是一个更为详细和全面的系统设计概述&#xff1a; 1. 藏品发售 藏品分类&#xff1a;藏品可以按照不同的类别进行分类&#xff0c;如绘画、雕塑、摄影等。稀有度设置&#xff1a;后台可以为每个藏品设置不同的稀有度&#x…

ssl证书价格一年多少钱?如何申请?

由于行业新规&#xff0c;现在阿里云、腾讯云等几乎所有平台都不再提供一年期免费证书&#xff0c;如果需要一年期证书则需要支付一定的费用。SSL证书的价格根据类型不同几十到几百上千不等。 一年期SSL证书申请通道https://www.joyssl.com/?nid16 一年期SSL证书申请流程&am…

人工智能(一)架构

一、引言 人工智能这个词不是很新鲜&#xff0c;早就有开始研究的&#xff0c;各种推荐系统、智能客服都是有一定的智能服务的&#xff0c;但是一直都没有体现出多高的智能性&#xff0c;很多时候更像是‘人工智障’。 但是自从chatGpt3被大范围的营销和使用之后&#xff0c;人…

基于springboot的中小型医院网站源码数据库

基于springboot的中小型医院网站源码数据库 本基于Spring Boot的中小型医院网站设计目标是实现用户网络预约挂号的功能&#xff0c;同时提高医院管理效率&#xff0c;更好的为广大用户服务。 本文重点阐述了中小型医院网站的开发过程&#xff0c;以实际运用为开发背景&#x…

基于PHP+MySQL开发的百娣美业课程管理软件系统后端功能介绍

如何开发一个美容产业链的商户管理系统。 1. 需求分析。 在开发美容产业链商户管理系统之前&#xff0c;必须首先进行需求分析。商户需要明确自己的需求和目标&#xff0c;了解系统的功能模块和业务流程&#xff0c;为后续发展提供明确的方向。 2. 系统设计。 根据需求…

python常用基础知识

目录 &#xff08;1&#xff09;print函数 &#xff08;2&#xff09;注释 &#xff08;3&#xff09;input函数 &#xff08;4&#xff09;同时赋值和连续赋值 &#xff08;5&#xff09;type函数和id函数 &#xff08;6&#xff09;python赋值是地址赋值 &#xff08;…

Qt编译和使用freetype矢量字库方法

在之前讲过QT中利用freetype提取字库生成图片的方法&#xff1a; #QT利用freetype提取字库图片_qt freetype-CSDN博客文章浏览阅读1.2k次。这是某个项目中要用到的片段&#xff0c;结合上一篇文章#QT从字体名获取字库文件路径使用// 保存位图int SaveBitmapToFile(HBITMAP hBi…

【会议征稿,ACM出版】第四届人工智能,大数据与算法国际学术会议 (CAIBDA 2024, 7/5-7)

由河南省科学院、河南大学主办&#xff0c;河南省科学院智慧创制研究所、河南大学学术发展部、河南大学人工智能学院承办的第四届人工智能&#xff0c;大数据与算法国际学术会议 (CAIBDA 2024)将于2024年7月5-7日于中国郑州隆重举行。CAIBDA 2024致力于为人工智能&#xff0c;大…

你还去营业厅注销流量卡吗?别浪费时间了,现在有三种方法都可以

家人们&#xff0c;不用的手机号你会注销吗&#xff1f;在这里小编提醒大家&#xff0c;不用的手机卡千万要记得注销&#xff0c;不能直接扔掉&#xff0c;不然可能收到天价欠费单。 ​  今天&#xff0c;小编整理汇总了三种常见的注销方法&#xff0c;不用再跑去营业厅了&a…

Spring的IOC(Inversion of Control)设计模式

Spring的IOC&#xff08;Inversion of Control&#xff09;是一种设计模式&#xff0c;它通过控制反转的思想来降低组件之间的耦合度。在Spring框架中&#xff0c;IOC容器负责管理应用程序中的对象&#xff0c;使得对象之间的依赖关系由容器来维护和注入。 以下是Spring IOC的…