verl old policy计算:log_prob获取过程梳理
在强化学习驱动的大语言模型后训练中,old policy(即当前策略的旧版本)的对数概率log_prob是PPO及其变体(如GRPO)算法的核心中间量。它不仅用于构建重要性采样权重,还直接参与优势函数估计、KL散度约束与策略梯度更新。然而,verl框架中log_prob的实际计算路径并非线性直白——它横跨数据分片、设备映射、rollout生成、模型并行与日志概率重计算等多个抽象层级,且受FSDP、vLLM、tensor parallelism等多重机制影响。
本文不讲抽象公式,也不堆砌配置参数,而是以一条真实 rollout sample 为线索,从ray_trainer.fit()中的一次compute_log_prob(batch)调用出发,逐层下钻至 GPU 显存中的张量级操作,完整还原π_old(τ_i^(t) | τ_i^{<t})这一关键值是如何被切分、调度、前向、收集并最终组装成 batch-level 张量的。目标是让读者合上文章时,能清晰回答:
- 为什么
log_prob不在 rollout 生成时一并算好,而要单独再跑一遍? log_prob_micro_batch_size_per_gpu=8到底控制哪一段的批处理大小?- 当
data.train_batch_size=60、rollout.n=12、n_gpus=6时,GPU 上真正并发执行 log_prob 计算的 sequence 数量是多少? ActorRolloutRefWorker如何协调 actor 模型、rollout 引擎与 ref 模型三者的 log_prob 计算节奏?
我们不预设你熟悉 FSDP 内部或 vLLM 调度器,所有技术点均以“发生了什么 + 为什么这样设计”双视角展开。
1. 问题起点:为什么需要单独 compute_log_prob?
在verl/verl/trainer/ppo/ray_trainer.py的主训练循环中,你会看到这样一段代码:
# 使用 actor_rollout_wg计算每一个rollout sample中每个token的 old policy log_prob with _timer('old_log_prob', timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) batch = batch.union(old_log_prob)注意:此时batch已不是原始的 60 条 prompt,而是经过generate_sequences()后膨胀至720 条 rollout sequences的结果(60 × 12)。这 720 条序列已包含input_ids、attention_mask、prompt_lengths等字段,但唯独没有log_probs。
这看似低效——既然 actor 模型刚生成了这些 token,为何不顺手把每个 token 的log_prob也缓存下来?答案藏在两个工程现实里:
1.1 rollout 引擎与 policy 模型的职责分离
verl将 rollout(采样)与 log_prob 计算解耦为两个独立阶段,核心原因在于引擎选型自由度:
rollout可以是vLLM(高吞吐、支持 PagedAttention)、SGLang(低延迟、支持 FP8)、或HuggingFace(调试友好、全兼容);- 但
log_prob计算必须由当前训练中的 actor 模型(FSDP 包装的 PyTorch 模块)执行,以保证梯度可回传、参数状态一致、KL 散度可监督。
vLLM/SGLang 的推理引擎虽快,但它们:
- 不暴露 token-level logits 接口(或需 hack);
- 不与 FSDP 的参数分片、梯度同步机制对齐;
- 若强制在 rollout 阶段输出 logits,则无法复用 FSDP 的 offload、sharding、gradient checkpointing 等优化。
因此,verl采用“rollout 只管生成 token ID,log_prob 另起炉灶、用 actor 模型重跑一次前向”的设计。这不是冗余,而是为训练稳定性、梯度一致性与框架可扩展性付出的必要代价。
1.2 log_prob 计算需严格匹配训练时的模型状态
PPO/GRPO 的策略梯度公式中,log_prob必须对应π_old—— 即本次 update 步骤开始时冻结的 actor 模型参数。而 rollout 过程可能持续数毫秒到数百毫秒,在此期间若 actor 模型被其他 worker 更新(如异步训练),则log_prob就不再代表π_old。
verl的解决方案是:在compute_log_prob()被调用的瞬间,确保 actor 模型参数处于 freeze 状态,并使用与 rollout 生成时完全相同的 tokenizer、padding 策略、attention mask 构造逻辑。这种强一致性只能通过统一入口(即ActorRolloutRefWorker.compute_log_prob)来保障。
关键结论:
compute_log_prob不是性能瓶颈的补救,而是 RL 训练语义正确性的基础设施。它本质是一次“带参数快照的、确定性重放”。
2. 数据流拆解:从 720 条 sequence 到 GPU 上的 micro-batch
假设当前训练 step 的batch包含 720 条 rollout sequences,每条 sequence 是一个DataProto对象,内含input_ids: [seq_len]、attention_mask: [seq_len]、prompt_length: int等字段。compute_log_prob(batch)的目标,是为每条 sequence 的每个 token(除 prompt 部分外)计算log π_old(token | context)。
这个过程绝非简单model(input_ids)。它必须适配verl的混合并行架构。我们按执行顺序梳理:
2.1 入口:ActorRolloutRefWorker.compute_log_prob
该方法定义在verl/verl/workers/fsdp_workers.py,其签名如下:
def compute_log_prob(self, batch: DataProto) -> DataProto: # batch: shape [720, max_seq_len] ...第一件事是设备对齐与分片准备:
batch = batch.to(torch.cuda.current_device()) # 移至当前 GPU if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) # 若参数被 offload,先拉回 GPU接着,它会检查self._is_rollout和self._is_actor标志位(回忆:role='actor_rollout'→ 二者均为True),并进入真正的计算分支。
2.2 分片策略:谁负责哪部分数据?
verl不会把全部 720 条 sequence 塞给单个 GPU。它依据device_mesh和log_prob_micro_batch_size_per_gpu进行两级切分:
第一级:按 FSDP world size 切分(数据并行维度)
self.device_mesh.size() == 6(6 张 GPU)self.config.rollout.log_prob_micro_batch_size_per_gpu = 8- 因此,每张 GPU 最多同时处理 8 条 sequence 的 log_prob 计算
这意味着:720 条 sequence 需要被划分为ceil(720 / 6) = 120个“GPU-local batch”,每个 batch 大小为 6(因为 720 ÷ 6 = 120,但实际调度更细粒度)。
但注意:log_prob_micro_batch_size_per_gpu=8是硬性上限,而非固定大小。verl会动态将 720 条数据按8为单位分组,形成720 // 8 = 90个 micro-batch,再将这 90 个 micro-batch轮询(round-robin)分配给 6 张 GPU。
所以,每张 GPU 实际处理90 // 6 = 15个 micro-batch,每个 micro-batch 含 8 条 sequence →每卡共处理 120 条 sequence(15 × 8),6 卡合计 720 条。
验证:
720 ÷ 6 = 120,120 ÷ 8 = 15→ 完全整除,无 padding。这是verl配置校验逻辑(见fsdp_workers.py中assert ... % ... == 0)所保障的。
第二级:按 sequence length 动态填充(避免显存浪费)
每条 sequence 长度不同(prompt + generated tokens)。verl不做 static padding 到max_seq_len,而是采用dynamic batching + left-padding:
- 在每个 micro-batch(8 条 sequence)内,找出
max_seq_len; - 将所有 8 条 sequence左对齐填充至该长度(prompt 在前,generated 在后);
- 构造
input_ids: [8, max_seq_len]、attention_mask: [8, max_seq_len]; position_ids依实际 token 位置生成(非全 0);prompt_lengths: [8]单独保存,用于后续 mask prompt token 的 log_prob。
这种策略显著降低显存占用,尤其当 sequence 长度方差大时。
2.3 前向计算:如何得到 token-level log_prob?
拿到[8, max_seq_len]的 batch 后,compute_log_prob调用 actor 模型前向:
with torch.no_grad(): outputs = self.actor_module_fsdp( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ) logits = outputs.logits # [8, max_seq_len, vocab_size]关键来了:logits是每个位置的未归一化分数,而我们需要的是log π_old(token | context),即:
- 对于第
i条 sequence 的第j个 token(j >= prompt_length[i]),取logits[i, j, input_ids[i, j]]; - 然后减去
logsumexp(logits[i, j, :])(即 softmax 分母的 log)。
verl用高效向量化方式实现:
# logits: [B, T, V], input_ids: [B, T] log_probs = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) # [B, T] # 减去 logsumexp log_probs = log_probs - torch.logsumexp(logits, dim=-1) # [B, T]最后,用prompt_lengthsmask 掉 prompt 部分(只保留 generated token 的 log_prob):
mask = torch.arange(max_seq_len).expand(B, -1) >= prompt_lengths.unsqueeze(1) # [B, T] log_probs = log_probs.masked_fill(~mask, 0.0) # prompt 位置置 0,不影响后续 sum输出log_probs: [8, max_seq_len],连同prompt_lengths、mask一起打包进DataProto,返回给上层。
2.4 汇总:从 6 卡结果到全局 batch
6 张 GPU 并行完成各自 120 条 sequence 的 log_prob 计算后,compute_log_prob需将结果按原始顺序拼接回 720 条。这通过torch.distributed.all_gather实现:
- 每卡输出
local_log_probs: [120, max_seq_len_local](max_seq_len_local可能不同); all_gather收集所有卡的结果,得到global_log_probs: [720, max_seq_len_global];- 同时 gather
prompt_lengths_global: [720]和mask_global: [720, max_seq_len_global]; - 最终封装为
DataProto,字段包括:log_probs: [720, max_seq_len_global]prompt_lengths: [720]mask: [720, max_seq_len_global]
至此,old_log_prob = compute_log_prob(batch)完成,batch.union(old_log_prob)将其注入原始 batch,供后续compute_advantage使用。
3. ref policy 的 log_prob:复用同一套流水线
ref policy的 log_prob 计算(compute_ref_log_prob)与old policy高度对称,区别仅三点:
3.1 模型不同:ref model 替代 actor model
ref policy通常是一个冻结的、早先 checkpoint 的 LLM(如 SFT 后模型);- 它不参与梯度更新,故无需 FSDP 的梯度同步,但依然需要
FSDP加载(因模型可能超大); verl通过self.ref_module_fsdp引用它,并在compute_ref_log_prob中调用其前向。
3.2 分片配置独立:ref 有自己的 micro-batch size
配置项actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8与 rollout 的log_prob_micro_batch_size_per_gpu完全解耦。你可以为 ref 设置=4(更小 batch,更稳显存)或=16(更大 batch,更高吞吐),只要满足整除约束。
这意味着:ref 的 720 条 sequence 也会被切成720 // 4 = 180个 micro-batch(若设为 4),再分发给 6 卡,每卡处理180 // 6 = 30个 micro-batch →每卡 120 条 sequence(30 × 4),与 actor 一致。
3.3 无 offload 开销:ref 通常不启用 param/optimizer offload
回顾fsdp_workers.py初始化逻辑:
elif self._is_ref: self._is_offload_param = self.config.ref.fsdp_config.get('param_offload', False)默认param_offload=False,所以 ref model 始终驻留 GPU,省去了load/offload的 PCIe 传输开销,计算更快。
关键结论:ref policy 的 log_prob 流水线是 actor 的“镜像副本”,仅模型与配置参数不同。
verl的模块化设计让两者共享 90% 以上代码,极大降低维护成本。
4. GRPO 特殊性:log_prob 在无 critic 场景下的角色跃迁
GRPO(Generalized Reward-based Policy Optimization)作为 PPO 的轻量变体,其最大特点是移除 critic model 与 reward model,直接用规则函数reward_fn(batch)输出token_level_scores,并视其为Value。
这使得log_prob的作用发生质变:
| 维度 | PPO(标准) | GRPO(verl 实现) |
|---|---|---|
| log_prob 主要用途 | 构建重要性权重ρ = π_new/π_old,用于裁剪梯度 | 同左;额外承担 KL 散度监控职责(若启用use_kl_loss) |
| KL 散度计算 | 依赖π_old与π_ref的 log_prob 差异,常用于 early stopping | verl中apply_kl_penalty直接基于old_log_prob与ref_log_prob计算 per-token KL,并加权到token_level_rewards |
| advantage 计算 | A = r + γV_{t+1} - V_t,需 critic 输出V | A = r_t(即时 reward),log_prob不参与 advantage 构建,但决定梯度方向 |
看一段verl中 GRPO 的 KL 处理逻辑:
if not self.config.actor_rollout_ref.actor.get('use_kl_loss', False): batch, kl_metrics = apply_kl_penalty( batch, kl_ctrl=self.kl_ctrl, # KL 控制器(如 AdaptiveKLController) kl_penalty=self.config.algorithm.kl_penalty ) metrics.update(kl_metrics) else: # 若 use_kl_loss=True,则 KL 作为 loss 项,而非 penalty batch.batch['token_level_rewards'] = batch.batch['token_level_scores']apply_kl_penalty内部正是用batch['old_log_prob']与batch['ref_log_prob']计算:
kl_div = (ref_log_prob - old_log_prob) * mask # mask 掉 prompt kl_penalty = kl_penalty_coef * kl_div.sum(dim=-1) # per-sequence KL # 加到 reward 上,抑制 policy 过度偏离 ref batch.batch['token_level_rewards'] = batch.batch['token_level_scores'] - kl_penalty.unsqueeze(-1) * mask可见,在 GRPO 中,log_prob不仅是梯度基石,更是策略保守性的实时仪表盘。verl通过将 KL 计算下沉至log_prob层,实现了 zero-shot 的策略约束,无需额外 critic head。
5. 工程实践建议:如何调优 log_prob 性能与显存
理解原理后,落地调优才有依据。以下是基于verl源码与实测经验的四条硬核建议:
5.1 优先调log_prob_micro_batch_size_per_gpu,而非ppo_mini_batch_size
很多用户误以为增大ppo_mini_batch_size能提升吞吐,但log_prob的瓶颈常在per-GPU batch size:
- 设
log_prob_micro_batch_size_per_gpu=8,n_gpus=6→ 每步处理 48 条 sequence; - 若显存充足,可安全增至
=16→ 每步处理 96 条,吞吐翻倍; ppo_mini_batch_size增大会导致 FSDP 分片数变化,可能引发通信开销激增,收益不确定。
行动项:监控nvidia-smi,若 GPU memory usage < 85%,且log_prob阶段耗时 > 200ms,立即将*_micro_batch_size_per_gpu提升 2x。
5.2 rollout 与 log_prob 的 tensor parallelism 必须一致
actor_rollout_ref.rollout.tensor_model_parallel_size=2表示 rollout 用 2 卡协同处理一个 sequence(vLLM 的 TP)。但log_prob计算走的是 actor 的 FSDP 分片,若 actor 的 FSDP 分片数 ≠ rollout TP size,会导致数据错位。
verl的校验逻辑在_build_rollout中:
assert self.world_size % infer_tp == 0 # 6 % 2 == 0 → OK但若你设infer_tp=3,则6 % 3 == 0仍成立,却与 actor 的device_mesh.size()=6不匹配(FSDP 默认按 world_size 分片)。此时log_prob输入的input_idsshape 可能错乱。
行动项:保持rollout.tensor_model_parallel_size == 1(禁用 TP)或== n_gpus(全卡 TP),最稳妥;若必须用=2或=3,请确保n_gpus是其整数倍,且 actor 的 FSDPfsdp_size与之对齐。
5.3 关闭不必要的 offload,尤其对 ref policy
如前所述,ref policy 默认不启用 offload。但若你在ref.fsdp_config中手动开启param_offload=True,则每次compute_ref_log_prob都会触发 CPU↔GPU 参数搬运,拖慢 3–5 倍。
行动项:检查ppo_trainer.yaml,确保actor_rollout_ref.ref.fsdp_config.param_offload=False(默认值),切勿修改。
5.4 日志概率 debug:用print_shape=True快速定位截断
当log_prob结果异常(如全 0、NaN、长度不匹配),在compute_log_prob调用前插入:
print("DEBUG batch shape:", batch.batch['input_ids'].shape) print("DEBUG prompt_lengths:", batch.batch['prompt_lengths'][:5])若发现input_ids.shape[0] != 720,说明上游generate_sequences未正确膨胀;若prompt_lengths有负值或超长,说明 tokenizer 或 padding 逻辑出错。
行动项:verl的DataProto支持.debug_print()方法,一行调用即可 dump 全量字段 shape 与 dtype,比 print 更高效。
6. 总结:log_prob 是 verl 的策略心跳
verl中old policy log_prob的获取,远不止一次模型前向。它是verl混合并行架构的缩影:
- 数据层面,它被
FSDP切分、被vLLM生成、被dynamic batching填充; - 计算层面,它由
actor模型重放、与ref模型对比、为GRPO的 KL 控制提供燃料; - 工程层面,它用
micro_batch_size_per_gpu作杠杆,平衡吞吐与显存,用offload开关控制延迟,用mask精确界定策略作用域。
理解这一过程,你就掌握了verlRL 训练的脉搏。下次看到compute_log_prob耗时飙升,你知道该查micro_batch_size;看到 KL 散度失控,你知道该验ref_log_prob与old_log_prob的对齐;看到 OOM 报错,你知道该关offload或调小tensor_model_parallel_size。
log_prob不是黑箱输出,而是verl将强化学习理论严谨落地的工程宣言。
--- > **获取更多AI镜像** > > 想探索更多AI镜像和应用场景?访问 [CSDN星图镜像广场](https://ai.csdn.net/?utm_source=mirror_blog_end),提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。