Vision Transfomer系列第一节---从0到1的源码实现

本专栏主要是深度学习/自动驾驶相关的源码实现,获取全套代码请参考

这里写目录标题

  • 准备
  • 逐步源码实现
    • 数据集读取
    • VIt模型搭建
      • hand
      • 类别和位置编码
          • 类别编码
          • 位置编码
      • blocks
      • head
      • VIT整体
    • Runner(参考mmlab)
    • 可视化
  • 总结

准备

在这里插入图片描述
本博客完成Vision Transfomer(VIT)模型的搭建和flowers数据集的训练测试.整个源码包括如下几个任务:
1.读取flowers数据集的dataset类,对应文件dataset.py
2.VIT模型搭建,主要依赖于上几篇博客,对应model.py

1.transfomer中Multi-Head Attention的源码实现的MultiheadAttention类,用于搭建BaseTransformerLayer类,实现encoder和decoder功能
2.transfomer中Decoder和Encoder的base_layer的源码实现的BaseTransformerLayer类,帮助我们丝滑地搭建各类transformer网络
3.transfomer中正余弦位置编码的源码实现[可选]

3.设置优化器学习率和训练/验证模型,对应runner.py和train.py
4.可视化测试单个图片的预测结果,对应demo.py

逐步源码实现

源码结构如下
在这里插入图片描述

数据集读取

主要原理:根据dataset的路径,存储各个图片对应的路径,label隐藏在路径中.
在getitem函数中完成指定index图片和label的读取和数据增强功能

class Flowers(Dataset):# 用于读取flower数据集def __init__(self, dataset_path: str, transforms=None):'''存储所有数据 data路径和label:param dataset_path:'''super(Dataset, self).__init__()flowers = os.listdir(dataset_path)flowers = sorted(flowers) # 必须排序,否在每一次顺序不一样训练测试类别就会乱self.flower_paths = []self.class2label = {}  # 类别str 转 labellabel = 0for _, flower in enumerate(flowers):flowers_path = os.path.join(dataset_path, flower)if os.path.isdir(flowers_path):self.class2label[flower] = labellabel +=1sub_flowers = os.listdir(flowers_path)for sub_flower in sub_flowers:self.flower_paths.append(os.path.join(flowers_path, sub_flower))self.label2class = label2class(self.class2label)  # label 转 类别strself.transforms = transforms''''''def __getitem__(self, item):# 读取数据和labelimg = Image.open(self.flower_paths[item])label = self.class2label[self.flower_paths[item].split('/')[-2]]if self.transforms is not None:img = self.transforms(img)  # 数据增强return img, label

VIt模型搭建

将整个深度学习模型按照人体分为hand+backbone+neck+head 4个部分,Vit模型不同CNN模型,它的backbone+neck为多个MultiHeadAttention堆叠组成,称之为blocks.

hand

hand主用完成预处理,将数据用"手"揉捏成想要的类型.本处主要完成图片的patch操作,将图片分割成一个个小块,使用大核的卷积完成.然后把w和h拉平后shape就和NLP(b,n,d)一样了.

class PatchLayer(nn.Module):def __init__(self, img_size, patch_size=20, embeding_dim=64):super(PatchLayer, self).__init__()self.grid_size = (img_size[0] // patch_size, img_size[1] // patch_size)self.num_patches = self.grid_size[0] * self.grid_size[1]self.proj = nn.Conv2d(in_channels=3,out_channels=embeding_dim,kernel_size=(patch_size, patch_size),stride=patch_size,padding=0)self.norm = nn.LayerNorm(normalized_shape=embeding_dim)def forward(self, img):img = self.proj(img)  # 图片分割img = img.flatten(start_dim=2)  # wh拉平img = img.permute(0, 2, 1)  # [b wh c]img = self.norm(img)return img

类别和位置编码

类别编码

直接cat到input上面,那么最后也取出对应的那一列作为类别输出.这是transformer类型网络的常用手段.
个人解释:训练出类别的访问者,这个访问者可以从特征信息(原input)中提取类别信息.训练访问者方法就是类别loss回归,训练时候先推出,推理时推出

位置编码

add到input上,可以使用可学习式的位置编码也可以使用正余弦位置编码.这是transformer类型网络的常用手段,还要特征层编码等
个人解释:训练出位置的标记者

        # 类别编码self.cls_token = nn.Parameter(torch.zeros(size=[1, 1, embed_dim]))# 固定位置编码和可学习位置编码# self.pos_embed = posemb_sincos_1d(len=num_patches + 1, dim=embed_dim,temperature=1000).unsqueeze(0)self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

blocks

blocks使用注意力机制完成特征提取,
个人解释:
input线性映射为[query,key,value],需求侧(query)从供给侧(value)中取值,取值的根据是qurey@key转置生成的注意力矩阵(需求侧和供给侧每个像素之间的相似度),最后输出与输入shape相同.所以我们重复depth次,多次特征提取.
源码直接调用:transfomer中Decoder和Encoder的base_layer的源码实现的BaseTransformerLayer类

head

主要对transfomer输出的类别特征进行映射,embed维度映射为num_class维度

self.head = nn.Linear(embed_dim, num_classes)

VIT整体

主要是上述几个模块的集合及其正向传播过程:
完成二维图片变一维特征,一维特征transfomer特征提取,分类头输出.

class Vit(nn.Module):def __init__(self, img_size=[224, 224], patch_size=16, num_classes=1000,embed_dim=768, depth=12, num_heads=12):super(Vit, self).__init__()self.patch_embed = PatchLayer(img_size, patch_size, embed_dim)num_patches = self.patch_embed.num_patchesself.blocks = nn.Sequential(*[BaseTransformerLayer(attn_cfgs=[dict(embed_dim=embed_dim, num_heads=num_heads)],fnn_cfg=dict(embed_dim=embed_dim, feedforward_channels=4 * embed_dim, act_cfg='ReLU',ffn_drop=0.),operation_order=('self_attn', 'norm', 'ffn', 'norm'))for _ in range(depth)])# 类别编码self.cls_token = nn.Parameter(torch.zeros(size=[1, 1, embed_dim]))# 固定位置编码和可学习位置编码# self.pos_embed = posemb_sincos_1d(len=num_patches + 1, dim=embed_dim,temperature=1000).unsqueeze(0)self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))# 分类头self.head = nn.Linear(embed_dim, num_classes)self.loss_class = nn.CrossEntropyLoss()  # 内置softmaxself.init_weights()''''''def forward(self, img):query = self.hand(img)query = self.extract_feature(query)cls_fea = query[:, -1, :]  # 刚刚class_token被cat到了dim1的最后一个数x = self.head(cls_fea)return x

Runner(参考mmlab)

建立优化前,设置学习率,根据指定的work_flow顺序进行训练的测试,并保留最优权重

class Runner:def __init__(self, arg, model, device):self.arg = arg# 建立优化器params = [p for p in model.parameters() if p.requires_grad]self.optimizer = torch.optim.SGD(params=params, lr=arg.lr, momentum=0.9, weight_decay=5E-5)lf = lambda x: ((1 + math.cos(x * math.pi / arg.epochs)) / 2) * (1 - arg.lrf) + arg.lrf  # cosineself.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lf)self.model = model.to(device)self.device = deviceif arg.load_from is not None and arg.load_from != '':weight_dict = torch.load(arg.load_from, map_location=device)model.load_state_dict(weight_dict)def run(self, dataloaders: dict):# 开始训练和验证assert 'train' in self.arg.work_flow.keys(), '必须要用训练任务'epoch_start = 0best_accuracy = 0.0while epoch_start < self.arg.epochs:for task, times in self.arg.work_flow.items():if task == 'train':  # 开始训练for _ in range(times):epoch_start += 1  # epoch只记录训练轮self.model.train()loss_sum = 0.0data_loader = tqdm(dataloaders['train'], file=sys.stdout)for step, data_dict in enumerate(data_loader):img, label = data_dictinstance = {'data': img.to(self.device),'label': label.to(self.device)}loss = self.model.loss(**instance)loss_sum += loss.detach()  # 要十分注意 避免往计算图中引入新的东西loss.backward()self.optimizer.step()self.optimizer.zero_grad()data_loader.desc = "[train epoch {}] loss: {:.3f}".\format(epoch_start,loss_sum.item() / (step + 1))self.scheduler.step()print('train: epoch={}, loss={}'.format(epoch_start, loss_sum / (step + 1.0)))elif task == 'val':  # 开始验证''''''else:raise ValueError('task must be in [train, val, test]')

可视化

读取单张图片,转换格式输入模型,输出的label,转化为class名和置信度,显示图像,class名和置信度.

if __name__ == '__main__':# 建立数据集data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])img = Image.open('*****daisy/21652746_cc379e0eea_m.jpg')input = data_transform(img).unsqueeze(0)label2class = Flowers(dataset_path='../datasets/flower_photos-mini').label2classdevice = torch.device('cuda:0')# 建立模型model = Vit(img_size=[224, 224],patch_size=16,embed_dim=768,depth=12,num_heads=12,num_classes=5).to(device)weight_dict = torch.load('weights/vit.pth', map_location=device)model.load_state_dict(weight_dict)model.eval()with torch.no_grad():output = model(input.to(device))output = output.detach().cpu()label = output[0].numpy().argmax()cnf = torch.softmax(output[0],dim=0).numpy().max()*100.0cnf = np.around(cnf, decimals=2) #保留2位小数plt.imshow(img)plt.title('{} : {}%'.format(label2class[label],cnf))plt.show()

总结

vit是视觉transfomer最经典的模型,复现一次代码十分有必要,中间会产生很多思考和问题.
后面章节将会更有价值,我将会:

1.利用本次的代码进行很多思考和trick的验证
2.总结本次代码的BUG们,及其产生的原理和解决方法

如需获取全套代码请参考

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/671077.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

股票K线简介

股票K线&#xff08;K-Line&#xff09;是用于表示股票价格走势的图形&#xff0c;主要由四个关键价格点组成&#xff1a;开盘价、收盘价、最高价和最低价。K线图广泛应用于股票市场技术分析中&#xff0c;它提供了丰富的信息&#xff0c;帮助分析师和投资者理解市场的行情走势…

一周学会Django5 Python Web开发-Django5介绍及安装

锋哥原创的Python Web开发 Django5视频教程&#xff1a; 2024版 Django5 Python web开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili2024版 Django5 Python web开发 视频教程(无废话版) 玩命更新中~共计10条视频&#xff0c;包括&#xff1a;2024版 Django5 Python we…

算法练习-四数之和(思路+流程图+代码)

难度参考 难度&#xff1a;中等 分类&#xff1a;数组 难度与分类由我所参与的培训课程提供&#xff0c;但需要注意的是&#xff0c;难度与分类仅供参考。且所在课程未提供测试平台&#xff0c;故实现代码主要为自行测试的那种&#xff0c;以下内容均为个人笔记&#xff0c;旨在…

“过年买年货,花光了我的年终奖”

【潮汐商业评论/原创】 前脚刚进家门&#xff0c;后脚快递电话一个接着一个。 临近春节&#xff0c;Julia是提前批回家的年轻人&#xff0c;与Julia一同到家的还有她的年货。上至大小家电&#xff0c;下到坚果零食&#xff0c;短短几天快递就堆满了客厅。 吃的喝的还能理解&…

MySQL的优化

优化MySQL的几点&#xff1a; 从设计上优化 从查询上优化 从索引上优化 从存储上优化 1&#xff0c;SQL的执行频率 MySQL客户端连接成功后&#xff0c;通过show [session/global] status命令可以查看服务器状态信息。通过查看状态信息可以查看对当前数据库的主要操作类型…

【01】判断素数/质数(C语言)

目录 &#xff08;1&#xff09;素数特点&#xff1a;只能被1和本身整除 &#xff08;2&#xff09;代码如下&#xff1a; &#xff08;3&#xff09;运行结果如下 ​编辑 &#xff08;4&#xff09;函数引申 &#xff08;1&#xff09;素数特点&#xff1a;只能被1和本身…

数字图像处理(实践篇)四十三 OpenCV-Python 使用SURF算法检测图像上的特征点的实践

目录 一 SURF算法概述 1 积分图 2 SURF算法 3 SIFT与SURF 二 涉及的函数 三 实践 一 SURF算法概述

在容器中使用buildah构建镜像

简介 buildah是一个构建OCI标准镜像的工具&#xff0c;可以用来替代docker build 在常见的linux发行版中可直接通过包管理工具安装使用 # centos yum install buildah# ubuntu/debian apt install buildah# alpine apk add buildah其他发行版安装方法详见 github&#xff0c…

Django通过Json配置文件分配多个定时任务

def load_config():with open("rule.json", rb)as f:config json.load(f)return configdef job(task_name, config, time_interval):# ... 通过task_name判断进行操作if task_name get_data_times:passdef main():config load_config()for task_name, task_value…

C++——stack与queue与容器适配器

1.stack和queue的使用 1.1stack的使用 栈这种数据结构我们应该挺熟了&#xff0c;先入后出&#xff0c;只有一个出口(出口靠栈顶近)嘛 stack的底层容器可以是任何标准的容器类模板或者一些其他特定的容器类&#xff0c;这些容器类应该支持以操作&#xff1a; empty&#xff1…

专业知识库:中小型企业必备的高效工具

在如今这个信息爆炸的时代&#xff0c;知识管理已经成为了企业运营的重要环节。特别是对于中小型企业来说&#xff0c;如何有效地管理公司内部的知识&#xff0c;提高工作效率&#xff0c;已经成为了一个亟待解决的问题。在这篇文章中&#xff0c;我将向大家介绍一种能够帮助企…

Python轴承故障诊断入门教学

目录 往期精彩内容&#xff1a; 1 工作室实验平台介绍 2 轴承故障诊断教程—数据集 3 轴承故障诊断教程—算法模型 3.1 振动分析方法 3.2 频域特征提取 3.3 时域特征提取 3.4 模型基础的机器学习方法 3.5 深度学习方法 3.6 时频域融合方法 3.7 信号重构方法 3.8 基…

Linux-----文本三剑客补充~

一、模糊匹配 模糊匹配用 ~ 表示包含&#xff0c;!~表示不包含 1、匹配含有root的列 [rootlocalhost ~]#awk -F: /root/ /etc/passwd root:x:0:0:root:/root:/bin/bash operator:x:11:0:operator:/root:/sbin/nologin [rootlocalhost ~]#awk -F: $1~ /root/ /etc/passw…

知名开发工具RubyMine全新发布v2023.3——支持AI Assistant

RubyMine 是一个为Ruby 和 Rails开发者准备的 IDE&#xff0c;其带有所有开发者必须的功能&#xff0c;并将之紧密集成于便捷的开发环境中。 RubyMine v2023.3正式版下载 新版本改进AI Assistant支持、Rails应用程序和引擎的自定义路径、对Rails 7.1严格locals的代码洞察、RB…

人胰岛素样生长因子-1 ELISA试剂盒IGF-1 (human), ELISA kit

高灵敏ELISA试剂盒&#xff0c;4小时内可得结果&#xff0c;最低可检测34.2 pg/ml的IGF-1 胰岛素样生长因子-1&#xff08;IGF-1&#xff09;是一种多肽激素&#xff0c;在结构上与胰岛素相似。它参与调节中枢和周围神经系统的神经元生长和发育。IGF-1是一种有效的神经元凋亡抑…

【Zookeeper】what is Zookeeper?

官网地址&#xff1a;https://zookeeper.apache.org/https://zookeeper.apache.org/ 以下来自官网的介绍 ZooKeeper is a centralized service for maintaining configuration information, naming, providing distributed synchronization, and providing group services. A…

机试复习-3

前言&#xff1a;前面耽误太多时间&#xff0c;2月份是代码月&#xff0c;一定抓紧赶上&#xff0c;每天至少两道题 day1 2024.2.6 1.排序开启&#xff1a; 1.机试考试&#xff1a;排序应用考察 c的qsort c的sort 作用&#xff1a;对数组&#xff0c;vector排序&#…

c#读取csv文件中的某一列的数据

chat8 (chat779.com) 上面试GPT-3.5,很好的浏览网站&#xff0c;输入问题&#xff0c;可得到答案。 问题1&#xff1a;c#如何在csv中读取某一列数据 解答方案&#xff1a;在 C#中&#xff0c;你可以使用File.ReadAllLines来读取CSV中的所有行&#xff0c;然后逐行解析每一行…

flask+vue+python跨区通勤人员健康体检预约管理系统

跨区通勤人员健康管理系统设计的目的是为用户提供体检项目等功能。 与其它应用程序相比&#xff0c;跨区通勤人员健康的设计主要面向于跨区通勤人员&#xff0c;旨在为管理员和用户提供一个跨区通勤人员健康管理系统。用户可以通过系统及时查看体检预约等。 跨区通勤人员健康管…

Linux 分析指定JAVA服务进程所占内存CPU详情

1、获取服务进程PID [rootVM-32-26-centos ~]# service be3Service status Application is running as root (UID 0). This is considered insecure. Running [25383]2、获取进程占用详情 [rootVM-32-26-centos ~]# cat /proc/25383/status Name: java Umask: 0022 State: S…