原文:
towardsdatascience.com/how-to-train-a-vision-transformer-vit-from-scratch-f26641f26af2
嗨,大家好!对于那些还不认识我的人来说,我叫弗朗索瓦,我是 Meta 的研究科学家。我对解释高级人工智能概念并使其更易于理解充满热情。
今天,让我们深入了解计算机视觉领域最重大的贡献之一:视觉 Transformer (ViT)。
本文重点介绍了自发布以来视觉 Transformer 的最先进实现。为了完全理解 ViT 的工作原理,我强烈建议阅读我关于理论基础的另一篇帖子:视觉 Transformer 的终极指南
如何从头开始训练 VIT?
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/cb4b6029679c9aa4d031080f4ea9ad42.png
ViT 架构,图片来自 原文
1. 注意力层
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/ef0e7fbdbbec052c3f903192be6218f0.png
注意力层,图片由作者提供
让我们从 Transformer 编码器最著名的构建块开始:注意力层。
classAttention(nn.Module):def__init__(self,dim,heads=8,dim_head=64,dropout=0.):super().__init__()inner_dim=dim_head*heads# Calculate the total inner dimension based on the number of attention heads and the dimension per head# Determine if a final projection layer is needed based on the number of heads and dimension per headproject_out=not(heads==1anddim_head==dim)self.heads=heads# Store the number of attention headsself.scale=dim_head**-0.5# Scaling factor for the attention scores (inverse of sqrt(dim_head))self.norm=nn.LayerNorm(dim)# Layer normalization to stabilize training and improve convergenceself.attend=nn.Softmax(dim=-1)# Softmax layer to compute attention weights (along the last dimension)self.dropout=nn.Dropout(dropout)# Dropout layer for regularization during training# Linear layer to project input tensor into queries, keys, and valuesself.to_qkv=nn.Linear(dim,inner_dim*3,bias=False)# Conditional projection layer after attention, to project back to the original dimension if requiredself.to_out=nn.Sequential(nn.Linear(inner_dim,dim),# Linear layer to project concatenated head outputs back to the original input dimensionnn.Dropout(dropout)# Dropout layer for regularization)ifproject_outelsenn.Identity()# Use Identity (no change) if no projection is neededdefforward(self,x):x=self.norm(x)# Apply normalization to the input tensor# Apply the linear layer to get queries, keys, and values, then split into 3 separate tensorsqkv=self.to_qkv(x).chunk(3,dim=-1)# Chunk the tensor into 3 parts along the last dimension: (query, key, value)# Reshape each chunk tensor from (batch_size, num_patches, inner_dim) to (batch_size, num_heads, num_patches, dim_head)q,k,v=map(lambdat:rearrange(t,'b n (h d) -> b h n d',h=self.heads),qkv)# Calculate dot products between queries and keys, scale by the inverse square root of dimensiondots=torch.matmul(q,k.transpose(-1,-2))*self.scale# Shape: (batch_size, num_heads, num_patches, num_patches)# Apply softmax to get attention weightsattn=self.attend(dots)# Shape: (batch_size, num_heads, num_patches, num_patches)attn=self.dropout(attn)# Multiply attention weights by values to get the outputout=torch.matmul(attn,v)# Shape: (batch_size, num_heads, num_patches, dim_head)# Rearrange the output tensor to (batch_size, num_patches, inner_dim)out=rearrange(out,'b h n d -> b n (h d)')# Combine heads dimension with the output dimension# Project the output back to the original input dimension if neededout=self.to_out(out)# Shape: (batch_size, num_patches, dim)returnout# Return the final output tensor关键点:
inner_dim: 是dim_head和头数number的乘积。为了矢量化并加快计算速度,我们在张量乘积之前合并这两个维度。为了计算速度:我们不需要分别初始化 Q、K、V,我们可以将它们连接到一个大张量中,称为 self.to_qkv。这样,我们可以一次性计算所有内容。
einops是一个非常有用的库,可以通过指定维度来重新排列张量大小。它非常直观。例如,如果你有一个维度为 (
batch_size,n_tokens,number_headsxhead_dim) 的张量,并且你想将最后一个维度拆分为 (batch_size,n_tokens,number_heads,head_dim),你可以使用 Einops.rarrange(qvk,'b n (h d) → b n h d',h = num_heads),这对于跟踪你正在操作的维度非常有用
2. 前馈网络 (FFN)
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/fdb66c1f8fd915973cc8a08b90a4b45d.png
前馈网络,图片由作者提供
接下来,我们添加第二个 Transformer 块:前馈网络。
classFFN(nn.Module):def__init__(self,dim,hidden_dim,dropout=0.):super().__init__()self.net=nn.Sequential(# norm -> linear -> activation -> dropout -> linear -> dropout# we first norm with a layer normnn.LayerNorm(dim),nn.Linear(dim,hidden_dim),# we project in a higher dimension hidden_dimnn.GELU(),# we apply the GELU activation functionnn.Dropout(dropout),# we apply dropoutnn.Linear(hidden_dim,dim),# we project back to the original dimension dimnn.Dropout(dropout)# we apply dropout)defforward(self,x):returnself.net(x)这里没有什么复杂的。你只需要理解 FFN 是两个 MLP 的连续,通常第一个 MLP 将数据投影到更高的维度,第二个 MLP 将其投影回输入维度,这就是为什么我们有dim和hidden dim
关键点:
dim: 输入标记的维度。hidden_dim:FFN 的中间维度。GELU:一种激活函数。虽然原始论文使用 ReLU,但由于其更平滑的过渡,GELU 已变得更加流行。
3. Transformer 编码器:L 个 Transformer 层的堆叠
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/5a2259b75e0d5a76dd274c09ef65da7e.png
Transformer 编码器,图像由作者提供
在注意力层和前馈网络就位后,我们可以组装一个 Transformer 层。Transformer 编码器本质上是一个 L 个 Transformer 层的堆叠。
记住,Transformer 层就像乐高积木一样——输入维度与输出维度相同,所以你可以堆叠尽可能多的(或者你的内存允许的)。
- 不要忘记残差连接对于保持梯度流和使优化更平滑非常重要。
classTransformer(nn.Module):def__init__(self,dim,depth,heads,dim_head,mlp_dim_ratio,dropout):super().__init__()self.norm=nn.LayerNorm(dim)self.layers=nn.ModuleList([])mlp_dim=mlp_dim_ratio*dimfor_inrange(depth):self.layers.append(nn.ModuleList([Attention(dim=dim,heads=heads,dim_head=dim_head,dropout=dropout),FFN(dim=dim,hidden_dim=mlp_dim,dropout=dropout)]))defforward(self,x):forattn,ffninself.layers:x=attn(x)+x x=ffn(x)+xreturnself.norm(x)组装最终的 ViT
我们已经完成了最困难的任务,现在我们可以组装完整的视觉 Transformer。
我们主要需要添加 3 个组件:
将图像转换为图像块,然后转换为向量。
添加位置嵌入
添加
CLS标记
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/793214ef1880c42e62377c2910750d7e.png
图像块化,图像由作者提供
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/f9b65adf7b4694ec9e1c4d981b4df801.png
将图像转换为图像块,图像由作者提供
首先,我们定义一个简单的实用函数,帮助我们将标量转换为元组。
defpair(t):""" Converts a single value into a tuple of two values. If t is already a tuple, it is returned as is. Args: t: A single value or a tuple. Returns: A tuple where both elements are t if t is not a tuple. """returntifisinstance(t,tuple)else(t,t)现在我们已经准备好编写 ViT 的代码了!
让我们从几个合理性检查开始:
- 我们需要检查我们是否正确地将图像分割成一定数量的图像块,这个数量是一个整数。换句话说,我们需要检查
image_height和image_width是否可以被patch_dimension整除。
classViT(nn.Module):def__init__(self,*,image_size,patch_size,num_classes,dim,depth,heads,mlp_dim_ratio,pool='cls',channels=3,dim_head=64,dropout=0.):""" Initializes a Vision Transformer (ViT) model. Args: image_size (int or tuple): Size of the input image (height, width). patch_size (int or tuple): Size of each patch (height, width). num_classes (int): Number of output classes. dim (int): Dimension of the embedding space. depth (int): Number of transformer layers. heads (int): Number of attention heads. mlp_dim (int): Dimension of the feedforward network. pool (str): Pooling strategy ('cls' or 'mean'). channels (int): Number of input channels (e.g., 3 for RGB images). dim_head (int): Dimension of each attention head. dropout (float): Dropout rate. """super().__init__()# Convert image size and patch size to tuples if they are single valuesimage_height,image_width=pair(image_size)patch_height,patch_width=pair(patch_size)# Ensure that the image dimensions are divisible by the patch sizeassertimage_height%patch_height==0andimage_width%patch_width==0,'Image dimensions must be divisible by the patch size.'# Calculate the number of patches and the dimension of each patchnum_patches=(image_height//patch_height)*(image_width//patch_width)patch_dim=channels*patch_height*patch_width下一步是将图像块转换为嵌入。记住,在这里一个图像有 C = 3 个维度。我们需要展开这个维度,并将每个图像块压缩成维度 _patch_size x patch_size x c.*
# Define the patch embedding layerself.to_patch_embedding=nn.Sequential(# Rearrange the input tensor to (batch_size, num_patches, patch_dim)Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',p1=patch_height,p2=patch_width),nn.LayerNorm(patch_dim),# Normalize each patchnn.Linear(patch_dim,dim),# Project patches to embedding dimensionnn.LayerNorm(dim)# Normalize the embedding)然后我们需要定义CLS标记和位置嵌入。CLS 标记有助于将整个图像表示为一个单一的向量,位置嵌入有助于模型对标记具有空间感知。它们都是学习参数(随机初始化)。
# Ensure the pooling strategy is validassertpoolin{'cls','mean'},'pool type must be either cls (cls token) or mean (mean pooling)'# Define CLS token and positional embeddingsself.cls_token=nn.Parameter(torch.randn(1,1,dim))# Learnable class tokenself.pos_embedding=nn.Parameter(torch.randn(1,num_patches+1,dim))# Positional embeddings for patches and class token现在我们只需要定义我们之前定义的 transformer 层,并添加一个分类头
# Define the transformer encoderself.transformer=Transformer(dim,depth,heads,dim_head,mlp_dim_ratio,dropout)# Pooling strategy ('cls' token or mean of patches)self.pool=pool# Identity layer (no change to the tensor)self.to_latent=nn.Identity()# Classification headself.mlp_head=nn.Linear(dim,num_classes)前向传递
我们已经初始化了 ViT 的所有组件,现在我们只需要按正确的顺序调用它们进行前向传递。
我们首先将输入图像转换为图像块,并将每个图像块展开成一个向量。
然后我们重复
CLS标记(沿着批次维度),并在维度 1 上连接,这是序列长度。确实我们学习了一个向量的参数,但需要将其连接到每个图像上,这就是为什么我们需要扩展一个维度。然后我们将位置嵌入添加到每个标记上。
defforward(self,img):""" Forward pass through the Vision Transformer model. Args: img (Tensor): Input image tensor of shape (batch_size, channels, height, width). Returns: dict: A dictionary containing the class token, feature map, and classification result. """# Convert image to patch embeddingsx=self.to_patch_embedding(img)# Shape: (batch_size, num_patches, dim)b,n,_=x.shape# Get batch size, number of patches, and embedding dimension# Repeat class token for each image in the batchcls_tokens=repeat(self.cls_token,'1 1 d -> b 1 d',b=b)# Concatenate class token with patch embeddingsx=torch.cat((cls_tokens,x),dim=1)# Add positional embeddings to the inputx+=self.pos_embedding[:,:(n+1)]# Apply dropout for regularizationx=self.dropout(x)然后我们应用Transformer 编码器。我们主要用它来构建包含 3 个内容的输出:**
CLS 标记(图像的单个向量表示)。
特征图(图像每个图像块的向量表示)
分类头逻辑(可选):这用于分类任务。请注意,视觉 Transformer 可以用于不同的任务,但分类是最初使用的任务。
# Pass through transformer encoderx=self.transformer(x)# Shape: (batch_size, num_patches + 1, dim)# Extract class token and feature mapcls_token=x[:,0]# Extract class tokenfeature_map=x[:,1:]# Remaining tokens are feature map# Apply pooling operation: 'cls' token or mean of patchespooled_output=cls_tokenifself.pool=='cls'elsefeature_map.mean(dim=1)# Apply the identity transformation (no change to the tensor)pooled_output=self.to_latent(pooled_output)# Apply the classification head to the pooled outputclassification_result=self.mlp_head(pooled_output)# Return a dictionary with the required componentsreturn{'cls_token':cls_token,# Class token'feature_map':feature_map,# Feature map (patch embeddings)'classification_head_logits':classification_result# Final classification result}总结一下,以下是 ViT 的最终代码。您可以在本 github 仓库中找到其更新版本:
GitHub – FrancoisPorcher/awesome-ai-tutorials: The best collection of AI tutorials to make you a…
classViT(nn.Module):def__init__(self,*,image_size,patch_size,num_classes,dim,depth,heads,mlp_dim_ratio,pool='cls',channels=3,dim_head=64,dropout=0.):""" Initializes a Vision Transformer (ViT) model. Args: image_size (int or tuple): Size of the input image (height, width). patch_size (int or tuple): Size of each patch (height, width). num_classes (int): Number of output classes. dim (int): Dimension of the embedding space. depth (int): Number of transformer layers. heads (int): Number of attention heads. mlp_dim (int): Dimension of the feedforward network. pool (str): Pooling strategy ('cls' or 'mean'). channels (int): Number of input channels (e.g., 3 for RGB images). dim_head (int): Dimension of each attention head. dropout (float): Dropout rate. """super().__init__()# Convert image size and patch size to tuples if they are single valuesimage_height,image_width=pair(image_size)patch_height,patch_width=pair(patch_size)# Ensure that the image dimensions are divisible by the patch sizeassertimage_height%patch_height==0andimage_width%patch_width==0,'Image dimensions must be divisible by the patch size.'# Calculate the number of patches and the dimension of each patchnum_patches=(image_height//patch_height)*(image_width//patch_width)patch_dim=channels*patch_height*patch_width# Define the patch embedding layerself.to_patch_embedding=nn.Sequential(# Rearrange the input tensor to (batch_size, num_patches, patch_dim)Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',p1=patch_height,p2=patch_width),nn.LayerNorm(patch_dim),# Normalize each patchnn.Linear(patch_dim,dim),# Project patches to embedding dimensionnn.LayerNorm(dim)# Normalize the embedding)# Ensure the pooling strategy is validassertpoolin{'cls','mean'},'pool type must be either cls (cls token) or mean (mean pooling)'# Define CLS token and positional embeddingsself.cls_token=nn.Parameter(torch.randn(1,1,dim))# Learnable class tokenself.pos_embedding=nn.Parameter(torch.randn(1,num_patches+1,dim))# Positional embeddings for patches and class tokenself.dropout=nn.Dropout(dropout)# Dropout for regularization# Define the transformer encoderself.transformer=Transformer(dim,depth,heads,dim_head,mlp_dim_ratio,dropout)# Pooling strategy ('cls' token or mean of patches)self.pool=pool# Identity layer (no change to the tensor)self.to_latent=nn.Identity()# Classification headself.mlp_head=nn.Linear(dim,num_classes)defforward(self,img):""" Forward pass through the Vision Transformer model. Args: img (Tensor): Input image tensor of shape (batch_size, channels, height, width). Returns: dict: A dictionary containing the class token, feature map, and classification result. """# Convert image to patch embeddingsx=self.to_patch_embedding(img)# Shape: (batch_size, num_patches, dim)b,n,_=x.shape# Get batch size, number of patches, and embedding dimension# Repeat class token for each image in the batchcls_tokens=repeat(self.cls_token,'1 1 d -> b 1 d',b=b)# Concatenate class token with patch embeddingsx=torch.cat((cls_tokens,x),dim=1)# Add positional embeddings to the inputx+=self.pos_embedding[:,:(n+1)]# Apply dropout for regularizationx=self.dropout(x)# Pass through transformer encoderx=self.transformer(x)# Shape: (batch_size, num_patches + 1, dim)# Extract class token and feature mapcls_token=x[:,0]# Extract class tokenfeature_map=x[:,1:]# Remaining tokens are feature map# Apply pooling operation: 'cls' token or mean of patchespooled_output=cls_tokenifself.pool=='cls'elsefeature_map.mean(dim=1)# Apply the identity transformation (no change to the tensor)pooled_output=self.to_latent(pooled_output)# Apply the classification head to the pooled outputclassification_result=self.mlp_head(pooled_output)# Return a dictionary with the required componentsreturn{'cls_token':cls_token,# Class token'feature_map':feature_map,# Feature map (patch embeddings)'classification_head_logits':classification_result# Final classification result}恭喜,您已经从头开始构建了一个视觉 Transformer!
感谢阅读!在您离开之前:
想要了解更多精彩的教程,请查看我在 Github 上的 AI 教程汇编
GitHub – FrancoisPorcher/awesome-ai-tutorials: The best collection of AI tutorials to make you a…
You should get my articles in your inbox.Subscribe here.
如果您想访问 Medium 上的优质文章,您只需每月支付 5 美元即可。如果您通过我的链接注册,您只需支付部分费用,无需额外费用即可支持我。
如果您觉得这篇文章有见地且有益,请考虑关注我并为我点赞以获取更深入的内容!您的支持帮助我继续创作有助于我们共同理解的内容。
参考文献
•“An Image is Worth 16×16 Words”由 Alexey Dosovitskiy 等人(2021)发表。您可以在 arXiv 上阅读完整论文。