深度学习笔记39_Pytorch文本分类入门

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

一、我的环境

1.语言环境:Python 3.8

2.编译器:Pycharm

3.深度学习环境:

  • torch==1.12.1+cu113
  • torchvision==0.13.1+cu113

、导入数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")             #忽略警告信息
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")from torchtext.datasets import AG_NEWS
train_iter = AG_NEWS(split='train')      # 加载 AG News 数据集

、构建词典

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iteratortokenizer  = get_tokenizer('basic_english') # 返回分词器函数def yield_tokens(data_iter):for _, text in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) # 设置默认索引,如果找不到单词,则会选择默认索引
print(vocab(['here', 'is', 'an', 'example']))

结果: [475, 21, 30, 5297]

text_pipeline  = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1
print(text_pipeline('here is the an example'))
结果:[475, 21, 2, 30, 5297]
print(label_pipeline('10'))
结果:10

生成数据批次和迭代器

from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_label, _text) in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)# 偏移量,即语句的总词汇量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)text_list  = torch.cat(text_list)offsets    = torch.tensor(offsets[:-1]).cumsum(dim=0) #返回维度dim中输入元素的累计和return label_list.to(device), text_list.to(device), offsets.to(device)# 数据加载器
dataloader = DataLoader(train_iter,batch_size=8,shuffle   =False,collate_fn=collate_batch)

定义模型

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size,   # 词典大小embed_dim,    # 嵌入的维度sparse=False) # self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)

定义实例

num_class  = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
em_size     = 64
model      = TextClassificationModel(vocab_size, em_size, num_class).to(device)

定义训练函数与评估函数

import timedef train(dataloader):model.train()  # 切换为训练模式total_acc, train_loss, total_count = 0, 0, 0log_interval = 500start_time   = time.time()for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()                    # grad属性归零loss = criterion(predicted_label, label) # 计算网络输出和真实值之间的差距,label为真实值loss.backward()                          # 反向传播optimizer.step()  # 每一步自动更新# 记录acc与losstotal_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc/total_count, train_loss/total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval()  # 切换为测试模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label)  # 计算loss值# 记录测试数据total_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count

 结果:

| epoch 1 | 500/1782 batches| train_acc 0.901 train_loss 0.00458
| epoch 1 | 1000/1782 batches| train_acc 0.905 train_loss 0.00438
| epoch 1 | 1500/1782 batches| train_acc 0.908 train_loss 0.00437
---------------------------------------------------------------------
| epoch 1 | time:6.30s |valid_acc 0.907 | valid_loss 0.004
---------------------------------------------------------------------
| epoch 2 | 500/1782 batches| train_acc 0.917 train_loss 0.00381
| epoch 2 | 1000/1782 batches| train_acc 0.917 train_loss 0.00383
| epoch 2 | 1500/1782 batches| train_acc 0.917 train_loss 0.00386
---------------------------------------------------------------------
| epoch 2 | time:6.26s |valid_acc 0.911 | valid_loss 0.004
---------------------------------------------------------------------
| epoch 3 | 500/1782 batches| train_acc 0.929 train_loss 0.00330
| epoch 3 | 1000/1782 batches| train_acc 0.927 train_loss 0.00340
| epoch 3 | 1500/1782 batches| train_acc 0.923 train_loss 0.00354
---------------------------------------------------------------------
| epoch 3 | time:6.21s |valid_acc 0.935 | valid_loss 0.003
---------------------------------------------------------------------
| epoch 4 | 500/1782 batches| train_acc 0.933 train_loss 0.00306
| epoch 4 | 1000/1782 batches| train_acc 0.932 train_loss 0.00311
| epoch 4 | 1500/1782 batches| train_acc 0.929 train_loss 0.00318
---------------------------------------------------------------------
| epoch 4 | time:6.22s |valid_acc 0.916 | valid_loss 0.003
---------------------------------------------------------------------
| epoch 5 | 500/1782 batches| train_acc 0.948 train_loss 0.00253
| epoch 5 | 1000/1782 batches| train_acc 0.949 train_loss 0.00242
| epoch 5 | 1500/1782 batches| train_acc 0.951 train_loss 0.00238
---------------------------------------------------------------------
| epoch 5 | time:6.23s |valid_acc 0.954 | valid_loss 0.002
---------------------------------------------------------------------
| epoch 6 | 500/1782 batches| train_acc 0.951 train_loss 0.00241
| epoch 6 | 1000/1782 batches| train_acc 0.952 train_loss 0.00236
| epoch 6 | 1500/1782 batches| train_acc 0.952 train_loss 0.00235
---------------------------------------------------------------------
| epoch 6 | time:6.26s |valid_acc 0.954 | valid_loss 0.002
---------------------------------------------------------------------
| epoch 7 | 500/1782 batches| train_acc 0.954 train_loss 0.00228
| epoch 7 | 1000/1782 batches| train_acc 0.951 train_loss 0.00238
| epoch 7 | 1500/1782 batches| train_acc 0.954 train_loss 0.00228
---------------------------------------------------------------------
| epoch 7 | time:6.26s |valid_acc 0.954 | valid_loss 0.002
---------------------------------------------------------------------
| epoch 8 | 500/1782 batches| train_acc 0.953 train_loss 0.00227
| epoch 8 | 1000/1782 batches| train_acc 0.955 train_loss 0.00224
| epoch 8 | 1500/1782 batches| train_acc 0.954 train_loss 0.00224
---------------------------------------------------------------------
| epoch 8 | time:6.32s |valid_acc 0.954 | valid_loss 0.002
---------------------------------------------------------------------
| epoch 9 | 500/1782 batches| train_acc 0.955 train_loss 0.00218
| epoch 9 | 1000/1782 batches| train_acc 0.953 train_loss 0.00227
| epoch 9 | 1500/1782 batches| train_acc 0.955 train_loss 0.00227
---------------------------------------------------------------------
| epoch 9 | time:6.24s |valid_acc 0.954 | valid_loss 0.002
---------------------------------------------------------------------
| epoch 10 | 500/1782 batches| train_acc 0.952 train_loss 0.00229
| epoch 10 | 1000/1782 batches| train_acc 0.955 train_loss 0.00220
| epoch 10 | 1500/1782 batches| train_acc 0.956 train_loss 0.00220
---------------------------------------------------------------------
| epoch 10 | time:6.29s |valid_acc 0.954 | valid_loss 0.002
---------------------------------------------------------------------

定义训练函数与评估函数

print('Checking the results of test dataset.')
test_acc, test_loss = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))
Checking the results of test dataset.
test accuracy    0.905

总结: 

  1. 预训练词向量:使用GloVe、FastText等预训练词向量能显著提升性能

  2. 正则化:合理使用dropout、权重衰减等技术防止过拟合

  3. 超参数调优:学习率、批大小、隐藏层维度等对模型性能影响很大

  4. 迁移学习:对于小数据集,考虑使用BERT等预训练模型进行微调

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/76217.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

二分查找-LeetCode

题目 给定一个 n 个元素有序的&#xff08;升序&#xff09;整型数组 nums 和一个目标值 target&#xff0c;写一个函数搜索 nums 中的 target&#xff0c;如果目标值存在返回下标&#xff0c;否则返回 -1。 示例 1: 输入: nums [-1,0,3,5,9,12], target 9 输出: 4 解释: …

从 Ext 到 F2FS,Linux 文件系统与存储技术全面解析

与 Windows 和 macOS 操作系统不同&#xff0c;Linux 是由爱好者社区开发的大型开源项目。它的代码始终可供那些想要做出贡献的人使用&#xff0c;任何人都可以根据个人需求自由调整它&#xff0c;或在其基础上创建自己的发行版本。这就是为什么 Linux 存在如此多的变体&#x…

leetcode:3210. 找出加密后的字符串(python3解法)

难度&#xff1a;简单 给你一个字符串 s 和一个整数 k。请你使用以下算法加密字符串&#xff1a; 对于字符串 s 中的每个字符 c&#xff0c;用字符串中 c 后面的第 k 个字符替换 c&#xff08;以循环方式&#xff09;。 返回加密后的字符串。 示例 1&#xff1a; 输入&#xff…

JVM详解(曼波脑图版)

(✪ω✪)&#xff89; 好哒&#xff01;曼波会用最可爱的比喻给小白同学讲解JVM&#xff0c;准备好开启奇妙旅程了吗&#xff1f;(๑˃̵ᴗ˂̵)و &#x1f4cc; 思维导图 ━━━━━━━━━━━━━━━━━━━ &#x1f34e; JVM是什么&#xff1f;&#xff08;苹果式比…

ZStack文档DevOps平台建设实践

&#xff08;一&#xff09;前言 对于软件产品而言&#xff0c;文档是不可或缺的一环。文档能帮助用户快速了解并使用软件&#xff0c;包括不限于特性概览、用户手册、API手册、安装部署以及场景实践教程等。由于软件与文档紧密耦合&#xff0c;面对业务的瞬息万变以及软件的飞…

Git创建分支操作指南

1. 创建新分支但不切换&#xff08;仅创建&#xff09; git branch <分支名>示例&#xff1a;创建一个名为 new-feature 的分支git branch new-feature2. 创建分支并立即切换到该分支 git checkout -b <分支名> # 传统方式 # 或 git switch -c <分支名&g…

package.json 中的那些版本数字前面的符号是什么意思?

1. 语义化版本&#xff08;SemVer&#xff09; 语义化版本的格式是 MAJOR.MINOR.PATCH&#xff0c;其中&#xff1a; MAJOR&#xff1a;主版本号&#xff0c;表示不兼容的 API 修改。MINOR&#xff1a;次版本号&#xff0c;表示新增功能但保持向后兼容。PATCH&#xff1a;修订号…

如何有效防止服务器被攻击

首先&#xff0c;我们要明白服务器被攻击的危害有多大。据不完全统计&#xff0c;每年因服务器遭受攻击而导致的经济损失高达数十亿。这可不是一个小数目&#xff0c;就好比您辛苦积攒的财富&#xff0c;瞬间被人偷走了一大半。 要有效防止服务器被攻击&#xff0c;第一步就是…

Chainlit 快速构建Python LLM应用程序

背景 chainlit 是一款简单易用的Web UI goggle&#xff0c;它支持使用 Python 语言快速构建 LLM 应用程序&#xff0c;提供了丰富的功能&#xff0c;包括文本分析&#xff0c;情感分析等。 这里我们以官网openai提供的例子&#xff0c;快速的开发一个带有UI的聊天界面&#xf…

华为OD机试真题——硬件产品销售方案(2025A卷:100分)Java/python/JavaScript/C++/C语言/GO六种最佳实现

2025 A卷 100分 题型 本文涵盖详细的问题分析、解题思路、代码实现、代码详解、测试用例以及综合分析&#xff1b; 并提供Java、python、JavaScript、C、C语言、GO六种语言的最佳实现方式&#xff01; 2025华为OD真题目录全流程解析/备考攻略/经验分享 华为OD机试真题《硬件产品…

【数据结构_6】双向链表的实现

一、实现MyDLinkedList&#xff08;双向链表&#xff09; package LinkedList;public class MyDLinkedList {//首先我们要创建节点&#xff08;因为双向链表和单向链表的节点不一样&#xff01;&#xff01;&#xff09;static class Node{public String val;public Node prev…

做Data+AI的长期主义者,加速全球化战略布局

在Data与AI深度融合的新纪元&#xff0c;唯有秉持长期主义方能真正释放数智化的深层价值。2025年是人工智能从技术爆发转向规模化落地的关键节点&#xff0c;也是标志着袋鼠云即将迎来十周年的重要里程碑。2025年4月16日&#xff0c;袋鼠云成功举办了“做DataAI的长期主义者——…

构建基于PHP和MySQL的解梦系统:设计与实现

引言 梦境解析一直是人类心理学和文化研究的重要领域。随着互联网技术的发展,构建一个在线的解梦系统能够帮助更多人理解自己梦境的含义。本文将详细介绍如何使用PHP和MySQL构建一个功能完整的解梦系统,包括系统架构设计、数据库模型、核心功能实现以及优化策略。 本文源码下…

【桌面】【系统应用】Samba共享文件夹

目录 场景一&#xff1a;银河麒麟桌面与银河麒麟桌面之间共享文件夹 环境准备 实现目标 操作步骤 &#xff08;一&#xff09;配置主机A共享文件夹 1、环境准备 2、在主机A创建共享文件夹 3、设置共享文件密码 &#xff08;二&#xff09;主机B访问主机A 场景二&…

OpenCV 图形API(37)图像滤波-----分离过滤器函数sepFilter()

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 应用一个可分离的线性滤波器到一个矩阵&#xff08;图像&#xff09;。 该函数对矩阵应用一个可分离的线性滤波器。也就是说&#xff0c;首先&a…

webpack理解与使用

一、背景 webpack的最初目标是实现前端工程的模块化&#xff0c;旨在更高效的管理和维护项目中的每一个资源。 最早的时候&#xff0c;我们通过文件划分的方式实现模块化&#xff0c;也就是将每个功能及其相关状态数据都放在一个JS文件中&#xff0c;约定每个文件就是一个独立…

rac环境下,增加一个控制文件controlfile

先关闭节点二&#xff0c;在节点一上操作 1、查看控制文件个数和路径 SQL> show parameter control 2、备份参数文件 SQL> create pfile/home/oracle/orcl.pfile20250417 from spfile; 3、修改控制文件参数 SQL> alter system set contr…

git安装(windows)

通过网盘分享的文件&#xff1a;资料(1) 链接: https://pan.baidu.com/s/1MAenYzcQ436MlKbIYQidoQ 提取码: evu6 点击next 可修改安装路径 默认就行 一般从命令行调用&#xff0c;所以不用创建。 用vscode&#xff0c;所以这么选择。

Spring Boot整合难点?AI一键生成全流程解决方案

在当今的软件开发领域&#xff0c;Spring Boot 凭借其简化开发流程、快速搭建项目的优势&#xff0c;成为了众多开发者的首选框架。然而&#xff0c;Spring Boot 的整合过程并非一帆风顺&#xff0c;常常会遇到各种难点。而飞算 JavaAI 的出现&#xff0c;为解决这些问题提供了…

Python批量处理PDF图片详解(插入、压缩、提取、替换、分页、旋转、删除)

目录 一、概述 二、 使用工具 三、Python 在 PDF 中插入图片 3.1 插入图片到现有PDF 3.2 插入图片到新建PDF 3.3 批量插入多张图片到PDF 四、Python 提取 PDF 图片及其元数据 五、Python 替换 PDF 图片 5.1 使用图片替换图片 5.2 使用文字替换图片 六、Python 实现 …