Pytorch 文本情感分类案例

一共六个脚本,分别是:

        ①generateDictionary.py用于生成词典

        ②datasets.py定义了数据集加载的方法

        ③models.py定义了网络模型

        ④configs.py配置一些参数

        ⑤run_train.py训练模型

        ⑥run_test.py测试模型

数据集icon-default.png?t=N7T8https://download.csdn.net/download/Victor_Li_/88486959?spm=1001.2014.3001.5501停用词表icon-default.png?t=N7T8https://download.csdn.net/download/Victor_Li_/88486973?spm=1001.2014.3001.5501

generateDictionary.py如下

import jiebadata_path = "./weibo_senti_100k.csv"
data_stop_path = "./hit_stopwords.txt"
data_list = open(data_path,encoding='utf-8').readlines()[1:]
stops_word = open(data_stop_path,encoding='utf-8').readlines()
stops_word = [line.strip() for line in stops_word]
stops_word.append(" ")
stops_word.append("\n")voc_dict = {}
min_seq = 1
top_n = 1000
UNK = "UNK"
PAD = "PAD"
for item in data_list:label = item[0]content = item[2:].strip()seg_list = jieba.cut(content,cut_all=False)seg_res = []for seg_item in seg_list:if seg_item in stops_word:continueseg_res.append(seg_item)if seg_item in voc_dict.keys():voc_dict[seg_item] += 1else:voc_dict[seg_item] = 1# print(content)# print(seg_res)voc_list = sorted([_ for _ in voc_dict.items() if _[1] > min_seq],key=lambda x:x[1],reverse=True)[:top_n]voc_dict = {word_count[0]:idx for idx,word_count in enumerate(voc_list)}voc_dict.update({UNK:len(voc_dict),PAD:len(voc_dict)+1})ff = open("./dict","w")
for item in voc_dict.keys():ff.writelines("{},{}\n".format(item,voc_dict[item]))
ff.close()

datasets.py如下

from torch.utils.data import Dataset, DataLoader
import jieba
import numpy as npdef read_dict(voc_dict_path):voc_dict = {}with open(voc_dict_path, 'r') as f:for line in f:line = line.strip()if line == '':continueword, index = line.split(",")voc_dict[word] = int(index)return voc_dictdef load_data(data_path, data_stop_path,isTest):data_list = open(data_path, encoding='utf-8').readlines()[1:]stops_word = open(data_stop_path, encoding='utf-8').readlines()stops_word = [line.strip() for line in stops_word]stops_word.append(" ")stops_word.append("\n")voc_dict = {}data = []max_len_seq = 0for item in data_list:label = item[0]content = item[2:].strip()seg_list = jieba.cut(content, cut_all=False)seg_res = []for seg_item in seg_list:if seg_item in stops_word:continueseg_res.append(seg_item)if seg_item in voc_dict.keys():voc_dict[seg_item] += 1else:voc_dict[seg_item] = 1if len(seg_res) > max_len_seq:max_len_seq = len(seg_res)if isTest:data.append([label, seg_res,content])else:data.append([label, seg_res])return data, max_len_seqclass text_ClS(Dataset):def __init__(self, data_path, data_stop_path,voc_dict_path,isTest=False):self.isTest = isTestself.data_path = data_pathself.data_stop_path = data_stop_pathself.voc_dict = read_dict(voc_dict_path)self.data, self.max_len_seq = load_data(self.data_path, self.data_stop_path,isTest)np.random.shuffle(self.data)def __len__(self):return len(self.data)def __getitem__(self, item):data = self.data[item]label = int(data[0])word_list = data[1]if self.isTest:content = data[2]input_idx = []for word in word_list:if word in self.voc_dict.keys():input_idx.append(self.voc_dict[word])else:input_idx.append(self.voc_dict["UNK"])if len(input_idx) < self.max_len_seq:input_idx += [self.voc_dict["PAD"] for _ in range(self.max_len_seq - len(input_idx))]data = np.array(input_idx)if self.isTest:return label,data,contentelse:return label, datadef data_loader(dataset,config):return DataLoader(dataset,batch_size=config.batch_size,shuffle=config.is_shuffle,num_workers=4,pin_memory=True)

models.py如下

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Model(nn.Module):def __init__(self,config):super(Model,self).__init__()self.embeding = nn.Embedding(config.n_vocab,config.embed_size,padding_idx=config.n_vocab - 1)self.lstm = nn.LSTM(config.embed_size,config.hidden_size,config.num_layers,batch_first=True,bidirectional=True,dropout=config.dropout)self.maxpool = nn.MaxPool1d(config.pad_size)self.fc = nn.Linear(config.hidden_size * 2 + config.embed_size,config.num_classes)self.softmax = nn.Softmax(dim=1)def forward(self,x):embed = self.embeding(x)out, _ = self.lstm(embed)out = torch.cat((embed, out), 2)out = F.relu(out)out = out.permute(0, 2, 1)out = self.maxpool(out).reshape(out.size()[0],-1)out = self.fc(out)out = self.softmax(out)return out

configs.py如下

import torch.typesclass Config():def __init__(self):self.n_vocab = 1002self.embed_size = 256self.hidden_size = 256self.num_layers = 5self.dropout = 0.8self.num_classes = 2self.pad_size = 32self.batch_size = 32self.is_shuffle = Trueself.learning_rate = 0.001self.num_epochs = 100self.devices = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

run_train.py如下

import torch
import torch.nn as nn
from torch import optim
from models import Model
from datasets import data_loader,text_ClS
from configs import Config
import time
import torch.multiprocessing as mpif __name__ == '__main__':mp.freeze_support()cfg = Config()data_path = "./weibo_senti_100k.csv"data_stop_path = "./hit_stopwords.txt"dict_path = "./dict"dataset = text_ClS(data_path, data_stop_path, dict_path)train_dataloader = data_loader(dataset,cfg)cfg.pad_size = dataset.max_len_seqmodel_text_cls = Model(cfg)model_text_cls.to(cfg.devices)loss_func = nn.CrossEntropyLoss()optimizer = optim.Adam(model_text_cls.parameters(), lr=cfg.learning_rate)scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)for epoch in range(cfg.num_epochs):running_loss = 0correct = 0total = 0epoch_start_time = time.time()for i,(labels,datas) in enumerate(train_dataloader):datas = datas.to(cfg.devices)labels = labels.to(cfg.devices)pred = model_text_cls.forward(datas)loss_val = loss_func(pred,labels)running_loss += loss_val.item()loss_val.backward()if ((i + 1) % 4 == 0) or (i + 1 == len(train_dataloader)):optimizer.step()optimizer.zero_grad()_, predicted = torch.max(pred.data, 1)correct += (predicted == labels).sum().item()total += labels.size(0)scheduler.step()accuracy_train = 100 * correct / totalepoch_end_time = time.time()epoch_time = epoch_end_time - epoch_start_timetain_loss = running_loss / len(train_dataloader)print("Epoch [{}/{}],Time: {:.4f}s,Loss: {:.4f},Acc: {:.2f}%".format(epoch + 1, cfg.num_epochs, epoch_time, tain_loss,accuracy_train))torch.save(model_text_cls.state_dict(),"./text_cls_model/text_cls_model{}.pth".format(epoch))

run_test.py如下

import torch
import torch.nn as nn
from torch import optim
from models import Model
from datasets import data_loader,text_ClS
from configs import Config
import time
import torch.multiprocessing as mpif __name__ == '__main__':mp.freeze_support()cfg = Config()data_path = "./test.csv"data_stop_path = "./hit_stopwords.txt"dict_path = "./dict"cfg.batch_size = 1dataset = text_ClS(data_path, data_stop_path, dict_path,isTest=True)dataloader = data_loader(dataset,cfg)cfg.pad_size = dataset.max_len_seqmodel_text_cls = Model(cfg)model_text_cls.load_state_dict(torch.load('./text_cls_model/text_cls_model0.pth'))model_text_cls.to(cfg.devices)classes_name = ['负面的','正面的']for i,(label,input,content) in enumerate(dataloader):label = label.to(cfg.devices)input = input.to(cfg.devices)pred = model_text_cls.forward(input)_, predicted = torch.max(pred.data, 1)print("内容:{}, 实际结果:{}, 预测结果:{}".format(content,classes_name[label],classes_name[predicted[0]]))

测试结果如下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/125774.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

narak靶机攻略

narak靶机攻略 扫描 渗透 cewl http://10.4.7.158 > use1.txthydra -L use1.txt -P use1.txt http-get://10.4.7.158/webdav -V -t 50 -fyamdoot:Swargcadaver http://10.4.7.158/webdav<?php $ip10.4.7.158; $port12138; $sock fsockopen($ip, $port); $descriptors…

识别flink的反压源头

背景 flink中最常见的问题就是反压&#xff0c;这种情况下我们要正确的识别导致反压的真正的源头&#xff0c;本文就简单看下如何正确识别反压的源头 反压的源头 首先我们必须意识到现实中轻微的反压是没有必要去优化的&#xff0c;因为这种情况下是由于偶尔的流量峰值,Task…

JavaScript的高级概述

还记得我们刚刚开始的时候给JavaScript的定义吗&#xff1f; JavaScript是一种高级的&#xff0c;面向对象的&#xff0c;多范式变成语言&#xff01; 这种定义JavaScript只是冰山一角&#xff01; JavaScript的高级定义 JavaScript是一种高级的、基于原型的、面向对象、多范…

网络协议--TCP的保活定时器

23.1 引言 许多TCP/IP的初学者会很惊奇地发现可以没有任何数据流通过一个空闲的TCP连接。也就是说&#xff0c;如果TCP连接的双方都没有向对方发送数据&#xff0c;则在两个TCP模块之间不交换任何信息。例如&#xff0c;没有可以在其他网络协议中发现的轮询。这意味着我们可以…

Node.js的基本概念node -v 和npm -v 这两个命令的作用

Node.js 是一个开源且跨平台的 JavaScript 运行时环境&#xff0c;它可以让你在服务器端运行 JavaScript 代码。Node.js 使用了 Chrome 的 V8 JavaScript 引擎来执行代码&#xff0c;非常高效。 在 Node.js 出现之前&#xff0c;JavaScript 通常只在浏览器中运行&#xff0c;用…

第 369 场 LeetCode 周赛题解

A 找出数组中的 K-or 值 模拟 class Solution { public:int findKOr(vector<int> &nums, int k) {vector<int> cnt(32);for (auto x: nums)for (int i 0; i < 32; i)if (x >> i & 1)cnt[i];int res 0;for (int i 0; i < 32; i)if (cnt[i] &…

uniapp 开发微信小程序 v-bind给子组件传递函数,该函数中的this不是父组件的二是子组件的this

解决办法&#xff1a;子组件通过缓存子组件this然后&#xff0c;用bind改写this 这个方法因为定义了全局变量that 那么该变量就只能用一次&#xff0c;不然会有赋值覆盖的情况。 要么就弃用v-bind传入函数,改为emit传入自定义事件 [uniapp] uview(1.x) 二次封装u-navbar 导致…

ubuntu配置 Conda 更改默认环境路径

我的需求是以后凡是新建一个虚拟环境都需要安装在一个挂载了大容量的分区/data里面 /home里面的是即将爆满但是还能塞点东西的硬盘. 如果您想要永久更改 Conda 的默认环境路径&#xff0c;可以编辑 Conda 的配置文件。首先&#xff0c;找到 Conda 的配置文件通常是 .condarc 文…

如何读懂深度学习python项目,以`Multi-label learning from single positive label`为例

Paper : Multi-label learning from single positive label Code 先读一读README.md 可能有意想不到的收获&#xff1b; 实验环境设置要仔细看哦&#xff01; 读论文 如何读论文&#xff0c;Readpaper经典十问 &#xff08;可能在我博客里有写&#xff09; How to read a …

html获取网络数据,列表展示 一

html获取网络数据&#xff0c;列表展示 js遍历json数组中的json对象 image.png || - 判断数据是否为空&#xff0c;为空就显示 - <!DOCTYPE html> <html><head><meta charset"utf-8"><title>网页列表</title></head><b…

省钱兄短剧短视频视频滑动播放模块源码支持微信小程序h5安卓IOS

# 开源说明 开源省钱兄短剧系统的播放视频模块&#xff08;写了测试弄了好久才弄出来、最核心的模块、已经实战了&#xff09;&#xff0c;使用uniapp技术&#xff0c;提供学习使用&#xff0c;支持IOSAndroidH5微信小程序&#xff0c;使用Hbuilder导入即可运行 #注意&#xff…

JavaScript从入门到精通系列第二十六篇:详解JavaScript中的Math对象

大神链接&#xff1a;作者有幸结识技术大神孙哥为好友&#xff0c;获益匪浅。现在把孙哥视频分享给大家。 孙哥连接&#xff1a;孙哥个人主页 作者简介&#xff1a;一个颜值99分&#xff0c;只比孙哥差一点的程序员 本专栏简介&#xff1a;话不多说&#xff0c;让我们一起干翻J…

无人机真机搭建问题记录文档(待续)

搭建问题 问题1 高飞课程中的飞控停产&#xff0c;更换飞控&#xff08;pixhawx 6c&#xff09;出现如下问题 1、飞控太大造成安装机载电脑的碳板上的孔被挡住。 2、课程提供的飞控固件&#xff0c;与更换的飞控不匹配 解决办法 1、现在的无人机碳板上只安装三个螺纹孔&…

GPS学习(一):在ROS2中将GPS经纬度数据转换为机器人ENU坐标系,在RVIZ中显示坐标轨迹

文章目录 一、GPS模块介绍二、坐标转换转换原理参数解释&#xff1a; 增加回调函数效果演示 本文记录在Ubuntu22.04-Humbel中使用NMEA协议GPS模块的过程&#xff0c;使用国产ROS开发板鲁班猫(LubanCat )进行调试。 一、GPS模块介绍 在淘宝找了款性价比较高的轮趣科技GPS北斗双…

Linux 远程桌面软件

为您的 IT 管理员配备最好的 Linux 远程桌面软件至关重要。原因如下&#xff1f;Linux 是一个开源和免费的操作系统&#xff0c;它提供了一个非常灵活和可定制的软件内核。由于其开源性质&#xff0c;Linux 被认为是市场上最安全的操作系统之一&#xff0c;它拥有一个全球用户社…

Qt 中model/View 架构 详解,以及案例实现相薄功能

model/View 架构 导读 ​ 我们的系统需要显示大量数据,比如从数据库中读取数据,以自己的方式显示在自己的应用程序的界面中。早期的 Qt 要实现这个功能,需要定义一个组件,在这个组件中保存一个数据对象,比如一个列表。我们对这个列表进行查找、插入等的操作,或者把修改…

PHP自定义文件缓存实现

文件缓存&#xff1a;可以将PHP脚本的执行结果缓存到文件中。当一个PHP脚本被请求时&#xff0c;先查看是否存在缓存文件&#xff0c;如果存在且未过期&#xff0c;则直接读取缓存文件内容返回给客户端&#xff0c;而无需执行脚本 1、文件缓存写法一&#xff0c;每个文件缓存一…

语音信号处理给音乐信号增加房间混响效果

语音信号处理给音乐信号增加房间混响效果 是否需要申请加入数字音频系统研究开发交流答疑群(课题组)?可加我微信hezkz17, 本群提供音频技术答疑服务 1 源码布局 2 源文件与音频文件和生成文件 3 编译方法

「Qt中文教程指南」如何创建基于Qt Widget的应用程序(四)

Qt 是目前最先进、最完整的跨平台C开发工具。它不仅完全实现了一次编写&#xff0c;所有平台无差别运行&#xff0c;更提供了几乎所有开发过程中需要用到的工具。如今&#xff0c;Qt已被运用于超过70个行业、数千家企业&#xff0c;支持数百万设备及应用。 本文描述了如何使用…

点云配准--对称式ICP

对称式ICP 写在前面的话 针对于局部平面不完美的情况&#xff0c;提出了一种对称式ICP目标函数&#xff0c;相较于传统的ICP方法&#xff0c;增大了收敛域&#xff0c;提高了收敛速度。论文理论说明不甚清楚&#xff0c;实验较少&#xff0c;但代码开源。 理论 对称目标函数…