使用paddlepaddle框架构建ViT用于CIFAR10图像分类

使用paddlepaddle框架构建ViT用于CIFAR10图像分类

硬件环境:GPU (1 * NVIDIA T4)
运行时间:一个epoch大概一分钟

import paddle
import time
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
from paddle.io import DataLoader
import numpy as np
import paddle.optimizer.lr as lrScheduler
from paddle.vision.transforms import BaseTransform
import math
from tqdm import tqdmpaddle.seed(1024)
np.random.seed(1234)# 设置使用的设备为GPU
paddle.set_device('gpu')# 通过AutoTransforms实现随机数据增强
class AutoTransforms(BaseTransform):def __init__(self, transforms=None, keys=None):super(AutoTransforms, self).__init__(keys)self.transforms = transformsdef _apply_image(self, image):if self.transforms is None:return imagechoose=np.random.randint(0, len(self.transforms))return self.transforms[choose](image)# 训练集数据增强
mean = [0.5071, 0.4867, 0.4408]
std = [0.2675, 0.2565, 0.2761]transforms_list= [transforms.BrightnessTransform(0.5),  # 亮度变换transforms.SaturationTransform(0.5),  # 饱和度变换transforms.ContrastTransform(0.5),    # 对比度变换transforms.HueTransform(0.5),         # 色调变换transforms.RandomRotation(15,expand=True,fill=128),   # 随机旋转transforms.ColorJitter(0.5,0.5,0.5,0.5),transforms.Grayscale(3)     # 转换为灰度图
]train_tx = transforms.Compose([transforms.RandomHorizontalFlip(),AutoTransforms(transforms_list),transforms.RandomCrop(32),transforms.RandomVerticalFlip(),transforms.Transpose(),transforms.Normalize(0.0, 255.0),transforms.Normalize(mean, std)
])val_tx = transforms.Compose([transforms.Transpose(),transforms.Normalize(0.0, 255.0),transforms.Normalize(mean, std)
])cifar10_train = paddle.vision.datasets.Cifar10(mode='train', transform=train_tx, download=True)
cifar10_test = paddle.vision.datasets.Cifar10(mode='test', transform=val_tx, download=True)# 训练集数量50000,测试集数量10000
print('训练集数量:', len(cifar10_train), '训练集图像尺寸', cifar10_train[0][0].shape)
print('测试集数量:', len(cifar10_test), '测试集图像尺寸', cifar10_test[0][0].shape)def anti_normalize(image):# 将图像转换为张量image = paddle.to_tensor(image)# 处理均值和标准差t_mean = paddle.to_tensor(mean).reshape([3, 1, 1]).expand([3, 32, 32])t_std = paddle.to_tensor(std).reshape([3, 1, 1]).expand([3, 32, 32])# 反归一化return (image * t_std + t_mean).transpose([1, 2, 0])# ViT模型组网部分包含图像切片(Patches),多层感知机(MLP),多头自注意力机制(MultiHeadSelfAttention)以及Transformer编码器(Transformer Encoder)。
# Patches的目的是实现图像切块,将整张图像分割成一个个小块(patch),以方便后续将图像编码成一个个tokens。
class Patches(paddle.nn.Layer):def __init__(self, patch_size):super(Patches, self).__init__()self.patch_size = patch_sizedef forward(self, images):patches = F.unfold(images, self.patch_size, self.patch_size)return patches.transpose([0,2,1])# 多层感知机包含线性层,激活层(GELU),DropOut层。线性层将输入扩增指定维度,再缩减回去,MLP不改变输入输出维度。
class Mlp(nn.Layer):def __init__(self, feats, mlp_hidden, dropout=0.1):super().__init__()self.fc1 = nn.Linear(feats, mlp_hidden)self.fc2 = nn.Linear(mlp_hidden, feats)self.act = nn.GELU()self.dropout = nn.Dropout(dropout)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.dropout(x)x = self.fc2(x)x = self.dropout(x)return x# 多头自注意力机制。
class MultiHeadSelfAttention(nn.Layer):def __init__(self, feats, head=8, dropout=0., attn_dropout=0.0):super(MultiHeadSelfAttention, self).__init__()self.head = headself.feats = featsself.sqrt_d = self.feats ** 0.5self.qkv = nn.Linear(feats,feats * 3)self.out = nn.Linear(feats, feats)self.dropout = nn.Dropout(dropout)self.attn_dropout = nn.Dropout(attn_dropout)def transpose_multi_head(self, x):new_shape = x.shape[:-1] + [self.head, self.feats//self.head]x = x.reshape(new_shape)x = x.transpose([0, 2, 1, 3])return xdef forward(self, x):b, n, f = x.shapeqkv = self.qkv(x).chunk(3, -1)q, k, v = map(self.transpose_multi_head, qkv)attn = F.softmax(paddle.einsum("bhif, bhjf->bhij", q, k) / self.sqrt_d, axis=-1)attn = self.attn_dropout(attn)attn = paddle.einsum("bhij, bhjf->bihf", attn, v)out = self.dropout(self.out(attn.flatten(2))) # 使用flatten函数将多头输出恢复为原始的特征维度,并通过out线性层进行映射return out# 一个Transformer Encoder包括LayerNorm层,MultiHeadSelfAttention以及MLP,将输入进来的token编码输出。
class TransformerEncoder(nn.Layer):def __init__(self, feats, mlp_hidden, head=8, dropout=0., attn_dropout=0.):super(TransformerEncoder, self).__init__()self.layer1 = nn.LayerNorm(feats)self.msa = MultiHeadSelfAttention(feats, head=head, dropout=dropout, attn_dropout=attn_dropout)self.layer2 = nn.LayerNorm(feats)self.mlp = Mlp(feats, mlp_hidden)def forward(self, x):out = self.msa(self.layer1(x)) + xout = self.mlp(self.layer2(out)) + outreturn out# 将Patches,MLP,MultiHeadSelfAttention以及TransformerEncoder组合,实现ViT。
class ViT(nn.Layer):# in_c:输入图像的通道数;num_classes:分类任务的类别数;img_size:输入图像的尺寸;patch:将图像分割为的块的大小# dropout 和 attn_dropout:分别用于MLP和自注意力机制的Dropout比例;num_layers:Transformer编码器层的数量# hidden:Transformer的隐藏层维度;mlp_hidden:MLP的隐藏层维度# head:多头自注意力模块的头数;is_cls_token:是否添加分类令牌def __init__(self, in_c=3, num_classes=10, img_size=32, patch=8, dropout=0., attn_dropout=0.0, num_layers=7, hidden=384, mlp_hidden=384*4, head=8, is_cls_token=True):super(ViT, self).__init__()self.patch = patchself.is_cls_token = is_cls_tokenself.patch_size = img_size // self.patchself.patches = Patches(self.patch_size)f = (img_size // self.patch) ** 2 * 3num_tokens = (self.patch ** 2) + 1 if self.is_cls_token else (self.patch ** 2)# emb:线性层,用于将块的特征映射到隐藏层维度self.emb = nn.Linear(f, hidden)self.cls_token  = paddle.create_parameter(shape = [1, 1, hidden],dtype = 'float32',default_initializer=nn.initializer.Assign(paddle.randn([1, 1, hidden]))) if is_cls_token else None# pos_embedding:位置嵌入,用于为每个块(包括分类令牌)提供位置信息self.pos_embedding  = paddle.create_parameter(shape = [1,num_tokens, hidden],dtype = 'float32',default_initializer=nn.initializer.Assign(paddle.randn([1,num_tokens, hidden])))encoder_list = [TransformerEncoder(hidden, mlp_hidden=mlp_hidden, dropout=dropout, attn_dropout=attn_dropout, head=head) for _ in range(num_layers)]self.encoder = nn.Sequential(*encoder_list)self.fc = nn.Sequential(nn.LayerNorm(hidden),nn.Linear(hidden, num_classes) # for cls_token)# 使用patches将输入图像x分割为块,并展平这些块。# 将展平的块通过emb线性层映射到隐藏层维度。# 如果is_cls_token为True,则在输入序列的开始处添加一个分类令牌。# 将位置嵌入pos_embedding添加到输入序列中。# 将输入序列传递给encoder进行Transformer编码。# 如果is_cls_token为True,则只取分类令牌的输出;否则,取所有块的输出的平均值。# 将最终输出传递给fc全连接层以进行分类。# 返回分类结果。def forward(self, x):out = self.patches(x)out = self.emb(out)if self.is_cls_token:out = paddle.concat([self.cls_token.tile([out.shape[0],1,1]), out], axis=1)out = out + self.pos_embeddingout = self.encoder(out)if self.is_cls_token:out = out[:,0]else:out = out.mean(1)out = self.fc(out)return out# 构建LabelSmoothingCrossEntropyLoss作为损失函数,并采用LinearWarmup和CosineAnnealingDecay构建带有Warmup的Cosine学习率衰减方式。
# 标签平滑的交叉熵损失函数,正则化方法,提高模型的泛化能力
class LabelSmoothingCrossEntropyLoss(nn.Layer):def __init__(self, classes, smoothing=0.0, dim=-1):super(LabelSmoothingCrossEntropyLoss, self).__init__()self.confidence = 1.0 - smoothingself.smoothing = smoothingself.cls = classesself.dim = dimdef forward(self, pred, target):pred = F.log_softmax(pred, axis=self.dim)with paddle.no_grad():true_dist = paddle.ones_like(pred)true_dist.fill_(self.smoothing / (self.cls - 1))true_dist.put_along_axis_(target.unsqueeze(1), self.confidence, 1)return paddle.mean(paddle.sum(-true_dist * pred, axis=self.dim))def get_scheduler(epochs, warmup_epochs, learning_rate):base_scheduler = lrScheduler.CosineAnnealingDecay(learning_rate=learning_rate, T_max=epochs, eta_min=1e-5, verbose=False)scheduler = lrScheduler.LinearWarmup(base_scheduler, warmup_epochs, 1e-5, learning_rate, last_epoch=-1, verbose=False)return scheduler# 模型构建
Model = ViT(in_c=3, num_classes=10, img_size=32, patch=8, dropout=0.5, attn_dropout=0.1, num_layers=7, hidden=384, head=12, mlp_hidden=384, is_cls_token=True)
# 输出模型结构
paddle.summary(Model, (1, 3, 32, 32))# 定义训练的超参数、优化器、损失函数和学习率衰减方式,构建数据迭代器。
EPOCHS = 100    # 训练的总轮数
BATCH_SIZE = 128    # 批处理大小
NUM_CLASSES = 10    # 类别总数
WARMUP_EPOCHS = 5   # 学习率预热阶段的轮数
LR = 1e-3   # 初始学习率# 学习率调度器
scheduler = get_scheduler(epochs=EPOCHS, warmup_epochs=WARMUP_EPOCHS, learning_rate=LR)
# Adam优化器
optim = paddle.optimizer.Adam(learning_rate=scheduler, parameters=Model.parameters(), weight_decay=5e-5)
# 损失函数
criterion = LabelSmoothingCrossEntropyLoss(NUM_CLASSES, smoothing=0.1)# 加载训练集,打乱顺序
train_loader = DataLoader(cifar10_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, drop_last=False)
# 加载测试集,不打乱顺序
test_loader = DataLoader(cifar10_test, batch_size=BATCH_SIZE * 16, shuffle=False, num_workers=0, drop_last=False)# 定义模型训练函数train_epoch,在模型训练过程中,打印训练过程中的学习率,损失值以及模型在训练集上的精度。
def train_epoch(model, epoch, interval=20):acc_num = 0  # acc_num用于记录正确预测的数量total_samples = 0   # total_samples用于记录已经处理过的样本总数nb = len(train_loader)pbar = enumerate(train_loader)# 用tqdm库来创建一个进度条pbar = tqdm(pbar, total=nb, colour='red', disable=((epoch + 1) % interval != 0))pbar.set_description(f'EPOCH: {epoch:3d}')for _, (_, data) in enumerate(pbar):x_data = data[0]        # 从数据批次中提取特征和标签。y_data = data[1]predicts = model(x_data)    # 使用模型对特征进行预测。loss = criterion(predicts, y_data)  # 计算预测损失。loss_item = loss.item()acc_num += paddle.sum(predicts.argmax(1) == y_data).item()total_samples += y_data.shape[0]    # 更新正确预测的数量和总样本数,从而计算总的准确率。total_acc = acc_num / total_samplescurrent_lr = optim.get_lr() # 获取当前的学习率。loss.backward() # 反向传播损失以更新模型的权重。pbar.set_postfix(train_loss=f'{loss_item:5f}', train_acc=f'{total_acc:5f}', train_lr=f'{current_lr:5f}')optim.step()     # 使用优化器进行一步优化。optim.clear_grad()  # 清除已计算的梯度,为下一个批次的优化做准备。scheduler.step()    # 更新进度条的信息,显示当前的损失、准确率和学习率。# 定义模型评估函数validation,在模型验证过程中,输出模型在验证集上的精度。
@paddle.no_grad()
def validation(model, epoch, interval=20):model.eval()acc_num = 0total_samples = 0nb = len(test_loader)pbar = enumerate(test_loader)pbar = tqdm(pbar, total=nb, colour='green', disable=((epoch + 1) % interval != 0))pbar.set_description(f'EVAL')for _, (_, data) in enumerate(pbar):x_data = data[0]y_data = data[1]predicts = model(x_data)acc_num += paddle.sum(predicts.argmax(1) == y_data).item()total_samples += y_data.shape[0]batch_acc = paddle.metric.accuracy(predicts, y_data.unsqueeze(1)).item()total_acc = acc_num / total_samplespbar.set_postfix(eval_batch_acc=f'{batch_acc:4f}', total_acc=f'{total_acc:4f}')# 每20轮打印一次模型训练和评估信息,每50轮保存一次模型参数。
start = time.time()
print(start)
for epoch in range(EPOCHS):train_epoch(Model, epoch)validation(Model, epoch)if (epoch + 1) % 50 == 0:paddle.save(Model.state_dict(), str(epoch + 1) + '.pdparams')
paddle.save(Model.state_dict(), 'finished.pdparams')
end = time.time()
print('Training Cost ', (end-start) / 60, 'minutes')state_dict = paddle.load('finished.pdparams')   # 加载模型的权重
Model.set_state_dict(state_dict)
Model.eval()
top1_num = 0    # 记录Top1预测正确的样本数
top5_num = 0    # 记录Top5预测正确的样本数
total_samples = 0
nb = len(test_loader)
pbar = enumerate(test_loader)
pbar = tqdm(pbar, total=nb, colour='green')
pbar.set_description(f'EVAL')
with paddle.no_grad():for _, (_, data) in enumerate(pbar):x_data = data[0]y_data = data[1]predicts = Model(x_data)total_samples += y_data.shape[0]# paddle.metric.accuracy计算Top1和Top5的准确率,并更新相应的计数器。top1_num += paddle.metric.accuracy(predicts, y_data.unsqueeze(1), k=1).item() * y_data.shape[0]top5_num += paddle.metric.accuracy(predicts, y_data.unsqueeze(1), k=5).item() * y_data.shape[0]TOP1 = top1_num / total_samplesTOP5 = top5_num / total_samplespbar.set_postfix(TOP1=f'{TOP1:4f}', TOP5=f'{TOP5:4f}')

预测结果:TOP1=0.800800, TOP5=0.963500

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/15958.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CCF-GESP 等级考试 2023年3月认证C++一级真题解析

2024年03月真题 一、单选题(每题2分,共30分) 第 1 题 以下不属于计算机输入设备的有( )。 A. 键盘B. 音箱C. 鼠标D. 传感器 正确答案:B. 音箱 解析: A. 键盘:键盘是输入设备。B. …

第六节:带你全面理解vue3 浅层响应式API: shallowRef, shallowReactive, shallowReadonly

前言 前面两章,给大家讲解了vue3中ref, reactive,readonly创建响应式数据的API, 以及常用的计算属性computed, 侦听器watch,watchEffect的使用 其中reactive, ref, readonly创建的响应式数据都是深层响应. 而本章主要给大家讲解以上三个API 对应的创建浅层响应式数据的 API,…

Java面试题:Executor框架在Java并发编程中扮演什么角色?如何使用它?

在Java并发编程中,Executor框架扮演着核心角色,它提供了一种高级的、线程安全的机制来异步执行任务。Executor框架的主要目的是将任务的提交与任务的执行分离,从而简化了多线程编程的复杂性。 Executor框架的角色: 任务与线程分离…

持续总结中!2024年面试必问 20 道 Redis面试题(八)

上一篇地址:持续总结中!2024年面试必问 20 道 Redis面试题(七)-CSDN博客 十五、使用过Redis做异步队列么,你是怎么用的? Redis作为一个高性能的键值存储系统,非常适合用来实现异步队列。异步队…

【STM32单片机】----实现LED灯闪烁实战

🎩 欢迎来到技术探索的奇幻世界👨‍💻 📜 个人主页:一伦明悦-CSDN博客 ✍🏻 作者简介: C软件开发、Python机器学习爱好者 🗣️ 互动与支持:💬评论 &…

【机器学习-23】关联规则(Apriori)算法:介绍、应用与实现

在现代数据分析中,经常需要从大规模数据集中挖掘有用的信息。关联规则挖掘是一种强大的技术,可以揭示数据中的隐藏关系和规律。本文将介绍如何使用Python进行关联规则挖掘,以帮助您发现数据中的有趣模式。 一、引言 1. 简要介绍关联规则学习…

[处理器芯片]-5 超标量CPU实现之ALU

ALU(Arithmetic Logic Unit,算术逻辑单元),是CPU执行单元中最主要的组成部分。 1 主要功能 算术运算:执行加法、减法、乘法和除法等算术运算。 逻辑运算:执行与、或、非、异或等逻辑运算。 移位运算&am…

动态路由实验—OSPF

动态路由协议实验-------OSPF 链路状态路由选择协议又被称为最短路径优先协议&#xff0c;它基SPF&#xff08;shortest path first &#xff09;算法 实验要求&#xff1a;各个PC之间能够互通 1.四台PC配置如下 PC1 PC2 PC3 PC4 2.配置各个交换机的口子的IP R1 <HUAWE…

Room注解无效原因

在Android项目中&#xff0c;如果父模块使用Kotlin&#xff0c;而子模块用Java编写&#xff0c;并且在子模块中使用了Room库&#xff0c;那么你会发现需要使用kapt而不是annotationProcessor来处理Room注解。这里有几个原因和背景知识&#xff1a; 1. 项目配置的影响 父模块的…

spiderfoot一键扫描IP信息(KALI工具系列九)

目录 1、KALI LINUX简介 2、spiderfoot工具简介 3、在KALI中使用spiderfoot 3.1 目标主机IP&#xff08;win&#xff09; 3.2 KALI的IP 4、命令示例 4.1 web访问 4.2 扫描并进行DNS解析 4.3 全面扫描 5、总结 1、KALI LINUX简介 Kali Linux 是一个功能强大、多才多…

YOLOv8+PyQt:实时检测(摄像头、视频)

1.YOLO&#xff1a;CPU实时检测&#xff08;摄像头、视频&#xff09;https://blog.csdn.net/qq_45445740/article/details/106557451 2.YOLOv8PyQt&#xff0c;实现摄像头或视频的实时检测 需要安装 PySide6 和 ultralytics pip install PySide6 pip install ultralyticsfr…

基于docxtpl的模板生成Word

docxtpl是一个用于生成Microsoft Word文档的模板引擎库。它结合了docx模块和Jinja2模板引擎&#xff0c;使用户能够使用Microsoft Word模板文件并在其中填充动态数据。这个库提供了一种方便的方式来生成个性化的Word文档&#xff0c;并支持条件语句、循环语句和变量等控制结构&…

如何在 Elasticsearch 中选择精确 kNN 搜索和近似 kNN 搜索

作者&#xff1a;来自 Elastic Carlos Delgado kNN 是什么&#xff1f; 语义搜索&#xff08;semantic search&#xff09;是相关性排名的强大工具。 它使你不仅可以使用关键字&#xff0c;还可以考虑文档和查询的实际含义。 语义搜索基于向量搜索&#xff08;vector search&…

Angular Ivy:新渲染引擎的性能提升与优化

Angular Ivy是Angular 9及更高版本中引入的默认渲染引擎&#xff0c;它取代了以前的View Engine。Ivy的目标是提高Angular的性能、减少包大小和提高开发者的生产力。 1. AOT编译的改进&#xff1a; 在Ivy中&#xff0c;Angular使用了更早的AOT&#xff08;Ahead-of-Time&…

在AnolisOS8.9系统安装docker-compose

在AnolisOS8.9系统安装docker-compose 下载docker-compose之前请先确保docker已经安装完&#xff0c;教程可以参考 在阿里Anolis OS 8.9龙蜥操作系统安装docker 下载最新版的docker-compose文件 sudo curl -L https://github.com/docker/compose/releases/download/v2.21.0…

大数据工具之HIVE-参数调优,调度乱码(二)

一、调度乱码 在利用HUE工具,搭建WORKFLOW流程的过程中,如果直接执行hivesql数据正常,不会出现乱码现象,如果利用WORKFLOW搭建的流程,进行数据的拉取,会出现数据中文乱码现象,这些乱码主要是由于select 中的硬编码中文导致出现的现象 具体现象如下: select case when …

百度 提前批 国际化广告部 (深圳-机器学习/数据挖掘/自然语言处理工程师) 一面+二面面经

文章目录 0、面试情况1、一面1.1、简历上的项目介绍了个遍1.2、dbscan原理1.3、为什么梯度的负方向就是损失函数下降最快的方向&#xff1f;1.4、bn原理&#xff0c;为什么bn能解决过拟合&#xff0c;1.5、auc原理&#xff0c;为什么ctr或你的广告推荐里用auc指标&#xff1f;1…

TG5032CGN TCXO 超高稳定10pin端子型适用于汽车动力转向控制器

TG5032CGN TCXO / VC-TCXO是一款应用广泛的晶振&#xff0c;具有超高稳定性&#xff0c;CMOS输出和使用晶体基振的削波正弦波输出形式。且有低相位噪声优势&#xff0c;是温补晶体振荡器(TCXO)和压控晶体振荡器(VCXO)结合的产物&#xff0c;具有TCXO和VCXO的共同优点&#xff0…

后台接口返回void但是response有设置合适的相关信息,前端调用接口解析Blob数据下载excel文件

1、pom.xml文件增加依赖&#xff1a; <dependency><groupId>org.apache.poi</groupId><artifactId>poi-ooxml</artifactId></dependency> 2、接口代码如下&#xff1a; /*** 企业列表--导出*/GetMapping(value "/downloadTenantL…

微信小程序上线必备:SSL证书申请以及安装

一、认识ssl证书 1、ssl证书是什么&#xff1f; SSL证书&#xff0c;全称Secure Socket Layer Certificate&#xff0c;是一种数字证书&#xff0c;它遵循SSL&#xff08;现在通常指TLS&#xff0c;Transport Layer Security&#xff09;协议标准&#xff0c;用于在客户端&…