N3 中文文本分类

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊# 前言

前言

前面学习了相关自然语言编码,这周进行相关实战

导入依赖库和设置设备

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, warningswarnings.filterwarnings("ignore")  # 忽略警告
# win10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

这段代码导入了必要的库并设置了设备(GPU或CPU)。

数据预处理和词汇表构建

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import AG_NEWStrain_iter = AG_NEWS(split='train')
tokenizer = get_tokenizer('basic_english')  # 返回分词器函数def yield_tokens(data_iter):for _, text in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])  # 设置默认索引,如果找不到单词,则会选择默认索引

这里使用torchtext库加载AG_NEWS数据集,定义了一个分词器并构建了词汇表。

数据处理管道

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1
text_pipeline('here is the an example')

定义了两个数据处理管道:text_pipeline用于将文本转化为词汇表中的索引序列,label_pipeline用于将标签转化为整数索引。

定义数据加载器

from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_label, _text) in batch:label_list.append(label_pipeline(_label))processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)  # 返回维度dim中输入元素的累计和return label_list.to(device), text_list.to(device), offsets.to(device)dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

定义了一个collate_batch函数用于将一个批次的数据整合在一起,并创建了一个数据加载器。

定义模型

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)

定义了一个文本分类模型TextClassificationModel,包括初始化函数、权重初始化和前向传播函数。模型由一个嵌入层和一个线性层组成。

训练和评估函数

import timedef train(dataloader):model.train()  # 切换为训练模式total_acc, train_loss, total_count = 0, 0, 0log_interval = 500start_time = time.time()for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()  # grad属性归零loss = criterion(predicted_label, label)  # 计算网络输出和真实值之间的差距,label为真实值loss.backward()  # 反向传播optimizer.step()  # 每一步自动更新total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc / total_count, train_loss / total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval()  # 切换为测试模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label)  # 计算loss值total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc / total_count, train_loss / total_count

定义了训练和评估函数,用于训练模型和评估模型性能。

数据集分割和数据加载器创建

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_datasetEPOCHS = 10  # epoch
LR = 5  # 学习率
BATCH_SIZE = 64  # batch size for trainingcriterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = Nonetrain_iter, test_iter = AG_NEWS()  # 加载数据
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)split_train_, split_valid_ = random_split(train_dataset,[num_train, len(train_dataset) - num_train])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)

加载数据集并将其转换为适用于随机访问的数据集,分割训练集和验证集,并创建相应的数据加载器。

训练和验证模型

for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:1d} | time: {:4.2f}s | ''valid_acc {:4.3f} valid_loss {:4.3f}'.format(epoch,time.time() - epoch_start_time,val_acc, val_loss))print('-' * 69)

进行训练和验证,在每个epoch结束时打印验证准确率和损失,并根据验证结果调整学习率。

测试模型

print('Checking the results of test dataset.')
test_acc, test_loss = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))

在测试集上评估模型性能并打印测试准确率。

结果

在这里插入图片描述

总结

这个案例实现了一个完整的文本分类流程,从数据预处理、模型定义到训练和评估。使用torchtext加载数据,并利用PyTorch构建和训练深度学习模型,实现了对AG_NEWS数据集的文本分类任务,达到了90.1%的精度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/27884.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

虚函数机制-动态绑定的应用

虚函数使得程序在运行的时候根据指针指向对象的类型来确定调用哪个函数。 下图中&#xff1a;都为静态绑定。因为在编译器就确定了可以调用的函数 此时当基类指针指向派生类对象时&#xff0c;因为没有virtual关键字&#xff0c;所以在编译阶段就根据指针类型确定了要指向的函…

秋招突击——第四弹——Java的SSN框架快速入门——Maven

文章目录 引言Maven分模块开发与设计分模块开发的过程 依赖管理可选依赖与排除依赖 继承与聚合聚合继承 属性和版本管理属性扩大集中管理的范围版本管理 多环境开发多环境开发 私服简介安装私服资源操作流程分析上传和下载 总结 引言 前一个部分花了太多时间&#xff0c;后续得…

【Pandas驯化-02】pd.read_csv读取中文出现error解决方法

【Pandas】驯化-02pd.read_csv读取中文出现error解决方法 本次修炼方法请往下查看 &#x1f308; 欢迎莅临我的个人主页 &#x1f448;这里是我工作、学习、实践 IT领域、真诚分享 踩坑集合&#xff0c;智慧小天地&#xff01; &#x1f387; 相关内容文档获取 微信公众号 &…

MEMS:Lecture 17 Noise MDS

讲义 Minimum Detectable Signal (MDS) Minimum Detectable Signal&#xff08;最小可检测信号&#xff09;是指当信号-噪声比&#xff08;Signal-to-Noise Ratio, SNR&#xff09;等于1时的输入信号水平。简单来说&#xff0c;MDS 是一个系统能够分辨出信号存在的最低输入信号…

视频网站下载利器yt-dlp参数详解

yt-dlp 是一个强大的命令行工具&#xff0c;用来下载 YouTube 和其他网站上的视频和音频。它拥有丰富的参数&#xff0c;可以定制下载行为&#xff0c;满足各种需求。本文将详细介绍 yt-dlp 的参数使用。 一、基本参数 -f, –format FORMAT: 指定下载格式&#xff0c;可以用视…

mysql:1205-Lock wait timeout exceeded;try restarting transaction

1.现象 2.分析 使用下面sql在自带数据库的information_schema中查询,注意观察那些长时间开启事务又没完成的进程,然后根据进程的db、操作人、主机、事务开启时间和状态,来排查是什么情况导致的事务未完成(代码异常、执行时间超时等等);我这里是异步作业事务执行时间过长导致的 …

H5拟态个人主页

演示地址&#xff1a;科技语者个人主页 (chgskj.cn) 文末有该项目的源码~ 这张图片的效果你是不是非常想要get同款&#xff1f; 源码就是这个样子 这段HTML代码构建了一个个人主页&#xff0c;结合了CSS样式和JavaScript功能。 下面是对代码的主要组成部分的详细解释&#x…

苏姿丰回忆IBM工作经历 曾参与PS3 Cell处理器开发

AMD首席执行官苏姿丰博士曾在IBM工作了13年&#xff0c;先后担任IBM纽约半导体研发中心的副主管、研发部门主管和CEO特别助理。1998年苹果发布的iMac G3里&#xff0c;使用的PowerPC 750是首个采用铜互连技术的处理器&#xff0c;取代了铝互连技术。此前相关报道中曾提及&#…

深入理解计算机系统 CSAPP 家庭作业6.37

S256 N64时: sumA:这个很简单了,不说了 sumB:如下表. i递增时一直不命中 读到j1,i0 即读a[0][1]时 组0存放的是a[48][0] -a[48][3] 接着读a[1][1]时,组16放的是a[49][0]-a[49][3],j递增之后还是一直不命中 组0:a[0][0]a[16][0]a[32][0]a[48][0]a[0][1]组16:a[1][0]a[17][…

Windows下的zip压缩包版Mysql8.3.0数据迁移到Mysql8.4.0可以用拷贝data文件夹的方式

Windows下的zip压缩包版Mysql8.3.0数据迁移到Mysql8.4.0可以用拷贝data文件夹的方式 拷贝后, 所有账户和数据都是一样的 步骤 停止MySQL服务 net stop mysql 或 sc.exe stop mysql net stop mysqlsc.exe stop mysql卸载 Mysql8.3.0 的服务 mysqld remove 或 mysqld remove m…

idea的java代码引用proto文件报错

尝试了四种办法&#xff0c;感觉第一个和第二个比较有效。 前提是要先安装了 proto 的idea插件。 1.修改idea配置文件编译大文件的限制 proto生成的源文件有数万行&#xff0c;源文件过大导致 idea 拒绝编译过大的源文件。 解决方案&#xff1a; 如果 protoc 生成的 class 文…

C++语法05 浮点型/实数类型

什么是实数类型 实数类型是一种数据类型&#xff0c;实数类型变量里能存放小数和整数。 定义格式&#xff1a;double a; 赋值&#xff1a;a0.4; 输入&#xff1a;cin>>a; 输出&#xff1a;cout<<a; 训练&#xff1a;尺子的价格 小知在文具店买铅笔&#xff…

RIP路由协议汇总(华为)

#交换设备 RIP路由协议汇总 一、原理概述 当网络中路由器的路由条目非常多时&#xff0c;可以通过路由汇总&#xff08;又称路由汇聚或路由聚合&#xff09;来减少路由条目数&#xff0c;加快路由收敛时间和增强网络稳定性。路由汇总的原理是&#xff0c;同一个自然网段内的不…

基数和基数转换

目录 一、定义&#xff1a; 二、各个进制&#xff1a; 1、二进制&#xff1a; 2、八进制&#xff1a; 3、十进制&#xff1a; 4、十六进制&#xff1a; 三、基数转换&#xff1a; 1、各类基数转十进制&#xff1a; 二转十&#xff1a; 八转十&#xff1a; 十六转八&a…

Maven 项目的创建(导入依赖、仓库、maven的配置、配置国内源、以及可能遇到的问题)

一、创建Maven项目 使用的编译软件&#xff1a;idea 软件版本&#xff1a; 社区版 2021.1 - 2022.4&#xff08;为什么选择这个版本&#xff0c;因为只有这个版本里有一些插件是可以安装的&#xff09; 专业版不限制&#xff08;专业版功能是最全的&#xff0c;但是收费&am…

【操作与配置】Pytorch环境搭建

安装显卡驱动 显卡驱动是一种软件程序&#xff0c;用于控制显卡硬件与操作系统之间的通信和交互。显卡驱动负责向操作系统提供有关显卡硬件的信息&#xff0c;以及使操作系统能够正确地控制和管理显卡的各种功能和性能。显卡驱动还包含了针对不同应用程序和游戏的优化&#xff…

C语言入门学习系列:基本语法

目录 引言1. 标准库与头文件2. 语句3. 表达式3.1 表达式在赋值语句中3.2 表达式在控制结构中3.3 表达式作为函数参数3.4 表达式和语句的区别 4. 语句块5. 空格6. 注释7. printf() 函数7.1 基本用法7.2 占位符7.3 输出格式 引言 #include <stdio.h>int main(void) {int a…

能耗分析与远程抄表是什么?

一、引言 在21世纪的数字化时代&#xff0c;能耗分析和远程抄表已成为现代能源管理的重要组成部分。这两项技术不仅提高了能源效率&#xff0c;还为企业和个人提供了更精细的能源使用数据&#xff0c;从而实现更科学的节能减排。 二、能耗分析的深度洞察 能耗分析是通过收集…

深入理解计算机系统 CSAPP 家庭作业6.36

A:100% 数组x的大小是缓存的两倍, x[0][0]-x[0][127]刚好存满512字节,那就意味着x[1][0]映射在缓存的组0,那就意味着x[0][i]和x[1][i]总是读到缓存后又互相替换. B:25% 缓存变为1024字节,意味着x[1][0]被映射在缓存的组128 (组0到127存放x[0][0]到x[0][127]),所以每次读一行…

cs与msf权限传递,以及mimikatz抓取明文密码

cs与msf权限传递&#xff0c;以及mimikatz抓取win10明文密码 1、环境准备2、Cobalt Strike ------> MSF2.1 Cobalt Strike拿权限2.2 将CS权限传递给msf 3、MSF ------> Cobalt Strike3.1 msf拿权限3.2 将msf权限传递给CS 4、使用mimikatz抓取明文密码 1、环境准备 攻击&…