Google的MLP-MIXer的复现(pytorch实现)

Google的MLP-MIXer的复现(pytorch实现)

该模型原论文实现用的jax框架实现,先贴出原论文的代码实现:

# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.from typing import Any, Optionalimport einops
import flax.linen as nn
import jax
import jax.numpy as jnpclass MlpBlock(nn.Module):mlp_dim: int@nn.compactdef __call__(self, x):y = nn.Dense(self.mlp_dim)(x)y = nn.gelu(y)return nn.Dense(x.shape[-1])(y)class MixerBlock(nn.Module):"""Mixer block layer."""tokens_mlp_dim: intchannels_mlp_dim: int@nn.compactdef __call__(self, x):y = nn.LayerNorm()(x)y = jnp.swapaxes(y, 1, 2)y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y) #  (32, 512, 196)y = jnp.swapaxes(y, 1, 2)x = x + yy = nn.LayerNorm()(x)return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)class MlpMixer(nn.Module):"""Mixer architecture."""patches: Anynum_classes: intnum_blocks: inthidden_dim: inttokens_mlp_dim: intchannels_mlp_dim: intmodel_name: Optional[str] = None@nn.compactdef __call__(self, inputs, *, train):del trainx = nn.Conv(self.hidden_dim, self.patches.size,strides=self.patches.size, name='stem')(inputs)x = einops.rearrange(x, 'n h w c -> n (h w) c')  # 从(32,512,14,14)变成了(32,196,512)for _ in range(self.num_blocks):x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)x = nn.LayerNorm(name='pre_head_layer_norm')(x)x = jnp.mean(x, axis=1)if self.num_classes:x = nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,name='head')(x)return xmodel_params = {'patches': {'size': (16, 16), 'stride': (16, 16)}, # 这里需要一个描述patch大小和步长的对象,例如Flax的stem模块初始化参数'num_classes': 10,  # 分类任务的类别数'num_blocks': 8,  # Mixer Block的重复次数'hidden_dim': 512,  # 隐藏层维度'tokens_mlp_dim': 256,  # token mixing的MLP维度'channels_mlp_dim': 2048,  # channel mixing的MLP维度
}# 准备输入数据,例如一批32张图片,每张图片尺寸为512x14x14(假设已经按要求预处理)# 初始化模型
seed=0
key = jax.random.PRNGKey(seed)
model = MlpMixer.apply(key, **model_params)input_data = jnp.ones((4096, 224, 224, 3))  # 示例输入数据
# 调用模型进行前向传播
output = model(input_data)print("Output shape:", output)  # 打印输出形状,预期是(32, 10)如果num_classes=10

该模型的总体框架图如下所示:

在这里插入图片描述

对该框架的讲解,网上已经很多了,就不在此赘述。

实现的pytorch代码如下所示:

class MlpBlock(nn.Module):def __init__(self, in_mlp_dim=196, out_mlp_dim=256):super(MlpBlock, self).__init__()self.mlp_dim = out_mlp_dimself.dense1 = nn.Linear(in_mlp_dim, out_mlp_dim)  # 若输入的向量为[32,196, 512]则输入的也应该是512,输出可以自己定self.gelu = nn.GELU()self.dense2 = nn.Linear(out_mlp_dim, in_mlp_dim)def forward(self, x):y = self.dense1(x)y = self.gelu(y)y = self.dense2(y)return yclass MixerBlock(nn.Module):def __init__(self, tokens_mlp_dim=256, channels_mlp_dim=2048, batch_size=32):super(MixerBlock, self).__init__()self.batch_size = batch_sizeself.norm1 = nn.LayerNorm(512)  # 对512维的做归一化,默认给最后一个维度做归一化self.token_Mixing = MlpBlock(out_mlp_dim=tokens_mlp_dim)self.norm2 = nn.LayerNorm(512)      # 对512维的做归一化self.channel_mixing = MlpBlock(in_mlp_dim=512, out_mlp_dim=channels_mlp_dim)def forward(self, x):y = self.norm1(x)y = y.permute(0, 2, 1)y = self.token_Mixing(y)y = y.permute(0, 2, 1)x = x + yy = self.norm2(x)return x + self.channel_mixing(y)class MlpMixer(nn.Module):def __init__(self, patches, num_classes, num_blocks, hidden_dim, tokens_mlp_dim, channels_mlp_dim):super(MlpMixer, self).__init__()self.stem = nn.Conv2d(3, hidden_dim, kernel_size=patches, stride=patches)self.mixer_block_1 = MixerBlock()self.mixer_blocks = nn.ModuleList([MixerBlock(tokens_mlp_dim, channels_mlp_dim) for _ in range(num_blocks)])self.pre_head_norm = nn.LayerNorm(hidden_dim)self.head = nn.Linear(hidden_dim, num_classes) if num_classes > 0 else nn.Identity()def forward(self, x):x = self.stem(x)b, c, h, w = x.shapex = x.view(b, c, -1).permute(0, 2, 1)for mixer_block in self.mixer_blocks:x = mixer_block(x)x = self.pre_head_norm(x)x = x.mean(dim=1)x = self.head(x)return x# model = MlpMixer(16, 10, 6, 512, 256, 2048)
# input_tensor = torch.randn(32, 3, 224, 224)  # (batch size, num_patches, input_dim)
# output = model(input_tensor)
# print(output)

在将flax框架的代码改为pytorch实现的时候,还是踩了不少的坑,在此讲一下,希望后面做的人,可以避免。

1.在flax框架的nn.linear层中没有输入维度,只有一个输出维度。

2.在处理两个差异的时候,如输入维度[32,196,512],其中代表的意思分别为batch_size为32,196为图片在经过patch之后的224*224输入之后经过patch=16,变为14 * 14即196,512会在二维卷积处理之后输出的channel类似。

1.在flax框架的nn.linear层中没有输入维度,只有一个输出维度。

2.在处理两个差异的时候,如输入维度[32,196,512],其中代表的意思分别为batch_size为32,196为图片在经过patch之后的224*224输入之后经过patch=16,变为14 * 14即196,512会在二维卷积处理之后输出的channel类似。

在nn.linear那儿的in_channel与第三个维度保持一致,就可以不必将其三维的转换为二维的。同时在对layernorm那儿转换的时候,默认也是对最后一个维度进行正则化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/16474.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GEC210编译环境搭建

一、下载编译工具链 下载:点击跳转 二、解压到 /usr/local/arm 目录 sudo mv gec210.zip /usr/local/arm cd /usr/local/arm sudo unzip gec210.zip 三、添加到环境变量 PATH/usr/local/arm/arm-cortex_a8-linux-gnueabi-4.7.3/bin:$PATH 四、测试验证 在终端…

python数据分析-基于数据挖掘对APP评分的预测

前言 当我们谈论关于APP用户分析与电子商务之间的联系时,机器学习在这两个领域的应用变得至关重要。App用户分析和电子商务之间存在着密切的关联,因为用户行为和偏好的深入理解对于提高用户体验、增加销售以及优化产品功能至关重要。故本文基于K-近邻模…

OFDM 802.11a的FPGA实现(二十)使用AXI-Stream FIFO进行跨时钟(含代码)

目录 1.前言 2.AXI-Stream FIFO时序 3.AXI-Stream FIFO配置信息 4.时钟控制模块MMCM 5.ModelSim仿真 6.总结 1.前言 至此,通过前面的文章讲解,对于OFDM 802.11a的发射基带的一个完整的PPDU帧的所有处理已经全部完成,其结构如下图所示&…

CAN总线简介

1. CAN总线概述 1.1 CAN定义与历史背景 CAN,全称为Controller Area Network,是一种基于消息广播的串行通信协议。它最初由德国Bosch公司在1983年为汽车行业开发,目的是实现汽车内部电子控制单元(ECUs)之间的可靠通信。…

03自动辅助导航驾驶NOP其实就是NOA

蔚来NOP是什么意思?蔚来NOP是啥 蔚来NOP的意思就是NavigateonPilot智能辅助导航驾驶,也就是大家俗称的高阶辅助驾驶,在车主设定好导航路线,并且符合开启NOP条件的前提下,蔚来NOP可以代替驾驶员完成从A点到B点的智能辅助…

【二叉树】:LeetCode:100.相同的数(分治)

🎁个人主页:我们的五年 🔍系列专栏:初阶初阶结构刷题 🎉欢迎大家点赞👍评论📝收藏⭐文章 1.问题描述: 2.问题分析: 二叉树是区分结构的,即左右子树是不一…

[JDK工具-6] jmap java内存映射工具

文章目录 1. 介绍2. 主要选项3. 生成java堆转储快照 jmap -dump4. 显示堆详细信息 jmap -heap pid5. 显示堆中对象统计信息 jmap -histo pid jmap(Memory Map for Java) 1. 介绍 位置:jdk\bin 作用: jdk安装后会自带一些小工具,jmap命令(Mem…

Kafka SASL_SSL集群认证

背景 公司需要对kafka环境进行安全验证,目前考虑到的方案有Kerberos和SSL和SASL_SSL,最终考虑到安全和功能的丰富度,我们最终选择了SASL_SSL方案。处于知识积累的角度,记录一下kafka SASL_SSL安装部署的步骤。 机器规划 目前测试环境公搭建了三台kafka主机服务,现在将详…

H3CNE-7-TCP和UDP协议

TCP和UDP协议 TCP:可靠传输,面向连接 -------- 速度慢,准确性高 UDP:不可靠传输,非面向连接 -------- 速度快,但准确性差 面向连接:如果某应用层协议的四层使用TCP端口,那么正式的…

智能家居完结 -- 整体设计

系统框图 前情提要: 智能家居1 -- 实现语音模块-CSDN博客 智能家居2 -- 实现网络控制模块-CSDN博客 智能家居3 - 实现烟雾报警模块-CSDN博客 智能家居4 -- 添加接收消息的初步处理-CSDN博客 智能家居5 - 实现处理线程-CSDN博客 智能家居6 -- 配置 ini文件优化设备添加-CS…

【MySQL】聊聊count的相关操作

在平时的操作中,经常使用count进行操作,计算统计的数据。那么具体的原理是如何的?为什么有时候执行count很慢。 count的实现方式 select count(*) from student;对于MyISAM引擎来说,会把一个表的总行数存储在磁盘上,…

Linux下Vision Mamba环境配置+多CUDA版本切换

上篇文章大致讲了下Vision Mamba的相关知识,网上关于Vision Mamba的配置博客太多,笔者主要用来整合下。 笔者在Win10和Linux下分别尝试配置相关环境。 Win10下配置 失败 \textcolor{red}{失败} 失败,最后出现的问题如下: https://…

基于物联网架构的电子小票服务系统

1.电子小票物联网架构 采用感知层、网络层和应用层的3层物联网体系架构模型,电子小票物联网的架构见图1。 图1 电子小票物联网架构 感知层的小票智能硬件能够取代传统的小票打印机,在不改变商家原有收银系统的前提下,采集收音机待打印的购物…

修改 ant design tour 漫游式导航的弹窗边框样式

一 说明 应项目要求,调整ant design tour 弹窗边框的样式。tour 原本样式是有遮罩层,因此没有边框看起来也不突兀。原图如下: 但是UI设计是取消遮罩层,并设置边框样式。当 取消 了遮罩层,没有设置边框样式的图片如下&a…

python考试成绩管理与分析:从列表到方差

新书上架~👇全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一、考试成绩的输入与列表管理 二、成绩的总分与平均成绩计算 三、成绩方差的计算 四、成…

人工智能场景下的网络负载均衡技术

AI技术驱动智能应用井喷,智能算力增速远超通用算力。IDC预测,未来五年,我国智能算力规模年复合增长率将超50%,开启数据中心算力新纪元。随着需求激增,数据中心或智算网络亟需扩容、增速、减时延,确保网络稳…

rockylinux 利用nexus 搭建私服yum仓库

简单说下为啥弄这个私服,因为自己要学习一些东西,比如新版的k8s等,其中会涉及到一些yum的安装,为了防止因网络问题导致yum安装失败,和重复下载,所以弄个私服,当然也有为了意外保障的想法&#x…

【实战JVM】-基础篇-01-JVM通识-字节码详解

【实战JVM】-基础篇-01-JVM通识-字节码详解-类的声明周期-加载器 1 初识JVM1.1 什么是JVM1.2 JVM的功能1.2.1 即时编译 1.3 常见JVM 2 字节码文件详解2.1 Java虚拟机的组成2.2 字节码文件的组成2.2.1 正确打开字节码文件2.2.2 字节码组成2.2.3 基础信息2.2.3.1 魔数2.2.3.1 主副…

【C++】右值引用 移动语义

目录 前言一、右值引用与移动语义1.1 左值引用和右值引用1.2 右值引用使用场景和意义1.3 右值引用引用左值及其一些更深入的使用场景分析1.3.1 完美转发 二、新的类功能三、可变参数模板 前言 本篇文章我们继续来聊聊C11新增的一些语法——右值引用,我们在之前就已…