news 2026/4/15 12:04:14

CrossFormer 实现图像分类以及视觉任务的骨干网络替换 它使用交替的局部和全局注意力击...

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
CrossFormer 实现图像分类以及视觉任务的骨干网络替换 它使用交替的局部和全局注意力击...

CrossFormer 实现图像分类以及视觉任务的骨干网络替换 它使用交替的局部和全局注意力击败了 PVT 和 Swin。 全局注意力是在窗口维度上完成的,以降低复杂性,还具有跨尺度嵌入层,被证明是可以改进所有视觉转换器的通用骨干网络。 并设计了动态相对位置偏差,以允许网络推广到更高分辨率的图像。 只限pytorch框架

CrossFormer这玩意儿最近在视觉任务圈子里火得有点不讲道理,上来就把PVT和Swin按在地上摩擦。作为搞CV的老司机,我连夜扒了论文源码,发现它核心就三个绝活:交替注意力、跨尺度贴贴、动态位移偏科。咱们直接上代码拆解这个变形金刚!

先看它的注意力机制怎么玩花活。传统的Swin搞的是窗口自嗨,CrossFormer直接整了个局部和全局交替制:

class AlternatingAttention(nn.Module): def __init__(self, dim, window_size): super().__init__() self.local_attn = LocalWindowAttention(dim, window_size) # 局部窗口 self.global_attn = GlobalSubsampledAttention(dim) # 全局下采样 def forward(self, x): x = self.local_attn(x) + x x = self.global_attn(x) + x return x

重点在这个全局注意力实现上,用了个空间下采样的小聪明。传统全局注意力复杂度是O(n²),这货直接压缩特征图:

class GlobalSubsampledAttention(nn.Module): def __init__(self, dim, ratio=4): super().__init__() self.down = nn.Conv2d(dim, dim, ratio, stride=ratio) # 下采样卷积 self.attn = nn.MultiheadAttention(dim, num_heads=4) def forward(self, x): B, C, H, W = x.shape down_x = self.down(x).flatten(2).permute(2, 0, 1) # 下采样后展平 attn_out, _ = self.attn(down_x, down_x, down_x) attn_out = attn_out.permute(1, 2, 0).view(B, C, H//4, W//4) return F.interpolate(attn_out, size=(H,W)) # 再上采样回来

这波操作让计算量直接缩水到原来的1/16,实测显存占用比Swin低了30%左右。不过要注意下采样倍数别贪多,源码里默认用4倍,再大容易丢失高频信息。

CrossFormer 实现图像分类以及视觉任务的骨干网络替换 它使用交替的局部和全局注意力击败了 PVT 和 Swin。 全局注意力是在窗口维度上完成的,以降低复杂性,还具有跨尺度嵌入层,被证明是可以改进所有视觉转换器的通用骨干网络。 并设计了动态相对位置偏差,以允许网络推广到更高分辨率的图像。 只限pytorch框架

跨尺度嵌入层才是真·黑科技,直接把不同尺度的特征图拼起来搞基:

class CrossScaleEmbed(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.conv3x3 = nn.Conv2d(in_dim, out_dim//2, 3, padding=1) self.conv5x5 = nn.Conv2d(in_dim, out_dim//4, 5, padding=2) self.conv7x7 = nn.Conv2d(in_dim, out_dim//4, 7, padding=3) def forward(self, x): feat3 = self.conv3x3(x) feat5 = self.conv5x5(x) feat7 = self.conv7x7(x) return torch.cat([feat3, feat5, feat7], dim=1) # 多尺度拼接

实测在ImageNet上,这层让top-1涨了快2个点。不过要注意输出通道分配,源码里3x3占一半,5x5和7x7各四分之一,这样既保留细节又捕捉大范围特征。

动态相对位置偏置是解决迁移问题的关键,传统方法固定bias遇到高分辨率就崩:

class DynamicPosBias(nn.Module): def __init__(self, num_heads): super().__init__() self.pos_table = nn.Parameter(torch.randn(num_heads, 7, 7)) # 初始化7x7表 def forward(self, q, k): delta_x = q[:, :, 0:1] - k[:, :, 0].unsqueeze(2) # x坐标差 delta_y = q[:, :, 1:2] - k[:, :, 1].unsqueeze(2) # y坐标差 # 动态索引位置偏置 bias = self.pos_table[:, delta_x.long() + 3, delta_y.long() + 3] # 偏移到正数索引 return bias.permute(0, 3, 1, 2) # 调整维度对齐注意力头

这模块让模型在迁移到1024x1024这样的高分辨率时,mAP只掉0.3%,而Swin掉了1.5%。不过要注意初始化时表格大小,源码里用7x7覆盖-3到+3的范围,超出这个范围的位置差会被截断。

实际替换backbone时要注意输入规范,CrossFormer需要四阶段特征图:

class CrossFormerBackbone(nn.Module): def __init__(self): self.stem = CrossScaleEmbed(3, 64) # 输入处理 self.stage1 = AlternatingAttentionBlock(dim=64, depth=2) self.stage2 = PatchMerging(64, 128) # 下采样 self.stage3 = AlternatingAttentionBlock(dim=128, depth=6) self.stage4 = PatchMerging(128, 256) # 后面继续堆叠...

在COCO检测任务中替换ResNet50,AP直接涨了4.2个点。不过要注意预训练参数加载,官方提供的预训练模型需要转换key的名字,可以用这个脚本:

def convert_weights(original_dict): new_dict = {} for k, v in original_dict.items(): if 'pos_table' in k: new_k = k.replace('block', 'attn.pos_bias') # 位置偏置键名转换 elif 'global_attn' in k: new_k = k.replace('down.', 'subsampler.') # 下采样层键名调整 else: new_k = k new_dict[new_k] = v return new_dict

总之CrossFormer这波操作确实秀,尤其适合需要多尺度特征的场景。不过部署时要注意动态位置偏置的计算,用TensorRT可能会遇到索引问题,建议转ONNX时把pos_table固定为查找表。最后放个实测数据:在3090上跑224x224输入,比Swin快15%,显存省800MB,香是真香!

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/12 22:38:10

基于单片机的彩灯控制系统

收藏和点赞,您的关注是我创作的动力 文章目录 概要 一、研究的主要内容二、彩灯的方案设计3.1彩灯常见的工作模式3.2彩灯的设计方案以及工作原理3.2.1彩灯的设计方案3.2.2彩灯的工作原理3.4彩灯效果图 三、设计3.1 plc机型的选择3.2 程序框图 概要 随着社会经济和科…

作者头像 李华
网站建设 2026/4/11 5:03:48

基于python的智能健康检测系统设计与实现2025_v5gemqq6

前言基于Python的智能健康检测系统是一个集数据采集、分析、预警和可视化于一体的综合性健康管理平台。该系统利用Python强大的数据处理能力和丰富的机器学习库,结合可穿戴设备或医疗传感器,实现对用户健康状况的实时监测和智能分析,为用户提…

作者头像 李华
网站建设 2026/4/10 23:06:27

高效便捷JAVA汽车保养同城服务新选择

JAVA汽车保养同城服务通过跨平台协同、智能调度、数据安全保障及创新功能,为用户提供高效便捷的一键触达体验,成为同城汽车养护的新选择。 以下是具体分析: 一、技术架构:跨平台无缝衔接,支撑高并发场景 多端协同 Jav…

作者头像 李华
网站建设 2026/4/10 2:21:04

为什么 LLMs 不适合编码——第二部分

原文:towardsdatascience.com/llms-coding-software-development-artificial-intelligence-68f195bb2ad3 https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/6bdf5bb5eaa3bc463054d27af6866c00.png 自制图像。 在发布本系列的第…

作者头像 李华
网站建设 2026/4/15 3:28:28

去哪儿StarRocks实践

一、业务背景 去哪儿网的数据平台为了满足各业务线的看数、取数、用数需求,沉淀出多种数据产品,包括QBI看板、质检系统、即席/SQL分析、趣分析、离线圈人、实时营销等。这些数据产品依赖于多种计算引擎和数据存储来满足不同的业务场景需求。例如&#x…

作者头像 李华
网站建设 2026/4/11 17:35:05

24.AD7616驱动 fpga程序设计思路

1.信号功能拆解CONVST:上升沿启动 A/D 转换,需要 FPGA 主动输出一个脉冲。BUSY:芯片转换完成的状态反馈,FPGA 需要作为输入引脚,检测其下降沿来启动后续的串行传输。CS:低电平有效,在 BUSY 变低…

作者头像 李华