深度学习PyTorch 之 transformer-中文多分类

transformer的原理部分在前面基本已经介绍完了,接下来就是代码部分,因为transformer可以做的任务有很多,文本的分类、时序预测、NER、文本生成、翻译等,其相关代码也会有些不同,所以会分别进行介绍

但是对于不同的任务其流程是一样的,所以一些重复的步骤就不过多解释了。

1、 前期准备

数据和之前LSTM是一样的,同时我们还使用上次训练好的词嵌入模型

以下是代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from gensim.models import KeyedVectors
from sklearn.model_selection import train_test_split
import pandas as pd
import jieba
import re
from sklearn.preprocessing import LabelEncoder# 加载数据
file_path = './data/news.csv'
data = pd.read_csv(file_path)# 显示数据的前几行
data.head()# 文本清洗和分词函数
def clean_and_cut(text):# 删除特殊字符和数字text = re.sub(r'[^a-zA-Z\u4e00-\u9fff]', '', text)# 使用jieba进行分词words = jieba.cut(text)return ' '.join(words)X_train_cut = data["text"].apply(clean_and_cut)
# 显示处理后的文本
data.head()# 将标签转换为数值形式
label_encoder = LabelEncoder()
data["label"] = label_encoder.fit_transform(data["label"])
# 加载保存的word vectors
loaded_wv = KeyedVectors.load('word_vector', mmap='r') class Word2VecDataset(Dataset):def __init__(self, texts, labels, word2vec, max_len=100):self.texts = textsself.labels = labelsself.word2vec = word2vecself.max_len = max_lendef __len__(self):return len(self.texts)def __getitem__(self, idx):text = self.texts[idx]label = self.labels[idx]embeds = [self.word2vec[word] if word in self.word2vec else np.zeros(self.word2vec.vector_size) for word in text]if len(embeds) > self.max_len:embeds = embeds[:self.max_len]else:embeds += [np.zeros(self.word2vec.vector_size) for _ in range(self.max_len - len(embeds))]return torch.tensor(embeds, dtype=torch.float), torch.tensor(label, dtype=torch.long)# texts和labels是数据集中的文本和标签列表
texts = X_train_cut.tolist()
labels = data['label'].tolist()# 划分数据集
train_texts, test_texts, train_labels, test_labels = train_test_split(texts, labels, test_size=0.2)

2、位置编码和主模型

import mathclass PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=100):super(PositionalEncoding, self).__init__()# 创建一个位置编码矩阵pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)  # (1, max_len, d_model)self.register_buffer('pe', pe)def forward(self, x):# x: (batch_size, max_len, d_model)x = x + self.pe.expand(x.size(0), -1, -1)return x

2.1 PositionalEncoding 类

这个类用于创建和提供位置编码。位置编码是 Transformer 模型中用于注入序列中单词的位置信息的机制。这种位置信息对于模型理解单词的顺序很重要。

初始化方法 __init__
  • d_model:模型的维度,也是词嵌入的维度。
  • max_len:序列的最大长度。
  • pe:位置编码矩阵,大小为 (1, max_len, d_model)。这个矩阵被注册为一个缓冲区,这意味着它会被保存和加载与模型的其他参数一起。
前向传播方法 forward
  • 输入 x 的形状是 (batch_size, max_len, d_model)
  • self.pe.expand(x.size(0), -1, -1):这个操作将位置编码矩阵扩展为 (batch_size, max_len, d_model),以便它可以与输入数据相加。
  • 最后,将扩展后的位置编码矩阵加到输入数据上,并返回结果。
#修改Transformer模型以添加位置编码
class TransformerClassifierWithPE(nn.Module):def __init__(self, num_classes, d_model=100, nhead=2, num_layers=2, dim_feedforward=2048, dropout=0.1):super(TransformerClassifierWithPE, self).__init__()# 位置编码self.pos_encoder = PositionalEncoding(d_model)# Transformer编码器层encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout)self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)# 分类器self.classifier = nn.Linear(d_model, num_classes)def forward(self, x):# x: (batch_size, max_len, d_model)x = self.pos_encoder(x)x = x.permute(1, 0, 2)  # (max_len, batch_size, d_model)x = self.transformer_encoder(x)  # (max_len, batch_size, d_model)x = x.mean(dim=0)  # (batch_size, d_model)x = self.classifier(x)  # (batch_size, num_classes)return x

2.2 TransformerClassifierWithPE 类

这个类定义了一个带有位置编码的 Transformer 分类器模型。

初始化方法 __init__
  • num_classes:分类任务的类别数量。
  • d_model:模型的维度,也是词嵌入的维度。
  • nhead:多头注意力的头数。
  • num_layers:Transformer 编码器层的数量。
  • dim_feedforward:前馈网络中的隐藏层维度。
  • dropout:Dropout 的概率。
  • pos_encoder:PositionalEncoding 实例,用于位置编码。
  • transformer_encoder:Transformer 编码器,由多个 TransformerEncoderLayer 组成。
  • classifier:线性分类器,用于生成最终的分类结果。
前向传播方法 forward
  • 输入 x 的形状是 (batch_size, max_len, d_model)
  • 首先,使用 self.pos_encoder(x) 获取位置编码后的输入。
  • 然后,将输入的维度从 (batch_size, max_len, d_model) 转换为 (max_len, batch_size, d_model),这是因为 PyTorch 的 Transformer 编码器期望的输入维度是这样的。
  • 接下来,通过 self.transformer_encoder(x) 应用 Transformer 编码器。
  • 然后,使用 x.mean(dim=0) 获取每个序列的平均表示。
  • 最后,通过 self.classifier(x) 应用线性分类器,得到最终的分类结果。
    这个模型可以用于文本分类任务,其中输入是文本序列的词嵌入表示。

3、训练模型


# 模型参数
d_model = 512
nhead = 8
num_encoder_layers = 3
dim_feedforward = 2048
num_classes = len(data.label.unique())  # 假设label_dict是我们的标签字典
max_len = 256model = TransformerClassifierWithPE( d_model=d_model, nhead=nhead, num_layers=num_encoder_layers, dim_feedforward=dim_feedforward, num_classes=num_classes, max_len=max_len,dropout=0.1)-----------------------------
TransformerModel((pos_encoder): PositionalEncoding()(transformer_encoder): TransformerEncoder((layers): ModuleList((0-2): 3 x TransformerEncoderLayer((self_attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True))(linear1): Linear(in_features=512, out_features=2048, bias=True)(dropout): Dropout(p=0.1, inplace=False)(linear2): Linear(in_features=2048, out_features=512, bias=True)(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)(dropout1): Dropout(p=0.1, inplace=False)(dropout2): Dropout(p=0.1, inplace=False))))(decoder): Linear(in_features=512, out_features=10, bias=True)
)
# 训练模型
num_epochs = 20
for epoch in range(num_epochs):for inputs, labels in train_loader:# 清除梯度optimizer.zero_grad()# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs, labels)# 反向传播loss.backward()# 更新参数optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')
# 在测试集上评估模型
model.eval()
with torch.no_grad():correct = 0total = 0for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the model on the test set: {100 * correct / total}%')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/755738.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【RabbitMQ | 第七篇】RabbitMQ实现JSON、Map格式数据的发送与接收

文章目录 7.RabbitMQ实现JSON、Map格式数据的发送与接收7.1消息发送端7.1.1引入依赖7.1.2yml配置7.1.3RabbitMQConfig配置类——(非常重要)(1)创建交换器方法(2)创建队列方法(3)绑定…

代码随想录算法训练营第27天|93.复原IP地址、78.子集、90.子集二

目录 一、力扣93.复原IP地址1.1 题目1.2 思路1.3 代码1.4 总结 二、力扣78.子集2.1 题目2.2 思路2.3 代码2.4 总结 三、力扣90.子集二3.1 题目3.2 思路3.3 代码3.4 总结 一、力扣93.复原IP地址 (比较困难,做起来很吃力) 1.1 题目 1.2 思路 …

【数据结构练习题】栈——1.括号匹配 2.逆波兰表达式求值 3.出栈入栈次序匹配 4.最小栈

♥♥♥♥♥个人主页♥♥♥♥♥ ♥♥♥♥♥数据结构练习题总结专栏♥♥♥♥♥ 文件目录 前言1.括号匹配1.1问题描述1.2解题思路1.3画图解释1.4代码实现2.逆波兰表达式求值 2.1问题描述2.2解题思路2.3画图解释2.4代码解释3.出栈入栈次序匹配 3.1问题描述3.2思路分析3.3画图解释3.…

【No.13】蓝桥杯二分查找|整数二分|实数二分|跳石头|M次方根|分巧克力(C++)

二分查找算法 知识点 二分查找原理讲解在单调递增序列 a 中查找 x 或 x 的后继在单调递增序列 a 中查找 x 或 x 的前驱 二分查找算法讲解 枚举查找即顺序查找, 实现原理是逐个比较数组 a[0:n-1] 中的元素,直到找到元素 x 或搜索整个数组后确定 x 不在…

CPU设计实战—异常处理指令

异常类型以及精确异常的处理 异常有点像中断,处理完还要回到原来的状态,所以需要对之前的状态进行保存。本CPU主要实现对以下异常的处理: 1.外部硬件中断 2.复位异常 3.系统调用异常(发生在译码阶段) 4.溢出异常&…

Linux下磁盘分区类型及文件系统扩容

本篇文章基础知识点较多,文章偏长。建议收藏~ 之前介绍过一篇文章 重新构建KVM虚拟机基础镜像,当中有个待优化的点。 Centos 官方的镜像中默认的系统盘(/dev/vda)的大小是8G空间 但是实际使用时,8G的系统盘肯定不满足需求。这个时候我们就需…

做好外贸网站SEO优化,拓展海外市场

随着全球贸易的发展和互联网的普及,越来越多的外贸企业将目光投向了网络,希望通过建立网站来拓展海外市场。然而,在竞争激烈的外贸市场中,要让自己的网站脱颖而出,吸引更多的目标客户,就需要进行有效的SEO优…

openGauss学习笔记-246 openGauss性能调优-SQL调优-经验总结:SQL语句改写规则

文章目录 openGauss学习笔记-246 openGauss性能调优-SQL调优-经验总结:SQL语句改写规则246.1 使用union all代替union246.2 join列增加非空过滤条件246.3 not in转not exists246.4 选择hashagg246.5 尝试将函数替换为case语句246.6 避免对索引使用函数或表达式运算2…

PyTorch学习笔记之基础函数篇(十三)

文章目录 7.7 torch.ceil() 函数7.8 torch.floor() 函数7.9 torch.clamp() 函数7.10 torch.neg() 函数7.11 torch.reciprocal() 函数7.12 torch.rsqrt() 函数7.13 torch.sqrt() 函数 7.7 torch.ceil() 函数 在PyTorch中,torch.ceil 函数用于对张量(tens…

面试算法-50-二叉树的最大深度

题目 给定一个二叉树 root ,返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 示例 1: 输入:root [3,9,20,null,null,15,7] 输出:3 解 class Solution {public int maxDepth(TreeNo…

《算法设计与分析第二版》100行 C语言实现 广度度优先算法 BFS——最短距离

抄录自课本P157页。 #include <stdio.h> #define MAXQ 100 // 队列大小 #define MAxN 10 // 最大迷宫大小 int n8; // 迷宫大小 char Maze [MAxN][MAxN] {{O,X,X,X,X,X,X,X,},{O,O,O,X,O,X,O,X,},{X,X,O,O,O,X,O,X,},{X,X,O,X,O,X,X,X,},…

Git ignore: 忽略与清除

一、vs、vc #.svn .clang-format .gitignore Src/[Dd]ebug/ Src/[Rr]elease/ .vs .vs/* ​ # other file *.txt *.log *LOG/ *log/ ​ # Compiled Object files *.slo *.lo #*.o *.obj ​ # Precompiled Headers *.gch *.pch ​ # Compiled Dynamic libraries #*.so *.dylib #…

LightDB24.1 Sequence支持设置minvalue小于INT64_MIN

背景介绍 Oracle数据库支持设置sequence的minvalue为-1000000000000000000000000000&#xff0c;在用户迁移到LightDB时&#xff0c;sequence设置minvalue为-1000000000000000000000000000会报错。为了兼容Oracle数据库的使用习惯&#xff0c;在LightDB24.1版本中&#xff0c;…

HDFS概述及常用shell操作

HDFS 一、HDFS概述1.1 HDFS适用场景1.2 HDFS优缺点1.3 HDFS文件块大小 二、HDFS的shell操作2.1 上传2.2 下载2.3 HDFS直接操作 一、HDFS概述 1.1 HDFS适用场景 因为HDFS里所有的文件都是维护在磁盘里的 在磁盘中对文件的历史内容进行修改 效率极其低(但是追加可以) 1.2 HDF…

Linux电源管理——系统Suspend/Resume流程

本篇文章主要是自己的学习笔记&#xff0c;主要内容是分析linux系统中设备的Suspend和Resume流程&#xff0c;用到的内核版本为 linux-4.14。 目录 1、Linux 内核的Suspend方法 2、__device_suspend 函数 3、pm_op 函数 4、suspend_enter 函数 5、resume流程 1、Linux 内…

dockerfile更改docker镜像源

方法一&#xff1a; ## 更换源 RUN sed -i s/deb.debian.org//mirrors.aliyun.com/g /etc/apt/sources.list \ && apt-get update 方法二&#xff1a; RUN echo "deb http://mirrors.tuna.tsinghua.edu.cn/debian/ buster main contrib non-free" >/…

js中副作用的消除还解决了并行计算带来的竞争问题,具体是如何解决的

在JavaScript中&#xff0c;副作用是指对外部环境产生的可观察的变化&#xff0c;例如修改全局变量、修改DOM元素等。副作用的存在可能导致代码的可维护性和可测试性下降&#xff0c;并且在并行计算中可能引发竞争问题。 不纯的函数有可能访问同一块资源&#xff0c;如果先后调…

走近 AI Infra 架构师:在高速飞驰的大模型“赛车”上“换轮子”的人

如果把大模型训练比作 F1 比赛&#xff0c;长凡所在的团队就是造车的人&#xff0c;也是在比赛现场给赛车换轮子的人。1% 的训练提速&#xff0c;或者几秒之差的故障恢复时间&#xff0c;累积起来&#xff0c;都能影响到几百万的成本。长凡说&#xff1a;“大模型起来的时候&am…

算法详解——选择排序和冒泡排序

一、选择排序 选择排序算法的执行过程是这样的&#xff1a;首先&#xff0c;算法遍历整个列表以确定最小的元素&#xff0c;接着&#xff0c;这个最小的元素被置换到列表的开头&#xff0c;确保它被放置在其应有的有序位置上。接下来&#xff0c;从列表的第二个元素开始&#x…

事件高级、

文章目录 1.注册事件&#xff08;绑定事件&#xff09;addEventListener 事件监听方式attachEvent 事件监听方式、兼容性解决方案 * 2.删除事件&#xff08;解绑事件&#xff09;删除事件的方式删除事件兼容性解决方案 * 3.DOM事件流4.事件对象使用语法兼容性方案*常见属性和方…