基于LSTM的文本分类1——模型搭建

源码

# coding: UTF-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Config(object):"""配置参数类,用于存储模型和训练的超参数"""def __init__(self, dataset, embedding):self.model_name = 'TextRNN'  # 模型名称self.train_path = dataset + '/data/train.txt'  # 训练集路径self.dev_path = dataset + '/data/dev.txt'      # 验证集路径self.test_path = dataset + '/data/test.txt'    # 测试集路径self.class_list = [x.strip() for x in open(dataset + '/data/class.txt').readlines()]  # 类别列表self.vocab_path = dataset + '/data/vocab.pkl'  # 词表路径self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'  # 模型保存路径self.log_path = dataset + '/log/' + self.model_name  # 日志保存路径# 加载预训练词向量(若提供)self.embedding_pretrained = torch.tensor(np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32')) \if embedding != 'random' else Noneself.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 训练设备# 模型超参数self.dropout = 0.5              # 随机失活率self.require_improvement = 1000 # 若超过该batch数效果未提升,则提前终止训练self.num_classes = len(self.class_list)  # 类别数self.n_vocab = 0                # 词表大小(运行时赋值)self.num_epochs = 10            # 训练轮次self.batch_size = 128           # 批次大小self.pad_size = 32              # 句子填充/截断长度self.learning_rate = 1e-3       # 学习率# 词向量维度(使用预训练时与预训练维度对齐,否则设为300)self.embed = self.embedding_pretrained.size(1) \if self.embedding_pretrained is not None else 300self.hidden_size = 128          # LSTM隐藏层维度self.num_layers = 2             # LSTM层数'''基于LSTM的文本分类模型'''
class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()# 词嵌入层:加载预训练词向量或随机初始化if config.embedding_pretrained is not None:self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)else:self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)# 双向LSTM层self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,bidirectional=True, batch_first=True, dropout=config.dropout)# 全连接分类层self.fc = nn.Linear(config.hidden_size * 2, config.num_classes)  # 双向LSTM输出维度翻倍def forward(self, x):x, _ = x  # 输入x为(padded_seq, seq_len),此处取padded_seqout = self.embedding(x)  # [batch_size, seq_len, embed_dim]out, _ = self.lstm(out)  # LSTM输出维度 [batch_size, seq_len, hidden_size*2]# 取最后一个时间步的输出作为句子表示out = self.fc(out[:, -1, :])  # [batch_size, num_classes]return out

数据集

上图是我们这次做的文本分类。一共十个话题领域,我们的目标是输入一句话,模型能够实现对话题领域的区分。

上图是我们使用的数据集。前面的汉字部分是模型学习的文本,后面接一个tab键是对该文本的分类。

配置类

配置的重点是模型的超参数,这里分析一下模型涉及的超参数。

Dropout随机失活率

self.dropout = 0.5

在LSTM层之间随机屏蔽部分神经元输出,强迫模型学习冗余特征表示。公式:hdrop=h⊙mhdrop​=h⊙m,其中mm是伯努利分布的0-1掩码。

早停阈值

elf.require_improvement = 1000

早停阈值的思想是:连续N个batch在验证集无精度提升则终止训练。首次训练数据的时候可能摸不清楚情况,设置了较大的epoch值,浪费掉大量训练时间。假设batch_size=128,数据集1万样本 , 每个epoch大约有78个batch。1000个batch的耐心期大约是13个epoch。

序列填充长度

self.pad_size = 32

序列填充长度的作用是,将变长文本序列处理为固定长度,满足神经网络批量处理的要求 。如果文本长度小于32,则填充特定的字符。如果文本长度大于32,则进行截断,保留32个字符。

序列填充长度通常使用95分位方式获得,获取代码如下

import numpy as np
lengths = [len(text.split()) for text in train_texts]
pad_size = int(np.percentile(lengths, 95))  # 覆盖95%样本

词向量维度

self.embed = 300

词向量的维度决定了语义空间的自由度 。假设我们使用字分割,每个文字对应一个300维的向量,将向量输入到模型中完成训练。

可以得出,向量维数越多,可以包含的信息数量就越多。但是并不是维度越高越好,下面的表是高维和低维的对比。

因子低维(d=50)高维(d=1024)
语义区分度相似词易混淆可学习细粒度差异
计算复杂度O(Vd) 内存占用低GPU显存需求高
训练数据需求1M+ tokens即可需100M+ tokens
下游任务适配性适合简单分类任务适合语义匹配任务

由于我们的数据量较小,所以使用较低的词向量维度。另外,如果使用预训练模型,词向量维度的值需要和预训练模型的值相同。

LSTM隐藏层维度

self.hidden_size = 128

隐藏层维度先卖个关子,下一章LSTM模型解析的时候讲。

模型搭建

Input Text → Embedding Layer → Bidirectional LSTM → Last Timestep Output → FC Layer → Class Probabilities

文本是无法直接被计算机识别的,所以文本需要映射为稠密向量才能输入给模型。因此在输入模型前要加一步向量映射。

class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()# 词嵌入层:加载预训练词向量或随机初始化if config.embedding_pretrained is not None:self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)else:self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)# 双向LSTM层self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,bidirectional=True, batch_first=True, dropout=config.dropout)# 全连接分类层self.fc = nn.Linear(config.hidden_size * 2, config.num_classes)  # 双向LSTM输出维度翻倍def forward(self, x):x, _ = x  # 输入x为(padded_seq, seq_len),此处取padded_seqout = self.embedding(x)  # [batch_size, seq_len, embed_dim]out, _ = self.lstm(out)  # LSTM输出维度 [batch_size, seq_len, hidden_size*2]# 取最后一个时间步的输出作为句子表示out = self.fc(out[:, -1, :])  # [batch_size, num_classes]return out

词嵌入层

首先构建词嵌入层,将本地的预训练embedding加载到pytorch里面。

双向LSTM层

我们使用双向LSTM模型,即将文本从左到右训练一次,也从右到左(倒着来)训练一次。

参数名作用说明典型值
input_size输入特征维度(等于嵌入维度)300
hidden_size隐藏层维度128/256
num_layersLSTM堆叠层数2-4
bidirectional启用双向LSTMTrue
batch_first输入输出使用(batch, seq, *)格式True
dropout层间dropout概率(仅当num_layers>1时生效)0.5

全连接分类层

self.fc = nn.Linear(config.hidden_size * 2, config.num_classes)

 全连接的输入通道数是隐藏层维度的两倍,原因是我们的模型是双向的,双向的结果都需要输出给全连接层。

前向传播

def forward(self, x):x, _ = x  # 解包(padded_seq, seq_len)out = self.embedding(x)  # [batch, seq_len, embed_dim]out, _ = self.lstm(out)  # [batch, seq_len, 2*hidden_size]out = self.fc(out[:, -1, :])  # 取最后时刻的输出return out

首先提取输入x的填充张量。可以看到张量里有4760这种值,这个值是我们在文字长度不够时的填充内容。

 经过embedding映射后可以看到,张量out里的数据变成128*32*300的维度,300的维度就是词向量维度,可以看到data里的数据都由原来的整数映射成了向量。

经过lstm运算后,out张量数据变成了128*32*128的维度

 最终经过全连接层,out张量变成了128*10维度的张量。128是batch_size,10个维度即代表该条数据在10个分类中的概率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/74184.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小了 60,500 倍,但更强;AI 的“深度诅咒”

作者:Ignacio de Gregorio 图片来自 Unsplash 的 Bahnijit Barman 几周前,我们看到 Anthropic 尝试训练 Claude 去通关宝可梦。模型是有点进展,但离真正通关还差得远。 但现在,一个独立的小团队用一个只有一千万参数的模型通关了…

nextjs使用02

并行路由 同一个页面,放多个路由,, 目录前面加,layout中可以当作插槽引入 import React from "react";function layout({children,notifications,user}:{children:React.ReactNode,notifications:React.ReactNode,user:React.Re…

github 无法在shell里链接

当我在shell端git push时,我发现总是22 timeout的问题。 我就进行了以下步骤的尝试并最终得到了解决。 第一步,我先确定我可以curl github,也就是我网络没问题 curl -v https://github.com 如果这个时候不超时和报错,说明网络…

当前主流的大模型知识库软件对比分析

以下是当前主流的大模型知识库软件对比分析,涵盖功能特性、适用场景及优劣势,结合最新技术动态和行业实践提供深度选型参考: 一、企业级智能知识库平台 1. 阿里云百炼(Model Studio) 核心能力:基于RAG技…

Java的比较器 Comparable 和 Comparator

在 Java 中,Comparable 和 Comparator 是用于对象排序的重要接口。它们提供了不同的排序方式,适用于不同的需求,同时在 Java 底层排序算法中发挥着关键作用。本文将从基础概念、使用方法、排序实现(包括升序、降序)、底…

基于Qlearning强化学习的太赫兹信道信号检测与识别matlab仿真

目录 1.算法仿真效果 2.算法涉及理论知识概要 2.1 太赫兹信道特性 2.2 Q-learning强化学习基础 2.3 基于Q-learning 的太赫兹信道信号检测与识别系统 3.MATLAB核心程序 4.完整算法代码文件获得 1.算法仿真效果 matlab2024b仿真结果如下(完整代码运行后无水印…

力扣刷题————199.二叉树的右视图

给定一个二叉树的 根节点 root,想象自己站在它的右侧,按照从顶部到底部的顺序,返回从右侧所能看到的节点值。 示例 1: 输入:root [1,2,3,null,5,null,4] 输出:[1,3,4] 解题思路:我们可以想到这…

文件包含漏洞的小点总结

文件本地与远程包含: 文件包含有本地包含与远程包含的区别:本地包含只能包含服务器已经有的问题; 远程包含可以包含一切网络上的文件。 本地包含: ①无限制 感受一下使用phpstudy的文件上传,开启phpstudy的apache…

深度学习处理时间序列(5)

Keras中的循环层 上面的NumPy简单实现对应一个实际的Keras层—SimpleRNN层。不过,二者有一点小区别:SimpleRNN层能够像其他Keras层一样处理序列批量,而不是像NumPy示例中的那样只能处理单个序列。也就是说,它接收形状为(batch_si…

操作系统相关知识点

操作系统在进行线程切换时需要进行哪些动作? 保存当前线程的上下文 保存寄存器状态、保存栈信息。 调度器选择下一个线程 调度算法决策:根据策略(如轮转、优先级、公平共享)从就绪队列选择目标线程。 处理优先级:实时…

从0到1:Rust 如何用 FFmpeg 和 OpenGL 打造硬核视频特效

引言:视频特效开发的痛点,你中了几个? 视频特效如今无处不在:短视频平台的滤镜美化、直播间的实时美颜、影视后期的电影级调色,甚至 AI 生成内容的动态效果。无论是个人开发者还是团队,视频特效都成了吸引…

【并发编程 | 第一篇】线程相关基础知识

1.并发和并行有什么区别 并发是指多核CPU上的多任务处理,多个任务在同一时刻真正同时执行。 并行是指单核CPU上的多任务处理,多个任务在同一时间段内交替执行,通过时间片轮转实现交替执行,用于解决IO密集型瓶颈。 如何理解线程安…

Kafka 偏移量

在 Apache Kafka 中,偏移量(Offset)是一个非常重要的概念。它不仅用于标识消息的位置,还在多种场景中发挥关键作用。本文将详细介绍 Kafka 偏移量的核心概念及其使用场景。 一、偏移量的核心概念 1. 定义 偏移量是一个非负整数…

18.redis基本操作

Redis(Remote Dictionary Server)是一个开源的、高性能的键值对(Key-Value)存储数据库,广泛应用于缓存、消息队列、实时分析等场景。它以其极高的读写速度、丰富的数据结构和灵活的应用方式而受到开发者的青睐。 Redis 的主要特点 ​高性能: ​内存存储:Redis 将所有数…

历年跨链合约恶意交易详解(一)——THORChain退款逻辑漏洞

漏洞合约函数 function returnVaultAssets(address router, address payable asgard, Coin[] memory coins, string memory memo) public payable {if (router address(this)){for(uint i 0; i < coins.length; i){_adjustAllowances(asgard, coins[i].asset, coins[i].a…

通俗易懂的讲解SpringBean生命周期

&#x1f4d5;我是廖志伟&#xff0c;一名Java开发工程师、《Java项目实战——深入理解大型互联网企业通用技术》&#xff08;基础篇&#xff09;、&#xff08;进阶篇&#xff09;、&#xff08;架构篇&#xff09;清华大学出版社签约作家、Java领域优质创作者、CSDN博客专家、…

深入理解 `git pull --rebase` 与 `--allow-unrelated-histories`:区别、原理与实战指南

&#x1f680; git pull --rebase vs --allow-unrelated-histories 全面解析 在日常使用 Git 时&#xff0c;我们经常遇到两种拉取远程代码的方式&#xff1a;git pull --rebase 和 git pull --allow-unrelated-histories。它们的区别是什么&#xff1f;各自适用哪些场景&…

Matlab_Simulink中导入CSV数据与仿真实现方法

前言 在Simulink仿真中&#xff0c;常需将外部数据&#xff08;如CSV文件或MATLAB工作空间变量&#xff09;作为输入信号驱动模型。本文介绍如何高效导入CSV数据至MATLAB工作空间&#xff0c;并通过From Workspace模块实现数据到Simulink的精确传输&#xff0c;适用于运动控制…

Spring Boot 中 JdbcTemplate 处理枚举类型转换 和 减少数据库连接的方法 的详细说明,包含代码示例和关键要点

以下是 Spring Boot 中 JdbcTemplate 处理枚举类型转换 和 减少数据库连接的方法 的详细说明&#xff0c;包含代码示例和关键要点&#xff1a; 一、JdbcTemplate 处理枚举类型转换 1. 场景说明 假设数据库存储的是枚举的 String 或 int 值&#xff0c;但 Java 实体类使用 enu…

API 安全之认证鉴权

作者&#xff1a;半天 前言 API 作为企业的重要数字资源&#xff0c;在给企业带来巨大便利的同时也带来了新的安全问题&#xff0c;一旦被攻击可能导致数据泄漏重大安全问题&#xff0c;从而给企业的业务发展带来极大的安全风险。正是在这样的背景下&#xff0c;OpenAPI 规范…