news 2026/4/30 17:46:45

从NLP到CV:手把手教你用PyTorch复现ViT(Vision Transformer)图像分类模型

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
从NLP到CV:手把手教你用PyTorch复现ViT(Vision Transformer)图像分类模型

从NLP到CV:手把手教你用PyTorch复现ViT(Vision Transformer)图像分类模型

当Transformer在自然语言处理领域大放异彩时,计算机视觉领域的研究者们开始思考:这种基于自注意力机制的强大架构能否同样革新图像处理的方式?2020年,Google Research团队给出了肯定答案——Vision Transformer(ViT)的横空出世,证明了纯Transformer架构在图像分类任务上不仅能与CNN匹敌,甚至能超越传统卷积网络的性能。本文将带你从零开始,用PyTorch完整实现这个突破性模型。

1. 环境准备与数据预处理

在开始构建ViT之前,我们需要配置合适的开发环境。建议使用Python 3.8+和PyTorch 1.10+版本,这些版本对Transformer架构的支持最为成熟。以下是推荐的环境配置:

conda create -n vit python=3.8 conda activate vit pip install torch torchvision torchaudio pip install numpy matplotlib tqdm

ViT的核心创新在于将图像视为一系列patch的集合,这与NLP中将句子视为token序列的处理方式异曲同工。我们首先需要实现图像分块处理:

import torch import torch.nn as nn def image_to_patches(x, patch_size=16): """ 将图像分割为固定大小的patch 参数: x: 输入图像张量 [B, C, H, W] patch_size: 每个patch的宽度/高度 返回: patches: 展平的patch序列 [B, num_patches, patch_dim] """ B, C, H, W = x.shape x = x.unfold(2, patch_size, patch_size) # 沿高度维度展开 x = x.unfold(3, patch_size, patch_size) # 沿宽度维度展开 patches = x.contiguous().view(B, -1, C*patch_size*patch_size) return patches

对于CIFAR-10数据集,标准的32×32分辨率图像被分割为8×8=64个4×4的patch(当patch_size=4时)。每个patch将被线性投影到一个固定维度的向量空间,这个过程类似于NLP中的词嵌入:

class PatchEmbedding(nn.Module): def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=512): super().__init__() self.img_size = img_size self.patch_size = patch_size self.n_patches = (img_size // patch_size) ** 2 self.proj = nn.Linear(patch_size*patch_size*in_chans, embed_dim) def forward(self, x): patches = image_to_patches(x, self.patch_size) embeddings = self.proj(patches) return embeddings

2. ViT核心组件实现

2.1 位置编码与类别token

与原始Transformer不同,ViT需要处理2D图像的位置信息。我们采用可学习的位置编码,而非Transformer中的正弦位置编码:

class PositionalEncoding(nn.Module): def __init__(self, n_patches, embed_dim): super().__init__() self.pos_embed = nn.Parameter(torch.zeros(1, n_patches+1, embed_dim)) nn.init.trunc_normal_(self.pos_embed, std=0.02) def forward(self, x): return x + self.pos_embed

ViT引入了一个特殊的[class] token,其最终状态将作为整个图像的表示。这个设计借鉴了BERT中的[CLS] token:

class ClassToken(nn.Module): def __init__(self, embed_dim): super().__init__() self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) nn.init.trunc_normal_(self.cls_token, std=0.02) def forward(self, x): cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) return x

2.2 Transformer编码器实现

ViT的核心是标准的Transformer编码器,由多头自注意力机制和前馈网络交替组成:

class TransformerBlock(nn.Module): def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1): super().__init__() self.norm1 = nn.LayerNorm(embed_dim) self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout) self.norm2 = nn.LayerNorm(embed_dim) self.mlp = nn.Sequential( nn.Linear(embed_dim, int(embed_dim * mlp_ratio)), nn.GELU(), nn.Dropout(dropout), nn.Linear(int(embed_dim * mlp_ratio), embed_dim), nn.Dropout(dropout) ) def forward(self, x): # 自注意力部分 x_norm = self.norm1(x) attn_output, _ = self.attn(x_norm, x_norm, x_norm) x = x + attn_output # 前馈网络部分 x_norm = self.norm2(x) mlp_output = self.mlp(x_norm) x = x + mlp_output return x

3. 完整ViT模型组装

现在我们可以将各个组件组合成完整的Vision Transformer:

class VisionTransformer(nn.Module): def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=512, depth=6, num_heads=8, mlp_ratio=4.0, num_classes=10, dropout=0.1): super().__init__() # 图像分块与嵌入 self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim) n_patches = (img_size // patch_size) ** 2 # 类别token和位置编码 self.cls_token = ClassToken(embed_dim) self.pos_embed = PositionalEncoding(n_patches, embed_dim) # Transformer编码器堆叠 self.blocks = nn.ModuleList([ TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout) for _ in range(depth) ]) # 分类头 self.norm = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, num_classes) def forward(self, x): # 图像分块与嵌入 x = self.patch_embed(x) # 添加类别token和位置编码 x = self.cls_token(x) x = self.pos_embed(x) # 通过Transformer编码器 for block in self.blocks: x = block(x) # 提取类别token状态用于分类 cls_token = x[:, 0] cls_token = self.norm(cls_token) logits = self.head(cls_token) return logits

4. 模型训练与调优技巧

4.1 训练配置

ViT的训练需要特别注意学习率调度和正则化策略。以下是推荐的训练配置:

from torch.optim import AdamW model = VisionTransformer( img_size=32, patch_size=4, in_chans=3, embed_dim=512, depth=6, num_heads=8, num_classes=10 ).to(device) optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.05) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) criterion = nn.CrossEntropyLoss()

4.2 数据增强策略

由于ViT缺乏CNN固有的平移不变性等归纳偏置,数据增强对模型性能至关重要:

from torchvision import transforms train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])

4.3 常见问题排查

在ViT实现过程中,开发者常遇到以下问题及解决方案:

  1. 训练不稳定:降低初始学习率,增加warmup步数
  2. 过拟合:增强数据增强,增加dropout率
  3. 梯度爆炸:使用梯度裁剪(nn.utils.clip_grad_norm_
  4. 内存不足:减小batch size或使用梯度累积

5. 性能评估与对比分析

在CIFAR-10数据集上,我们实现的ViT模型可以达到约85%的测试准确率。与CNN相比,ViT展现出以下特点:

特性ViTCNN
计算效率中等
数据需求大量中等
可扩展性优秀一般
局部特征提取需学习内置
全局关系建模优秀有限

值得注意的是,ViT的性能高度依赖于预训练数据规模。在小规模数据集上,CNN通常表现更好;但在大数据集(如ImageNet-21k)上预训练后,ViT展现出更强的迁移学习能力。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/30 17:44:40

怎样高效掌握Python GUI开发:实用PyQt6实战手册

怎样高效掌握Python GUI开发:实用PyQt6实战手册 【免费下载链接】PyQt-Chinese-tutorial PyQt6中文教程 项目地址: https://gitcode.com/gh_mirrors/py/PyQt-Chinese-tutorial PyQt-Chinese-Tutorial是一份全面的PyQt6中文教程,专为Python开发者和…

作者头像 李华
网站建设 2026/4/30 17:41:05

如何轻松释放Windows内存:Mem Reduct完整使用指南

如何轻松释放Windows内存:Mem Reduct完整使用指南 【免费下载链接】memreduct Lightweight real-time memory management application to monitor and clean system memory on your computer. 项目地址: https://gitcode.com/gh_mirrors/me/memreduct 你是不…

作者头像 李华
网站建设 2026/4/30 17:40:48

Adobe-GenP:3分钟解锁Adobe全家桶的终极激活指南

Adobe-GenP:3分钟解锁Adobe全家桶的终极激活指南 【免费下载链接】Adobe-GenP Adobe CC 2019/2020/2021/2022/2023 GenP Universal Patch 3.0 项目地址: https://gitcode.com/gh_mirrors/ad/Adobe-GenP 还在为Adobe Creative Cloud的高昂订阅费发愁吗&#x…

作者头像 李华
网站建设 2026/4/30 17:38:36

Spectrimage 从图像创建调色板:四轮迭代,调色板更似人工挑选!

迭代 1:让它运行起来在第一个版本中,在 RGB 空间中进行中值切割量化,划分七个 ROYGBIV 区域,经三轮区域内颜色选择和跨区域去重。但代码杂乱,有十三个命名常量和六条判断灰色规则,逻辑难理解,还…

作者头像 李华
网站建设 2026/4/30 17:34:24

Crossref REST API架构设计与高性能元数据查询系统实现指南

Crossref REST API架构设计与高性能元数据查询系统实现指南 【免费下载链接】rest-api-doc Documentation for Crossrefs REST API. For questions or suggestions, see https://community.crossref.org/ 项目地址: https://gitcode.com/gh_mirrors/re/rest-api-doc 在学…

作者头像 李华