ViT的极简pytorch实现及其即插即用

先放一张ViT的网络图
在这里插入图片描述
可以看到是把图像分割成小块,像NLP的句子那样按顺序进入transformer,经过MLP后,输出类别。每个小块是16x16,进入Linear Projection of Flattened Patches, 在每个的开头加上cls token和位置信息,也就是position embedding。
去掉数据读取部分,直接上一个极简的ViT代码:

import torch
from torch import nnfrom einops import rearrange, repeat
from einops.layers.torch import Rearrange# helpersdef pair(t):return t if isinstance(t, tuple) else (t, t)# classesclass PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout = 0.):super().__init__()self.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)class Attention(nn.Module):def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):super().__init__()inner_dim = dim_head *  headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim = -1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):qkv = self.to_qkv(x).chunk(3, dim = -1)## 对tensor张量分块 x :1 197 1024   qkv 最后是一个元祖,tuple,长度是3,每个元素形状:1 197 1024q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)out = torch.matmul(attn, v)out = rearrange(out, 'b h n d -> b n (h d)')return self.to_out(out)class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):super().__init__()self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn xclass ViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):super().__init__()image_height, image_width = pair(image_size)   # 224*224patch_height, patch_width = pair(patch_size)   # 16 * 16assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'num_patches = (image_height // patch_height) * (image_width // patch_width)patch_dim = channels * patch_height * patch_widthassert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'self.to_patch_embedding = nn.Sequential(# (b,3,224,224) -> (b,196,768)    14*14=196  16*16*3=768Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),nn.Linear(patch_dim, dim),    # (b,196,1024))self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))self.cls_token = nn.Parameter(torch.randn(1, 1, dim))self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)self.pool = poolself.to_latent = nn.Identity()self.mlp_head = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, num_classes))def forward(self, img):x = self.to_patch_embedding(img)        # img 1 3 224 224  输出形状x : 1 196 1024b, n, _ = x.shape                       # 1 196cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)    # (1,1,1024)x = torch.cat((cls_tokens, x), dim=1)   # (1,197,1024)x += self.pos_embedding[:, :(n + 1)]    # (1,197,1024)x = self.dropout(x)                     # (1,197,1024)x = self.transformer(x)                 # (1,197,1024)x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]     # (1,1024)x = self.to_latent(x)      # (1,1024)return self.mlp_head(x)    # (1,1000)if __name__ == '__main__':v = ViT(image_size = 224,patch_size = 16,num_classes = 1000,dim = 1024,depth = 6,heads = 16,mlp_dim = 2048,dropout = 0.1,emb_dropout = 0.1)img = torch.randn(1, 3, 224, 224)preds = v(img)        # (1, 1000)print(preds.shape)

去掉cls和最后的全连接分类头,变成即插即用的模块:

import torch
from torch import nnfrom einops import rearrange
from einops.layers.torch import Rearrange# helpersdef pair(t):return t if isinstance(t, tuple) else (t, t)# classesclass PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout = 0.):super().__init__()self.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)class Attention(nn.Module):def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):super().__init__()inner_dim = dim_head *  headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim = -1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):qkv = self.to_qkv(x).chunk(3, dim = -1)## 对tensor张量分块 x :1 197 1024   qkv 最后是一个元祖,tuple,长度是3,每个元素形状:1 197 1024q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)out = torch.matmul(attn, v)out = rearrange(out, 'b h n d -> b n (h d)')return self.to_out(out)class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):super().__init__()self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn xclass ViT(nn.Module):def __init__(self, *, image_size, patch_size, dim = 1024, depth = 3, heads = 16, mlp_dim = 2048, dim_head = 64, dropout = 0.1, emb_dropout = 0.1):super().__init__()channels, image_height, image_width = image_size   # 256,64,80patch_height, patch_width = pair(patch_size)       # 4*4assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'num_patches = (image_height // patch_height) * (image_width // patch_width)     # 16*20patch_dim = 64 * patch_height * patch_width    # 64*8*10self.conv1 = nn.Conv2d(256, 64, 1)self.to_patch_embedding = nn.Sequential(# (b,64,64,80) -> (b,320,1024)    16*20=320  4*4*64=1024Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),nn.Linear(patch_dim, dim),    # (b,320,1024))self.to_img = nn.Sequential(# b c (h p1) (w p2) -> (b,64,64,80)      16*20=320  4*4*64=1024Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', \p1 = patch_height, p2 = patch_width, h = image_height // patch_height, w = image_width // patch_width),nn.Conv2d(64, 256, 1),      # (b,64,64,80) -> (b,256,64,80))# 位置编码self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)def forward(self, img):x = self.conv1(img)                     # img 1 256 64 80 -> 1 64 64 80x = self.to_patch_embedding(x)          # 1 320 1024b, n, _ = x.shape                       # 1 320x += self.pos_embedding[:, :(n + 1)]    # (1,320,1024)x = self.dropout(x)                     # (1,320,1024)x = self.transformer(x)                 # (1,320,1024)x = self.to_img(x)return x                                # (1 256 64 80)if __name__ == '__main__':v = ViT(image_size = (256,64,80), patch_size = 4)img = torch.randn(1, 256, 64, 80)preds = v(img)         # (1, 256, 64, 80)print(preds.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/585096.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mysql5.7主从数据库同步失败(日记文件错误)解决记录

记录一次Mysql主从数据库同步失败(日记文件错误)解决记录 查看同步状态: 具体错误: 检查mysql数据库日记 2021-06-10T03:45:43.522398Z 1 [ERROR] Error reading packet from server for channel : event read from binlog did not pass crc check; the…

Oracle 拼接字符串

语法 使用||拼接如果内容中有单引号,则可在该单引号前面再加一个单引号进行转义 例子 比如有一个业务是根据需要生成多条插入语句 select insert into des_account_des_role(des_account_id, roles_id) values( || id || , || (select id from des_role where wo…

Ps:八大混合模式及其在色彩渲染上的运用

在所有的图层混合模式中,有八种比较特别。 特别之处在于,其它的混合模式在修改图层的“不透明度”或“填充”时,效果是一样的。 而这八种混合模式使用“填充”比使用“不透明度”可带来更好的效果,有时甚至可以说是惊艳。 提示&am…

ubuntu下编译obs-studio遇到的问题记录

参考的是这篇文档:Build Instructions For Linux obsproject/obs-studio Wiki GitHub 在安装OBS dependencies时, sudo apt install libavcodec-dev libavdevice-dev libavfilter-dev libavformat-dev libavutil-dev libswresample-dev libswscale-d…

[ 云计算 | AWS ] 对比分析:Amazon SNS 与 SQS 消息服务的异同与选择

文章目录 一、前言二、Amazon SNS 服务(Amazon Simple Notification Service)三、Amazon SQS 服务(Amazon Simple Queue Service)四、SNS 与 SQS 的区别(本文重点)4.1 基于推送和轮询区别4.2 消费者数量对应…

HBuilder常用的快捷键

查看专栏目录 Network 灰鸽宝典专栏主要关注服务器的配置,前后端开发环境的配置,编辑器的配置,网络服务的配置,网络命令的应用与配置,windows常见问题的解决等。 文章目录 常用快捷键分9项快捷键1.文件(4)2.编辑(13)3.…

网络安全应急响应工具之-流量安全取证NetworkMiner

在前面的一些文章中,用了很多的章节介绍流量分析和捕获工具wireshark。Wireshark是一款通用的网络协议分析工具,非常强大,关于wireshark的更多介绍,请关注专栏,wireshark从入门到精通。本文将介绍一个专注于网络流量取…

计算机网络【DHCP动态主机配置协议】

DHCP 出现 电脑或手机需要 IP 地址才能上网。大刘有两台电脑和两台手机,小美有一台笔记本电脑、一台平板电脑和两台手机,老王、阿丽、敏敏也有几台终端设备。如果为每台设备手动配置 IP 地址,那会非常繁琐,一点儿也不方便。特别是…

k8s:kubernets

自动部署、自动扩展和管理的容器化部署的应用程序的一个开源系统 k8s负责自动化运维管理多个容器化程序的集群,是一个功能强大的容器编排工具 可以以分布式和集群化的方式进行容器管理 1.18版本,目前最多的是1.20版本,最新的是1.29版本&am…

Java IDEA JUnit 单元测试

JUnit是一个开源的 Java 单元测试框架,它使得组织和运行测试代码变得非常简单,利用JUnit可以轻松地编写和执行单元测试,并且可以清楚地看到哪些测试成功,哪些失败 JUnit 还提供了生成测试报告的功能,报告不仅包含测试…

什么是PD快充诱骗芯片?它的原理是什么?

PD快充诱骗芯片,顾名思义,就是通过LDR6328Q PD取电芯片把pd适配器的电压给诱骗出来固定给后端设备供电。 PD诱骗芯片是受电端的一种PD协议芯片,它内置了PD通讯模块,通过与供电端(如PD充电器)的PD协议芯片握…

Ubuntu 安装MySQL以及基本使用

前言 MySQL是一个开源数据库管理系统,通常作为流行的LAMP(Linux,Apache,MySQL,PHP / Python / Perl)堆栈的一部分安装。它使用关系数据库和SQL(结构化查询语言)来管理其数据。 安装…

k8s之陈述式资源管理

1.kubectl命令 kubectl version 查看k8s的版本 kubectl api-resources 查看所有api的资源对象的名称 kubectl cluster-info 查看k8s的集群信息 kubectl get cs 查看master节点的状态 kubectl get pod 查看默认命名空间内的pod的信息 kubectl get ns 查看当前集群所有的命…

HTML实战演练之贪吃蛇美食大作战

导入: 一 :粉丝要求 今天一位小伙伴私信我说,想玩HTML贪吃蛇美食大作战,自己也是学HTML的,希望我能安排一下,那么好它来了 需知: 一:别着急先看需要知道的 要用HTML开发贪吃蛇美食…

很实用的ChatGPT网站——httpchat-zh.com

很实用的ChatGPT网站——http://chat-zh.com/ 今天介绍一个好兄弟开发的ChatGPT网站,网址[http://chat-zh.com/]。这个网站功能模块很多,包含生活、美食、学习、医疗、法律、经济等很多方面。下面简单介绍一些部分功能与大家一起分享。 登录和注册页面…

【Java EE初阶五】wait及notify关键字

1. wait和notify的概念 所谓的wait和notify其实就是等待、通知机制;该机制的作用域join类似;由于多个线程之间是随机调度的,引入wait和notify就是为了能够从应用层面上,干预到多个不同线程代码的执行顺序,此处的干预&a…

Java动态代理机制 代码示例demo

文章目录 JDK动态代理代码实现示例1.定义发送短信的接口2.实现发送短信的接口3.定义一个 JDK 动态代理类4.获取代理对象的工厂类5.实际使用 JDK 动态代理只能代理实现了接口的类CGLIB动态代理代码实现示例1.实现一个使用阿里云发送短信的类2.自定义 MethodInterceptor&#xff…

Java——猫猫图鉴微信小程序(前后端分离版)

目录 一、开源项目 二、项目来源 三、使用框架 四、小程序功能 1、用户功能 2、管理员功能 五、使用docker快速部署 六、更新信息 审核说明 一、开源项目 猫咪信息点-ruoyi-cat: 1、一直想做点项目进行学习与练手,所以做了一个对自己来说可以完成的…

HCIP:rip综合实验

实验要求: 【R1-R2-R3-R4-R5运行RIPV2】 【R6-R7运行RIPV1】 1.使用合理IP地址规划网络,各自创建环回接口 2.R1创建环回 172.16.1.1/24 172.16.2.1/24 172.16.3.1/24 3.要求R3使用R2访问R1环回 4.加快网络收敛,减少路由条目数量,增…

Unity中URP下的添加雾效支持

文章目录 前言一、URP下Shader支持雾效的步骤1、添加雾效变体2、在Varying结构体中添加雾效因子3、在顶点着色器中,我们使用内置函数得到雾效因子4、在片元着色器中,把输出颜色 和 雾效因子混合输出 二、在Unity中打开雾效三、测试代码 前言 我们使用之…