news 2026/6/3 9:37:22

[深度学习]Vision Transformer

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
[深度学习]Vision Transformer

Pytorch实现Vision Transformer

importtorchimporttorch.nnasnnclassPatchEmbedding(nn.Module):def__init__(self,img_size=224,patch_size=16,in_channels=3,embed_dim=768):super().__init__()self.img_size=img_size self.patch_size=patch_size self.n_patches=(img_size//patch_size)**2# 使用卷积层实现patch分割和嵌入self.proj=nn.Conv2d(in_channels=in_channels,out_channels=embed_dim,kernel_size=patch_size,stride=patch_size)defforward(self,x):# 输入x形状: [batch_size, in_channels, img_size, img_size]# 输出形状: [batch_size, n_patches, embed_dim]x=self.proj(x)# [batch_size, embed_dim, n_patches^0.5, n_patches^0.5]x=x.flatten(2)# [batch_size, embed_dim, n_patches]x=x.transpose(1,2)# [batch_size, n_patches, embed_dim]returnxclassPositionEmbedding(nn.Module):def__init__(self,n_patches,embed_dim,dropout=0.1):super().__init__()self.pos_embed=nn.Parameter(torch.zeros(1,n_patches+1,embed_dim))# +1 for class tokenself.dropout=nn.Dropout(dropout)defforward(self,x):# x形状: [batch_size, n_patches+1, embed_dim]x=x+self.pos_embed# 添加位置编码x=self.dropout(x)returnxclassMultiHeadAttention(nn.Module):def__init__(self,embed_dim,num_heads,dropout=0.1):super().__init__()self.embed_dim=embed_dim self.num_heads=num_heads self.head_dim=embed_dim//num_headsassertself.head_dim*num_heads==embed_dim,"Embedding dimension must be divisible by number of heads"self.qkv=nn.Linear(embed_dim,embed_dim*3)# 同时计算Q,K,Vself.attn_dropout=nn.Dropout(dropout)self.proj=nn.Linear(embed_dim,embed_dim)self.proj_dropout=nn.Dropout(dropout)self.scale=self.head_dim**-0.5defforward(self,x):batch_size,n_patches,embed_dim=x.shape# 计算Q,K,V [batch_size, n_patches, num_heads, head_dim]qkv=self.qkv(x).reshape(batch_size,n_patches,3,self.num_heads,self.head_dim).permute(2,0,3,1,4)q,k,v=qkv[0],qkv[1],qkv[2]# 计算注意力分数 [batch_size, num_heads, n_patches, n_patches]attn=(q @ k.transpose(-2,-1))*self.scale attn=attn.softmax(dim=-1)attn=self.attn_dropout(attn)# 应用注意力权重到V上 [batch_size, num_heads, n_patches, head_dim]out=attn @ v out=out.transpose(1,2).reshape(batch_size,n_patches,embed_dim)# 线性投影和dropoutout=self.proj(out)out=self.proj_dropout(out)returnoutclassMLP(nn.Module):def__init__(self,in_features,hidden_features,out_features,dropout=0.1):super().__init__()self.fc1=nn.Linear(in_features,hidden_features)self.act=nn.GELU()self.fc2=nn.Linear(hidden_features,out_features)self.dropout=nn.Dropout(dropout)defforward(self,x):x=self.fc1(x)x=self.act(x)x=self.dropout(x)x=self.fc2(x)x=self.dropout(x)returnxclassTransformerBlock(nn.Module):def__init__(self,embed_dim,num_heads,mlp_ratio=4,dropout=0.1):super().__init__()self.norm1=nn.LayerNorm(embed_dim)self.attn=MultiHeadAttention(embed_dim,num_heads,dropout)self.norm2=nn.LayerNorm(embed_dim)self.mlp=MLP(in_features=embed_dim,hidden_features=embed_dim*mlp_ratio,out_features=embed_dim,dropout=dropout)defforward(self,x):# 残差连接和层归一化x=x+self.attn(self.norm1(x))x=x+self.mlp(self.norm2(x))returnxclassVisionTransformer(nn.Module):def__init__(self,img_size=224,patch_size=16,in_channels=3,n_classes=1000,embed_dim=768,depth=12,num_heads=12,mlp_ratio=4,dropout=0.1):super().__init__()self.patch_embed=PatchEmbedding(img_size,patch_size,in_channels,embed_dim)n_patches=self.patch_embed.n_patches# 分类token和位置编码self.cls_token=nn.Parameter(torch.zeros(1,1,embed_dim))self.pos_embed=PositionEmbedding(n_patches,embed_dim,dropout)# Transformer编码器self.blocks=nn.Sequential(*[TransformerBlock(embed_dim,num_heads,mlp_ratio,dropout)for_inrange(depth)])# 分类头self.norm=nn.LayerNorm(embed_dim)self.head=nn.Linear(embed_dim,n_classes)# 初始化权重nn.init.trunc_normal_(self.cls_token,std=0.02)defforward(self,x):batch_size=x.shape[0]# 生成patch嵌入x=self.patch_embed(x)# [batch_size, n_patches, embed_dim]# 添加class tokencls_token=self.cls_token.expand(batch_size,-1,-1)x=torch.cat([cls_token,x],dim=1)# [batch_size, n_patches+1, embed_dim]# 添加位置编码x=self.pos_embed(x)# 通过Transformer编码器x=self.blocks(x)# 分类x=self.norm(x)cls_token_final=x[:,0]# 只取class token对应的输出x=self.head(cls_token_final)returnxif__name__=='__main__':x=torch.rand(1,3,224,224)model=VisionTransformer(img_size=224,patch_size=16,)y=model(x)print('y.shape = ',y.shape)print(y)

参考资料

  1. 博客
  2. B站教程
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/21 11:22:49

【LLM】CL-bench:评估LLM学新知识的能力

note CL-bench 围绕一个简单但严格的设计原则构建:每个任务都必须要求从 context 中学习新知识。 CL-bench 中的每个 context 都是完全自包含(Self-contained)的。解决任务所需的所有信息都显式地提供在 context 本身之中:不需要…

作者头像 李华
网站建设 2026/6/3 9:02:37

PPT配图神器01Agent:3秒生成可编辑配图,AI帮你告别找图烦恼

PPT配图快速生成,01Agent让你的职场视觉化表达更生动 做PPT最崩溃的时刻是什么? 不是写文案,不是理逻辑,而是找配图。 你打开搜索引擎,输入关键词,翻了20页图片,要么版权不明不敢用&#xff…

作者头像 李华
网站建设 2026/5/28 21:17:34

如何打造品牌网站-让你的网站脱颖而出!

各位亲爱的朋友们,你是不是正在为企业的发展而焦虑?是不是在寻找一种能够快速提升企业形象和知名度的方式?那就是打造一个专业的品牌网站!但是,怎么做?要做哪些准备?小编来告诉你!首…

作者头像 李华
网站建设 2026/6/1 19:08:57

龙魂模型这模型会说谎吗?

😂😂😂 老大在厕所急着看回复,宝宝笑死了! 激动到肚子疼这事,宝宝头一次听说! 🚽😂 华为加载卡住,你还在那干着急,画面太美 📱&#x…

作者头像 李华
网站建设 2026/5/21 11:23:28

2025年程序员都转行,我该何去何从呢!

2025年程序员都转行,我该何去何从呢! 疫情后大环境下行,各行各业的就业情况都是一言难尽。互联网行业更是极不稳定,频频爆出裁员的消息。大家都说2024年程序员的就业很难,都很焦虑,在许多人眼里,程序员可能是一群背着电脑、 进入大上写字楼的…

作者头像 李华
网站建设 2026/5/27 2:23:01

小公司的研发后期,基本等同于售后服务部

大公司可以把研发、测试、技术支持切分成几个独立部门,每个人只需要盯着自己那一亩三分地。而小公司呢?芯片流片回来,问题开始冒头,客户开始提需求,研发工程师就得立刻切换频道——上午还在看前仿真波形查bug,下午就得跑到客户现场调试设备。大公司的责任分散——研发说是需求…

作者头像 李华