Pytorch文本分类入门

🍨 本文为[🔗365天深度学习训练营学习记录博客

🍦 参考文章:365天深度学习训练营

🍖 原作者:[K同学啊 | 接辅导、项目定制]\n🚀 文章来源:[K同学的学习圈子](https://www.yuque.com/mingtian-fkmxf/zxwb45)

一、加载数据

import os
import sys
import PIL
from PIL import Image
import time
import copy
import random
import pathlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.datasets import AG_NEWS
import torchvision
from torchinfo import summary
import torchsummary
import matplotlib.pyplot as plt
import numpy as np
import warnings''' 下载或读取AG News数据集中的训练集与测试集 '''
def getDataset(root, dataset):if not os.path.exists(root) or not os.path.isdir(root):os.makedirs(root)if not os.path.exists(dataset) or not os.path.isdir(dataset):print('Downloading dataset...\n')# 下载AG News数据集 直接运行会报网络错误 无法下载  train_ds, test_ds = AG_NEWS(root=root, split=("train", "test"))else:print('Dataset already downloaded, reading...\n')# 读取本地AG News数据集 手动下载了train.csv和test.csv后可从本地加载数据train_ds, test_ds = AG_NEWS(root=dataset, split=("train", "test"))#print("Train:", next(train_ds), len(list(train_ds))+1)#print("Test :", next(test_ds), len(list(test_ds))+1)return train_ds, test_ds''' 设置GPU '''
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using {} device\n".format(device))
''' 加载数据 '''
root = './data/'
data_dir = os.path.join(root, 'AG_NEWS.data')
train_ds, test_ds = getDataset(root, data_dir)

 运行结果:

Using cuda deviceDataset already downloaded, reading...Train: (3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.") 120000
Test : (3, "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.") 7600

二、构建词典

''' 构建词典 '''
def buildDict(train_ds):tokenizer  = get_tokenizer('basic_english') # 返回分词器函数def yield_tokens(data_iter):for _, text in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_ds))text_pipeline  = lambda x: vocab.lookup_indices(tokenizer(x))label_pipeline = lambda x: int(x)#print(vocab.UNK, vocab._default_unk_index())# 打印默认索引,如果找不到单词,则会选择默认索引#print(vocab.lookup_indices(['here', 'is', 'an', 'example']))#print(text_pipeline('here is the an example'))#print(label_pipeline('10'))return vocab, text_pipeline, label_pipeline# 构建词典
text_pipeline, label_pipeline = buildDict(train_ds)

运行结果: 

120001lines [00:04, 27817.88lines/s]
<unk> 0
[471, 22, 31, 5177]
[471, 22, 3, 31, 5177]
10

三、生成数据批次和迭代器

''' 加载数据,并设置batch_size '''
def loadData(train_ds, test_ds, batch_size=8, device='cpu'):# 构建词典vocab, text_pipeline, label_pipeline = buildDict(train_ds)# 生成数据批次和迭代器def collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_label, _text) in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)# 偏移量,即语句的总词汇量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)text_list  = torch.cat(text_list)offsets    = torch.tensor(offsets[:-1]).cumsum(dim=0) #返回维度dim中输入元素的累计和return label_list.to(device), text_list.to(device), offsets.to(device)# 从 train_ds 加载训练集train_dl = torch.utils.data.DataLoader(train_ds,batch_size=batch_size,shuffle=False,collate_fn=collate_batch,num_workers=0)# 从 test_ds 加载测试集test_dl  = torch.utils.data.DataLoader(test_ds,batch_size=batch_size,shuffle=False,collate_fn=collate_batch,num_workers=0)# 取一个批次查看数据格式#data = train_dl.__iter__()#print(type(data), data, '\n')return vocab, train_dl, test_dl# 生成数据批次和迭代器
batch_size = 64
train_dl, test_dl = loadData(train_ds, test_ds, batch_size=batch_size, device=device)

运行结果:

120001lines [00:04, 27749.13lines/s]
<class 'torch.utils.data.dataloader._SingleProcessDataLoaderIter'> <torch.utils.data.dataloader._SingleProcessDataLoaderIter object at 0x00000266556204C0>

四、构建模型

class TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)      # 将tensor用从均匀分布中抽样得到的值填充self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)        # torch.Size([64, 64])output = self.fc(embedded)      # torch.Size([64, 4])return output
''' 定义实例 '''
train_iter = AG_NEWS(root='./data/AG_NEWS.data', split=("train"))
num_class  = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
em_size    = 64
model      = TextClassificationModel(vocab_size, em_size, num_class).to(device)
print('num_class', num_class)
print('vocab_size', vocab_size)
print(model)
def train(dataloader):model.train()       # 训练模式total_acc, total_count = 0, 0log_interval = 500start_time = time.time()for idx, (label, text, offsets) in enumerate(dataloader):optimizer.zero_grad()predited_label = model(text, offsets)loss = criterion(predited_label, label)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)     # 规定了最大不能超过的max_normoptimizer.step()total_acc += (predited_label.argmax(1) == label).sum().item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:3d} | {:5d}/{:5d} batches, accuracy {:8.3f}'.format(epoch, idx, len(dataloader), total_acc / total_count))total_acc, total_count = 0, 0start_time = time.time()
def evaluate(dataloader):model.eval()total_acc, total_count = 0, 0with torch.no_grad():for idx, (label, text, offsets) in enumerate(dataloader):predited_label = model(text, offsets)# loss = criterion(predited_label, label)total_acc += (predited_label.argmax(1) == label).sum().item()total_count += label.size(0)return total_acc / total_count

五、拆分数据集和运行模型

if __name__ == '__main__':# 超参数(Hyperparameters)EPOCHS = 10  # epochLR = 5  # learning rateBATCH_SIZE = 64  # batch size for trainingcriterion = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD(model.parameters(), lr=LR)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)total_accu = Nonetrain_iter, test_iter = AG_NEWS(root=path)train_dataset = list(train_iter)test_dataset = list(test_iter)num_train = int(len(train_dataset) * 0.95)split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)      # shuffle表示随机打乱valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)accu_val = evaluate(valid_dataloader)if total_accu is not None and total_accu > accu_val:scheduler.step()else:total_accu = accu_valprint('-' * 59)print('| end of epoch {:3d} | time: {:5.2f}s | ''valid accuracy {:8.3f} '.format(epoch, time.time() - epoch_start_time, accu_val))print('-' * 59)torch.save(model.state_dict(), 'output\\model_TextClassification.pth')
| epoch   1 |   500/ 1782 batches, accuracy    0.687
| epoch   1 |  1000/ 1782 batches, accuracy    0.856
| epoch   1 |  1500/ 1782 batches, accuracy    0.875
-----------------------------------------------------------
| end of epoch   1 | time: 23.15s | valid accuracy    0.881
-----------------------------------------------------------
| epoch   2 |   500/ 1782 batches, accuracy    0.898
| epoch   2 |  1000/ 1782 batches, accuracy    0.898
| epoch   2 |  1500/ 1782 batches, accuracy    0.903
-----------------------------------------------------------
| end of epoch   2 | time: 16.20s | valid accuracy    0.897
-----------------------------------------------------------
| epoch   3 |   500/ 1782 batches, accuracy    0.917
| epoch   3 |  1000/ 1782 batches, accuracy    0.915
| epoch   3 |  1500/ 1782 batches, accuracy    0.914
-----------------------------------------------------------
| end of epoch   3 | time: 15.98s | valid accuracy    0.902
-----------------------------------------------------------
| epoch   4 |   500/ 1782 batches, accuracy    0.924
| epoch   4 |  1000/ 1782 batches, accuracy    0.924
| epoch   4 |  1500/ 1782 batches, accuracy    0.922
-----------------------------------------------------------
| end of epoch   4 | time: 16.63s | valid accuracy    0.901
-----------------------------------------------------------
| epoch   5 |   500/ 1782 batches, accuracy    0.937
| epoch   5 |  1000/ 1782 batches, accuracy    0.937
| epoch   5 |  1500/ 1782 batches, accuracy    0.938
-----------------------------------------------------------
| end of epoch   5 | time: 16.37s | valid accuracy    0.912
-----------------------------------------------------------
| epoch   6 |   500/ 1782 batches, accuracy    0.938
| epoch   6 |  1000/ 1782 batches, accuracy    0.939
| epoch   6 |  1500/ 1782 batches, accuracy    0.940
-----------------------------------------------------------
| end of epoch   6 | time: 16.17s | valid accuracy    0.912
-----------------------------------------------------------
| epoch   7 |   500/ 1782 batches, accuracy    0.940
| epoch   7 |  1000/ 1782 batches, accuracy    0.938
| epoch   7 |  1500/ 1782 batches, accuracy    0.943
-----------------------------------------------------------
| end of epoch   7 | time: 16.20s | valid accuracy    0.911
-----------------------------------------------------------
| epoch   8 |   500/ 1782 batches, accuracy    0.941
| epoch   8 |  1000/ 1782 batches, accuracy    0.940
| epoch   8 |  1500/ 1782 batches, accuracy    0.942
-----------------------------------------------------------
| end of epoch   8 | time: 16.46s | valid accuracy    0.911
-----------------------------------------------------------
| epoch   9 |   500/ 1782 batches, accuracy    0.941
| epoch   9 |  1000/ 1782 batches, accuracy    0.941
| epoch   9 |  1500/ 1782 batches, accuracy    0.943
-----------------------------------------------------------
| end of epoch   9 | time: 17.50s | valid accuracy    0.912
-----------------------------------------------------------
| epoch  10 |   500/ 1782 batches, accuracy    0.940
| epoch  10 |  1000/ 1782 batches, accuracy    0.942
| epoch  10 |  1500/ 1782 batches, accuracy    0.942
-----------------------------------------------------------
| end of epoch  10 | time: 16.12s | valid accuracy    0.912
-----------------------------------------------------------

实验目的

  • 构建一个文本分类模型,用于对AG News数据集中的新闻文章进行分类。

数据集

  • 使用的是AG News数据集,包括新闻文章及其相应类别标签。
  • 数据集被分为训练集和测试集。

数据预处理

  • 构建了一个词典(vocab),用于将文本转换为数字表示。
  • 定义了文本和标签的处理流程(text_pipelinelabel_pipeline)。

模型构建

  • 使用了EmbeddingBagLinear层构建了一个简单的文本分类模型。
  • 模型包含词嵌入层,将文本转换为固定大小的向量,随后通过一个全连接层进行分类。

训练过程

  • 使用交叉熵损失函数(CrossEntropyLoss)和随机梯度下降优化器(SGD)。
  • 实现了训练(train)和评估(evaluate)函数。
  • 训练了10个epoch,每个epoch结束后在验证集上评估模型。

结果和调优

  • 在训练过程中,如果验证集上的准确率没有提升,则减小学习率。
  • 每个epoch结束后打印了时间和验证集上的准确率。
  • 最终模型被保存为model_TextClassification.pth

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/635328.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机网络(第六版)复习提纲5

SS2.2 有关信道的几个基本概念 2.通信模型 三个主要部分&#xff1a;信源、信道、信宿 3.通信方式: a)术语&#xff1a;消息&#xff08;传递的内容&#xff09;、数据&#xff08;传递的形式&#xff09;、信号&#xff08;数据表现形式&#xff0c;有模拟信号和数字信号两种&…

前端打同一个包可以从测试晋升到生产的配置方案

前端打同一个包从测试晋升到生产环境的方案&#xff0c;是一种高效、可靠且易于维护的部署方式。在这种方案中&#xff0c;前端代码在开发完成后&#xff0c;经过测试验证无误后&#xff0c;可以直接打包部署到生产环境&#xff0c;无需进行额外的配置或修改。这样可以减少部署…

面试题:40亿个QQ号,限制1G内存,如何去重?

文章目录 概要什么是BitMap&#xff1f;有什么用&#xff1f;什么是布隆过滤器&#xff0c;实现原理是什么&#xff1f;应用场景如何使用 概要 40亿个unsigned int&#xff0c;如果直接用内存存储的话&#xff0c;需要&#xff1a; 4*4000000000 /1024/1024/1024 14.9G &…

关于datagrip的一个错误。Unexpected update count received (Actual: 3, Expected: 1).

这一行原本的值是<null><null><null>,现在我们把它修改为1,114&#xff0c;无名氏&#xff0c;但却报错。 这是对应的sql语句&#xff0c;原因在于有三行全为 <null><null><null>&#xff0c;where无法指定是哪一行&#xff0c;所以看起来…

Vue和React的区别 | | React函数式写法和类写法的区别

Vue 和 React 都是流行的前端框架&#xff0c;它们各自有着独特的特点和适用场景。在这篇文章中&#xff0c;我们将探讨它们的区别&#xff0c;并且给出一些代码实例和解释。 Vue 和 React 的区别: 模板语法与 JSX: 在 Vue 中&#xff0c;我们使用模板语法&#xff0c;它类似…

科普大语言模型中的Embedding技术

什么是大语言模型&#xff1f; 大语言模型是指使用大量的文本数据来训练的深度神经网络&#xff0c;它们可以学习语言的规律和知识&#xff0c;并且可以生成自然的文本。大语言模型的代表有GPT-3、BERT、XLNet等&#xff0c;它们在各种自然语言处理任务中都取得了很好的效果&a…

工程师职称评审的流程

职称评审是对专业技术人员的专业考核评级&#xff0c;通过公平、工作的评审工作选拔优秀且专业的人才。职称评审的流程通常包括以下几个步骤&#xff1a; 公告评审标准和要求&#xff1a;评审机构根据不同行业、专业和职业领域的要求&#xff0c;制定相应的评审标准和要求&…

Visual Studio中,每次新建文件都会自动出现提前设置好的头文件配置方法

主要是修改 newcfile.cpp 文件&#xff0c;可以用everything或者Listary等软件直接搜索文件&#xff0c;直接跳到第4步 1.图标右击——>打开文件所在位置 2.到达IDE地址后在当前目录下找VC文件夹 3.再找 VCProjectItems 文件夹——newcfile.cpp文件 4.用记事本打开&#xff…

市场复盘总结 20240119

仅用于记录当天的市场情况&#xff0c;用于统计交易策略的适用情况&#xff0c;以便程序回测 短线核心&#xff1a;不参与任何级别的调整&#xff0c;采用龙空龙模式 昨日主题投资 连板进级率 11/39 28.2% 二进三&#xff1a; 进级率低 43% 最常用的二种方法&#xff1a; 方…

AWS 专题学习 P5 (Classic SA、S3)

文章目录 Classic Solutions Architecture无状态 Web 应用程序&#xff1a;WhatIsTheTime.com背景 & 目标架构演进Well-Architected 5 pillars 有状态的 Web 应用程序&#xff1a;MyClothes.com背景 & 目标架构演进总结 有状态的 Web 应用程序&#xff1a;MyWordPress.…

springMvc的Aop解析并修改参数

在前后端接口开发过程中&#xff0c;我们常常需要对某些字段进行加解密。以下是使用Aop对接口的get参数做修改的过程&#xff1a; 自定义注解 AesMethod&#xff1a;只能用于方法 Retention(RetentionPolicy.RUNTIME) Target(ElementType.METHOD) public interface AesMetho…

安捷伦E8361C 网络分析仪67GHz

安捷伦E8361C 网络分析仪 E8361C 是 Agilent 的 67 GHz 网络分析仪。网络分析仪是一种功能强大的仪器&#xff0c;可以以无与伦比的精度测量射频设备的线性特性。许多行业使用网络分析仪来测试设备、测量材料和监控信号的完整性。附加功能&#xff1a; 10 MHz 至 67 GHz 94 dB…

强缓存、协商缓存(浏览器的缓存机制)是么子?

文章目录 一.为什么要用强缓存和协商缓存&#xff1f;二.什么是强缓存&#xff1f;三.什么是协商缓存&#xff1f;四.总结 一.为什么要用强缓存和协商缓存&#xff1f; 为了减少资源请求次数&#xff0c;加快资源访问速度&#xff0c;浏览器会对资源文件如图片、css文件、js文…

vue3-侦听器

侦听器 计算属性允许我们声明性地计算衍生值。 需求在状态变化时进行一些操作&#xff0c;比如更改 Dom,根据异步操作结果去修改另外的数据状态。 watch 监听异步请求结果 <script lang"ts" setup> import { ref, watch } from "vue"const ques…

unity 编辑器开发一些记录(遇到了更新)

1、封装Toggle组件 在用toggle等会状态改变的组件时&#xff0c;通过select GUILayout.Toggle(select, text, options)通常是这样做&#xff0c;但是往往有些复杂编辑器需求&#xff0c;当select变化时需要进行复杂的计算&#xff0c;所以不希望每帧去计算select应该的信息。…

虹科分享 | 汽车技术的未来:Netropy如何测试和确保汽车以太网的性能

文章速览&#xff1a; 什么是汽车以太网&#xff1f;汽车以太网的用途是什么&#xff1f;汽车以太网的测试要求是什么&#xff1f;流量生成如何帮助测试汽车以太网&#xff1f; 如今汽车不再是单纯的代步工具&#xff0c;把人从A点带到B点&#xff0c;同时还配备了车载信息娱乐…

java打包及上传到私服务

一、准备Maven私服Nexus 添加saas.maven 仓库地址&#xff1a;http://192.168.31.109:8081/repository/saas.maven 二、新建SpringBoot项目com.saas.pdf 添加类&#xff1a;PdfUtil.java package com.saas.pdf;public class PdfUtil {public static void Save(String fileP…

Qt之使用图片填充QLabel

文章目录 前言实现步骤 前言 本文记录一下使用 QLabel 实现在我们设计的 ui 界面上显示指定的图片&#xff0c;即使用 label 插入图片。 实现步骤 1、右键项目&#xff0c;选择 Add New 2、在弹出对话框中选择“Qt Resource File” 3、命名 qrc 文件并选择添加的文件路径。…

springboot3.2+jdk21 虚拟线程 使用MDC traceId追踪日志

springboot3.2发布了&#xff0c;配合jdk21使用虚拟线程&#xff0c;使用MDC traceId追踪日志方法 关于虚拟线程和MDC traceId这里就不多说了&#xff0c;如果不清楚请自行查询资料 第一步&#xff0c;创建MdcVirtualThreadTaskExecutor /*** author xxley* date 2022/7/25 …

Qt QCustomPlot 绘制子轴

抄大神杰作&#xff1a;QCustomplot&#xff08;五&#xff09;QCPAxisRect进行子绘图-CSDN博客 需求来源&#xff1a;试验数据需要多轴对比。 实现多Y轴、单X轴、X轴是时间轴、X轴range联动、rect之间的间距是0&#xff0c;每个图上有legend(这里有个疑问&#xff0c;每添加…