1.Ulysses Context Parallel (上下文并行)原理
下面的例子主要展示的是 Image Tokens(最复杂的部分)。
TeleTron 中 DiT 模型处理长序列的核心机制:如何通过SeqAllToAll4D在“序列并行”和“头并行”之间转换。
4 个大框 (GPU 0 - GPU 3):代表参与并行计算的 4 张显卡。
小色块:代表数据张量(Tensor)。
颜色:代表数据最初属于哪个 GPU(即属于序列的哪一部分)。
文字:
H0-H3代表注意力头(Heads),S0-S3代表序列片段(Sequence Chunks)。
1.1 初始状态 (Sequence Parallel)
长序列被切分为 4 段(S0, S1, S2, S3)。
GPU 0 只有 S0,但它拥有 S0 的所有注意力头 (Head 0-3)。
问题:Attention 需要计算全局关联,GPU 0 只有 S0,看不到 S1-S3,无法直接计算全局 Attention。
1.2 动态Padding
如果序列长度是 101,GPU 是 4 个。101 % 4≠ 0。pad.py 会在 All-to-All 之前将序列补齐到 104(能被 4 整除),确保每个 GPU 分到的数据块大小一致,否则通信原语会报错。
1.3 SeqAllToAll4D (Scatter)
这是 Ulysses 的核心魔法。它执行了一个转置 (Transpose)操作。
观察动画:GPU 0 把 Head 1 发给 GPU 1,Head 2 发给 GPU 2,Head 3 发给 GPU 3。同时它也接收了别人的 Head 0。
结果:现在 GPU 0 拥有了S0, S1, S2, S3的全部数据,但只包含 Head 0
1.4 Attention 计算
因为 GPU 0 现在拥有了全序列(S0-S3)的 Head 0 数据,它可以直接进行标准的 Attention 计算(Q、K、V 都在本地了)。这就是为什么叫“Context Parallel”——每个 GPU 处理一部分 Context(这里是按 Head 划分)。
1.5 SeqAllToAll4D (Gather)
计算完 Attention 后,数据是按 Head 划分的。为了进行下一层网络(如 MLP)的计算,必须还原回按 Sequence 划分的布局。
执行逆向操作,数据飞回各自原本的 GPU。
2. 为什么需要 All-to-All?(核心矛盾)
在计算下列公式时,存在一个矛盾:
输入状态(为了存下长序列):我们把超长的序列切几段,每张卡存一段。
卡1拥有: 序列的第 0~100 个字,所有的注意力头。
卡2拥有: 序列的第 101~200 个字,所有的注意力头。
计算需求(Self-Attention):第 0 个字(卡1)需要和第 199 个字(卡2)计算相关性。
问题:如果不通信,卡1根本看不到卡2里的数据,无法计算全局 Attention。
3. All-to-All 做了什么?(维度交换)
SeqAllToAll4D 在这里执行了一个“洗牌”操作。它让所有显卡互相交换数据,从而改变数据的切分维度。
操作前(序列并行):
卡1:拿着部分序列(Seq 1/N),但是有全部特征头(Heads All)。
卡2:拿着部分序列(Seq 2/N),但是有全部特征头(Heads All)。
(此时无法做全局 Attention)
⬇️执行 All-to-All 通信⬇️
(卡1把它的第2个头的数据发给卡2,卡2把它的第1个头的数据发给卡1...)操作后(头并行):
卡1:拿着全部序列(Seq All),但是只有部分特征头(Head 1/N)。
卡2:拿着全部序列(Seq All),但是只有部分特征头(Head 2/N)。
(此时可以做 Attention 了!因为卡1拥有所有序列的 Head 1 信息,它可以在 Head 1 的维度上计算第 0 个字和第 199 个字的关系
4. TeleTron 的区别对待
理解了 All-to-All 是在做“切分维度的转换”,我们就能看懂 TeleTron 为什么要对文本特殊处理。
4.1 图像 Tokens (Long Sequence) - 必须做 All-to-All
图像序列太长了(比如 4K 分辨率),单卡显存根本存不下完整的序列 (𝑄,𝐾,𝑉)
All-to-All #1:把切散的序列拼凑回来,同时把 Heads 切散分给各卡。
计算 Attention:每张卡算自己那部分 Heads 的全局 Attention。
All-to-All #2:再次通信,把 Heads 拼凑回来,重新把序列切散(为了省显存)。
代价:两次巨大的通信开销。
4.2 . 文本 Tokens (Short Sequence) - 优化方案
文本通常很短(几十到几百个 Token),或者是作为 Condition(条件)。
现状:单卡完全存得下完整的文本序列。
TeleTron 的逻辑:既然单卡能存下完整的文本序列,为什么还要浪费时间做那两次昂贵的 All-to-All 来切分 Heads 呢?
优化操作:
Split Forward (切分):仅仅是将文本数据简单地切分或复制,使其在物理位置上与图像数据所在的设备对齐(Align),以便后续做 Cross-Attention 或者拼接。
Skip All-to-All:直接跳过了图像那种“序列换头”的复杂通信过程。
计算:直接利用本地数据参与计算。
Gather Backward (聚合):反向传播时,简单地把梯度收集起来即可。
5. 总结
All-to-All 在这里就是“为了让原本被切碎的长序列能看到彼此,暂时把注意力头(Heads)切碎来换取视野”的操作。
TeleTron 对文本的优化在于:文本不够长,不需要“切碎头换视野”,直接处理即可,从而省下了昂贵的通信费。