Scheduler Overlap:CPU-GPU 调度级重叠(--disable-overlap-schedule)
一、这是哪一层的 Overlap?
SGLang 里有三层 overlap,它们彼此独立,针对不同瓶颈:
| 层次 | 选项 | 隐藏的是什么 |
|---|---|---|
| Scheduler Overlap(本篇) | --disable-overlap-schedule(默认开启) | 每次 iteration 之间的 CPU 结果处理开销 |
| SBO | --enable-single-batch-overlap | 单批内 Down GEMM 与 Combine Send 的通信延迟 |
| TBO | --enable-two-batch-overlap | 两批之间 MoE dispatch/combine 的通信延迟 |
二、动机:CPU 在哪里慢?
每次 GPU forward 结束后,CPU 需要处理结果(process_batch_result):
① copy_done.synchronize() ← 等 GPU→CPU 数据传输完成(强制阻塞同步点)
② next_token_ids.tolist() ← tensor 转 Python list
③ for req in batch:
req.output_ids.append() ← 更新输出
req.check_finished() ← 检查 EOS / max_tokens 停止条件
release_kv_cache() ← 释放完成请求的 KV slots(树操作)
tree_cache.cache_unfinished_req() ← 更新前缀树
req.grammar.accept_token() ← grammar 状态机推进(按需)
④ send_to_tokenizer() ← 发送输出到 detokenizer 进程
普通模式下,GPU 必须等这些全做完才能开始下一批:
三、解决方案:两个 CUDA Stream + GPU-side 等待
schedule_stream(优先级 0):整个 event_loop 跑在这个 stream context 里
└─ 凡在此 context 内发出的 GPU 操作,全部入队到 schedule_stream
└─ 包含:CPU 调度逻辑 + 调度期间触发的 GPU kernel/copy
forward_stream:GPU model forward 专用
└─ forward_batch_generation()
└─ copy_to_cpu(non_blocking=True)
└─ copy_done.record()
copy_stream:GPU→CPU 异步拷贝
└─ non_blocking=True 的 .to("cpu") 实际走这里(PyTorch 内部管理)
关键机制:forward_stream.wait_stream(schedule_stream) 是 GPU-side 等待(不阻塞 CPU),GPU 要等 schedule_stream 上所有已入队的操作执行完,才开始 forward。
初始化(scheduler.py:1231-1061):
# run_event_loop()
self.schedule_stream = self.device_module.Stream(priority=0)
with CudaStreamContext(self.schedule_stream): # event_loop 全跑在此 context
dispatch_event_loop(self)
# init_overlap()
self.forward_stream_ctx = self.device_module.stream(self.forward_stream)
self.copy_stream = self.device_module.Stream()
self.future_map = FutureMap(...) # 投机解码专用,见 5.2
self.batch_record_buf = [None] * 2 # 防 GC,见 5.1
四、schedule_stream 上的真实 CUDA 操作
这是理解
wait_stream的关键:schedule_stream 上不只有 CPU 代码,还有真实的 GPU 操作。
CudaStreamContext(self.schedule_stream) 意味着这个 context 里发出的任何 PyTorch CUDA op 都会入队到 schedule_stream。
get_next_batch_to_run() → prepare_for_decode() → alloc_for_decode() 链路上的 GPU ops:
# mem_cache/common.py:440-460(alloc_for_decode 内)
# ① GPU gather:从 req_to_token 表读上一个 token 的 slot 位置
last_loc = batch.req_to_token_pool.req_to_token[
batch.req_pool_indices, batch.seq_lens - 1 # ← GPU 索引读取
]
# ② GPU 加法:计算下一个 token 的序列长度
seq_lens_next = batch.seq_lens + token_per_req # ← GPU tensor 加法
# ③ GPU clone:复制当前 seq_lens 作为写入位置
locs = batch.seq_lens.clone() # ← GPU clone
# ④ GPU dtype 转换
out_cache_loc.to(torch.int32) # ← GPU cast
# ⑤ GPU scatter write:把新 token 的 KV slot 写进映射表
batch.req_to_token_pool.write(
(batch.req_pool_indices, locs), out_cache_loc.to(torch.int32)
)
# memory_pool.py:150:self.req_to_token[indices] = values ← GPU 索引写入
Prefill 时还有:
# ⑥ Triton kernel launch(写 req_to_token 表前缀段)
write_req_to_token_pool_triton(...) # ← GPU kernel
# ⑦ H2D 拷贝(新请求的 token ids 传到 GPU)
input_ids = torch.tensor(tokens, device="cuda") # ← CPU→GPU 拷贝
schedule_batch.py:2026 里的加法也是 GPU op:
self.seq_lens = self.seq_lens + 1 # ← GPU 加法(在 schedule_stream 上执行)
所以 forward_stream.wait_stream(schedule_stream) 等的是:
req_to_token映射表写入完成(forward 需要读这张表定位 KV cache)seq_lens新 tensor 创建完成(forward 用它计算 attention mask)out_cache_loc准备好(forward 用它写入新 token 的 KV)- Prefill 时的 H2D 拷贝完成(input_ids 到位)
五、完整 Pipeline 调用链与时序图
5.1 每一步在做什么(稳态:batch N-1 与 batch N 并发)
┌──────────────────────────────────────────────────────────────────────────┐
│ SCHEDULE STREAM (CPU 线程 + 调度期间触发的 GPU ops) │
│ │
│ ① recv_requests() │
│ ├─ [CPU] zmq/ipc poll 收新请求(TokenizedGenerateReqInput) │
│ └─ [CPU] TP rank0 收后通过 cpu_group 广播到其他 rank │
│ │
│ ② process_input_requests(recv_reqs) │
│ ├─ [CPU] 新推理请求 → 加入 waiting_queue │
│ └─ [CPU] abort/flush 等控制命令 → 立刻处理 │
│ │
│ ③ get_next_batch_to_run() ← CPU 重活 + 多个 GPU ops │
│ │ │
│ ├─ [CPU] merge last prefill batch into running_batch │
│ │ last_batch.filter_batch() + running_batch.merge_batch() │
│ │ │
│ ├─ Prefill 路径:get_new_batch_prefill() │
│ │ ├─ [CPU] 从 waiting_queue 选请求(贪心/优先级策略) │
│ │ ├─ [CPU] match_prefix() → 前缀树匹配,确定 cache hit 长度 │
│ │ ├─ [CPU] alloc KV token slots → free_slots 列表操作(纯 CPU) │
│ │ └─ prepare_for_extend(): │
│ │ ├─ [GPU→CPU H2D] input_ids.to("cuda") ← 新请求 token │
│ │ ├─ [GPU] write_req_to_token_pool_triton ← Triton kernel │
│ │ └─ [GPU] req_to_token scatter write │
│ │ │
│ └─ Decode 路径:update_running_batch() → prepare_for_decode() │
│ ├─ [CPU] input_ids = output_ids(指针赋值,此时是 future 占位) │
│ ├─ [CPU] alloc free_slots index(纯 CPU 列表操作) │
│ ├─ [GPU] req_to_token[req_pool_indices, seq_lens-1] ← gather │
│ ├─ [GPU] seq_lens_next = seq_lens + token_per_req ← 加法 │
│ ├─ [GPU] locs = seq_lens.clone() ← clone │
│ ├─ [GPU] req_to_token[indices] = out_cache_loc ← scatter │
│ ├─ [GPU] seq_lens = seq_lens + 1 (新 tensor,非 in-place) │
│ └─ [GPU] seq_lens_cpu = seq_lens_cpu + 1 │
│ │
│ ④ is_disable_overlap_for_batch() [CPU] │
│ └─ 判断是否强制串行(连续 prefill / spec+grammar) │
│ │
│ ⑤ run_batch(batch) │
│ ├─ [CPU] batch.get_model_worker_batch() ← 打包字段引用,无 copy │
│ ├─ [CPU] record_batch_in_overlap() ← 防 GC,引用存 buf │
│ ├─ [CPU] sampling_info.copy_for_forward() ← dataclass replace │
│ ├─ [CPU] future_map.alloc_future_indices(bs) ← 分配占位 indices │
│ │ │
│ └─ 切换到 forward_stream_ctx ────────────────────────────────────┐ │
│ │ │
│ ╔══════════════════════════ FORWARD STREAM ═════════════════════╗ │ │
│ ║ ║ │ │
│ ║ wait_stream(schedule_stream) ║ │ │
│ ║ ↑ GPU 在这里等 schedule_stream 上所有操作完成(③的 GPU ops) ║ │ │
│ ║ ║ │ │
│ ║ future_map.resolve_future(model_worker_batch) ║ │ │
│ ║ └─ [GPU] 把上批 next_token_ids 写进当前批 input_ids ║ │ │
│ ║ └─ 必须在此 stream:保证上批 forward 产生的 token 已就绪 ║ │ │
│ ║ ║ │ │
│ ║ forward_batch_generation(model_worker_batch) ← GPU forward ║ │ │
│ ║ └─ 非阻塞,入队后 CPU 立刻返回 ║ │ │
│ ║ ║ │ │
│ ║ copy_done = Event() ║ │ │
│ ║ copy_to_cpu(non_blocking=True): ║ │ │
│ ║ next_token_ids.to("cpu", non_blocking=True) ║ │ │
│ ║ logprobs.to("cpu", non_blocking=True)(按需) ║ │ │
│ ║ copy_done.record() ← 在 forward_stream 打事件标记 ║ │ │
│ ╚═══════════════════════════════════════════════════════════════╝ │ │
│ │ │
│ ◄────────────────────────────────────────────────────────────────┘ │
│ ⑥ result_queue.append((batch.copy(), batch_result)) [CPU] │
│ │
│ ⑦ pop_and_process() ← 处理 batch N-1,此时 GPU 正在跑 batch N │
│ └─ process_batch_result(batch_N-1, result_N-1): │
│ ├─ [CPU 阻塞] copy_done.synchronize() ← 唯一强制同步点 │
│ │ 等 GPU→CPU 的 H2D 传输完成 │
│ ├─ [CPU] next_token_ids.tolist() ← tensor → Python list │
│ ├─ [CPU] for req in batch: │
│ │ req.output_ids.append(next_token_id) │
│ │ req.check_finished() ← EOS / max_tokens 判断 │
│ │ release_kv_cache() ← 释放 KV slots(树操作) │
│ │ tree_cache.cache_unfinished_req() ← 更新前缀树 │
│ │ req.grammar.accept_token() ← grammar 状态机(按需) │
│ └─ [CPU] send_to_tokenizer() ← 输出发到 detokenizer 进程 │
│ │
│ ⑧ launch_batch_sample_if_needed() ← spec 专用,依赖上批 grammar 状态 │
└──────────────────────────────────────────────────────────────────────────┘
5.2 时序并发图(两个 iteration 对比)
时间轴 ──────────────────────────────────────────────────────────────────►
schedule_stream:
[①②③④⑤ sched(N)] [⑥⑦ process(N-1)] [①②③④⑤ sched(N+1)] [⑦ process(N)] ...
↑ ↑
⑦ 处理 N-1 时 ⑦ 处理 N 时
GPU 在跑 N GPU 在跑 N+1
forward_stream:
[wait] [resolve] [forward(N) ─────────] [copy(N)→CPU]
[wait] [resolve] [forward(N+1) ───]
[copy(N+1)→CPU]
←────────────────────────────────────────────────────────────────────────►
copy_done.synchronize() 是 CPU 等 GPU 的唯一阻塞点(⑦ 最开头)
forward_stream.wait_stream() 是 GPU 等 GPU/CPU ops 的同步点(非阻塞 CPU)
六、三个关键设计细节
六、三个关键设计细节
6.1 seq_lens 为什么不能 in-place
prepare_for_decode() 里:
if self.enable_overlap:
self.seq_lens = self.seq_lens + 1 # ← 创建新 tensor
else:
self.seq_lens += 1 # ← 普通模式可以 in-place
原因:self.seq_lens 已打包进上一批的 model_worker_batch,该 batch 此时正在 forward_stream 上运行(GPU 正在读它)。
如果 += 1(in-place)→ schedule_stream 上的 GPU 加法原地修改了 forward_stream 正在读取的 tensor → 两个 GPU stream 的数据竞争(无同步)→ 结果不确定。
6.2 FutureMap:解决 input_ids 的时序矛盾
矛盾:
正常逻辑: process_batch_result(N) 把 next_token_id 写入 req.output_ids
→ prepare_for_decode(N+1) 读 output_ids 作为 input_ids
Overlap 逻辑:run_batch(N+1) 要在 process_batch_result(N) 完成前启动!
→ req.output_ids 还没有 token N 的结果
→ input_ids 怎么来?
解决方案:GPU 产生 token 后立刻存进 FutureMap,下一批从 FutureMap 取值(都在 forward_stream 上,保证顺序)。
6.3 forward_stream.wait_stream(schedule_stream) 等的是什么
| GPU op | 来源 | forward 为什么需要它 |
|---|---|---|
req_to_token[indices] = out_cache_loc | alloc_for_decode | forward 需读此表定位 KV cache |
seq_lens = seq_lens + 1 | prepare_for_decode | forward 用于计算 attention mask |
out_cache_loc(新 token slot) | alloc_for_decode | forward 将新 token KV 写入此处 |
input_ids.to("cuda") | prepare_for_extend | forward 的输入 |
6.4 copy_done.synchronize() 同时是 pipeline 深度的天然节流阀
这个 sync 点不只是"等数据",它还把 CPU 的 lookahead 限制在恰好 1 batch。
没有 sync 点会怎样(假设):
CPU 不停 launch,从不等 GPU 结果 → req 状态不更新 → KV cache 不释放 → OOM。
有 sync 点(实际):CPU 永远只超前 1 step
偏移:CPU 的 proc(K)·SYNC 恰好对应 GPU 的 fw(K) 结束,差恰好 1 step。
sync 点同时承担两个职责:
- 正确性:必须等 GPU 结果落到 CPU 上,才能更新 req 状态、释放 KV cache
- 流控:天然把 pipeline depth 压在 1,防止 CPU 无限超前
七、强制串行的场景
is_disable_overlap_for_batch() 返回 True 时,先处理上批结果再启动 GPU:
① 连续 prefill(需要环境变量开启)
disable_overlap_for_batch = (
envs.SGLANG_DISABLE_CONSECUTIVE_PREFILL_OVERLAP.get()
and batch.forward_mode.is_extend()
and self.last_batch.forward_mode.is_extend()
)
原因:overlap 模式下 batch N-1 的 process_batch_result 被推迟到 batch N 启动后,导致 N-1 的第一个 token(TTFT)多等了一个 iteration 时间。默认关闭(即默认允许 overlap),需要 SGLANG_DISABLE_CONSECUTIVE_PREFILL_OVERLAP=1 才触发串行。
② spec v2 + grammar + decode
need_grammar_sync = (
batch.is_spec_v2 and batch.has_grammar
and batch.forward_mode.is_decode()
and len(self.result_queue) > 0
)
原因:spec v2 draft 生成依赖上批更新后的 grammar 状态(req.grammar.accept_token() 在 process_batch_result 里调用)。overlap 时状态尚未更新 → draft 基于错误状态。
③ 架构层面强制禁用(server_args.py)
pp_size > 1:Pipeline Parallelism 跨 stage 需要严格同步 hidden states- Mamba no_buffer 模式:状态更新有严格的顺序依赖
- Mixed chunked prefill:与 overlap 逻辑不兼容
八、关键文件索引
| 文件 | 行号 | 关键内容 |
|---|---|---|
server_args.py | 617, 5054 | disable_overlap_schedule: bool = False |
managers/scheduler.py | 321 | self.enable_overlap = not server_args.disable_overlap_schedule |
managers/scheduler.py | 1231-1241 | run_event_loop() → schedule_stream context |
managers/scheduler.py | 1272-1324 | event_loop_overlap() → overlap 核心逻辑 |
managers/schedule_batch.py | 1970-2027 | prepare_for_decode() → seq_lens 非 in-place |
managers/overlap_utils.py | 全文 | FutureMap → 解决 input_ids 时序问题 |
九、Token Pool 设计:两级间接寻址
这是理解 schedule_stream 上 GPU op 的基础——为什么 alloc_for_decode 里有那么多 scatter/gather。
9.1 两张表的职责分工
ReqToTokenPool(地址翻译表)
shape: [max_num_reqs, max_context_len] dtype: int32 device: CUDA
语义: req_pool_indices[i], pos → slot_idx(哪个 token slot 存了该位置的 KV)
free_slots: CPU Python list(alloc 直接 list.pop,CPU 逻辑)
MHATokenToKVPool(实际 KV 数据)
per layer: k_buffer[layer], v_buffer[layer]
shape: [total_token_slots + page_size, num_heads, head_dim] dtype: fp16/bf16 device: CUDA
free_slots: GPU tensor(batch alloc 用 GPU slice 更快)
两级寻址链:
(req_idx, pos) → req_to_token[req_idx, pos] = slot_idx ← ReqToTokenPool(GPU scatter/gather)
slot_idx → k_buffer[layer][slot_idx, :, :] ← MHATokenToKVPool(直接 GPU 索引)
9.2 Decode 时的 alloc_for_decode(为什么有这么多 GPU op)
# mem_cache/common.py:423-462
# ① CPU: 从 free_slots(Python list)pop 出 bs 个空闲 slot
out_cache_loc = free_slots[:bs] # CPU list → GPU tensor(alloc 时已经处理)
# ② GPU gather:读取各请求上一个 token 的 slot(用于 attention 中的 KV 定位)
last_loc = req_to_token[req_pool_indices, seq_lens - 1] # GPU gather
# ③ GPU 算出下一个位置的 seq_len
seq_lens_next = batch.seq_lens + token_per_req # GPU 加法
# ④ GPU clone:复制当前 seq_lens 作为写入目标位置
locs = batch.seq_lens.clone() # GPU clone
# ⑤ GPU scatter write:把新 slot 填进地址翻译表
req_to_token_pool.write((req_pool_indices, locs), out_cache_loc.to(torch.int32))
# 等价于:req_to_token[req_pool_indices, locs] = out_cache_loc ← GPU scatter write
forward 为什么需要等这些 GPU op?
- ②:gather 读上一步的 slot 位置,用于 attention 时定位已有 KV
- ⑤:scatter write 把新 token 的 slot 写进翻译表,forward 结束后要往这个 slot 写 KV
9.3 设计直觉
为什么 free_slots 在两个 pool 里类型不同?
ReqToTokenPool.free_slots = Python list
→ 原因:每次 alloc 数量就是 batch_size(数量可预测),Python list 的 pop/append 够用
→ list.pop() 是 O(1),CPU 上直接操作
MHATokenToKVPool.free_slots = GPU tensor
→ 原因:KV slot 的数量可达几百万,每次 alloc 可能是几千个(prefill 长序列)
→ GPU tensor 的 slice 操作(free_slots[:n])远快于 CPU list,且结果直接是 GPU tensor
十、设计直觉总结
- schedule_stream 不只是"CPU stream":它是 CPU 调度逻辑跑在的 CUDA stream context,凡在其中发出的 GPU op(tensor 加法、scatter write、Triton kernel)都入队到这个 stream。
- wait_stream = 跨 stream 的数据依赖声明:forward_stream 依赖 schedule_stream 写好的
req_to_token、seq_lens等,用wait_stream在 GPU 侧声明这个依赖,不阻塞 CPU。 - non_blocking + Event = 异步数据搬运:GPU→CPU 拷贝不阻塞,CPU 在需要数据时才
synchronize,这是唯一的 CPU 等 GPU 的阻塞点。 - FutureMap = 解耦 GPU token 生产与 CPU req 消费:GPU 产生 token、CPU 更新 req 对象,两者时序不同,通过 FutureMap 作为中间层解耦。
- 非 in-place = 防跨 stream 数据竞争:两个 stream 共享同一个 GPU tensor 时,有修改的一方必须创建新 tensor,不能原地改。
问答复盘
Q1: forward_stream.wait_stream(schedule_stream) 是 CPU 在等还是 GPU 在等?等的是什么?
A:GPU 在等,不阻塞 CPU。等的是 schedule_stream 上的 GPU ops:req_to_token scatter write(KV cache 地址表更新)、seq_lens + 1 加法(新 tensor 创建)、Prefill 时的 H2D input_ids 拷贝等。这些都是 forward 的前置数据依赖。
Q2: 为什么 future_map.resolve_future() 要放在 forward_stream 上?
A:因为它要读上一批 forward 产生的 next_token_ids(GPU tensor)。放在同一个 forward_stream 上,保证上批 forward 执行完(stream 内顺序执行)再读。如果放 schedule_stream,需要额外 wait_stream(forward_stream) 同步。
Q3: seq_lens += 1 为什么在 overlap 模式下不能用?
A:seq_lens 已打包进上批 model_worker_batch,forward_stream 上的 GPU 正在读它。如果 schedule_stream 上做 in-place 修改,两个 stream 对同一 GPU tensor 并发读写、无同步 → 数据竞争。必须 = + 1 创建新 tensor,旧 tensor 由 batch_record_buf 保持引用到 GPU 用完。
Q4: process_batch_result 里最慢的一步是什么?
A:copy_done.synchronize(),唯一的 CPU 阻塞等待点,等 GPU→CPU 的 PCIe 传输完成。开了 logprob 返回时数据量变大,这一步更慢。
Q5: 为什么连续 prefill 会影响 TTFT?
A:overlap 模式下 batch N-1 的 process_batch_result 被推到 batch N 启动后执行。如果 N-1 是 prefill,它的第一个 output token 要等 N 启动之后才处理完发给客户端,比串行模式多延迟了一个 iteration 的 GPU forward 时间。
Q6: FutureMap 解决了什么根本矛盾?
A:overlap 要求 run_batch(N+1) 在 process_batch_result(N) 之前启动,但 N+1 的 input_ids 来自 N 的输出。FutureMap 让 GPU 直接从 N 的 next_token_ids GPU tensor 拿数据(通过 resolve_future),绕过 CPU 侧的 req.output_ids 更新,解开这个循环依赖。
Q7: copy_done(N).synchronize() 会阻止 prepare(N+1) 与 forward(N) 并发吗?
A:不会。一次 iteration 的 CPU 顺序是 prepare(N) → run_batch(N) → process(N-1),sync 在末尾。所以:
- iter K+1 开头执行
prepare(N+1)时,GPU 正在跑forward(N)(iter K 入队) sync(N)在 iter K+1 的process(N)里,是 iter K+1 的最后一步prepare(N+1)比sync(N)早运行,完全不被阻塞
正确并发关系:prepare(N+1) ∥ forward(N)。
Q8: copy_done.synchronize() 只是等数据吗?它还有什么作用?
A:它同时是 pipeline 深度的天然节流阀。每次 iteration 末尾,CPU 必须等上一批结果落地才能进入下一轮,把 lookahead 天然限制在 1 step。如果没有这个 sync,CPU 会无限超前 launch,导致:KV cache 从不释放(check_finished 不运行)→ OOM;req 状态从不更新 → 下一批 input_ids 全错;客户端永远收不到输出。