news 2026/2/10 10:47:54

Swin Transformer整体架构

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Swin Transformer整体架构

完美!这张图清楚地展示了Swin Transformer的核心创新。我来详细解释架构和为什么必须交替使用W-MSA和SW-MSA


一、Swin Transformer整体架构

设计理念

Swin =Shifted WindowTransformer

  • 核心思想:在局部窗口内计算注意力,而非全局
  • 通过窗口移动实现跨窗口信息交互
  • 构建层次化特征金字塔(类似CNN)

完整流程(对应你的图)

输入图像 [H×W×3] ↓ Patch Partition (4×4切分) ↓ [H/4 × W/4 × 48] ← 每个patch flatten成48维 ↓ Linear Embedding ↓ [H/4 × W/4 × C] ← 投影到C维(如96) ↓ ┌──────────────────────────────────┐ │ Stage 1: 2× Swin Transformer Block │ │ - Block 1: W-MSA + MLP │ │ - Block 2: SW-MSA + MLP │ ← 交替使用! └──────────────────────────────────┘ ↓ Patch Merging (空间下采样) ↓ [H/8 × W/8 × 2C] ↓ Stage 2/3/4... (重复上述过程)

二、W-MSA:Window Multi-Head Self-Attention

传统ViT的问题

# ViT的全局注意力 输入: [H×W个patches] 计算复杂度: O((H×W)²) # 每个patch和所有patch计算注意力 # 例如224×224图像,patch=16×16 patches数量 = (224/16)² = 196 复杂度 = 196² = 38,416 次计算 # 如果高分辨率图像1024×1024 patches = 4096 复杂度 = 4096² = 16,777,216 次计算!💥 无法承受

W-MSA的解决方案

将特征图划分成不重叠的窗口(Window)

假设特征图 56×56,窗口大小 M=7 划分结果: ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐ │ W00 │ │ │ │ W01 │ │ │ │ │ 7×7│ │ │ │ 7×7│ │ │ │ │ │ │ │ │ │ │ │ │ ├─────┴─────┴─────┴─────┼─────┴─────┴─────┴─────┤ │ Window 0 │ Window 1 │ └────────────────────────┴────────────────────────┘ 总共: (56/7) × (56/7) = 8×8 = 64个窗口 每个窗口: 7×7 = 49个patches

计算复杂度对比

# ViT全局注意力 complexity = (56×56)² ≈ 10M # W-MSA窗口注意力 num_windows = (56/7)² = 64 complexity_per_window = (7×7)² = 2,401 total_complexity = 64 × 2,401 ≈ 154K 加速比: 10M / 154K ≈ 65倍!

W-MSA的代码逻辑

def window_partition(x, window_size): """ x: (B, H, W, C) 返回: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() windows = windows.view(-1, window_size, window_size, C) return windows # 在每个窗口内独立计算注意力 def W_MSA(x): windows = window_partition(x, M=7) # 切分窗口 attn_windows = self_attention(windows) # 窗口内注意力 x = window_reverse(attn_windows) # 还原 return x

三、为什么W-MSA不够?致命缺陷!

问题:窗口之间完全隔离

窗口划分示例(M=4): ┌────┬────┬────┬────┐┌────┬────┬────┬────┐ │ 🔴 │ │ │ ││ 🔵 │ │ │ │ ├────┼────┼────┼────┤├────┼────┼────┼────┤ │ │ │ │ ││ │ │ │ │ ├────┼────┼────┼────┤├────┼────┼────┼────┤ │ │ │ │ ⚠️ ││ ⚠️ │ │ │ │ ├────┼────┼────┼────┤├────┼────┼────┼────┤ │ │ │ │ ││ │ │ │ │ └────┴────┴────┴────┘└────┴────┴────┴────┘ Window 0 Window 1 ⚠️ 问题: - 🔴 和 🔵 永远无法交互! - 窗口边界的信息被人为割裂 - 无法建模跨窗口的长距离依赖

具体例子

物体检测任务: 一只猫横跨两个窗口 ┌────────────┐┌────────────┐ │ 🐱头部 ││ 🐱身体 │ │ Window 0 ││ Window 1 │ └────────────┘└────────────┘ 纯W-MSA的结果: - Window 0只能看到头部特征 - Window 1只能看到身体特征 - 模型无法理解这是同一只猫!

四、SW-MSA:Shifted Window的解决方案

核心思想:移动窗口位置

第 l 层 (W-MSA): ┌────┬────┐┌────┬────┐ │ A │ B ││ C │ D │ ├────┼────┤├────┼────┤ │ E │ F ││ G │ H │ └────┴────┘└────┴────┘ 窗口0 窗口1 第 l+1 层 (SW-MSA): 向右下移动 M/2 ┌────┬────┐┌────┬────┐ │ ? │ A ││ B │ ? │ ├────┼────┤├────┼────┤ │ ? │ E ││ F │ ? │ └────┴────┘└────┴────┘ 新窗口0 新窗口1 关键变化: - 原来A和C在不同窗口,现在在同一窗口! - B、F可以交互了 - 实现了跨原始窗口边界的信息交换

详细图解(M=4的例子)

W-MSA (第l层)

原始窗口划分 (4×4) ┌─────────┬─────────┐ │ 0 1 2 │ 3 4 5 │ │ 6 7 8 │ 9 10 11 │ │12 13 14 │15 16 17 │ ├─────────┼─────────┤ ← 窗口边界 │18 19 20 │21 22 23 │ │24 25 26 │27 28 29 │ │30 31 32 │33 34 35 │ └─────────┴─────────┘ 问题: - patch 14 和 15 无法交互(被边界隔开) - patch 8 和 21 无法交互(被边界隔开)

SW-MSA (第l+1层)

窗口向右下移动2个位置 (M/2=2) ┌─────────┬─────────┐ │ 7 8 9 │10 11 │ │13 14 15 │16 17 │ │19 20 21 │22 23 │ ├─────────┼─────────┤ │25 26 27 │28 29 │ │31 32 33 │34 35 │ └─────────┴─────────┘ 现在: - ✅ patch 14 和 15 在同一窗口了! - ✅ patch 8 和 21 也可以交互了!

SW-MSA的实现技巧:Cyclic Shift

直接移动会产生不规则窗口(边界问题),Swin使用巧妙的循环移位+mask

# 第1步:Cyclic Shift(循环移位) def cyclic_shift(x, shift_size): return torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) # 移位前 ┌──┬──┬──┬──┐ │A │B │C │D │ ├──┼──┼──┼──┤ │E │F │G │H │ ├──┼──┼──┼──┤ │I │J │K │L │ ├──┼──┼──┼──┤ │M │N │O │P │ └──┴──┴──┴──┘ # 向右下移动2格后(循环) ┌──┬──┬──┬──┐ │K │L │I │J │ ← I、J循环到右边 ├──┼──┼──┼──┤ │O │P │M │N │ ├──┼──┼──┼──┤ │C │D │A │B │ ← A、B循环到右边 ├──┼──┼──┼──┤ │G │H │E │F │ └──┴──┴──┴──┘

第2步:使用Attention Mask

循环移位后,某些patch不应该交互(比如原本距离很远)

# 构造mask防止不相邻区域交互 mask矩阵示例: ┌────┬────┬────┬────┐ │ 0 │ -∞ │ 0 │ -∞ │ ← 0表示可见,-∞表示屏蔽 ├────┼────┼────┼────┤ │ -∞ │ 0 │ -∞ │ 0 │ ├────┼────┼────┼────┤ │ 0 │ -∞ │ 0 │ -∞ │ ├────┼────┼────┼────┤ │ -∞ │ 0 │ -∞ │ 0 │ └────┴────┴────┴────┘ # 注意力计算 attn = softmax(Q·K^T / √d + mask) # mask=-∞的位置,softmax后=0,不产生交互

完整SW-MSA代码

def SW_MSA(x, shift_size): B, H, W, C = x.shape # Step 1: Cyclic Shift shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) # Step 2: 窗口划分 x_windows = window_partition(shifted_x, window_size) # Step 3: 窗口内注意力 (带mask) attn_windows = self_attention(x_windows, mask=attn_mask) # Step 4: 窗口合并 shifted_x = window_reverse(attn_windows, window_size, H, W) # Step 5: Reverse Cyclic Shift(移回去) x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2)) return x

五、为什么要交替使用W-MSA和SW-MSA?

设计原理

Swin Block结构(你的图中): ┌─────────────────────────────┐ │ Block l (偶数层) │ │ Layer Norm │ │ W-MSA ← 固定窗口 │ │ Residual │ │ Layer Norm │ │ MLP │ │ Residual │ └─────────────────────────────┘ ↓ ┌─────────────────────────────┐ │ Block l+1 (奇数层) │ │ Layer Norm │ │ SW-MSA ← 移动窗口 │ │ Residual │ │ Layer Norm │ │ MLP │ │ Residual │ └─────────────────────────────┘

交替使用的三大原因

1. 建立跨窗口连接
2层的信息流动: Layer l (W-MSA): ┌────┬────┐┌────┬────┐ │ A │ B ││ C │ D │ └────┴────┘└────┴────┘ 窗口内交互: A↔B, C↔D Layer l+1 (SW-MSA): ┌────┬────┐ │ A │ B │ ← A现在可以和C交互了! │ C │ D │ └────┴────┘ 跨窗口交互: A↔C, B↔D 结果: 2层后,所有patch间接相连!
2. 高效建模长距离依赖
信息传播路径: A → B (layer l) ↓ B → D (layer l+1) ← 跨窗口 ↓ D → ... (layer l+2) 经过多层后: - 每个patch的感受野呈指数增长 - O(log(H/M))层后达到全局感受野 - 远比全局注意力O(n²)高效
3. 保持计算效率
复杂度对比: 纯全局注意力(ViT): 每层 = O((H×W)²) 4层 = 4 × O((H×W)²) Swin (W-MSA + SW-MSA交替): 每层 = O(M² × H/M × W/M) = O(M×H×W) ← 线性复杂度! 4层 = 4 × O(M×H×W) 当H=W=56, M=7: ViT: 4 × (56×56)² ≈ 40M Swin: 4 × 7 × 56 × 56 ≈ 88K

六、直观类比

城市交通网络类比

W-MSA= 社区内道路

┌─────────┐ ┌─────────┐ │ 社区A │ │ 社区B │ │ 🏠→🏠→🏠│ │🏠→🏠→🏠│ │ ↓ │ │ ↓ │ │ 🏠 │ │ 🏠 │ └─────────┘ └─────────┘ 社区内高效通行,但社区间隔离

SW-MSA= 跨社区快速路

┌─────────┐ │ 🏠 🏠 │ ← A和B的居民可以交流了 │ 🏠 🏠 │ └─────────┘ 定期调整道路布局,连接不同社区

交替使用= 完整交通网络

  • 局部有社区道路(高效)
  • 定期有快速路(连接)
  • 几次换乘可达任意地点(全局感受野)

七、与ViT的对比总结

特性ViTSwin Transformer
注意力范围全局局部窗口
计算复杂度O(n²)O(n)
特征层次单尺度多尺度金字塔
跨窗口交互-Shifted Window
适用任务分类分类+检测+分割
高分辨率❌ 不可行✅ 可行

八、实际效果验证

# ImageNet分类性能 Swin-T (28M参数): 81.3% Top-1 ViT-B (86M参数): 81.8% Top-1 → Swin用更少参数达到相近性能 # COCO目标检测 Swin-T: 50.5 mAP ResNet-50: 46.9 mAP → Swin的层次化特征更适合检测 # 推理速度(224×224图像) ViT-B: 18ms Swin-T: 12ms → Swin更高效

总结:为什么必须交替?

  1. W-MSA单独使用

    • ✅ 计算高效
    • ❌ 窗口隔离,无全局视野
  2. SW-MSA的作用

    • 打破窗口边界
    • 建立跨窗口连接
    • 扩大感受野
  3. 交替使用的智慧

    • 平衡局部效率和全局感受野
    • 2层就能让所有patch间接相连
    • 深层网络后达到全局理解能力

核心公式

局部效率 (W-MSA) + 全局连接 (SW-MSA) = 高效的全局建模

swin transform中的patch merging操作,是不是在通道维度合并,还是在空间维度合并

看你上传的图片!我来帮你详细分析Swin Transformer的Patch Merging操作。

答案:在空间维度合并,同时通道维度扩展

这个操作同时影响两个维度,但核心是空间下采样


Patch Merging的具体过程

操作步骤(结合你的图片)

输入特征图H × W × C(比如图中的7×7网格,每个patch有96通道)

第1步:空间划分(2×2分组)

将相邻的2×2个patches分成一组 ┌─────┬─────┐ │ P00 │ P01 │ → 一组 ├─────┼─────┤ │ P10 │ P11 │ └─────┴─────┘

从你的图片看:

  • 原始:7×7 = 49个patches
  • 分组后:每4个相邻patches作为一组
  • 结果:空间分辨率降低到大约 4×4(向下取整)

第2步:通道拼接

# 伪代码演示 x = x.reshape(B, H//2, 2, W//2, 2, C) # 重排形状 x = x.permute(0, 1, 3, 2, 4, 5) # 调整维度顺序 x = x.reshape(B, H//2, W//2, 4*C) # 拼接:4个patch的C通道 → 4C通道 # 具体例子 输入: [7, 7, 96] # H=7, W=7, C=96 ├─ 重排: [3, 2, 3, 2, 96] # 7//2=3 (向下取整) └─ 拼接: [3, 3, 4×96] = [3, 3, 384] # 4C = 4×96 = 384

关键理解(对应图中的标注):

  • 4C = 4 × 96:4个patches的通道拼接在一起
  • 空间:7×7 → 约4×4(实际会是 ⌊7/2⌋ × ⌊7/2⌋ = 3×3)
  • 通道:96 → 384

第3步:线性降维(可选)

x = Linear(4C, 2C)(x) # 384 → 192 # 最终输出 输入: [7, 7, 96] 输出: [3, 3, 192] # 空间减半,通道翻倍

通常会接一个线性层把4C降到2C,这样:

  • 空间分辨率降低2倍(H/2 × W/2)
  • 通道数翻倍(C → 2C)

图解对比

原始patches (你的图片): ┌───┬───┬───┬───┬───┬───┬───┐ │ │ │ │ │ │ │ │ 7×7网格 ├───┼───┼───┼───┼───┼───┼───┤ 每个格子=1个patch │ │🐱│ │ │🐱│ │ │ 通道数=96 ├───┼───┼───┼───┼───┼───┼───┤ │ │🐱│🐱│🐱│🐱│ │ │ └───┴───┴───┴───┴───┴───┴───┘ ↓ Patch Merging 合并后: ┌─────────┬─────────┬─────────┐ │ │ │ │ 约4×4网格 │ 4个合1 │ 4个合1 │ 4个合1 │ 每个大格=4个原始patch ├─────────┼─────────┼─────────┤ 通道数=384 (4×96) │ 🐱🐱 │ 🐱🐱 │ │ └─────────┴─────────┴─────────┘ ↓ Linear降维 (可选) 最终输出: 空间:约4×4 通道:192 (2×96)

为什么这样设计?

1. 模拟CNN的下采样金字塔

Stage 1: 56×56 patches, 96通道 ┐ ↓ Patch Merging │ 类似ResNet的 Stage 2: 28×28 patches, 192通道 │ 多尺度特征 ↓ Patch Merging │ Stage 3: 14×14 patches, 384通道 │ ↓ Patch Merging │ Stage 4: 7×7 patches, 768通道 ┘

2. 计算效率考虑

  • 空间减半:计算量降低到原来的1/4
  • 通道翻倍:增加模型容量,学习更复杂特征
  • 平衡:深层网络用更少的空间位置,但更丰富的通道特征

代码实现解析

class PatchMerging(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = nn.LayerNorm(4 * dim) def forward(self, x): """ x: (B, H, W, C) """ B, H, W, C = x.shape # 确保H和W是偶数(padding处理) pad_input = (H % 2 == 1) or (W % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) # 空间维度合并:取2×2网格的4个位置 x0 = x[:, 0::2, 0::2, :] # 左上角 [B, H/2, W/2, C] x1 = x[:, 1::2, 0::2, :] # 右上角 x2 = x[:, 0::2, 1::2, :] # 左下角 x3 = x[:, 1::2, 1::2, :] # 右下角 # 通道维度拼接:4个patch的通道拼在一起 x = torch.cat([x0, x1, x2, x3], dim=-1) # [B, H/2, W/2, 4C] x = self.norm(x) x = self.reduction(x) # [B, H/2, W/2, 2C] return x # 使用示例 input_tensor = torch.randn(1, 7, 7, 96) # [B, H, W, C] merge = PatchMerging(dim=96) output = merge(input_tensor) print(output.shape) # torch.Size([1, 4, 4, 192]) # ↑ ↑ ↑ # H/2 W/2 2C

与你图片的对应关系

你的图片展示的关键点:

  1. 7×7的网格:原始patches的空间布局

  2. 4C = 4 × 96

    • 4 = 2×2个相邻patches
    • 96 = 原始每个patch的通道数
    • 384 = 合并后的临时通道数
  3. 猫的图像被分割

    • 每个小格子 = 1个patch embedding
    • Merging后,4个小格子的信息融合到1个大格子

总结对比

维度操作类型变化目的
空间下采样H×W → (H/2)×(W/2)减少计算量,类似池化
通道拼接+降维C → 4C → 2C增加表达能力

核心答案

  • 主要在空间维度合并(2×2的patches变成1个)
  • 通道维度是副作用(4个patch的通道拼接导致4C,然后降到2C)

与ViT的Patch Embedding对比

ViT的Patch Embedding: 图像 [224×224×3] → 切分成patches [14×14个, 每个16×16×3] → 线性投影 [14×14, 768维] → 全程保持14×14,不做merging Swin的Patch Merging: 特征图 [56×56×96] → 每2×2 patches合并 [28×28×384] → 线性降维 [28×28×192] → 多次merging构建金字塔 [56→28→14→7]

Swin的优势多尺度特征层次,更适合密集预测任务(检测、分割)


版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/2/5 11:55:40

JVM(JAVA虚拟机内存不足)

这是本人第二次遇到这个问题,打开pycharm还没有Start就撒由那拉了,报出来一堆IDE错误,看都看不懂,上次遇到是大模型的指导改了配置文件依然没解决,这次跟着大佬操作,电脑十分丝滑 步骤(win11&am…

作者头像 李华
网站建设 2026/2/7 16:13:49

类变量和全局变量的生命周期分别是多久?

类变量和全局变量的生命周期核心差异在于创建时机、存活范围、销毁条件,本质由它们的 “归属对象”(类 vs 模块)决定,以下是分维度的精准解析:一、先明确核心前提Python 中变量的生命周期依附于其所属的命名空间对象&a…

作者头像 李华
网站建设 2026/2/5 18:21:36

前后端分离考试系统系统|SpringBoot+Vue+MyBatis+MySQL完整源码+部署教程

摘要 随着信息技术的快速发展,传统考试系统在效率、灵活性和用户体验方面逐渐暴露出局限性。传统系统通常采用前后端耦合的设计模式,导致系统维护困难、扩展性差,难以满足现代教育中对个性化考试和实时数据分析的需求。前后端分离架构通过将用…

作者头像 李华
网站建设 2026/2/6 9:29:12

Java SpringBoot+Vue3+MyBatis 考试系统系统源码|前后端分离+MySQL数据库

💡实话实说:用最专业的技术、最实惠的价格、最真诚的态度服务大家。无论最终合作与否,咱们都是朋友,能帮的地方我绝不含糊。买卖不成仁义在,这就是我的做人原则。摘要 随着信息技术的快速发展,传统的考试管…

作者头像 李华
网站建设 2026/2/9 0:52:29

我是如何成为一名拥有破碎英语的国际学生数据科学家的

原文:towardsdatascience.com/how-i-became-a-data-scientist-as-an-international-student-with-broken-english-cce01d6db840?sourcecollection_archive---------5-----------------------#2024-04-09 花了很长时间,但我找到了自己愿意做一辈子的事 …

作者头像 李华