AI 推理性能调优:动态批处理与连续批处理的调度策略
二、静态批处理的吞吐瓶颈:为什么 GPU 总是在"等请求"
大模型推理的 GPU 利用率通常只有 30-50%,原因在于请求到达的时间不均匀。静态批处理(Static Batching)等待固定数量的请求或超时后才组成一个 Batch 推理,在请求稀疏时 GPU 大量时间处于空闲等待状态。更严重的是,静态批处理要求 Batch 内所有请求同时完成——一个生成长度为 512 Token 的请求会阻塞整个 Batch,即使其他请求早已在 64 Token 处完成。动态批处理和连续批处理(Continuous Batching / In-Flight Batching)通过在推理过程中动态插入和移除请求,将 GPU 利用率提升到 80% 以上。
二、批处理策略的演进与原理
批处理策略经历了三代演进:静态批处理 → 动态批处理 → 连续批处理。静态批处理以固定窗口收集请求,整批推理;动态批处理以滑动窗口收集请求,新请求可加入下一个 Batch;连续批处理在每一步 Decode 时都可以插入新请求和移除已完成请求,实现真正的"流水线"式推理。
graph LR subgraph 静态批处理 A1[请求1<br/>128 tokens] --> A2[Batch 1<br/>等待所有请求完成] A3[请求2<br/>512 tokens] --> A2 A4[请求3<br/>64 tokens] --> A2 A2 --> A5[GPU 空闲等待<br/>请求3 提前完成但无法释放] end subgraph 连续批处理 B1[Step 1: 请求1,2,3] --> B2[Step 2: 请求1,2,3] B2 --> B3[Step 3: 请求1,2,3<br/>请求3 完成,移除] B3 --> B4[Step 4: 请求1,2,4<br/>请求4 新加入] B4 --> B5[Step 5: 请求1,2,4<br/>请求1 完成,移除] B5 --> B6[...持续调度] end style A5 fill:#ffcdd2 style B6 fill:#e8f5e9连续批处理的核心机制是 Iteration-Level Scheduling:每一步 Decode 后,调度器检查哪些请求已经生成 EOS Token,将其从 Batch 中移除;同时检查等待队列中是否有新请求,将其加入 Batch。这样 GPU 永远不会因为等待而空闲,Batch 中的请求数始终保持在最优水平。
三、连续批处理调度器的工程实现
3.1 请求状态管理
from dataclasses import dataclass, field from typing import List, Optional, Dict from enum import Enum import time class RequestStatus(Enum): WAITING = "waiting" # 等待调度 PREFILLING = "prefilling" # 预填充阶段 DECODING = "decoding" # 解码阶段 COMPLETED = "completed" # 已完成 ABORTED = "aborted" # 已中止 @dataclass class InferenceRequest: """推理请求:包含输入、参数和状态""" request_id: str prompt_tokens: List[int] max_tokens: int temperature: float = 0.7 top_p: float = 1.0 # 运行时状态 status: RequestStatus = RequestStatus.WAITING generated_tokens: List[int] = field(default_factory=list) kv_cache_ids: List[int] = field(default_factory=list) # KV Cache 页面 ID arrival_time: float = field(default_factory=time.time) first_token_time: Optional[float] = None @property def is_completed(self) -> bool: """检查请求是否完成:生成了 EOS 或达到最大长度""" if len(self.generated_tokens) == 0: return False if self.generated_tokens[-1] == 2: # EOS Token ID return True if len(self.generated_tokens) >= self.max_tokens: return True return False @property def total_tokens(self) -> int: return len(self.prompt_tokens) + len(self.generated_tokens) @property def ttft(self) -> Optional[float]: """Time To First Token:首 Token 延迟""" if self.first_token_time is None: return None return self.first_token_time - self.arrival_time class ContinuousBatchScheduler: """ 连续批处理调度器:每一步 Decode 后动态调整 Batch 组成 设计考量:调度器需要在吞吐和延迟之间平衡。 - 更大的 Batch 提高吞吐,但增加单请求延迟 - 更小的 Batch 降低延迟,但浪费 GPU 算力 - Prefill 阶段计算密集,应尽量与 Decode 阶段分开调度 """ def __init__( self, max_batch_size: int = 32, max_tokens_per_batch: int = 8192, scheduling_policy: str = "fcfs", # fcfs / priority / prefill_first ): self.max_batch_size = max_batch_size self.max_tokens_per_batch = max_tokens_per_batch self.scheduling_policy = scheduling_policy self._waiting_queue: List[InferenceRequest] = [] self._running_batch: List[InferenceRequest] = [] def add_request(self, request: InferenceRequest): """添加新请求到等待队列""" self._waiting_queue.append(request) def schedule(self) -> List[InferenceRequest]: """ 执行一轮调度: 1. 移除已完成的请求 2. 从等待队列中选择新请求加入 Batch 返回当前 Batch 中的请求列表 """ # 步骤 1:移除已完成的请求 completed = [r for r in self._running_batch if r.is_completed] for r in completed: r.status = RequestStatus.COMPLETED self._running_batch = [r for r in self._running_batch if not r.is_completed] # 步骤 2:计算当前 Batch 的剩余容量 current_tokens = sum(r.total_tokens for r in self._running_batch) remaining_slots = self.max_batch_size - len(self._running_batch) remaining_tokens = self.max_tokens_per_batch - current_tokens # 步骤 3:从等待队列中选择新请求 if self.scheduling_policy == "fcfs": self._schedule_fcfs(remaining_slots, remaining_tokens) elif self.scheduling_policy == "prefill_first": self._schedule_prefill_first(remaining_slots, remaining_tokens) return self._running_batch def _schedule_fcfs(self, remaining_slots: int, remaining_tokens: int): """FCFS 策略:先到先服务""" while remaining_slots > 0 and remaining_tokens > 0 and self._waiting_queue: request = self._waiting_queue[0] prompt_tokens = len(request.prompt_tokens) if prompt_tokens > remaining_tokens: # 单个请求的 Prompt 就超过剩余 Token 预算 # 如果 Batch 为空,仍需调度(否则请求永远无法开始) if len(self._running_batch) == 0: request.status = RequestStatus.PREFILLING self._running_batch.append(request) self._waiting_queue.pop(0) break request.status = RequestStatus.PREFILLING self._running_batch.append(request) self._waiting_queue.pop(0) remaining_slots -= 1 remaining_tokens -= prompt_tokens def _schedule_prefill_first(self, remaining_slots: int, remaining_tokens: int): """Prefill-First 策略:优先调度短 Prompt 的请求,减少 Prefill 对长请求的延迟影响""" # 按 Prompt 长度排序:短 Prompt 优先调度 self._waiting_queue.sort(key=lambda r: len(r.prompt_tokens)) self._schedule_fcfs(remaining_slots, remaining_tokens) def get_metrics(self) -> dict: """返回调度器指标""" return { "waiting_requests": len(self._waiting_queue), "running_requests": len(self._running_batch), "batch_utilization": len(self._running_batch) / self.max_batch_size, }3.2 Prefill 与 Decode 分离调度
@dataclass class ChunkedPrefillConfig: """分块 Prefill 配置:将长 Prompt 分成多个 Chunk 逐步处理""" chunk_size: int = 512 # 每个 Chunk 的 Token 数 max_prefill_chunks: int = 2 # 每步最多处理的 Prefill Chunk 数 class PrefillDecodeScheduler(ContinuousBatchScheduler): """ Prefill/Decode 分离调度器: 将 Prefill 和 Decode 分配到不同的步骤执行, 避免长 Prompt 的 Prefill 阻塞短请求的 Decode 设计考量:Prefill 是计算密集型(并行处理所有 Prompt Token), Decode 是访存密集型(逐 Token 生成,每步读取完整 KV Cache)。 混合调度会导致 Decode 请求的延迟抖动。 分离调度让 Decode 请求获得稳定的低延迟。 """ def __init__(self, prefill_config: ChunkedPrefillConfig, **kwargs): super().__init__(**kwargs) self.prefill_config = prefill_config self._prefill_pending: List[InferenceRequest] = [] def schedule_step(self) -> dict: """ 执行一步调度,返回 Prefill 和 Decode 的请求分组 """ # 移除已完成的请求 completed = [r for r in self._running_batch if r.is_completed] for r in completed: r.status = RequestStatus.COMPLETED self._running_batch = [r for r in self._running_batch if not r.is_completed] # 分离 Prefill 和 Decode 请求 prefill_requests = [r for r in self._running_batch if r.status == RequestStatus.PREFILLING] decode_requests = [r for r in self._running_batch if r.status == RequestStatus.DECODING] # 优先调度 Decode 请求(延迟敏感) # Prefill 请求使用剩余的 Batch 容量 decode_capacity = self.max_batch_size prefill_capacity = max(0, self.max_batch_size - len(decode_requests)) # 从等待队列中选择新请求进行 Prefill new_prefill = [] remaining_tokens = self.max_tokens_per_batch for req in self._waiting_queue[:prefill_capacity]: if len(req.prompt_tokens) <= remaining_tokens: req.status = RequestStatus.PREFILLING new_prefill.append(req) remaining_tokens -= len(req.prompt_tokens) for req in new_prefill: self._waiting_queue.remove(req) # Prefill 完成后转为 Decode for req in prefill_requests: req.status = RequestStatus.DECODING if req.first_token_time is None: req.first_token_time = time.time() return { "prefill": new_prefill, "decode": decode_requests + prefill_requests, # 已完成 Prefill 的加入 Decode "completed": completed, }四、连续批处理的边界与权衡
连续批处理的调度策略直接影响 TTFT(首 Token 延迟)和 Throughput(吞吐量)的平衡。FCFS 策略公平但可能导致长 Prompt 阻塞短请求;Prefill-First 策略优化了短请求的 TTFT,但长请求可能被饿死。生产环境通常采用混合策略:短请求优先调度,但为长请求设置最大等待时间,防止无限延迟。
KV Cache 的显存管理是连续批处理的另一个挑战。请求动态进出 Batch,KV Cache 需要频繁分配和释放。PagedAttention 通过按页分配 KV Cache 解决了碎片问题,但页面分配/释放本身也有开销。当 Batch 中请求的序列长度差异很大时,短请求释放的页面可能无法被长请求复用(页面大小固定),导致显存利用率下降。
在 Prefill/Decode 分离调度中,Prefill 的 Chunk 大小需要精心选择。Chunk 太大(如 2048 Token)会阻塞 Decode 步骤,增加延迟;Chunk 太小(如 64 Token)会增加 Prefill 的步骤数,降低计算效率。经验值是 256-512 Token,在延迟和效率之间取平衡。
五、总结
连续批处理通过 Iteration-Level Scheduling 实现请求的动态插入和移除,将 GPU 利用率从 30-50% 提升到 80% 以上。核心实践包括:FCFS 策略保证公平性,Prefill-First 策略优化短请求延迟,Prefill/Decode 分离调度消除延迟抖动,分块 Prefill 平衡计算效率与延迟。调度策略的选型应基于业务场景的延迟要求和请求分布特征,持续监控 TTFT 和 Throughput 指标进行调优。