从NLP到CV:手把手教你用PyTorch复现ViT(Vision Transformer)图像分类模型
当Transformer在自然语言处理领域大放异彩时,计算机视觉领域的研究者们开始思考:这种基于自注意力机制的强大架构能否同样革新图像处理的方式?2020年,Google Research团队给出了肯定答案——Vision Transformer(ViT)的横空出世,证明了纯Transformer架构在图像分类任务上不仅能与CNN匹敌,甚至能超越传统卷积网络的性能。本文将带你从零开始,用PyTorch完整实现这个突破性模型。
1. 环境准备与数据预处理
在开始构建ViT之前,我们需要配置合适的开发环境。建议使用Python 3.8+和PyTorch 1.10+版本,这些版本对Transformer架构的支持最为成熟。以下是推荐的环境配置:
conda create -n vit python=3.8 conda activate vit pip install torch torchvision torchaudio pip install numpy matplotlib tqdmViT的核心创新在于将图像视为一系列patch的集合,这与NLP中将句子视为token序列的处理方式异曲同工。我们首先需要实现图像分块处理:
import torch import torch.nn as nn def image_to_patches(x, patch_size=16): """ 将图像分割为固定大小的patch 参数: x: 输入图像张量 [B, C, H, W] patch_size: 每个patch的宽度/高度 返回: patches: 展平的patch序列 [B, num_patches, patch_dim] """ B, C, H, W = x.shape x = x.unfold(2, patch_size, patch_size) # 沿高度维度展开 x = x.unfold(3, patch_size, patch_size) # 沿宽度维度展开 patches = x.contiguous().view(B, -1, C*patch_size*patch_size) return patches对于CIFAR-10数据集,标准的32×32分辨率图像被分割为8×8=64个4×4的patch(当patch_size=4时)。每个patch将被线性投影到一个固定维度的向量空间,这个过程类似于NLP中的词嵌入:
class PatchEmbedding(nn.Module): def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=512): super().__init__() self.img_size = img_size self.patch_size = patch_size self.n_patches = (img_size // patch_size) ** 2 self.proj = nn.Linear(patch_size*patch_size*in_chans, embed_dim) def forward(self, x): patches = image_to_patches(x, self.patch_size) embeddings = self.proj(patches) return embeddings2. ViT核心组件实现
2.1 位置编码与类别token
与原始Transformer不同,ViT需要处理2D图像的位置信息。我们采用可学习的位置编码,而非Transformer中的正弦位置编码:
class PositionalEncoding(nn.Module): def __init__(self, n_patches, embed_dim): super().__init__() self.pos_embed = nn.Parameter(torch.zeros(1, n_patches+1, embed_dim)) nn.init.trunc_normal_(self.pos_embed, std=0.02) def forward(self, x): return x + self.pos_embedViT引入了一个特殊的[class] token,其最终状态将作为整个图像的表示。这个设计借鉴了BERT中的[CLS] token:
class ClassToken(nn.Module): def __init__(self, embed_dim): super().__init__() self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) nn.init.trunc_normal_(self.cls_token, std=0.02) def forward(self, x): cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) return x2.2 Transformer编码器实现
ViT的核心是标准的Transformer编码器,由多头自注意力机制和前馈网络交替组成:
class TransformerBlock(nn.Module): def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1): super().__init__() self.norm1 = nn.LayerNorm(embed_dim) self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout) self.norm2 = nn.LayerNorm(embed_dim) self.mlp = nn.Sequential( nn.Linear(embed_dim, int(embed_dim * mlp_ratio)), nn.GELU(), nn.Dropout(dropout), nn.Linear(int(embed_dim * mlp_ratio), embed_dim), nn.Dropout(dropout) ) def forward(self, x): # 自注意力部分 x_norm = self.norm1(x) attn_output, _ = self.attn(x_norm, x_norm, x_norm) x = x + attn_output # 前馈网络部分 x_norm = self.norm2(x) mlp_output = self.mlp(x_norm) x = x + mlp_output return x3. 完整ViT模型组装
现在我们可以将各个组件组合成完整的Vision Transformer:
class VisionTransformer(nn.Module): def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=512, depth=6, num_heads=8, mlp_ratio=4.0, num_classes=10, dropout=0.1): super().__init__() # 图像分块与嵌入 self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim) n_patches = (img_size // patch_size) ** 2 # 类别token和位置编码 self.cls_token = ClassToken(embed_dim) self.pos_embed = PositionalEncoding(n_patches, embed_dim) # Transformer编码器堆叠 self.blocks = nn.ModuleList([ TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout) for _ in range(depth) ]) # 分类头 self.norm = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, num_classes) def forward(self, x): # 图像分块与嵌入 x = self.patch_embed(x) # 添加类别token和位置编码 x = self.cls_token(x) x = self.pos_embed(x) # 通过Transformer编码器 for block in self.blocks: x = block(x) # 提取类别token状态用于分类 cls_token = x[:, 0] cls_token = self.norm(cls_token) logits = self.head(cls_token) return logits4. 模型训练与调优技巧
4.1 训练配置
ViT的训练需要特别注意学习率调度和正则化策略。以下是推荐的训练配置:
from torch.optim import AdamW model = VisionTransformer( img_size=32, patch_size=4, in_chans=3, embed_dim=512, depth=6, num_heads=8, num_classes=10 ).to(device) optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.05) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) criterion = nn.CrossEntropyLoss()4.2 数据增强策略
由于ViT缺乏CNN固有的平移不变性等归纳偏置,数据增强对模型性能至关重要:
from torchvision import transforms train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])4.3 常见问题排查
在ViT实现过程中,开发者常遇到以下问题及解决方案:
- 训练不稳定:降低初始学习率,增加warmup步数
- 过拟合:增强数据增强,增加dropout率
- 梯度爆炸:使用梯度裁剪(
nn.utils.clip_grad_norm_) - 内存不足:减小batch size或使用梯度累积
5. 性能评估与对比分析
在CIFAR-10数据集上,我们实现的ViT模型可以达到约85%的测试准确率。与CNN相比,ViT展现出以下特点:
| 特性 | ViT | CNN |
|---|---|---|
| 计算效率 | 中等 | 高 |
| 数据需求 | 大量 | 中等 |
| 可扩展性 | 优秀 | 一般 |
| 局部特征提取 | 需学习 | 内置 |
| 全局关系建模 | 优秀 | 有限 |
值得注意的是,ViT的性能高度依赖于预训练数据规模。在小规模数据集上,CNN通常表现更好;但在大数据集(如ImageNet-21k)上预训练后,ViT展现出更强的迁移学习能力。