第N4周:中文文本分类-Pytorch实现

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

一、准备工作

1.任务说明

 文本分类流程图:

 2.加载数据

​编辑 二、数据的预处理

1.构建词典

2.生成数据批次和迭代器

三、模型构建

四、训练模型

五、小结


一、准备工作

1.任务说明

本次将使用PyTorch实现中文文本分类。主要代码与N1周基本一致,不同的是本次任务中使用了本地的中文数据,数据示例如下:

本周任务:

1.学习如何进行中文本文预处理

2.根据文本内容(第1列)预测文本标签(第2列)

进阶任务:

1.尝试根据第一周的内容独立实现,尽可能的不看本文的代码

2.构建更复杂的网络模型,将准确率提升至91% 

 文本分类流程图:

 2.加载数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")   #忽略警告信息device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)import pandas as pd#加载自定义中文数据
train_data = pd.read_csv('./train.csv',sep='\t',header = None)
#构造数据集迭代器
def coustom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,ytrain_iter =coustom_data_iter(train_data[0].values[:],train_data[1].values[:])

输出: 

 二、数据的预处理

1.构建词典

#构建词典
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba#中文分词方法
tokenizer = jieba.lcut
def yield_tokens(data_iter):for text,_ in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])   #设置默认索引,如果找不到单词,则会选择默认索引
vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])
label_name = list(set(train_data[1].values[:]))
print(label_name)
text_pipeline = lambda x : vocab(tokenizer(x))
label_pipeline = lambda x : label_name.index(x)print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

输出:

2.生成数据批次和迭代器

#生成数据批次和迭代器
from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list, offsets = [],[],[0]         for(_text, _label) in batch:#标签列表label_list.append(label_pipeline(_label))#文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)#偏移量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list,dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)       #返回维度dim中输入元素的累计和return text_list.to(device), label_list.to(device), offsets.to(device)#数据加载器
dataloader = DataLoader(train_iter,batch_size = 8,shuffle = False,collate_fn = collate_batch
)

三、模型构建

#搭建模型
from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel,self).__init__()self.embedding = nn.EmbeddingBag(vocab_size,      #词典大小embed_dim,        # 嵌入的维度sparse=False)     #self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)
#初始化模型
#定义实例
num_class = len(label_name)
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)
#定义训练与评估函数
import timedef train(dataloader):model.train()          #切换为训练模式total_acc, train_loss, total_count = 0,0,0log_interval = 50start_time = time.time()for idx, (text,label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()                             #grad属性归零loss = criterion(predicted_label, label)          #计算网络输出和真实值之间的差距,label为真loss.backward()                                   #反向传播torch.nn.utils.clip_grad_norm_(model.parameters(),0.1)  #梯度裁剪optimizer.step()                                  #每一步自动更新#记录acc与losstotal_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('|epoch{:d}|{:4d}/{:4d} batches|train_acc{:4.3f} train_loss{:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))total_acc,train_loss,total_count = 0,0,0staet_time = time.time()def evaluate(dataloader):model.eval()      #切换为测试模式total_acc,train_loss,total_count = 0,0,0with torch.no_grad():for idx,(text,label,offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label,label)   #计算loss值#记录测试数据total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count

四、训练模型

#拆分数据集并运行模型
from torch.utils.data.dataset   import random_split
from torchtext.data.functional  import to_map_style_dataset# 超参数设定
EPOCHS      = 10   #epoch
LR          = 5    #learningRate
BATCH_SIZE  = 64   #batch size for training#设置损失函数、选择优化器、设置学习率调整函数
criterion   = torch.nn.CrossEntropyLoss()
optimizer   = torch.optim.SGD(model.parameters(), lr = LR)
scheduler   = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma = 0.1)
total_accu  = None# 构建数据集
train_iter = custom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset   = to_map_style_dataset(train_iter)
split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])train_dataloader    = DataLoader(split_train_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)
valid_dataloader    = DataLoader(split_valid_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)#获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:d} | time:{:4.2f}s | valid_acc {:4.3f} valid_loss {:4.3f}'.format(epoch,time.time() - epoch_start_time,val_acc,val_loss))print('-' * 69)
test_acc,test_loss = evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))
#测试指定的数据
def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text))output = model(text, torch.tensor([0]))return output.argmax(1).item()ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"
model = model.to("cpu")print("该文本的类别是: %s" %label_name[predict(ex_text_str,text_pipeline)])

输出:

|epoch1|  50/ 152 batches|train_acc0.431 train_loss0.03045
|epoch1| 100/ 152 batches|train_acc0.700 train_loss0.01936
|epoch1| 150/ 152 batches|train_acc0.768 train_loss0.01370
---------------------------------------------------------------------
| epoch 1 | time:1.58s | valid_acc 0.789 valid_loss 0.012
---------------------------------------------------------------------
|epoch2|  50/ 152 batches|train_acc0.818 train_loss0.01030
|epoch2| 100/ 152 batches|train_acc0.831 train_loss0.00932
|epoch2| 150/ 152 batches|train_acc0.850 train_loss0.00811
---------------------------------------------------------------------
| epoch 2 | time:1.47s | valid_acc 0.837 valid_loss 0.008
---------------------------------------------------------------------
|epoch3|  50/ 152 batches|train_acc0.870 train_loss0.00688
|epoch3| 100/ 152 batches|train_acc0.887 train_loss0.00658
|epoch3| 150/ 152 batches|train_acc0.893 train_loss0.00575
---------------------------------------------------------------------
| epoch 3 | time:1.46s | valid_acc 0.866 valid_loss 0.007
---------------------------------------------------------------------
|epoch4|  50/ 152 batches|train_acc0.906 train_loss0.00507
|epoch4| 100/ 152 batches|train_acc0.918 train_loss0.00468
|epoch4| 150/ 152 batches|train_acc0.915 train_loss0.00478
---------------------------------------------------------------------
| epoch 4 | time:1.47s | valid_acc 0.886 valid_loss 0.006
---------------------------------------------------------------------
|epoch5|  50/ 152 batches|train_acc0.938 train_loss0.00378
|epoch5| 100/ 152 batches|train_acc0.935 train_loss0.00379
|epoch5| 150/ 152 batches|train_acc0.932 train_loss0.00376
---------------------------------------------------------------------
| epoch 5 | time:1.51s | valid_acc 0.890 valid_loss 0.006
---------------------------------------------------------------------
|epoch6|  50/ 152 batches|train_acc0.951 train_loss0.00310
|epoch6| 100/ 152 batches|train_acc0.952 train_loss0.00287
|epoch6| 150/ 152 batches|train_acc0.950 train_loss0.00289
---------------------------------------------------------------------
| epoch 6 | time:1.50s | valid_acc 0.894 valid_loss 0.006
---------------------------------------------------------------------
|epoch7|  50/ 152 batches|train_acc0.963 train_loss0.00233
|epoch7| 100/ 152 batches|train_acc0.963 train_loss0.00244
|epoch7| 150/ 152 batches|train_acc0.965 train_loss0.00222
---------------------------------------------------------------------
| epoch 7 | time:1.49s | valid_acc 0.898 valid_loss 0.005
---------------------------------------------------------------------
|epoch8|  50/ 152 batches|train_acc0.975 train_loss0.00183
|epoch8| 100/ 152 batches|train_acc0.976 train_loss0.00176
|epoch8| 150/ 152 batches|train_acc0.971 train_loss0.00188
---------------------------------------------------------------------
| epoch 8 | time:1.67s | valid_acc 0.900 valid_loss 0.005
---------------------------------------------------------------------
|epoch9|  50/ 152 batches|train_acc0.982 train_loss0.00145
|epoch9| 100/ 152 batches|train_acc0.982 train_loss0.00139
|epoch9| 150/ 152 batches|train_acc0.980 train_loss0.00141
---------------------------------------------------------------------
| epoch 9 | time:2.05s | valid_acc 0.901 valid_loss 0.006
---------------------------------------------------------------------
|epoch10|  50/ 152 batches|train_acc0.990 train_loss0.00108
|epoch10| 100/ 152 batches|train_acc0.984 train_loss0.00119
|epoch10| 150/ 152 batches|train_acc0.986 train_loss0.00105
---------------------------------------------------------------------
| epoch 10 | time:1.98s | valid_acc 0.900 valid_loss 0.005
---------------------------------------------------------------------
模型准确率为:0.8996
该文本的类别是: Travel-Query

五、小结

  • 数据加载
    • 定义一个生成器函数,将文本和标签成对迭代。这是为了后续的数据处理和加载做准备。
  • 分词与词汇表
    • 使用jieba进行中文分词,jieba.lcut可以将中文文本切割成单个词语列表。
    • 使用torchtextbuild_vocab_from_iterator从分词后的文本中构建词汇表,并设置默认索引为<unk>,表示未知词汇。这对处理未见过的词汇非常重要。
  • 数据管道:创建文本和标签处理管道。
    • 创建两个处理管道:

    • text_pipeline:将文本转换为词汇表中的索引。
    • label_pipeline:将标签转换为索引。
  • 模型构建:定义带嵌入层和全连接层的文本分类模型。
    • 定义一个文本分类模型TextClassificationModel,包括一个嵌入层nn.EmbeddingBag和一个全连接层nn.Linearnn.EmbeddingBag在处理变长序列时性能较好,因为它不需要明确的填充操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/852744.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

conda添加镜像源与channels

文章目录 一、conda下添加国内镜像源&#xff08;window下&#xff09;二、pip配置国内镜像源&#xff08;window下&#xff0c;临时修改&#xff09;三、conda源的定义 一、conda下添加国内镜像源&#xff08;window下&#xff09; 1、为【channels】配置清华镜像通道 直接在…

【Pandas】已完美解决:AttributeError: ‘DataFrame‘ object has no attribute ‘ix‘

文章目录 一、问题背景二、可能出错的原因三、错误代码示例四、正确代码示例&#xff08;结合实战场景&#xff09;五、注意事项 一、问题背景 在Pandas的早期版本中&#xff0c;ix 是一个方便的索引器&#xff0c;允许用户通过标签和整数位置来索引DataFrame的行和列。然而&a…

Pixi.js学习 (四)鼠标跟随、元素组合与图片位控

目录 一、鼠标移动跟随 1.1 获取鼠标坐标 1.2 鼠标跟随 二、锚点、元素组合 2.1 锚点 2.2 元素组合 三、图片图层 四、实战 例题一&#xff1a;完成合金弹头人物交互 例题二&#xff1a;反恐重击瞄准和弹痕 例题一代码&#xff1a; 例题二代码&#xff1a; 总结 前言 为了提高作…

ADS基础教程20 - 电磁仿真(EM)参数化

EM介绍 一、引言二、参数化设置1.参数定义2.参数赋值3.创建EM模型和符号 四、总结 一、引言 参数化EM仿真&#xff0c;是在Layout环境下创建参数&#xff0c;相当于在原理图中声明变量。 二、参数化设置 1.参数定义 1&#xff09;在Layout视图&#xff0c;菜单栏中选中EM&g…

大模型出现的不断重复的现象

无论是大语言模型还是多模态模型,都遇到过这个问题,该如何解决呢? 1.调整推理参数 [BUG] 返回重复的内容 Issue #277 QwenLM/Qwen GitHub是否已有关于该错误的issue或讨论? | Is there an existing issue / discussion for this? 我已经搜索过已有的issues和讨论 | I…

【Linux】基础IO——系统文件IO

我之前是讲过c语言的文件操作的&#xff0c;但是说实话我压根就不知道它在干什么&#xff0c;后面c语言/c,数据结构的学习过程中也没用过文件操作&#xff0c;今天我们就来会会这个文件操作 1.回顾c语言文件接口 1.1.fopen r &#xff1a;只读模式打开&#xff0c;文件流指针…

【LeetCode 92.】 反转链表 II

1.题目 虽然本题很好拆解&#xff0c;但是实现起来还是有一些难度的。 2. 分析 尽可能抽象问题&#xff0c;然后简化代码 我在写本题的时候&#xff0c;遇到了下面这两个问题&#xff1a; 没有把[left,right] 这个区间的链表给断开&#xff0c;所以导致反转起来非常麻烦。…

【iOS】KVO相关总结

目录 1. 什么是KVO&#xff1f;2. KVO的基本使用3. KVO的进阶使用observationInfo属性context 的使用KVO触发监听方法的方式自动触发手动触发 KVO新旧值相等时不触发KVO的从属关系一对一关系一对多关系 4. KVO使用注意5. KVO本质原理分析伪代码保留伪代码下的类并编译运行对比添…

小白都能看懂的 “栈”

什么是栈&#xff1f;首先引用维基百科的解释&#xff1a; 栈&#xff08;stack&#xff09;是计算机科学中的一种抽象资料类型&#xff0c;只允许在有序的线性资料集合的一端&#xff08;称为堆栈顶端&#xff0c;top&#xff09;进行加入数据&#xff08;push&#xff09;和移…

Go语言结构体内嵌接口

前言 在golang中&#xff0c;结构体内嵌结构体&#xff0c;接口内嵌接口都很常见&#xff0c;但是结构体内嵌接口很少见。它是做什么用的呢&#xff1f; 当我们需要重写实现了某个接口的结构体的(该接口)的部分方法&#xff0c;可以使用结构体内嵌接口。 作用 继承赋值给接口…

信号与系统实验MATLAB-实验1-信号的MATLAB表示及信号运算

实验1-信号的MATLAB表示及信号运算 一、实验目的 1、掌握MATLAB的使用&#xff1b; 2、掌握MATLAB生成信号波形&#xff1b; 3、掌握MATLAB分析常用连续信号&#xff1b; 4、掌握信号运算的MATLAB实现。 二、实验内容 编写程序实现下列常用函数&#xff0c;并显示波形。…

PyTorch -- Visdom 快速实践

安装&#xff1a;pip install visdom 注&#xff1a;如果安装后启动报错可能是 visdom 版本选择问题 启动&#xff1a;python -m visdom.server 之后打开出现的链接 http://localhost:8097Checking for scripts. Its Alive! INFO:root:Application Started INFO:root:Working…

数据网格和视图入门

WinForms数据网格&#xff08;GridControl类&#xff09;是一个数据感知控件&#xff0c;可以以各种格式&#xff08;视图&#xff09;显示数据。本主题包含以下部分&#xff0c;这些部分将指导您如何使用网格控件及其视图和列&#xff08;字段&#xff09;。 Grid Control’s…

BUUCTF-Web题目1

目录 [HCTF 2018]admin 1、题目 2、知识点 3、思路 [极客大挑战 2019]BuyFlag 1、题目 2、知识点 3、思路 [HCTF 2018]admin 1、题目 2、知识点 BP暴力破解密码 3、思路 打开题目&#xff0c;查看页面源代码&#xff0c;发现需要admin用户才可以登录 这一台有很多解法…

LeetCode | 20.有效的括号

这道题就是栈这种数据结构的应用&#xff0c;当我们遇到左括号的时候&#xff0c;比如{,(,[&#xff0c;就压栈&#xff0c;当遇到右括号的时候&#xff0c;比如},),]&#xff0c;就把栈顶元素弹出&#xff0c;如果不匹配&#xff0c;则返回False&#xff0c;当遍历完所有元素后…

K8s 卷快照类

卷快照类 卷快照类 这个警告信息通常出现在使用 kubectl 删除 Kubernetes 集群资源时&#xff0c;如果尝试删除的是集群作用域&#xff08;cluster-scoped&#xff09;的资源&#xff0c;但指定了命名空间&#xff08;namespace&#xff09;&#xff0c;就会出现这个警告。 集…

基于PointNet / PointNet++深度学习模型的激光点云语义分割

一、场景要素语义分割部分的文献阅读笔记 1.1 PointNet PointNet网络模型开创性地实现了直接将点云数据作为输入的高效深度学习方法&#xff08;端到端学习&#xff09;。最大池化层、全局信息聚合结构以及联合对齐结构是该网络模型的三大关键模块&#xff0c;最大池化层解决了…

72、AndroidStudio 导入项目Connect timed out错误解决

一、背景&#xff1a; 开发过程中难免会 clone 其他的项目&#xff0c;clone 或者下载成功之后。使用 android studio 打开项目时经常遇到 Connect timed out错误如图所示&#xff1a; 二、分析原因&#xff1a; 1、既然链接超时&#xff0c;肯定是 android studio 在运行…

包装类的应用

一.什么是包装类 基本数据类型所对应的引用数据类型 二.集合中不能存储基本数据类型 三.JDK5以后对包装类新增了什么特性&#xff1f; // 自动装箱:把基本数据类型会自动的变成对应的包装类 // 自动拆箱:把包装类自动的变成其对象的基本数据类型 四.我们以后如何获取包…

02-MybatisPlus批量插入性能够吗?

1 前言 “不要用 mybatis-plus 的批量插入&#xff0c;它其实也是遍历插入&#xff0c;性能很差的”。真的吗&#xff1f;他们的立场如下&#xff1a; 遍历插入&#xff0c;反复创建。这是一个重量级操作&#xff0c;所以性能差。这里不用看源码也知道&#xff0c;因为这个和…