第N8周:使用Word2vec实现文本分类

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

本周任务:

  • 结合Word2Vec文本内容预测文本标签

加载数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings
import pandas as pdwarnings.filterwarnings('ignore')device = torch.device("cuda" if torch.cuda.is_available() else  "cpu") 
print(device)
# 从本地CSV文件中读取文本内容和标签
train_data = pd.read_csv("D:/桌面/365/train.csv", sep='\t', header=None)
train_data.head()
def coustom_data_iter(texts, labels):for x, y in zip(texts, labels):yield x, yx = train_data[0].values[:]
y = train_data[1].values[:]

构建词典

from gensim.models.word2vec import Word2Vec
import numpy as npw2v = Word2Vec(vector_size=100, min_count=3)w2v.build_vocab(x)
w2v.train(x, total_examples=w2v.corpus_count, epochs=20)
def average_vec(text):vec = np.zeros(100).reshape((1,100))for word in text:try:vec += w2v.wv[word].reshape((1,100))except KeyError:continuereturn vecx_vec = np.concatenate([average_vec(z) for z in x])w2v.save('./w2c_model.pkl')

生成数据批次和迭代器

text_pipeline = lambda x: average_vec(x)
label_pipeline = lambda x: label_name.index(x)
text_pipeline('你在干嘛')
array([[ 0.78121352,  1.93111382,  0.96291968,  0.39362412, -1.67714586,-0.55152619,  1.7284598 ,  0.69204517,  1.1396839 , -0.9755076 ,-0.55864345, -3.68676656,  1.41707338, -0.44626126,  0.2580443 ,1.09325009,  2.28043211, -2.26334408,  3.32311766, -1.24760717,2.2325974 , -0.48408172, -0.55063696,  0.36853465, -1.32127168,-0.53377433, -1.48909409, -0.5050023 ,  1.42371842, -0.4252875 ,2.52355766,  0.60818394, -1.68924798, -0.16912293,  1.26915893,-0.4575564 ,  0.02507078,  3.33139969, -2.1995108 ,  0.44307417,-0.41596803,  1.39861814, -0.58643346,  0.91654699, -0.08089826,0.08773175,  1.51611513, -0.22212304, -3.55333737,  1.93851076,0.42497785, -1.47862379, -0.96684674,  1.20408788, -0.86870126,-1.12228102,  1.67186388, -1.11024326, -0.18936946,  1.0811481 ,1.82965288, -0.78202841,  2.17574303, -1.03871018, -0.51042572,0.40746585, -1.70572275,  1.3409467 ,  1.38298857,  1.11757374,-0.8333215 ,  0.04856796,  1.43110101, -0.02333559,  0.82732772,-0.9469737 , -4.43783602, -0.20290428,  1.04759257, -1.21757071,-1.30356295,  0.50049417, -1.87846385,  2.47995635, -2.41918275,-1.72291106,  2.65663178, -0.96948189, -1.30033612, -0.37353188,0.53420451, -1.99955091,  0.12223354,  1.74861516,  0.99491888,-1.43117569,  0.063243  ,  0.84598846, -2.79536995,  0.02697589]])
from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list = [],[]for (_text, _label) in batch:label_list.append(label_pipeline(_label))processed_text = torch.tensor(text_pipeline(_text), dtype=torch.float32)text_list.append(processed_text)label_list = torch.tensor(label_list, dtype=torch.int64)text_list  = torch.cat(text_list)return text_list.to(device), label_list.to(device)datalodaer = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

模型构建

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, num_class):super(TextClassificationModel, self).__init__()self.fc = nn.Linear(100, num_class)def forward(self, text):text = text.float()return self.fc(text)
在这组词汇中不匹配的词汇:书

初始化模型

num_class = len(label_name)
vocab_size = 100000
em_size = 12
model = TextClassificationModel(num_class).to(device)

定义训练及评估函数

import timedef train(dataloader):model.train()total_acc, train_loss, total_count = 0,0,0log_interval = 50start_time = time.time()for idx, (text,label) in enumerate(dataloader): # text, label的顺序不能反,否则会报错predicted_label = model(text)optimizer.zero_grad()loss = criterion(predicted_label, label)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)optimizer.step()total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc/total_count, train_loss/total_count))total_acc, train_loss, total_count = 0,0,0start_time = time.time()def evaluate(dataloader):model.eval()total_acc,train_loss, total_count = 0,0,0with torch.no_grad():for idx, (text,label) in enumerate(dataloader):predicted_label = model(text)loss = criterion(predicted_label, label)total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count     

训练模型

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_datasetEPOCHS = 10
LR = 5
BATCH_SIZE = 64criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = Nonetrain_iter = coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_, split_valid_ = random_split(train_dataset, [int(len(train_dataset) * 0.8), int(len(train_dataset) * 0.2)])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)for epoch in range(1,EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('|epoch {:1d} | time: {:4.2f}s |''valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch, time.time() - epoch_start_time, val_acc, val_loss,lr))print('-' * 69)
| epoch 1 |   50/ 152 batches| train_acc 0.752 train_loss 0.02433
| epoch 1 |  100/ 152 batches| train_acc 0.836 train_loss 0.01740
| epoch 1 |  150/ 152 batches| train_acc 0.831 train_loss 0.01821
---------------------------------------------------------------------
|epoch 1 | time: 2.83s |valid_acc 0.847 valid_loss 0.016 | lr 5.000000
---------------------------------------------------------------------
| epoch 2 |   50/ 152 batches| train_acc 0.843 train_loss 0.01709
| epoch 2 |  100/ 152 batches| train_acc 0.835 train_loss 0.01863
| epoch 2 |  150/ 152 batches| train_acc 0.854 train_loss 0.01577
---------------------------------------------------------------------
|epoch 2 | time: 1.28s |valid_acc 0.852 valid_loss 0.017 | lr 5.000000
---------------------------------------------------------------------
| epoch 3 |   50/ 152 batches| train_acc 0.854 train_loss 0.01663
| epoch 3 |  100/ 152 batches| train_acc 0.855 train_loss 0.01743
| epoch 3 |  150/ 152 batches| train_acc 0.846 train_loss 0.01738
---------------------------------------------------------------------
|epoch 3 | time: 1.34s |valid_acc 0.862 valid_loss 0.017 | lr 5.000000
---------------------------------------------------------------------
| epoch 4 |   50/ 152 batches| train_acc 0.862 train_loss 0.01514
| epoch 4 |  100/ 152 batches| train_acc 0.854 train_loss 0.01638
| epoch 4 |  150/ 152 batches| train_acc 0.854 train_loss 0.01920
---------------------------------------------------------------------
|epoch 4 | time: 1.18s |valid_acc 0.847 valid_loss 0.018 | lr 5.000000
---------------------------------------------------------------------
| epoch 5 |   50/ 152 batches| train_acc 0.898 train_loss 0.00902
| epoch 5 |  100/ 152 batches| train_acc 0.897 train_loss 0.00885
| epoch 5 |  150/ 152 batches| train_acc 0.900 train_loss 0.00893
---------------------------------------------------------------------
|epoch 5 | time: 1.37s |valid_acc 0.879 valid_loss 0.011 | lr 0.500000
---------------------------------------------------------------------
| epoch 6 |   50/ 152 batches| train_acc 0.900 train_loss 0.00788
| epoch 6 |  100/ 152 batches| train_acc 0.904 train_loss 0.00703
| epoch 6 |  150/ 152 batches| train_acc 0.901 train_loss 0.00681
---------------------------------------------------------------------
|epoch 6 | time: 1.33s |valid_acc 0.883 valid_loss 0.010 | lr 0.500000
---------------------------------------------------------------------
| epoch 7 |   50/ 152 batches| train_acc 0.922 train_loss 0.00573
| epoch 7 |  100/ 152 batches| train_acc 0.901 train_loss 0.00728
| epoch 7 |  150/ 152 batches| train_acc 0.894 train_loss 0.00702
---------------------------------------------------------------------
|epoch 7 | time: 1.12s |valid_acc 0.879 valid_loss 0.009 | lr 0.500000
---------------------------------------------------------------------
| epoch 8 |   50/ 152 batches| train_acc 0.908 train_loss 0.00630
| epoch 8 |  100/ 152 batches| train_acc 0.905 train_loss 0.00593
| epoch 8 |  150/ 152 batches| train_acc 0.911 train_loss 0.00526
---------------------------------------------------------------------
|epoch 8 | time: 1.11s |valid_acc 0.881 valid_loss 0.009 | lr 0.050000
---------------------------------------------------------------------
| epoch 9 |   50/ 152 batches| train_acc 0.911 train_loss 0.00580
| epoch 9 |  100/ 152 batches| train_acc 0.905 train_loss 0.00611
| epoch 9 |  150/ 152 batches| train_acc 0.917 train_loss 0.00516
---------------------------------------------------------------------
|epoch 9 | time: 1.12s |valid_acc 0.881 valid_loss 0.009 | lr 0.005000
---------------------------------------------------------------------
| epoch 10 |   50/ 152 batches| train_acc 0.912 train_loss 0.00564
| epoch 10 |  100/ 152 batches| train_acc 0.905 train_loss 0.00575
| epoch 10 |  150/ 152 batches| train_acc 0.916 train_loss 0.00565
---------------------------------------------------------------------
|epoch 10 | time: 1.12s |valid_acc 0.881 valid_loss 0.009 | lr 0.000500
---------------------------------------------------------------------

测试指定数据

def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text), dtype=torch.float32)print(text.shape)output = model(text)return output.argmax(1).item()ex_text_str = '还有双鸭山到淮阴的汽车票吗13号的'model = model.to('cpu')print('该文本的类别是: %s' %label_name[predict(ex_text_str, text_pipeline)])
torch.Size([1, 100])
该文本的类别是: Travel-Query

总结

  • 本周是结合前几周的内容,使用Word2Vec进行词嵌入之后,再实现中文文本分类
  • 本次自己的错误:将for idx, (text,label) in enumerate(dataloader): 中的text、label搞反了,导致输入和模型的输出无法匹配,因此花费了很多时间

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/61573.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何在 UniApp 中实现 iOS 版本更新检测

随着移动应用的不断发展,保持应用程序的更新是必不可少的,这样用户才能获得更好的体验。本文将帮助你在 UniApp 中实现 iOS 版的版本更新检测和提示,适合刚入行的小白。我们将分步骤进行说明,每一步所需的代码及其解释都会一一列出…

移动充储机器人“小奥”的多场景应用(上)

一、高速公路服务区应用 在高速公路服务区,新能源汽车的充电需求得到“小奥”机器人的及时响应。该机器人配备有储能电池和自动驾驶技术,能够迅速定位至指定充电点,为待充电的新能源汽车提供服务。得益于“小奥”的机动性,其服务…

Redis 的代理类注入失败,连不上 redis

在测试 redis 是否成功连接时&#xff0c;发现 bean 没有被创建成功&#xff0c;导致报错 根据报错提示&#xff0c;需要我们添加依赖&#xff1a; <dependency><groupId>org.apache.commons</groupId><artifactId>commons-pool2</artifactId>&l…

桌面怎么快速添加便签?适合桌面记事的便签小工具

在数字化时代&#xff0c;我们每天面对电脑处理大量任务&#xff0c;无论是工作计划、会议纪要还是个人生活琐事&#xff0c;都需要一个可靠的桌面记事工具来帮助我们记录和整理。因此&#xff0c;一款适合桌面使用的便签软件成为了我们不可或缺的助手。 敬业签就是这样一款功…

UE5 腿部IK 解决方案 footplacement

UE5系列文章目录 文章目录 UE5系列文章目录前言一、FootPlacement 是什么&#xff1f;二、具体实现 前言 在Unreal Engine 5 (UE5) 中&#xff0c;腿部IK&#xff08;Inverse Kinematics&#xff0c;逆向运动学&#xff09;是一个重要的动画技术&#xff0c;用于实现角色脚部准…

KLV6008固态继电器:高压应用的理想紧凑方案

在当今快节奏的电子领域&#xff0c;找到平衡性能、可靠性和安全性的组件至关重要。CRIA Semiconductor的KLV6008固态继电器(SSR)正是满足了这一要求。这款紧凑型继电器专为高压、低电流切换而设计&#xff0c;是适用于各种应用的多功能解决方案。 为什么选择KLV6008&#xff1…

在 Swift 中实现字符串分割问题:以字典中的单词构造句子

文章目录 前言摘要描述题解答案题解代码题解代码分析示例测试及结果时间复杂度空间复杂度总结 前言 本题由于没有合适答案为以往遗留问题&#xff0c;最近有时间将以往遗留问题一一完善。 LeetCode - #140 单词拆分 II 不积跬步&#xff0c;无以至千里&#xff1b;不积小流&…

HarmonyOs鸿蒙开发实战(21)=>组件间通信@ohos/liveeventbus

1.简介 LiveEventBus是一款消息总线&#xff0c;具有生命周期感知能力&#xff0c;支持Sticky&#xff0c;支持跨进程&#xff0c;支持跨APP发送消息。 2.下载安装 ohpm install ohos/liveeventbus 3.订阅&#xff0c;注册监听 4.发送事件 5. 完成 > 记得关注博主&#xff…

OpenCV和Qt坐标系不一致问题

“ OpenCV和QT坐标系导致绘图精度下降问题。” OpenCV和Qt常用的坐标系都是笛卡尔坐标系&#xff0c;但是细微处有些不同。 01 — OpenCV坐标系 OpenCV是图像处理库&#xff0c;是以图像像素为一个坐标位置&#xff0c;即一个像素对应一个坐标&#xff0c;所以其坐标系也叫图像…

nohup java -jar supporterSys.jar --spring.profiles.active=prod

文章目录 1、ps -ef | grep java2、kill 13713、ps -ef | grep java4、nohup java -jar supporterSys.jar --spring.profiles.activeprod &5、ps -ef | grep java1. 启动方式进程 1371进程 19994 2. 主要区别3. 可能的原因4. 建议 1、ps -ef | grep java rootshipper:~# p…

Ubuntu上安装MySQL并且实现远程登录

目录 下载网络工具 查看网络连接 更新系统软件包&#xff1b; 安装mysql数据库 查看mysql数据库状态 以数字ip形式显示mysql的监听状态。&#xff08;默认监听端口是3306&#xff09; 查看安装mysql数据库时系统创建的目录信息。 根据查询到的系统用户名以及随机密码&a…

shell编写——脚本传参与运算

shell编写——脚本传参与运算 声明&#xff01; 学习视频来自B站up主 泷羽sec 有兴趣的师傅可以关注一下&#xff0c;如涉及侵权马上删除文章&#xff0c;笔记只是方便各位师傅的学习和探讨&#xff0c;文章所提到的网站以及内容&#xff0c;只做学习交流&#xff0c;其他均与本…

设计模式之 观察者模式

观察者模式&#xff08;Observer Pattern&#xff09;是一种行为型设计模式&#xff0c;它定义了一种一对多的依赖关系&#xff0c;让多个观察者对象同时监听一个主题对象&#xff08;Subject&#xff09;。当主题对象的状态发生变化时&#xff0c;所有依赖于它的观察者都会得到…

深入了解 Linux htop 命令:功能、用法与示例

文章目录 深入了解 Linux htop 命令&#xff1a;功能、用法与示例什么是 htop&#xff1f;htop 的安装htop的基本功能A区&#xff1a;系统资源使用情况B区&#xff1a;系统概览信息C区&#xff1a;进程列表D区&#xff1a;功能键快捷方式 与 top 的对比常见用法与示例实际场景应…

【深度学习】【RKNN】【C++】模型转化、环境搭建以及模型部署的详细教程

【深度学习】【RKNN】【C】模型转化、环境搭建以及模型部署的详细教程 提示:博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论 文章目录 【深度学习】【RKNN】【C】模型转化、环境搭建以及模型部署的详细教程前言模型转换--pytorch转rknnpytorch转onnxonnx转rkn…

【spark】远程debug spark任务(含有pyspark)

--master yarn和--master client都是可以的。 spark-submit \ --master yarn \ --deploy-mode client \ --name "test-remote-debug" \ --conf "spark.driver.extraJavaOptions-agentlib:jdwptransportdt_socket,servery,suspendn,address5005" \ --conf …

如何使用 Vivado 从源码构建 Infinite-ISP FPGA 项目

如约介绍源码构建 Infinite-ISP 项目&#xff0c;其实大家等的是源码&#xff0c;所以中间过程简洁略过&#xff0c;可以直接翻到文末获取链接。 开源ISP&#xff08;Infinite-ISP&#xff09;介绍 构建工程 第一步&#xff0c;从文末或者下面链接获取源码 https://github.com/…

彻底理解Redis的持久化方式

一.由来 因为Redis之所以能够提供高效读写的操作&#xff0c;是因为它是基于内存的&#xff0c;但是这样也会带来一个问题&#xff0c;及在服务器宕机或者重启的情况下&#xff0c;内存里面的数据就会被丢失掉&#xff0c;所以为了解决这个问题&#xff0c;Redis就提供了持久化…

Bug Fix 20241122:缺少lib文件错误

今天有朋友提醒才突然发现 gitee 上传的代码存在两个很严重&#xff0c;同时也很低级的错误。 因为gitee的默认设置不允许二进制文件的提交&#xff0c; 所以PH47框架下的库文件&#xff08;各逻辑层的库文件&#xff09;&#xff0c;以及Stm32Cube驱动的库文件都没上传到Gi…

NVR管理平台EasyNVR多个NVR同时管理:全方位安防监控视频融合云平台方案

EasyNVR是基于端-边-云一体化架构的安防监控视频融合云平台&#xff0c;具有简单轻量的部署方式与多样的功能&#xff0c;支持多种协议&#xff08;如GB28181、RTSP、Onvif、RTMP&#xff09;和设备类型&#xff08;IPC、NVR等&#xff09;&#xff0c;提供视频直播、录像、回放…