【NLP练习】中文文本分类-Pytorch实现

中文文本分类-Pytorch实现

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、准备工作

1. 任务说明

本次使用Pytorch实现中文文本分类。主要代码与文本分类代码基本一致,不同的是本次任务使用了本地的中文数据,数据示例如下:
在这里插入图片描述

2.加载数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")   #忽略警告信息device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

输出:

device(type='cpu')
import pandas as pd#加载自定义中文数据
train_data = pd.read_csv('./train.csv',sep='\t',header = None)
train_data.head()

输出:
在这里插入图片描述

#构造数据集迭代器
def coustom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,ytrain_iter =coustom_data_iter(train_data[0].values[:],train_data[1].values[:])

二、数据预处理

#构建词典
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba#中文分词方法
tokenizer = jieba.lcut
def yield_tokens(data_iter):for text,_ in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])   #设置默认索引,如果找不到单词,则会选择默认索引
vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])

输出:

[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
label_name = list(set(train_data[1].values[:]))
print(label_name)

输出:

['FilmTele-Play', 'Alarm-Update', 'Weather-Query', 'Audio-Play', 'Radio-Listen', 'Travel-Query', 'Music-Play', 'Video-Play', 'HomeAppliance-Control', 'Calendar-Query', 'TVProgram-Play', 'Other']
text_pipeline = lambda x : vocab(tokenizer(x))
label_pipeline = lambda x : label_name.index(x)print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

输出:

[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
7

lambda表达式的语法为:lambda arguments: expression
其中arguments是函数的参数,可以有多个参数,用逗号分隔。expression是一个表达式,它定义了函数的返回值。

  • text_pipeline函数: 将原始文本数据转换为整数列表,使用了之前构建的vocab词表和tokenizer分词器函数。具体步骤:
  1. 接受一个字符串x作为输入
  2. 使用tokenizer将其分词
  3. 将每个词在vocab词表中的索引放入一个列表返回
  • label_pipeline函数: 将原始标签数据转换为整数,它接受一个字符串x作为输入,并使用 label_index.index(x) 方法获取x在label_name列表中的索引作为输出。

2.生成数据批次和迭代器

#生成数据批次和迭代器
from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list, offsets = [],[],[0]         for(_text, _label) in batch:#标签列表label_list.append(label_pipeline(_label))#文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)#偏移量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list,dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)       #返回维度dim中输入元素的累计和return text_list.to(device), label_list.to(device), offsets.to(device)#数据加载器
dataloader = DataLoader(train_iter,batch_size = 8,shuffle = False,collate_fn = collate_batch
)

三、模型构建

1. 搭建模型

#搭建模型
from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel,self).__init__()self.embedding = nn.EmbeddingBag(vocab_size,      #词典大小embed_dim,        # 嵌入的维度sparse=False)     #self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)

2. 初始化模型

#初始化模型
#定义实例
num_class = len(label_name)
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)

3. 定义训练与评估函数

#定义训练与评估函数
import timedef train(dataloader):model.train()          #切换为训练模式total_acc, train_loss, total_count = 0,0,0log_interval = 50start_time = time.time()for idx, (text,label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()                             #grad属性归零loss = criterion(predicted_label, label)          #计算网络输出和真实值之间的差距,label为真loss.backward()                                   #反向传播torch.nn.utils.clip_grad_norm_(model.parameters(),0.1)  #梯度裁剪optimizer.step()                                  #每一步自动更新#记录acc与losstotal_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('|epoch{:d}|{:4d}/{:4d} batches|train_acc{:4.3f} train_loss{:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))total_acc,train_loss,total_count = 0,0,0staet_time = time.time()def evaluate(dataloader):model.eval()      #切换为测试模式total_acc,train_loss,total_count = 0,0,0with torch.no_grad():for idx,(text,label,offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label,label)   #计算loss值#记录测试数据total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count

四、训练模型

1. 拆分数据集并运行模型

#拆分数据集并运行模型
from torch.utils.data.dataset   import random_split
from torchtext.data.functional  import to_map_style_dataset# 超参数设定
EPOCHS      = 10   #epoch
LR          = 5    #learningRate
BATCH_SIZE  = 64   #batch size for training#设置损失函数、选择优化器、设置学习率调整函数
criterion   = torch.nn.CrossEntropyLoss()
optimizer   = torch.optim.SGD(model.parameters(), lr = LR)
scheduler   = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma = 0.1)
total_accu  = None# 构建数据集
train_iter = custom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset   = to_map_style_dataset(train_iter)
split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])train_dataloader    = DataLoader(split_train_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)
valid_dataloader    = DataLoader(split_valid_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)#获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:d} | time:{:4.2f}s | valid_acc {:4.3f} valid_loss {:4.3f}'.format(epoch,time.time() - epoch_start_time,val_acc,val_loss))print('-' * 69)

输出:

['还有双鸭山到淮阴的汽车票吗13号的' '从这里怎么回家' '随便播放一首专辑阁楼里的佛里的歌' ...'黎耀祥陈豪邓萃雯畲诗曼陈法拉敖嘉年杨怡马浚伟等到场出席' '百事盖世群星星光演唱会有谁' '下周一视频会议的闹钟帮我开开']
|epoch1|  50/ 152 batches|train_acc0.953 train_loss0.00282
|epoch1| 100/ 152 batches|train_acc0.953 train_loss0.00271
|epoch1| 150/ 152 batches|train_acc0.952 train_loss0.00292
---------------------------------------------------------------------
| epoch 1 | time:5.50s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch2|  50/ 152 batches|train_acc0.961 train_loss0.00231
|epoch2| 100/ 152 batches|train_acc0.967 train_loss0.00204
|epoch2| 150/ 152 batches|train_acc0.963 train_loss0.00228
---------------------------------------------------------------------
| epoch 2 | time:5.06s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch3|  50/ 152 batches|train_acc0.975 train_loss0.00173
|epoch3| 100/ 152 batches|train_acc0.973 train_loss0.00177
|epoch3| 150/ 152 batches|train_acc0.972 train_loss0.00166
---------------------------------------------------------------------
| epoch 3 | time:5.07s | valid_acc 0.948 valid_loss 0.003
---------------------------------------------------------------------
|epoch4|  50/ 152 batches|train_acc0.984 train_loss0.00137
|epoch4| 100/ 152 batches|train_acc0.987 train_loss0.00123
|epoch4| 150/ 152 batches|train_acc0.983 train_loss0.00119
---------------------------------------------------------------------
| epoch 4 | time:5.07s | valid_acc 0.950 valid_loss 0.003
---------------------------------------------------------------------
|epoch5|  50/ 152 batches|train_acc0.985 train_loss0.00125
|epoch5| 100/ 152 batches|train_acc0.987 train_loss0.00119
|epoch5| 150/ 152 batches|train_acc0.986 train_loss0.00120
---------------------------------------------------------------------
| epoch 5 | time:5.03s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch6|  50/ 152 batches|train_acc0.985 train_loss0.00118
|epoch6| 100/ 152 batches|train_acc0.989 train_loss0.00114
|epoch6| 150/ 152 batches|train_acc0.985 train_loss0.00120
---------------------------------------------------------------------
| epoch 6 | time:5.40s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch7|  50/ 152 batches|train_acc0.984 train_loss0.00119
|epoch7| 100/ 152 batches|train_acc0.986 train_loss0.00119
|epoch7| 150/ 152 batches|train_acc0.989 train_loss0.00112
---------------------------------------------------------------------
| epoch 7 | time:5.71s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch8|  50/ 152 batches|train_acc0.985 train_loss0.00115
|epoch8| 100/ 152 batches|train_acc0.986 train_loss0.00128
|epoch8| 150/ 152 batches|train_acc0.989 train_loss0.00107
---------------------------------------------------------------------
| epoch 8 | time:5.22s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch9|  50/ 152 batches|train_acc0.988 train_loss0.00114
|epoch9| 100/ 152 batches|train_acc0.983 train_loss0.00127
|epoch9| 150/ 152 batches|train_acc0.989 train_loss0.00109
---------------------------------------------------------------------
| epoch 9 | time:5.28s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch10|  50/ 152 batches|train_acc0.986 train_loss0.00115
|epoch10| 100/ 152 batches|train_acc0.987 train_loss0.00117
|epoch10| 150/ 152 batches|train_acc0.986 train_loss0.00119
---------------------------------------------------------------------
| epoch 10 | time:5.22s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
test_acc,test_loss = evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))

输出:

模型准确率为:0.9492

2. 测试指定数据

#测试指定的数据
def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text))output = model(text, torch.tensor([0]))return output.argmax(1).item()ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"
model = model.to("cpu")print("该文本的类别是: %s" %label_name[predict(ex_text_str,text_pipeline)])

输出:

该文本的类别是: Travel-Query

五、总结

训练神经网络时,可使用梯度裁剪的方法来防止梯度爆炸,使得模型训练更加稳定

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/793913.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FAS-Net

感想 图的下标弄不好&#xff0c;且作者未提供代码。AAAI的质量也就这样吧

Web Component 组件库有什么优势

前言 前端目前比较主流的框架有 react&#xff0c;vuejs&#xff0c;angular 等。 我们通常去搭建组件库的时候都是基于某一种框架去搭建&#xff0c;比如 ant-design 是基于 react 搭建的UI组件库&#xff0c;而 element-plus 则是基于 vuejs 搭建的组件库。 可能你有这种体…

Transformer的代码实现 day03(Positional Encoding)

Positional Encoding的理论部分 注意力机制是不含有位置信息&#xff0c;这也就表明&#xff1a;“我爱你”&#xff0c;“你爱我”这两者没有区别&#xff0c;而在现实世界中&#xff0c;这两者有区别。所以位置编码是在进行注意力计算之前&#xff0c;给输入加上一个位置信息…

【RISC-V 指令集】RISC-V 向量V扩展指令集介绍(五)- 向量加载和存储

1. 引言 以下是《riscv-v-spec-1.0.pdf》文档的关键内容&#xff1a; 这是一份关于向量扩展的详细技术文档&#xff0c;内容覆盖了向量指令集的多个关键方面&#xff0c;如向量寄存器状态映射、向量指令格式、向量加载和存储操作、向量内存对齐约束、向量内存一致性模型、向量…

Redis -- 缓存穿透问题解决思路

缓存穿透 &#xff1a;缓存穿透是指客户端请求的数据在缓存中和数据库中都不存在&#xff0c;这样缓存永远不会生效&#xff0c;这些请求都会打到数据库。 常见的解决方案有两种&#xff1a; 缓存空对象 优点&#xff1a;实现简单&#xff0c;维护方便 缺点&#xff1a; 额外…

【JavaSE】接口 详解(上)

前言 本篇会讲到Java中接口内容&#xff0c;概念和注意点可能比较多&#xff0c;需要耐心多看几遍&#xff0c;我尽可能的使用经典的例子帮助大家理解~ 欢迎关注个人主页&#xff1a;逸狼 创造不易&#xff0c;可以点点赞吗~ 如有错误&#xff0c;欢迎指出~ 目录 前言 接口 语法…

pta 1086 就不告诉你

1086 就不告诉你 分数 15 全屏浏览 切换布局 作者 CHEN, Yue 单位 浙江大学 做作业的时候&#xff0c;邻座的小盆友问你&#xff1a;“五乘以七等于多少&#xff1f;”你应该不失礼貌地围笑着告诉他&#xff1a;“五十三。”本题就要求你&#xff0c;对任何一对给定的正整数…

新手开抖店:选品过后如何有效对接达人?这些方法100%有效!

哈喽~我是电商月月 要说做抖音小店最主要的是什么&#xff1f;那当然是找品了 那出单最快的方法是什么&#xff1f;无疑是达人带货了&#xff01; 但新手店铺没销量&#xff0c;没体验分&#xff0c;没好评怎么能让达人同意帮我们带货呢&#xff1f; 方法其实很简单&#x…

“双碳”目标下资源环境中的可计算一般均衡(CGE)模型应用

我国政府承诺在2030年实现“碳达峰”&#xff0c;2060年实现“碳中和”&#xff0c;这就是“双碳”目标。为了实现这一目标就必须应用各种二氧化碳排放量很高技术的替代技术&#xff0c;不仅需要考虑技术上的可靠性&#xff0c;也需要考虑经济上的可行性。可计算一般均衡模型&a…

AI预测福彩3D第26弹【2024年4月4日预测--第4套算法重新开始计算第11次测试】

今天清明节假日&#xff0c;一会要外出&#xff0c;可能要晚点回来。咱们尽早先把预测数据跑完&#xff0c;把结果发出来供各位彩友参考。合并下算法&#xff0c;3D的预测以后将重点测试本套算法&#xff0c;因为本套算法的命中率较高。以后有时间的话会在第二篇文章中发布排列…

UTONMOS:AI+Web3+元宇宙数字化“三位一体”将触发经济新爆点

人工智能、元宇宙、Web3&#xff0c;被称为数字化的“三位一体”&#xff0c;如何看待这三大技术所扮演的角色&#xff1f; 3月24日&#xff0c;2024全球开发者先锋大会“数字化的三位一体——人工智能、元宇宙、Web3.0”论坛在上海漕河泾开发区举行&#xff0c;首次提出&…

深入探索MySQL:成本模型解析与查询性能优化,及未来深度学习与AI模型的应用展望

码到三十五 &#xff1a; 个人主页 在数据库管理系统中&#xff0c;查询优化器是一个至关重要的组件&#xff0c;它负责将用户提交的SQL查询转换为高效的执行计划。在MySQL中&#xff0c;查询优化器使用了一个称为“成本模型”的机制来评估不同执行计划的优劣&#xff0c;并选择…

网络安全 | 什么是负载均衡器?

关注WX&#xff1a; CodingTechWork 介绍 负载均衡是在多个服务器之间有效分配网络流量的过程。负载均衡的目的是优化应用程序的可用性&#xff0c;并确保良好的终端用户体验。负载均衡可协助高流量网站和云计算应用程序应对数百万个用户请求&#xff0c;从而保证客户请求不会…

2012年认证杯SPSSPRO杯数学建模C题(第二阶段)碎片化趋势下的奥运会商业模式全过程文档及程序

2012年认证杯SPSSPRO杯数学建模 C题 碎片化趋势下的奥运会商业模式 原题再现&#xff1a; 从 1984 年的美国洛杉矶奥运会开始&#xff0c;奥运会就不在成为一个“非卖品”&#xff0c;它在向观众诠释更高更快更强的体育精神的同时&#xff0c;也在攫取着巨大的商业价值&#…

颜色空间/模型(RGB, YUV,CMY/CMYK, HSI, HSV等)

什么是颜色 颜色是通过眼、脑和我们的生活经验所产生的对光的视觉感受&#xff0c;我们肉眼所见到的光线&#xff0c;是由波长范围很窄的电磁波产生的&#xff0c;不同波长的电磁波表现为不同的颜色&#xff0c;对色彩的辨认是肉眼受到电磁波辐射能刺激后所引起的视觉神经感觉…

51单片机实验02- P0口流水灯实验

目录 一、实验的背景和意义 二、实验目的 三、实验步骤 四、实验仪器 五、实验任务及要求 1&#xff0c;从led4开始右移 1&#xff09;思路 ①起始灯 &#xff08;led4&#xff09; ②右移 2&#xff09;效果 3&#xff09;代码 2&#xff0c;从其他小灯并向右依次…

面向C++程序员的Rust教程(二)

先序文章请看&#xff1a; 面向C程序员的Rust教程&#xff08;一&#xff09; 所有权与移动语义 要说Rust语言跟其他语言最大的区别&#xff0c;那笔者觉得非数这个所有权和移动语义莫属。 深浅复制 对于绝大多数语言来说&#xff0c;变量/对象之间的赋值通常都是复制语义。…

微信开发工具——进行网页授权

微信开发工具——进行网页授权 微信公众平台设置 1.在首页创建好自己的订阅号 网站&#xff1a;https://mp.weixin.qq.com/ 点击立即注册,在选择订阅号&#xff08;个人创建使用&#xff09; 之后按流程填写后&#xff0c;点击设置与开发-------->基本配置&#xff0c;这…

JAVA八股--redis

JAVA八股--redis 如何保证Redis和数据库数据一致性redisson实现的分布式锁的主从一致性Redis脑裂现象及解决方案介绍I/O多路复用模型undo log 和 redo log&#xff08;没掌握MyISAM 和 InnoDB 有什么区别&#xff1f; 如何保证Redis和数据库数据一致性 关于异步通知中消息队列…

Kubernetes(k8s):精通 Pod 操作的关键命令

Kubernetes&#xff08;k8s&#xff09;&#xff1a;精通 Pod 操作的关键命令 1、查看 Pod 列表2、 查看 Pod 的详细信息3、创建 Pod4、删除 Pod5、获取 Pod 日志6、进入 Pod 执行命令7、暂停和启动 Pod8、改变 Pod 副本数量9、查看当前部署中使用的镜像版本10、滚动更新 Pod11…