摘要:
记录MindSpore AI框架使用ViT模型在ImageNet图像数据分类上进行训练、验证、推理的过程和方法。包括环境准备、下载数据集、数据集加载、模型解析与构建、模型训练与推理等。
一、概念
1. ViT模型
Vision Transformer
自注意结构模型
Self-Attention
Transformer模型
能够训练具有超过100B规模的参数模型
领域
自然语言处理
计算机视觉
不依赖卷积操作
2.模型结构
ViT模型主体结构

从下往上
最下面主输入数据集
原图像划分为多个patch(图像块)
二维patch(不考虑channel)转换为一维向量
中间backbone基于Transformer模型Encoder部分
Multi-head Attention结构
部分结构顺序有调整
Normalization位置不同
上面Blocks堆叠后接全连接层Head
附加输入类别向量
输出识别分类结果
二、环境准备
确保安装了Python环境和MindSpore
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore
输出:
Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by:
三、数据准备
1.下载、解压数据集
下载源
http://image-net.org
ImageNet数据集
本案例应用数据集是从ImageNet筛选的子集。
from download import download
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"
path = download(dataset_url, path, kind="zip", replace=True)
输出:
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip (489.1 MB)file_sizes: 100%|█████████████████████████████| 513M/513M [00:02<00:00, 228MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./
2.数据集路径结构:
.dataset/├── ILSVRC2012_devkit_t12.tar.gz├── train/├── infer/└── val/
3.加载数据集
import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms
data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)
trans_train = [transforms.RandomCropDecodeResize(size=224,scale=(0.08, 1.0),ratio=(0.75, 1.333)),transforms.RandomHorizontalFlip(prob=0.5),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)
四、模型解析
1.Transformer基本原理
Transformer模型
基于Attention机制的编码器-解码器型结构
模型结构图:

多个Encoder和Decoder模块所组成
Encoder和Decoder详细结构图:

Encoder与Decoder结构组成
多头注意力Multi-Head Attention层
基于自注意力Self-Attention机制
多个Self-Attention并行组成
Feed Forward层
Normaliztion层
残差连接(Residual Connection),图中的“Add”
2.Attention模块
Self-Attention核心内容
为输入向量的每个单词学习一个权重
给定查询向量Query
计算Query和各个Key的相似性或者相关性
得到注意力分布
得到每个Key对应Value的权重系数
对Value进行加权求和得到最终的Attention数值。
Self-Attention机制:
(1) 最初的输入向量
经过Embedding层
映射成dim x 3
分割成三个向量
Q(Query)
K(Key)
V(Value)
输入向量为一个一维向量序列(x1,x2,x3)
每个一维向量经过Embedding层映射出Q、K、V三个向量
只是Embedding矩阵不同
矩阵参数通过学习得到
向量之间关联
通过Q、K、V三个矩阵可计算
其中两个向量点乘获得权重
另一个向量承载权重向加的结果


(2) 自注意力机制的自注意主要体现
Q、K、V来源于其自身
自注意过程
提取输入的不同顺序的向量的联系与特征
通过不同顺序向量之间的联系紧密性表现
Q与K乘积经过Softmax的结果
获取Q,K,V向量间权重
Q、K点乘
除以维度的平方根
Softmax处理所有向量的结果


![]()

(3) 全局自注意
向量V与Q、K经过Softmax结果
weight sum
每一组Q、K、V最后都有一个V输出
当前向量结合其他向量关联权重得到结果
![]()
Self-Attention全部过程:

多头注意力机制
分割self-Attention处理的向量为多个Head部分处理
并行加速
保持参数总量不变
同样的query, key和value映射为高维空间(Q,K,V)
不同子空间(Q_0,K_0,V_0)
分开计算自注意力
最后再合并不同子空间中的注意力信息。
同一个输入向量
多个注意力机制可以并行加速处理
处理时更充分的分析和利用了向量特征
下图中ai和aj是同一个向量分割而得

以下是Multi-Head Attention代码:
from mindspore import nn, ops
class Attention(nn.Cell):def __init__(self,dim: int,num_heads: int = 8,keep_prob: float = 1.0,attention_keep_prob: float = 1.0):super(Attention, self).__init__()
self.num_heads = num_headshead_dim = dim // num_headsself.scale = ms.Tensor(head_dim ** -0.5)
self.qkv = nn.Dense(dim, dim * 3)self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)self.out = nn.Dense(dim, dim)self.out_drop = nn.Dropout(p=1.0-keep_prob)self.attn_matmul_v = ops.BatchMatMul()self.q_matmul_k = ops.BatchMatMul(transpose_b=True)self.softmax = nn.Softmax(axis=-1)
def construct(self, x):"""Attention construct."""b, n, c = x.shapeqkv = self.qkv(x)qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))q, k, v = ops.unstack(qkv, axis=0)attn = self.q_matmul_k(q, k)attn = ops.mul(attn, self.scale)attn = self.softmax(attn)attn = self.attn_drop(attn)out = self.attn_matmul_v(attn, v)out = ops.transpose(out, (0, 2, 1, 3))out = ops.reshape(out, (b, n, c))out = self.out(out)out = self.out_drop(out)
return out
Transformer Encoder
多结构拼接形成Transformer基础结构
Self-Attention
Feed Forward
Residual Connection
Feed Forward,Residual Connection结构代码:
from typing import Optional, Dict
class FeedForward(nn.Cell):def __init__(self,in_features: int,hidden_features: Optional[int] = None,out_features: Optional[int] = None,activation: nn.Cell = nn.GELU,keep_prob: float = 1.0):super(FeedForward, self).__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.dense1 = nn.Dense(in_features, hidden_features)self.activation = activation()self.dense2 = nn.Dense(hidden_features, out_features)self.dropout = nn.Dropout(p=1.0-keep_prob)
def construct(self, x):"""Feed Forward construct."""x = self.dense1(x)x = self.activation(x)x = self.dropout(x)x = self.dense2(x)x = self.dropout(x)
return x
class ResidualCell(nn.Cell):def __init__(self, cell):super(ResidualCell, self).__init__()self.cell = cell
def construct(self, x):"""ResidualCell construct."""return self.cell(x) + x
Self-Attention构建ViT模型中的TransformerEncoder部分:

ViT模型Transformer不同
Normalization放在Self-Attention和Feed Forward之前
其他结构不变
Transformer结构图
多个子encoder堆叠构建模型编码器
ViT模型配置超参数num_layers
确定堆叠层数
Residual Connection,Normalization的结构
保证信息经过深层处理不退化
增强模型泛化能力
TransformerEncoder结构和多层感知器(MLP)结合
构成了ViT模型的backbone部分
class TransformerEncoder(nn.Cell):def __init__(self,dim: int,num_layers: int,num_heads: int,mlp_dim: int,keep_prob: float = 1.,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: nn.Cell = nn.LayerNorm):super(TransformerEncoder, self).__init__()layers = []
for _ in range(num_layers):normalization1 = norm((dim,))normalization2 = norm((dim,))attention = Attention(dim=dim,num_heads=num_heads,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob)
feedforward = FeedForward(in_features=dim,hidden_features=mlp_dim,activation=activation,keep_prob=keep_prob)
layers.append(nn.SequentialCell([ResidualCell(nn.SequentialCell([normalization1, attention])),ResidualCell(nn.SequentialCell([normalization2, feedforward]))]))self.layers = nn.SequentialCell(layers)
def construct(self, x):"""Transformer construct."""return self.layers(x)
ViT模型的输入
传统的Transformer结构
处理自然语言领域的词向量
(Word Embedding or Word Vector),
词向量是一维向量堆叠
图片是二维矩阵堆叠,
多头注意力机制处理一维词向量堆叠时会提取词向量之间的联系也就是上下文语义
ViT模型中:
输入图像每个channel卷积操作划分1616个patch
一幅输入224 x 224的图像卷积处理
得到16 x 16个patch
每一个patch的大小就是14 x 14
每个patch矩阵拉伸成为一维向量
获得近似词向量堆叠的效果
14 x 14patch转换为长度196的向量
图像输入网络经过的第一步处理。
Patch Embedding代码:
class PatchEmbedding(nn.Cell):MIN_NUM_PATCHES = 4
def __init__(self,image_size: int = 224,patch_size: int = 16,embed_dim: int = 768,input_channels: int = 3):super(PatchEmbedding, self).__init__()
self.image_size = image_sizeself.patch_size = patch_sizeself.num_patches = (image_size // patch_size) ** 2self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)
def construct(self, x):"""Path Embedding construct."""x = self.conv(x)b, c, h, w = x.shapex = ops.reshape(x, (b, c, h * w))x = ops.transpose(x, (0, 2, 1))
return x
输入图像划分patch后
经过pos_embedding
class_embedding两个过程。
class_embedding借鉴BERT模型用于文本分类
每一个word vector之前增加一个类别值
196维向量加上class_embedding变为197维
class_embedding是一个可以学习的参数
经过网络的不断训练,输出向量的第一个维度的输出来决定最后的输出类别;
输入16 x 16patch
输出16x16个class_embedding进行分类。
pos_embedding也是一组可以学习的参数
加入patch矩阵
pos_embedding有4种方案
采用一维pos_embedding
由于class_embedding是加在pos_embedding之前
所以pos_embedding维度会比patch拉伸后的维度加1。
五、整体构建ViT
构建ViT模型代码
from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameter
def init(init_type, shape, dtype, name, requires_grad):"""Init."""initial = initializer(init_type, shape, dtype).init_data()return Parameter(initial, name=name, requires_grad=requires_grad)
class ViT(nn.Cell):def __init__(self,image_size: int = 224,input_channels: int = 3,patch_size: int = 16,embed_dim: int = 768,num_layers: int = 12,num_heads: int = 12,mlp_dim: int = 3072,keep_prob: float = 1.0,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: Optional[nn.Cell] = nn.LayerNorm,pool: str = 'cls') -> None:super(ViT, self).__init__()
self.patch_embedding = PatchEmbedding(image_size=image_size,patch_size=patch_size,embed_dim=embed_dim,input_channels=input_channels)num_patches = self.patch_embedding.num_patches
self.cls_token = init(init_type=Normal(sigma=1.0),shape=(1, 1, embed_dim),dtype=ms.float32,name='cls',requires_grad=True)
self.pos_embedding = init(init_type=Normal(sigma=1.0),shape=(1, num_patches + 1, embed_dim),dtype=ms.float32,name='pos_embedding',requires_grad=True)
self.pool = poolself.pos_dropout = nn.Dropout(p=1.0-keep_prob)self.norm = norm((embed_dim,))self.transformer = TransformerEncoder(dim=embed_dim,num_layers=num_layers,num_heads=num_heads,mlp_dim=mlp_dim,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob,drop_path_keep_prob=drop_path_keep_prob,activation=activation,norm=norm)self.dropout = nn.Dropout(p=1.0-keep_prob)self.dense = nn.Dense(embed_dim, num_classes)
def construct(self, x):"""ViT construct."""x = self.patch_embedding(x)cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))x = ops.concat((cls_tokens, x), axis=1)x += self.pos_embedding
x = self.pos_dropout(x)x = self.transformer(x)x = self.norm(x)x = x[:, 0]if self.training:x = self.dropout(x)x = self.dense(x)
return x
整体流程图如下所示:

六、模型训练与推理
1.模型训练
模型开始训练
设定损失函数
优化器
回调函数
调整epoch_size
from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train
# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()
# construct model
network = ViT()
# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),max_lr=0.00005,total_step=epoch_size * step_size,step_per_epoch=step_size,decay_epoch=10)
# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)
# define loss function
class CrossEntropySmooth(LossBase):"""CrossEntropy."""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):super(CrossEntropySmooth, self).__init__()self.onehot = ops.OneHot()self.sparse = sparseself.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):if self.sparse:label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)loss = self.ce(logit, label)return loss
network_loss = CrossEntropySmooth(sparse=True,reduction="mean",smooth_factor=0.1,num_classes=num_classes)
# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")
# train model
model.train(epoch_size,dataset_train,callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],dataset_sink_mode=False,)
输出:
Downloading data from https://download-mindspore.osinfra.cn/vision/classification/vit_b_16_224.ckpt (330.2 MB)file_sizes: 100%|████████████████████████████| 346M/346M [00:26<00:00, 13.2MB/s]
Successfully downloaded file to ./ckpt/vit_b_16_224.ckpt
epoch: 1 step: 125, loss is 1.4842896
Train epoch time: 275011.631 ms, per step time: 2200.093 ms
epoch: 2 step: 125, loss is 1.3481578
Train epoch time: 23961.255 ms, per step time: 191.690 ms
epoch: 3 step: 125, loss is 1.3990085
Train epoch time: 24217.701 ms, per step time: 193.742 ms
epoch: 4 step: 125, loss is 1.1687485
Train epoch time: 23769.989 ms, per step time: 190.160 ms
epoch: 5 step: 125, loss is 1.209775
Train epoch time: 23603.390 ms, per step time: 188.827 ms
epoch: 6 step: 125, loss is 1.3151006
Train epoch time: 23977.132 ms, per step time: 191.817 ms
epoch: 7 step: 125, loss is 1.4682239
Train epoch time: 23898.189 ms, per step time: 191.186 ms
epoch: 8 step: 125, loss is 1.2927357
Train epoch time: 23681.583 ms, per step time: 189.453 ms
epoch: 9 step: 125, loss is 1.5348746
Train epoch time: 23521.045 ms, per step time: 188.168 ms
epoch: 10 step: 125, loss is 1.3726548
Train epoch time: 23719.398 ms, per step time: 189.755 ms
2.模型验证
模型验证
ImageFolderDataset接口用于读取数据集
CrossEntropySmooth接口用于损失函数实例化
Model等接口用于编译模型
步骤:
数据增强
定义ViT网络结构
加载预训练模型参数
设置损失函数
设置评价指标
Top_1_Accuracy输出最大值为预测结果
Top_5_Accuracy输出前5的值为预测结果
两个指标的值越大,代表模型准确率越高
编译模型
验证
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)
trans_val = [transforms.Decode(),transforms.Resize(224 + 32),transforms.CenterCrop(224),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]
dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)
# construct model
network = ViT()
# load ckpt
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
network_loss = CrossEntropySmooth(sparse=True,reduction="mean",smooth_factor=0.1,num_classes=num_classes)
# define metric
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),'Top_5_Accuracy': train.Top5CategoricalAccuracy()}
if ascend_target:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")
# evaluate model
result = model.eval(dataset_val)
print(result)
输出:
{'Top_1_Accuracy': 0.7495, 'Top_5_Accuracy': 0.928}
3.模型推理
推理图片数据预处理
resize
normalize
匹配训练输入数据
dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)
trans_infer = [transforms.Decode(),transforms.Resize([224, 224]),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]
dataset_infer = dataset_infer.map(operations=trans_infer,input_columns=["image"],num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)
模型推理
调用模型predict方法
index2label获取对应标签
自定义show_result接口在对应图片上写结果
import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import io
class Color(Enum):"""dedine enum color."""red = (0, 0, 255)green = (0, 255, 0)blue = (255, 0, 0)cyan = (255, 255, 0)yellow = (0, 255, 255)magenta = (255, 0, 255)white = (255, 255, 255)black = (0, 0, 0)
def check_file_exist(file_name: str):"""check_file_exist."""if not os.path.isfile(file_name):raise FileNotFoundError(f"File `{file_name}` does not exist.")
def color_val(color):"""color_val."""if isinstance(color, str):return Color[color].valueif isinstance(color, Color):return color.valueif isinstance(color, tuple):assert len(color) == 3for channel in color:assert 0 <= channel <= 255return colorif isinstance(color, int):assert 0 <= color <= 255return color, color, colorif isinstance(color, np.ndarray):assert color.ndim == 1 and color.size == 3assert np.all((color >= 0) & (color <= 255))color = color.astype(np.uint8)return tuple(color)raise TypeError(f'Invalid type for color: {type(color)}')
def imread(image, mode=None):"""imread."""if isinstance(image, pathlib.Path):image = str(image)
if isinstance(image, np.ndarray):passelif isinstance(image, str):check_file_exist(image)image = Image.open(image)if mode:image = np.array(image.convert(mode))else:raise TypeError("Image must be a `ndarray`, `str` or Path object.")
return image
def imwrite(image, image_path, auto_mkdir=True):"""imwrite."""if auto_mkdir:dir_name = os.path.abspath(os.path.dirname(image_path))if dir_name != '':dir_name = os.path.expanduser(dir_name)os.makedirs(dir_name, mode=777, exist_ok=True)
image = Image.fromarray(image)image.save(image_path)
def imshow(img, win_name='', wait_time=0):"""imshow"""cv2.imshow(win_name, imread(img))if wait_time == 0: # prevent from hanging if windows was closedwhile True:ret = cv2.waitKey(1)
closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1# if user closed window or if some key pressedif closed or ret != -1:breakelse:ret = cv2.waitKey(wait_time)
def show_result(img: str,result: Dict[int, float],text_color: str = 'green',font_scale: float = 0.5,row_width: int = 20,show: bool = False,win_name: str = '',wait_time: int = 0,out_file: Optional[str] = None) -> None:"""Mark the prediction results on the picture."""img = imread(img, mode="RGB")img = img.copy()x, y = 0, row_widthtext_color = color_val(text_color)for k, v in result.items():if isinstance(v, float):v = f'{v:.2f}'label_text = f'{k}: {v}'cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,font_scale, text_color)y += row_widthif out_file:show = Falseimwrite(img, out_file)
if show:imshow(img, win_name, wait_time)
def index2label():"""Dictionary output for image numbers and categories of the ImageNet dataset."""metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")meta = io.loadmat(metafile, squeeze_me=True)['synsets']
nums_children = list(zip(*meta))[4]meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
_, wnids, classes = list(zip(*meta))[:3]clssname = [tuple(clss.split(', ')) for clss in classes]wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
mapping = {}for index, (_, class_name) in enumerate(wind2class_name):mapping[index] = class_name[0]return mapping
# Read data for inference
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):image = image["image"]image = ms.Tensor(image)prob = model.predict(image)label = np.argmax(prob.asnumpy(), axis=1)mapping = index2label()output = {int(label): mapping[int(label)]}print(output)show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG",result=output,out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")
输出:
{236: 'Doberman'}
推理过程完成后
推理文件夹下找图片推理结果
