第N4周:中文文本分类——Pytorch实现

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

数据集:train

 一.加载数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")device = torch.device("cuda"if torch.cuda.is_available() else"cpu")
device

#获取数据
import pandas as pd
train_data = pd.read_csv('D:/BaiduNetdiskDownload/train.csv',sep='\t',header=None)
train_data.head()

 

#构建数据的迭代器
def coustom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,ytrain_iter = coustom_data_iter(train_data[0].values[:],train_data[1].values[:])

 二.数据处理

1.构建词典

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba tokenizer =jieba.lcutdef yield_tokens(data_iter):for text,_ in data_iter:yield tokenizer(text)  
vocab = build_vocab_from_iterator(yield_tokens(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

vocab(['随便','播放','一','首','歌','我'])
out:[173, 4, 181, 0, 108, 2]
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: label_name.index(x)
label_name =list(set(train_data[1].values[:]))
print(label_name)


print(text_pipeline('我想看'))
print(label_pipeline('HomeAppliance-Control'))

创建了两个lambda函数,一个用于将文本转换成词汇索引,另一个用于将标签文本转换成它们在label_name列表中的索引。

除了一般使用def声明的函数外,Python中还支持lambda匿名函数,可以在任何场合替代def函数。
匿名函数,通常指的是运行时临时创建的,没有显示命名的函数,它允许快速定义简单的函数。

语法

lambda arguments :expression

lambda argument1,argument2...,argumentn : expression using arguments

        lambda 是关键字
  arguments是参数,可以是0个或多个,用逗号分割
  expression是一个表达式,描述了函数的返回值

lambda关键字用于创建小巧的匿名函数。lambda a, b: a+b 函数返回两个参数的和。Lambda 函数可用于任何需要函数对象的地方。在语法上,匿名函数只能是单个表达式。在语义上,它只是常规函数定义的语法糖。与嵌套函数定义一样,lambda 函数可以引用包含作用域中的变量。

优点:

  简洁:lambda表达式可以快速定义简单的函数,无需使用def语句。
  匿名:由于lambda表达式没有正式的函数名称,因此它们是匿名的,可以用于需要短生命周期函数的情况。
  轻量级:lambda表达式只包含一个表达式,因此它们占用内存空间较小,适合用于小型任务。
  可嵌套:Lambda表达式可以嵌套在其他函数或代码块中使用,使代码更加紧凑。

缺点:
  只能包含一个表达式:Lambda表达式只能包含一个表达式,这意味着它们不能包含多个语句或复杂逻辑。
  调试困难:由于Lambda表达式通常很短,因此很难在调试时设置断点和查看执行流程。
  不支持变量定义:Lambda表达式不能定义变量,只能使用已存在的变量。这意味着它们在处理复杂逻辑时可能会受到限制。
  性能问题:Lambda表达式可能在性能方面不如使用def语句定义的函数。由于它们通常是轻量级的,因此可能不会进行优化。

2.生成数据批次和迭代器

from torch.utils.data import DataLoaderdef collate_batch(batch):label_list,text_list,offsets=[],[],[0]for(_text,_label)in batch:label_list.append(label_pipeline(_label))processed_text= torch.tensor(text_pipeline(_text),dtype=torch.int64)text_list.append(processed_text)offsets.append(processed_text.size(0))label_list=torch.tensor(label_list,dtype=torch.int64)text_list=torch.cat(text_list)offsets=torch.tensor(offsets[:-1]).cumsum(dim=0)return text_list.to(device),label_list.to(device),offsets.to(device)dataloader=DataLoader(train_iter,batch_size=8,shuffle =False,collate_fn=collate_batch)

collate_batch函数用于处理数据加载器中的批次。它接收一个批次的数据,处理它,并返回适合模型训练的数据格式。在这个函数内部,它遍历批次中的每个文本和标签对,将标签添加到label_list,将文本通过text_pipeline函数处理后转换为tensor,并添加到text_list。

三.模型构建

1.搭建模型

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self,vocab_size,embed_dim,num_class):super(TextClassificationModel,self).__init__()self.embedding = nn.EmbeddingBag(vocab_size,embed_dim,sparse=False)self.fc = nn.Linear(embed_dim,num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange,initrange)self.fc.weight.data.uniform_(-initrange,initrange)self.fc.bias.data.zero_()def forward(self,text,offsets):embedded=self.embedding(text,offsets)return self.fc(embedded)

Embedding 层 (nn.EmbeddingBag):

通过 nn.EmbeddingBag 定义了一个嵌入层,用于将文本数据嵌入到低维空间中。
vocab_size 参数指定词典的大小,embed_dim 参数指定嵌入的维度,sparse=False 表示不使用稀疏张量。
全连接层 (nn.Linear):

使用 nn.Linear 定义了一个全连接层,将嵌入后的文本表示映射到最终的类别分数。
输入维度为 embed_dim,输出维度为 num_class(类别的数量)。
初始化权重 (init_weights):

在模型初始化时调用,用于初始化 Embedding 层和全连接层的权重和偏置。
使用均匀分布初始化权重,偏置值被归零。
前向传播 (forward):

接受文本数据 text 和对应的偏移量 offsets 作为输入。
使用 Embedding 层将文本嵌入到低维空间,通过调用 self.embedding(text, offsets) 实现。
将嵌入后的文本表示传递给全连接层,得到最终的类别分数。

2.初始化模型

num_class = len(label_name)
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size,em_size,num_class).to(device)

3.定义训练和评估函数

import timedef train(dataloader):model.train()total_acc,train_loss,total_count =0,0,0log_interval = 50start_time =time.time()for idx,(text,label,offsets) in enumerate(dataloader):predicted_label = model(text,offsets)optimizer.zero_grad()loss= criterion(predicted_label,label)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),0.1)optimizer.step()total_acc +=(predicted_label.argmax(1)==label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx >0:elapsed = time.time() - start_timeprint('| epoch {:1d}|{:4d}/{:4d} batches''| train_acc {:4.3f} train_loss{:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))total_acc,train_loss,total_count = 0,0,0start_time =time.time()def evaluate(dataloader):model.eval()total_acc,train_loss,total_count = 0,0,0with torch.no_grad():for idx, (text,label,offsets) in enumerate(dataloader):predicted_label =model(text,offsets)loss = criterion(predicted_label,label)total_acc +=(predicted_label.argmax(1) == label).sum().item()train_loss +=loss.item()total_count +=label.size(0)return total_acc/total_count,train_loss/total_count

train 和 evaluate分别用于训练和评估文本分类模型。

·训练函数 train 的工作流程如下:将模型设置为训练模式。初始化总准确率、训练损失和总计数变量。记录训练开始的时间。遍历数据加载器,对每个批次:进行预测、清零优化器的梯度、计算损失(使用一个损失函数,例如交叉熵)、反向传播计算梯度、通过梯度裁剪防止梯度爆炸、执行一步优化器更新模型权重、更新总准确率和总损失、每隔一定间隔,打印训练进度和统计信息。

·评估函数 evaluate 的工作流程如下:将模型设置为评估模式、初始化总准确率和总损失、不计算梯度(为了节省内存和计算资源)、遍历数据加载器,对每个批次:进行预测、计算损失、
更新总准确率和总损失、返回整体的准确率和平均损失。

四、训练模型 

1.拆分数据集并运行模型

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_datasetEPOCHS = 10
LR = 5
BATCH_SIZE= 64criterion = torch.nn.CrossEntropyLoss()
optimizer =torch.optim.SGD(model.parameters(),lr=LR)
scheduler =torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.1)
total_accu =Nonetrain_iter = coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_,split_valid_=random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])train_dataloader =DataLoader(split_train_,batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_,batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_batch)
for epoch in range(1,EPOCHS + 1):epoch_start_time =time.time()train(train_dataloader)val_acc,val_loss = evaluate(valid_dataloader)lr= optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-'* 69)print('|epoch {:1d} |time: {:4.2f}s | ''valid_acc{:4.3f} valid_loss{:4.3f} | lr{:4.6f}'.format(epoch,time.time()- epoch_start_time,val_acc,val_loss,lr))print('-'* 69)

2.测试数据

def predict(text,text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text))output =model(text,torch.tensor([0]))return output.argmax(1).item()ex_text_str="还有双鸭山到淮阴的汽车票吗13号的"model =model.to("cpu")
print("该文本的类别是:%s"%label_name[predict(ex_text_str,text_pipeline)])该文本的类别是:Travel-Query

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/14153.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue3的核心API功能:computed()API使用

常规使用方法: 这样是常规使用方法. 另一种使用方法: 这样分别定义computed的get回调函数和set回调函数, 上面例子定义了plusOne.value的值为1, 那么这时候就走了computed的set回调函数,而没有走get回调函数. 当我们打印plusOne.value的值的时候,走的是get的回调函数而不是…

ios 原生项目迁移flutter第一天环境

由于公司已经有第一个吃螃蟹的项目组&#xff0c;我在迁移的时候想着站在巨人的肩膀上&#xff0c;但是搭配环境一定要问清楚对方flutter版本&#xff0c;路径也要安排好&#xff0c;不然就不行。 对着自己的项目照着葫芦画瓢&#xff0c;我刚开始为了配置管理图个方便随便放&…

Unity3D读取Excel表格写入Excel表格

系列文章目录 unity工具 文章目录 系列文章目录&#x1f449;前言&#x1f449;一、读取Excel表格&#x1f449;二、写入Excel表格&#x1f449;三、Fileinfo和Directoryinfo的操作&#x1f449;四、壁纸分享&#x1f449;总结 &#x1f449;前言 有时候难免会遇到读取文件写…

提供一个c# winform的多语言框架源码,采用json格式作为语言包,使用简单易于管理加载且不卡UI,支持“语言分级”管理

提供一个c# winform的多语言框架源码&#xff0c;采用json格式作为语言包&#xff0c;不使用resx资源&#xff0c;当然本质一样的&#xff0c;你也可以改为resx 一、先看下测试界面 演示了基本的功能&#xff1a;切换语言&#xff0c;如何加载语言&#xff0c;如何分级加载语…

【webrtc】内置opus解码器的移植

m98 ,不知道是什么版本的opus,之前的交叉编译构建: 【mia】ffmpeg + opus 交叉编译 【mia】ubuntu22.04 : mingw:编译ffmpeg支持opus编解码 看起来是opus是1.3.1 只需要移植libopus和opus的webrtc解码部分即可。 linux构建的windows可运行的opus库 G:\NDDEV\aliply-0.4\C…

如何为社交feed场景设计缓存体系?no.35

Feed 流场景分析 Feed 流是很多移动互联网系统的重要一环&#xff0c;如微博、微信朋友圈、QQ 好友动态、头条/抖音信息流等。虽然这些产品形态各不相同&#xff0c;但业务处理逻辑却大体相同。用户日常的“刷刷刷”&#xff0c;就是在获取 Feed 流&#xff0c;这也是 Feed 流的…

达梦数据库详解

达梦认证是指针对中国数据库管理系统&#xff08;DBMS&#xff09;厂商达梦公司所推出的数据库产品&#xff0c;即达梦数据库&#xff08;DMDB&#xff09;&#xff0c;进行的一种官方认证体系。达梦认证旨在验证数据库管理人员对达梦数据库产品的掌握程度&#xff0c;及其在数…

【HUST】信道编码|基于LDPC码的物理层安全编码方案概述

本文对方案的总结是靠 Kimi 阅读相关论文后生成的&#xff0c;我只看了标题和摘要感觉确实是这么回事&#xff0c;并没有阅读原文。 行文逻辑&#xff1a;是我自己设定的&#xff0c;但我并不是这个研究领域的&#xff0c;所以如果章节划分时有问题&#xff0c;期待指出&#x…

FTP文件传输议

FTP是一种文件传输协议&#xff1a;用来上传和下载&#xff0c;实现远程共享文件&#xff0c;和统一管理文件 工作原理&#xff1a;用于互联网上的控制文件的双向传输是一个应用程序。工作在TCP/IP协议簇的&#xff0c;其传输协议是TCP协议提高文件传输的共享性和可靠性&#…

8.STL中Vector容器的常见操作(附习题)

目录 1.vector的介绍 2 vector的使用 2.1 vector的定义 2.2 vector iterator 的使用 2.3 vector 空间增长问题 2.3 vector 增删查改 2.4 vector 迭代器失效问题 2.5 vector 在OJ中的使用 1.vector的介绍 vector是表示可变大小数组的序列容器。 就像数组一样&#xff0…

【Unitydemo制作】音游制作—控制器与特效

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;元宇宙-秩沅 &#x1f468;‍&#x1f4bb; hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍&#x1f4bb; 本文由 秩沅 原创 &#x1f468;‍&#x1f4bb; 收录于专栏&#xff1a;就业…

儿童卧室灯品牌该如何挑选?几款专业儿童卧室灯品牌分享

近视在儿童中愈发普遍&#xff0c;许多家长开始认识到&#xff0c;除了学业成绩之外&#xff0c;孩子的视力健康同样重要。毕竟&#xff0c;学业的落后可以逐渐弥补&#xff0c;而一旦孩子近视&#xff0c;眼镜便可能成为长期伴随。因此&#xff0c;专业的护眼台灯对于每个家庭…

大泽动力应急排水方舱功能介绍

一、排水方舱简介及其应用 排水方舱&#xff0c;亦被称为扬水设备&#xff0c;主要用于排除船舶内的积水&#xff0c;保证船体内的稳定与干燥。它常与抽水设备结合使用&#xff0c;能将船体内的水抽离并排放到外部&#xff0c;从而确保船只的正常运行。 二、排水方舱的运作方式…

链表经典OJ问题【环形链表】

题目导入 题目一&#xff1a;给你一个链表的头节点 head &#xff0c;判断链表中是否有环 题目二&#xff1a;给定一个链表的头节点 head &#xff0c;返回链表开始入环的第一个节点。 如果链表无环&#xff0c;则返回 NULL。 题目一 给你一个链表的头节点 head &#xff0c;…

leetcode230 二叉搜索树中第K小的元素

题目 给定一个二叉搜索树的根节点 root &#xff0c;和一个整数 k &#xff0c;请你设计一个算法查找其中第 k 个最小元素&#xff08;从 1 开始计数&#xff09;。 示例 输入&#xff1a;root [5,3,6,2,4,null,null,1], k 3 输出&#xff1a;3 解析 这道题应该是能做出…

【HMGD】STM32/GD32 I2C DMA 主从通信

STM32 I2C配置 主机配置 主机只要配置速度就行 从机配置 从机配置相同速度&#xff0c;可以设置第二地址 因为我的板子上面已经有了上拉电阻&#xff0c;所以可以直接通信 STM32 I2C DMA 定长主从通信代码示例 int state 0; static uint8_t I2C_recvBuf[10] {0}; stat…

扭矩拧紧螺栓简便的估算方法

扭矩拧紧螺栓简便的估算方法。 计算公式&#xff1a; T K x D x P 其中&#xff1a;T为拧紧力矩&#xff1b;D为螺纹公称直径&#xff1b;P为预紧力&#xff1b;K为拧紧系数。 预紧力计算公式&#xff1a;P(0.75~0.9) σsAs&#xff1b;其中前面系数对可拆连接取0.75&#xff0…

NLP(18)--大模型发展(2)

前言 仅记录学习过程&#xff0c;有问题欢迎讨论 Transformer结构&#xff1a; LLM的结构变化&#xff1a; Muti-head 共享&#xff1a; Q继续切割为muti-head,但是K,V少切&#xff0c;比如切为2个&#xff0c;然后复制到n个muti-head减少参数量&#xff0c;加速训练 atte…

运维开发.索引引擎ElasticSearch.倒序索引的概念

运维开发.索引引擎ElasticSearch 倒序索引的概念 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite&#xff1a;http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.csdn…

两步将 CentOS 6.0 原地升级并迁移至 RHEL 7.9

《OpenShift / RHEL / DevSecOps 汇总目录》 说明 本文介绍如何将一个 CentOS 6.0 的系统升级并转换迁移到 RHEL 7.9。 本文是《在离线环境中将 CentOS 7.X 原地升级并迁移至 RHEL 7.9》阶进篇。 所有被测软件的验证操作可参见上述前文中对应章节的说明。 准备 CentOS 6.…