第N6周:使用Word2vec实现文本分类

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings
#忽略警告信息
warnings.filterwarnings("ignore")
# win10系统
device = torch.device("cuda"if torch.cuda.is_available()else"cpu")
deviceimport pandas as pd
# 加载自定义中文数据
train_data= pd.read_csv('./data/train2.csv',sep='\t',header=None)
train_data.head()# 构造数据集迭代器
def coustom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,y
x = train_data[0].values[:]
#多类标签的one-hot展开
y = train_data[1].values[:]from gensim.models.word2vec import Word2Vec
import numpy as np
#训练word2Vec浅层神经网络模型
w2v=Word2Vec(vector_size=100#是指特征向量的维度,默认为100。,min_count=3)#可以对字典做截断。词频少于min_count次数的单词会被丢弃掉,默认为5w2v.build_vocab(x)
w2v.train(x,total_examples=w2v.corpus_count,epochs=20)# 将文本转化为向量
def average_vec(text):vec =np.zeros(100).reshape((1,100))for word in text:try:vec +=w2v.wv[word].reshape((1,100))except KeyError:continuereturn vec
#将词向量保存为Ndarray
x_vec= np.concatenate([average_vec(z)for z in x])
#保存Word2Vec模型及词向量
w2v.save('data/w2v_model.pk1')train_iter= coustom_data_iter(x_vec,y)
len(x),len(x_vec)label_name =list(set(train_data[1].values[:]))
print(label_name)text_pipeline =lambda x:average_vec(x)
label_pipeline =lambda x:label_name.index(x)text_pipeline("你在干嘛")
label_pipeline("Travel-Query")from torch.utils.data import DataLoader
def collate_batch(batch):label_list,text_list=[],[]for(_text,_label)in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text),dtype=torch.float32)text_list.append(processed_text)label_list = torch.tensor(label_list,dtype=torch.int64)text_list = torch.cat(text_list)return text_list.to(device),label_list.to(device)
# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,
shuffle =False,
collate_fn=collate_batch)from torch import nn
class TextclassificationModel(nn.Module):def __init__(self,num_class):super(TextclassificationModel,self).__init__()self.fc = nn.Linear(100,num_class)def forward(self,text):return self.fc(text)num_class =len(label_name)
vocab_size =100000
em_size=12
model= TextclassificationModel(num_class).to(device)import time
def train(dataloader):model.train()#切换为训练模式total_acc,train_loss,total_count =0,0,0log_interval=50start_time= time.time()for idx,(text,label)in enumerate(dataloader):predicted_label= model(text)# grad属性归零optimizer.zero_grad()loss=criterion(predicted_label,label)#计算网络输出和真实值之间的差距,labelloss.backward()#反向传播torch.nn.utils.clip_grad_norm(model.parameters(),0.1)#梯度裁剪optimizer.step()#每一步自动更新#记录acc与losstotal_acc+=(predicted_label.argmax(1)==label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval==0 and idx>0:elapsed =time.time()-start_timeprint('Iepoch {:1d}I{:4d}/{:4d} batches''|train_acc {:4.3f} train_loss {:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))total_acc,train_loss,total_count =0,0,0start_time = time.time()
def evaluate(dataloader):model.eval()#切换为测试模式total_acc,train_loss,total_count =0,0,0with torch.no_grad():for idx,(text,label)in enumerate(dataloader):predicted_label= model(text)loss = criterion(predicted_label,label)# 计算loss值# 记录测试数据total_acc+=(predicted_label.argmax(1)== label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc/total_count,train_loss/total_countfrom torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 超参数
EPOCHS=10#epoch
LR=5 #学习率
BATCH_SIZE=64 # batch size for training
criterion = torch.nn.CrossEntropyLoss()
optimizer= torch.optim.SGD(model.parameters(),lr=LR)
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.1)
total_accu = None
# 构建数据集
train_iter= coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_,split_valid_= random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])
train_dataloader =DataLoader(split_train_,batch_size=BATCH_SIZE,
shuffle=True,collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_,batch_size=BATCH_SIZE,
shuffle=True,collate_fn=collate_batch)
for epoch in range(1,EPOCHS+1):epoch_start_time = time.time()train(train_dataloader)val_acc,val_loss = evaluate(valid_dataloader)# 获取当前的学习率lr =optimizer.state_dict()['param_groups'][0]['1r']if total_accu is not None and total_accu>val_acc:scheduler.step()else:total_accu = val_accprint('-'*69)print('|epoch {:1d}|time:{:4.2f}s |''valid_acc {:4.3f} valid_loss {:4.3f}I1r {:4.6f}'.format(epoch,time.time()-epoch_start_time,val_acc,val_loss,lr))print('-'*69)# test_acc,test_loss =evaluate(valid_dataloader)
# print('模型准确率为:{:5.4f}'.format(test_acc))
#
#
# def predict(text,text_pipeline):
#     with torch.no_grad():
#         text = torch.tensor(text_pipeline(text),dtype=torch.float32)
#         print(text.shape)
#         output = model(text)
#         return output.argmax(1).item()
# # ex_text_str="随便播放一首专辑阁楼里的佛里的歌"
# ex_text_str="还有双鸭山到淮阴的汽车票吗13号的"
# model=model.to("cpu")
# print("该文本的类别是:%s"%label_name[predict(ex_text_str,text_pipeline)])

以上是文本识别基本代码

输出:

[[-0.85472693  0.96605204  1.5058695  -0.06065784 -2.10079319 -0.120211511.41170089  2.00004494  0.90861696 -0.62710127 -0.62408304 -3.805954991.02797993 -0.45584389  0.54715634  1.70490362  2.33389823 -1.996075184.34822938 -0.76296186  2.73265275 -1.15046433  0.82106878 -0.32701646-0.50515595 -0.37742117 -2.02331601 -1.365334    1.48786476 -1.63949711.59438308  2.23569647 -0.00500725 -0.65070192  0.07377997  0.01777986-1.35580809  3.82080549 -2.19764423  1.06595343  0.99296588  0.58972518-0.33535255  2.15471306 -0.52244038  1.00874437  1.28869729 -0.72208139-2.81094289  2.2614549   0.20799019 -2.36187895 -0.94019454  0.49448857-0.68613767 -0.79071895  0.47535057 -0.78339124 -0.71336574 -0.279315671.0514895  -1.76352624  1.93158554 -0.85853558 -0.65540617  1.3612217-1.39405773  1.18187538  1.31730198 -0.02322496  0.14652854  0.222498812.01789951 -0.40144247 -0.39880068 -0.16220299 -2.85221207 -0.277228682.48236791 -0.51239379 -1.47679498 -0.28452797 -2.64497767  2.12093259-1.2326943  -1.89571355  2.3295732  -0.53244872 -0.67313893 -0.808146040.86987564 -1.31373079  1.33797717  1.02223087  0.5817025  -0.835356470.97088164  2.09045361 -2.57758138  0.07126901]]
6

输出结果并非为0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/792180.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

DD3L内存芯片介绍

在数字科技迅猛发展的今天,内存芯片作为硬件的核心组件之一,扮演着至关重要的角色。而DD3L内存芯片以其卓越的性能和独特的设计,成为众多高端电子设备的不二选择。那么,DD3L内存芯片究竟如何应用,又是如何释放数字世界…

导入预览以及解决导入量大引发超时等问题

1、首选解决预览问题 由于使用的是vue3,页面与数据都是交互响应式的,所以可以通过组件或者原生的文件上传,获取到excel的sheet,从而来计算条数,页码,页数,手动实现分页逻辑,也就是把…

flex:1的作用是什么?

占满剩余的高度 <div classfather><div classson1></div><div classson2></div> </div>当给father添加display:flex之后&#xff0c;假设给son2添加flex:1&#xff0c;那么son2将会占满除son1之外的高度

【DevOps工具篇】Keycloak中设置LDAP认证

【DevOps工具篇】Keycloak中设置LDAP认证 目录 【DevOps工具篇】Keycloak中设置LDAP认证本次使用的环境服务器配置LDAP目录结构使用存储在LDAP中的用户进行登录Keycloak配置步骤功能测试从LDAP向Keycloak批量添加用户Keycloak配置步骤功能测试推荐超级课程: Docker快速入门到精…

基于springboot+vue+Mysql的招生管理系统

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…

震惊!!原来阻塞队列消息队列这样理解会更简单!!!

震惊!!原来阻塞队列&&消息队列这样理解会更简单!!! 一:阻塞队列二:消息队列2.1:生产者消费者模型2.1.1:解耦合:2.1.2:削峰填谷: 三:消息队列代码3.1.13.1.2:3.1.3:生产慢,消费快,消费阻塞3.1.3:生产快,消费慢,生产阻塞 二级目录二级目录 一:阻塞队列 阻塞队列:先进先出…

gitcode 配置 SSH 公钥

在 gitcode 上配置SSH公钥后&#xff0c;可以通过SSH协议安全地访问远程仓库&#xff0c;无需每次都输入用户名和密码。以下是配置SSH公钥的步骤&#xff1a; 5分钟解决方案 用 OpenSSH公钥生成器 生成 公钥和私钥&#xff0c;私钥文件&#xff08;id_rsa&#xff09;下载&am…

解决Maven Clean过程因内存溢出导致的“Process terminated”问题

正文&#xff1a; 在今天的开发过程中&#xff0c;我遇到了一个意料之外的问题&#xff0c;当我尝试使用 Maven 进行项目清理&#xff08;maven clean&#xff09;时&#xff0c;命令行反馈了一个错误信息&#xff1a;“Process terminated”。经过一番排查&#xff0c;发现问…

为什么资讯网站选择高防IP防护攻击

为什么资讯网站选择高防IP防护攻击&#xff1f;在当今信息爆炸的时代&#xff0c;各种资讯网站扮演着重要的角色&#xff0c;为人们提供丰富的信息资源和资讯服务。然而&#xff0c;随着网络环境的不断变化和网络安全威胁的增加&#xff0c;资讯网站面临着越来越严重的网络攻击…

【Leetcode】top 100 图论

基础知识补充 1.图分为有向图和无向图&#xff0c;有权图和无权图&#xff1b; 2.图的表示方法&#xff1a;邻接矩阵适合表示稠密图&#xff0c;邻接表适合表示稀疏图&#xff1b; 邻接矩阵&#xff1a; 邻接表&#xff1a; 基础操作补充 1.邻接矩阵&#xff1a; class GraphAd…

JavaScript class继承

如果用新的class关键字来编写Student&#xff0c;可以这样写&#xff1a; class Student { constructor(name) { this.name name; } hello() {alert(Hello, this.name !); }} class的定义包含了构造函数constructor和定义在原型对象上的函数hello()&#xff08;注意没有f…

Open3D(C++) 鲁棒损失函数优化的ICP算法

目录 一、损失函数1、关于2、损失函数3、Open3D实现二、代码实现三、结果展示1、配准前1、配准后本文由CSDN点云侠原创,

C语言----数据在内存中的存储

文章目录 前言1.整数在内存中的存储2.大小端字节序和字节序判断2.1 什么是大小端&#xff1f;2.2 练习 3.浮点数在内存中的存储3.1.引子3.2.浮点数的存储3.2.2 浮点数取的过程 前言 下面给大家介绍一下数据在内存中的存储&#xff0c;这个是一个了解c语言内部的知识点&#xf…

百问网FreeRTOS学习笔记第50到56讲

/*出处&#xff1a;https://video.100ask.net/p/t_pc/course_pc_detail/column/p_6503fadfe4b064a82f0ab191本专栏一切无特殊声明的知识转述&#xff08;源码、文字以及图表&#xff09;版权均归属于百问网&#xff0c;源码仅供学习&#xff0c;请勿用于商业用途&#xff1b;不…

GoPro相机使用的文件格式和频率

打开GoPro相机(以11为例)&#xff0c;里面是一个DCIM文件夹。 DCIM是digital camera in memory 的简写&#xff0c;即存照片的文件夹&#xff0c;常见于数码相机、手机存储卡中的文件夹名字。 正常手机拍照和视频都是保存在此文件夹的。正常建议不用删&#xff0c;因为只要拍照…

Vue3中props和emits的使用总结

Vue3中props和emits的使用介绍 1&#xff0c;看代码1.1&#xff0c;App.vue1.2&#xff0c;TodoItem.vue 2&#xff0c;总结2.1 props2.2 emits 1&#xff0c;看代码 1.1&#xff0c;App.vue <script setup> import { ref,reactive } from vue import TodoItem from ./…

用队列实现栈(C)

目录 题目&#xff1a; 解题&#xff1a; 代码讲解&#xff1a; 1.构建 2.creat 3.压栈 4.出栈 5.判空 6.释放 题目&#xff1a; 请你仅使用两个队列实现一个后入先出&#xff08;LIFO&#xff09;的栈&#xff0c;并支持普通栈的全部四种操作&#xff08;push、top、po…

【Linux学习】Linux 的虚拟化和容器化技术

˃͈꒵˂͈꒱ write in front ꒰˃͈꒵˂͈꒱ ʕ̯•͡˔•̯᷅ʔ大家好&#xff0c;我是xiaoxie.希望你看完之后,有不足之处请多多谅解&#xff0c;让我们一起共同进步૮₍❀ᴗ͈ . ᴗ͈ აxiaoxieʕ̯•͡˔•̯᷅ʔ—CSDN博客 本文由xiaoxieʕ̯•͡˔•̯᷅ʔ 原创 CSDN 如…

如何设计一个类似Dubbo的RPC框架

首先有个注册中心,提供的服务在注册中心注册保留各个服务的信息,用zookeeper来做。然后消费者需要去注册中心拿对应的服务信息,而且每个服务可能会存在于多台机器上。接着就发起一次请求了,怎么发起?基于动态代理,面向接口获取到一个动态代理,就是接口在本地的一个代理,…

MySQL 导入库/建表时/出现乱码

问题描述&#xff1a; 新建不久的项目在使用Navicat for MySQL进行查看数据&#xff0c;发现表中注释的部分乱码&#xff0c;但是项目中获取的数据使用不会。 猜测因为是数据库编码和项目中使用的不一样&#xff0c;又因为项目的连接语句定义了需要编码&#xff0c;故项目运行…