Word2Vec实现文本识别分类

深度学习训练营之使用Word2Vec实现文本识别分类

  • 原文链接
  • 环境介绍
  • 前言
  • 前置工作
    • 设置GPU
    • 数据查看
    • 构建数据迭代器
  • Word2Vec的调用
  • 生成数据批次和迭代器
  • 模型训练
    • 初始化
    • 拆分数据集并进行训练
  • 预测

原文链接

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第N4周:用Word2Vec实现文本分类
  • 🍖 原作者:K同学啊|接辅导、项目定制

环境介绍

  • 语言环境:Python3.9.12
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2

前言

本次内容我本来是使用miniconda的环境的,但是好像有文件发生了损坏,出现了如下报错,据我所了解应该是某个文件发生了损坏,应该是之前将anaconda误删有关,有所了解或者有同样问题的朋友可以一起进行探讨

前置工作

设置GPU

如果

# 先进行数据加载
import torch
import torch.nn as nn
import torchvision
import os,PIL,pathlib,warnings
import time
from torchvision import transforms, datasets
from torch import nn
from torch.utils.data.dataset import random_splitwarnings.filterwarnings("ignore")#忽略警告信息
device=torch.device("cuda"if torch.cuda.is_available()else "cpu")
device

device(type=‘cpu’)

数据查看

本次使用的数据集和之前中文文本识别分类的是一样的

import pandas as pd
train_data=pd.read_csv('train.csv',sep='\t',header=None)
train_data.head()

在这里插入图片描述

构建数据迭代器

#构建数据集迭代器
def coustom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,yx=train_data[0].values[:]
y=train_data[1].values[:]    

添加数据迭代器是为了让数据的随机性增强,进行数据集的划分,可以有效的发挥内存的高利用率

Word2Vec的调用

对Word2Vec进行直接的调用

from gensim.models.word2vec import Word2Vec
import numpy as np
#训练浅层神经网络模型
w2v=Word2Vec(vector_size=100,min_count=3)w2v.build_vocab(x)
w2v.train(x,total_examples=w2v.corpus_count,epochs=30)

build_vocab统计输入每一个词汇出现的次数

def average_vec(text):vec=np.zeros(100).reshape((1,100))#表示平均向量#(n,100),其中n表示x中的元素的数量 for word in text:try:vec+=w2v.wv[word].reshape((1,100))except KeyError:continue#未找到,再进行迭代下一个词return vecx_vec=np.concatenate([average_vec(z) for z in x])
w2v.save('w2v_model.pkl')

该步骤将输入的文本转变成了平均向量
对于输入进来的text当中的每一个单词都进行一个查询,确认是否当中有该词,如果有那么就将其添加到vector当中,否则跳出本层循环,查找下一个词.
最后通过np当中的concatenate方法进行一个向量的连接

train_iter=coustom_data_iter(x_vec,y)#训练迭代器
print(len(x),len(y))

12100 12100

设置训练的迭代器

label_name=list(set(train_data[1].values[:]))
print(label_name)
['FilmTele-Play', 'Weather-Query', 'Audio-Play', 'Radio-Listen', 'HomeAppliance-Control', 'Alarm-Update', 'Travel-Query', 'Video-Play', 'Calendar-Query', 'TVProgram-Play', 'Music-Play', 'Other']

生成数据批次和迭代器

text_pipeline=lambda x:average_vec(x)
label_pipeline=lambda x:label_name.index(x)
#lambda语法:lambda  arguments
text_pipeline("我想你了")

在这里插入图片描述

label_pipeline("Travel-Query")

6

这里的结果每次都会不太一样,具有一定的随机性

from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list= [], []for (_text,_label) in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.float32)text_list.append(processed_text)# 偏移量,即语句的总词汇量label_list = torch.tensor(label_list, dtype=torch.int64)text_list  = torch.cat(text_list)return text_list.to(device),label_list.to(device)# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,shuffle   =False,collate_fn=collate_batch)

和之前的不同在于没有了offset

模型训练

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, num_class):super(TextClassificationModel, self).__init__()self.fc = nn.Linear(100, num_class)def forward(self, text):return self.fc(text)

初始化

num_class  = len(label_name)
vocab_size = 100000
em_size    = 12
model      = TextClassificationModel(num_class).to(device)
import timedef train(dataloader):model.train()  # 切换为训练模式total_acc, train_loss, total_count = 0, 0, 0log_interval = 50start_time   = time.time()for idx, (text,label) in enumerate(dataloader):predicted_label = model(text)optimizer.zero_grad()                    # grad属性归零loss = criterion(predicted_label, label) # 计算网络输出和真实值之间的差距,label为真实值loss.backward()                          # 反向传播torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # 梯度裁剪optimizer.step()  # 每一步自动更新# 记录acc与losstotal_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc/total_count, train_loss/total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval()  # 切换为测试模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (text,label) in enumerate(dataloader):predicted_label = model(text)loss = criterion(predicted_label, label)  # 计算loss值# 记录测试数据total_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count

拆分数据集并进行训练

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 超参数
EPOCHS     = 30 # epoch
LR         = 5  # 学习率
BATCH_SIZE = 64 # batch size for trainingcriterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None# 构建数据集
train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:1d} | time: {:4.2f}s | ''valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch,time.time() - epoch_start_time,val_acc,val_loss,lr))print('-' * 69)
| epoch 1 |   50/ 152 batches | train_acc 0.742 train_loss 0.02635
| epoch 1 |  100/ 152 batches | train_acc 0.820 train_loss 0.02033
| epoch 1 |  150/ 152 batches | train_acc 0.838 train_loss 0.01927
---------------------------------------------------------------------
| epoch 1 | time: 0.95s | valid_acc 0.819 valid_loss 0.023 | lr 5.000000
---------------------------------------------------------------------
| epoch 2 |   50/ 152 batches | train_acc 0.850 train_loss 0.01876
| epoch 2 |  100/ 152 batches | train_acc 0.849 train_loss 0.02012
| epoch 2 |  150/ 152 batches | train_acc 0.847 train_loss 0.01736
---------------------------------------------------------------------
| epoch 2 | time: 0.92s | valid_acc 0.869 valid_loss 0.016 | lr 5.000000
---------------------------------------------------------------------
| epoch 3 |   50/ 152 batches | train_acc 0.858 train_loss 0.01588
| epoch 3 |  100/ 152 batches | train_acc 0.833 train_loss 0.02008
| epoch 3 |  150/ 152 batches | train_acc 0.864 train_loss 0.01813
---------------------------------------------------------------------
| epoch 3 | time: 0.86s | valid_acc 0.835 valid_loss 0.023 | lr 5.000000
---------------------------------------------------------------------
| epoch 4 |   50/ 152 batches | train_acc 0.883 train_loss 0.01309
| epoch 4 |  100/ 152 batches | train_acc 0.899 train_loss 0.00996
| epoch 4 |  150/ 152 batches | train_acc 0.895 train_loss 0.00927
---------------------------------------------------------------------
| epoch 4 | time: 0.87s | valid_acc 0.888 valid_loss 0.011 | lr 0.500000
---------------------------------------------------------------------
| epoch 5 |   50/ 152 batches | train_acc 0.906 train_loss 0.00834
...
| epoch 30 |  150/ 152 batches | train_acc 0.900 train_loss 0.00717
---------------------------------------------------------------------
| epoch 30 | time: 0.92s | valid_acc 0.886 valid_loss 0.010 | lr 0.000000
---------------------------------------------------------------------
test_acc, test_loss = evaluate(valid_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))

在这里插入图片描述

预测

def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text),dtype=torch.float32)print(text.shape)output = model(text)return output.argmax(1).item()ex_text_str = "随便播放一首专辑阁楼里的佛里的歌"
#ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"
model = model.to(device)print("该文本的类别是:%s" % label_name[predict(ex_text_str, text_pipeline)])
torch.Size([1, 100])
该文本的类别是:Music-Play

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/7024.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

探析国内数字孪生引擎技术现状

在数字孪生软件来发中,渲染引擎是一个关键点,国内大多数字孪生平台引擎通常使用的是自研的渲染引擎或者采用开源的渲染引擎。下面通过一些常见的渲染引擎在国内数字孪生引擎中的应用带大家了解数字孪生软件开发的方式。 自研渲染引擎:许多数…

HTTPS安全套接字层超文本传输协议

HTTPS安全套接字层超文本传输协议 HTTPS简介HTTPS和HTTP的主要区别客户端在使用HTTPS方式与Web服务器通信时的步骤SSL/TLS协议的加密(握手)过程为什么数据传输阶段使用对称加密HTTPS 的优点HTTPS 的缺点HTTPS 的优化证书优化会话复用 HTTPS简介 HTTP协议…

文件包含漏洞利用思路

简介 通过PHP函数引入文件时,传入的文件名没有经过合理的验证,从而操作了预想之外的文件,导致意外的文件泄漏甚至恶意代码注入。 常见的文件包含函数 php中常见的文件包含函数有以下四种: include()require()include_once()re…

苍穹外卖day05——Redis(被病毒入侵)+店铺营业状态设置

Redis被病毒入侵了 数据删光光然后只剩这四个玩意,乱下东西乱删东西,还好是docker部署,不然就寄了。 在服务器上部署redis记得一定要设置密码,不然被人扫肉鸡注入病毒整个服务器给你崩掉。 使用配置类的方式搭建相关程序 配置数…

实现简单Spring基于XML的配置程序

定义一个容器,使用ConcurrentHashMap 做为单例对象的容器 先解析beans.xml得到第一个bean对象的信息,id,class,属性和属性值使用反射生成对象,并赋值将创建好的bean对象放入到singletonObjects集合中提供getBean(id)方…

【Redis】剖析RDB和AOF持久化原理

文章目录 前言1、AOF日志1.1、概述1.2、日志文件1.3、写回策略1.4、策略实现原理1.5、重写机制1.6、AOF 后台重写1.6.1、介绍1.6.2、实现原理 1.7、优缺点 2、RDB快照2.1、概述2.2、实现方式2.3、实现原理2.4、极端情况2.5、优缺点 3、混合体实现4、大Key问题4.1、何为大key4.2…

profinet 调试记录

一、 树莓派运行codesys runtime 1. 用户名称要以 root 登录 若是普通用户,会提示:脚本必须以 root 身份运行 2. codesys报错: 在树莓派config.txt文件添加:arm_64bit0 3. 扫描设备需开启PLC 图标变红,则开启成…

【MATLAB第58期】基于MATLAB的PCA-Kmeans、PCA-LVQ与BP神经网络分类预测模型对比

【MATLAB第58期】基于MATLAB的PCA-Kmeans、PCA-LVQ与BP神经网络分类预测模型对比 一、数据介绍 基于UCI葡萄酒数据集进行葡萄酒分类及产地预测 共包含178组样本数据,来源于三个葡萄酒产地,每组数据包含产地标签及13种化学元素含量,即已知类…

STM32H5开发(1)----总览

STM32H5开发----1.总览 概述样品申请STM32H5-2MB 框图产品列表STM32H5-2MB 框图STM32H5-128KB框图功能对比STM32H5-128KB vs H5-2MB组员对比STM32H5 亮点 概述 STM32H5系列微控制器是意法半导体公司推出的一款高性能MCU, CortexM33内核的微控制器产品。 他和STM32F2、F4、F7、…

论文精度系列之详解图神经网络

论文地址:A Gentle Introduction to Graph Neural Networks 翻译:图表就在我们身边;现实世界的对象通常根据它们与其他事物的连接来定义。一组对象以及它们之间的连接自然地表示为图形。十多年来,研究人员已经开发了对图数据进行操作的神经网络(称为图神…

CentOS 7.9 安装 mydumper(RPM方式)

链接:https://pan.baidu.com/s/1sGhtiKPOmJw1xj0zv-djkA?pwdtaoz 码:taoz 开始正文啦: rpm -ivh mydumper-0.14.5-3-zstd.el7.x86_64.rpm 问题如下: 解决: yum -y install epel-release yum install -y libzstd …

zabbix安装Grafana

一、web访问 https://s3-us-west-2.amazonaws.com/grafana-releases/release/grafana-4.6.1-1.x86_64.rpm [rootserver ~] yum localinstall -y grafana-4.6.1-1.x86_64.rpm //yum方式安装本地rpm并自动解决依赖关系 [rootserver ~] grafana-cli plugins install alexanderzob…

利用 trait 实现多态

我在书上看到基于 std::io::Write 的示例,它是一个 trait 类型,内部声明了一些方法。和 go 语言不同,rust 中类型必须明确实现 trait 类型,而 go 语言属于 duck 模式。 std::io::Write下面的例子中调用 write_all 方式来演示&…

国标GB28181视频监控平台EasyGBS无法播放,抓包返回ICMP的排查过程

国标GB28181视频平台EasyGBS是基于国标GB/T28181协议的行业内安防视频流媒体能力平台,可实现的视频功能包括:实时监控直播、录像、检索与回看、语音对讲、云存储、告警、平台级联等功能。国标GB28181视频监控平台部署简单、可拓展性强,支持将…

1 请使用js、css、html技术实现以下页面,表格内容根据查询条件动态变化。

1.1 创建css文件,用于编辑style 注意: 1.背景颜色用ppt的取色器来获取: 先点击ppt的形状轮廓,然后点击取色器,吸颜色,然后再点击形状轮廓的其他轮廓颜色,即可获取到对应颜色。 2.表格间的灰色线…

【Spring Boot】Web开发 — 数据验证

Web开发 — 数据验证 对于应用系统而言,任何客户端传入的数据都不是绝对安全有效的,这就要求我们在服务端接收到数据时也对数据的有效性进行验证,以确保传入的数据安全正确。接下来介绍Spring Boot是如何实现数据验证的。 1.Hibernate Vali…

生态合作丨MemFireDB通过麒麟软件NeoCertify认证

近日,敏博科技“MemFireDB分布式关系数据库系统V2.8”与麒麟软件“银河麒麟高级服务器操作系统V10” 完成兼容性测试,获得麒麟软件 NeoCertify 认证证书。测试结果显示,MemFireDB数据库在国产操作系统上运行稳定,产品已经达到通用…

android studio(火烈鸟版本)使用protobuf

一、简介 Protobuf 全称:Protocol Buffers,是 Google 推出的一种与平台无关、语言无关、可扩展的轻便高效的序列化数据存储格式,类似于我们常用的 xml 和 json。 二、特点 Protobuf 用两个字总结:小,快。用 Protobu…

十大排序算法详解

目录 1. 冒泡排序 a. 思路 b. code 2. 插入排序 a. 思路 b. code 3. 希尔排序【插入排序plus】 a. 思路 b. code 4. 选择排序 a. 思路 b. code 5. 基数排序 a. 前置知识 b. 思路 c. code 6. 计数排序 a. 思路 b. code 7. 桶排序(计数排序plus &…

Could not resolve placeholder

本质原因:项目启动未扫描到该配置,一般来说是配置不对 检查方向 1、检查编译后的target包里是否有该配置所在的文件 如果不在就clear,重新编译启动再去检查 2、检查启动的环境是否匹配 编译后的target包下的配置文件名称是否跟启动类的环境…