vLLM Breakable CUDA Graph
本文最后更新于 2026年6月7日
最近笔者在 vLLM 引入了 breakable cudagraph: vllm-project/vllm#42304,核心入口是eager_break_during_capture这个装饰器。它的思路和现有的
piecewise cudagraph
完全不同,笔者觉得挺有意思的,所以写一篇介绍一下它是怎么实现的。
背景:为什么 cudagraph 需要被“打断”
CUDA Graph 的好处是把一连串 kernel launch 录制下来,之后 replay 时几乎没有 host 端开销, 对 decode 这种 launch-bound 的场景收益很大。
但是 LLM 里有一类算子不适合被 capture 进 graph:attention。
所以我们希望:把 attention 排除在 cudagraph 之外,让它 eager 执行,其余部分照常 capture。
vLLM 原本的做法是 piecewise cudagraph:依赖
torch.compile,在 FX graph 层面按
splitting_ops(也就是 attention 这些算子)
把整个模型切成很多段,每段单独编译、单独 capture cudagraph,attention
节点夹在段与段之间 eager 跑。
breakable cudagraph 是另一条路(灵感来自 sgl-project/sglang#19102):
不在编译期切图,而是用一个 capture context 驱动整个 forward,在 dispatcher 层拦截 attention 临时结束当前的 stream capture,把这个 op eager 执行掉,再开一段新的 capture 接着录。
一句话总结对比:
- piecewise:编译期静态切图,靠
torch.compile的 FX split - breakable:运行期动态打断,靠 CUDA stream capture 的 begin/end
后者最大的好处是完全不依赖 torch.compile。
整体结构
整个机制可以拆成三块:
eager_break_during_capture:装饰在 attention 这类 custom op 上的“断点”装饰器BreakableCUDAGraphCapture:真正干活的 capture context,负责分段 capture + 记录 eager 段 + replayBreakableCUDAGraphWrapper:套在整个 model 外面的 wrapper,对标原来的CUDAGraphWrapper,负责按 batch 形状缓存与调度
断点装饰器
eager_break_during_capture这个装饰器贴在那些“不能被录进
graph”的 op 上。最典型的就是 attention:
1 | |
eager_break_during_capture返回一个 wrapper。wrapper
在运行时判断当前要不要打断:
1 | |
几个值得注意的点:
为什么要 weak-ref 参数? 因为 replay 用的 lambda 会一直持有这些参数。如果是强引用, 就会把 cudagraph memory pool 里对应的 slot 永久 pin 住,导致不同 batch descriptor 之间没法复用内存。 而 cudagraph pool 本身已经拥有这块内存,replay 时这些 tensor 地址是稳定的,所以用 weak ref 去 deref 是安全的。
FULL 模式直接放行。 如果 forward context 说当前是
CUDAGraphMode.FULL(整个 forward 录一张大图),
那就不应该打断——breakable 解决的是 PIECEWISE 场景。
BreakableCUDAGraphCapture:分段
capture 的核心
这是真正干活的类。它是 thread-local 的,一个线程同时只能有一个活跃 capture。
它维护一个关键的数据结构:self.segments,一个
zero-arg callable 列表。 列表里每个元素要么是某段
cudagraph 的 replay 方法,要么是一个 eager 段的函数。replay
时按顺序挨个调用就行。
最后的效果
1 | |
我们希望最后设计出来的API还是和torch提供的with torch.cuda.graph(g):类似,下面讲讲如何做到。
分段
__enter__ 时开第一段,__exit__
时结束最后一段:
1 | |
注意这里是直接调 CUDAGraph.capture_begin() /
capture_end(), 而不是用 torch.cuda.graph()
这个上下文管理器。
eager 打断
核心就是
add_eager,eager_break_during_capture
最后调的就是它:
1 | |
第 1 步 self._end_segment()
先结束 capture,把 graph 记录下来
第 2 步 result = fn()
capture 结束后,fn()(也就是被弱引用包好的 attention
调用:
fn = lambda: attention(*weak_args, **weak_kwargs))就以普通
eager 模式跑一遍。
第 3 步 self.segments.append(fn)
- graph 段 append 的是
cudagraph.replay——replay 时重放录好的 kernel,输入输出都在固定地址 - eager 段 append 的是
fn本身——replay 时重新跑一遍这个 Python 函数
第 4 步
self._begin_segment():开新的一段接着录。
attention 跑完了,后面还有 MLP / norm 等可以继续录进 graph
的算子,所以马上重新 capture_begin(), 开一段新的
graph,等下一个 attention 再来打断。如此往复,整个 forward 就被切成
[graph段, eager段, graph段, eager段, ...]
这样交替的序列。
配一张时序图大概是这样:
1 | |
可以看到 segments 最终是 graph 段和 eager
段交替排列的列表,replay 时按这个顺序挨个跑就还原了整个 forward。
replay
replay 简单到不能再简单:
1 | |
graph 段就是 cudagraph.replay(),eager
段就是重新跑一遍那个 fn。注意 replay 不返回值——
所有的数据流都通过 static output buffer 传递,调用方拿的是固定地址的
output tensor。
总结
breakable cudagraph 把“在 attention 边界处打断 cudagraph”这件事,从编译期的 FX 切图搬到了运行期的 stream capture 打断:
eager_break_during_capture贴在 attention 这类 op 上当断点,capture 中遇到就add_eagerBreakableCUDAGraphCapture把整个 forward 录成[graph段, eager段, graph段, ...]的 callable 列表,replay 时顺序跑BreakableCUDAGraphWrapper按 batch 形状缓存 capture,对标原来的CUDAGraphWrapper
好处则是彻底摆脱了 torch.compile,对于 DeepSeek V4
这类模型不用再走那条又长又重的编译管线。