VIT总结

关于transformer、VIT和Swin T的总结

1.transformer

1.1.注意力机制

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.[1]
输入是query和 key-value,注意力机制首先计算query与每个key的关联性(compatibility)每个关联性作为每个value的权重(weight),各个权重与value的乘积相加得到输出

Attention Is All You Need 中用到的attention叫做“Scaled Dot-Product Attention”,具体过程如下图所示:
在这里插入图片描述
代码实现:

import torch
import torch.nn as nnclass SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size), "Embed size needs  to  be div by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads * self.head_dim, embed_size)def forward(self, values, keys, query, mask):N = query.shape[0]  # the number of training examplesvalue_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split embedding into self.heads piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])# queries shape: (N, query_len, heads, heads_dim)# keys shape: (N, key_len, heads, heads_dim)# energy shape: (N, heads, query_len, key_len)if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))# Fills elements of self tensor with value where mask is Trueattention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim)# attention shape: (N, heads, query_len, key_len)# values shape: (N, value_len, heads, head_dim)# after einsum (N, query_len, heads, head_dim) then flatten last two dimensionsout = self.fc_out(out)return out

1.为什么有mask?
NLP处理不定长文本需要padding,但是padding的内容无意义,所以处理时需要mask.
2.关于qkv
qkv是相同的,需要查询的q,与每一个key相乘得到权重信息,权重与v相乘,这样结果受权重大的v影响
3.为什么除以根号dk

We suspect that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients 4. To counteract this effect, we scale the dot products by 1 √dk
点积过大,经过softmax,进入饱和区,梯度很小

4.为什么需要多头
在这里插入图片描述
不同头部的output就是从不同层面(representation subspace)考虑关联性而得到的输出。

1.2.TransformerBlock

解码端的后面两部分和编码段一样,所以打包成一个类
在这里插入图片描述

class TransformerBlock(nn.Module):def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion * embed_size),nn.ReLU(),nn.Linear(forward_expansion * embed_size, embed_size))self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, mask):attention = self.attention(value, key, query, mask)x = self.dropout(self.norm1(attention + query))forward = self.feed_forward(x)out = self.dropout(self.norm2(forward + x))return out

1.3.Encoder

关键的就是位置编码

class Encoder(nn.Module):def __init__(self,src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length):super(Encoder, self).__init__()self.embed_size = embed_sizeself.device = deviceself.word_embedding = nn.Embedding(src_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x, mask):N, seq_lengh = x.shapepositions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))for layer in self.layers:out = layer(out, out, out, mask)return out

2.VIT

在这里插入图片描述

Reference:

[1].Attention Is All You Need
[2].https://zhuanlan.zhihu.com/p/366592542
[3].代码实现:https://zhuanlan.zhihu.com/p/653170203
[4].An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/190452.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

YOLOv8优化策略:SENetV2,squeeze和excitation全面升级,效果优于SENet | 2023年11月最新成果

🚀🚀🚀本文改进: SENetV2,squeeze和excitation全面升级,作为注意力机制引入到YOLOv8,放入不同网络位置实现涨点 🚀🚀🚀YOLOv8改进专栏:http://t.csdnimg.cn/hGhVK 学姐带你学习YOLOv8,从入门到创新,轻轻松松搞定科研; 1.SENetV2 论文:https://arxiv.org/…

C#中GDI+图形图像绘制(直线、矩形、圆、椭圆、圆弧、扇形、多边形)

目录 一、直线 二、矩形 三、椭圆 四、圆 五、圆弧 六、扇形 七、多边形 八、示例源码 一、直线 调用Graphics类中的DrawLine()方法,结合Pen对象可以绘制直线。DrawLine()方法有以下两种构造函数。 第一种用于绘制一条连接两个Point结构的线。当参数pt1的值…

状态类算法复杂排序输出

对于目标检测任务中对某一类的检测结果进行输出的时候,一般都是无序的,很明显这样子很难满足的我们的需求,我们更喜欢他是这样子输出的: 👇 我们可以看到——”按顺序输出结果“中的字段是完美的和上面图片中的识别结…

大三上oracle数据库期末复习

1、创建表空间 2、创建用户 3、用户授权 oracle数据库逻辑存储结构: 1、表空间(最大的逻辑存储单元) 创建表空间 2、段 3、盘区(最小的磁盘空间分配单元) 4、数据块(最小的数据读写单元) 用…

thinkphp 5.1 对数据库查出来的字段进行预处理

比如数据库的设计是下面这样子&#xff1a; 我想展示的是这个样子&#xff1a; 前端可以处理。 Think PHP的处理方式&#xff1a; 定义属性 &#xff1a; $this->customize 任意值;//这里的之没有作用 <?phpnamespace app\hs\controller\shop;use app\daogou\mo…

分享4个工具,轻松搞定PDF和图像中提取文本

大型语言模型已经席卷了互联网&#xff0c;导致更多的人没有认真关注使用这些模型最重要的部分&#xff1a;高质量的数据&#xff01; 本文旨在提供一些有效从任何类型文档中提取文本的技术。 Python库 本文专注于Pytesseract、easyOCR、PyPDF2和LangChain库。实验数据是一个…

计算机网络TCP篇①

目录 一、TCP 基本信息 1.1、TCP 的头格式 1.2、什么是 TCP 1.3、什么是 TCP 连接 1.4、TCP 与 UDP 的区别 1.2、TCP 连接建立 1.2.1、TCP 三次握手的过程 1.2.2、为什么是三次握手&#xff1f;不是两次&#xff1f;四次&#xff1f;&#xff08;这个问题真是典中典&am…

深度学习实战63-利用自适应混合金字塔网络实现人脸皮肤美颜效果,快速部署与实现一键美颜功能

大家好,我是微学AI,今天给大家介绍一下深度学习实战63-利用自适应混合金字塔网络实现人脸皮肤美颜效果,快速部署与实现一键美颜功能。在本文中,我将介绍一种新颖的自适应混合金字塔网络(ABPN),该网络可以实现对超高分辨率照片的快速局部修饰。该网络主要由两个组件组成:一…

你知道Canary金丝雀版本的由来吗

Canary金丝雀版本是一种软件开发中常见的概念&#xff0c;它作为一种测试和试用版&#xff0c;旨在保护用户安全性和隐私&#xff0c;同时促进创新和改进。本文主要介绍Canary版本的由来。 随着技术的不断进步&#xff0c;软件开发变得越来越复杂且困难。为了满足用户需求并提…

【渗透】记录阿里云CentOS一次ddos攻击

文章目录 发现防御 发现 防御 流量清洗 使用高防

io基础入门

压缩的封装 参考&#xff1a;https://blog.csdn.net/qq_29897369/article/details/120407125?utm_mediumdistribute.pc_relevant.none-task-blog-2defaultbaidujs_baidulandingword~default-0-120407125-blog-120163063.235v38pc_relevant_sort_base3&spm1001.2101.3001.…

【数据结构(五)】递归

文章目录 1. 递归的概念2. 递归能解决什么问题3. 递归的规则4. 递归实际应用案例4.1. 迷宫问题4.2. 八皇后问题4.2.1. 思路分析4.2.1. 代码实现 1. 递归的概念 简单的说: 递归就是方法自己调用自己&#xff0c;每次调用时传入不同的变量。递归有助于编程者解决复杂的问题&…

数据结构 - 堆:TOP-K问题

问题描述 TOP-K问题&#xff1a;即求数据结合中前K个最大的元素或者最小的元素&#xff0c;一般情况下数据量都比较大 比如&#xff1a;专业前10名、世界500强、富豪榜、游戏中前100的活跃玩家等 对于Top-K问题&#xff0c;能想到的最简单直接的方式就是排序&#xff0c;但是&…

Linux部署elasticsearch集群

文章目录 一、集群规划二、安装前准备(所有节点操作)创建数据目录修改系统配置文件/etc/sysctl.conf创建用户组设置limits.conf 三、初始化配置(在节点1上操作)下载安装包解压安装包修改jvm.options文件下配置的所占内存修改集群配置文件elasticsearch.yml将安装包传到另外两个…

00后卷王真的很卷吗?

前言 都在传00后躺平、整顿职场&#xff0c;但该说不说&#xff0c;是真的卷&#xff0c;感觉我都要被卷废了... 前段时间&#xff0c;公司招了一个年轻人&#xff0c;其中有一个是00后&#xff0c;工作才一年多&#xff0c;直接跳槽到我们公司&#xff0c;薪资据说有18K&…

Linux学习——模拟实现mybash小程序

目录 一&#xff0c;跟正宗的bash见个面 二&#xff0c;实现一个山寨的bash 1.提示符 2.输入命令与回显命令 3.解析命令 4.执行命令 5.执行逻辑 三&#xff0c;全部代码 一&#xff0c;跟正宗的bash见个面 在这篇文章中&#xff0c;我会写一个myshell小程序。这个小程序…

logback-spring.xml详解

《springboot使用logback日志框架超详细教程》文中&#xff0c;filter中最重要的两个过滤器LevelFilter&#xff08;日志级别精确匹配&#xff09;、ThresholdFilter&#xff08;阈值过滤&#xff09; 的描述非常准确&#xff1a; springboot使用logback日志框架超详细教程_sp…

SQL Server数据库部署

数据库简介 使用数据库的必要性 使用数据库可以高效且条理分明地存储数据&#xff0c;使人们能够更加迅速、方便地管理数据。数据库 具有以下特点。 》可以结构化存储大量的数据信息&#xff0c;方便用户进行有效的检索和访问。 》 可以有效地保持数据信息的一致性&#xff0c…

【Casbin】一篇文章入门Casbin

Casbin Casbin模型基础&#xff08;PERM&#xff09;Policy定义Request定义MatchersEffect ACL模型RBAC模型Go语言实战使用前先下载casbin包新建一个Casbin enforcer判断是否能通过增加Policy删除Policy更新Policy获取Policy Casbin 权限管理在几乎每个系统中都是必备的模块。…

java设计模式学习之【桥接模式】

文章目录 引言桥接模式简介定义与用途&#xff1a;实现方式 使用场景优势与劣势桥接模式在Spring中的应用绘图示例代码地址 引言 想象你正在开发一个图形界面应用程序&#xff0c;需要支持多种不同的窗口操作系统。如果每个系统都需要写一套代码&#xff0c;那将是多么繁琐&am…