政安晨:【Keras机器学习实践要点】(二十七)—— 使用感知器进行图像分类

目录

简介

设置

准备数据

配置超参数

使用数据增强

实施前馈网络(FFN)

将创建修补程序作为一个层

实施补丁编码层

建立感知器模型

变换器模块

感知器模型

编译、培训和评估模式


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:实施用于图像分类的感知器模型。

简介


本文实现了安德鲁-杰格(Andrew Jaegle)等人的 Perceiver

图像分类模型,在 CIFAR-100 数据集上进行演示。

Perceiver 模型利用非对称注意力机制,将输入信息迭代提炼成一个紧密的潜在瓶颈,使其能够扩展以处理非常大的输入信息。

换句话说:假设你的输入数据数组(如图像)有 M 个元素(即补丁),其中 M 很庞大。在标准变换器模型中,会对 M 个元素执行自注意操作。这一操作的复杂度为 O(M^2)。然而,感知器模型会创建一个大小为 N 个元素(其中 N << M)的潜在数组,并迭代执行两个操作:

1. 潜数组和数据数组之间的交叉注意变换器 - 此操作的复杂度为 O(M.N)。
2. 潜数组上的自注意变换器 - 此操作的复杂度为 O(N^2)。


本示例需要 Keras 3.0 或更高版本。

设置

import keras
from keras import layers, activations, ops

准备数据

num_classes = 100
input_shape = (32, 32, 3)(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

演绎展示:

配置超参数

learning_rate = 0.001
weight_decay = 0.0001
batch_size = 64
num_epochs = 2  # You should actually use 50 epochs!
dropout_rate = 0.2
image_size = 64  # We'll resize input images to this size.
patch_size = 2  # Size of the patches to be extract from the input images.
num_patches = (image_size // patch_size) ** 2  # Size of the data array.
latent_dim = 256  # Size of the latent array.
projection_dim = 256  # Embedding size of each element in the data and latent arrays.
num_heads = 8  # Number of Transformer heads.
ffn_units = [projection_dim,projection_dim,
]  # Size of the Transformer Feedforward network.
num_transformer_blocks = 4
num_iterations = 2  # Repetitions of the cross-attention and Transformer modules.
classifier_units = [projection_dim,num_classes,
]  # Size of the Feedforward network of the final classifier.print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
print(f"Patches per image: {num_patches}")
print(f"Elements per patch (3 channels): {(patch_size ** 2) * 3}")
print(f"Latent array shape: {latent_dim} X {projection_dim}")
print(f"Data array shape: {num_patches} X {projection_dim}")

演绎展示:

请注意,为了将每个像素作为数据数组中的单独输入,请将 patch_size 设置为 1。

使用数据增强

data_augmentation = keras.Sequential([layers.Normalization(),layers.Resizing(image_size, image_size),layers.RandomFlip("horizontal"),layers.RandomZoom(height_factor=0.2, width_factor=0.2),],name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

实施前馈网络(FFN)

def create_ffn(hidden_units, dropout_rate):ffn_layers = []for units in hidden_units[:-1]:ffn_layers.append(layers.Dense(units, activation=activations.gelu))ffn_layers.append(layers.Dense(units=hidden_units[-1]))ffn_layers.append(layers.Dropout(dropout_rate))ffn = keras.Sequential(ffn_layers)return ffn

将创建修补程序作为一个层

class Patches(layers.Layer):def __init__(self, patch_size):super().__init__()self.patch_size = patch_sizedef call(self, images):batch_size = ops.shape(images)[0]patches = ops.image.extract_patches(image=images,size=(self.patch_size, self.patch_size),strides=(self.patch_size, self.patch_size),dilation_rate=1,padding="valid",)patch_dims = patches.shape[-1]patches = ops.reshape(patches, [batch_size, -1, patch_dims])return patches

实施补丁编码层

PatchEncoder 层会通过投射到大小为 latent_dim 的矢量中对补丁进行线性变换。此外,它还会为投影向量添加可学习的位置嵌入。

请注意,最初的 Perceiver 论文使用的是傅立叶特征位置编码。

class PatchEncoder(layers.Layer):def __init__(self, num_patches, projection_dim):super().__init__()self.num_patches = num_patchesself.projection = layers.Dense(units=projection_dim)self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)def call(self, patches):positions = ops.arange(start=0, stop=self.num_patches, step=1)encoded = self.projection(patches) + self.position_embedding(positions)return encoded

建立感知器模型

感知器由两个模块组成:一个交叉注意模块和一个自注意标准变换器。

交叉注意模块


交叉注意模块将(latent_dim, projection_dim)潜在数组和(data_dim, projection_dim)数据数组作为输入,以产生(latent_dim, projection_dim)潜在数组作为输出。为了应用交叉关注,查询向量由潜在数组生成,而键向量和值向量则由编码图像生成。

请注意,本例中的数据数组是图像,其中 data_dim 设置为 num_patches。

def create_cross_attention_module(latent_dim, data_dim, projection_dim, ffn_units, dropout_rate
):inputs = {# Recieve the latent array as an input of shape [1, latent_dim, projection_dim]."latent_array": layers.Input(shape=(latent_dim, projection_dim), name="latent_array"),# Recieve the data_array (encoded image) as an input of shape [batch_size, data_dim, projection_dim]."data_array": layers.Input(shape=(data_dim, projection_dim), name="data_array"),}# Apply layer norm to the inputslatent_array = layers.LayerNormalization(epsilon=1e-6)(inputs["latent_array"])data_array = layers.LayerNormalization(epsilon=1e-6)(inputs["data_array"])# Create query tensor: [1, latent_dim, projection_dim].query = layers.Dense(units=projection_dim)(latent_array)# Create key tensor: [batch_size, data_dim, projection_dim].key = layers.Dense(units=projection_dim)(data_array)# Create value tensor: [batch_size, data_dim, projection_dim].value = layers.Dense(units=projection_dim)(data_array)# Generate cross-attention outputs: [batch_size, latent_dim, projection_dim].attention_output = layers.Attention(use_scale=True, dropout=0.1)([query, key, value], return_attention_scores=False)# Skip connection 1.attention_output = layers.Add()([attention_output, latent_array])# Apply layer norm.attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)# Apply Feedforward network.ffn = create_ffn(hidden_units=ffn_units, dropout_rate=dropout_rate)outputs = ffn(attention_output)# Skip connection 2.outputs = layers.Add()([outputs, attention_output])# Create the Keras model.model = keras.Model(inputs=inputs, outputs=outputs)return model

变换器模块

转换器将交叉注意模块输出的潜向量作为输入,对其 latent_dim 元素应用多头自注意,然后通过前馈网络,生成另一个(latent_dim,projection_dim)潜数组。

def create_transformer_module(latent_dim,projection_dim,num_heads,num_transformer_blocks,ffn_units,dropout_rate,
):# input_shape: [1, latent_dim, projection_dim]inputs = layers.Input(shape=(latent_dim, projection_dim))x0 = inputs# Create multiple layers of the Transformer block.for _ in range(num_transformer_blocks):# Apply layer normalization 1.x1 = layers.LayerNormalization(epsilon=1e-6)(x0)# Create a multi-head self-attention layer.attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=0.1)(x1, x1)# Skip connection 1.x2 = layers.Add()([attention_output, x0])# Apply layer normalization 2.x3 = layers.LayerNormalization(epsilon=1e-6)(x2)# Apply Feedforward network.ffn = create_ffn(hidden_units=ffn_units, dropout_rate=dropout_rate)x3 = ffn(x3)# Skip connection 2.x0 = layers.Add()([x3, x2])# Create the Keras model.model = keras.Model(inputs=inputs, outputs=x0)return model

感知器模型


感知器模型重复交叉注意模块和变换器模块的迭代次数--通过共享权重和跳转连接--使潜在阵列能够根据需要从输入图像中迭代提取信息。

class Perceiver(keras.Model):def __init__(self,patch_size,data_dim,latent_dim,projection_dim,num_heads,num_transformer_blocks,ffn_units,dropout_rate,num_iterations,classifier_units,):super().__init__()self.latent_dim = latent_dimself.data_dim = data_dimself.patch_size = patch_sizeself.projection_dim = projection_dimself.num_heads = num_headsself.num_transformer_blocks = num_transformer_blocksself.ffn_units = ffn_unitsself.dropout_rate = dropout_rateself.num_iterations = num_iterationsself.classifier_units = classifier_unitsdef build(self, input_shape):# Create latent array.self.latent_array = self.add_weight(shape=(self.latent_dim, self.projection_dim),initializer="random_normal",trainable=True,)# Create patching module.warself.patch_encoder = PatchEncoder(self.data_dim, self.projection_dim)# Create cross-attenion module.self.cross_attention = create_cross_attention_module(self.latent_dim,self.data_dim,self.projection_dim,self.ffn_units,self.dropout_rate,)# Create Transformer module.self.transformer = create_transformer_module(self.latent_dim,self.projection_dim,self.num_heads,self.num_transformer_blocks,self.ffn_units,self.dropout_rate,)# Create global average pooling layer.self.global_average_pooling = layers.GlobalAveragePooling1D()# Create a classification head.self.classification_head = create_ffn(hidden_units=self.classifier_units, dropout_rate=self.dropout_rate)super().build(input_shape)def call(self, inputs):# Augment data.augmented = data_augmentation(inputs)# Create patches.patches = self.patcher(augmented)# Encode patches.encoded_patches = self.patch_encoder(patches)# Prepare cross-attention inputs.cross_attention_inputs = {"latent_array": ops.expand_dims(self.latent_array, 0),"data_array": encoded_patches,}# Apply the cross-attention and the Transformer modules iteratively.for _ in range(self.num_iterations):# Apply cross-attention from the latent array to the data array.latent_array = self.cross_attention(cross_attention_inputs)# Apply self-attention Transformer to the latent array.latent_array = self.transformer(latent_array)# Set the latent array of the next iteration.cross_attention_inputs["latent_array"] = latent_array# Apply global average pooling to generate a [batch_size, projection_dim] repesentation tensor.representation = self.global_average_pooling(latent_array)# Generate logits.logits = self.classification_head(representation)return logits

编译、培训和评估模式

def run_experiment(model):# Create ADAM instead of LAMB optimizer with weight decay. (LAMB isn't supported yet)optimizer = keras.optimizers.Adam(learning_rate=learning_rate)# Compile the model.model.compile(optimizer=optimizer,loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc"),keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),],)# Create a learning rate scheduler callback.reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.2, patience=3)# Create an early stopping callback.early_stopping = keras.callbacks.EarlyStopping(monitor="val_loss", patience=15, restore_best_weights=True)# Fit the model.history = model.fit(x=x_train,y=y_train,batch_size=batch_size,epochs=num_epochs,validation_split=0.1,callbacks=[early_stopping, reduce_lr],)_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)print(f"Test accuracy: {round(accuracy * 100, 2)}%")print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")# Return history to plot learning curves.return history

请注意,在 V100 GPU 上以当前设置训练感知器模型大约需要 200 秒。

perceiver_classifier = Perceiver(patch_size,num_patches,latent_dim,projection_dim,num_heads,num_transformer_blocks,ffn_units,dropout_rate,num_iterations,classifier_units,
)history = run_experiment(perceiver_classifier)

演绎展示: 

Test accuracy: 0.91%
Test top 5 accuracy: 5.2%

经过 40 次历时后,Perceiver 模型在测试数据上达到了约 53% 的准确率和 81% 的前五名准确率。

正如 Perceiver 论文的消融部分所述,通过增加潜在阵列大小、增加潜在阵列和数据阵列元素的(投影)维度、增加变换器模块中的块数以及增加交叉注意和潜在变换器模块的迭代次数,可以获得更好的结果。您还可以尝试增加输入图像的大小,并使用不同的补丁尺寸。

Perceiver 可以从增大模型尺寸中获益。

不过,更大的模型需要更大的加速器来适应和有效训练。这就是 Perceiver 论文中使用 32 个 TPU 内核进行实验的原因。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/811997.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Boot集成Graphql快速入门Demo

1.Graphql介绍 GraphQL 是一个用于 API 的查询语言&#xff0c;是一个使用基于类型系统来执行查询的服务端运行时&#xff08;类型系统由你的数据定义&#xff09;。GraphQL 并没有和任何特定数据库或者存储引擎绑定&#xff0c;而是依靠你现有的代码和数据支撑。 优势 GraphQL…

npm install 报 ERESOLVE unable to resolve dependency tree 异常解决方法

问题 在安装项目依赖时&#xff0c;很大可能会遇到安装不成功的问题&#xff0c;其中有一个很大的原因&#xff0c;可能就是因为你的npm版本导致的。 1.npm ERR! code ERESOLVE npm ERR! ERESOLVE unable to resolve dependency tree 2.ERESOLVE unable to resolve dependenc…

【C++之queue的应用及模拟实现】

C学习笔记---014 C之queue的应用及模拟实现1、queue的简单介绍2、queue的简单接口应用3、queue的模拟实现3.1、queue的结构一般的构建3.2、queue的适配器模式构建3.3、queue的主要接口函数 4、queue的模拟实现完整代码4.1、一般方式4.2、泛型模式 5、queue巩固练习题5.1、最小栈…

VSCode中 task.json 和 launch.json 的作用和参数解释以及配置教程

前言 由于 VS Code 并不是一个传统意义上的 IDE&#xff0c;所以初学者可能在使用过程中会有很多的疑惑&#xff0c;其中比较常见的一个问题就是 tasks.json和 launch.json两个文件分别有什么作用以及如何配置 tasks.json VSCode 官网提供的 tasks.json 配置教程 使用不同的…

UE4_导入内容_Alembic文件导入器

Alembic文件导入器 Alembic文件格式(.abc)是一个开放的计算机图形交换框架&#xff0c;它将复杂的动画化场景浓缩成一组非过程式的、与应用程序无关的烘焙几何结果。虚幻引擎4(UE4)允许你通过 Alembic导入器 导入你的Alembic文件&#xff0c;这让你可以在外部自由地创建复杂的…

什么是态势感知?

什么是态势感知&#xff1f; 同学&#xff0c;听说过态势感知吗&#xff1f;啥&#xff1f;不知道&#xff1f;不知道很正常&#xff0c;因为态势感知是一个比较小众、比较神秘的概念。为什么态势感知很神秘&#xff0c;首先是因为这是来自军事情报领域的概念&#xff0c;然后…

008Node.js模块、自定义模块和CommonJs

CommonJS API定义很多普通应用程序(主要指非浏览器的应用)使用的API&#xff0c;从而填补了这个空白。它的终极目标是提供一个类似Python&#xff0c;Ruby和Java标 准库。这样的话&#xff0c;开发者可以使用CommonJS API编写应用程序&#xff0c;然后这些应用可以运行在不同的…

【尝试】域名验证:配置github二级目录下的txt文件

【尝试】域名验证&#xff1a;配置github二级目录下的txt文件 写在最前面一、初始化本地仓库二、设置远程仓库1. 远程仓库 URL 没有设置或设置错误添加远程仓库修改远程仓库 2. 访问权限问题3. 仓库不存在步骤 1: 在你的仓库中添加文件步骤 2: 确认GitHub Pages设置步骤 3: 访问…

lua学习笔记21完结篇(lua中的垃圾回收)

print("*****************************lua中的垃圾回收*******************************") text{id24,name"仙贝"} --垃圾回收关键字collectgarbag --获取当前lua占用内存数 k字节 返回值*1024就可以得到具体占用字节数 print(collectgarbage("count&…

本地部署开源免费文件传输工具LocalSend并实现公网快速传送文件

&#x1f308;个人主页: Aileen_0v0 &#x1f525;热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​&#x1f4ab;个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-X4xB3gSR3z2VUfmN {font-family:"trebuchet ms",verdana,arial,sans-serif;font-siz…

C++项目——集群聊天服务器项目(十四)客户端业务

大家好~前段时间有些事情需要处理&#xff0c;没来得及更新&#xff0c;实在不好意思。 今天来继续更新集群聊天服务器项目的客户端功能&#xff0c;主要实现客户端业务&#xff0c;包括添加好友、点对点聊天、创建群组、添加群组、群组聊天业务&#xff0c;接下来我们一起来敲…

Prompt 工程技术提问的艺术,如何向 ChatGPT 提问?

Prompt 工程技术简介 什么是 Prompt 工程&#xff1f; Prompt 工程是创建提示或指导像 ChatGPT 这样的语言模型输出的过程。它允许用户控制模型的输出并 生成符合其特定需求的文本。ChatGPT 是一种先进的语言模型&#xff0c;能够生成类似于人类的文本。它建立在 Transformer 架…

FPGA基于VCU的H265视频解压缩,解码后HDMI2.0输出,支持4K60帧,提供工程源码+开发板+技术支持

目录 1、前言免责声明 2、相关方案推荐我这里已有的视频图像编解码方案4K60帧HDMI2.0输入&#xff0c;H265视频压缩方案 3、详细设计方案设计框图FPGA开发板解压视频源Zynq UltraScale VCUVideo Frame Buffer ReadVideo MixerHDMI 1.4/2.0 Transmitter SubsystemVideo PHY Cont…

316_C++_xml文件解析成map,可以放到表格上 + xml、xlsx文件互相解析

xml文件例如&#xff1a; <?xml version"1.0" encoding"UTF-8" standalone"yes"?> <TrTable> <tr id"0" label"TR_PB_CH" text"CH%2"/> <tr id"4" label"TR_PB_CHN"…

必应bing搜索广告推广国内能开户吗?

随着互联网广告市场的不断进化和细分化&#xff0c;必应Bing搜索广告已逐渐成为中国企业拓展国内市场、精准触达目标客户的重要渠道之一。2024年&#xff0c;必应Bing在国内市场的进一步布局&#xff0c;不仅彰显了其对本土企业的强大吸引力&#xff0c;更带来了全新的开户政策…

【网站项目】驾校预约管理系统小程序

&#x1f64a;作者简介&#xff1a;拥有多年开发工作经验&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。&#x1f339;赠送计算机毕业设计600个选题excel文件&#xff0c;帮助大学选题。赠送开题报告模板&#xff…

【opencv】示例-image_alignment.cpp 利用ECC 算法进行图像对齐

affine imshow("image", target_image); imshow("template", template_image); imshow("warped image", warped_image); imshow("error (black: no error)", abs(errorImage) * 255 / max_of_error); homography 这段代码是一个利用EC…

「51媒体网」汽车类媒体有哪些?车展媒体宣传

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 汽车类媒体有很多&#xff0c;具体如下&#xff1a; 汽车之家&#xff1a;提供全面的汽车新闻、评测、导购等内容。 爱卡汽车&#xff1a;同样是一个综合性的汽车信息平台&#xff0c;涵…

2024 年适用于 Mac 电脑的最佳 SD 卡恢复软件

D 卡体积很小&#xff0c;广泛用于数码相机、摄像机、行车记录仪、无人机等。通常&#xff0c;在使用设备拍照、拍摄视频后&#xff0c;您会将文件移动到 Mac 进行进一步编辑或作为备份。大多数时候&#xff0c;应该存在问题。但是&#xff0c;您的 SD 卡仍然会出现一些问题并导…

Harmony鸿蒙南向驱动开发-Regulator接口使用

功能简介 Regulator模块用于控制系统中某些设备的电压/电流供应。在嵌入式系统&#xff08;尤其是手机&#xff09;中&#xff0c;控制耗电量很重要&#xff0c;直接影响到电池的续航时间。所以&#xff0c;如果系统中某一个模块暂时不需要使用&#xff0c;就可以通过Regulato…