好的,我们来详细解释一下MegatronInferStrategy类中的get_data_input函数,并举例说明其在分布式环境下的工作方式以及batch的形状。
一、核心目标 (Core Goal)
get_data_input函数的核心目标是:在复杂的分布式并行环境中,确保所有需要数据的 GPU 进程(Rank)都能正确接收到输入数据(DataProto对象)。
在 Megatron 的 3D 并行(数据并行 DP、张量并行 TP、流水线并行 PP)设置中,数据通常只由一个或少数几个"领导"进程加载。此函数负责将这份数据广播(Broadcast)到其他协同工作的进程。
- 对于张量并行 (TP) 和上下文并行 (CP):在同一个流水线阶段(Pipeline Stage)内,所有 TP/CP 进程需要处理相同的 micro-batch。因此,数据需要从该阶段的"领导"进程(如 TP Rank 0)广播到其他所有 TP/CP 进程。
- 对于流水线并行 (PP):虽然只有第一个流水线阶段(PP Rank 0)直接使用原始输入张量(如
input_ids),但后续阶段的进程可能也需要访问数据中的元信息(meta_info),例如批次大小、控制标志等。因此,包含元信息的整个数据对象也需要广播到所有流水线阶段。
二、DataProto结构回顾
在分析函数之前,我们先回顾一下DataProto的结构。它是一个自定义的数据容器,通常包含:
batch: 一个字典,存储了所有的张量数据,例如{'input_ids': ..., 'attention_mask': ...}。non_tensor_batch: 一个字典,存储非张量数据。meta_info: 一个字典,存储元数据,例如{'global_step': 100, 'micro_batch_size': 4}。
三、函数分步解析 (get_data_input)
defget_data_input(self,batch:DataProto):# 1. 定义一个辅助函数,用于广播 Python 对象defbroadcast_obj(obj,group):# 只有指定 group 内的 rank 0 才持有对象,其他 rank 持有 Noneobj_list=[objifdist.get_rank(group)==0elseNone]# 获取 group 内 rank 0 的全局 rank 作为广播源src_rank=dist.get_process_group_ranks(group)[0]# 从 src_rank 将 obj_list[0] 广播到 group 内的所有其他 rankdist.broadcast_object_list(obj_list,src=src_rank,group=group)# 返回广播后的对象returnobj_list[0]# 2. 检查是否需要广播非张量数据broadcast_non_tensor_batch=batch.meta_info.get("_broadcast_non_tensor_batch",False)# 3. 第一层广播:在流水线第一阶段内部,进行 TP/CP 广播ifmpu.get_pipeline_model_parallel_rank()==0andmpu.get_tensor_and_context_parallel_world_size()>1:# 这个条件确保了:# a. 当前进程位于第一个流水线阶段 (PP Rank 0)# b. 存在张量并行或上下文并行 (TP/CP world size > 1)ifbroadcast_non_tensor_batch:# 如果标志为真,广播整个 batch 对象tmp_batch=broadcast_obj(batch,mpu.get_tensor_and_context_parallel_group())batch.batch=tmp_batch.batch batch.non_tensor_batch=tmp_batch.non_tensor_batchelse:# 默认只广播张量部分batch.batch=broadcast_obj(batch.batch,mpu.get_tensor_and_context_parallel_group())# 作用:将数据从 PP0-TP0-CP0 广播到 PP0-TP1-CP0, PP0-TP0-CP1, ... 等,确保第一阶段的所有工作进程都有相同的输入张量。# 4. 第二层广播:跨流水线阶段,进行 PP 广播ifmpu.get_pipeline_model_parallel_world_size()>1:# 这个条件确保了:# a. 流水线并行被启用 (PP world size > 1)ifbroadcast_non_tensor_batch:# 广播整个 batch 对象tmp_batch=broadcast_obj(batch,mpu.get_pipeline_model_parallel_group())batch.batch=tmp_batch.batch batch.non_tensor_batch=tmp_batch.non_tensor_batchelse:# 默认只广播张量部分batch.batch=broadcast_obj(batch.batch,mpu.get_pipeline_model_parallel_group())# 作用:将数据从 PP0 的某个进程广播到 PP1, PP2, ... 的对应进程。# 这确保了即使后续阶段不使用输入张量,它们也能访问到 DataProto 中的 meta_info,# 从而保持所有阶段的行为一致性(例如,知道总共有多少个 micro-batch)。returnbatch四、举例说明:8卡 3D 并行场景
假设我们有8个GPU,并行配置如下:
- 数据并行 (DP) = 2
- 流水线并行 (PP) = 2
- 张量并行 (TP) = 2
我们可以将8个全局Rank这样分组((DP, PP, TP) 坐标):
- DP 组 0:
Rank 0: (0, 0, 0) -> 第1个数据副本,第1个流水线阶段,第1个张量分片Rank 1: (0, 0, 1) -> 第1个数据副本,第1个流水线阶段,第2个张量分片Rank 2: (0, 1, 0) -> 第1个数据副本,第2个流水线阶段,第1个张量分片Rank 3: (0, 1, 1) -> 第1个数据副本,第2个流水线阶段,第2个张量分片
- DP 组 1:
Rank 4: (1, 0, 0) -> 第2个数据副本,第1个流水线阶段,第1个张量分片Rank 5: (1, 0, 1) -> 第2个数据副本,第1个流水线阶段,第2个张量分片Rank 6: (1, 1, 0) -> 第2个数据副本,第2个流水线阶段,第1个张量分片Rank 7: (1, 1, 1) -> 第2个数据副本,第2个流水线阶段,第2个张量分片
数据加载:通常,数据加载器只会将数据发送到每个数据并行组的"领导"进程,即Rank 0和Rank 4。我们以Rank 0为例,它收到了一个DataProto对象。
get_data_input执行流程 (以 DP 组 0 为例):
Rank 0(0, 0, 0) 执行:- 进入第三步 (TP/CP 广播):
mpu.get_pipeline_model_parallel_rank()返回 0,满足条件。mpu.get_tensor_and_context_parallel_world_size()返回 2,满足条件。mpu.get_tensor_and_context_parallel_group()包含Rank 0和Rank 1。Rank 0是这个 group 的 rank 0,它将batch.batch广播给Rank 1。
- 进入第四步 (PP 广播):
mpu.get_pipeline_model_parallel_world_size()返回 2,满足条件。mpu.get_pipeline_model_parallel_group()包含Rank 0和Rank 2。Rank 0是这个 group 的 rank 0,它将batch.batch广播给Rank 2。
- 进入第三步 (TP/CP 广播):
Rank 1(0, 0, 1) 执行:- 进入第三步 (TP/CP 广播):
- PP Rank 是 0,TP/CP Size 是 2,满足条件。
- 它属于
{Rank 0, Rank 1}group,但它不是 group 内的 rank 0。因此,它接收来自Rank 0的广播数据。
- 进入第四步 (PP 广播):
- PP Size 是 2,满足条件。
- 它的 PP group 是
{Rank 1, Rank 3}。Rank 1是这个 group 的 rank 0。 - 它将刚刚从
Rank 0收到的数据再广播给Rank 3。
- 进入第三步 (TP/CP 广播):
Rank 2(0, 1, 0) 执行:- 跳过第三步:
mpu.get_pipeline_model_parallel_rank()返回 1,不满足条件。
- 进入第四步 (PP 广播):
- PP Size 是 2,满足条件。
- 它属于
{Rank 0, Rank 2}group,但它不是 group 内的 rank 0。因此,它接收来自Rank 0的广播数据。
- 跳过第三步:
Rank 3(0, 1, 1) 执行:- 跳过第三步:PP Rank 是 1。
- 进入第四步 (PP 广播):
- PP Size 是 2,满足条件。
- 它属于
{Rank 1, Rank 3}group,但它不是 group 内的 rank 0。因此,它接收来自Rank 1的广播数据。
最终结果:经过get_data_input函数后,DP 组 0 内的Rank 0, 1, 2, 3都拥有了完全相同的batch.batch数据。同样的逻辑也适用于 DP 组 1(由Rank 4发起广播)。这样,所有的进程都准备好了处理各自的计算任务。
五、Batch Shape 示例
现在,我们来看batch的具体形状。假设我们的模型处理一个批次大小为16,序列长度为2048的数据。
在get_data_input之前,Rank 0和Rank 4上的DataProto对象可能如下:
# DataProto object on Rank 0 DataProto( batch={ 'input_ids': torch.LongTensor of shape [16, 2048], 'attention_mask': torch.LongTensor of shape [16, 2048], # 其他可能的张量... }, meta_info={ 'global_step': 100, 'micro_batch_size': 4, # 假设每个 micro-batch 是 4 # ... } )input_ids: 形状为[16, 2048]。16是整个 global batch 在这个 DP replica 上的大小,2048是序列长度。attention_mask: 形状与input_ids相同,[16, 2048]。
在get_data_input之后,Rank 0, 1, 2, 3, 4, 5, 6, 7所有8个进程的batch.batch属性都会指向一个与上述batch字典内容和形状完全相同的副本。
后续的inner_forward_step函数会根据并行策略对这个完整的batch进行切分:
- Micro-batch 切分:
[16, 2048]的数据会被切分成 16 / 4 = 4 个 micro-batch,每个形状为[4, 2048]。 - TP/CP 切分:在
inner_forward_step内部,_get_feature_on_this_cp_rank等函数会进一步对[4, 2048]的 micro-batch 进行张量切分。例如,对于张量并行,一个[4, 2048, vocab_size]的logits张量在Rank 0上会是[4, 2048, vocab_size/2],在Rank 1上是另一半。
总结来说,get_data_input函数是数据从"加载点"到"计算点"的关键桥梁,它通过两层广播机制,巧妙地解决了在复杂 3D 并行下的数据分发问题。
好的,我们来详细拆解broadcast_obj这个辅助函数,并通过一个具体的例子来解释它的工作原理。
函数目标
broadcast_obj的目标是:在一个指定的分布式进程组 (group) 中,将一个Python对象(不一定是Tensor)从该组的第0个进程(rank 0)广播给组内所有其他进程。
这对于同步非张量数据(如配置字典、字符串、或者像DataProto这样的自定义对象)非常有用,因为torch.distributed.broadcast只支持张量。
函数逐行解析
defbroadcast_obj(obj,group):# 1. 创建一个列表,其中只有 group 内的 rank 0 持有对象# dist.get_rank(group) 获取当前进程在指定 group 内的相对排名。# 如果当前进程是 group 内的 rank 0,obj_list 就是 [obj]。# 如果不是,obj_list 就是 [None]。obj_list=[objifdist.get_rank(group)==0elseNone]# 2. 获取广播的源头(source rank)# dist.get_process_group_ranks(group) 返回一个列表,包含了 group 内所有进程的全局 rank。# 例如,一个 group 可能由全局的 Rank 4, 5, 6, 7 组成,这个函数就返回 [4, 5, 6, 7]。# [0] 表示我们取这个列表的第一个元素,也就是 group 内 rank 0 对应的全局 rank。# 这确保了广播源是固定的,即组内的领导者。src_rank=dist.get_process_group_ranks(group)[0]# 3. 执行广播# torch.distributed.broadcast_object_list 是 PyTorch 的一个函数,专门用于广播 Python 对象列表。# - obj_list: 这是输入/输出参数。在调用前,只有 src_rank 上有对象,其他进程是 [None]。# 调用后,所有进程的 obj_list 都会被 src_rank 上的 obj_list[0] 覆盖。# - src: 指定哪个全局 rank 是广播的源头。# - group: 限定广播只在这个进程组内发生。dist.broadcast_object_list(obj_list,src=src_rank,group=group)# 4. 返回结果# 因为广播后,所有进程的 obj_list 都变成了 [obj],# 所以 obj_list[0] 就是从源头广播过来的那个对象。returnobj_list[0]举例说明:TP 组内的广播
让我们回到之前的 8 卡 3D 并行场景,并聚焦于TP 广播这一步。
场景:
- 我们正在执行
get_data_input函数。 - 当前进程组是
mpu.get_tensor_and_context_parallel_group(),我们简化一下,只考虑 TP 组。 - 在
DP 组 0的PP 阶段 0,这个 TP 组由全局 Rank 0和全局 Rank 1组成。 - 数据加载后,只有
Rank 0持有batch.batch这个字典对象,Rank 1没有。
broadcast_obj调用:Rank 0和Rank 1都会调用broadcast_obj(batch.batch, tp_group)。
执行流程:
在全局 Rank 0上:
dist.get_rank(group):Rank 0在这个 TP 组{Rank 0, Rank 1}中的相对排名是0。obj_list = ...:dist.get_rank(group) == 0为True。Rank 0持有batch.batch对象(我们称之为B)。所以obj_list变成了[B]。src_rank = ...:dist.get_process_group_ranks(group)返回[0, 1](TP组内所有成员的全局Rank)。[0]取第一个元素,所以src_rank被设置为0。
dist.broadcast_object_list(...):Rank 0调用此函数,它作为源(src=0),将自己的obj_list[0](也就是B)广播给组内所有其他成员(这里是Rank 1)。- 调用结束后,
Rank 0的obj_list仍然是[B]。
return obj_list[0]: 函数返回B。
在全局 Rank 1上:
dist.get_rank(group):Rank 1在这个 TP 组{Rank 0, Rank 1}中的相对排名是1。obj_list = ...:dist.get_rank(group) == 0为False。Rank 1此时没有batch.batch对象,所以obj_list变成了[None]。src_rank = ...:dist.get_process_group_ranks(group)返回[0, 1]。[0]取第一个元素,所以src_rank同样被设置为0。Rank 1也知道了广播的源头是Rank 0。
dist.broadcast_object_list(...):Rank 1调用此函数,它作为接收方。- 它会等待
src=0(即Rank 0) 发送数据。 - 当它收到
Rank 0广播过来的对象B后,它会用B覆盖自己的obj_list。所以Rank 1上的obj_list从[None]变成了[B]。
return obj_list[0]: 函数返回B。
最终结果
调用broadcast_obj之后:
Rank 0返回了它本来就有的batch.batch对象。Rank 1返回了它从Rank 0那里接收到的batch.batch对象。
现在,Rank 0和Rank 1都拥有了完全相同的batch.batch字典,数据同步完成。
为什么这么设计?
- 健壮性:
src_rank = dist.get_process_group_ranks(group)[0]这种写法比直接写src=0更健壮。它不依赖于全局 Rank 0 一定是某个组的领导者。它动态地找出任何一个给定group的领导者(即该组中全局 Rank 值最小的那个进程)。 - 通用性: 这个函数可以广播任何可被
pickle序列化的 Python 对象,使其非常通用。 - 简洁性: 将复杂的分布式通信逻辑封装在一个简单的函数中,使得上层代码(如
get_data_input)更加清晰易读。它隐藏了"谁是源"、“谁是目标”、"如何创建占位符"等细节。
不,这些mpu.get_..._rank()函数返回的不是全局 Rank 编号。
它们返回的是当前进程在其特定并行维度上的局部(或相对)Rank 编号。这是一个非常关键的区别,理解它对于理解 Megatron 的并行机制至关重要。
让我们逐一解释,并用我们之前的 8 卡例子来说明。
关键概念
- 全局 Rank (Global Rank): 在整个分布式任务中,每个进程都有一个从
0到N-1(N是总进程数)的唯一标识符。这就是全局 Rank。通常由torch.distributed.get_rank()获取。 - 局部/相对 Rank (Local/Relative Rank): 在一个特定的进程组(如数据并行组、张量并行组)内,每个进程会有一个从
0到GroupSize-1的排名。这就是局部 Rank。
Megatron 的mpu(Model Parallel Unit) 模块就是为了方便地管理和查询这些不同并行维度上的局部 Rank。
mpu函数解析与举例
假设我们还是用这个 8 卡的配置:
- DP Size = 2
- PP Size = 2
- TP Size = 2
全局 Rank 分布如下 (坐标(dp_rank, pp_rank, tp_rank)):
Rank 0: (0, 0, 0)Rank 1: (0, 0, 1)Rank 2: (0, 1, 0)Rank 3: (0, 1, 1)Rank 4: (1, 0, 0)Rank 5: (1, 0, 1)Rank 6: (1, 1, 0)Rank 7: (1, 1, 1)
现在,我们来看在全局 Rank 3这个进程上,各个mpu函数的返回值是什么:
当前进程:全局 Rank 3
mpu.get_data_parallel_rank()- 含义: 返回当前进程在其所属的数据并行组中的局部 Rank。
- 分析: 全局 Rank 3 属于数据并行组 0(成员是 Ranks 0, 1, 2, 3)。在这个组内,它的排名是第 3 个(从0开始)。然而,数据并行 Rank 是其在数据并行维度上的坐标。
- 坐标:
(dp=0, pp=1, tp=1)。它的dp_rank坐标是0。 - 返回值:
0
mpu.get_pipeline_model_parallel_rank()- 含义: 返回当前进程在其所属的流水线并行组中的局部 Rank。
- 分析: 全局 Rank 3 属于流水线并行组
{Rank 1, Rank 3}(因为它们的dp_rank和tp_rank相同,都是(0, 1))。在这个组内,Rank 1是局部 Rank 0,Rank 3是局部 Rank 1。 - 坐标:
(dp=0, pp=1, tp=1)。它的pp_rank坐标是1。 - 返回值:
1
mpu.get_tensor_model_parallel_rank()- 含义: 返回当前进程在其所属的张量并行组中的局部 Rank。
- 分析: 全局 Rank 3 属于张量并行组
{Rank 2, Rank 3}(因为它们的dp_rank和pp_rank相同,都是(0, 1))。在这个组内,Rank 2是局部 Rank 0,Rank 3是局部 Rank 1。 - 坐标:
(dp=0, pp=1, tp=1)。它的tp_rank坐标是1。 - 返回值:
1
再举一个例子:全局 Rank 4
当前进程:全局 Rank 4
mpu.get_data_parallel_rank()- 坐标:
(dp=1, pp=0, tp=0)。 - 返回值:
1
- 坐标:
mpu.get_pipeline_model_parallel_rank()- 坐标:
(dp=1, pp=0, tp=0)。 - 返回值:
0
- 坐标:
mpu.get_tensor_model_parallel_rank()- 坐标:
(dp=1, pp=0, tp=0)。 - 返回值:
0
- 坐标:
总结
mpu函数 | 返回值含义 |
|---|---|
mpu.get_data_parallel_rank() | 当前进程的数据并行局部 Rank (它的dp_rank坐标) |
mpu.get_pipeline_model_parallel_rank() | 当前进程的流水线并行局部 Rank (它的pp_rank坐标) |
mpu.get_tensor_model_parallel_rank() | 当前进程的张量并行局部 Rank (它的tp_rank坐标) |
mpu.get_context_parallel_rank() | 当前进程的上下文并行局部 Rank (它的cp_rank坐标) |
这些函数是 Megatron 中进行条件判断和逻辑分支的基石。例如,代码中常见的模式:
if mpu.get_pipeline_model_parallel_rank() == 0:--> “只有流水线的第一阶段才执行此代码块”。if mpu.get_tensor_model_parallel_rank() == 0:--> “只有每个张量并行组的领导者(局部Rank 0)才执行此代码块”(常用于打印日志或保存非切分的权重)。if mpu.get_data_parallel_rank() == 0:--> “只有每个数据并行组的领导者才执行此代码块”(常用于优化器更新或梯度累积)。
通过组合使用这些局部 Rank 查询,Megatron 可以精确地控制每个进程在复杂的并行拓扑中应该扮演的角色和执行的任务。
是的,你说得非常对!这两个dist函数返回的都是全局 Rank 编号。这与mpu模块的函数形成了鲜明的对比,理解这个区别是掌握这段代码的关键。
我们来详细解析这两个函数。
1.dist.get_rank(group)
- 官方文档:
torch.distributed.get_rank(group=group) - 作用: 返回当前进程在指定的进程组
group内的局部(或相对)Rank 编号。 - 关键点:返回值是局部的。
- 示例:
- 假设 TP 组由全局 Rank
{2, 3}组成。 - 在全局 Rank 2 上调用
dist.get_rank(tp_group),返回0。 - 在全局 Rank 3 上调用
dist.get_rank(tp_group),返回1。
- 假设 TP 组由全局 Rank
在broadcast_obj函数中,if dist.get_rank(group) == 0:这个判断就是用来确定“我是否是这个特定group的领导者(局部 Rank 0)?”。
2.dist.get_process_group_ranks(group)
- 官方文档:
torch.distributed.get_process_group_ranks(group) - 作用: 返回一个列表,其中包含了指定进程组
group中所有成员的全局 Rank 编号。 - 关键点:返回的列表内容是全局的。这个列表总是按照局部 Rank 的顺序排列的(即,列表的第
i个元素是该组局部 Ranki对应的全局 Rank)。 - 示例:
- 假设 TP 组由全局 Rank
{2, 3}组成。 - 在任何一个进程(无论是 Rank 2, 3 还是其他 Rank)上调用
dist.get_process_group_ranks(tp_group),都会返回同一个列表[2, 3]。
- 假设 TP 组由全局 Rank
在broadcast_obj函数中,src_rank = dist.get_process_group_ranks(group)[0]的作用就是:
- 获取
group的所有成员的全局 Rank 列表。 - 取列表的第一个元素
[0]。 - 因为列表是按局部 Rank 排序的,所以第一个元素就是这个
group的领导者(局部 Rank 0)的全局 Rank 编号。 - 这个全局 Rank 编号被用作
dist.broadcast_object_list的src参数,告诉 PyTorch 应该从哪个具体的全局进程开始广播。
总结与对比
| 函数调用 | 返回值是全局还是局部? | 作用 |
|---|---|---|
torch.distributed.get_rank()(无参数) | 全局 | 获取当前进程的全局 Rank。 |
torch.distributed.get_rank(group) | 局部 | 获取当前进程在指定group内的局部 Rank。 |
mpu.get_..._rank()(e.g.,mpu.get_tensor_model_parallel_rank()) | 局部 | 获取当前进程在特定并行维度(如 TP)上的局部 Rank。 |
dist.get_process_group_ranks(group) | 全局(列表内容是全局的) | 获取指定group所有成员的全局 Rank 列表。 |
所以,broadcast_obj的逻辑可以这样通俗地理解:
dist.get_rank(group) == 0: “我是不是这个小团体的老大?” -> 如果是,准备好要分享的数据。dist.get_process_group_ranks(group)[0]: “我们这个小团体的老大,他的全局身份证号是多少?” -> 拿到这个号码。dist.broadcast_object_list(..., src=...): “大家注意,请从持有这个全局身份证号的人那里接收数据!”
这种结合使用局部 Rank 进行判断和全局 Rank 进行通信的模式,是 PyTorch 分布式编程中非常标准和健壮的做法。