vllm-torch-compile-integration
本文最后更新于 2026年3月19日
vLLM的torch.compile集成在网上资料非常少,官方写的有一篇博客 和文档
但是很多相关的介绍都是语焉不详,只有宏观上的介绍,没有具体的细节。
笔者正在写这篇的时候,BBuf 发了一篇类似的文章[https://zhuanlan.zhihu.com/p/1955402895890560120]
所以笔者决定自己写一篇。
torch.compile
介绍
这个网上的介绍已经非常多了,在此就不再赘述。
需要明确的是,torch.compile包括两个重要的组成部分:Dynamo和Inductor
整个编译过程大概是下面这样,具体可以参考这篇博客
- 使用 TorchDynamo 来 trace 函数
- xxxx
- Inductor 运行 post_grad_passes,优化计算图。
- xxxx
自定义backend
torch.compile提供了使用自定义backend的的接口,具体参考这篇文档
backend函数需要符合这样的接口 1
(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> CallableDynamo
调用,接收一个fx graph,后端需要返回优化后的、等价的fx graph。
1 | |
自定义pass
本节主要参考了下面两篇文档
- 官方文档:Writing Graph Transformations on ATen IR
- 官方教程: Building a Convolution/Batch Norm fuser with torch.compile
当torch.compile官方提供的pass不满足要求的时候可以选择添加自定义pass。
官方总结了这么几种添加pass的目的
- Axis A: 1. Creating one-to-X mapping (eg. decomposition) 2. Creating many-to-one mapping (eg. fusion)
- Axis B: 1. Doing forwards iteration (eg. shape propagation) 2. Doing backwards iteration (eg. dead code elimination)
- Axis C: 1. Dependent on local node information (eg. out-variant conversion) 2. Dependent on global graph information (eg. memory planning)
修改计算图最直接的方式就是直接操作生成的fx graph
1
2
3
4def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
node.target = torch.ops.aten.mul.Tensorreplace_add_with_mul会把所有的add操作都替换为mul。
当然直接操作计算图是比较繁琐的,所以pytorch也提供了一些比较方便的
helper utilities。
我目前接触过的,大概有三种pytorch提供的自定义pass的方式
torch.fx.Transformersubgraph_rewriterpattern matcher
torch.fx.Transformer
1 | |
1 | |
这种方式只能比较方便地对一个 node 做变换。
subgraph_rewriter
如果想要实现 fusion,就需要 X->one 或者 X->Y 的能力
1 | |
pattern matcher
看起来pattern matcher和subgraph_rewriter非常相似,我也没搞明白他们俩的区别和联系是什么,为什么要搞两套api
具体可以参考这篇教程,有一个完整的例子。
1 | |
这个 api 我用的时候还是有些问题,比如说 pattern
的参数不能是常量,一定要是 tensor。
所以如果某个参数是常量,并且有多个可能的取值,那就只能针对每个取值都生成一个pass,取值有无限多个的情况我也不知道怎么办。
Pass Manager
当后端有了多个自定义pass的时候,他们会统一在pass manager处理。
示意代码如下: 1
2
3
4
5
6
7
8
9
10
11class PassManager:
def __init__(self):
self.passes: list[InductorPass] = []
def __call__(self, graph: fx.Graph):
for pass_ in self.passes:
pass_(graph)
VllmInductorPass.dump_prefix += 1
self.post_cleanup(graph)
def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)
vLLM如何利用torch.compile
上面那部分更偏torch.compile本身,这一节专门讲 vLLM
源码是怎么把这套机制用起来的。
@support_torch_compile给模型类注入编译能力TorchCompileWithNoGuardsWrapper负责真正调用torch.compileCompilationConfig.init_backend()在VLLM_COMPILE模式下返回VllmBackendVllmBackend负责 cache、图切分、按 range 预编译、挂载 post-grad passesPiecewiseBackend在运行时按 token 数选择合适的已编译子图
入口:@support_torch_compile
它主要做了几件事:
- 推断或者接收
dynamic_arg_dims - 把
TorchCompileWithNoGuardsWrapper动态加到类的基类里 - 在新的
__init__里根据CompilationMode、enable_if、ignore_torch_compile等条件决定要不要真的编译 - 在第一次调用时给输入打 dynamic shape 标记
- 在 AOT compile 打开时尝试从磁盘直接加载已编译产物
简化一下,入口大概是这样:
1 | |
这里有两个很关键的点。
第一,vLLM 不是完全依赖 Dynamo
自己去猜哪些维度是动态的,而是会在第一次编译前显式调用
torch._dynamo.mark_dynamic / mark_unbacked
去标记输入维度。
第二,vLLM 会在 tracing 期间收集 forward
相关源码文件。做法有点“黑科技”:它会 patch
InliningInstructionTranslator.inline_call_,把 Dynamo
内联到的函数所在文件都记下来,后面把这些文件内容也纳入 compile cache 的
hash。这样一旦模型相关 Python 源码改了,cache
就会失效并触发重新编译。
TorchCompileWithNoGuardsWrapper:核心目标是
trace once
从实现目标上看,vLLM 并不只是“套了一层
torch.compile”,而是在围绕下面几个目标组织运行时:
- 尽量只做一次 Dynamo tracing
- 尽量把 guard 去掉,避免运行时开销
- 在需要的时候自己接管不同 shape range 的编译与调度
TorchCompileWithNoGuardsWrapper.__init__
里最终还是会调用 torch.compile:
1 | |
但它会根据配置额外做几件事:
- 非
STOCK_TORCH_COMPILE模式下,通过guard_filter_fn跳过大部分 guards UNBACKEDdynamic shapes 时改走check_invariants_and_forward- 可选启用 AOT compile
- 可选用 bytecode hook 保存 Dynamo 变换后的 code object,后续直接 dispatch 到编译后的 bytecode
如果还是有点抽象,可以把它理解成“把 torch.compile
拆成了两层”:
- Dynamo 这一层:尽量只 trace 一次,别反复回到 Python 重新分析字节码
- 后端执行这一层:可以继续按不同 shape range 准备不同的已编译版本
举个简化例子。假设有这样一个模块:
1 | |
如果直接写:
1 | |
PyTorch 默认会保留一批 guards。只要它判断新输入不满足旧 guards,就可能重新走一遍 Dynamo tracing,生成新的编译结果。
而 vLLM 的思路更像是:
1 | |
后面 batch size 变了,vLLM 希望避免“因为 guards 不匹配而重新回到
Dynamo 重新 trace Python”。如果需要针对不同 shape
做不同优化,它更倾向于在 backend 这一层处理,例如把
[1, 8]、[9, 16] 编成不同 range
的子图,然后运行时只做 dispatch。
所以 TorchCompileWithNoGuardsWrapper
的职责不是自己完成所有优化,而是先把最上面那层“trace once”稳住,把后面的
shape-specialization、cache、piecewise compile 交给
VllmBackend 这一层。
CompilationMode.VLLM_COMPILE
会返回 VllmBackend
现在 CompilationConfig.init_backend()
里,三种模式的分工已经很清楚了:
STOCK_TORCH_COMPILE/DYNAMO_TRACE_ONCE:直接返回普通 PyTorch backend 或自定义 backendVLLM_COMPILE:返回VllmBackend(vllm_config)
VllmBackend 在做什么
VllmBackend
是这一节最关键的类。它的工作可以概括成下面几步。
1. 配置 cache
它会综合下面这些因素计算 cache key:
- 环境变量相关因素
VllmConfig的哈希- compiler 自身相关因素
- Dynamo tracing 过程中记录到的源码文件内容
然后把 cache 放到类似下面的目录里:
1 | |
并且 CompilerManager 会把缓存索引存成一个 Python
文件,key 形如:
1 | |
2. 把 post-grad pass manager 接到 Inductor 上
VllmBackend.configure_post_pass() 会创建
PostGradPassManager,再把它挂到当前平台对应的 inductor hook
上。这样之后每个 FX graph 在进 Inductor 时,都会先跑 vLLM 自己的
post-grad passes。
这个阶段还支持把用户自己通过
inductor_compile_config[pass_key] 传入的 pass,追加到 pass
manager 里。
这里的“挂载 post-grad passes”可以直接理解成:把 vLLM 自己的图变换函数注册到 Inductor 的一个回调槽位里。
流程大概是这样:
1 | |
当 Inductor 真正开始处理这个 fx_graph 时,会先调用:
1 | |
vLLM里也就是:
1 | |
跑完这些 pass 之后,Inductor 才继续做后面的 lowering 和 codegen。
“post-grad”这个名字本身比较容易让人误会,好像是在 backward 之后再处理一遍图。这里其实主要是沿用 PyTorch/Inductor 里的术语,表示“这是挂在 Inductor 的 post-grad custom pass hook 上的 pass”。对于 vLLM 这种 inference 场景,更好理解的说法是:
- Dynamo 先把 Python 代码变成 FX graph
- AOTAutograd / Inductor 会拿到一个 functionalized graph
- vLLM 在 Inductor 正式生成 kernel 之前,先对这张 graph 做一轮自己的改写
所以“挂载 post-grad passes”本质上就是:在 Inductor 开始代码生成之前,先插入一层 vLLM 自己的 FX graph rewrite。
3. 决定是在 FX 层切图,还是让 Inductor 自己 partition
如果 use_inductor_graph_partition=False,vLLM 会在 FX
层按 splitting_ops 手工切图:
1 | |
如果 use_inductor_graph_partition=True,那就不在 FX
层切,而是把 splitting_ops 注册成 Inductor partition
rule,让 partition 在更晚的 codegen 阶段发生。这样 custom passes
看到的是更完整的图,更容易做 fusion。
4. 为每个子图创建
PiecewiseBackend
PiecewiseCompileInterpreter
会遍历切出来的子模块,把真正需要编译的 submodule 替换成
PiecewiseBackend。PiecewiseBackend
会在初始化时把所有 compile range 先编完:
- 对
compile_sizes里的单点 size 做静态 shape 编译 - 对
compile_ranges_endpoints生成的 range 做通用 shape 编译
运行时则根据第一个 symbolic shape 参数来 dispatch:
1 | |
也就是说,vLLM 不是等到线上跑到某个 batch size 才临时编译,而是更偏向“冷启动时把需要的几个 range 预编译好,运行时只做选择”。
5. 按需叠加 piecewise cudagraph
如果当前配置需要 piecewise
cudagraph,wrap_with_cudagraph_if_needed() 会再给每个
PiecewiseBackend 外面包一层 cudagraph wrapper。
所以这套结构可以理解成:
- Dynamo 只 trace 一次
VllmBackend把图拆成若干段- 每段由
PiecewiseBackend预编译多个 shape 版本 - 每段外面还可以再套 cudagraph
compile_sizes
和 compile_ranges_endpoints
这里有一个很值得单独强调的点。
vLLM 里同时存在两种“编译粒度”:
compile_sizes:对某些固定 token 数做精确编译,例如[1, 2, 4, 8]compile_ranges_endpoints:对一个区间做通用编译,例如[1, 8]、[9, 16]
而且单点 size 的优先级更高。如果同时存在 [1, 8] 和 size
4,那运行到 4 时会优先走单点 size
对应的已编译图,而不是通用 range 图。
pass manager 顺序
当前实现里的实际执行顺序是:
PassConfig打开的内建 passes- 用户额外注入的 custom pass
PostCleanupPassFixFunctionalizationPass
这里最好不要把这四项都理解成同一类东西。更准确地说,它们分成两组:
- 前两项是“真正做图变换的 pass”
- 后两项是“收尾 pass”
先看前两项。
1. PassConfig
打开的内建 passes
这一类 pass 是 vLLM 自己实现、自己维护的 pass。它们都在
vllm/compilation/passes/ 下面,由
PostGradPassManager.configure() 根据
PassConfig 的开关决定要不要加入
self.passes。
例如源码里会按条件加入这些 pass:
NoOpEliminationPassSequenceParallelismPassAsyncTPPassAllReduceFusionPassRMSNormQuantFusionPassActivationQuantFusionPassAttnFusionPassRopeKVCacheFusionPassQKNormRoPEFusionPass
2. 用户额外注入的 custom pass
这一项不是 vLLM 默认内建的 pass,而是给使用者留的扩展口。
VllmBackend.configure_post_pass() 会检查当前平台对应的
pass_key。如果用户在
inductor_compile_config[pass_key] 里额外放了一个
pass,那么这个 pass 也会被追加到 PostGradPassManager
里。
可以把它理解成:
1 | |
也就是说,vLLM 先把自己那套 pass 配好,然后再给用户一个“在同一个 hook 里再插一个 pass”的机会。
3. PostCleanupPass
从这个 pass 开始,就不是“做某个具体优化”的 pass 了,而是收尾。
PostCleanupPass 做的事情很简单:
- 把 graph 重新做一遍稳定的拓扑排序
- 删除已经没有用户的 dead code
为什么需要它?因为前面的 pattern matcher / graph rewrite 在改图之后,不保证图还是干净、拓扑有序的。比如某个 pattern 被替换掉之后,原来的一些中间节点可能已经没人用了;再比如替换后的节点顺序可能不够规整。
所以 PostCleanupPass 的作用可以概括成一句话:
前面的 pass 负责“改对图”,它负责“把改完的图整理干净”。
4.
FixFunctionalizationPass
这个 pass 也是收尾,但它做的事情比 PostCleanupPass
更特殊。
在 Inductor 这一套流水线里,很多自定义算子在图里会先以
auto_functionalized(...) 的形式出现。这样做有利于 tracing
和 pattern matching,因为 functionalized graph 更规整、更容易改写。
但问题是,functionalized 形式往往会引入额外的中间 tensor、copy,或者把“原地更新”的语义改写成一串看起来更绕的函数式节点。如果直接拿这种图去做最终代码生成,性能和图结构都不理想。
FixFunctionalizationPass
的作用就是在最后把这些关键节点重新
defunctionalize,恢复成更接近真实执行语义的形式。它会处理的对象包括:
- rotary embedding
fused_add_rms_normrms_norm_dynamic_per_token_quantsilu_and_mulfused_qk_norm_ropefused_rope_and_unified_kv_cache_update
它必须最后执行,原因也很简单:
- 前面的很多 fusion / matcher 更依赖 functionalized graph 的规整形态
- 一旦 defunctionalize 完,图就不再适合继续跑普通的 cleanup / DCE 了
所以顺序上可以把这四项理解成:
- 先跑 vLLM 自己的优化 pass
- 再跑用户自己额外插进来的 pass
- 用
PostCleanupPass把图清理干净 - 最后用
FixFunctionalizationPass把 graph 从“便于改写的形式”收回到“便于执行的形式”
vLLM 有哪些 pass
vllm/config/vllm.py 会按 optimization level 给
PassConfig 注入默认值:
O0:基本都关O1:会按条件打开fuse_norm_quant、fuse_act_quant、fuse_act_padding、fuse_rope_kvcacheO2/O3:进一步按模型类型、量化方式、平台能力打开fuse_attn_quant、enable_sp、fuse_gemm_comms、fuse_allreduce_rms等
下面按类别挑几个比较重要的 pass 讲一下。
1. NoOpEliminationPass
这个 pass 还是非常重要,但它现在已经不是一个孤立的小优化,而是很多 fusion pass 的前置清理步骤。
它会删除三类 no-op:
- reshape 链里中间多余的 reshape
- 输出 shape 和输入 shape 等价的 reshape
- 输出 shape 和输入 shape 等价的 slice / slice_scatter
例如:
1 | |
可以直接化简成:
1 | |
源码里的注释明确写了:RMSNorm-quant fusion 依赖这个
pass,因为某些路径里 apply_fp8_linear 会额外引入冗余
reshape。SequenceParallelismPass
也会在自己做完替换后,调用一次 NoOpEliminationPass
清掉临时插进去的 slice。
2.
RMSNormQuantFusionPass
这个 pass 负责把
rms_norm + quantfused_add_rms_norm + quant
替换成单个 fused custom op。
而且它支持的量化模式也比较多,至少包括:
- static fp8 quant
- dynamic per-token fp8 quant
- group fp8 quant(CUDA 上还有不同 group shape、scale layout、e8m0 / TMA 对齐组合)
它的核心思路还是
pattern matcher,但模式数量已经很多了。比如同一个
epsilon,它会把:
- 普通
RMSNorm FusedAddRMSNorm- 静态量化
- 动态量化
- group quant
这些组合全部注册成 pattern。
3.
ActivationQuantFusionPass
这个 pass 做的是 MLP 激活这边的 fusion,把:
silu_and_mul + fp8 quantsilu_and_mul + nvfp4 quant
替换成 fused op,例如 silu_and_mul_quant。
相比
RMSNormQuantFusionPass,它的结构更简单一些,但思路完全一样:先匹配
unfused pattern,再用 auto_functionalized 的 fused custom
op 替换。
4. AttnFusionPass
这个 pass 是把 attention 后面的量化 fuse 进 attention 本身。
也就是说它不是简单地做“attention 节点旁边的小修小补”,而是直接尝试把:
attentionattention output quant
合并成一个支持 fused output quant 的 attention kernel。
这里有一个比较有意思的实现细节:它不是写一个完全泛化的 pattern
去匹配所有 attention,而是先从
CompilationConfig.static_forward_context 里把具体的
Attention 层找出来,再按 layer 注册 pattern。这样做的原因是
attention 层里有一些字符串参数和 layer-specific 元信息,直接 wildcard
不太现实。
5.
SequenceParallelismPass
这个 pass 不是单纯“把通信算子 fuse 一下”,而是先改图的并行策略。
先约定一下张量形状。下面把某一层 hidden states 记成:
1 | |
在 tensor parallel 下,Transformer block 里经常会有一类“row parallel”线性层。它的特点是:
- 每个 rank 只持有一部分权重
- 每个 rank 先各自算出一个 partial output
- 然后通过
AllReduce(sum)把这些 partial output 求和,恢复成完整的 hidden states
所以在图上常见的结构会是:
1 | |
如果把它放回到一个更具体的 LLM block 里,可以把它想成下面这个局部结构:
1 | |
这里的关键点在于:RMSNorm 是按 token 独立做的。
更准确地说,RMSNorm 对每个 token 的 hidden vector
单独计算均方根并归一化。也就是说它在 hidden dimension 上做归一化,但不同
token 之间互不依赖。于是:
- 它需要每个 token 的完整 hidden dimension
- 但它不需要“所有 token 必须都在同一个 rank 上”
这正是 sequence parallelism 能成立的原因。
它会把:
1 | |
改写成:
1 | |
如果后面还接了 quant,也会一起考虑。
这里顺便解释一下
ReduceScatter。它可以理解成“AllReduce 的前半段
+ 把结果切开分给各个 rank”。
例如两个 rank 上分别有:
1 | |
先做 reduce(sum) 得到:
1 | |
再 scatter 成两半:
1 | |
也就是说,AllReduce 是“每个 rank 都拿到完整结果”,而
ReduceScatter 是“每个 rank 只拿到结果的一部分”。
这个改写为什么是对的,可以直接从张量形状来理解。
假设 AllReduce 之后得到的完整输出是
[T, H]。如果 TP world size = 2,那么
ReduceScatter(dim=0) 做的事情可以理解成:
- 先完成原本
AllReduce(sum)该做的求和 - 但不是把完整的
[T, H]都留在每个 rank 上 - 而是把 token 维切开,让 rank0 拿到前一半 token,rank1 拿到后一半 token
于是每个 rank 上拿到的是:
1 | |
注意这里 hidden dimension H 还是完整的,只是 token 维
T 被分片了。
而 RMSNorm 恰好只要求“每个 token 的 hidden vector
是完整的”,并不要求“所有 token 都在当前 rank
上”。所以它可以完全本地执行:
1 | |
如果后面的算子又需要完整的 sequence,再用 AllGather 把
token 维拼回来即可:
1 | |
可以画成下面这样:
1 | |
所以这个 pass 的收益,重点不在于“RMSNorm
本身更快了”,而在于它把一段原本夹在 AllReduce
后面的计算,改成了“通信之后每个 rank
先各算各的”。这样图里就显式出现了:
ReduceScatter- 本地计算
AllGather
后续 AsyncTPPass 才有机会继续识别出:
GEMM + ReduceScatterAllGather + GEMM
这样的模式,并进一步 fuse / overlap compute 和 communication。
换句话说,SequenceParallelismPass
更像是在给后续的通信融合“铺路”。
从性能角度看,它的收益大致有三层。
第一层是减少重复计算。
原来的做法里,AllReduce 之后每个 rank 都拿到完整的
[T, H],于是每个 rank 都会各自做一遍完整的
RMSNorm([T, H])。如果 TP world size 是
N,那这类逐 token 的小算子实际上被重复执行了 N
次。
而改成 sequence parallelism 之后,每个 rank 只持有自己的 token shard
[T/N, H],于是每个 rank 只需要做:
1 | |
也就是说,像 RMSNorm、residual
add、量化这类本来就不需要跨 token 交互的算子,不再需要在每个 rank
上对完整 sequence 重复做一遍。
第二层是降低本地内存带宽压力。
这类算子很多时候不是 compute-bound,而是 memory-bound。把输入从
[T, H] 变成 [T/N, H] 之后:
- 读 hidden states 的流量变少
- 写中间结果的流量变少
- 访问 residual / quant scale 等辅助张量的流量也会一起下降
所以就算不考虑后续 fusion,仅仅把这些小算子下沉到 token shard 上本地执行,也经常能带来可观收益。
第三层也是最重要的一层:给后续 compute + communication 融合创造条件。
单独看 AllReduce 和
ReduceScatter + AllGather,通信量本身并不会凭空消失。真正大的收益通常来自图结构改变之后,后续
pass 可以继续把:
GEMM + ReduceScatterAllGather + GEMM
识别出来,并替换成更底层的 fused collective kernel。这样才能进一步做到:
- 让通信更早开始
- 让通信和计算 overlap
- 减少中间张量落地
所以更准确地说,SequenceParallelismPass
的性能收益来源是:
- 少做重复的逐 token 小算子
- 降低这些小算子的内存访问量
- 为后续
AsyncTPPass提供更适合融合的图结构
因此它本身既有直接收益,也有“铺路型”的间接收益。
这个 pass 还不是无脑开启的,它会看:
- 当前平台是不是 CUDA
- hidden size 是否大到值得做 sequence parallelism
- compile range 是否满足阈值
- 在 piecewise compilation 场景下,当前 size 是否是单点并且能被 TP size 整除
6. AsyncTPPass
这个 pass 可以看成是 SequenceParallelismPass 的后续配套
pass。
它会继续把:
GEMM + ReduceScatterAllGather + GEMMscaled_mm + ReduceScatter
等模式替换成更底层的 fused collective / symmetric memory 算子,比如:
torch.ops.symm_mem.fused_matmul_reduce_scattertorch.ops.symm_mem.fused_all_gather_matmul
所以从 pass 依赖关系上看,可以粗略理解成:
SequenceParallelismPass先把图改造成适合 overlap compute / comm 的样子AsyncTPPass再把这些新结构 fuse 成更高效的 kernel
7. AllReduceFusionPass
这个 pass 走的是另一条路线:不改成
ReduceScatter/AllGather,而是直接把
AllReduce + RMSNormAllReduce + FusedAddRMSNormAllReduce + RMSNorm + Quant
等模式替换成 FlashInfer 提供的 fused kernel。
8. QKNormRoPEFusionPass
这个 pass 把 Q/K 上的 RMSNorm 和 RoPE 合并成一个 fused custom
op:fused_qk_norm_rope。
它要求的信息比普通 pattern 多一些,比如:
head_dimnum_headsnum_kv_headsepsilonis_neox
所以实现上也是先从实际的 Attention 层对象里把这些
metadata 拿出来,再注册 pattern。
另外它前面常常会配合一个 utility pass:
SplitCoalescingPass
这个 pass 会把重复的 split_with_sizes
节点合并掉。因为如果 Q/K/V 是从同一个张量 split
出来的,但图里重复生成了多个 split_with_sizes,后面的
fusion pattern 往往就看不出“它们其实来自同一个 qkv”。
9.
RopeKVCacheFusionPass
这个 pass 主要针对的是:
- rotary embedding
- KV cache update
把这两个步骤合成一个 fused kernel,减少中间 tensor 和 kernel launch。
它的实现还会依赖另外两个 utility passes 先把图整理干净:
SplitCoalescingPassScatterSplitReplacementPass
ScatterSplitReplacementPass
的作用特别适合举例说明。functionalization 之后,rotary embedding
经常会变成:
- 先从
qkv里 split 出q和k - 跑 functionalized 的 rotary op
- 再通过
slice_scatter写回qkv - 然后又 split 一次拿出
q/k/v
如果后面没有别的用户真的需要“写回后的整个 qkv”,那这套
slice_scatter + split
就纯属绕远路。ScatterSplitReplacementPass
会把这段图改回更直接的形式,这样 RopeKVCacheFusionPass
才更容易匹配出融合模式。
另外这个 pass 还会根据 compile_range
控制是否启用,默认只在较小 token
数下生效,因为源码注释里明确写了:它更偏向 small-batch decode,prefill
大 batch 下未必占优。
10.
PostCleanupPass 和
FixFunctionalizationPass
PostCleanupPass 做的事情很朴素:
- stable topological sort
- dead code elimination
原因是 pattern matcher 改图之后,不保证图还是拓扑有序的,也可能留下无用节点。
FixFunctionalizationPass 则更关键。因为 vLLM 里很多
custom op 在 tracing / matching 时会先以
auto_functionalized(...)
的形式出现,但真正执行时如果一直保留 functionalized 版本,往往会引入额外
copy。这个 pass 会把一些重要算子重新 defunctionalize 回去,例如:
- rotary embedding
fused_add_rms_normrms_norm_dynamic_per_token_quantsilu_and_mulfused_qk_norm_ropefused_rope_and_unified_kv_cache_update
源码里甚至专门强调了一句:FixFunctionalizationPass
必须最后执行,而且执行完之后不要再跑 DCE,因为
defunctionalized 图里可能会出现“看起来像 dead
code、但实际上不能随便删”的节点。
小结
vLLM 不是简单地“调用了一次 torch.compile”,而是在 Dynamo
只 trace 一次的前提下,围绕 cache、piecewise compile、compile
ranges、cudagraph 和一组 post-grad custom
passes,搭了一个自己的编译执行框架。
最后放一个vLLM-compile slides