【深度学习】解析Vision Transformer (ViT): 从基础到实现与训练

之前介绍:

https://qq742971636.blog.csdn.net/article/details/132061304

文章目录

  • 背景
      • 实现代码示例
      • 解释
  • 训练
      • 数据准备
      • 模型定义
      • 训练和评估
      • 总结

在这里插入图片描述

Vision Transformer(ViT)是一种基于transformer架构的视觉模型,它最初是由谷歌研究团队在论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》中提出的。ViT将图像分割成固定大小的patches(例如16x16),并将每个patch视为一个词(类似于NLP中的单词)进行处理。以下是ViT的详细讲解:

背景

在计算机视觉领域,传统的卷积神经网络(CNNs)一直是处理图像的主流方法。然而,CNNs存在一些局限性,如在处理长距离依赖关系时表现不佳。ViT引入了transformer架构,通过全局注意力机制,有效地处理图像中的长距离依赖关系。

实现代码示例

ViT代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeatclass PatchEmbedding(nn.Module):def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):super().__init__()self.img_size = img_sizeself.patch_size = patch_sizeself.grid_size = img_size // patch_sizeself.num_patches = self.grid_size ** 2self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)def forward(self, x):x = self.proj(x)  # [B, embed_dim, H, W]x = x.flatten(2)  # [B, embed_dim, num_patches]x = x.transpose(1, 2)  # [B, num_patches, embed_dim]return xclass Attention(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):super().__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)def forward(self, x):B, N, C = x.shapeqkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)return xclass MLP(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = nn.GELU()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass Block(nn.Module):def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.):super().__init__()self.norm1 = nn.LayerNorm(dim)self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)self.drop_path = nn.Identity() if drop_path == 0 else nn.Dropout(drop_path)self.norm2 = nn.LayerNorm(dim)self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop)def forward(self, x):x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))return xclass VisionTransformer(nn.Module):def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):super().__init__()self.num_classes = num_classesself.num_features = self.embed_dim = embed_dimself.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)num_patches = self.patch_embed.num_patchesself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))self.pos_drop = nn.Dropout(p=drop_rate)dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]self.blocks = nn.ModuleList([Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i])for i in range(depth)])self.norm = nn.LayerNorm(embed_dim)self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()nn.init.trunc_normal_(self.pos_embed, std=0.02)nn.init.trunc_normal_(self.cls_token, std=0.02)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):nn.init.trunc_normal_(m.weight, std=0.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, nn.LayerNorm):nn.init.ones_(m.weight)nn.init.zeros_(m.bias)def forward(self, x):B = x.shape[0]x = self.patch_embed(x)cls_tokens = self.cls_token.expand(B, -1, -1)x = torch.cat((cls_tokens, x), dim=1)x = x + self.pos_embedx = self.pos_drop(x)for blk in self.blocks:x = blk(x)x = self.norm(x)cls_token_final = x[:, 0]x = self.head(cls_token_final)return x# 示例输入
img = torch.randn(1, 3, 224, 224)
model = VisionTransformer()
output = model(img)
print(output.shape)  # 输出大小为 [1, 1000]

解释

  1. PatchEmbedding:将输入图像分割为不重叠的patches,并通过卷积操作将其转换为embedding。
  2. Attention:实现自注意力机制。
  3. MLP:实现多层感知器(MLP),包括GELU激活函数和Dropout。
  4. Block:包含一个注意力层和一个MLP层,每层都有残差连接和层归一化。
  5. VisionTransformer:组合上述模块,形成完整的ViT模型。包含位置嵌入和分类头。

训练

为了在GPU上训练ViT模型,你可以使用PyTorch中的DataLoader来处理数据,并确保模型和数据都在GPU上。以下是一个详细的代码示例,包括数据准备、模型定义、训练和评估。

数据准备

假设你的数据结构如下:

dataset/class1/img1.jpgimg2.jpg...class2/img1.jpgimg2.jpg......

你可以使用 torchvision.datasets.ImageFolder 来加载数据。

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm# 数据转换和增强
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载数据
data_dir = 'dataset'
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(data_dir, 'val'), transform=transform)train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)# 获取类别数
num_classes = len(train_dataset.classes)

模型定义

定义ViT模型并将其移动到GPU上。

# VisionTransformer定义(使用上面的定义)
model = VisionTransformer(num_classes=num_classes).cuda()# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)# 如果有多个GPU,使用DataParallel
if torch.cuda.device_count() > 1:model = nn.DataParallel(model)

训练和评估

定义训练和评估函数,并进行训练。

def train_one_epoch(model, criterion, optimizer, data_loader, device):model.train()running_loss = 0.0running_corrects = 0for inputs, labels in tqdm(data_loader):inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(data_loader.dataset)epoch_acc = running_corrects.double() / len(data_loader.dataset)return epoch_loss, epoch_accdef evaluate(model, criterion, data_loader, device):model.eval()running_loss = 0.0running_corrects = 0with torch.no_grad():for inputs, labels in data_loader:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(data_loader.dataset)epoch_acc = running_corrects.double() / len(data_loader.dataset)return epoch_loss, epoch_acc# 训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_epochs = 25for epoch in range(num_epochs):train_loss, train_acc = train_one_epoch(model, criterion, optimizer, train_loader, device)val_loss, val_acc = evaluate(model, criterion, val_loader, device)print(f'Epoch {epoch}/{num_epochs - 1}')print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')# 保存模型
torch.save(model.state_dict(), 'vit_model.pth')

总结

这段代码展示了如何使用PyTorch在GPU上训练Vision Transformer模型。包括数据加载、模型定义、训练和评估步骤。请根据你的实际需求调整批量大小、学习率和训练轮数等参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/28143.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

blender bpy将顶点颜色转换为UV纹理vertex color to texture

一、关于环境 安装blender的bpy,不需要额外再安装blender软件。在python控制台中直接输入pip install bpy即可。 二、关于代码 本文所给出代码仅为参考,禁止转载和引用,仅供个人学习。 本文所给出的例子是https://download.csdn.net/downl…

BerkeleyDB练习

代码; #include <db.h> #include <stdio.h>int main() {DB *dbp;db_create(&dbp, NULL, 0);printf("Berkeley DB version: %s\n", db_version(NULL, NULL, NULL));dbp->close(dbp, 0);return 0; } 编译运行

4-异常-log4j配置日志滚动覆盖出现日志丢失问题

4-异常-log4j配置日志打印滚动覆盖出现日志丢失问题(附源码分析) 更多内容欢迎关注我&#xff08;持续更新中&#xff0c;欢迎Star✨&#xff09; Github&#xff1a;CodeZeng1998/Java-Developer-Work-Note 技术公众号&#xff1a;CodeZeng1998&#xff08;纯纯技术文&…

XGBoost预测及调参过程(+变量重要性)--血友病计数数据

所使用的数据是血友病数据&#xff0c;如有需要&#xff0c;可在主页资源处获取&#xff0c;数据信息如下&#xff1a; 读取数据及数据集区分 数据预处理及区分数据集代码如下&#xff08;详细预处理说明见上篇文章--随机森林&#xff09;&#xff1a; import pandas as pd im…

异常封装类统一后端响应的数据格式

异常封装类 如何统一后端响应的数据格式 1. 背景 后端作为数据的处理和响应&#xff0c;如何才能和前端配合好&#xff0c;能够高效的完成任务&#xff0c;其中一个比较重要的点就是后端返回的数据格式。 没有统一的响应格式&#xff1a; // 第一种&#xff1a; {"dat…

探索开源世界:2024年值得关注的热门开源项目推荐

文章目录 每日一句正能量前言GitCode成立背景如何使用GitCode如何把你现有的项目迁移至 GitCode&#xff1f;热门开源项目推荐actions-poetry - 管理 Python 依赖项的 GitLab CI/CD 工具项目概述技术分析应用场景特点项目地址 Spider - 网络爬虫框架项目简介技术分析应用场景项…

【RabbitMQ】异步消息及Rabbitmq安装

https://blog.csdn.net/weixin_73077810/article/details/133836287 https://www.bilibili.com/video/BV1mN4y1Z7t9/ 同步调用和异步调用 如果我们的业务需要实时得到服务提供方的响应&#xff0c;则应该选择同步通讯&#xff08;同步调用&#xff09;。 如果我们追求更高的效…

Jupyter Notebook简介

目录 1.概述 2.诞生背景 3.历史版本 4.安装 5.卸载 6.如何使用 7.菜单和菜单项 8.示例 9.未来展望 10.总结 1.概述 Jupyter Notebook是一种基于Web的交互式计算环境&#xff0c;主要用于数据分析、数据科学、机器学习以及探索性编程等领域。允许用户在单个文档中编写…

批量文本编辑神器:一键拆分每行内容,高效实现批量处理与保存,让文本编辑更高效快捷!

在信息化快速发展的今天&#xff0c;文本编辑已经成为我们工作、学习和生活中不可或缺的一部分。然而&#xff0c;面对大量的文本内容&#xff0c;如何高效地进行编辑和处理&#xff0c;成为了许多人面临的难题。今天&#xff0c;我要向大家介绍一款批量文本编辑神器&#xff0…

【C#】图形图像编程

实验目标和要求&#xff1a; 掌握C#图形绘制基本概念&#xff1b;掌握C#字体处理&#xff1b;能进行C#图形图像综合设计。 运行效果如下所示&#xff1a; 1.功能说明与核心代码 使用panel为画板&#xff0c;完成以下设计内容&#xff1a; 使用pen绘制基础图形&#xff1b;使…

【MYSQL】MYSQL操作库

1.数据库字符编码集/数据库校验集 当我们在数据库中保存数据时&#xff0c;需要存和取时候编码一致&#xff0c;比方说你用汉语保存的数据&#xff0c;当你读的时候为了避免乱码问题&#xff0c;也必须用汉语读&#xff0c;这就叫做数据库字符编码集一致。 当我们进行查找&…

C语言的结构体与联合体

引言 C语言提供了结构体和联合体两种聚合数据类型&#xff0c;使得程序员可以创建包括多个数据类型的复杂数据结构。结构体用于将不同类型的数据组合成一个单元&#xff0c;而联合体用于在同一存储空间中存储不同类型的数据。本篇文章将详细介绍C语言中的结构体和联合体&#x…

快消品经销商如何进行有效的团队激励?

很多经销商会面临员工工作不积极、吃大锅饭的现象&#xff0c;导致企业人力成本浪费严重&#xff0c;工作效率也得不到提升&#xff0c;因此经销商老板们必须进行一些绩效考核&#xff0c;然后开展一些有效的激励政策&#xff0c;这样通过提成激励来提高员工的积极性。 1、梳理…

探地雷达正演模拟,基于时域有限差分方法,四

突然发现第三章后半部分已经讲了使用接收记录成像的问题&#xff0c;所以这一章只讲解简单的数据分析。 &#xff08;均以宽角法数据为例子&#xff0c;剖面法数据处理方式都是相同的&#xff09;假设&#xff0c;我们现在已经获得了一个GPR记录&#xff0c;可以是常用的.sgy格…

有关排序的算法

目录 选择法排序 冒泡法排序 qsort排序&#xff08;快速排序&#xff09; qsort排序整型 qsort排序结构体类型 排序是我们日常生活中比较常见的问题&#xff0c;这里我们来说叨几个排序的算法。 比如有一个一维数组 arr[8] {2,5,3,1,7,6,4,8},我们想要把它排成升序&#…

StarNet实战:使用StarNet实现图像分类任务(一)

文章目录 摘要安装包安装timm 数据增强Cutout和MixupEMA项目结构计算mean和std生成数据集 摘要 https://arxiv.org/pdf/2403.19967 论文主要集中在介绍和分析一种新兴的学习范式——星操作&#xff08;Star Operation&#xff09;&#xff0c;这是一种通过元素级乘法融合不同子…

VS2022 使用C++访问 mariadb 数据库

首先,下载 MariaDB Connector/C++ 库 MariaDB Products & Tools Downloads | MariaDB 第二步,安装后 第三步,写代码 #include <iostream> #include <cstring> #include <memory> #include <windows.h>#include <mariadb/conncpp.hpp>…

使用 Python 进行测试(6)Fake it...

总结 如果我有: # my_life_work.py def transform(param):return param * 2def check(param):return "bad" not in paramdef calculate(param):return len(param)def main(param, option):if option:param transform(param)if not check(param):raise ValueError(…

winform 应用程序 添加 wpf控件后影响窗体DPI改变

第一步&#xff1a;添加 应用程序清单文件 app.manifest 第二步&#xff1a;把这段配置 注释放开&#xff0c;第一个配置true 改成false

Wifi通信协议:WEP,WPA,WPA2,WPA3,WPS

前言 无线安全性是保护互联网安全的重要因素。连接到安全性低的无线网络可能会带来安全风险&#xff0c;包括数据泄露、账号被盗以及恶意软件的安装。因此&#xff0c;利用合适的Wi-Fi安全措施是非常重要的&#xff0c;了解WEP、WPA、WPA2和WPA3等各种无线加密标准的区别也是至…