vLLM Breakable CUDA Graph

本文最后更新于 2026年6月7日

最近笔者在 vLLM 引入了 breakable cudagraph: vllm-project/vllm#42304,核心入口是eager_break_during_capture这个装饰器。它的思路和现有的 piecewise cudagraph 完全不同,笔者觉得挺有意思的,所以写一篇介绍一下它是怎么实现的。

背景:为什么 cudagraph 需要被“打断”

CUDA Graph 的好处是把一连串 kernel launch 录制下来,之后 replay 时几乎没有 host 端开销, 对 decode 这种 launch-bound 的场景收益很大。

但是 LLM 里有一类算子不适合被 capture 进 graph:attention

所以我们希望:把 attention 排除在 cudagraph 之外,让它 eager 执行,其余部分照常 capture。

vLLM 原本的做法是 piecewise cudagraph:依赖 torch.compile,在 FX graph 层面按 splitting_ops(也就是 attention 这些算子) 把整个模型切成很多段,每段单独编译、单独 capture cudagraph,attention 节点夹在段与段之间 eager 跑。

breakable cudagraph 是另一条路(灵感来自 sgl-project/sglang#19102):

不在编译期切图,而是用一个 capture context 驱动整个 forward,在 dispatcher 层拦截 attention 临时结束当前的 stream capture,把这个 op eager 执行掉,再开一段新的 capture 接着录。

一句话总结对比:

  • piecewise:编译期静态切图,靠 torch.compile 的 FX split
  • breakable:运行期动态打断,靠 CUDA stream capture 的 begin/end

后者最大的好处是完全不依赖 torch.compile

整体结构

整个机制可以拆成三块:

  1. eager_break_during_capture:装饰在 attention 这类 custom op 上的“断点”装饰器
  2. BreakableCUDAGraphCapture:真正干活的 capture context,负责分段 capture + 记录 eager 段 + replay
  3. BreakableCUDAGraphWrapper:套在整个 model 外面的 wrapper,对标原来的 CUDAGraphWrapper,负责按 batch 形状缓存与调度

断点装饰器

eager_break_during_capture这个装饰器贴在那些“不能被录进 graph”的 op 上。最典型的就是 attention:

1
2
3
4
5
# vllm/model_executor/layers/attention/mla_attention.py
@eager_break_during_capture
@maybe_transfer_kv_layer
def unified_mla_attention_with_output(q, kv_c_normed, k_pe, output, layer_name, ...):
...

eager_break_during_capture返回一个 wrapper。wrapper 在运行时判断当前要不要打断:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@functools.wraps(fn)
def wrapper(*args, **kwargs):
capture = BreakableCUDAGraphCapture.current()
if capture is None:
return fn(*args, **kwargs) # 没有活跃 capture,正常 eager 执行
if not capture._capturing:
return fn(*args, **kwargs)
if is_forward_context_available():
mode = get_forward_context().cudagraph_runtime_mode
if mode == CUDAGraphMode.FULL:
return fn(*args, **kwargs) # FULL graph 模式不需要打断

# 真正在 capture 中:弱引用参数,交给 capture 处理
weak_args = tuple(
weak_ref_tensor(a) if isinstance(a, torch.Tensor) else a for a in args
)
weak_kwargs = {
k: weak_ref_tensor(v) if isinstance(v, torch.Tensor) else v
for k, v in kwargs.items()
}
return capture.add_eager(lambda: fn(*weak_args, **weak_kwargs))

几个值得注意的点:

为什么要 weak-ref 参数? 因为 replay 用的 lambda 会一直持有这些参数。如果是强引用, 就会把 cudagraph memory pool 里对应的 slot 永久 pin 住,导致不同 batch descriptor 之间没法复用内存。 而 cudagraph pool 本身已经拥有这块内存,replay 时这些 tensor 地址是稳定的,所以用 weak ref 去 deref 是安全的。

FULL 模式直接放行。 如果 forward context 说当前是 CUDAGraphMode.FULL(整个 forward 录一张大图), 那就不应该打断——breakable 解决的是 PIECEWISE 场景。

BreakableCUDAGraphCapture:分段 capture 的核心

这是真正干活的类。它是 thread-local 的,一个线程同时只能有一个活跃 capture。

它维护一个关键的数据结构:self.segments,一个 zero-arg callable 列表。 列表里每个元素要么是某段 cudagraph 的 replay 方法,要么是一个 eager 段的函数。replay 时按顺序挨个调用就行。

最后的效果

1
2
3
4
capture = BreakableCUDAGraphCapture(pool=self.graph_pool)
with capture:
output = self.runnable(*args, **kwargs)
output = weak_ref_tensors(output)

我们希望最后设计出来的API还是和torch提供的with torch.cuda.graph(g):类似,下面讲讲如何做到。

分段

__enter__ 时开第一段,__exit__ 时结束最后一段:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def _begin_segment(self):
g = torch.cuda.CUDAGraph()
if self.pool is not None:
g.capture_begin(pool=self.pool)
else:
g.capture_begin()
self._current_graph = g
self._capturing = True

def _end_segment(self):
if not self._capturing:
return
self._current_graph.capture_end()
self.segments.append(self._current_graph.replay) # 把这段的 replay 记下来
self._current_graph = None
self._capturing = False

注意这里是直接调 CUDAGraph.capture_begin() / capture_end(), 而不是用 torch.cuda.graph() 这个上下文管理器。

eager 打断

核心就是 add_eagereager_break_during_capture 最后调的就是它:

1
2
3
4
5
6
def add_eager(self, fn):
self._end_segment() # 1. 结束当前 graph 段
result = fn() # 2. 在 capture stream 上 eager 执行
self.segments.append(fn) # 3. 把 fn 本身记进 segments,留给 replay
self._begin_segment() # 4. 开新的一段接着录
return result

第 1 步 self._end_segment()

先结束 capture,把 graph 记录下来

第 2 步 result = fn()

capture 结束后,fn()(也就是被弱引用包好的 attention 调用: fn = lambda: attention(*weak_args, **weak_kwargs))就以普通 eager 模式跑一遍。

第 3 步 self.segments.append(fn)

  • graph 段 append 的是 cudagraph.replay——replay 时重放录好的 kernel,输入输出都在固定地址
  • eager 段 append 的是 fn 本身——replay 时重新跑一遍这个 Python 函数

第 4 步 self._begin_segment():开新的一段接着录。

attention 跑完了,后面还有 MLP / norm 等可以继续录进 graph 的算子,所以马上重新 capture_begin(), 开一段新的 graph,等下一个 attention 再来打断。如此往复,整个 forward 就被切成 [graph段, eager段, graph段, eager段, ...] 这样交替的序列。

配一张时序图大概是这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
capture 开始

├─ [graph segment 0] ← MLP / norm / ... 全部录进 graph 0

├─ attention 调用 → add_eager
│ end graph 0 → eager 跑 attention → begin graph 1

├─ [graph segment 1] ← 下一层的 MLP / norm ...

├─ attention 调用 → add_eager
│ end graph 1 → eager 跑 attention → begin graph 2

├─ [graph segment 2]

capture 结束(end 最后一段)

segments = [g0.replay, attn0, g1.replay, attn1, g2.replay]

可以看到 segments 最终是 graph 段和 eager 段交替排列的列表,replay 时按这个顺序挨个跑就还原了整个 forward。

replay

replay 简单到不能再简单:

1
2
3
def replay(self):
for r in self.segments:
r()

graph 段就是 cudagraph.replay(),eager 段就是重新跑一遍那个 fn。注意 replay 不返回值—— 所有的数据流都通过 static output buffer 传递,调用方拿的是固定地址的 output tensor。

总结

breakable cudagraph 把“在 attention 边界处打断 cudagraph”这件事,从编译期的 FX 切图搬到了运行期的 stream capture 打断

  • eager_break_during_capture 贴在 attention 这类 op 上当断点,capture 中遇到就 add_eager
  • BreakableCUDAGraphCapture 把整个 forward 录成 [graph段, eager段, graph段, ...] 的 callable 列表,replay 时顺序跑
  • BreakableCUDAGraphWrapper 按 batch 形状缓存 capture,对标原来的 CUDAGraphWrapper

好处则是彻底摆脱了 torch.compile,对于 DeepSeek V4 这类模型不用再走那条又长又重的编译管线。