【NLP练习】使用Word2Vec实现文本分类

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、数据预处理

1. 任务说明

本次加入Word2Vec使用PyTorch实现中文文本分类,Word2Vec则是其中的一种词嵌入方法,是一种用于生成词向量的浅层神经网络模型。Word2Vec通过学习大量的文本数据,将每个单词表示为一个连续的向量,这些向量可以捕捉单词之间的语义和句法关系。数据示例如下:
在这里插入图片描述

2. 加载数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")   #忽略警告信息device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

输出:

device(type='cpu')
import pandas as pd#加载自定义中文数据
train_data = pd.read_csv(r'D:\Personal Data\Learning Data\DL Learning Data\train.csv',sep ='\t',header = None)
train_data.head()

输出:
在这里插入图片描述

#构造数据集迭代器
def coustom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,yx = train_data[0].values[:]
y = train_data[1].values[:]

3. 构建词典

from gensim.models.word2vec import Word2Vec
import numpy as np#训练Word2Vec浅层神经网络模型
w2v = Word2Vec(vector_size=100,   #特征向量的维度,默认为100min_count=3)        #对字典做截断,词频少于min_count次数的单词会被丢弃掉,默认值为5
w2v.build_vocab(x)
w2v.train(x,total_examples=w2v.corpus_count,epochs=20)

输出:

(2732920, 3663560)
#将文本转化为向量
def average_vec(text):vec = np.zeros(100).reshape((1,100))for word in text:try:vec += w2v.wv[word].reshape((1,100))except KeyError:continuereturn vec#将词向量保存为Ndarray
x_vec = np.concatenate([average_vec(z) for z in x])#保存Word2Vec模型及词向量
w2v.save(r'D:\Personal Data\Learning Data\DL Learning Data\w2v.pkl')
train_iter = coustom_data_iter(x_vec, y)
len(x),len(x_vec)

输出:

(12100, 12100)
label_name = list(set(train_data[1].values[:]))
print(label_name)
['Music-Play', 'Travel-Query', 'Weather-Query', 'Audio-Play', 'Radio-Listen', 'Video-Play', 'Calendar-Query', 'HomeAppliance-Control', 'Alarm-Update', 'Other', 'TVProgram-Play', 'FilmTele-Play']

4. 生成数据批次和迭代器


text_pipeline = lambda x : average_vec(x)
label_pipeline = lambda x : label_name.index(x)print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

输出:

[[ -3.16253691  -1.9659146    3.77608298   1.06067566  -5.1883576-8.70868033   3.89949582  -2.18139926   6.70676575  -4.9919778316.07808281   9.24493882 -15.24484421  -6.60270358  -6.24634131-3.64680131  -2.53697125   2.8301437    7.22867384  -2.133602622.1341381    6.06681348  -4.65962007   1.23247945   4.331831732.15399135  -1.83306327  -2.49018155  -0.22937663   1.57925591-3.22308699   3.56521453   5.94520254   3.46486389   3.46772102-4.10725167   0.31579057   9.28542571   7.48527321  -2.930142968.39484799 -11.3110949    4.46019076  -0.64214947  -6.3485507-5.3710938    1.6277833   -1.44570495   7.21582842   3.292127360.79481401  10.0952674   -0.72304608  -0.46801499   6.08651663-0.67166806  10.56184006   1.74745524  -4.52621601   1.8375443-5.368839    10.54501078  -2.85536074  -4.55352878 -13.424223743.17138463   7.39386847  -2.24578104 -16.08510212  -5.7369401-2.90420356  -4.19321531   3.29097138  -9.36627482   3.67335742-0.80693699  -0.53749662  -3.67742246   0.48116201   5.517548480.82724179   4.13207588   0.86254621  13.13354776  -3.113592512.18450189   9.11669949  -4.88159943   2.01295654  11.02899793-5.33385142  -7.47531134  -4.02018939  -0.52363324  -1.799801854.00845213  -2.436053     0.16959296  -7.10417359  -0.55219389]]
5
#生成数据批次和迭代器
from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list = [],[]        for(_text, _label) in batch:#标签列表label_list.append(label_pipeline(_label))#文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.float32)text_list.append(processed_text)label_list = torch.tensor(label_list,dtype=torch.int64)text_list = torch.cat(text_list)return text_list.to(device), label_list.to(device)#数据加载器
dataloader = DataLoader(train_iter,batch_size = 8,shuffle = False,collate_fn = collate_batch
)

二、构建模型

1. 搭建模型

#搭建模型
from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, num_class):super(TextClassificationModel,self).__init__()self.fc = nn.Linear(100, num_class)def forward(self, text):return self.fc(text)

2. 初始化模型

#初始化模型
#定义实例
num_class = len(label_name)
vocab_size = 100000
em_size = 12
model = TextClassificationModel(num_class).to(device)

3. 定义训练与评估函数

#定义训练与评估函数
import timedef train(dataloader):model.train()          #切换为训练模式total_acc, train_loss, total_count = 0,0,0log_interval = 50start_time = time.time()for idx, (text,label) in enumerate(dataloader):predicted_label = model(text)optimizer.zero_grad()                             #grad属性归零loss = criterion(predicted_label, label)          #计算网络输出和真实值之间的差距,label为真loss.backward()                                   #反向传播torch.nn.utils.clip_grad_norm_(model.parameters(),0.1)  #梯度裁剪optimizer.step()                                  #每一步自动更新#记录acc与losstotal_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('|epoch{:d}|{:4d}/{:4d} batches|train_acc{:4.3f} train_loss{:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))total_acc,train_loss,total_count = 0,0,0staet_time = time.time()def evaluate(dataloader):model.eval()      #切换为测试模式total_acc,train_loss,total_count = 0,0,0with torch.no_grad():for idx,(text,label) in enumerate(dataloader):predicted_label = model(text)loss = criterion(predicted_label,label)   #计算loss值#记录测试数据total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count

三、训练模型

1. 拆分数据集并运行模型

#拆分数据集并运行模型
from torch.utils.data.dataset   import random_split
from torchtext.data.functional  import to_map_style_dataset# 超参数设定
EPOCHS      = 10   #epoch
LR          = 5    #learningRate
BATCH_SIZE  = 64   #batch size for training#设置损失函数、选择优化器、设置学习率调整函数
criterion   = torch.nn.CrossEntropyLoss()
optimizer   = torch.optim.SGD(model.parameters(), lr = LR)
scheduler   = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma = 0.1)
total_accu  = None# 构建数据集
train_iter = coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset   = to_map_style_dataset(train_iter)
split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])train_dataloader    = DataLoader(split_train_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)
valid_dataloader    = DataLoader(split_valid_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)#获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:d} | time:{:4.2f}s | valid_acc {:4.3f} valid_loss {:4.3f}'.format(epoch,time.time() - epoch_start_time,val_acc,val_loss))print('-' * 69)

输出:

|epoch1|  50/ 152 batches|train_acc0.724 train_loss0.02592
|epoch1| 100/ 152 batches|train_acc0.820 train_loss0.01937
|epoch1| 150/ 152 batches|train_acc0.832 train_loss0.01843
---------------------------------------------------------------------
| epoch 1 | time:1.11s | valid_acc 0.827 valid_loss 0.019
---------------------------------------------------------------------
|epoch2|  50/ 152 batches|train_acc0.842 train_loss0.01750
|epoch2| 100/ 152 batches|train_acc0.831 train_loss0.01787
|epoch2| 150/ 152 batches|train_acc0.841 train_loss0.01953
---------------------------------------------------------------------
| epoch 2 | time:1.14s | valid_acc 0.780 valid_loss 0.029
---------------------------------------------------------------------
|epoch3|  50/ 152 batches|train_acc0.873 train_loss0.01189
|epoch3| 100/ 152 batches|train_acc0.884 train_loss0.00944
|epoch3| 150/ 152 batches|train_acc0.905 train_loss0.00763
---------------------------------------------------------------------
| epoch 3 | time:1.09s | valid_acc 0.886 valid_loss 0.009
---------------------------------------------------------------------
|epoch4|  50/ 152 batches|train_acc0.891 train_loss0.00794
|epoch4| 100/ 152 batches|train_acc0.894 train_loss0.00711
|epoch4| 150/ 152 batches|train_acc0.905 train_loss0.00646
---------------------------------------------------------------------
| epoch 4 | time:1.09s | valid_acc 0.874 valid_loss 0.009
---------------------------------------------------------------------
|epoch5|  50/ 152 batches|train_acc0.902 train_loss0.00593
|epoch5| 100/ 152 batches|train_acc0.909 train_loss0.00591
|epoch5| 150/ 152 batches|train_acc0.897 train_loss0.00687
---------------------------------------------------------------------
| epoch 5 | time:1.03s | valid_acc 0.890 valid_loss 0.008
---------------------------------------------------------------------
|epoch6|  50/ 152 batches|train_acc0.909 train_loss0.00592
|epoch6| 100/ 152 batches|train_acc0.900 train_loss0.00609
|epoch6| 150/ 152 batches|train_acc0.904 train_loss0.00607
---------------------------------------------------------------------
| epoch 6 | time:1.02s | valid_acc 0.890 valid_loss 0.008
---------------------------------------------------------------------
|epoch7|  50/ 152 batches|train_acc0.908 train_loss0.00559
|epoch7| 100/ 152 batches|train_acc0.906 train_loss0.00604
|epoch7| 150/ 152 batches|train_acc0.902 train_loss0.00623
---------------------------------------------------------------------
| epoch 7 | time:1.00s | valid_acc 0.888 valid_loss 0.008
---------------------------------------------------------------------
|epoch8|  50/ 152 batches|train_acc0.906 train_loss0.00558
|epoch8| 100/ 152 batches|train_acc0.904 train_loss0.00592
|epoch8| 150/ 152 batches|train_acc0.908 train_loss0.00602
---------------------------------------------------------------------
| epoch 8 | time:1.08s | valid_acc 0.888 valid_loss 0.008
---------------------------------------------------------------------
|epoch9|  50/ 152 batches|train_acc0.903 train_loss0.00566
|epoch9| 100/ 152 batches|train_acc0.911 train_loss0.00550
|epoch9| 150/ 152 batches|train_acc0.904 train_loss0.00630
---------------------------------------------------------------------
| epoch 9 | time:1.20s | valid_acc 0.889 valid_loss 0.008
---------------------------------------------------------------------
|epoch10|  50/ 152 batches|train_acc0.910 train_loss0.00564
|epoch10| 100/ 152 batches|train_acc0.912 train_loss0.00550
|epoch10| 150/ 152 batches|train_acc0.897 train_loss0.00633
---------------------------------------------------------------------
| epoch 10 | time:1.09s | valid_acc 0.889 valid_loss 0.008
---------------------------------------------------------------------
test_acc,test_loss = evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))

输出:

模型准确率为:0.8843

2. 测试指定数据

def predict(text,text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text),dtype = torch.float32)print(text.shape)output = model(text)return output.argmax(1).item()ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"
print("该文本的类别是:%s" %label_name[predict(ex_text_str,text_pipeline)])
torch.Size([1, 100])
该文本的类别是:Travel-Query

四、总结

Word2Vec 通过学习单词的上下文关系,将单词映射到向量空间。这使得语义上相似的单词在向量空间中具有相近的位置。因此,使用 Word2Vec 可以更好地捕获文本中的语义信息,从而提高文本分类的准确性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/663.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2001-2022上市公司数字化转型数据(含原始数据+计算代码+计算结果)

2001-2022上市公司数字化转型数据(含原始数据计算代码计算结果) 1、时间:2001-2022年 2、来源:原始数据整理自wind 3、指标:证券代码、证券简称、统计截止日期、是否发生ST或*ST或PT、是否发生暂停上市、行业代码、…

简单了解Element Plus

请简述Element Plus是什么,以及它与其他UI框架的主要区别是什么? 答案: Element Plus是一套为开发者、设计师和产品经理准备的基于Vue 3.0的桌面端组件库。它与其他UI框架的主要区别在于其高度的可定制性、丰富的组件库以及良好的性能表现。…

戴尔电脑怎么关闭开机密码?

1.同时按键盘上是“window键”(一般是键盘最下面一排第二个)和“R键“,并在弹出的窗口输入“netplwiz”然后确定。 2.然后会弹出的“用户账户”窗口,接下来取消勾选“要使用本计算机,用户必须输入用户名和密码” 3.上面…

MySQL之explain执行计划

一、explain作用 MySQL的EXPLAIN命令是开发者经常使用的一个强大的分析工具,帮助开发者了解查询的性能瓶颈和优化方向。 二、使用方法 只需要在要执行的sql语句前加explain关键字即可,如下 mysql> explain select * from user where id >60; -…

每日算法练习(1)

开一个新坑,记录下自己每天的算法练习,希望自己通过1个多月的学习,能够成为算法大神。 下面正式开始新坑。 两个数组的交集 这是牛客上的题,根据题意,我们有多种解法,这题用哈希比较好写。我们可以弄一个…

1.8、数位DP(算法提高课)

一、数字游戏 题目链接:http://ybt.ssoier.cn:8088/problem_show.php?pid1588 题意:求给定区间【a,b】中的不降数的个数,不降数的定义为从左到右各位数字成小于等于的关系。 思路:首先预处理出来 f[i][j] 为一共有…

pytorch环境配置踩坑记录

一、问题1 1.执行命令 conda create -n pytorch python3.62.报错如下 Solving environment: failedCondaHTTPError: HTTP 000 CONNECTION FAILED for url <https://repo.anaconda.com/pkgs/msys2/noarch/repodata.json.bz2> Elapsed: -An HTTP error occurred when tr…

Java 变得越来越像 Rust?

随着编程技术的增强和复杂性的提升&#xff0c;许多编程语言也纷纷效仿&#xff0c;Java 也不例外。 另一边&#xff0c;尽管社区内部问题重重&#xff0c;但 Rust 仍逐年获得开发人员的喜爱。这背后都是有原因的&#xff1a;Rust 的编译器让开发人员避免了各种问题。编译器对…

【GlobalMapper精品教程】074:从Lidar点云创建3D地形模型

本文基于地形点云数据,基于泊松方法、贪婪三角形测量方法和阿尔法形状创建3d地形模型。 文章目录 一、加载地形点云数据二、创建三维地形模型1. 泊松方法2. 贪婪三角形测量方法3. 阿尔法形状注意事项一、加载地形点云数据 加载配套案例数据包中的data074.rar中的地形点云数据…

3D地图大屏 附源码(Three.js + Vue3)

目录 &#x1f44b; 前言 &#x1f680; 项目包 ⚒️ 字体制作 &#x1f310; 地图制作 &#x1f4a1; 参考视频 & 项目 开源项目&#xff08;Vue3tsWindcssEchartThree.js大屏案例&#xff09; 开源&#xff08;教程&#xff09; UI风格学习&#xff08; www.shuzixs.com …

本地事务存在的问题

在微服务中&#xff0c;如果还是使用本地事务会出现问题 比如订单服务中先下订单再调用库存服务再调用用户服务增加积分&#xff0c;这时候如果调用库存服务出现假失败&#xff0c;也就是说实际上成功了&#xff0c;但是因为网络原因没有返回&#xff0c;没返回出错了&#xff…

java调用讯飞星火认知模型

前往讯飞开发平台选择产品&#xff0c;获取appId、apiKey、APISecret&#xff0c;这里我选择的是v3.0模型。 java后端实现 本项目以及实现了基本的会话功能&#xff0c;小伙伴可以自己扩充其他的例如绘画功能。 注意&#xff1a;星火模型的api使用的是websocket协议&#xf…

c 多文件编程

1.结构目录 声明类:用于声明方法,方便方法管理和调用&#xff1b; 实现类:用于实现声明的方法; 应用层:调用方法使用 写过java代码的兄弟们可以这么理解&#xff1a; 声明类 为service层 实现类 为serviceimpl层 应用层 为conlloter层 2.Dome 把函数声明放在头文件xxx.h中&…

vtk.vtkAssembly()用法解释

vtk.vtkAssembly 是 VTK库中的一个重要类&#xff0c;允许通过将多个vtkActor对象组合在一起来创建复杂的3D模型。每个 vtk.vtkAssembly 对象都可以包含其他 vtk.vtkAssembly 对象&#xff0c;构成一个层级的组合结构。 以下是创建并使用 vtk.vtkAssembly 的一个基本示例&…

与上级意见不合时如何恰当地表达自己的观点?

在工作中与上级意见不合时&#xff0c;恰当表达自己的观点并寻求共识是一个需要谨慎处理的问题。以下是一些建议&#xff1a; 1. **尊重与礼貌**&#xff1a;在任何情况下&#xff0c;都应保持对上级的尊重和礼貌。即使在意见不合时&#xff0c;也要避免情绪化&#xff0c;保持…

200页图解国标《数据分类分级规则》正式稿,强化重要数据识别

GB/T 43697-2024《数据安全技术 数据分类分级规则》正式稿发布&#xff0c;并于2024年10月1日实施。2024年4月17日&#xff0c;国家标准全文公开系统公布了国标最终版。《数据分类分级规则》是全国网安标委更名后&#xff0c;发布的第一部以“数据安全技术”命名的国家标准&…

Python-VBA函数之旅-enumerate函数

目录 1、enumerate函数&#xff1a; 1-1、Python&#xff1a; 1-2、VBA&#xff1a; 2、相关文章&#xff1a; 个人主页&#xff1a;非风V非雨-CSDN博客 enumerate函数在Python中是一个强大的内置函数&#xff0c;用于将一个可迭代对象转换为一个索引序列&#xff0c;同时返…

java-spring 图灵 04 doscan

01.本次的重点依旧是扫描函数&#xff0c;这次是spring中的源码&#xff1a; 02.第一步&#xff0c;构造AnnotationConfigApplicationContext 主方法&#xff1a; public static void main(String[] args) {// 创建一个Spring容器AnnotationConfigApplicationContext applica…

基于react native的android原生微信客服,微信支付以及判断是否安装微信

基于react native的android原生微信客服&#xff0c;微信支付以及判断是否安装微信 引入SDK&#xff08;Android Studio 环境下&#xff09;创建wxapi/WXPayEntryActivity.java&#xff08;用于接收微信响应返回信息&#xff09;CustomerServiceModule.javaCustomerServicePack…

C#基础|Debug程序调试学习和技巧总结

哈喽&#xff0c;你好啊&#xff0c;我是雷工&#xff01; 在程序的开发过程中&#xff0c;可能绝大部分时间是用来调试程序&#xff0c; 当完成了某个功能的编程&#xff0c;都需要调试一下程序&#xff0c;看编程是否存在问题。 01 为什么需要程序调试 无论是电气工程师还…