第N6周:使用Word2vec实现文本分类

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings
#忽略警告信息
warnings.filterwarnings("ignore")
# win10系统
device = torch.device("cuda"if torch.cuda.is_available()else"cpu")
deviceimport pandas as pd
# 加载自定义中文数据
train_data= pd.read_csv('./data/train2.csv',sep='\t',header=None)
train_data.head()# 构造数据集迭代器
def coustom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,y
x = train_data[0].values[:]
#多类标签的one-hot展开
y = train_data[1].values[:]from gensim.models.word2vec import Word2Vec
import numpy as np
#训练word2Vec浅层神经网络模型
w2v=Word2Vec(vector_size=100#是指特征向量的维度,默认为100。,min_count=3)#可以对字典做截断。词频少于min_count次数的单词会被丢弃掉,默认为5w2v.build_vocab(x)
w2v.train(x,total_examples=w2v.corpus_count,epochs=20)# 将文本转化为向量
def average_vec(text):vec =np.zeros(100).reshape((1,100))for word in text:try:vec +=w2v.wv[word].reshape((1,100))except KeyError:continuereturn vec
#将词向量保存为Ndarray
x_vec= np.concatenate([average_vec(z)for z in x])
#保存Word2Vec模型及词向量
w2v.save('data/w2v_model.pk1')train_iter= coustom_data_iter(x_vec,y)
len(x),len(x_vec)label_name =list(set(train_data[1].values[:]))
print(label_name)text_pipeline =lambda x:average_vec(x)
label_pipeline =lambda x:label_name.index(x)text_pipeline("你在干嘛")
label_pipeline("Travel-Query")from torch.utils.data import DataLoader
def collate_batch(batch):label_list,text_list=[],[]for(_text,_label)in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text),dtype=torch.float32)text_list.append(processed_text)label_list = torch.tensor(label_list,dtype=torch.int64)text_list = torch.cat(text_list)return text_list.to(device),label_list.to(device)
# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,
shuffle =False,
collate_fn=collate_batch)from torch import nn
class TextclassificationModel(nn.Module):def __init__(self,num_class):super(TextclassificationModel,self).__init__()self.fc = nn.Linear(100,num_class)def forward(self,text):return self.fc(text)num_class =len(label_name)
vocab_size =100000
em_size=12
model= TextclassificationModel(num_class).to(device)import time
def train(dataloader):model.train()#切换为训练模式total_acc,train_loss,total_count =0,0,0log_interval=50start_time= time.time()for idx,(text,label)in enumerate(dataloader):predicted_label= model(text)# grad属性归零optimizer.zero_grad()loss=criterion(predicted_label,label)#计算网络输出和真实值之间的差距,labelloss.backward()#反向传播torch.nn.utils.clip_grad_norm(model.parameters(),0.1)#梯度裁剪optimizer.step()#每一步自动更新#记录acc与losstotal_acc+=(predicted_label.argmax(1)==label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval==0 and idx>0:elapsed =time.time()-start_timeprint('Iepoch {:1d}I{:4d}/{:4d} batches''|train_acc {:4.3f} train_loss {:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))total_acc,train_loss,total_count =0,0,0start_time = time.time()
def evaluate(dataloader):model.eval()#切换为测试模式total_acc,train_loss,total_count =0,0,0with torch.no_grad():for idx,(text,label)in enumerate(dataloader):predicted_label= model(text)loss = criterion(predicted_label,label)# 计算loss值# 记录测试数据total_acc+=(predicted_label.argmax(1)== label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc/total_count,train_loss/total_countfrom torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 超参数
EPOCHS=10#epoch
LR=5 #学习率
BATCH_SIZE=64 # batch size for training
criterion = torch.nn.CrossEntropyLoss()
optimizer= torch.optim.SGD(model.parameters(),lr=LR)
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.1)
total_accu = None
# 构建数据集
train_iter= coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_,split_valid_= random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])
train_dataloader =DataLoader(split_train_,batch_size=BATCH_SIZE,
shuffle=True,collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_,batch_size=BATCH_SIZE,
shuffle=True,collate_fn=collate_batch)
for epoch in range(1,EPOCHS+1):epoch_start_time = time.time()train(train_dataloader)val_acc,val_loss = evaluate(valid_dataloader)# 获取当前的学习率lr =optimizer.state_dict()['param_groups'][0]['1r']if total_accu is not None and total_accu>val_acc:scheduler.step()else:total_accu = val_accprint('-'*69)print('|epoch {:1d}|time:{:4.2f}s |''valid_acc {:4.3f} valid_loss {:4.3f}I1r {:4.6f}'.format(epoch,time.time()-epoch_start_time,val_acc,val_loss,lr))print('-'*69)# test_acc,test_loss =evaluate(valid_dataloader)
# print('模型准确率为:{:5.4f}'.format(test_acc))
#
#
# def predict(text,text_pipeline):
#     with torch.no_grad():
#         text = torch.tensor(text_pipeline(text),dtype=torch.float32)
#         print(text.shape)
#         output = model(text)
#         return output.argmax(1).item()
# # ex_text_str="随便播放一首专辑阁楼里的佛里的歌"
# ex_text_str="还有双鸭山到淮阴的汽车票吗13号的"
# model=model.to("cpu")
# print("该文本的类别是:%s"%label_name[predict(ex_text_str,text_pipeline)])

以上是文本识别基本代码

输出:

[[-0.85472693  0.96605204  1.5058695  -0.06065784 -2.10079319 -0.120211511.41170089  2.00004494  0.90861696 -0.62710127 -0.62408304 -3.805954991.02797993 -0.45584389  0.54715634  1.70490362  2.33389823 -1.996075184.34822938 -0.76296186  2.73265275 -1.15046433  0.82106878 -0.32701646-0.50515595 -0.37742117 -2.02331601 -1.365334    1.48786476 -1.63949711.59438308  2.23569647 -0.00500725 -0.65070192  0.07377997  0.01777986-1.35580809  3.82080549 -2.19764423  1.06595343  0.99296588  0.58972518-0.33535255  2.15471306 -0.52244038  1.00874437  1.28869729 -0.72208139-2.81094289  2.2614549   0.20799019 -2.36187895 -0.94019454  0.49448857-0.68613767 -0.79071895  0.47535057 -0.78339124 -0.71336574 -0.279315671.0514895  -1.76352624  1.93158554 -0.85853558 -0.65540617  1.3612217-1.39405773  1.18187538  1.31730198 -0.02322496  0.14652854  0.222498812.01789951 -0.40144247 -0.39880068 -0.16220299 -2.85221207 -0.277228682.48236791 -0.51239379 -1.47679498 -0.28452797 -2.64497767  2.12093259-1.2326943  -1.89571355  2.3295732  -0.53244872 -0.67313893 -0.808146040.86987564 -1.31373079  1.33797717  1.02223087  0.5817025  -0.835356470.97088164  2.09045361 -2.57758138  0.07126901]]
6

输出结果并非为0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/792180.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于springboot+vue+Mysql的招生管理系统

开发语言:Java框架:springbootJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:…

震惊!!原来阻塞队列消息队列这样理解会更简单!!!

震惊!!原来阻塞队列&&消息队列这样理解会更简单!!! 一:阻塞队列二:消息队列2.1:生产者消费者模型2.1.1:解耦合:2.1.2:削峰填谷: 三:消息队列代码3.1.13.1.2:3.1.3:生产慢,消费快,消费阻塞3.1.3:生产快,消费慢,生产阻塞 二级目录二级目录 一:阻塞队列 阻塞队列:先进先出…

gitcode 配置 SSH 公钥

在 gitcode 上配置SSH公钥后,可以通过SSH协议安全地访问远程仓库,无需每次都输入用户名和密码。以下是配置SSH公钥的步骤: 5分钟解决方案 用 OpenSSH公钥生成器 生成 公钥和私钥,私钥文件(id_rsa)下载&am…

【Leetcode】top 100 图论

基础知识补充 1.图分为有向图和无向图,有权图和无权图; 2.图的表示方法:邻接矩阵适合表示稠密图,邻接表适合表示稀疏图; 邻接矩阵: 邻接表: 基础操作补充 1.邻接矩阵: class GraphAd…

Open3D(C++) 鲁棒损失函数优化的ICP算法

目录 一、损失函数1、关于2、损失函数3、Open3D实现二、代码实现三、结果展示1、配准前1、配准后本文由CSDN点云侠原创,

C语言----数据在内存中的存储

文章目录 前言1.整数在内存中的存储2.大小端字节序和字节序判断2.1 什么是大小端?2.2 练习 3.浮点数在内存中的存储3.1.引子3.2.浮点数的存储3.2.2 浮点数取的过程 前言 下面给大家介绍一下数据在内存中的存储,这个是一个了解c语言内部的知识点&#xf…

【Linux学习】Linux 的虚拟化和容器化技术

˃͈꒵˂͈꒱ write in front ꒰˃͈꒵˂͈꒱ ʕ̯•͡˔•̯᷅ʔ大家好,我是xiaoxie.希望你看完之后,有不足之处请多多谅解,让我们一起共同进步૮₍❀ᴗ͈ . ᴗ͈ აxiaoxieʕ̯•͡˔•̯᷅ʔ—CSDN博客 本文由xiaoxieʕ̯•͡˔•̯᷅ʔ 原创 CSDN 如…

MySQL 导入库/建表时/出现乱码

问题描述: 新建不久的项目在使用Navicat for MySQL进行查看数据,发现表中注释的部分乱码,但是项目中获取的数据使用不会。 猜测因为是数据库编码和项目中使用的不一样,又因为项目的连接语句定义了需要编码,故项目运行…

浅述安防视频监控平台EasyCVR视频汇聚管理系统运维管理能力

智慧安防监控EasyCVR视频管理平台能在复杂的网络环境中,将前端设备统一集中接入与汇聚管理。国标GB28181协议视频监控/视频汇聚EasyCVR平台可以提供实时远程视频监控、视频录像、录像回放与存储、告警、语音对讲、云台控制、平台级联、磁盘阵列存储、视频集中存储、…

在云端遇见雨云:一位服务器寻觅者的指南

引言:寻觅一座云端归宿 当我踏入数字世界的边缘,带着对网络的探索与期待,我迫切需要一座安全可靠的数字栖息地。云计算技术正如一场魔法般的变革,而在这片广袤的云端中,雨云就像是一位友善的向导,引领我穿越…

30.多个线程交替执行

线程一输出a,5次; 线程二输出b,5次; 线程三输出c,5次; 现在要求输出abcabcabcabcabc怎么实现? 采用wait和notifyAll实现 public class ThreadTest {public static void main(String[] args) {WaitNotify waitNotify new Wai…

3DGS实时高质量大规模场景渲染最新SOTA!

作者:小柠檬 | 来源:3DCV 在公众号「3DCV」后台,回复「原论文」可获取论文pdf 添加微信:dddvision,备注:3D高斯,拉你入群。文末附行业细分群 详细内容请关注3DCV 3D视觉精品课程:…

【Java EE】Maven jar 包下载失败问题的解决方法

文章目录 1. 配置好国内的Maven源1.1配置当前项⽬setting1.2设置新项⽬的setting 2.重新下载jar包3.其他问题⭕总结 1. 配置好国内的Maven源 因为中央仓库在国外, 所以下载起来会⽐较慢, 所以咱们选择借助国内⼀些公开的远程仓库来下载资源 接下来介绍, 如何设置国内源 1.1配…

【JAVAEE学习】探究Java中多线程的使用和重点及考点

˃͈꒵˂͈꒱ write in front ꒰˃͈꒵˂͈꒱ ʕ̯•͡˔•̯᷅ʔ大家好,我是xiaoxie.希望你看完之后,有不足之处请多多谅解,让我们一起共同进步૮₍❀ᴗ͈ . ᴗ͈ აxiaoxieʕ̯•͡˔•̯᷅ʔ—CSDN博客 本文由xiaoxieʕ̯•͡˔•̯᷅ʔ 原创 CSDN 如…

Arcgis研究区图经纬度(南北)切换为英文字体(SN)

只在做英文论文研究区图的时候用,平常为了方便还是切换为中文

面试题:JVM 调优

一、JVM 参数设置 1. tomcat 的设置 vm 参数 修改 TOMCAT_HOME/bin/catalina.sh 文件,如下图 JAVA_OPTS"-Xms512m -Xmx1024m" 2. springboot 项目 jar 文件启动 通常在linux系统下直接加参数启动springboot项目 nohup java -Xms512m -Xmx1024m -jar…

作业3:计算机体系结构属性优选

作业3:计算机体系结构属性优选 一. 单选题(共11题,55分) (单选题)下列哪个选项属于非线性结构( )? A. 线性表 B. 栈 C. 树 D. 队列 正确答案: C:树; (单选题) 浮点数在机器中的表示形式如下所…

JS详解-fetch核心语法

document.querySelector(.btn).addEventListener(click,async () > {const p new URLSearchParams({pname:浙江省,cname:杭州市})//1、如何请求?默认为get,参数1 url地址,返回promiseconst res await fetch(http://hmajax.itheima.net/…

给你一个网站如何测试?

主要围绕,功能,页面 UI ,兼容,性能,安全,这几个方面去聊,首先是制定测试计划,确定测试范围和测试策略,一般包括以下几个部分:功能性测试;界面测试…

【打印SQL执行日志】⭐️Mybatis-Plus通过配置在控制台打印执行日志

目录 前言 一、Mybatis-Plus 开启日志的方式 二、测试 三、日志分析 章末 前言 小伙伴们大家好,相信大家平时在处理问题时都有各自的方式,最常用以及最好用的感觉还是断点调试,但是涉及到操作数据库的执行时,默认的话在控制台…