【详细注释+流程讲解】基于深度学习的文本分类 TextCNN

前言

这篇文章用于记录阿里天池 NLP 入门赛,详细讲解了整个数据处理流程,以及如何从零构建一个模型,适合新手入门。

赛题以新闻数据为赛题数据,数据集报名后可见并可下载。赛题数据为新闻文本,并按照字符级别进行匿名处理。整合划分出14个候选分类类别:财经、彩票、房产、股票、家居、教育、科技、社会、时尚、时政、体育、星座、游戏、娱乐的文本数据。实质上是一个 14 分类问题。

赛题数据由以下几个部分构成:训练集20w条样本,测试集A包括5w条样本,测试集B包括5w条样本。

比赛地址:零基础入门NLP - 新闻文本分类_学习赛_天池大赛-阿里云天池的赛制

数据可以通过上面的链接下载。

其中还用到了训练好的词向量文件。

词向量下载链接: 百度网盘 请输入提取码 提取码: qbpr

这篇文章中使用的模型主要是CNN + LSTM + Attention,主要学习的是数据处理的完整流程,以及模型构建的完整流程。虽然还没有使用 Bert 等方案,不过如果看完了这篇文章,理解了整个流程之后,即使你想要使用其他模型来处理,也能更快实现。

1. 为什么写篇文章

首先,这篇文章的代码全部都来源于 Datawhale 提供的开源代码,我添加了自己的笔记,帮助新手更好地理解这个代码。

1.1 Datawhale 提供的代码有哪些需要改进?

Datawhale 提供的代码里包含了数据处理,以及从 0 到 1模型建立的完整流程。但是和前面提供的 basesline 的都不太一样,它包含了非常多数据处理的细节,模型也是由 3 个部分构成,所以看起来难度陡然上升。

其次,代码里的注释非常少,也没有讲解整个数据处理和网络的整体流程。这些对于新手来说,增加了理解的门槛。 在数据竞赛方面,我也是一个新人,花了一天的时间,仔细研究数据在一种每一个步骤的转化,对于一些难以理解的代码,在群里询问之后,也得到了 Datawhale 成员的热心解答。最终才明白了全部的代码。

1.2 我做了什么改进?

所以,为了减少对于新手的阅读难度,我添加了一些内容。

  1. 首先,梳理了整个流程,包括两大部分:数据处理模型

    因为代码不是从上到下顺序阅读的。因此,更容易让人理解的做法是:先从整体上给出宏观的数据转换流程图,其中要包括数据在每一步的 shape,以及包含的转换步骤,让读者心中有一个框架图,再带着这个框架图去看细节,会更加了然于胸。

  2. 其次,除了了解了整体流程,在真正的代码细节里,读者可能还是会看不懂某一段小逻辑。因此,我在原有代码的基础之上增添了许多注释,以降低代码的理解门槛。

2. 数据处理

2.1 数据拆分为 10 份

数据首先会经过all_data2fold函数,这个函数的作用是把原始的 DataFrame 数据,转换为一个list,有 10 个元素,表示交叉验证里的 10 份,每个元素是 dict,每个dict包括 label 和 text

首先根据 label 来划分数据行所在 index, 生成 label2id

label2id 是一个 dictkey 为 labelvalue 是一个 list,存储的是该类对应的 index

​然后根据`label2id`,把每一类别的数据,划分到 10 份数据中。​![](https://image.zhangxiann.com/数据处理.gif)​最终得到的数据`fold_data`是一个`list`,有 10 个元素,每个元素是 `dict`,包括 `label` 和 `text`的列表:`[{labels:textx}, {labels:textx}. . .]`。

最后,把前 9 份数据作为训练集train_data,最后一份数据作为验证集dev_data,并读取测试集test_data

2.2 定义并创建 Vacab

Vocab 的作用是:

  • 创建 词 和 index 对应的字典,这里包括 2 份字典,分别是:_id2word 和 _id2extword

  • 其中 _id2word 是从新闻得到的, 把词频小于 5 的词替换为了 UNK。对应到模型输入的 batch_inputs1

  • _id2extword 是从 word2vec.txt 中得到的,有 5976 个词。对应到模型输入的 batch_inputs2

  • 后面会有两个 embedding 层,其中 _id2word 对应的 embedding 是可学习的,_id2extword 对应的 embedding 是从文件中加载的,是固定的。

  • 创建 label 和 index 对应的字典。

  • 上面这些字典,都是基于train_data创建的。

3. 模型

3.1 把文章分割为句子

上上一步得到的 3 个数据,都是一个listlist里的每个元素是 dict,每个 dict 包括 label 和 text。这 3 个数据会经过 get_examples函数。 get_examples函数里,会调用sentence_split函数,把每一篇文章分割成为句子。

然后,根据vocab,把 word 转换为对应的索引,这里使用了 2 个字典,转换为 2 份索引,分别是:word_idsextword_ids。最后返回的数据是一个 list,每个元素是一个 tuple: (label, 句子数量,doc)。其中doc又是一个 list,每个 元素是一个 tuple: (句子长度,word_ids, extword_ids)

​在迭代训练时,调用data_iter函数,生成每一批的batch_data。在data_iter函数里,会调用batch_slice函数生成每一个batch。拿到batch_data后,每个数据的格式仍然是上图中所示的格式,下面,调用batch2tensor函数。

3.2 生成训练数据

batch2tensor函数最后返回的数据是:(batch_inputs1, batch_inputs2, batch_masks), batch_labels。形状都是(batch_size, doc_len, sent_len)doc_len表示每篇新闻有几句话,sent_len表示每句话有多少个单词。

batch_masks在有单词的位置,值为1,其他地方为 0,用于后面计算 Attention,把那些没有单词的位置的 attention 改为 0。

batch_inputs1, batch_inputs2, batch_masks,形状是(batch_size, doc_len, sent_len),转换为(batch_size * doc_len, sent_len)

3.3 网络部分

下面,终于来到网络部分。模型结构图如下:

3.3.1 WordCNNEncoder

WordCNNEncoder 网络结构示意图如下:​

1. Embedding

batch_inputs1, batch_inputs2都输入到WordCNNEncoderWordCNNEncoder包括两个embedding层,分别对应batch_inputs1,embedding 层是可学习的,得到word_embedbatch_inputs2,读取的是外部训练好的词向量,因此是不可学习的,得到extword_embed。所以会分别得到两个词向量,将 2 个词向量相加,得到最终的词向量batch_embed,形状是(batch_size * doc_len, sent_len, 100),然后添加一个维度,变为(batch_size * doc_len, 1, sent_len, 100),对应 Pytorch 里图像的(B, C, H, W)

2. CNN

然后,分别定义 3 个卷积核,output channel 都是 100 维。

第一个卷积核大小为[2,100],得到的输出是(batch_size * doc_len, 100, sent_len-2+1, 1),定义一个池化层大小为[sent_len-2+1, 1],最终得到输出经过squeeze()的形状是(batch_size * doc_len, 100)

同理,第 2 个卷积核大小为[3,100],第 3 个卷积核大小为[4,100]。卷积+池化得到的输出形状也是(batch_size * doc_len, 100)

最后,将这 3 个向量在第 2 个维度上做拼接,得到输出的形状是(batch_size * doc_len, 300)

3.3.2 shape 转换

把上一步得到的数据的形状,转换为(batch_size , doc_len, 300)名字是sent_reps。然后,对mask进行处理。

batch_masks的形状是(batch_size , doc_len, 300),表示单词的 mask,经过sent_masks = batch_masks.bool().any(2).float()得到句子的 mask。含义是:在最后一个维度,判断是否有单词,只要有 1 个单词,那么整句话的 mask 就是 1,sent_masks的维度是:(batch_size , doc_len)

3.3.3 SentEncoder

SentEncoder 网络结构示意图如下:

SentEncoder包含了 2 层的双向 LSTM,输入数据sent_reps的形状是(batch_size , doc_len, 300),LSTM 的 hidden_size 为 256,由于是双向的,经过 LSTM 后的数据维度是(batch_size , doc_len, 512),然后和 mask 按位置相乘,把没有单词的句子的位置改为 0,最后输出的数据sent_hiddens,维度依然是(batch_size , doc_len, 512)

3.3.4 Attention

接着,经过AttentionAttention的输入是sent_hiddenssent_masks。在Attention里,sent_hiddens首先经过线性变化得到key,维度不变,依然是(batch_size , doc_len, 512)

然后keyquery相乘,得到outputsquery的维度是512,因此output的维度是(batch_size , doc_len),这个就是我们需要的attention,表示分配到每个句子的权重。下一步需要对这个attetionsoftmax,并使用sent_masks,把没有单词的句子的权重置为-1e32,得到masked_attn_scores

最后把masked_attn_scoreskey相乘,得到batch_outputs,形状是(batch_size, 512)

3.3.5 FC

最后经过FC层,得到分类概率的向量。

4. 完整代码+注释

4.1 数据处理

导入包

import random

import numpy as np
import torch
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(levelname)s: %(message)s')

查看本文全部内容,欢迎访问天池技术圈官方地址:【详细注释+流程讲解】基于深度学习的文本分类 TextCNN_天池技术圈-阿里云天池


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/791713.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

同步检查继电器 JT-1/200 100V 面板嵌入式安装,板后接线

系列型号 JT-1同步检查继电器; DT-1同步检查继电器; JT-3同步检查继电器; DT-3同步检查继电器; 一、应用范围 JT(DT)系列同步检查继电器用于两端供电线路的自动重合闸线路中,以检查线路上电压的存在及线路上和变电站汇…

基于Spring Boot的在线考试系统

开发语言:Java 框架:springboot JDK版本:JDK1.8 服务器:tomcat7 数据库:mysql 5.7(一定要5.7版本) 数据库工具:Navicat11 开发软件:eclipse/myeclipse/idea Maven…

Go语言时间编程

1.时间元素编程 时间是一个重要的编程元素,可用于计算、同步服务器以及测量。Go语言标准库提供了time包,其中包含用于同当前时间交互以及测量时间的函数和方法。 在编程中,时间通常被称为“实时” “过去的时间” 和 “壁挂钟” 。对于术语“壁挂钟”,可将其视为挂在墙上的…

element-ui breadcrumb 组件源码分享

今日简单分享 breadcrumb 组件的源码实现,主要从以下三个方面: 1、breadcrumb 组件页面结构 2、breadcrumb 组件属性 3、breadcrumb 组件 slot 一、breadcrumb 组件页面结构 二、breadcrumb 组件属性 2.1 separator 属性,分隔符&#xff…

程序员沟通之道:TCP与UDP之辩,窥见有效沟通的重要性(day19)

程序员沟通的重要性: 今天被师父骂了一顿,说我不及时回复他,连最起码的有效沟通都做不到怎么当好一个程序员,想想还挺有道理,程序员需要知道用户到底有哪些需求,用户与程序员之间的有效沟通就起到了关键性作…

PCL 计算直线到三角形的距离(3D)

文章目录 一、简介二、实现代码三、实现效果参考资料一、简介 这里的思路相对就简单很多了,主要有之前的积累: 1、首先,我们可以判断直线与三角形是否相交,相交则距离为0,这里可以参考之前的博客:PCL 计算一条射线与一个三角形的交点。 2、如果直线与三角形未相交,则只需…

Java多态世界(day18)

多态:重写的方法调用和执行 1.静态绑定:编译器在父类中找方法,如: 上面的eat()方法是先在父类中找方法,父类没有的话,就算子类有编译也会报错。(如果引用方法在父类中存…

UE4_普通贴图制作法线Normal材质

UE4 普通贴图制作法线Normal材质 2021-07-02 10:46 导入一张普通贴图: 搜索节点:NormalFromHeightmap 搜索节点:TextureObjectparameter,并修改成导入的普通贴图,连接至HeightMap中 创建参数normal,连接…

世界各国各行业全球价值链数据集(2007-2021年)

01、数据介绍 在全球化的浪潮下,世界各国各行业都深深地融入了全球价值链中,共同构建了一个复杂而精细的网络。这个网络不仅连接了各国的经济体系,也促进了全球贸易和合作的繁荣发展。 在发达国家,他们凭借先进的科技、高端人才…

Dockerfile详解构建镜像

Dockerfile构建企业级镜像 在服务器上可以通过源码或rpm方式部署Nginx服务,但不利于大规模的部署。为提高效率,可以通过Dockerfile的方式将Nginx服务封装到镜像中,然后Docker基于镜像快速启动容器,实现服务的快速部署。 Dockerf…

【Vue3】el-checkbox-group实现权限配置和应用

一. 需求 针对不同等级的用户,配置不同的可见项 配置效果如下 (1)新增,获取数据列表 (2)编辑,回显数据列表 应用效果如下 (1)父级配置 (2)子级…

Cloudflare Workers

Cloudflare Workers 一、由来和历史 Cloudflare Workers是Cloudflare提供的一项边缘计算服务,它允许开发者在全球分布的Cloudflare数据中心运行JavaScript代码,实现灵活的边缘逻辑处理。Workers最初于2017年推出,是Cloudflare推动边缘计算和…

提速又稳定:使用国内镜像源加速 pip 安装软件包

文章目录 前言国内镜像源使用方式个人简介 前言 当涉及到 Python 开发时,使用 pip 安装软件包已经成为家常便饭。但是很多开发者都会遇到一个共同的问题:国外源下载速度慢,不仅浪费时间,而且经常导致安装失败。为了解决这个问题&…

详解TCP/IP五层模型

目录 一、什么是TCP五层模型? 二、TCP五层模型的详细内容 1. 应用层 2. 传输层 3. 网络层 4. 数据链路层 5. 物理层 三、网络设备所在分层 封装和分⽤ 三、Java示例 引言: 在网络通信中,TCP/IP协议是至关重要的。为了更好地理解TCP协议的工…

详解设计模式:单例的进化之路

概念 单例模式(Singleton Pattern)是设计模式中一个重要的模式之一,是确保一个类在任何情况下都绝对只有一个实例。单例模式一般会屏蔽构造器,单例对象提供一个全局访问点,属于创建型模式。 根据初始化时间的不同,可以将单例模式…

文件操作讲解

目录 一.为什么使用文件 二.什么是文件 2.1程序文件 2.2数据文件 2.3文件名 三.文本文件和二进制文件 fwrite函数 fclose函数 四.文件的打开和关闭 4.1流和标准流 4.2文件指针 4.3文件的打开和关闭 五.文件的顺序读写 5.1文件的顺序读写函数 5.1.1fgetc函数…

》shader命令《--材质函数整理

》shader命令《--材质函数整理 2022-05-31 10:00 材质函数整理 Add 加法Subtract 剪法Multiply 乘法Divide 除法Append 向量合并Abs 绝对值Clamp 区间限定(限定高低值)Floor 舍去小数点Ceil 去掉小数1Fmod 计算余数Frac …

【软件工程】概要设计

1. 导言 1.1 目的 该文档的目的是描述学生成绩管理系统的概要设计,其主要内容包括: 系统功能简介 系统结构简介 系统接口设计 数据设计 模块设计 界面设计 本文的预期读者是: 项目开发人员 项目管理人员 项目评测人员(…

VS2022使用属性表快速设置OpenCV工程属性

1.创建C++控制台应用 2.配置工程 3.打开工程后,为工程添加属性表 打开属性管理器窗口,选择Debug|x64 然后右击选择添加新的项目属性表 并命名为opencv490_debug_x64 点击添加 Debug版本属性表添加成功 使用相同方法添加Release版本属性表

Windows通过git配置github代码仓库全流程

git git是代码的版本控制工具 git安装和github注册 这个默认弄过了 通过git和github之间的SSH配置 在github上面新建仓库,做好配置 git绑定GitHub账号 先cd到上传的文件所在的目录 git config --global user.name "你的github用户名"git config -…