news 2026/1/1 10:15:59

如何从头开始训练视觉 Transformer (ViT)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
如何从头开始训练视觉 Transformer (ViT)

原文:towardsdatascience.com/how-to-train-a-vision-transformer-vit-from-scratch-f26641f26af2

嗨,大家好!对于那些还不认识我的人来说,我叫弗朗索瓦,我是 Meta 的研究科学家。我对解释高级人工智能概念并使其更易于理解充满热情。

今天,让我们深入了解计算机视觉领域最重大的贡献之一:视觉 Transformer (ViT)

本文重点介绍了自发布以来视觉 Transformer 的最先进实现。为了完全理解 ViT 的工作原理,我强烈建议阅读我关于理论基础的另一篇帖子:视觉 Transformer 的终极指南

如何从头开始训练 VIT?

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/cb4b6029679c9aa4d031080f4ea9ad42.png

ViT 架构,图片来自 原文

1. 注意力层

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/ef0e7fbdbbec052c3f903192be6218f0.png

注意力层,图片由作者提供

让我们从 Transformer 编码器最著名的构建块开始:注意力层。

classAttention(nn.Module):def__init__(self,dim,heads=8,dim_head=64,dropout=0.):super().__init__()inner_dim=dim_head*heads# Calculate the total inner dimension based on the number of attention heads and the dimension per head# Determine if a final projection layer is needed based on the number of heads and dimension per headproject_out=not(heads==1anddim_head==dim)self.heads=heads# Store the number of attention headsself.scale=dim_head**-0.5# Scaling factor for the attention scores (inverse of sqrt(dim_head))self.norm=nn.LayerNorm(dim)# Layer normalization to stabilize training and improve convergenceself.attend=nn.Softmax(dim=-1)# Softmax layer to compute attention weights (along the last dimension)self.dropout=nn.Dropout(dropout)# Dropout layer for regularization during training# Linear layer to project input tensor into queries, keys, and valuesself.to_qkv=nn.Linear(dim,inner_dim*3,bias=False)# Conditional projection layer after attention, to project back to the original dimension if requiredself.to_out=nn.Sequential(nn.Linear(inner_dim,dim),# Linear layer to project concatenated head outputs back to the original input dimensionnn.Dropout(dropout)# Dropout layer for regularization)ifproject_outelsenn.Identity()# Use Identity (no change) if no projection is neededdefforward(self,x):x=self.norm(x)# Apply normalization to the input tensor# Apply the linear layer to get queries, keys, and values, then split into 3 separate tensorsqkv=self.to_qkv(x).chunk(3,dim=-1)# Chunk the tensor into 3 parts along the last dimension: (query, key, value)# Reshape each chunk tensor from (batch_size, num_patches, inner_dim) to (batch_size, num_heads, num_patches, dim_head)q,k,v=map(lambdat:rearrange(t,'b n (h d) -> b h n d',h=self.heads),qkv)# Calculate dot products between queries and keys, scale by the inverse square root of dimensiondots=torch.matmul(q,k.transpose(-1,-2))*self.scale# Shape: (batch_size, num_heads, num_patches, num_patches)# Apply softmax to get attention weightsattn=self.attend(dots)# Shape: (batch_size, num_heads, num_patches, num_patches)attn=self.dropout(attn)# Multiply attention weights by values to get the outputout=torch.matmul(attn,v)# Shape: (batch_size, num_heads, num_patches, dim_head)# Rearrange the output tensor to (batch_size, num_patches, inner_dim)out=rearrange(out,'b h n d -> b n (h d)')# Combine heads dimension with the output dimension# Project the output back to the original input dimension if neededout=self.to_out(out)# Shape: (batch_size, num_patches, dim)returnout# Return the final output tensor

关键点:

  • inner_dim: 是dim_head和头数number的乘积。为了矢量化并加快计算速度,我们在张量乘积之前合并这两个维度。

  • 为了计算速度:我们不需要分别初始化 Q、K、V,我们可以将它们连接到一个大张量中,称为 self.to_qkv。这样,我们可以一次性计算所有内容。

  • einops是一个非常有用的库,可以通过指定维度来重新排列张量大小。它非常直观。

  • 例如,如果你有一个维度为 (batch_size,n_tokens,number_headsxhead_dim) 的张量,并且你想将最后一个维度拆分为 (batch_size,n_tokens,number_heads,head_dim),你可以使用 Einops.rarrange(qvk,'b n (h d) → b n h d',h = num_heads),这对于跟踪你正在操作的维度非常有用

2. 前馈网络 (FFN)

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/fdb66c1f8fd915973cc8a08b90a4b45d.png

前馈网络,图片由作者提供

接下来,我们添加第二个 Transformer 块:前馈网络。

classFFN(nn.Module):def__init__(self,dim,hidden_dim,dropout=0.):super().__init__()self.net=nn.Sequential(# norm -> linear -> activation -> dropout -> linear -> dropout# we first norm with a layer normnn.LayerNorm(dim),nn.Linear(dim,hidden_dim),# we project in a higher dimension hidden_dimnn.GELU(),# we apply the GELU activation functionnn.Dropout(dropout),# we apply dropoutnn.Linear(hidden_dim,dim),# we project back to the original dimension dimnn.Dropout(dropout)# we apply dropout)defforward(self,x):returnself.net(x)

这里没有什么复杂的。你只需要理解 FFN 是两个 MLP 的连续,通常第一个 MLP 将数据投影到更高的维度,第二个 MLP 将其投影回输入维度,这就是为什么我们有dimhidden dim

关键点:

  • dim: 输入标记的维度。

  • hidden_dim:FFN 的中间维度。

  • GELU:一种激活函数。虽然原始论文使用 ReLU,但由于其更平滑的过渡,GELU 已变得更加流行。

3. Transformer 编码器:L 个 Transformer 层的堆叠

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/5a2259b75e0d5a76dd274c09ef65da7e.png

Transformer 编码器,图像由作者提供

在注意力层和前馈网络就位后,我们可以组装一个 Transformer 层。Transformer 编码器本质上是一个 L 个 Transformer 层的堆叠。

记住,Transformer 层就像乐高积木一样——输入维度与输出维度相同,所以你可以堆叠尽可能多的(或者你的内存允许的)。

  • 不要忘记残差连接对于保持梯度流和使优化更平滑非常重要。
classTransformer(nn.Module):def__init__(self,dim,depth,heads,dim_head,mlp_dim_ratio,dropout):super().__init__()self.norm=nn.LayerNorm(dim)self.layers=nn.ModuleList([])mlp_dim=mlp_dim_ratio*dimfor_inrange(depth):self.layers.append(nn.ModuleList([Attention(dim=dim,heads=heads,dim_head=dim_head,dropout=dropout),FFN(dim=dim,hidden_dim=mlp_dim,dropout=dropout)]))defforward(self,x):forattn,ffninself.layers:x=attn(x)+x x=ffn(x)+xreturnself.norm(x)

组装最终的 ViT

我们已经完成了最困难的任务,现在我们可以组装完整的视觉 Transformer。

我们主要需要添加 3 个组件:

  • 将图像转换为图像块,然后转换为向量

  • 添加位置嵌入

  • 添加CLS标记

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/793214ef1880c42e62377c2910750d7e.png

图像块化,图像由作者提供

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/f9b65adf7b4694ec9e1c4d981b4df801.png

将图像转换为图像块,图像由作者提供

首先,我们定义一个简单的实用函数,帮助我们将标量转换为元组。

defpair(t):""" Converts a single value into a tuple of two values. If t is already a tuple, it is returned as is. Args: t: A single value or a tuple. Returns: A tuple where both elements are t if t is not a tuple. """returntifisinstance(t,tuple)else(t,t)

现在我们已经准备好编写 ViT 的代码了!

让我们从几个合理性检查开始:

  • 我们需要检查我们是否正确地将图像分割成一定数量的图像块,这个数量是一个整数。换句话说,我们需要检查image_heightimage_width是否可以被patch_dimension整除。
classViT(nn.Module):def__init__(self,*,image_size,patch_size,num_classes,dim,depth,heads,mlp_dim_ratio,pool='cls',channels=3,dim_head=64,dropout=0.):""" Initializes a Vision Transformer (ViT) model. Args: image_size (int or tuple): Size of the input image (height, width). patch_size (int or tuple): Size of each patch (height, width). num_classes (int): Number of output classes. dim (int): Dimension of the embedding space. depth (int): Number of transformer layers. heads (int): Number of attention heads. mlp_dim (int): Dimension of the feedforward network. pool (str): Pooling strategy ('cls' or 'mean'). channels (int): Number of input channels (e.g., 3 for RGB images). dim_head (int): Dimension of each attention head. dropout (float): Dropout rate. """super().__init__()# Convert image size and patch size to tuples if they are single valuesimage_height,image_width=pair(image_size)patch_height,patch_width=pair(patch_size)# Ensure that the image dimensions are divisible by the patch sizeassertimage_height%patch_height==0andimage_width%patch_width==0,'Image dimensions must be divisible by the patch size.'# Calculate the number of patches and the dimension of each patchnum_patches=(image_height//patch_height)*(image_width//patch_width)patch_dim=channels*patch_height*patch_width

下一步是将图像块转换为嵌入。记住,在这里一个图像有 C = 3 个维度。我们需要展开这个维度,并将每个图像块压缩成维度 _patch_size x patch_size x c.*

# Define the patch embedding layerself.to_patch_embedding=nn.Sequential(# Rearrange the input tensor to (batch_size, num_patches, patch_dim)Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',p1=patch_height,p2=patch_width),nn.LayerNorm(patch_dim),# Normalize each patchnn.Linear(patch_dim,dim),# Project patches to embedding dimensionnn.LayerNorm(dim)# Normalize the embedding)

然后我们需要定义CLS标记和位置嵌入。CLS 标记有助于将整个图像表示为一个单一的向量,位置嵌入有助于模型对标记具有空间感知。它们都是学习参数(随机初始化)。

# Ensure the pooling strategy is validassertpoolin{'cls','mean'},'pool type must be either cls (cls token) or mean (mean pooling)'# Define CLS token and positional embeddingsself.cls_token=nn.Parameter(torch.randn(1,1,dim))# Learnable class tokenself.pos_embedding=nn.Parameter(torch.randn(1,num_patches+1,dim))# Positional embeddings for patches and class token

现在我们只需要定义我们之前定义的 transformer 层,并添加一个分类头

# Define the transformer encoderself.transformer=Transformer(dim,depth,heads,dim_head,mlp_dim_ratio,dropout)# Pooling strategy ('cls' token or mean of patches)self.pool=pool# Identity layer (no change to the tensor)self.to_latent=nn.Identity()# Classification headself.mlp_head=nn.Linear(dim,num_classes)

前向传递

我们已经初始化了 ViT 的所有组件,现在我们只需要按正确的顺序调用它们进行前向传递。

  • 我们首先将输入图像转换为图像块,并将每个图像块展开成一个向量。

  • 然后我们重复CLS标记(沿着批次维度),并在维度 1 上连接,这是序列长度。确实我们学习了一个向量的参数,但需要将其连接到每个图像上,这就是为什么我们需要扩展一个维度。

  • 然后我们将位置嵌入添加到每个标记上。

defforward(self,img):""" Forward pass through the Vision Transformer model. Args: img (Tensor): Input image tensor of shape (batch_size, channels, height, width). Returns: dict: A dictionary containing the class token, feature map, and classification result. """# Convert image to patch embeddingsx=self.to_patch_embedding(img)# Shape: (batch_size, num_patches, dim)b,n,_=x.shape# Get batch size, number of patches, and embedding dimension# Repeat class token for each image in the batchcls_tokens=repeat(self.cls_token,'1 1 d -> b 1 d',b=b)# Concatenate class token with patch embeddingsx=torch.cat((cls_tokens,x),dim=1)# Add positional embeddings to the inputx+=self.pos_embedding[:,:(n+1)]# Apply dropout for regularizationx=self.dropout(x)

然后我们应用Transformer 编码器。我们主要用它来构建包含 3 个内容的输出:**

  • CLS 标记(图像的单个向量表示)。

  • 特征图(图像每个图像块的向量表示)

  • 分类头逻辑(可选):这用于分类任务。请注意,视觉 Transformer 可以用于不同的任务,但分类是最初使用的任务。

# Pass through transformer encoderx=self.transformer(x)# Shape: (batch_size, num_patches + 1, dim)# Extract class token and feature mapcls_token=x[:,0]# Extract class tokenfeature_map=x[:,1:]# Remaining tokens are feature map# Apply pooling operation: 'cls' token or mean of patchespooled_output=cls_tokenifself.pool=='cls'elsefeature_map.mean(dim=1)# Apply the identity transformation (no change to the tensor)pooled_output=self.to_latent(pooled_output)# Apply the classification head to the pooled outputclassification_result=self.mlp_head(pooled_output)# Return a dictionary with the required componentsreturn{'cls_token':cls_token,# Class token'feature_map':feature_map,# Feature map (patch embeddings)'classification_head_logits':classification_result# Final classification result}

总结一下,以下是 ViT 的最终代码。您可以在本 github 仓库中找到其更新版本:

GitHub – FrancoisPorcher/awesome-ai-tutorials: The best collection of AI tutorials to make you a…

classViT(nn.Module):def__init__(self,*,image_size,patch_size,num_classes,dim,depth,heads,mlp_dim_ratio,pool='cls',channels=3,dim_head=64,dropout=0.):""" Initializes a Vision Transformer (ViT) model. Args: image_size (int or tuple): Size of the input image (height, width). patch_size (int or tuple): Size of each patch (height, width). num_classes (int): Number of output classes. dim (int): Dimension of the embedding space. depth (int): Number of transformer layers. heads (int): Number of attention heads. mlp_dim (int): Dimension of the feedforward network. pool (str): Pooling strategy ('cls' or 'mean'). channels (int): Number of input channels (e.g., 3 for RGB images). dim_head (int): Dimension of each attention head. dropout (float): Dropout rate. """super().__init__()# Convert image size and patch size to tuples if they are single valuesimage_height,image_width=pair(image_size)patch_height,patch_width=pair(patch_size)# Ensure that the image dimensions are divisible by the patch sizeassertimage_height%patch_height==0andimage_width%patch_width==0,'Image dimensions must be divisible by the patch size.'# Calculate the number of patches and the dimension of each patchnum_patches=(image_height//patch_height)*(image_width//patch_width)patch_dim=channels*patch_height*patch_width# Define the patch embedding layerself.to_patch_embedding=nn.Sequential(# Rearrange the input tensor to (batch_size, num_patches, patch_dim)Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',p1=patch_height,p2=patch_width),nn.LayerNorm(patch_dim),# Normalize each patchnn.Linear(patch_dim,dim),# Project patches to embedding dimensionnn.LayerNorm(dim)# Normalize the embedding)# Ensure the pooling strategy is validassertpoolin{'cls','mean'},'pool type must be either cls (cls token) or mean (mean pooling)'# Define CLS token and positional embeddingsself.cls_token=nn.Parameter(torch.randn(1,1,dim))# Learnable class tokenself.pos_embedding=nn.Parameter(torch.randn(1,num_patches+1,dim))# Positional embeddings for patches and class tokenself.dropout=nn.Dropout(dropout)# Dropout for regularization# Define the transformer encoderself.transformer=Transformer(dim,depth,heads,dim_head,mlp_dim_ratio,dropout)# Pooling strategy ('cls' token or mean of patches)self.pool=pool# Identity layer (no change to the tensor)self.to_latent=nn.Identity()# Classification headself.mlp_head=nn.Linear(dim,num_classes)defforward(self,img):""" Forward pass through the Vision Transformer model. Args: img (Tensor): Input image tensor of shape (batch_size, channels, height, width). Returns: dict: A dictionary containing the class token, feature map, and classification result. """# Convert image to patch embeddingsx=self.to_patch_embedding(img)# Shape: (batch_size, num_patches, dim)b,n,_=x.shape# Get batch size, number of patches, and embedding dimension# Repeat class token for each image in the batchcls_tokens=repeat(self.cls_token,'1 1 d -> b 1 d',b=b)# Concatenate class token with patch embeddingsx=torch.cat((cls_tokens,x),dim=1)# Add positional embeddings to the inputx+=self.pos_embedding[:,:(n+1)]# Apply dropout for regularizationx=self.dropout(x)# Pass through transformer encoderx=self.transformer(x)# Shape: (batch_size, num_patches + 1, dim)# Extract class token and feature mapcls_token=x[:,0]# Extract class tokenfeature_map=x[:,1:]# Remaining tokens are feature map# Apply pooling operation: 'cls' token or mean of patchespooled_output=cls_tokenifself.pool=='cls'elsefeature_map.mean(dim=1)# Apply the identity transformation (no change to the tensor)pooled_output=self.to_latent(pooled_output)# Apply the classification head to the pooled outputclassification_result=self.mlp_head(pooled_output)# Return a dictionary with the required componentsreturn{'cls_token':cls_token,# Class token'feature_map':feature_map,# Feature map (patch embeddings)'classification_head_logits':classification_result# Final classification result}

恭喜,您已经从头开始构建了一个视觉 Transformer!

感谢阅读!在您离开之前:

想要了解更多精彩的教程,请查看我在 Github 上的 AI 教程汇编

GitHub – FrancoisPorcher/awesome-ai-tutorials: The best collection of AI tutorials to make you a…

You should get my articles in your inbox.Subscribe here.

如果您想访问 Medium 上的优质文章,您只需每月支付 5 美元即可。如果您通过我的链接注册,您只需支付部分费用,无需额外费用即可支持我。


如果您觉得这篇文章有见地且有益,请考虑关注我并为我点赞以获取更深入的内容!您的支持帮助我继续创作有助于我们共同理解的内容。

参考文献

“An Image is Worth 16×16 Words”由 Alexey Dosovitskiy 等人(2021)发表。您可以在 arXiv 上阅读完整论文。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2025/12/31 21:21:36

北京理工大学学术答辩PPT模板:打造专业学术展示的终极指南

北京理工大学学术答辩PPT模板:打造专业学术展示的终极指南 【免费下载链接】北京理工大学学术答辩PPT模板 北京理工大学学术答辩PPT模板是官方正式版,专为学术答辩、课题汇报等场合设计,助您高效展示研究成果。模板采用专业设计,风…

作者头像 李华
网站建设 2025/12/30 19:52:34

基于STM32单片机的智能家居语音控制系统(有完整资料)

资料查找方式:特纳斯电子(电子校园网):搜索下面编号即可编号:T1102410M设计简介:本设计是基于STM32单片机的智能家居语音控制系统,主要实现以下功能:1、检测温湿度、光照强度、PM2.5…

作者头像 李华
网站建设 2025/12/29 10:08:53

学长亲荐8个AI论文工具,助你搞定本科论文格式规范!

学长亲荐8个AI论文工具,助你搞定本科论文格式规范! 论文写作的救星,AI 工具如何助你轻松应对格式规范 对于大多数本科生来说,撰写论文不仅是学术能力的考验,更是一场对时间、精力和耐心的挑战。尤其是在格式规范方面&a…

作者头像 李华
网站建设 2025/12/29 10:08:40

WebAssembly完整优化指南:7个高效技巧让应用飞起来

WebAssembly完整优化指南:7个高效技巧让应用飞起来 【免费下载链接】wasm-bindgen Facilitating high-level interactions between Wasm modules and JavaScript 项目地址: https://gitcode.com/gh_mirrors/wa/wasm-bindgen WebAssembly技术正在重塑现代Web开…

作者头像 李华
网站建设 2025/12/29 10:07:41

melonDS模拟器完全指南:快速上手Nintendo DS游戏体验

melonDS模拟器完全指南:快速上手Nintendo DS游戏体验 【免费下载链接】melonDS DS emulator, sorta 项目地址: https://gitcode.com/gh_mirrors/me/melonDS 想要在电脑上重温经典Nintendo DS游戏?melonDS模拟器是你的最佳选择!这款开源…

作者头像 李华
网站建设 2025/12/29 10:07:12

PyTorch-CUDA-v2.6镜像日志分析:用户最常用的功能是什么?

PyTorch-CUDA-v2.6镜像日志分析:用户最常用的功能是什么? 在深度学习项目从实验走向部署的过程中,一个反复出现的痛点是环境配置——你是否也经历过这样的场景?明明在本地跑得好好的模型,换到服务器上却因为 CUDA 版本…

作者头像 李华