transfomer中Decoder和Encoder的base_layer的源码实现

简介

Encoder和Decoder共同组成transfomer,分别对应图中左右浅绿色框内的部分.
在这里插入图片描述
Encoder:
目的:将输入的特征图转换为一系列自注意力的输出。
工作原理:首先,通过卷积神经网络(CNN)提取输入图像的特征。然后,这些特征通过一系列自注意力的变换层进行处理,每个变换层都会将特征映射进行编码并产生一个新的特征映射。这个过程旨在捕捉图像中的空间和通道依赖关系。
作用:通过处理输入特征,提取图像特征并进行自注意力操作,为后续的目标检测任务提供必要的特征信息。
Decoder:
目的:接受Encoder的输出,并生成对目标类别和边界框的预测。
工作原理:首先,它接收Encoder的输出,然后使用一系列解码器层对目标对象之间的关系和全局图像上下文进行推理。这些解码器层将最终的目标类别和边界框的预测作为输出。
作用:基于Encoder的输出和全局上下文信息,生成目标类别和边界框的预测结果。
总结:Encoder就是特征提取类似卷积;Decoder用于生成box,类似head

源码实现:

Encoder 通常是6个encoder_layer组成,Decoder 通常是6个decoder_layer组成
我实现了核心的BaseTransformerLayer层,可以用来定义encoder_layer和decoder_layer

具体源码及其注释如下,配好环境可直接运行(运行依赖于上一个博客的代码):

import torch
from torch import nn
from ZMultiheadAttention import MultiheadAttention  # 来自上一次写的attensionclass FFN(nn.Module):def __init__(self,embed_dim=256,feedforward_channels=1024,act_cfg='ReLU',ffn_drop=0.,):super(FFN, self).__init__()self.l1 = nn.Linear(in_features=embed_dim, out_features=feedforward_channels)if act_cfg == 'ReLU':self.act1 = nn.ReLU(inplace=True)else:self.act1 = nn.SiLU(inplace=True)self.d1 = nn.Dropout(p=ffn_drop)self.l2 = nn.Linear(in_features=feedforward_channels, out_features=embed_dim)self.d2 = nn.Dropout(p=ffn_drop)def forward(self, x):tmp = self.d1(self.act1(self.l1(x)))tmp = self.d2(self.l2(tmp))x = tmp + xreturn x# transfomer encode和decode的最小循环单元,用于打包self_attention或者cross_attention
class BaseTransformerLayer(nn.Module):def __init__(self,attn_cfgs=[dict(embed_dim=64, num_heads=4), dict(embed_dim=64, num_heads=4)],fnn_cfg=dict(embed_dim=64, feedforward_channels=128, act_cfg='ReLU', ffn_drop=0.),operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm')):super(BaseTransformerLayer, self).__init__()self.attentions = nn.ModuleList()# 搭建att层for attn_cfg in attn_cfgs:self.attentions.append(MultiheadAttention(**attn_cfg))self.embed_dims = self.attentions[0].embed_dim# 统计norm数量 并搭建self.norms = nn.ModuleList()num_norms = operation_order.count('norm')for _ in range(num_norms):self.norms.append(nn.LayerNorm(normalized_shape=self.embed_dims))# 统计ffn数量 并搭建self.ffns = nn.ModuleList()self.ffns.append(FFN(**fnn_cfg))self.operation_order = operation_orderdef forward(self, query, key=None, value=None, query_pos=None, key_pos=None):attn_index = 0norm_index = 0ffn_index = 0for order in self.operation_order:if order == 'self_attn':temp_key = temp_value = query  # 不用担心三个值一样,在attention里面会重映射qkvquery, attention = self.attentions[attn_index](query,temp_key,temp_value,query_pos=query_pos,key_pos=query_pos)attn_index += 1elif order == 'cross_attn':query, attention = self.attentions[attn_index](query,key,value,query_pos=query_pos,key_pos=key_pos)attn_index += 1elif order == 'norm':query = self.norms[norm_index](query)norm_index += 1elif order == 'ffn':query = self.ffns[ffn_index](query)ffn_index += 1return queryif __name__ == '__main__':query = torch.rand(size=(10, 2, 64))key = torch.rand(size=(5, 2, 64))value = torch.rand(size=(5, 2, 64))query_pos = torch.rand(size=(10, 2, 64))key_pos = torch.rand(size=(5, 2, 64))# encoder 通常是6个encoder_layer组成 每个encoder_layer['self_attn', 'norm', 'ffn', 'norm']encoder_layer = BaseTransformerLayer(attn_cfgs=[dict(embed_dim=64, num_heads=4)],fnn_cfg=dict(embed_dim=64, feedforward_channels=1024, act_cfg='ReLU',ffn_drop=0.),operation_order=('self_attn', 'norm', 'ffn', 'norm'))encoder_layer_output = encoder_layer(query=query, query_pos=query_pos, key_pos=key_pos)# decoder 通常是6个decoder_layer组成 每个decoder_layer['self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm']decoder_layer = BaseTransformerLayer(attn_cfgs=[dict(embed_dim=64, num_heads=4), dict(embed_dim=64, num_heads=4)],fnn_cfg=dict(embed_dim=64, feedforward_channels=1024, act_cfg='ReLU',ffn_drop=0.),operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm'))decoder_layer_output = decoder_layer(query=query, key=key, value=value, query_pos=query_pos, key_pos=key_pos)pass

具体流程说明:

Encoder 通常是6个encoder_layer组成,每个encoder_layer[‘self_attn’, ‘norm’, ‘ffn’, ‘norm’]
Decoder 通常是6个decoder_layer组成,每个decoder_layer[‘self_attn’, ‘norm’, ‘cross_attn’, ‘norm’, ‘ffn’, ‘norm’]
按照以上方式搭建网络即可
其中norm为LayerNorm,在样本内部进行归一化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/625932.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

构建未来教育:在线培训系统开发的技术探讨

随着远程学习的崛起和数字化教育的普及,在线培训系统的开发成为了现代教育的核心。本文将深入讨论在线培训系统的关键技术要点,涵盖前后端开发、数据库管理、以及安全性和身份验证等关键方面。 前端开发:提供交互性与用户友好体验 在构建在…

京东ES支持ZSTD压缩算法上线了:高性能,低成本 | 京东云技术团队

1 前言 在《ElasticSearch降本增效常见的方法》一文中曾提到过zstd压缩算法[1],一步一个脚印我们终于在京东ES上线支持了zstd;我觉得促使目标完成主要以下几点原因: Elastic官方原因:zstd压缩算法没有在Elastic官方的开发计划中&…

最新智能AI系统ChatGPT网站程序源码+详细图文搭建部署教程,Midjourney绘画,GPT语音对话+ChatFile文档对话总结+DALL-E3文生图

一、前言 SparkAi创作系统是基于ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统,支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署AI创作Ch…

如何增加服务器的高并发

随着互联网的快速发展和普及,越来越多的应用程序需要支持高并发的请求处理。在这种情况下增加服务器的高并发能力成为了一个热门的话题。下面简单的介绍如果提高服务器的高并发能力。 负载均衡 是把请求分发到多个服务器上,来实现请求的平衡和分担。负…

(一)环境部署

Python虚拟环境 安装virtualenv pip install virtualenv 创建环境 virtualenv -p D:\python\python.exe(python解释器目录) env-py3.6(虚拟环境目录,名称随意) 在当前目录下生成env-py3.6目录。 激活环境 ...\env-py3.6\Scripts> .\activate 关闭&#xf…

STM32 CubeIDE 使用 CMSIS-DAP烧录 (方法2--外部小工具)

前言: 本篇所用方法,需要借助一个外部的工具小软件。 优点:烧录更稳定; 缺点:不能在线仿真调试。 下面链接,是另一种方法:修改CubeIDE调试文件。能在CubeIDE直接烧录、仿真,但不稳定。…

Bazel

简介: Bazel 是 google 研发的一款开源构建和测试工具,也是一种简单、易读的构建工具。 Bazel 支持多种编程语言的项目,并针对多个平台构建输出。 高级构建语言:Bazel 使用一种抽象的、人类可读的语言在高语义级别上描述项目的构建属性。与其…

uniapp 简易自定义日历

1、组件代码 gy-calendar-self.vue <template><view class"calendar"><view class"selsct-date">请选择预约日期</view><!-- 日历头部&#xff0c;显示星期 --><view class"weekdays"><view v-for"…

Linux常用命令大全(三)

系统权限 用户组 1. 创建组groupadd 组名 2. 删除组groupdel 组名 3. 查找系统中的组cat /etc/group | grep -n “组名”说明&#xff1a;系统每个组信息都会被存放在/etc/group的文件中1. 创建用户useradd -g 组名 用户名 2. 设置密码passwd 用户名 3. 查找系统账户说明&am…

openssl快速生成自签名证书

系统&#xff1a;Centos 7.6 确保已安装openssl openssl version生成私钥文件 private.key &#xff08;文件名自定义&#xff09; openssl genpkey -algorithm RSA -out private.key -pkeyopt rsa_keygen_bits:2048-out private.key&#xff1a;生成的私钥文件-algorithm RS…

探索设计模式的魅力:工厂方法模式

工厂方法模式是一种创建型设计模式&#xff0c;它提供了一种创建对象的接口&#xff0c;但将具体实例化对象的工作推迟到子类中完成。这样做的目的是创建对象时不用依赖于具体的类&#xff0c;而是依赖于抽象&#xff0c;这提高了系统的灵活性和可扩展性。 以下是工厂方法模式的…

学习视频一些杂乱的东西

文章目录 ref获取dom元素监听深层的某个属性? 可选链操作符 和 ?? 双问号表达式v-slot 语法糖作用域插槽动态插槽 初始化数组骚操作数字滚动 -> gsapstyle妙招新奇的原型链 object.createB站笔记链接JS相关设计模式ajaxsvgvue3scsswebpack内存泄漏 ref获取dom元素 直接给…

基于深度学习的实例分割的Web应用

基于深度学习的实例分割的Web应用 1. 项目简介1.1 模型部署1.2 Web应用 2. Web前端开发3. Web后端开发4. 总结 1. 项目简介 这是一个基于深度学习的实例分割Web应用的项目介绍。该项目使用PaddlePaddle框架&#xff0c;并以PaddleSeg训练的图像分割模型为例。 1.1 模型部署 …

【iOS】数据存储方式总结(持久化)沙盒结构

在iOS开发中&#xff0c;我们经常性地需要存储一些状态和数据&#xff0c;比如用户对于App的相关设置、需要在本地缓存的数据等等&#xff0c;本篇文章将介绍六个主要的数据存储方式 iOS中数据存储方式&#xff08;数据持久化&#xff09; 根据要存储的数据大小、存储数据以及…

案例:应用内字体大小调节

文章目录 介绍相关概念完整实例 代码结构解读保存默认大小获取字体大小修改字体大小 介绍 本篇Codelab将介绍如何使用基础组件Slider&#xff0c;通过拖动滑块调节应用内字体大小。要求完成以下功能&#xff1a; 实现两个页面的UX&#xff1a;主页面和字体大小调节页面。拖动…

基于物联网设计的智能储物柜(4G+华为云IOT+微信小程序)

一、项目介绍 在游乐场、商场、景区等人流量较大的地方&#xff0c;往往存在用户需要临时存放物品的情况&#xff0c;例如行李箱、外套、购物袋等。为了满足用户的储物需求&#xff0c;并提供更加便捷的服务体验&#xff0c;当前设计了一款物联网智能储物柜。 该智能储物柜通…

git提交报错:remote: Please remove the file from history and try again.

1. 报错信息 remote: error: File: fba7046b22fd74b77425aa3e4eae0ea992d44998 500.28 MB, exceeds 100.00 MB. remote: Please remove the file from history and try again. git rev-list --objects --all | grep fba7046b22fd74b77425aa3e4eae0ea992d44998 2. 分析原因 e…

打架识别摄像机

随着社会治安问题的增加&#xff0c;打架事件在公共场所频繁发生&#xff0c;给社会治安带来了一定程度的威胁。因此&#xff0c;为了提高公共场所的安全性&#xff0c;可以利用现代科技&#xff0c;如人工智能和摄像技术&#xff0c;开发一种打架识别摄像机。 这种摄像机可以通…

基于 IDEA 进行 Maven 工程构建

一、构建概念和构建过程 项目构建是指将源代码、依赖库和资源文件等转换成可执行或可部署的应用程序的过程&#xff0c;在这个过程中包括编译源代码、链接依赖库、打包和部署等多个步骤。 项目构建是软件开发过程中至关重要的一部分&#xff0c;它能够大大提高软件开发效率&…

【Docker】CentOS stream 上安装 Docker 环境详细指南

文章目录 1. 定义2. 优势3. 安装1&#xff09;Linux 上安装&#xff08;强烈推荐&#xff09;2&#xff09;Windows 和 MAC 上安装 4. 验证1&#xff09;查看版本2&#xff09;运行 Hello World 总结 Docker 是一种轻量级的容器化技术&#xff0c;提供了一种在不同环境中快速、…