Transformer实战-系列教程4:Vision Transformer 源码解读2

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

4、Embbeding类

self.embeddings = Embeddings(config, img_size=img_size)
class Embeddings(nn.Module):"""Construct the embeddings from patch, position embeddings."""def __init__(self, config, img_size, in_channels=3):super(Embeddings, self).__init__()self.hybrid = Noneimg_size = _pair(img_size)if config.patches.get("grid") is not None:grid_size = config.patches["grid"]patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])n_patches = (img_size[0] // 16) * (img_size[1] // 16)self.hybrid = Trueelse:patch_size = _pair(config.patches["size"])n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])self.hybrid = Falseif self.hybrid:self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers,width_factor=config.resnet.width_factor)in_channels = self.hybrid_model.width * 16self.patch_embeddings = Conv2d(in_channels=in_channels,out_channels=config.hidden_size,kernel_size=patch_size,stride=patch_size)self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))self.dropout = Dropout(config.transformer["dropout_rate"])def forward(self, x):# print(x.shape)B = x.shape[0]cls_tokens = self.cls_token.expand(B, -1, -1)# print(cls_tokens.shape)if self.hybrid:x = self.hybrid_model(x)x = self.patch_embeddings(x)#Conv2d: Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))# print(x.shape)x = x.flatten(2)# print(x.shape)x = x.transpose(-1, -2)# print(x.shape)x = torch.cat((cls_tokens, x), dim=1)# print(x.shape)embeddings = x + self.position_embeddings# print(embeddings.shape)embeddings = self.dropout(embeddings)# print(embeddings.shape)return embeddings

接上前面的debug模式,在构造模型部分一直步入到Embbeding类中:

  1. 构造函数,传入了图像大小224*224,通道数3,以及配置参数
  2. patch_size=[16,16],16*16的区域选出一份特征,这个参数自己定义
  3. n_patches,224224的图像能够切分出1616的格子数量,(224/16)(224/16)=1414=196个
  4. 196就是我们要定义的序列的长度了
  5. patch_embeddings,是一个二维卷积,输入通道为3,输出通道为768,卷积核为patch_size=1616,步长为1616,步长为1616就表明原本224224的图像卷积后的长宽就为14*14了
  6. position_embeddings,初始化参数全部为0 ,形状为[1,197,768],197=196+1,加一的原因是在Transformer模型中,通常会在序列的开始添加一个可学习的类标记(class token),它在训练过程中帮助模型捕获全局信息以用于分类任务。position_embeddings是用来记录位置信息的
  7. cls_token,初始化参数全部为0,形状为[1,1,768]
  8. 因为要涉及到全连接层,所以加上Dropout

5、Encoder类

self.encoder = Encoder(config, vis)
class Encoder(nn.Module):def __init__(self, config, vis):super(Encoder, self).__init__()self.vis = visself.layer = nn.ModuleList()self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)for _ in range(config.transformer["num_layers"]):layer = Block(config, vis)self.layer.append(copy.deepcopy(layer))def forward(self, hidden_states):# print(hidden_states.shape)attn_weights = []for layer_block in self.layer:hidden_states, weights = layer_block(hidden_states)if self.vis:attn_weights.append(weights)encoded = self.encoder_norm(hidden_states)return encoded, attn_weights

接上前面的debug模式,在构造模型部分步入到Encoder类中:

  1. 构造函数传进配置参数
  2. vis,设置可视化
  3. layer,设置PyTorch的一个列表
  4. encoder_norm,LayerNorm,Batch Normalization是对Batch做归一化,LayerNorm对层
  5. 循环添加Block:循环config.transformer["num_layers"]次,每次都创建一个Block实例并添加到self.layer中。这里的Block是一个定义了Transformer编码器层的类,它包括自注意力机制和前馈网络。copy.deepcopy(layer)确保每次都是向ModuleList添加一个新的、独立的Block副本

之前ConvNet的任务中,都是使用Batch 做归一化,为什么Transformer是对Layer做归一化呢,Transformer是在NLP任务中提出来的,每一句话的单词个数都不一样,太长的阶段,短的补0,如果是对batch做归一化,长句子的后面一些地方要和短句子补0的地方做归一化,改用Layer归一化实现显著提升效果的情况。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/665800.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(已解决)spingboot 后端发送QQ邮箱验证码

打开QQ邮箱pop3请求服务&#xff1a;&#xff08;按照QQ邮箱引导操作&#xff09; 导入依赖&#xff08;不是maven项目就自己添加jar包&#xff09;&#xff1a; <!-- 邮件发送--><dependency><groupId>org.springframework.boot</groupId><…

vite打包原理

vite 工程化开发&#xff1a;打包工具 启动速度很快 核心原理还是webpack 把webpack封装了&#xff0c;把webpack对象封装了 和vue2整体结构几乎一致 webpack两种模式&#xff1a;开发&生产 代码打包编译&#xff0c;本地起一个web服务器实时预览编译后的结果 build 命令模…

2024.2.3

单向循环链表的头插 头删 尾插和尾删 //头结点插入 Linklist insere_element(Linklist head,datatype element) {Linklist screat();s->dataelement;if(NULLhead){heads;}else{Linklist phead;while(p->next!head){pp->next;}s->nexthead;heads;p->nexthead;}r…

太强了,AI数字人从制作到变现一次搞定

AI数字人从制作到变现 如果说GPT类大模型是我们人类的第二大脑&#xff0c;数字人就是我们人类在互联网上的第二个身体。随着 AI 的迅速发展&#xff0c;2024 年 AI 模型开始从大型语言模型向大型视觉模型转变。数字人技术作为其分支之一&#xff0c;正日益成为科技、娱乐、教…

Unity项目从built-in升级到URP(包含早期版本和2023版本)

unity不同版本的升级URP的方式不一样&#xff0c;但是大体流程是相似的 首先是加载URP包 Windows -> package manager,在unity registry中找到Universal RP 2023版本&#xff1a; 更早的版本&#xff1a; 创建URP资源和渲染器​​ 有些版本在加载时会自动创建&#…

ProcessSlot构建流程分析

ProcessorSlot ProcessorSlot构建流程 // com.alibaba.csp.sentinel.CtSph#lookProcessChain private Entry entryWithPriority(ResourceWrapper resourceWrapper, int count, boolean prioritized, Object... args)throws BlockException {// 省略创建 Context 的代码// 黑盒…

Optimizer:基于.Net开发的、提升Windows系统性能的终极开源工具

我们电脑使用久了后&#xff0c;就会产生大量的垃圾文件、无用的配置等&#xff0c;手动删除非常麻烦&#xff0c;今天推荐一个开源工具&#xff0c;可以快速帮助我们更好的优化Windos电脑。 01 项目简介 Optimizer是一个面向Windows系统的优化工具&#xff0c;旨在提升计算机…

Qt应用软件【数据篇】大小端数据转换

文章目录 大小端数据介绍大小端数据在内存中的样子C大小端数据转换QtAPI大小端转换 大小端数据介绍 大端&#xff08;Big Endian&#xff09;和小端&#xff08;Little Endian&#xff09;是一种描述计算机存储多字节数据的方式。 想象一下&#xff0c;你有一串数字&#xff0c…

vulhub中spring的CVE-2022-22947漏洞复现

Spring Cloud Gateway是Spring中的一个API网关。其3.1.0及3.0.6版本&#xff08;包含&#xff09;以前存在一处SpEL表达式注入漏洞&#xff0c;当攻击者可以访问Actuator API的情况下&#xff0c;将可以利用该漏洞执行任意命令。 参考链接&#xff1a; https://tanzu.vmware.c…

【用Unity开发一款横板跳跃游戏部分需要学习的技术点指南】

*** 用Unity开发一款横板跳跃游戏部分需要学习的技术点指南 空洞骑士是一款基于横板平台跳跃的传统风格2D动作冒险游戏&#xff0c;庞大的游戏世界交错相通&#xff0c;玩家控制小虫子去探索幽深黑暗的洞穴&#xff0c;成为了一代人茶余饭后的惦念&#xff0c;深受玩家喜爱。 …

类银河恶魔城学习记录1-6 Flip基本设置源代码 P33

Alex教程每一P的教程原代码加上我自己的理解初步理解写的注释&#xff0c;可供学习Alex教程的人参考 【Unity教程】从0编程制作类银河恶魔城游戏_哔哩哔哩_bilibili Player.cs using System.Collections; using System.Collections.Generic; using Unity.VisualScripting; us…

JAVA 反射和动态管理(二十-完)

反射和动态管理&#xff08;二十-完&#xff09; 反射 反射允许对字段&#xff08;成员变量&#xff09;&#xff0c;成员方法&#xff0c;构造方法的信息进行编程访问。 反射操作可分为获取和解刨。 获取不是从java文件获取&#xff0c;而是从class字节码文件获取。 作用…

MySQL全表扫描:性能杀手的隐患与优化策略

MySQL全表扫描&#xff1a;性能杀手的隐患与优化策略 MySQL数据库作为常用的关系型数据库管理系统之一&#xff0c;全表扫描问题一直困扰着开发者。本文将深入剖析MySQL全表扫描的原理、其对性能的严重影响&#xff0c;同时提供一系列优化策略&#xff0c;助您高效应对MySQL性能…

【NodeJS】fs 模块 (2)

流式文件写入 & 读取 流式文件写入 / 读取适合操作大文件 流式写入 ① 创建可写流&#xff1a;fs.createWriteStream(path[, options]) path&#xff1a;文件路径options&#xff1a;配置对象 flags&#xff1a;文件系统标志&#xff0c;默认值为 wencoding&#xff1a;…

Android battery saver 简单记录

目录 一. battery saver模式的policy (1) DEFAULT_FULL_POLICY 对应的配置和解释: (2) OFF_POLICY 对应的配置也就说不使用policy (3) 获取省电模式下的policy: 二. 对各个参数代码讲解 (1) adjustBrightnessFactor (2) enableAdjustBrightness (3) advertiseIsEnabled…

ctfshow——文件包含

文章目录 web 78——php伪协议第一种方法——php://input第二种方法——data://text/plain第三种方法——远程包含&#xff08;http://协议&#xff09; web 78——str_replace过滤字符php第一种方法——远程包含&#xff08;http://协议&#xff09;第二种方法——data://&…

070:vue中provide、inject的使用方法(图文示例)

第070个 查看专栏目录: VUE 本文章目录 示例背景示例效果图示例源代码父组件代码子组件代码孙组件代码 基本使用步骤 示例背景 本教程是介绍如何在vue中使用provide和inject。在 Vue 中&#xff0c;provide 和 inject 是用于实现祖先组件向后代组件传递数据的一种方式。 在这个…

oracle 触发器事前触发和事后触发区别

Oracle触发器的事前触发和事后触发主要在触发的时机和触发器内部的操作上有所区别。 触发时机&#xff1a;事前触发器是在触发事件发生之前运行&#xff0c;而事后触发器则在触发事件发生之后运行。 获取的数据&#xff1a;事前触发器通常可以获取到事件发生前和新的字段值。O…

Docker存储空间清理

不知不觉服务器存储空间被Docker掏空了… 查看Docker空间占用情况 使用docker system df命令&#xff0c;可以加 -v 查看详情 清理Docker不需要的内容 使用docker system prune -a命令清理Docker 所有停止的容器所有没有被使用的networks所有没容器的镜像所有build cache …

公共用例库计划--个人版(六)典型Bug页面设计与开发

1、任务概述 本次计划的核心任务是开发一个&#xff0c;个人版的公共用例库&#xff0c;旨在将各系统和各类测试场景下的通用、基础以及关键功能的测试用例进行系统性地归纳整理&#xff0c;并以提高用例的复用率为目标&#xff0c;力求最大限度地减少重复劳动&#xff0c;提升…