第N4周:中文文本分类-Pytorch实现

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

一、准备工作

1.任务说明

 文本分类流程图:

 2.加载数据

​编辑 二、数据的预处理

1.构建词典

2.生成数据批次和迭代器

三、模型构建

四、训练模型

五、小结


一、准备工作

1.任务说明

本次将使用PyTorch实现中文文本分类。主要代码与N1周基本一致,不同的是本次任务中使用了本地的中文数据,数据示例如下:

本周任务:

1.学习如何进行中文本文预处理

2.根据文本内容(第1列)预测文本标签(第2列)

进阶任务:

1.尝试根据第一周的内容独立实现,尽可能的不看本文的代码

2.构建更复杂的网络模型,将准确率提升至91% 

 文本分类流程图:

 2.加载数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")   #忽略警告信息device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)import pandas as pd#加载自定义中文数据
train_data = pd.read_csv('./train.csv',sep='\t',header = None)
#构造数据集迭代器
def coustom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,ytrain_iter =coustom_data_iter(train_data[0].values[:],train_data[1].values[:])

输出: 

 二、数据的预处理

1.构建词典

#构建词典
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba#中文分词方法
tokenizer = jieba.lcut
def yield_tokens(data_iter):for text,_ in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])   #设置默认索引,如果找不到单词,则会选择默认索引
vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])
label_name = list(set(train_data[1].values[:]))
print(label_name)
text_pipeline = lambda x : vocab(tokenizer(x))
label_pipeline = lambda x : label_name.index(x)print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

输出:

2.生成数据批次和迭代器

#生成数据批次和迭代器
from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list, offsets = [],[],[0]         for(_text, _label) in batch:#标签列表label_list.append(label_pipeline(_label))#文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)#偏移量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list,dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)       #返回维度dim中输入元素的累计和return text_list.to(device), label_list.to(device), offsets.to(device)#数据加载器
dataloader = DataLoader(train_iter,batch_size = 8,shuffle = False,collate_fn = collate_batch
)

三、模型构建

#搭建模型
from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel,self).__init__()self.embedding = nn.EmbeddingBag(vocab_size,      #词典大小embed_dim,        # 嵌入的维度sparse=False)     #self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)
#初始化模型
#定义实例
num_class = len(label_name)
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)
#定义训练与评估函数
import timedef train(dataloader):model.train()          #切换为训练模式total_acc, train_loss, total_count = 0,0,0log_interval = 50start_time = time.time()for idx, (text,label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()                             #grad属性归零loss = criterion(predicted_label, label)          #计算网络输出和真实值之间的差距,label为真loss.backward()                                   #反向传播torch.nn.utils.clip_grad_norm_(model.parameters(),0.1)  #梯度裁剪optimizer.step()                                  #每一步自动更新#记录acc与losstotal_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('|epoch{:d}|{:4d}/{:4d} batches|train_acc{:4.3f} train_loss{:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))total_acc,train_loss,total_count = 0,0,0staet_time = time.time()def evaluate(dataloader):model.eval()      #切换为测试模式total_acc,train_loss,total_count = 0,0,0with torch.no_grad():for idx,(text,label,offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label,label)   #计算loss值#记录测试数据total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count

四、训练模型

#拆分数据集并运行模型
from torch.utils.data.dataset   import random_split
from torchtext.data.functional  import to_map_style_dataset# 超参数设定
EPOCHS      = 10   #epoch
LR          = 5    #learningRate
BATCH_SIZE  = 64   #batch size for training#设置损失函数、选择优化器、设置学习率调整函数
criterion   = torch.nn.CrossEntropyLoss()
optimizer   = torch.optim.SGD(model.parameters(), lr = LR)
scheduler   = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma = 0.1)
total_accu  = None# 构建数据集
train_iter = custom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset   = to_map_style_dataset(train_iter)
split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])train_dataloader    = DataLoader(split_train_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)
valid_dataloader    = DataLoader(split_valid_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)#获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:d} | time:{:4.2f}s | valid_acc {:4.3f} valid_loss {:4.3f}'.format(epoch,time.time() - epoch_start_time,val_acc,val_loss))print('-' * 69)
test_acc,test_loss = evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))
#测试指定的数据
def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text))output = model(text, torch.tensor([0]))return output.argmax(1).item()ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"
model = model.to("cpu")print("该文本的类别是: %s" %label_name[predict(ex_text_str,text_pipeline)])

输出:

|epoch1|  50/ 152 batches|train_acc0.431 train_loss0.03045
|epoch1| 100/ 152 batches|train_acc0.700 train_loss0.01936
|epoch1| 150/ 152 batches|train_acc0.768 train_loss0.01370
---------------------------------------------------------------------
| epoch 1 | time:1.58s | valid_acc 0.789 valid_loss 0.012
---------------------------------------------------------------------
|epoch2|  50/ 152 batches|train_acc0.818 train_loss0.01030
|epoch2| 100/ 152 batches|train_acc0.831 train_loss0.00932
|epoch2| 150/ 152 batches|train_acc0.850 train_loss0.00811
---------------------------------------------------------------------
| epoch 2 | time:1.47s | valid_acc 0.837 valid_loss 0.008
---------------------------------------------------------------------
|epoch3|  50/ 152 batches|train_acc0.870 train_loss0.00688
|epoch3| 100/ 152 batches|train_acc0.887 train_loss0.00658
|epoch3| 150/ 152 batches|train_acc0.893 train_loss0.00575
---------------------------------------------------------------------
| epoch 3 | time:1.46s | valid_acc 0.866 valid_loss 0.007
---------------------------------------------------------------------
|epoch4|  50/ 152 batches|train_acc0.906 train_loss0.00507
|epoch4| 100/ 152 batches|train_acc0.918 train_loss0.00468
|epoch4| 150/ 152 batches|train_acc0.915 train_loss0.00478
---------------------------------------------------------------------
| epoch 4 | time:1.47s | valid_acc 0.886 valid_loss 0.006
---------------------------------------------------------------------
|epoch5|  50/ 152 batches|train_acc0.938 train_loss0.00378
|epoch5| 100/ 152 batches|train_acc0.935 train_loss0.00379
|epoch5| 150/ 152 batches|train_acc0.932 train_loss0.00376
---------------------------------------------------------------------
| epoch 5 | time:1.51s | valid_acc 0.890 valid_loss 0.006
---------------------------------------------------------------------
|epoch6|  50/ 152 batches|train_acc0.951 train_loss0.00310
|epoch6| 100/ 152 batches|train_acc0.952 train_loss0.00287
|epoch6| 150/ 152 batches|train_acc0.950 train_loss0.00289
---------------------------------------------------------------------
| epoch 6 | time:1.50s | valid_acc 0.894 valid_loss 0.006
---------------------------------------------------------------------
|epoch7|  50/ 152 batches|train_acc0.963 train_loss0.00233
|epoch7| 100/ 152 batches|train_acc0.963 train_loss0.00244
|epoch7| 150/ 152 batches|train_acc0.965 train_loss0.00222
---------------------------------------------------------------------
| epoch 7 | time:1.49s | valid_acc 0.898 valid_loss 0.005
---------------------------------------------------------------------
|epoch8|  50/ 152 batches|train_acc0.975 train_loss0.00183
|epoch8| 100/ 152 batches|train_acc0.976 train_loss0.00176
|epoch8| 150/ 152 batches|train_acc0.971 train_loss0.00188
---------------------------------------------------------------------
| epoch 8 | time:1.67s | valid_acc 0.900 valid_loss 0.005
---------------------------------------------------------------------
|epoch9|  50/ 152 batches|train_acc0.982 train_loss0.00145
|epoch9| 100/ 152 batches|train_acc0.982 train_loss0.00139
|epoch9| 150/ 152 batches|train_acc0.980 train_loss0.00141
---------------------------------------------------------------------
| epoch 9 | time:2.05s | valid_acc 0.901 valid_loss 0.006
---------------------------------------------------------------------
|epoch10|  50/ 152 batches|train_acc0.990 train_loss0.00108
|epoch10| 100/ 152 batches|train_acc0.984 train_loss0.00119
|epoch10| 150/ 152 batches|train_acc0.986 train_loss0.00105
---------------------------------------------------------------------
| epoch 10 | time:1.98s | valid_acc 0.900 valid_loss 0.005
---------------------------------------------------------------------
模型准确率为:0.8996
该文本的类别是: Travel-Query

五、小结

  • 数据加载
    • 定义一个生成器函数,将文本和标签成对迭代。这是为了后续的数据处理和加载做准备。
  • 分词与词汇表
    • 使用jieba进行中文分词,jieba.lcut可以将中文文本切割成单个词语列表。
    • 使用torchtextbuild_vocab_from_iterator从分词后的文本中构建词汇表,并设置默认索引为<unk>,表示未知词汇。这对处理未见过的词汇非常重要。
  • 数据管道:创建文本和标签处理管道。
    • 创建两个处理管道:

    • text_pipeline:将文本转换为词汇表中的索引。
    • label_pipeline:将标签转换为索引。
  • 模型构建:定义带嵌入层和全连接层的文本分类模型。
    • 定义一个文本分类模型TextClassificationModel,包括一个嵌入层nn.EmbeddingBag和一个全连接层nn.Linearnn.EmbeddingBag在处理变长序列时性能较好,因为它不需要明确的填充操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/852744.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

conda添加镜像源与channels

文章目录 一、conda下添加国内镜像源&#xff08;window下&#xff09;二、pip配置国内镜像源&#xff08;window下&#xff0c;临时修改&#xff09;三、conda源的定义 一、conda下添加国内镜像源&#xff08;window下&#xff09; 1、为【channels】配置清华镜像通道 直接在…

【Pandas】已完美解决:AttributeError: ‘DataFrame‘ object has no attribute ‘ix‘

文章目录 一、问题背景二、可能出错的原因三、错误代码示例四、正确代码示例&#xff08;结合实战场景&#xff09;五、注意事项 一、问题背景 在Pandas的早期版本中&#xff0c;ix 是一个方便的索引器&#xff0c;允许用户通过标签和整数位置来索引DataFrame的行和列。然而&a…

Flask-Logging

Flask-Logging 教程 概述 flask-logging 是一个用于在 Flask 应用中实现高级日志记录功能的库。它能够帮助开发者轻松地配置和管理日志&#xff0c;适用于开发和生产环境。通过使用 flask-logging&#xff0c;可以更好地监控应用的运行状态和调试问题。 官方文档 Flask-Log…

Pixi.js学习 (四)鼠标跟随、元素组合与图片位控

目录 一、鼠标移动跟随 1.1 获取鼠标坐标 1.2 鼠标跟随 二、锚点、元素组合 2.1 锚点 2.2 元素组合 三、图片图层 四、实战 例题一&#xff1a;完成合金弹头人物交互 例题二&#xff1a;反恐重击瞄准和弹痕 例题一代码&#xff1a; 例题二代码&#xff1a; 总结 前言 为了提高作…

ADS基础教程20 - 电磁仿真(EM)参数化

EM介绍 一、引言二、参数化设置1.参数定义2.参数赋值3.创建EM模型和符号 四、总结 一、引言 参数化EM仿真&#xff0c;是在Layout环境下创建参数&#xff0c;相当于在原理图中声明变量。 二、参数化设置 1.参数定义 1&#xff09;在Layout视图&#xff0c;菜单栏中选中EM&g…

QMap使用详解

QMap使用详解 1. 实例化 QMap 对象2. 插入数据3. 移除数据4. 遍历数据5. 由键查找对应键值6. 由键值查找键7. 修改键值8. 查找是否包含某个键9. 获取所有的键和键值10.清除数据11.一个键对应多个值12.QMultiMap 遍历数据13.完整示例代码14.使用自定义键类型的 QMap示例&#xf…

大模型出现的不断重复的现象

无论是大语言模型还是多模态模型,都遇到过这个问题,该如何解决呢? 1.调整推理参数 [BUG] 返回重复的内容 Issue #277 QwenLM/Qwen GitHub是否已有关于该错误的issue或讨论? | Is there an existing issue / discussion for this? 我已经搜索过已有的issues和讨论 | I…

【Linux】基础IO——系统文件IO

我之前是讲过c语言的文件操作的&#xff0c;但是说实话我压根就不知道它在干什么&#xff0c;后面c语言/c,数据结构的学习过程中也没用过文件操作&#xff0c;今天我们就来会会这个文件操作 1.回顾c语言文件接口 1.1.fopen r &#xff1a;只读模式打开&#xff0c;文件流指针…

Java程序员英语单词通关:

Java程序员英语单词通关&#xff1a; abstract - 抽象的 boolean - 布尔值 break - 打断 byte - 字节 case - 情况&#xff0c;实例 catch - 捕获 char - 字符 class - 类 continue - 继续 default - 默认&#xff0c;通常 do - 做&#xff0c;运行 double - 双精度…

【LeetCode 92.】 反转链表 II

1.题目 虽然本题很好拆解&#xff0c;但是实现起来还是有一些难度的。 2. 分析 尽可能抽象问题&#xff0c;然后简化代码 我在写本题的时候&#xff0c;遇到了下面这两个问题&#xff1a; 没有把[left,right] 这个区间的链表给断开&#xff0c;所以导致反转起来非常麻烦。…

【iOS】KVO相关总结

目录 1. 什么是KVO&#xff1f;2. KVO的基本使用3. KVO的进阶使用observationInfo属性context 的使用KVO触发监听方法的方式自动触发手动触发 KVO新旧值相等时不触发KVO的从属关系一对一关系一对多关系 4. KVO使用注意5. KVO本质原理分析伪代码保留伪代码下的类并编译运行对比添…

JVM垃圾回收的普遍步骤

JVM&#xff08;Java Virtual Machine&#xff09;进行垃圾回收时&#xff0c;通常遵循以下步骤。不同的垃圾收集器可能会有一些不同的实现细节&#xff0c;但基本步骤和思想大致相同。以下是一般的垃圾回收过程的主要步骤&#xff1a; 1. 标记阶段&#xff08;Marking Phase&…

小白都能看懂的 “栈”

什么是栈&#xff1f;首先引用维基百科的解释&#xff1a; 栈&#xff08;stack&#xff09;是计算机科学中的一种抽象资料类型&#xff0c;只允许在有序的线性资料集合的一端&#xff08;称为堆栈顶端&#xff0c;top&#xff09;进行加入数据&#xff08;push&#xff09;和移…

Go语言结构体内嵌接口

前言 在golang中&#xff0c;结构体内嵌结构体&#xff0c;接口内嵌接口都很常见&#xff0c;但是结构体内嵌接口很少见。它是做什么用的呢&#xff1f; 当我们需要重写实现了某个接口的结构体的(该接口)的部分方法&#xff0c;可以使用结构体内嵌接口。 作用 继承赋值给接口…

信号与系统实验MATLAB-实验1-信号的MATLAB表示及信号运算

实验1-信号的MATLAB表示及信号运算 一、实验目的 1、掌握MATLAB的使用&#xff1b; 2、掌握MATLAB生成信号波形&#xff1b; 3、掌握MATLAB分析常用连续信号&#xff1b; 4、掌握信号运算的MATLAB实现。 二、实验内容 编写程序实现下列常用函数&#xff0c;并显示波形。…

PyTorch -- Visdom 快速实践

安装&#xff1a;pip install visdom 注&#xff1a;如果安装后启动报错可能是 visdom 版本选择问题 启动&#xff1a;python -m visdom.server 之后打开出现的链接 http://localhost:8097Checking for scripts. Its Alive! INFO:root:Application Started INFO:root:Working…

数据网格和视图入门

WinForms数据网格&#xff08;GridControl类&#xff09;是一个数据感知控件&#xff0c;可以以各种格式&#xff08;视图&#xff09;显示数据。本主题包含以下部分&#xff0c;这些部分将指导您如何使用网格控件及其视图和列&#xff08;字段&#xff09;。 Grid Control’s…

BUUCTF-Web题目1

目录 [HCTF 2018]admin 1、题目 2、知识点 3、思路 [极客大挑战 2019]BuyFlag 1、题目 2、知识点 3、思路 [HCTF 2018]admin 1、题目 2、知识点 BP暴力破解密码 3、思路 打开题目&#xff0c;查看页面源代码&#xff0c;发现需要admin用户才可以登录 这一台有很多解法…

redis清空list

redis list清空 要清空Redis中的list&#xff0c;您可以使用LTRIM命令。Redis Ltrim 对一个列表进行修剪(trim)&#xff0c;就是说&#xff0c;让列表只保留指定区间内的元素&#xff0c;不在指定区间之内的元素都将被删除。 下标 0 表示列表的第一个元素&#xff0c;以 1 表示…

LeetCode | 20.有效的括号

这道题就是栈这种数据结构的应用&#xff0c;当我们遇到左括号的时候&#xff0c;比如{,(,[&#xff0c;就压栈&#xff0c;当遇到右括号的时候&#xff0c;比如},),]&#xff0c;就把栈顶元素弹出&#xff0c;如果不匹配&#xff0c;则返回False&#xff0c;当遍历完所有元素后…