vllm-torch-compile-integration

本文最后更新于 2026年3月19日

vLLM的torch.compile集成在网上资料非常少,官方写的有一篇博客文档

但是很多相关的介绍都是语焉不详,只有宏观上的介绍,没有具体的细节。

笔者正在写这篇的时候,BBuf 发了一篇类似的文章[https://zhuanlan.zhihu.com/p/1955402895890560120]

所以笔者决定自己写一篇。

torch.compile

介绍

这个网上的介绍已经非常多了,在此就不再赘述。

需要明确的是,torch.compile包括两个重要的组成部分:DynamoInductor

整个编译过程大概是下面这样,具体可以参考这篇博客

  • 使用 TorchDynamo 来 trace 函数
  • xxxx
  • Inductor 运行 post_grad_passes,优化计算图。
  • xxxx

自定义backend

torch.compile提供了使用自定义backend的的接口,具体参考这篇文档

backend函数需要符合这样的接口

1
(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable
自定义的后端会被 Dynamo 调用,接收一个fx graph,后端需要返回优化后的、等价的fx graph

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch

def my_custom_backend(gm, example_inputs):
return gm.forward

def f(...):
...

f_opt = torch.compile(f, backend=my_custom_backend)

@torch.compile(backend=my_custom_backend)
def g(...):
...

自定义pass

本节主要参考了下面两篇文档

torch.compile官方提供的pass不满足要求的时候可以选择添加自定义pass

官方总结了这么几种添加pass的目的

  • Axis A: 1. Creating one-to-X mapping (eg. decomposition) 2. Creating many-to-one mapping (eg. fusion)
  • Axis B: 1. Doing forwards iteration (eg. shape propagation) 2. Doing backwards iteration (eg. dead code elimination)
  • Axis C: 1. Dependent on local node information (eg. out-variant conversion) 2. Dependent on global graph information (eg. memory planning)

修改计算图最直接的方式就是直接操作生成的fx graph

1
2
3
4
def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
node.target = torch.ops.aten.mul.Tensor
replace_add_with_mul会把所有的add操作都替换为mul

当然直接操作计算图是比较繁琐的,所以pytorch也提供了一些比较方便的 helper utilities。

我目前接触过的,大概有三种pytorch提供的自定义pass的方式

  • torch.fx.Transformer
  • subgraph_rewriter
  • pattern matcher

torch.fx.Transformer

1
2
3
4
5
6
7
class ReplaceAddWithMul(torch.fx.Transformer):
def call_function(self, target, args, kwargs):
if target != torch.ops.aten.add.Tensor:
return super().call_function(target, args, kwargs)
return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs)

transformed_graph_module = ReplaceAddWithMul(graph_module).transform()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class ReplaceAddWithMulSub(torch.fx.Transformer):
"""
Original:
def f(x, y):
return x + y

After pass:
def f(x, y):
z = x * y
return z - y
"""
def call_function(self, target, args, kwargs):
if target != torch.ops.aten.add.Tensor:
return super().call_function(target, args, kwargs)

x, y = args

mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {})
return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {})

transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform()

这种方式只能比较方便地对一个 node 做变换。

subgraph_rewriter

如果想要实现 fusion,就需要 X->one 或者 X->Y 的能力

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from torch.fx import subgraph_rewriter
# This is an inplace operation.
def replace_patterns(graph_module):
"""
Original:
def f(x, y):
x = x + y
x = x * y
return x

After pass:
def f(x, y):
return x - y
"""
def pattern(x, y):
x = torch.ops.aten.add.Tensor(x, y)
x = torch.ops.aten.mul.Tensor(x, y)
return x

def replacement(x, y):
return torch.ops.aten.sub.Tensor(x, y)

replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
traced_module, pattern, replacement
)

pattern matcher

看起来pattern matchersubgraph_rewriter非常相似,我也没搞明白他们俩的区别和联系是什么,为什么要搞两套api

具体可以参考这篇教程,有一个完整的例子。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def pattern(x, y):
x = torch.ops.aten.add.Tensor(x, y)
x = torch.ops.aten.mul.Tensor(x, y)
return x

def replacement(x, y):
return torch.ops.aten.sub.Tensor(x, y)

example_inputs = [
torch.randn(1, 4),
torch.randn(1, 1),
]

from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._inductor import config

# Create a pattern matcher pass and register our pattern
patterns = PatternMatcherPass()

register_replacement(
pattern,
replacement,
example_inputs,
pm.fwd_only,
patterns,
)

# Create a custom pass function that applies our patterns
def fusion_pass(graph):
return patterns.apply(graph)

# Set our custom pass in the config
config.post_grad_custom_post_pass = fusion_pass

这个 api 我用的时候还是有些问题,比如说 pattern 的参数不能是常量,一定要是 tensor。 所以如果某个参数是常量,并且有多个可能的取值,那就只能针对每个取值都生成一个pass,取值有无限多个的情况我也不知道怎么办。

Pass Manager

当后端有了多个自定义pass的时候,他们会统一在pass manager处理。

示意代码如下:

1
2
3
4
5
6
7
8
9
10
11
class PassManager:
def __init__(self):
self.passes: list[InductorPass] = []
def __call__(self, graph: fx.Graph):
for pass_ in self.passes:
pass_(graph)
VllmInductorPass.dump_prefix += 1
self.post_cleanup(graph)
def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)

vLLM如何利用torch.compile

上面那部分更偏torch.compile本身,这一节专门讲 vLLM 源码是怎么把这套机制用起来的。

  1. @support_torch_compile 给模型类注入编译能力
  2. TorchCompileWithNoGuardsWrapper 负责真正调用 torch.compile
  3. CompilationConfig.init_backend()VLLM_COMPILE 模式下返回 VllmBackend
  4. VllmBackend 负责 cache、图切分、按 range 预编译、挂载 post-grad passes
  5. PiecewiseBackend 在运行时按 token 数选择合适的已编译子图

入口:@support_torch_compile

它主要做了几件事:

  • 推断或者接收 dynamic_arg_dims
  • TorchCompileWithNoGuardsWrapper 动态加到类的基类里
  • 在新的 __init__ 里根据 CompilationModeenable_ifignore_torch_compile 等条件决定要不要真的编译
  • 在第一次调用时给输入打 dynamic shape 标记
  • 在 AOT compile 打开时尝试从磁盘直接加载已编译产物

简化一下,入口大概是这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def _support_torch_compile(cls, dynamic_arg_dims, ...):
cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,)
old_init = cls.__init__

def __init__(self, *, vllm_config=None, prefix="", **kwargs):
old_init(self, ...)
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.do_not_compile = (
self.compilation_config.mode in [CompilationMode.NONE,
CompilationMode.STOCK_TORCH_COMPILE]
or ...
)
if not self.do_not_compile:
TorchCompileWithNoGuardsWrapper.__init__(self)

cls.__init__ = __init__

这里有两个很关键的点。

第一,vLLM 不是完全依赖 Dynamo 自己去猜哪些维度是动态的,而是会在第一次编译前显式调用 torch._dynamo.mark_dynamic / mark_unbacked 去标记输入维度。

第二,vLLM 会在 tracing 期间收集 forward 相关源码文件。做法有点“黑科技”:它会 patch InliningInstructionTranslator.inline_call_,把 Dynamo 内联到的函数所在文件都记下来,后面把这些文件内容也纳入 compile cache 的 hash。这样一旦模型相关 Python 源码改了,cache 就会失效并触发重新编译。

TorchCompileWithNoGuardsWrapper:核心目标是 trace once

从实现目标上看,vLLM 并不只是“套了一层 torch.compile”,而是在围绕下面几个目标组织运行时:

  • 尽量只做一次 Dynamo tracing
  • 尽量把 guard 去掉,避免运行时开销
  • 在需要的时候自己接管不同 shape range 的编译与调度

TorchCompileWithNoGuardsWrapper.__init__ 里最终还是会调用 torch.compile

1
2
3
4
5
6
7
self._compiled_callable = torch.compile(
compiled_ptr,
fullgraph=True,
dynamic=False,
backend=backend,
options=options,
)

但它会根据配置额外做几件事:

  • STOCK_TORCH_COMPILE 模式下,通过 guard_filter_fn 跳过大部分 guards
  • UNBACKED dynamic shapes 时改走 check_invariants_and_forward
  • 可选启用 AOT compile
  • 可选用 bytecode hook 保存 Dynamo 变换后的 code object,后续直接 dispatch 到编译后的 bytecode

如果还是有点抽象,可以把它理解成“把 torch.compile 拆成了两层”:

  • Dynamo 这一层:尽量只 trace 一次,别反复回到 Python 重新分析字节码
  • 后端执行这一层:可以继续按不同 shape range 准备不同的已编译版本

举个简化例子。假设有这样一个模块:

1
2
3
class Toy(torch.nn.Module):
def forward(self, x):
return x * 2

如果直接写:

1
2
3
toy = torch.compile(Toy())
toy(torch.randn(4, 128))
toy(torch.randn(8, 128))

PyTorch 默认会保留一批 guards。只要它判断新输入不满足旧 guards,就可能重新走一遍 Dynamo tracing,生成新的编译结果。

而 vLLM 的思路更像是:

1
2
3
4
5
6
7
8
9
10
11
model = TorchCompileWithNoGuardsWrapper(...)

# 第一次调用时:
# 1. 标记动态维度
# 2. Dynamo trace 一次
# 3. 尽量丢掉 guards
# 4. 把后续真正执行的逻辑交给 backend

model(x_bs_4)
model(x_bs_8)
model(x_bs_16)

后面 batch size 变了,vLLM 希望避免“因为 guards 不匹配而重新回到 Dynamo 重新 trace Python”。如果需要针对不同 shape 做不同优化,它更倾向于在 backend 这一层处理,例如把 [1, 8][9, 16] 编成不同 range 的子图,然后运行时只做 dispatch。

所以 TorchCompileWithNoGuardsWrapper 的职责不是自己完成所有优化,而是先把最上面那层“trace once”稳住,把后面的 shape-specialization、cache、piecewise compile 交给 VllmBackend 这一层。

CompilationMode.VLLM_COMPILE 会返回 VllmBackend

现在 CompilationConfig.init_backend() 里,三种模式的分工已经很清楚了:

  • STOCK_TORCH_COMPILE / DYNAMO_TRACE_ONCE:直接返回普通 PyTorch backend 或自定义 backend
  • VLLM_COMPILE:返回 VllmBackend(vllm_config)

VllmBackend 在做什么

VllmBackend 是这一节最关键的类。它的工作可以概括成下面几步。

1. 配置 cache

它会综合下面这些因素计算 cache key:

  • 环境变量相关因素
  • VllmConfig 的哈希
  • compiler 自身相关因素
  • Dynamo tracing 过程中记录到的源码文件内容

然后把 cache 放到类似下面的目录里:

1
VLLM_CACHE_ROOT/torch_compile_cache/<hash>/rank_<rank>_<dp_rank>/<prefix>/

并且 CompilerManager 会把缓存索引存成一个 Python 文件,key 形如:

1
(compile_range, graph_index, compiler_name) -> handle

2. 把 post-grad pass manager 接到 Inductor 上

VllmBackend.configure_post_pass() 会创建 PostGradPassManager,再把它挂到当前平台对应的 inductor hook 上。这样之后每个 FX graph 在进 Inductor 时,都会先跑 vLLM 自己的 post-grad passes。

这个阶段还支持把用户自己通过 inductor_compile_config[pass_key] 传入的 pass,追加到 pass manager 里。

这里的“挂载 post-grad passes”可以直接理解成:把 vLLM 自己的图变换函数注册到 Inductor 的一个回调槽位里。

流程大概是这样:

1
2
3
4
5
fx_graph = Dynamo(...).trace(...)

inductor_config[pass_key] = PostGradPassManager(...)

compiled = Inductor.compile(fx_graph, config=inductor_config)

当 Inductor 真正开始处理这个 fx_graph 时,会先调用:

1
PostGradPassManager(graph)

vLLM里也就是:

1
2
3
4
5
6
NoOpEliminationPass(graph)
SequenceParallelismPass(graph)
RMSNormQuantFusionPass(graph)
...
PostCleanupPass(graph)
FixFunctionalizationPass(graph)

跑完这些 pass 之后,Inductor 才继续做后面的 lowering 和 codegen。

“post-grad”这个名字本身比较容易让人误会,好像是在 backward 之后再处理一遍图。这里其实主要是沿用 PyTorch/Inductor 里的术语,表示“这是挂在 Inductor 的 post-grad custom pass hook 上的 pass”。对于 vLLM 这种 inference 场景,更好理解的说法是:

  • Dynamo 先把 Python 代码变成 FX graph
  • AOTAutograd / Inductor 会拿到一个 functionalized graph
  • vLLM 在 Inductor 正式生成 kernel 之前,先对这张 graph 做一轮自己的改写

所以“挂载 post-grad passes”本质上就是:在 Inductor 开始代码生成之前,先插入一层 vLLM 自己的 FX graph rewrite。

3. 决定是在 FX 层切图,还是让 Inductor 自己 partition

如果 use_inductor_graph_partition=False,vLLM 会在 FX 层按 splitting_ops 手工切图:

1
self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)

如果 use_inductor_graph_partition=True,那就不在 FX 层切,而是把 splitting_ops 注册成 Inductor partition rule,让 partition 在更晚的 codegen 阶段发生。这样 custom passes 看到的是更完整的图,更容易做 fusion。

4. 为每个子图创建 PiecewiseBackend

PiecewiseCompileInterpreter 会遍历切出来的子模块,把真正需要编译的 submodule 替换成 PiecewiseBackendPiecewiseBackend 会在初始化时把所有 compile range 先编完:

  • compile_sizes 里的单点 size 做静态 shape 编译
  • compile_ranges_endpoints 生成的 range 做通用 shape 编译

运行时则根据第一个 symbolic shape 参数来 dispatch:

1
2
3
runtime_shape = args[self.sym_shape_indices[0]]
range_entry = self._find_range_for_shape(runtime_shape)
return range_entry.runnable(*args)

也就是说,vLLM 不是等到线上跑到某个 batch size 才临时编译,而是更偏向“冷启动时把需要的几个 range 预编译好,运行时只做选择”。

5. 按需叠加 piecewise cudagraph

如果当前配置需要 piecewise cudagraph,wrap_with_cudagraph_if_needed() 会再给每个 PiecewiseBackend 外面包一层 cudagraph wrapper。

所以这套结构可以理解成:

  • Dynamo 只 trace 一次
  • VllmBackend 把图拆成若干段
  • 每段由 PiecewiseBackend 预编译多个 shape 版本
  • 每段外面还可以再套 cudagraph

compile_sizescompile_ranges_endpoints

这里有一个很值得单独强调的点。

vLLM 里同时存在两种“编译粒度”:

  • compile_sizes:对某些固定 token 数做精确编译,例如 [1, 2, 4, 8]
  • compile_ranges_endpoints:对一个区间做通用编译,例如 [1, 8][9, 16]

而且单点 size 的优先级更高。如果同时存在 [1, 8] 和 size 4,那运行到 4 时会优先走单点 size 对应的已编译图,而不是通用 range 图。

pass manager 顺序

当前实现里的实际执行顺序是:

  1. PassConfig 打开的内建 passes
  2. 用户额外注入的 custom pass
  3. PostCleanupPass
  4. FixFunctionalizationPass

这里最好不要把这四项都理解成同一类东西。更准确地说,它们分成两组:

  • 前两项是“真正做图变换的 pass”
  • 后两项是“收尾 pass”

先看前两项。

1. PassConfig 打开的内建 passes

这一类 pass 是 vLLM 自己实现、自己维护的 pass。它们都在 vllm/compilation/passes/ 下面,由 PostGradPassManager.configure() 根据 PassConfig 的开关决定要不要加入 self.passes

例如源码里会按条件加入这些 pass:

  • NoOpEliminationPass
  • SequenceParallelismPass
  • AsyncTPPass
  • AllReduceFusionPass
  • RMSNormQuantFusionPass
  • ActivationQuantFusionPass
  • AttnFusionPass
  • RopeKVCacheFusionPass
  • QKNormRoPEFusionPass

2. 用户额外注入的 custom pass

这一项不是 vLLM 默认内建的 pass,而是给使用者留的扩展口。

VllmBackend.configure_post_pass() 会检查当前平台对应的 pass_key。如果用户在 inductor_compile_config[pass_key] 里额外放了一个 pass,那么这个 pass 也会被追加到 PostGradPassManager 里。

可以把它理解成:

1
2
3
pass_manager = PostGradPassManager()
pass_manager.add(vllm_builtin_passes...)
pass_manager.add(user_custom_pass)

也就是说,vLLM 先把自己那套 pass 配好,然后再给用户一个“在同一个 hook 里再插一个 pass”的机会。

3. PostCleanupPass

从这个 pass 开始,就不是“做某个具体优化”的 pass 了,而是收尾。

PostCleanupPass 做的事情很简单:

  • 把 graph 重新做一遍稳定的拓扑排序
  • 删除已经没有用户的 dead code

为什么需要它?因为前面的 pattern matcher / graph rewrite 在改图之后,不保证图还是干净、拓扑有序的。比如某个 pattern 被替换掉之后,原来的一些中间节点可能已经没人用了;再比如替换后的节点顺序可能不够规整。

所以 PostCleanupPass 的作用可以概括成一句话:

前面的 pass 负责“改对图”,它负责“把改完的图整理干净”。

4. FixFunctionalizationPass

这个 pass 也是收尾,但它做的事情比 PostCleanupPass 更特殊。

在 Inductor 这一套流水线里,很多自定义算子在图里会先以 auto_functionalized(...) 的形式出现。这样做有利于 tracing 和 pattern matching,因为 functionalized graph 更规整、更容易改写。

但问题是,functionalized 形式往往会引入额外的中间 tensor、copy,或者把“原地更新”的语义改写成一串看起来更绕的函数式节点。如果直接拿这种图去做最终代码生成,性能和图结构都不理想。

FixFunctionalizationPass 的作用就是在最后把这些关键节点重新 defunctionalize,恢复成更接近真实执行语义的形式。它会处理的对象包括:

  • rotary embedding
  • fused_add_rms_norm
  • rms_norm_dynamic_per_token_quant
  • silu_and_mul
  • fused_qk_norm_rope
  • fused_rope_and_unified_kv_cache_update

它必须最后执行,原因也很简单:

  • 前面的很多 fusion / matcher 更依赖 functionalized graph 的规整形态
  • 一旦 defunctionalize 完,图就不再适合继续跑普通的 cleanup / DCE 了

所以顺序上可以把这四项理解成:

  1. 先跑 vLLM 自己的优化 pass
  2. 再跑用户自己额外插进来的 pass
  3. PostCleanupPass 把图清理干净
  4. 最后用 FixFunctionalizationPass 把 graph 从“便于改写的形式”收回到“便于执行的形式”

vLLM 有哪些 pass

vllm/config/vllm.py 会按 optimization level 给 PassConfig 注入默认值:

  • O0:基本都关
  • O1:会按条件打开 fuse_norm_quantfuse_act_quantfuse_act_paddingfuse_rope_kvcache
  • O2/O3:进一步按模型类型、量化方式、平台能力打开 fuse_attn_quantenable_spfuse_gemm_commsfuse_allreduce_rms

下面按类别挑几个比较重要的 pass 讲一下。

1. NoOpEliminationPass

这个 pass 还是非常重要,但它现在已经不是一个孤立的小优化,而是很多 fusion pass 的前置清理步骤。

它会删除三类 no-op:

  • reshape 链里中间多余的 reshape
  • 输出 shape 和输入 shape 等价的 reshape
  • 输出 shape 和输入 shape 等价的 slice / slice_scatter

例如:

1
2
3
4
mul_1 = ...
view_1 = torch.reshape(mul_1, [-1, 128, 32])
view_2 = torch.reshape(view_1, [-1, 4096])
view_3 = torch.reshape(view_2, [-1, 128, 32])

可以直接化简成:

1
2
mul_1 = ...
view_3 = torch.reshape(mul_1, [-1, 128, 32])

源码里的注释明确写了:RMSNorm-quant fusion 依赖这个 pass,因为某些路径里 apply_fp8_linear 会额外引入冗余 reshape。SequenceParallelismPass 也会在自己做完替换后,调用一次 NoOpEliminationPass 清掉临时插进去的 slice。

2. RMSNormQuantFusionPass

这个 pass 负责把

  • rms_norm + quant
  • fused_add_rms_norm + quant

替换成单个 fused custom op。

而且它支持的量化模式也比较多,至少包括:

  • static fp8 quant
  • dynamic per-token fp8 quant
  • group fp8 quant(CUDA 上还有不同 group shape、scale layout、e8m0 / TMA 对齐组合)

它的核心思路还是 pattern matcher,但模式数量已经很多了。比如同一个 epsilon,它会把:

  • 普通 RMSNorm
  • FusedAddRMSNorm
  • 静态量化
  • 动态量化
  • group quant

这些组合全部注册成 pattern。

3. ActivationQuantFusionPass

这个 pass 做的是 MLP 激活这边的 fusion,把:

  • silu_and_mul + fp8 quant
  • silu_and_mul + nvfp4 quant

替换成 fused op,例如 silu_and_mul_quant

相比 RMSNormQuantFusionPass,它的结构更简单一些,但思路完全一样:先匹配 unfused pattern,再用 auto_functionalized 的 fused custom op 替换。

4. AttnFusionPass

这个 pass 是把 attention 后面的量化 fuse 进 attention 本身。

也就是说它不是简单地做“attention 节点旁边的小修小补”,而是直接尝试把:

  • attention
  • attention output quant

合并成一个支持 fused output quant 的 attention kernel。

这里有一个比较有意思的实现细节:它不是写一个完全泛化的 pattern 去匹配所有 attention,而是先从 CompilationConfig.static_forward_context 里把具体的 Attention 层找出来,再按 layer 注册 pattern。这样做的原因是 attention 层里有一些字符串参数和 layer-specific 元信息,直接 wildcard 不太现实。

5. SequenceParallelismPass

这个 pass 不是单纯“把通信算子 fuse 一下”,而是先改图的并行策略。

先约定一下张量形状。下面把某一层 hidden states 记成:

1
hidden_states.shape = [num_tokens, hidden_size]

在 tensor parallel 下,Transformer block 里经常会有一类“row parallel”线性层。它的特点是:

  • 每个 rank 只持有一部分权重
  • 每个 rank 先各自算出一个 partial output
  • 然后通过 AllReduce(sum) 把这些 partial output 求和,恢复成完整的 hidden states

所以在图上常见的结构会是:

1
partial_output(rank0/rank1/...) -> AllReduce -> RMSNorm -> ...

如果把它放回到一个更具体的 LLM block 里,可以把它想成下面这个局部结构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
                        Transformer block 的一小段

+--------------------------------------+
hidden_states --->| attention / MLP 的 row-parallel 投影 |---> partial output
+--------------------------------------+ (每个 rank 一份)
|
v
AllReduce(sum)
|
v
full hidden states
shape = [T, H]
|
v
RMSNorm
|
v
下一层

这里的关键点在于:RMSNorm 是按 token 独立做的。

更准确地说,RMSNorm 对每个 token 的 hidden vector 单独计算均方根并归一化。也就是说它在 hidden dimension 上做归一化,但不同 token 之间互不依赖。于是:

  • 它需要每个 token 的完整 hidden dimension
  • 但它不需要“所有 token 必须都在同一个 rank 上”

这正是 sequence parallelism 能成立的原因。

它会把:

1
AllReduce -> RMSNorm -> Output

改写成:

1
ReduceScatter -> local RMSNorm -> AllGather -> Output

如果后面还接了 quant,也会一起考虑。

这里顺便解释一下 ReduceScatter。它可以理解成“AllReduce 的前半段 + 把结果切开分给各个 rank”。

例如两个 rank 上分别有:

1
2
rank0: [1, 2, 3, 4]
rank1: [10, 20, 30, 40]

先做 reduce(sum) 得到:

1
[11, 22, 33, 44]

再 scatter 成两半:

1
2
rank0: [11, 22]
rank1: [33, 44]

也就是说,AllReduce 是“每个 rank 都拿到完整结果”,而 ReduceScatter 是“每个 rank 只拿到结果的一部分”。

这个改写为什么是对的,可以直接从张量形状来理解。

假设 AllReduce 之后得到的完整输出是 [T, H]。如果 TP world size = 2,那么 ReduceScatter(dim=0) 做的事情可以理解成:

  1. 先完成原本 AllReduce(sum) 该做的求和
  2. 但不是把完整的 [T, H] 都留在每个 rank 上
  3. 而是把 token 维切开,让 rank0 拿到前一半 token,rank1 拿到后一半 token

于是每个 rank 上拿到的是:

1
2
rank0: [T/2, H]
rank1: [T/2, H]

注意这里 hidden dimension H 还是完整的,只是 token 维 T 被分片了。

RMSNorm 恰好只要求“每个 token 的 hidden vector 是完整的”,并不要求“所有 token 都在当前 rank 上”。所以它可以完全本地执行:

1
2
rank0: local RMSNorm on [T/2, H]
rank1: local RMSNorm on [T/2, H]

如果后面的算子又需要完整的 sequence,再用 AllGather 把 token 维拼回来即可:

1
2
3
4
5
6
7
8
9
10
11
Before:
每个 rank 都持有 [T, H]

After ReduceScatter:
每个 rank 只持有自己的 token shard [T/TP, H]

After local RMSNorm:
结果仍然是 [T/TP, H]

After AllGather:
再恢复成每个 rank 都有完整的 [T, H]

可以画成下面这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
原始做法

rank0 partial ----\
rank1 partial ----- AllReduce(sum) -> [T, H] -> RMSNorm([T, H]) -> 下一步
rank2 partial ----/


Sequence Parallelism 做法

rank0 partial ----\
rank1 partial ----- ReduceScatter(sum, dim=token) -> rank0:[T/TP,H]
rank2 partial ----/ rank1:[T/TP,H]
rank2:[T/TP,H]
|
v
每个 rank 本地做 RMSNorm([T/TP, H])
|
v
AllGather(dim=token) -> [T, H]
|
v
下一步

所以这个 pass 的收益,重点不在于“RMSNorm 本身更快了”,而在于它把一段原本夹在 AllReduce 后面的计算,改成了“通信之后每个 rank 先各算各的”。这样图里就显式出现了:

  • ReduceScatter
  • 本地计算
  • AllGather

后续 AsyncTPPass 才有机会继续识别出:

  • GEMM + ReduceScatter
  • AllGather + GEMM

这样的模式,并进一步 fuse / overlap compute 和 communication。

换句话说,SequenceParallelismPass 更像是在给后续的通信融合“铺路”。

从性能角度看,它的收益大致有三层。

第一层是减少重复计算。

原来的做法里,AllReduce 之后每个 rank 都拿到完整的 [T, H],于是每个 rank 都会各自做一遍完整的 RMSNorm([T, H])。如果 TP world size 是 N,那这类逐 token 的小算子实际上被重复执行了 N 次。

而改成 sequence parallelism 之后,每个 rank 只持有自己的 token shard [T/N, H],于是每个 rank 只需要做:

1
local RMSNorm([T/N, H])

也就是说,像 RMSNorm、residual add、量化这类本来就不需要跨 token 交互的算子,不再需要在每个 rank 上对完整 sequence 重复做一遍。

第二层是降低本地内存带宽压力。

这类算子很多时候不是 compute-bound,而是 memory-bound。把输入从 [T, H] 变成 [T/N, H] 之后:

  • 读 hidden states 的流量变少
  • 写中间结果的流量变少
  • 访问 residual / quant scale 等辅助张量的流量也会一起下降

所以就算不考虑后续 fusion,仅仅把这些小算子下沉到 token shard 上本地执行,也经常能带来可观收益。

第三层也是最重要的一层:给后续 compute + communication 融合创造条件。

单独看 AllReduceReduceScatter + AllGather,通信量本身并不会凭空消失。真正大的收益通常来自图结构改变之后,后续 pass 可以继续把:

  • GEMM + ReduceScatter
  • AllGather + GEMM

识别出来,并替换成更底层的 fused collective kernel。这样才能进一步做到:

  • 让通信更早开始
  • 让通信和计算 overlap
  • 减少中间张量落地

所以更准确地说,SequenceParallelismPass 的性能收益来源是:

  1. 少做重复的逐 token 小算子
  2. 降低这些小算子的内存访问量
  3. 为后续 AsyncTPPass 提供更适合融合的图结构

因此它本身既有直接收益,也有“铺路型”的间接收益。

这个 pass 还不是无脑开启的,它会看:

  • 当前平台是不是 CUDA
  • hidden size 是否大到值得做 sequence parallelism
  • compile range 是否满足阈值
  • 在 piecewise compilation 场景下,当前 size 是否是单点并且能被 TP size 整除

6. AsyncTPPass

这个 pass 可以看成是 SequenceParallelismPass 的后续配套 pass。

它会继续把:

  • GEMM + ReduceScatter
  • AllGather + GEMM
  • scaled_mm + ReduceScatter

等模式替换成更底层的 fused collective / symmetric memory 算子,比如:

  • torch.ops.symm_mem.fused_matmul_reduce_scatter
  • torch.ops.symm_mem.fused_all_gather_matmul

所以从 pass 依赖关系上看,可以粗略理解成:

  1. SequenceParallelismPass 先把图改造成适合 overlap compute / comm 的样子
  2. AsyncTPPass 再把这些新结构 fuse 成更高效的 kernel

7. AllReduceFusionPass

这个 pass 走的是另一条路线:不改成 ReduceScatter/AllGather,而是直接把

  • AllReduce + RMSNorm
  • AllReduce + FusedAddRMSNorm
  • AllReduce + RMSNorm + Quant

等模式替换成 FlashInfer 提供的 fused kernel。

8. QKNormRoPEFusionPass

这个 pass 把 Q/K 上的 RMSNorm 和 RoPE 合并成一个 fused custom op:fused_qk_norm_rope

它要求的信息比普通 pattern 多一些,比如:

  • head_dim
  • num_heads
  • num_kv_heads
  • epsilon
  • is_neox

所以实现上也是先从实际的 Attention 层对象里把这些 metadata 拿出来,再注册 pattern。

另外它前面常常会配合一个 utility pass:

  • SplitCoalescingPass

这个 pass 会把重复的 split_with_sizes 节点合并掉。因为如果 Q/K/V 是从同一个张量 split 出来的,但图里重复生成了多个 split_with_sizes,后面的 fusion pattern 往往就看不出“它们其实来自同一个 qkv”。

9. RopeKVCacheFusionPass

这个 pass 主要针对的是:

  • rotary embedding
  • KV cache update

把这两个步骤合成一个 fused kernel,减少中间 tensor 和 kernel launch。

它的实现还会依赖另外两个 utility passes 先把图整理干净:

  • SplitCoalescingPass
  • ScatterSplitReplacementPass

ScatterSplitReplacementPass 的作用特别适合举例说明。functionalization 之后,rotary embedding 经常会变成:

  1. 先从 qkv 里 split 出 qk
  2. 跑 functionalized 的 rotary op
  3. 再通过 slice_scatter 写回 qkv
  4. 然后又 split 一次拿出 q/k/v

如果后面没有别的用户真的需要“写回后的整个 qkv”,那这套 slice_scatter + split 就纯属绕远路。ScatterSplitReplacementPass 会把这段图改回更直接的形式,这样 RopeKVCacheFusionPass 才更容易匹配出融合模式。

另外这个 pass 还会根据 compile_range 控制是否启用,默认只在较小 token 数下生效,因为源码注释里明确写了:它更偏向 small-batch decode,prefill 大 batch 下未必占优。

10. PostCleanupPassFixFunctionalizationPass

PostCleanupPass 做的事情很朴素:

  • stable topological sort
  • dead code elimination

原因是 pattern matcher 改图之后,不保证图还是拓扑有序的,也可能留下无用节点。

FixFunctionalizationPass 则更关键。因为 vLLM 里很多 custom op 在 tracing / matching 时会先以 auto_functionalized(...) 的形式出现,但真正执行时如果一直保留 functionalized 版本,往往会引入额外 copy。这个 pass 会把一些重要算子重新 defunctionalize 回去,例如:

  • rotary embedding
  • fused_add_rms_norm
  • rms_norm_dynamic_per_token_quant
  • silu_and_mul
  • fused_qk_norm_rope
  • fused_rope_and_unified_kv_cache_update

源码里甚至专门强调了一句:FixFunctionalizationPass 必须最后执行,而且执行完之后不要再跑 DCE,因为 defunctionalized 图里可能会出现“看起来像 dead code、但实际上不能随便删”的节点。

小结

vLLM 不是简单地“调用了一次 torch.compile”,而是在 Dynamo 只 trace 一次的前提下,围绕 cache、piecewise compile、compile ranges、cudagraph 和一组 post-grad custom passes,搭了一个自己的编译执行框架。

最后放一个vLLM-compile slides