Google的MLP-MIXer的复现(pytorch实现)

Google的MLP-MIXer的复现(pytorch实现)

该模型原论文实现用的jax框架实现,先贴出原论文的代码实现:

# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.from typing import Any, Optionalimport einops
import flax.linen as nn
import jax
import jax.numpy as jnpclass MlpBlock(nn.Module):mlp_dim: int@nn.compactdef __call__(self, x):y = nn.Dense(self.mlp_dim)(x)y = nn.gelu(y)return nn.Dense(x.shape[-1])(y)class MixerBlock(nn.Module):"""Mixer block layer."""tokens_mlp_dim: intchannels_mlp_dim: int@nn.compactdef __call__(self, x):y = nn.LayerNorm()(x)y = jnp.swapaxes(y, 1, 2)y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y) #  (32, 512, 196)y = jnp.swapaxes(y, 1, 2)x = x + yy = nn.LayerNorm()(x)return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)class MlpMixer(nn.Module):"""Mixer architecture."""patches: Anynum_classes: intnum_blocks: inthidden_dim: inttokens_mlp_dim: intchannels_mlp_dim: intmodel_name: Optional[str] = None@nn.compactdef __call__(self, inputs, *, train):del trainx = nn.Conv(self.hidden_dim, self.patches.size,strides=self.patches.size, name='stem')(inputs)x = einops.rearrange(x, 'n h w c -> n (h w) c')  # 从(32,512,14,14)变成了(32,196,512)for _ in range(self.num_blocks):x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)x = nn.LayerNorm(name='pre_head_layer_norm')(x)x = jnp.mean(x, axis=1)if self.num_classes:x = nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,name='head')(x)return xmodel_params = {'patches': {'size': (16, 16), 'stride': (16, 16)}, # 这里需要一个描述patch大小和步长的对象,例如Flax的stem模块初始化参数'num_classes': 10,  # 分类任务的类别数'num_blocks': 8,  # Mixer Block的重复次数'hidden_dim': 512,  # 隐藏层维度'tokens_mlp_dim': 256,  # token mixing的MLP维度'channels_mlp_dim': 2048,  # channel mixing的MLP维度
}# 准备输入数据,例如一批32张图片,每张图片尺寸为512x14x14(假设已经按要求预处理)# 初始化模型
seed=0
key = jax.random.PRNGKey(seed)
model = MlpMixer.apply(key, **model_params)input_data = jnp.ones((4096, 224, 224, 3))  # 示例输入数据
# 调用模型进行前向传播
output = model(input_data)print("Output shape:", output)  # 打印输出形状,预期是(32, 10)如果num_classes=10

该模型的总体框架图如下所示:

在这里插入图片描述

对该框架的讲解,网上已经很多了,就不在此赘述。

实现的pytorch代码如下所示:

class MlpBlock(nn.Module):def __init__(self, in_mlp_dim=196, out_mlp_dim=256):super(MlpBlock, self).__init__()self.mlp_dim = out_mlp_dimself.dense1 = nn.Linear(in_mlp_dim, out_mlp_dim)  # 若输入的向量为[32,196, 512]则输入的也应该是512,输出可以自己定self.gelu = nn.GELU()self.dense2 = nn.Linear(out_mlp_dim, in_mlp_dim)def forward(self, x):y = self.dense1(x)y = self.gelu(y)y = self.dense2(y)return yclass MixerBlock(nn.Module):def __init__(self, tokens_mlp_dim=256, channels_mlp_dim=2048, batch_size=32):super(MixerBlock, self).__init__()self.batch_size = batch_sizeself.norm1 = nn.LayerNorm(512)  # 对512维的做归一化,默认给最后一个维度做归一化self.token_Mixing = MlpBlock(out_mlp_dim=tokens_mlp_dim)self.norm2 = nn.LayerNorm(512)      # 对512维的做归一化self.channel_mixing = MlpBlock(in_mlp_dim=512, out_mlp_dim=channels_mlp_dim)def forward(self, x):y = self.norm1(x)y = y.permute(0, 2, 1)y = self.token_Mixing(y)y = y.permute(0, 2, 1)x = x + yy = self.norm2(x)return x + self.channel_mixing(y)class MlpMixer(nn.Module):def __init__(self, patches, num_classes, num_blocks, hidden_dim, tokens_mlp_dim, channels_mlp_dim):super(MlpMixer, self).__init__()self.stem = nn.Conv2d(3, hidden_dim, kernel_size=patches, stride=patches)self.mixer_block_1 = MixerBlock()self.mixer_blocks = nn.ModuleList([MixerBlock(tokens_mlp_dim, channels_mlp_dim) for _ in range(num_blocks)])self.pre_head_norm = nn.LayerNorm(hidden_dim)self.head = nn.Linear(hidden_dim, num_classes) if num_classes > 0 else nn.Identity()def forward(self, x):x = self.stem(x)b, c, h, w = x.shapex = x.view(b, c, -1).permute(0, 2, 1)for mixer_block in self.mixer_blocks:x = mixer_block(x)x = self.pre_head_norm(x)x = x.mean(dim=1)x = self.head(x)return x# model = MlpMixer(16, 10, 6, 512, 256, 2048)
# input_tensor = torch.randn(32, 3, 224, 224)  # (batch size, num_patches, input_dim)
# output = model(input_tensor)
# print(output)

在将flax框架的代码改为pytorch实现的时候,还是踩了不少的坑,在此讲一下,希望后面做的人,可以避免。

1.在flax框架的nn.linear层中没有输入维度,只有一个输出维度。

2.在处理两个差异的时候,如输入维度[32,196,512],其中代表的意思分别为batch_size为32,196为图片在经过patch之后的224*224输入之后经过patch=16,变为14 * 14即196,512会在二维卷积处理之后输出的channel类似。

1.在flax框架的nn.linear层中没有输入维度,只有一个输出维度。

2.在处理两个差异的时候,如输入维度[32,196,512],其中代表的意思分别为batch_size为32,196为图片在经过patch之后的224*224输入之后经过patch=16,变为14 * 14即196,512会在二维卷积处理之后输出的channel类似。

在nn.linear那儿的in_channel与第三个维度保持一致,就可以不必将其三维的转换为二维的。同时在对layernorm那儿转换的时候,默认也是对最后一个维度进行正则化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/16474.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GEC210编译环境搭建

一、下载编译工具链 下载:点击跳转 二、解压到 /usr/local/arm 目录 sudo mv gec210.zip /usr/local/arm cd /usr/local/arm sudo unzip gec210.zip 三、添加到环境变量 PATH/usr/local/arm/arm-cortex_a8-linux-gnueabi-4.7.3/bin:$PATH 四、测试验证 在终端…

python数据分析-基于数据挖掘对APP评分的预测

前言 当我们谈论关于APP用户分析与电子商务之间的联系时,机器学习在这两个领域的应用变得至关重要。App用户分析和电子商务之间存在着密切的关联,因为用户行为和偏好的深入理解对于提高用户体验、增加销售以及优化产品功能至关重要。故本文基于K-近邻模…

OFDM 802.11a的FPGA实现(二十)使用AXI-Stream FIFO进行跨时钟(含代码)

目录 1.前言 2.AXI-Stream FIFO时序 3.AXI-Stream FIFO配置信息 4.时钟控制模块MMCM 5.ModelSim仿真 6.总结 1.前言 至此,通过前面的文章讲解,对于OFDM 802.11a的发射基带的一个完整的PPDU帧的所有处理已经全部完成,其结构如下图所示&…

opencv-C++ VS2019配置安装

最新opencv-c安装及配置教程(VS2019 C & opencv4.4.0)_c opencv配置-CSDN博客

夜雨触花感怀

夜雨触花感怀 雨落有轨迹,业成无坦途。 ​鸡毛飞虚空,寻德问心路。 ​恰如求耕耘,大话量寸土。 ​好吃品五味,难得评真俗。

CAN总线简介

1. CAN总线概述 1.1 CAN定义与历史背景 CAN,全称为Controller Area Network,是一种基于消息广播的串行通信协议。它最初由德国Bosch公司在1983年为汽车行业开发,目的是实现汽车内部电子控制单元(ECUs)之间的可靠通信。…

用Vuex存储可配置下载的ip地址(用XML进行ajax请求配置文件)

1.在public文件夹下创建一个名为Configuration的文件在创建一个Configuration.txt里面就放IP地址(这里的名字可以随便命名一定性的被人解读文件含义) 例如: http://172.171.208.1:80032.在store文件夹中创建一个名为 ajaxModule.js 的 Vuex …

2. CSS选择器与伪类

2.1 基本选择器回顾 在开始介绍CSS3选择器之前&#xff0c;我们先回顾一下CSS的基本选择器。这些选择器是所有CSS开发的基础。 2.1.1 元素选择器 元素选择器用于选中指定类型的HTML元素。 /* 选中所有的<p>元素 */ p {color: blue; }2.1.2 类选择器 类选择器用于选中…

03自动辅助导航驾驶NOP其实就是NOA

蔚来NOP是什么意思&#xff1f;蔚来NOP是啥 蔚来NOP的意思就是NavigateonPilot智能辅助导航驾驶&#xff0c;也就是大家俗称的高阶辅助驾驶&#xff0c;在车主设定好导航路线&#xff0c;并且符合开启NOP条件的前提下&#xff0c;蔚来NOP可以代替驾驶员完成从A点到B点的智能辅助…

深入理解数仓开发(二)数据技术篇之数据同步

1、数据同步 数据同步我们之前在数仓当中使用了多种工具&#xff0c;比如使用 Flume 将日志文件从服务器采集到 Kafka&#xff0c;再通过 Flume 将 Kafka 中的数据采集到 HDFS。使用 MaxWell 实时监听 MySQL 的 binlog 日志&#xff0c;并将采集到的变更日志&#xff08;json 格…

【二叉树】:LeetCode:100.相同的数(分治)

&#x1f381;个人主页&#xff1a;我们的五年 &#x1f50d;系列专栏&#xff1a;初阶初阶结构刷题 &#x1f389;欢迎大家点赞&#x1f44d;评论&#x1f4dd;收藏⭐文章 1.问题描述&#xff1a; 2.问题分析&#xff1a; 二叉树是区分结构的&#xff0c;即左右子树是不一…

[JDK工具-6] jmap java内存映射工具

文章目录 1. 介绍2. 主要选项3. 生成java堆转储快照 jmap -dump4. 显示堆详细信息 jmap -heap pid5. 显示堆中对象统计信息 jmap -histo pid jmap(Memory Map for Java) 1. 介绍 位置&#xff1a;jdk\bin 作用&#xff1a; jdk安装后会自带一些小工具&#xff0c;jmap命令(Mem…

PySide6升级导致的Fatal Python error: could not initialize part 2问题及其解决方法

问题出现 把PySide6从6.6.1升级到6.7.1&#xff0c;结果运行程序的时候就报如下错误&#xff1a; Traceback (most recent call last): File "signature_bootstrap.py", line 77, in bootstrap File "signature_bootstrap.py", line 93, in find_inc…

Kafka SASL_SSL集群认证

背景 公司需要对kafka环境进行安全验证,目前考虑到的方案有Kerberos和SSL和SASL_SSL,最终考虑到安全和功能的丰富度,我们最终选择了SASL_SSL方案。处于知识积累的角度,记录一下kafka SASL_SSL安装部署的步骤。 机器规划 目前测试环境公搭建了三台kafka主机服务,现在将详…

H3CNE-7-TCP和UDP协议

TCP和UDP协议 TCP&#xff1a;可靠传输&#xff0c;面向连接 -------- 速度慢&#xff0c;准确性高 UDP&#xff1a;不可靠传输&#xff0c;非面向连接 -------- 速度快&#xff0c;但准确性差 面向连接&#xff1a;如果某应用层协议的四层使用TCP端口&#xff0c;那么正式的…

智能家居完结 -- 整体设计

系统框图 前情提要: 智能家居1 -- 实现语音模块-CSDN博客 智能家居2 -- 实现网络控制模块-CSDN博客 智能家居3 - 实现烟雾报警模块-CSDN博客 智能家居4 -- 添加接收消息的初步处理-CSDN博客 智能家居5 - 实现处理线程-CSDN博客 智能家居6 -- 配置 ini文件优化设备添加-CS…

【MySQL】聊聊count的相关操作

在平时的操作中&#xff0c;经常使用count进行操作&#xff0c;计算统计的数据。那么具体的原理是如何的&#xff1f;为什么有时候执行count很慢。 count的实现方式 select count(*) from student;对于MyISAM引擎来说&#xff0c;会把一个表的总行数存储在磁盘上&#xff0c;…

Linux下Vision Mamba环境配置+多CUDA版本切换

上篇文章大致讲了下Vision Mamba的相关知识&#xff0c;网上关于Vision Mamba的配置博客太多&#xff0c;笔者主要用来整合下。 笔者在Win10和Linux下分别尝试配置相关环境。 Win10下配置 失败 \textcolor{red}{失败} 失败&#xff0c;最后出现的问题如下&#xff1a; https://…

基于物联网架构的电子小票服务系统

1.电子小票物联网架构 采用感知层、网络层和应用层的3层物联网体系架构模型&#xff0c;电子小票物联网的架构见图1。 图1 电子小票物联网架构 感知层的小票智能硬件能够取代传统的小票打印机&#xff0c;在不改变商家原有收银系统的前提下&#xff0c;采集收音机待打印的购物…