引言
如果你对人工智能(AI)或深度学习(Deep Learning)感兴趣,可能听说过“Transformer”这个词。它最初在自然语言处理(NLP)领域大放异彩,比如在翻译、聊天机器人和文本生成中表现出色。但你知道吗?Transformer不仅能处理文字,还能用来分类图像!这听起来是不是有点神奇?别担心,这篇博客将带你从零开始,了解Transformer的基本概念、它如何被应用到图像分类,以及通过一个简单的例子让你直观理解它的运作原理。无论你是AI新手还是好奇的技术爱好者,这篇文章都会尽量用通俗的语言为你解锁Transformer的奥秘。
第一部分:Transformer是什么?
Transformer是一种深度学习模型,最早由Vaswani等人在2017年的论文《Attention is All You Need》中提出。它的核心思想是“注意力机制”(Attention Mechanism),这是一种让模型学会“关注”输入中重要部分的能力。传统的模型,比如卷积神经网络(CNN)和循环神经网络(RNN),在处理图像或序列数据时有局限性,而Transformer通过注意力机制突破了这些限制。
1.1 为什么叫“Transformer”?
“Transformer”这个名字听起来很酷,但它其实反映了模型的功能:它能将输入数据“转换”(Transform)成更有意义的表示形式。比如,把一句话翻译成另一种语言,或者把一张图片“翻译”成一个分类标签(比如“猫”或“狗”)。它的核心在于通过计算输入数据之间的关系,生成更有用的输出。
1.2 Transformer的基本结构
Transformer由两个主要部分组成:编码器(Encoder)和解码器(Decoder)。不过,在图像分类任务中,我们通常只用到编码器部分。让我们简单看看它的组成:
- 输入嵌入(Input Embedding):把输入数据(比如单词或图像块)转换成数字向量。
- 注意力机制(Attention):让模型关注输入中最重要的部分。
- 前馈神经网络(Feed-Forward Network):对数据进一步处理。
- 层归一化和残差连接(Layer Normalization & Residual Connection):帮助模型稳定训练,避免“梯度消失”等问题。
这些组件堆叠在一起,形成多层结构,每一层都让模型对数据的理解更深一层。
1.3 注意力机制:Transformer的“超能力”
注意力机制是Transformer的核心。想象你在读一本书,当你看到“猫”这个词时,你会自动想到整句话的上下文,比如“猫在睡觉”还是“猫在跑”。注意力机制让模型也能做到这一点:它会计算输入中每个部分对其他部分的“重要性”,然后根据这些关系调整输出。
具体来说,Transformer使用的是“自注意力”(Self-Attention)。它会为输入的每个部分(比如图像的一个小块)生成三个向量:
- 查询(Query):我想知道什么?
- 键(Key):我有哪些信息?
- 值(Value):这些信息有多重要?
通过计算查询和键之间的相似度,模型决定每个值的权重,然后把它们加权组合起来。这种方式让Transformer能捕捉全局关系,而不是像CNN那样只关注局部区域。
第二部分:从NLP到图像分类:Vision Transformer (ViT)
Transformer最初是为NLP设计的,那它是怎么“跨界”到图像分类的呢?这要归功于2020年提出的Vision Transformer(简称ViT)。让我们看看它是如何工作的。
2.1 图像怎么变成Transformer的输入?
图像和文字完全不同,对吧?图像是一堆像素,而文字是一串单词。要让Transformer处理图像,第一步就是把图像“翻译”成它能理解的形式。ViT的做法是:
- 切分图像:把一张图片(比如224x224像素)切成固定大小的小块(比如16x16像素),就像把一张大拼图拆成小碎片。
- 展平并嵌入:把每个小块展平成一个向量(就像把拼图碎片摊平),然后通过一个线性层把它们变成嵌入向量(Embedding)。
- 加上位置信息:因为Transformer不像CNN有固定的空间感知能力,我们需要手动告诉它每个小块在图像中的位置。这通过“位置编码”(Positional Encoding)实现。
经过这些步骤,一张图像就变成了一个序列(Sequence),就像NLP中的一句话,只不过这里的“单词”是图像块。
2.2 Transformer处理图像的过程
一旦图像被转换成序列,Transformer的编码器就开始工作:
- 自注意力:计算每个图像块和其他图像块之间的关系。比如,在一张猫的图片中,耳朵和眼睛的图像块可能会被关联起来。
- 多层堆叠:通过多层编码器,模型逐渐提取更高层次的特征。
分类头:在最后一层,添加一个简单的分类层(比如全连接层),输出图像的类别(比如“猫”或“狗”)。
2.3 ViT的优势和挑战
相比传统的CNN,ViT有几个优点:
- 全局视野:它能一次性看到整张图像的关系,而不像CNN只关注局部。
- 灵活性:同一个模型可以轻松处理不同大小的输入。
但它也有挑战:
- 计算量大:自注意力机制需要大量计算,尤其当图像块很多时。
- 数据需求高:ViT需要大量标注数据才能训练得好。
第三部分:一个简单的例子:用ViT分类猫和狗
为了让新手更容易理解,我们通过一个具体的例子来说明Transformer如何进行图像分类。假设我们要训练一个模型,区分CIFAR-10数据集中的“猫”和“狗”图片(CIFAR-10是PyTorch内置的一个小型图像数据集,包含10类32x32像素的图像)。下面我们逐步拆解过程,并新增代码实现。
3.1 数据准备
CIFAR-10中的每张图片是32x32像素,RGB格式。我们将它切成4x4的小块(为了简化示例),总共有64个块(32 ÷ 4 = 8,8x8 = 64)。每个小块有48个数值(4x4x3,因为RGB有3个通道)。
3.2 嵌入过程
- 把每个小块展平成一个48维向量。
- 通过一个线性层,把48维映射到一个固定维度(比如32维),得到嵌入向量。
- 加上位置编码,告诉模型每个块的位置。
现在,这张图片变成了一个64x32的矩阵,就像一个有64个“单词”的序列。
3.3 自注意力计算
假设猫咪的耳朵在第10个块,眼睛在第20个块。Transformer会:
- 为每个块生成查询、键和值向量。
- 计算第10个块的查询和第20个块的键之间的相似度,发现它们关系密切。
- 根据相似度加权组合值向量,生成一个新的表示。
经过多层自注意力,模型学会关联猫的特征。
3.4 分类输出
在最后一层,ViT取一个特殊的“分类标记”(CLS Token),通过全连接层输出10个类别的概率(CIFAR-10有10类),比如“猫”的概率是0.8,“狗”是0.1。
3.5 代码实现
下面我们提供两种代码实现方式,帮助你直观感受ViT的运作。代码基于PyTorch,使用CIFAR-10数据集。
实现方式1:从头实现一个简化的ViT
这个实现简化了ViT的核心组件,适合理解原理。
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader# 超参数
patch_size = 4 # 切分图像为4x4的小块
embed_dim = 32 # 每个小块的嵌入维度
num_heads = 4 # 注意力头的数量
num_classes = 10 # CIFAR-10有10个类别
num_patches = (32 // patch_size) ** 2 # 64个小块 (32x32图像)# 数据加载
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)# 简化的ViT模型
class SimpleViT(nn.Module):def __init__(self):super(SimpleViT, self).__init__()# 将图像块映射到嵌入空间self.patch_to_embedding = nn.Linear(patch_size * patch_size * 3, embed_dim)# 位置编码self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))# CLS Tokenself.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))# Transformer编码器self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads), num_layers=2)# 分类头self.fc = nn.Linear(embed_dim, num_classes)def forward(self, x):b, c, h, w = x.shape # [batch_size, 3, 32, 32]# 切分成小块并展平x = x.view(b, c, h // patch_size, patch_size, w // patch_size, patch_size)x = x.permute(0, 2, 4, 1, 3, 5).contiguous() # [b, 8, 8, 3, 4, 4]x = x.view(b, num_patches, -1) # [b, 64, 48]# 映射到嵌入空间x = self.patch_to_embedding(x) # [b, 64, 32]# 添加CLS Tokencls_tokens = self.cls_token.expand(b, -1, -1) # [b, 1, 32]x = torch.cat((cls_tokens, x), dim=1) # [b, 65, 32]# 加上位置编码x = x + self.pos_embedding# 通过Transformerx = self.transformer(x) # [b, 65, 32]# 取CLS Token的输出进行分类x = self.fc(x[:, 0]) # [b, 10]return x# 训练模型
model = SimpleViT()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)for epoch in range(5): # 训练5个epochfor images, labels in trainloader:optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')
代码解释:
- 数据加载:从CIFAR-10加载32x32的图像,归一化处理。
- 图像切分:将32x32图像切成64个4x4的小块,展平后映射到32维嵌入。
- CLS Token:添加一个特殊标记,用于最终分类。
- Transformer:使用PyTorch内置的Transformer编码器,包含2层,每层有4个注意力头。
- 训练:简单训练5个epoch,优化分类损失。
实现方式2:使用预训练ViT模型(Hugging Face)
这个实现利用Hugging Face的预训练ViT模型,适合快速上手。
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 数据加载
transform = transforms.Compose([transforms.Resize((224, 224)), # ViT需要224x224输入transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=16, shuffle=True)# 加载预训练ViT模型和特征提取器
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 10) # 修改分类头为10类# 训练设置
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)# 训练模型
model.train()
for epoch in range(3): # 训练3个epochfor images, labels in trainloader:inputs = feature_extractor(images=[img.permute(1, 2, 0).numpy() for img in images], return_tensors="pt")inputs = {k: v for k, v in inputs.items()} # 转换为模型输入格式optimizer.zero_grad()outputs = model(**inputs).logits # 获取分类输出loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')
代码解释:
- 数据预处理:将CIFAR-10图像调整到224x224(ViT预训练模型的要求)。
- 预训练模型:加载Google的vit-base-patch16-224,替换分类头为10类。
- 特征提取器:自动处理图像输入,切分并嵌入。
- 训练:微调模型,适应CIFAR-10任务。
注意:运行第二种方式需要安装transformers库(pip install transformers)。
第四部分:新手常见问题解答
4.1 Transformer和CNN有什么不同?
CNN像一个放大镜,逐步扫描图像的局部特征;而Transformer像一个全景相机,一次性捕捉全局关系。两者各有千秋,ViT证明了Transformer也能在图像任务中大放异彩。
4.2 我需要多强的编程基础才能用Transformer?
好消息是,你不需要从头写Transformer!开源工具(如PyTorch和Hugging Face)提供了预训练模型。你只需要学会加载模型、准备数据和微调,就能上手。
4.3 ViT适合所有图像任务吗?
不完全是。ViT在大数据集(如ImageNet)上表现很好,但在小数据集或需要精细局部特征的任务上,CNN可能更合适。
第五部分
Transformer通过注意力机制和全局视野,为图像分类带来了新思路。Vision Transformer(ViT)展示了它如何将图像切分成块,像处理句子一样处理图片,最终实现分类。对于新手来说,理解它的关键在于:
- 图像如何变成序列。
- 自注意力如何捕捉关系。
- 分类如何通过简单输出实现。
通过上面的代码示例,你可以看到:
- 从头实现ViT帮助理解原理。
- 使用预训练模型能快速应用到实际任务。