政安晨:【Keras机器学习实践要点】(三十)—— 使用斯温变换器进行图像分类

目录

设置

配置超参数

准备数据

辅助函数

基于窗口的多头自注意力计算

模型训练与评估

准备 tf.data.Dataset

建立模型

在 CIFAR-100 上训练


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:使用 Swin Transformers(计算机视觉的通用骨干)进行图像分类。

本示例实现了用于图像分类的 Swin 变换器:使用移位窗口的分层视觉变换器,并在 CIFAR-100 数据集上进行了演示。

Swin Transformer(移位窗口变换器)可作为计算机视觉的通用骨干。Swin 变换器是一种分层变换器,其表示是通过移位窗口计算的。移位窗口方案将自我关注计算限制在非重叠局部窗口,同时允许跨窗口连接,从而提高了效率。这种架构可以灵活地对各种尺度的信息进行建模,其计算复杂度与图像大小呈线性关系。

本示例要求TensorFlow2.5以上。

设置

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf  # For tf.data and preprocessing only.
import keras
from keras import layers
from keras import ops

配置超参数

需要选择的一个关键参数是 patch_size,即输入补丁的大小。为了将每个像素作为一个单独的输入,可以将 patch_size 设置为 (1,1)。下面,我们将从 ImageNet-1K 的原始论文训练设置中汲取灵感,在本示例中保留大部分原始设置。

num_classes = 100
input_shape = (32, 32, 3)patch_size = (2, 2)  # 2-by-2 sized patches
dropout_rate = 0.03  # Dropout rate
num_heads = 8  # Attention heads
embed_dim = 64  # Embedding dimension
num_mlp = 256  # MLP layer size
# Convert embedded patches to query, key, and values with a learnable additive
# value
qkv_bias = True
window_size = 2  # Size of attention window
shift_size = 1  # Size of shifting window
image_dimension = 32  # Initial image sizenum_patch_x = input_shape[0] // patch_size[0]
num_patch_y = input_shape[1] // patch_size[1]learning_rate = 1e-3
batch_size = 128
num_epochs = 40
validation_split = 0.1
weight_decay = 0.0001
label_smoothing = 0.1

准备数据

我们通过 keras.datasets 加载 CIFAR-100 数据集,对图像进行归一化处理,并将整数标签转换为单击编码向量。

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
num_train_samples = int(len(x_train) * (1 - validation_split))
num_val_samples = len(x_train) - num_train_samples
x_train, x_val = np.split(x_train, [num_train_samples])
y_train, y_val = np.split(y_train, [num_train_samples])
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")plt.figure(figsize=(10, 10))
for i in range(25):plt.subplot(5, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(x_train[i])
plt.show()

演绎展示:

x_train shape: (45000, 32, 32, 3) - y_train shape: (45000, 100)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 100)

辅助函数


我们创建了两个辅助函数,帮助我们从图像中获取补丁序列、合并补丁和应用滤波。

def window_partition(x, window_size):_, height, width, channels = x.shapepatch_num_y = height // window_sizepatch_num_x = width // window_sizex = ops.reshape(x,(-1,patch_num_y,window_size,patch_num_x,window_size,channels,),)x = ops.transpose(x, (0, 1, 3, 2, 4, 5))windows = ops.reshape(x, (-1, window_size, window_size, channels))return windowsdef window_reverse(windows, window_size, height, width, channels):patch_num_y = height // window_sizepatch_num_x = width // window_sizex = ops.reshape(windows,(-1,patch_num_y,patch_num_x,window_size,window_size,channels,),)x = ops.transpose(x, (0, 1, 3, 2, 4, 5))x = ops.reshape(x, (-1, height, width, channels))return x

基于窗口的多头自注意力计算


通常情况下,Transformer模型会进行全局自注意力计算,即计算一个标记与所有其他标记之间的关系。全局计算会导致与标记数量呈二次复杂度的计算量。在这里,如原论文建议的那样,我们以非重叠的方式在局部窗口内计算自注意力。全局自注意力会导致与图像块数量呈二次复杂度的计算量,而基于窗口的自注意力则具有线性复杂度,并且易于扩展。

class WindowAttention(layers.Layer):def __init__(self,dim,window_size,num_heads,qkv_bias=True,dropout_rate=0.0,**kwargs,):super().__init__(**kwargs)self.dim = dimself.window_size = window_sizeself.num_heads = num_headsself.scale = (dim // num_heads) ** -0.5self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias)self.dropout = layers.Dropout(dropout_rate)self.proj = layers.Dense(dim)num_window_elements = (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1)self.relative_position_bias_table = self.add_weight(shape=(num_window_elements, self.num_heads),initializer=keras.initializers.Zeros(),trainable=True,)coords_h = np.arange(self.window_size[0])coords_w = np.arange(self.window_size[1])coords_matrix = np.meshgrid(coords_h, coords_w, indexing="ij")coords = np.stack(coords_matrix)coords_flatten = coords.reshape(2, -1)relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]relative_coords = relative_coords.transpose([1, 2, 0])relative_coords[:, :, 0] += self.window_size[0] - 1relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1)self.relative_position_index = keras.Variable(initializer=relative_position_index,shape=relative_position_index.shape,dtype="int",trainable=False,)def call(self, x, mask=None):_, size, channels = x.shapehead_dim = channels // self.num_headsx_qkv = self.qkv(x)x_qkv = ops.reshape(x_qkv, (-1, size, 3, self.num_heads, head_dim))x_qkv = ops.transpose(x_qkv, (2, 0, 3, 1, 4))q, k, v = x_qkv[0], x_qkv[1], x_qkv[2]q = q * self.scalek = ops.transpose(k, (0, 1, 3, 2))attn = q @ knum_window_elements = self.window_size[0] * self.window_size[1]relative_position_index_flat = ops.reshape(self.relative_position_index, (-1,))relative_position_bias = ops.take(self.relative_position_bias_table,relative_position_index_flat,axis=0,)relative_position_bias = ops.reshape(relative_position_bias,(num_window_elements, num_window_elements, -1),)relative_position_bias = ops.transpose(relative_position_bias, (2, 0, 1))attn = attn + ops.expand_dims(relative_position_bias, axis=0)if mask is not None:nW = mask.shape[0]mask_float = ops.cast(ops.expand_dims(ops.expand_dims(mask, axis=1), axis=0),"float32",)attn = ops.reshape(attn, (-1, nW, self.num_heads, size, size)) + mask_floatattn = ops.reshape(attn, (-1, self.num_heads, size, size))attn = keras.activations.softmax(attn, axis=-1)else:attn = keras.activations.softmax(attn, axis=-1)attn = self.dropout(attn)x_qkv = attn @ vx_qkv = ops.transpose(x_qkv, (0, 2, 1, 3))x_qkv = ops.reshape(x_qkv, (-1, size, channels))x_qkv = self.proj(x_qkv)x_qkv = self.dropout(x_qkv)return x_qkv

完整的Swin Transformer模型

我们通过将标准的多头注意力(MHA)替换为移位窗口注意力,组成了完整的Swin Transformer模型。

正如原始论文中建议的那样,我们创建了一个模型,包括基于移位窗口的MHA层,然后是2层MLP,其中间应用了GELU非线性激活函数,在每个MSA层和每个MLP之前应用了LayerNormalization,并在每个层之后使用了残差连接。

请注意,我们只创建了一个包含2个全连接层和2个Dropout层的简单MLP。

通常情况下,您会看到在文献中使用ResNet-50作为MLP,这是相当标准的做法。然而,在这篇论文中,作者使用了一个包含2层MLP和GELU非线性激活函数的模型。

class SwinTransformer(layers.Layer):def __init__(self,dim,num_patch,num_heads,window_size=7,shift_size=0,num_mlp=1024,qkv_bias=True,dropout_rate=0.0,**kwargs,):super().__init__(**kwargs)self.dim = dim  # number of input dimensionsself.num_patch = num_patch  # number of embedded patchesself.num_heads = num_heads  # number of attention headsself.window_size = window_size  # size of windowself.shift_size = shift_size  # size of window shiftself.num_mlp = num_mlp  # number of MLP nodesself.norm1 = layers.LayerNormalization(epsilon=1e-5)self.attn = WindowAttention(dim,window_size=(self.window_size, self.window_size),num_heads=num_heads,qkv_bias=qkv_bias,dropout_rate=dropout_rate,)self.drop_path = layers.Dropout(dropout_rate)self.norm2 = layers.LayerNormalization(epsilon=1e-5)self.mlp = keras.Sequential([layers.Dense(num_mlp),layers.Activation(keras.activations.gelu),layers.Dropout(dropout_rate),layers.Dense(dim),layers.Dropout(dropout_rate),])if min(self.num_patch) < self.window_size:self.shift_size = 0self.window_size = min(self.num_patch)def build(self, input_shape):if self.shift_size == 0:self.attn_mask = Noneelse:height, width = self.num_patchh_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None),)w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None),)mask_array = np.zeros((1, height, width, 1))count = 0for h in h_slices:for w in w_slices:mask_array[:, h, w, :] = countcount += 1mask_array = ops.convert_to_tensor(mask_array)# mask array to windowsmask_windows = window_partition(mask_array, self.window_size)mask_windows = ops.reshape(mask_windows, [-1, self.window_size * self.window_size])attn_mask = ops.expand_dims(mask_windows, axis=1) - ops.expand_dims(mask_windows, axis=2)attn_mask = ops.where(attn_mask != 0, -100.0, attn_mask)attn_mask = ops.where(attn_mask == 0, 0.0, attn_mask)self.attn_mask = keras.Variable(initializer=attn_mask,shape=attn_mask.shape,dtype=attn_mask.dtype,trainable=False,)def call(self, x, training=False):height, width = self.num_patch_, num_patches_before, channels = x.shapex_skip = xx = self.norm1(x)x = ops.reshape(x, (-1, height, width, channels))if self.shift_size > 0:shifted_x = ops.roll(x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2])else:shifted_x = xx_windows = window_partition(shifted_x, self.window_size)x_windows = ops.reshape(x_windows, (-1, self.window_size * self.window_size, channels))attn_windows = self.attn(x_windows, mask=self.attn_mask)attn_windows = ops.reshape(attn_windows,(-1, self.window_size, self.window_size, channels),)shifted_x = window_reverse(attn_windows, self.window_size, height, width, channels)if self.shift_size > 0:x = ops.roll(shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2])else:x = shifted_xx = ops.reshape(x, (-1, height * width, channels))x = self.drop_path(x, training=training)x = x_skip + xx_skip = xx = self.norm2(x)x = self.mlp(x)x = self.drop_path(x)x = x_skip + xreturn x

模型训练与评估

提取和嵌入补丁

我们首先创建3个层来帮助我们从图像中提取、嵌入和合并补丁,随后我们将使用我们构建的Swin Transformer类来处理这些补丁。

# Using tf ops since it is only used in tf.data.
def patch_extract(images):batch_size = tf.shape(images)[0]patches = tf.image.extract_patches(images=images,sizes=(1, patch_size[0], patch_size[1], 1),strides=(1, patch_size[0], patch_size[1], 1),rates=(1, 1, 1, 1),padding="VALID",)patch_dim = patches.shape[-1]patch_num = patches.shape[1]return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))class PatchEmbedding(layers.Layer):def __init__(self, num_patch, embed_dim, **kwargs):super().__init__(**kwargs)self.num_patch = num_patchself.proj = layers.Dense(embed_dim)self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)def call(self, patch):pos = ops.arange(start=0, stop=self.num_patch)return self.proj(patch) + self.pos_embed(pos)class PatchMerging(keras.layers.Layer):def __init__(self, num_patch, embed_dim):super().__init__()self.num_patch = num_patchself.embed_dim = embed_dimself.linear_trans = layers.Dense(2 * embed_dim, use_bias=False)def call(self, x):height, width = self.num_patch_, _, C = x.shapex = ops.reshape(x, (-1, height, width, C))x0 = x[:, 0::2, 0::2, :]x1 = x[:, 1::2, 0::2, :]x2 = x[:, 0::2, 1::2, :]x3 = x[:, 1::2, 1::2, :]x = ops.concatenate((x0, x1, x2, x3), axis=-1)x = ops.reshape(x, (-1, (height // 2) * (width // 2), 4 * C))return self.linear_trans(x)

准备 tf.data.Dataset


我们使用 tf.data 完成所有没有可训练权重的步骤。

准备训练集、验证集和测试集。

def augment(x):x = tf.image.random_crop(x, size=(image_dimension, image_dimension, 3))x = tf.image.random_flip_left_right(x)return xdataset = (tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(lambda x, y: (augment(x), y)).batch(batch_size=batch_size).map(lambda x, y: (patch_extract(x), y)).prefetch(tf.data.experimental.AUTOTUNE)
)dataset_val = (tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size=batch_size).map(lambda x, y: (patch_extract(x), y)).prefetch(tf.data.experimental.AUTOTUNE)
)dataset_test = (tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size=batch_size).map(lambda x, y: (patch_extract(x), y)).prefetch(tf.data.experimental.AUTOTUNE)
)

建立模型

我们组装了斯温变换器模型。

input = layers.Input(shape=(256, 12))
x = PatchEmbedding(num_patch_x * num_patch_y, embed_dim)(input)
x = SwinTransformer(dim=embed_dim,num_patch=(num_patch_x, num_patch_y),num_heads=num_heads,window_size=window_size,shift_size=0,num_mlp=num_mlp,qkv_bias=qkv_bias,dropout_rate=dropout_rate,
)(x)
x = SwinTransformer(dim=embed_dim,num_patch=(num_patch_x, num_patch_y),num_heads=num_heads,window_size=window_size,shift_size=shift_size,num_mlp=num_mlp,qkv_bias=qkv_bias,dropout_rate=dropout_rate,
)(x)
x = PatchMerging((num_patch_x, num_patch_y), embed_dim=embed_dim)(x)
x = layers.GlobalAveragePooling1D()(x)
output = layers.Dense(num_classes, activation="softmax")(x)

展示演绎:

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

在 CIFAR-100 上训练


我们在 CIFAR-100 上训练模型。

在本例中,为了缩短训练时间,我们只对模型进行了 40 个历元的训练。实际上,应该训练 150 个历元才能达到收敛。

model = keras.Model(input, output)
model.compile(loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),optimizer=keras.optimizers.AdamW(learning_rate=learning_rate, weight_decay=weight_decay),metrics=[keras.metrics.CategoricalAccuracy(name="accuracy"),keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),],
)history = model.fit(dataset,batch_size=batch_size,epochs=num_epochs,validation_data=dataset_val,
)

演绎展示:

Epoch 1/40352/352 ━━━━━━━━━━━━━━━━━━━━ 644s 2s/step - accuracy: 0.0517 - loss: 4.3948 - top-5-accuracy: 0.1816 - val_accuracy: 0.1396 - val_loss: 3.7930 - val_top-5-accuracy: 0.3922
Epoch 2/40352/352 ━━━━━━━━━━━━━━━━━━━━ 626s 2s/step - accuracy: 0.1606 - loss: 3.7267 - top-5-accuracy: 0.4209 - val_accuracy: 0.1946 - val_loss: 3.5560 - val_top-5-accuracy: 0.4862
Epoch 3/40352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - accuracy: 0.2160 - loss: 3.4910 - top-5-accuracy: 0.5076 - val_accuracy: 0.2440 - val_loss: 3.3946 - val_top-5-accuracy: 0.5384
Epoch 4/40352/352 ━━━━━━━━━━━━━━━━━━━━ 620s 2s/step - accuracy: 0.2599 - loss: 3.3266 - top-5-accuracy: 0.5628 - val_accuracy: 0.2730 - val_loss: 3.2732 - val_top-5-accuracy: 0.5812
Epoch 5/40352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - accuracy: 0.2841 - loss: 3.2082 - top-5-accuracy: 0.5988 - val_accuracy: 0.2878 - val_loss: 3.1837 - val_top-5-accuracy: 0.6050
Epoch 6/40352/352 ━━━━━━━━━━━━━━━━━━━━ 617s 2s/step - accuracy: 0.3049 - loss: 3.1199 - top-5-accuracy: 0.6262 - val_accuracy: 0.3110 - val_loss: 3.0970 - val_top-5-accuracy: 0.6292
Epoch 7/40352/352 ━━━━━━━━━━━━━━━━━━━━ 620s 2s/step - accuracy: 0.3271 - loss: 3.0387 - top-5-accuracy: 0.6501 - val_accuracy: 0.3292 - val_loss: 3.0374 - val_top-5-accuracy: 0.6488
Epoch 8/40352/352 ━━━━━━━━━━━━━━━━━━━━ 615s 2s/step - accuracy: 0.3454 - loss: 2.9764 - top-5-accuracy: 0.6679 - val_accuracy: 0.3480 - val_loss: 2.9921 - val_top-5-accuracy: 0.6598
Epoch 9/40352/352 ━━━━━━━━━━━━━━━━━━━━ 617s 2s/step - accuracy: 0.3571 - loss: 2.9272 - top-5-accuracy: 0.6801 - val_accuracy: 0.3522 - val_loss: 2.9585 - val_top-5-accuracy: 0.6746
Epoch 10/40352/352 ━━━━━━━━━━━━━━━━━━━━ 624s 2s/step - accuracy: 0.3658 - loss: 2.8809 - top-5-accuracy: 0.6924 - val_accuracy: 0.3562 - val_loss: 2.9364 - val_top-5-accuracy: 0.6784
Epoch 11/40352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - accuracy: 0.3796 - loss: 2.8425 - top-5-accuracy: 0.7021 - val_accuracy: 0.3654 - val_loss: 2.9100 - val_top-5-accuracy: 0.6832
Epoch 12/40352/352 ━━━━━━━━━━━━━━━━━━━━ 622s 2s/step - accuracy: 0.3884 - loss: 2.8113 - top-5-accuracy: 0.7103 - val_accuracy: 0.3740 - val_loss: 2.8808 - val_top-5-accuracy: 0.6948
Epoch 13/40352/352 ━━━━━━━━━━━━━━━━━━━━ 621s 2s/step - accuracy: 0.3994 - loss: 2.7718 - top-5-accuracy: 0.7239 - val_accuracy: 0.3778 - val_loss: 2.8637 - val_top-5-accuracy: 0.6994
Epoch 14/40352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - accuracy: 0.4072 - loss: 2.7491 - top-5-accuracy: 0.7271 - val_accuracy: 0.3848 - val_loss: 2.8533 - val_top-5-accuracy: 0.7002
Epoch 15/40352/352 ━━━━━━━━━━━━━━━━━━━━ 614s 2s/step - accuracy: 0.4142 - loss: 2.7180 - top-5-accuracy: 0.7344 - val_accuracy: 0.3880 - val_loss: 2.8383 - val_top-5-accuracy: 0.7080
Epoch 16/40352/352 ━━━━━━━━━━━━━━━━━━━━ 614s 2s/step - accuracy: 0.4231 - loss: 2.6918 - top-5-accuracy: 0.7392 - val_accuracy: 0.3934 - val_loss: 2.8323 - val_top-5-accuracy: 0.7072
Epoch 17/40352/352 ━━━━━━━━━━━━━━━━━━━━ 617s 2s/step - accuracy: 0.4339 - loss: 2.6633 - top-5-accuracy: 0.7484 - val_accuracy: 0.3972 - val_loss: 2.8237 - val_top-5-accuracy: 0.7138
Epoch 18/40352/352 ━━━━━━━━━━━━━━━━━━━━ 617s 2s/step - accuracy: 0.4388 - loss: 2.6436 - top-5-accuracy: 0.7506 - val_accuracy: 0.3984 - val_loss: 2.8119 - val_top-5-accuracy: 0.7144
Epoch 19/40352/352 ━━━━━━━━━━━━━━━━━━━━ 610s 2s/step - accuracy: 0.4439 - loss: 2.6251 - top-5-accuracy: 0.7552 - val_accuracy: 0.4020 - val_loss: 2.8044 - val_top-5-accuracy: 0.7178
Epoch 20/40352/352 ━━━━━━━━━━━━━━━━━━━━ 611s 2s/step - accuracy: 0.4540 - loss: 2.5989 - top-5-accuracy: 0.7652 - val_accuracy: 0.4012 - val_loss: 2.7969 - val_top-5-accuracy: 0.7246
Epoch 21/40352/352 ━━━━━━━━━━━━━━━━━━━━ 618s 2s/step - accuracy: 0.4586 - loss: 2.5760 - top-5-accuracy: 0.7684 - val_accuracy: 0.4092 - val_loss: 2.7807 - val_top-5-accuracy: 0.7254
Epoch 22/40352/352 ━━━━━━━━━━━━━━━━━━━━ 618s 2s/step - accuracy: 0.4607 - loss: 2.5624 - top-5-accuracy: 0.7724 - val_accuracy: 0.4158 - val_loss: 2.7721 - val_top-5-accuracy: 0.7232
Epoch 23/40352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - accuracy: 0.4658 - loss: 2.5407 - top-5-accuracy: 0.7786 - val_accuracy: 0.4180 - val_loss: 2.7767 - val_top-5-accuracy: 0.7280
Epoch 24/40352/352 ━━━━━━━━━━━━━━━━━━━━ 617s 2s/step - accuracy: 0.4744 - loss: 2.5233 - top-5-accuracy: 0.7840 - val_accuracy: 0.4164 - val_loss: 2.7707 - val_top-5-accuracy: 0.7300
Epoch 25/40352/352 ━━━━━━━━━━━━━━━━━━━━ 615s 2s/step - accuracy: 0.4758 - loss: 2.5129 - top-5-accuracy: 0.7847 - val_accuracy: 0.4196 - val_loss: 2.7677 - val_top-5-accuracy: 0.7294
Epoch 26/40352/352 ━━━━━━━━━━━━━━━━━━━━ 610s 2s/step - accuracy: 0.4853 - loss: 2.4954 - top-5-accuracy: 0.7863 - val_accuracy: 0.4188 - val_loss: 2.7571 - val_top-5-accuracy: 0.7362
Epoch 27/40352/352 ━━━━━━━━━━━━━━━━━━━━ 610s 2s/step - accuracy: 0.4858 - loss: 2.4785 - top-5-accuracy: 0.7928 - val_accuracy: 0.4186 - val_loss: 2.7615 - val_top-5-accuracy: 0.7348
Epoch 28/40352/352 ━━━━━━━━━━━━━━━━━━━━ 613s 2s/step - accuracy: 0.4889 - loss: 2.4691 - top-5-accuracy: 0.7945 - val_accuracy: 0.4208 - val_loss: 2.7561 - val_top-5-accuracy: 0.7350
Epoch 29/40352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - accuracy: 0.4940 - loss: 2.4592 - top-5-accuracy: 0.7992 - val_accuracy: 0.4244 - val_loss: 2.7546 - val_top-5-accuracy: 0.7398
Epoch 30/40352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - accuracy: 0.4989 - loss: 2.4391 - top-5-accuracy: 0.8025 - val_accuracy: 0.4180 - val_loss: 2.7861 - val_top-5-accuracy: 0.7302
Epoch 31/40352/352 ━━━━━━━━━━━━━━━━━━━━ 610s 2s/step - accuracy: 0.4994 - loss: 2.4354 - top-5-accuracy: 0.8032 - val_accuracy: 0.4264 - val_loss: 2.7608 - val_top-5-accuracy: 0.7394
Epoch 32/40352/352 ━━━━━━━━━━━━━━━━━━━━ 607s 2s/step - accuracy: 0.5011 - loss: 2.4238 - top-5-accuracy: 0.8090 - val_accuracy: 0.4292 - val_loss: 2.7625 - val_top-5-accuracy: 0.7384
Epoch 33/40352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - accuracy: 0.5065 - loss: 2.4144 - top-5-accuracy: 0.8085 - val_accuracy: 0.4288 - val_loss: 2.7517 - val_top-5-accuracy: 0.7328
Epoch 34/40352/352 ━━━━━━━━━━━━━━━━━━━━ 612s 2s/step - accuracy: 0.5094 - loss: 2.4099 - top-5-accuracy: 0.8093 - val_accuracy: 0.4260 - val_loss: 2.7550 - val_top-5-accuracy: 0.7390
Epoch 35/40352/352 ━━━━━━━━━━━━━━━━━━━━ 612s 2s/step - accuracy: 0.5109 - loss: 2.3980 - top-5-accuracy: 0.8115 - val_accuracy: 0.4278 - val_loss: 2.7496 - val_top-5-accuracy: 0.7396
Epoch 36/40352/352 ━━━━━━━━━━━━━━━━━━━━ 615s 2s/step - accuracy: 0.5178 - loss: 2.3868 - top-5-accuracy: 0.8139 - val_accuracy: 0.4296 - val_loss: 2.7519 - val_top-5-accuracy: 0.7404
Epoch 37/40352/352 ━━━━━━━━━━━━━━━━━━━━ 618s 2s/step - accuracy: 0.5151 - loss: 2.3842 - top-5-accuracy: 0.8150 - val_accuracy: 0.4308 - val_loss: 2.7504 - val_top-5-accuracy: 0.7424
Epoch 38/40352/352 ━━━━━━━━━━━━━━━━━━━━ 613s 2s/step - accuracy: 0.5169 - loss: 2.3798 - top-5-accuracy: 0.8159 - val_accuracy: 0.4360 - val_loss: 2.7522 - val_top-5-accuracy: 0.7464
Epoch 39/40352/352 ━━━━━━━━━━━━━━━━━━━━ 618s 2s/step - accuracy: 0.5228 - loss: 2.3641 - top-5-accuracy: 0.8201 - val_accuracy: 0.4374 - val_loss: 2.7386 - val_top-5-accuracy: 0.7452
Epoch 40/40352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - accuracy: 0.5232 - loss: 2.3633 - top-5-accuracy: 0.8212 - val_accuracy: 0.4266 - val_loss: 2.7614 - val_top-5-accuracy: 0.7410

让我们可视化模型的训练进程。

plt.plot(history.history["loss"], label="train_loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Train and Validation Losses Over Epochs", fontsize=14)
plt.legend()
plt.grid()
plt.show()

让我们展示在CIFAR-100上训练的最终结果。

loss, accuracy, top_5_accuracy = model.evaluate(dataset_test)
print(f"Test loss: {round(loss, 2)}")
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

演绎展示:
 

 79/79 ━━━━━━━━━━━━━━━━━━━━ 26s 325ms/step - accuracy: 0.4474 - loss: 2.7119 - top-5-accuracy: 0.7556
Test loss: 2.7
Test accuracy: 44.8%
Test top 5 accuracy: 75.23%

我们刚刚训练的Swin Transformer模型只有152K个参数,在40个epoch内就能达到约75%的测试前5准确率,而且没有出现过拟合的迹象,正如上面演绎所看到的。

这意味着我们可以更长时间地训练这个网络(可能稍微加强一点正则化),以获得更好的性能。这种性能可以通过其他技术(如余弦衰减学习率调度、其他数据增强技术)进一步提高。在实验时,我尝试了在稍高的dropout和更大的嵌入维度下训练模型150个epoch,这将性能提升到了CIFAR-100测试准确率约为72%的水平。

大家看到,在ImageNet上展示了87.3%的top-1准确率。文中还展示了一系列实验来研究输入尺寸、优化器等对该模型最终性能的影响。本文还进一步展示了使用该模型进行目标检测、语义分割和实例分割,并报告了这些任务的竞争性结果。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/814322.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[StartingPoint][Tier2]Included

LXD https://www.hackingarticles.in/lxd-privilege-escalation/ Task 1 What service is running on the target machine over UDP? &#xff08;目标机器上通过UDP运行的服务是什么&#xff1f;&#xff09; $ nmap -sU 10.129.232.86 -p 69 tftp Task 2 What class o…

DRF多表关联的序列化和反序列化

DRF多表关联的序列化和反序列化 目录 DRF多表关联的序列化和反序列化序列化定制字段source一对多的序列化 多表关联的序列化方式1&#xff1a;在表模型中定义方法方式2&#xff1a;定制返回格式SerializerMethodField方式3&#xff1a;子序列化 多表关联的反序列化反序列化保存…

Linux【实战篇】—— NFS服务搭建与配置

目录 一、介绍 1.1什么是NFS&#xff1f; 1.2客户端与服务端之间的NFS如何进行数据传输&#xff1f; 1.3RPC和NFS的启动顺序 1.4NFS服务 系统守护进程 二、安装NFS服务端 2.1安装NFS服务 2.2 创建共享目录 2.3创建共享目录首页文件 2.4关闭防火墙 2.5启动NFS服务 2.…

秋叶Stable diffusion的创世工具安装-带安装包链接

来自B站up秋葉aaaki&#xff0c;近期发布了Stable Diffusion整合包v4.7版本&#xff0c;一键在本地部署Stable Diffusion&#xff01;&#xff01; 适用于零基础想要使用AI绘画的小伙伴~本整合包支持SDXL&#xff0c;预装多种必须模型。无需安装git、python、cuda等任何内容&am…

day9 | 栈与队列 part-1 (Go) | 232 用栈实现队列、225 用队列实现栈

今日任务 栈与队列的理论基础 (介绍:代码随想录)232 用栈实现队列(题目: . - 力扣&#xff08;LeetCode&#xff09;)225 用队列实现栈 (题目: . - 力扣&#xff08;LeetCode&#xff09; ) 栈与队列的理论基础 栈 : 先进后出 队列: 后进先出 老师给的讲解:代码随想录 …

记一次centos合并excel,word,png,pdf为一个整体pdf的入坑爬坑过程(一直显示宋体问题)。

一、背景 原先已经简单实现了excel,word,png,pdf合成一个整体pdf的过程。并将它弄到docker容器中。 1、原先入坑的技术栈 php:7.4 (业务有涉及)php第三方包 setasign\Fpdi\Fpdi : 2.3.6 &#xff08;pdf合并&#xff09;libreoffice : 5.3.6.1ImageMagick: 6.9.10-68 2、…

本地PC安装eNSP Pro完成简单的WLAN实验

前言 上个月底华为更新一版eNSP Pro&#xff0c;新增了AC、AP、STA等设备&#xff0c;也就是说可以在eNSP中进行WLAN相关的实验了。之前写过一篇文章《将eNSP Pro部署在华为云是什么体验》介绍了怎么在华为云上部署eNSP Pro&#xff0c;这次使用本地PC机在虚拟机中安装eNSP Pr…

RF测试笔记:三阶交调失真概述及测试

1. 交调失真会带来哪些影响&#xff1f; 无线通信系统中&#xff0c;交调失真不仅会影响发射链路的性能&#xff0c;还会影响接收链路的性能。 对于发射链路&#xff0c;非线性最严重的部件非功率放大器莫属&#xff0c;当信号为宽带调制信号时&#xff0c;无论是在信号带宽内…

13 Php学习:面向对象

PHP 面向对象 面向对象&#xff08;Object-Oriented&#xff0c;简称 OO&#xff09;是一种编程思想和方法&#xff0c;它将程序中的数据和操作数据的方法封装在一起&#xff0c;形成"对象"&#xff0c;并通过对象之间的交互和消息传递来完成程序的功能。面向对象编…

基于Python的深度学习的中文情感分析系统(V2.0),附源码

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;…

【STL详解 —— list的介绍及使用】

STL详解 —— list的介绍及使用 list的介绍list的介绍使用list的构造list iterator的使用list capacitylist element accesslist modifiers 示例list的迭代器失效 list的介绍 list是可以在常数范围内在任意位置进行插入和删除的序列式容器&#xff0c;并且该容器可以前后双向迭…

基于Docker构建CI/CD工具链(八)用nginx收集测试报告

当前&#xff0c;我们已经介绍了如何使用 Apifox 和 JMeter 进行测试&#xff0c;尽管控制台已经输出了测试结果&#xff0c;但在实际工作中&#xff0c;我们通常需要更详细的测试报告。 测试报告在测试过程中已经生成&#xff0c;只需将其托管起来以便查阅。如果你有现成的 C…

C++11 设计模式4. 抽象工厂(Abstract Factory)模式

问题的提出 从前面我们已经使用了工厂方法模式 解决了一些问题。 现在 策划又提出了新的需求&#xff1a;对于各个怪物&#xff0c;在不同的场景下&#xff0c;怪物的面板数值会发生变化&#xff0c; //怪物分类&#xff1a;亡灵类&#xff0c;元素类&#xff0c;机械类 …

MATLAB 自定义实现点云法向量和曲率计算(详细解读)(64)

MATLAB 自定义实现点云法向量和曲率计算(详细解读)(64) 一、算法介绍二、算法步骤三、算法实现1.代码 (完整,注释清晰,可直接用)2.结果一、算法介绍 首先说明: ------这里代码手动实现,不调用matlab提供的法向量计算接口,更有助于大家了解法向量和曲率的计算方法,…

docker部署Prometheus+AlertManager实现邮件告警

文章目录 一、环境准备1、硬件准备&#xff08;虚拟机&#xff09;2、关闭防火墙&#xff0c;selinux3、所有主机安装docker 二、配置Prometheus1、docker启动Prometheus 三、添加监控节点1、docker启动node-exporter 四、Prometheus配置node-exporter1、修改prometheus.yml配置…

【网站项目】摄影竞赛小程序

&#x1f64a;作者简介&#xff1a;拥有多年开发工作经验&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。&#x1f339;赠送计算机毕业设计600个选题excel文件&#xff0c;帮助大学选题。赠送开题报告模板&#xff…

刷题之动态规划-回文串

前言 大家好&#xff0c;我是jiantaoyab&#xff0c;开始刷动态规划的回文串类型相关的题目 动态规划5个步骤 状态表示 &#xff1a;dp数组中每一个下标对应值的含义是什么>dp[i]表示什么状态转移方程&#xff1a; dp[i] 等于什么1 和 2 是动态规划的核心步骤&#xff0c;…

某次众测的加解密对抗

前言 起源于某次众测中&#xff0c;遇到请求包响应包全密文的情况&#xff0c;最终实现burp中加解密。 用到的工具有 sekiro&#xff08;rpc转发&#xff09;flask&#xff08;autodecoder自定义接口&#xff09;autodecoder&#xff08;burp插件转发&#xff09; debug部分…

ClickHouse--18--argMin() 和argMax()函数

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 argMin() 和argMax()函数业务场景使用案例1.准备表和数据&#xff1a;业务场景一&#xff1a;查看salary 最高和最小的user业务场景二&#xff1a;根据更新时间获取…

【fastapi】搭建第一个fastapi后端项目

本篇文章介绍一下fastapi后端项目的搭建。其实没有什么好说的&#xff0c;按照官方教程来即可&#xff1a;https://fastapi.tiangolo.com/zh/ 安装依赖 这也是我觉得python项目的槽点之一。所有依赖都安装在本地&#xff0c;一旦在别人电脑上编写项目就又要安装一遍。很扯淡。…