完美!这张图清楚地展示了Swin Transformer的核心创新。我来详细解释架构和为什么必须交替使用W-MSA和SW-MSA。
一、Swin Transformer整体架构
设计理念
Swin =Shifted WindowTransformer
- 核心思想:在局部窗口内计算注意力,而非全局
- 通过窗口移动实现跨窗口信息交互
- 构建层次化特征金字塔(类似CNN)
完整流程(对应你的图)
输入图像 [H×W×3] ↓ Patch Partition (4×4切分) ↓ [H/4 × W/4 × 48] ← 每个patch flatten成48维 ↓ Linear Embedding ↓ [H/4 × W/4 × C] ← 投影到C维(如96) ↓ ┌──────────────────────────────────┐ │ Stage 1: 2× Swin Transformer Block │ │ - Block 1: W-MSA + MLP │ │ - Block 2: SW-MSA + MLP │ ← 交替使用! └──────────────────────────────────┘ ↓ Patch Merging (空间下采样) ↓ [H/8 × W/8 × 2C] ↓ Stage 2/3/4... (重复上述过程)二、W-MSA:Window Multi-Head Self-Attention
传统ViT的问题
# ViT的全局注意力 输入: [H×W个patches] 计算复杂度: O((H×W)²) # 每个patch和所有patch计算注意力 # 例如224×224图像,patch=16×16 patches数量 = (224/16)² = 196 复杂度 = 196² = 38,416 次计算 # 如果高分辨率图像1024×1024 patches = 4096 复杂度 = 4096² = 16,777,216 次计算!💥 无法承受W-MSA的解决方案
将特征图划分成不重叠的窗口(Window)
假设特征图 56×56,窗口大小 M=7 划分结果: ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐ │ W00 │ │ │ │ W01 │ │ │ │ │ 7×7│ │ │ │ 7×7│ │ │ │ │ │ │ │ │ │ │ │ │ ├─────┴─────┴─────┴─────┼─────┴─────┴─────┴─────┤ │ Window 0 │ Window 1 │ └────────────────────────┴────────────────────────┘ 总共: (56/7) × (56/7) = 8×8 = 64个窗口 每个窗口: 7×7 = 49个patches计算复杂度对比:
# ViT全局注意力 complexity = (56×56)² ≈ 10M # W-MSA窗口注意力 num_windows = (56/7)² = 64 complexity_per_window = (7×7)² = 2,401 total_complexity = 64 × 2,401 ≈ 154K 加速比: 10M / 154K ≈ 65倍!W-MSA的代码逻辑
def window_partition(x, window_size): """ x: (B, H, W, C) 返回: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() windows = windows.view(-1, window_size, window_size, C) return windows # 在每个窗口内独立计算注意力 def W_MSA(x): windows = window_partition(x, M=7) # 切分窗口 attn_windows = self_attention(windows) # 窗口内注意力 x = window_reverse(attn_windows) # 还原 return x三、为什么W-MSA不够?致命缺陷!
问题:窗口之间完全隔离
窗口划分示例(M=4): ┌────┬────┬────┬────┐┌────┬────┬────┬────┐ │ 🔴 │ │ │ ││ 🔵 │ │ │ │ ├────┼────┼────┼────┤├────┼────┼────┼────┤ │ │ │ │ ││ │ │ │ │ ├────┼────┼────┼────┤├────┼────┼────┼────┤ │ │ │ │ ⚠️ ││ ⚠️ │ │ │ │ ├────┼────┼────┼────┤├────┼────┼────┼────┤ │ │ │ │ ││ │ │ │ │ └────┴────┴────┴────┘└────┴────┴────┴────┘ Window 0 Window 1 ⚠️ 问题: - 🔴 和 🔵 永远无法交互! - 窗口边界的信息被人为割裂 - 无法建模跨窗口的长距离依赖具体例子:
物体检测任务: 一只猫横跨两个窗口 ┌────────────┐┌────────────┐ │ 🐱头部 ││ 🐱身体 │ │ Window 0 ││ Window 1 │ └────────────┘└────────────┘ 纯W-MSA的结果: - Window 0只能看到头部特征 - Window 1只能看到身体特征 - 模型无法理解这是同一只猫!四、SW-MSA:Shifted Window的解决方案
核心思想:移动窗口位置
第 l 层 (W-MSA): ┌────┬────┐┌────┬────┐ │ A │ B ││ C │ D │ ├────┼────┤├────┼────┤ │ E │ F ││ G │ H │ └────┴────┘└────┴────┘ 窗口0 窗口1 第 l+1 层 (SW-MSA): 向右下移动 M/2 ┌────┬────┐┌────┬────┐ │ ? │ A ││ B │ ? │ ├────┼────┤├────┼────┤ │ ? │ E ││ F │ ? │ └────┴────┘└────┴────┘ 新窗口0 新窗口1 关键变化: - 原来A和C在不同窗口,现在在同一窗口! - B、F可以交互了 - 实现了跨原始窗口边界的信息交换详细图解(M=4的例子)
W-MSA (第l层):
原始窗口划分 (4×4) ┌─────────┬─────────┐ │ 0 1 2 │ 3 4 5 │ │ 6 7 8 │ 9 10 11 │ │12 13 14 │15 16 17 │ ├─────────┼─────────┤ ← 窗口边界 │18 19 20 │21 22 23 │ │24 25 26 │27 28 29 │ │30 31 32 │33 34 35 │ └─────────┴─────────┘ 问题: - patch 14 和 15 无法交互(被边界隔开) - patch 8 和 21 无法交互(被边界隔开)SW-MSA (第l+1层):
窗口向右下移动2个位置 (M/2=2) ┌─────────┬─────────┐ │ 7 8 9 │10 11 │ │13 14 15 │16 17 │ │19 20 21 │22 23 │ ├─────────┼─────────┤ │25 26 27 │28 29 │ │31 32 33 │34 35 │ └─────────┴─────────┘ 现在: - ✅ patch 14 和 15 在同一窗口了! - ✅ patch 8 和 21 也可以交互了!SW-MSA的实现技巧:Cyclic Shift
直接移动会产生不规则窗口(边界问题),Swin使用巧妙的循环移位+mask
# 第1步:Cyclic Shift(循环移位) def cyclic_shift(x, shift_size): return torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) # 移位前 ┌──┬──┬──┬──┐ │A │B │C │D │ ├──┼──┼──┼──┤ │E │F │G │H │ ├──┼──┼──┼──┤ │I │J │K │L │ ├──┼──┼──┼──┤ │M │N │O │P │ └──┴──┴──┴──┘ # 向右下移动2格后(循环) ┌──┬──┬──┬──┐ │K │L │I │J │ ← I、J循环到右边 ├──┼──┼──┼──┤ │O │P │M │N │ ├──┼──┼──┼──┤ │C │D │A │B │ ← A、B循环到右边 ├──┼──┼──┼──┤ │G │H │E │F │ └──┴──┴──┴──┘第2步:使用Attention Mask
循环移位后,某些patch不应该交互(比如原本距离很远)
# 构造mask防止不相邻区域交互 mask矩阵示例: ┌────┬────┬────┬────┐ │ 0 │ -∞ │ 0 │ -∞ │ ← 0表示可见,-∞表示屏蔽 ├────┼────┼────┼────┤ │ -∞ │ 0 │ -∞ │ 0 │ ├────┼────┼────┼────┤ │ 0 │ -∞ │ 0 │ -∞ │ ├────┼────┼────┼────┤ │ -∞ │ 0 │ -∞ │ 0 │ └────┴────┴────┴────┘ # 注意力计算 attn = softmax(Q·K^T / √d + mask) # mask=-∞的位置,softmax后=0,不产生交互完整SW-MSA代码
def SW_MSA(x, shift_size): B, H, W, C = x.shape # Step 1: Cyclic Shift shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) # Step 2: 窗口划分 x_windows = window_partition(shifted_x, window_size) # Step 3: 窗口内注意力 (带mask) attn_windows = self_attention(x_windows, mask=attn_mask) # Step 4: 窗口合并 shifted_x = window_reverse(attn_windows, window_size, H, W) # Step 5: Reverse Cyclic Shift(移回去) x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2)) return x五、为什么要交替使用W-MSA和SW-MSA?
设计原理
Swin Block结构(你的图中): ┌─────────────────────────────┐ │ Block l (偶数层) │ │ Layer Norm │ │ W-MSA ← 固定窗口 │ │ Residual │ │ Layer Norm │ │ MLP │ │ Residual │ └─────────────────────────────┘ ↓ ┌─────────────────────────────┐ │ Block l+1 (奇数层) │ │ Layer Norm │ │ SW-MSA ← 移动窗口 │ │ Residual │ │ Layer Norm │ │ MLP │ │ Residual │ └─────────────────────────────┘交替使用的三大原因
1. 建立跨窗口连接
2层的信息流动: Layer l (W-MSA): ┌────┬────┐┌────┬────┐ │ A │ B ││ C │ D │ └────┴────┘└────┴────┘ 窗口内交互: A↔B, C↔D Layer l+1 (SW-MSA): ┌────┬────┐ │ A │ B │ ← A现在可以和C交互了! │ C │ D │ └────┴────┘ 跨窗口交互: A↔C, B↔D 结果: 2层后,所有patch间接相连!2. 高效建模长距离依赖
信息传播路径: A → B (layer l) ↓ B → D (layer l+1) ← 跨窗口 ↓ D → ... (layer l+2) 经过多层后: - 每个patch的感受野呈指数增长 - O(log(H/M))层后达到全局感受野 - 远比全局注意力O(n²)高效3. 保持计算效率
复杂度对比: 纯全局注意力(ViT): 每层 = O((H×W)²) 4层 = 4 × O((H×W)²) Swin (W-MSA + SW-MSA交替): 每层 = O(M² × H/M × W/M) = O(M×H×W) ← 线性复杂度! 4层 = 4 × O(M×H×W) 当H=W=56, M=7: ViT: 4 × (56×56)² ≈ 40M Swin: 4 × 7 × 56 × 56 ≈ 88K六、直观类比
城市交通网络类比
W-MSA= 社区内道路
┌─────────┐ ┌─────────┐ │ 社区A │ │ 社区B │ │ 🏠→🏠→🏠│ │🏠→🏠→🏠│ │ ↓ │ │ ↓ │ │ 🏠 │ │ 🏠 │ └─────────┘ └─────────┘ 社区内高效通行,但社区间隔离SW-MSA= 跨社区快速路
┌─────────┐ │ 🏠 🏠 │ ← A和B的居民可以交流了 │ 🏠 🏠 │ └─────────┘ 定期调整道路布局,连接不同社区交替使用= 完整交通网络
- 局部有社区道路(高效)
- 定期有快速路(连接)
- 几次换乘可达任意地点(全局感受野)
七、与ViT的对比总结
| 特性 | ViT | Swin Transformer |
|---|---|---|
| 注意力范围 | 全局 | 局部窗口 |
| 计算复杂度 | O(n²) | O(n) |
| 特征层次 | 单尺度 | 多尺度金字塔 |
| 跨窗口交互 | - | Shifted Window |
| 适用任务 | 分类 | 分类+检测+分割 |
| 高分辨率 | ❌ 不可行 | ✅ 可行 |
八、实际效果验证
# ImageNet分类性能 Swin-T (28M参数): 81.3% Top-1 ViT-B (86M参数): 81.8% Top-1 → Swin用更少参数达到相近性能 # COCO目标检测 Swin-T: 50.5 mAP ResNet-50: 46.9 mAP → Swin的层次化特征更适合检测 # 推理速度(224×224图像) ViT-B: 18ms Swin-T: 12ms → Swin更高效总结:为什么必须交替?
W-MSA单独使用:
- ✅ 计算高效
- ❌ 窗口隔离,无全局视野
SW-MSA的作用:
- 打破窗口边界
- 建立跨窗口连接
- 扩大感受野
交替使用的智慧:
- 平衡局部效率和全局感受野
- 2层就能让所有patch间接相连
- 深层网络后达到全局理解能力
核心公式:
局部效率 (W-MSA) + 全局连接 (SW-MSA) = 高效的全局建模swin transform中的patch merging操作,是不是在通道维度合并,还是在空间维度合并
看你上传的图片!我来帮你详细分析Swin Transformer的Patch Merging操作。
答案:在空间维度合并,同时通道维度扩展
这个操作同时影响两个维度,但核心是空间下采样:
Patch Merging的具体过程
操作步骤(结合你的图片)
输入特征图:H × W × C(比如图中的7×7网格,每个patch有96通道)
第1步:空间划分(2×2分组)
将相邻的2×2个patches分成一组 ┌─────┬─────┐ │ P00 │ P01 │ → 一组 ├─────┼─────┤ │ P10 │ P11 │ └─────┴─────┘从你的图片看:
- 原始:7×7 = 49个patches
- 分组后:每4个相邻patches作为一组
- 结果:空间分辨率降低到大约 4×4(向下取整)
第2步:通道拼接
# 伪代码演示 x = x.reshape(B, H//2, 2, W//2, 2, C) # 重排形状 x = x.permute(0, 1, 3, 2, 4, 5) # 调整维度顺序 x = x.reshape(B, H//2, W//2, 4*C) # 拼接:4个patch的C通道 → 4C通道 # 具体例子 输入: [7, 7, 96] # H=7, W=7, C=96 ├─ 重排: [3, 2, 3, 2, 96] # 7//2=3 (向下取整) └─ 拼接: [3, 3, 4×96] = [3, 3, 384] # 4C = 4×96 = 384关键理解(对应图中的标注):
- 4C = 4 × 96:4个patches的通道拼接在一起
- 空间:7×7 → 约4×4(实际会是 ⌊7/2⌋ × ⌊7/2⌋ = 3×3)
- 通道:96 → 384
第3步:线性降维(可选)
x = Linear(4C, 2C)(x) # 384 → 192 # 最终输出 输入: [7, 7, 96] 输出: [3, 3, 192] # 空间减半,通道翻倍通常会接一个线性层把4C降到2C,这样:
- 空间分辨率降低2倍(H/2 × W/2)
- 通道数翻倍(C → 2C)
图解对比
原始patches (你的图片): ┌───┬───┬───┬───┬───┬───┬───┐ │ │ │ │ │ │ │ │ 7×7网格 ├───┼───┼───┼───┼───┼───┼───┤ 每个格子=1个patch │ │🐱│ │ │🐱│ │ │ 通道数=96 ├───┼───┼───┼───┼───┼───┼───┤ │ │🐱│🐱│🐱│🐱│ │ │ └───┴───┴───┴───┴───┴───┴───┘ ↓ Patch Merging 合并后: ┌─────────┬─────────┬─────────┐ │ │ │ │ 约4×4网格 │ 4个合1 │ 4个合1 │ 4个合1 │ 每个大格=4个原始patch ├─────────┼─────────┼─────────┤ 通道数=384 (4×96) │ 🐱🐱 │ 🐱🐱 │ │ └─────────┴─────────┴─────────┘ ↓ Linear降维 (可选) 最终输出: 空间:约4×4 通道:192 (2×96)为什么这样设计?
1. 模拟CNN的下采样金字塔
Stage 1: 56×56 patches, 96通道 ┐ ↓ Patch Merging │ 类似ResNet的 Stage 2: 28×28 patches, 192通道 │ 多尺度特征 ↓ Patch Merging │ Stage 3: 14×14 patches, 384通道 │ ↓ Patch Merging │ Stage 4: 7×7 patches, 768通道 ┘2. 计算效率考虑
- 空间减半:计算量降低到原来的1/4
- 通道翻倍:增加模型容量,学习更复杂特征
- 平衡:深层网络用更少的空间位置,但更丰富的通道特征
代码实现解析
class PatchMerging(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = nn.LayerNorm(4 * dim) def forward(self, x): """ x: (B, H, W, C) """ B, H, W, C = x.shape # 确保H和W是偶数(padding处理) pad_input = (H % 2 == 1) or (W % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) # 空间维度合并:取2×2网格的4个位置 x0 = x[:, 0::2, 0::2, :] # 左上角 [B, H/2, W/2, C] x1 = x[:, 1::2, 0::2, :] # 右上角 x2 = x[:, 0::2, 1::2, :] # 左下角 x3 = x[:, 1::2, 1::2, :] # 右下角 # 通道维度拼接:4个patch的通道拼在一起 x = torch.cat([x0, x1, x2, x3], dim=-1) # [B, H/2, W/2, 4C] x = self.norm(x) x = self.reduction(x) # [B, H/2, W/2, 2C] return x # 使用示例 input_tensor = torch.randn(1, 7, 7, 96) # [B, H, W, C] merge = PatchMerging(dim=96) output = merge(input_tensor) print(output.shape) # torch.Size([1, 4, 4, 192]) # ↑ ↑ ↑ # H/2 W/2 2C与你图片的对应关系
你的图片展示的关键点:
7×7的网格:原始patches的空间布局
4C = 4 × 96:
- 4 = 2×2个相邻patches
- 96 = 原始每个patch的通道数
- 384 = 合并后的临时通道数
猫的图像被分割:
- 每个小格子 = 1个patch embedding
- Merging后,4个小格子的信息融合到1个大格子
总结对比
| 维度 | 操作类型 | 变化 | 目的 |
|---|---|---|---|
| 空间 | 下采样 | H×W → (H/2)×(W/2) | 减少计算量,类似池化 |
| 通道 | 拼接+降维 | C → 4C → 2C | 增加表达能力 |
核心答案:
- 主要在空间维度合并(2×2的patches变成1个)
- 通道维度是副作用(4个patch的通道拼接导致4C,然后降到2C)
与ViT的Patch Embedding对比
ViT的Patch Embedding: 图像 [224×224×3] → 切分成patches [14×14个, 每个16×16×3] → 线性投影 [14×14, 768维] → 全程保持14×14,不做merging Swin的Patch Merging: 特征图 [56×56×96] → 每2×2 patches合并 [28×28×384] → 线性降维 [28×28×192] → 多次merging构建金字塔 [56→28→14→7]Swin的优势:多尺度特征层次,更适合密集预测任务(检测、分割)