vllm-torch-compile-integration

本文最后更新于 2025年10月5日

vLLM的torch.compile集成在网上资料非常少,官方写的有一篇博客文档

但是很多相关的介绍都是语焉不详,只有宏观上的介绍,没有具体的细节。

笔者正在写这篇的时候,BBuf 发了一篇类似的文章[https://zhuanlan.zhihu.com/p/1955402895890560120]

所以笔者决定自己写一篇。

torch.compile

介绍

这个网上的介绍已经非常多了,在此就不再赘述。

需要明确的是,torch.compile包括两个重要的组成部分:DynamoInductor

整个编译过程大概是下面这样,具体可以参考这篇博客

  • 使用 TorchDynamo 来 trace 函数
  • xxxx
  • Inductor 运行 post_grad_passes,优化计算图。
  • xxxx

自定义backend

torch.compile提供了使用自定义backend的的接口,具体参考这篇文档

backend函数需要符合这样的接口

1
(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable
自定义的后端会被 Dynamo 调用,接收一个fx graph,后端需要返回优化后的、等价的fx graph

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch

def my_custom_backend(gm, example_inputs):
return gm.forward

def f(...):
...

f_opt = torch.compile(f, backend=my_custom_backend)

@torch.compile(backend=my_custom_backend)
def g(...):
...

自定义pass

本节主要参考了下面两篇文档

torch.compile官方提供的pass不满足要求的时候可以选择添加自定义pass

官方总结了这么几种添加pass的目的

  • Axis A: 1. Creating one-to-X mapping (eg. decomposition) 2. Creating many-to-one mapping (eg. fusion)
  • Axis B: 1. Doing forwards iteration (eg. shape propagation) 2. Doing backwards iteration (eg. dead code elimination)
  • Axis C: 1. Dependent on local node information (eg. out-variant conversion) 2. Dependent on global graph information (eg. memory planning)

修改计算图最直接的方式就是直接操作生成的fx graph

1
2
3
4
def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
node.target = torch.ops.aten.mul.Tensor
replace_add_with_mul会把所有的add操作都替换为mul

当然直接操作计算图是比较繁琐的,所以pytorch也提供了一些比较方便的 helper utilities。

我目前接触过的,大概有三种pytorch提供的自定义pass的方式

  • torch.fx.Transformer
  • subgraph_rewriter
  • pattern matcher

torch.fx.Transformer

1
2
3
4
5
6
7
class ReplaceAddWithMul(torch.fx.Transformer):
def call_function(self, target, args, kwargs):
if target != torch.ops.aten.add.Tensor:
return super().call_function(target, args, kwargs)
return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs)

transformed_graph_module = ReplaceAddWithMul(graph_module).transform()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class ReplaceAddWithMulSub(torch.fx.Transformer):
"""
Original:
def f(x, y):
return x + y

After pass:
def f(x, y):
z = x * y
return z - y
"""
def call_function(self, target, args, kwargs):
if target != torch.ops.aten.add.Tensor:
return super().call_function(target, args, kwargs)

x, y = args

mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {})
return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {})

transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform()

这种方式只能比较方便地对一个 node 做变换。

subgraph_rewriter

如果想要实现 fusion,就需要 X->one 或者 X->Y 的能力

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from torch.fx import subgraph_rewriter
# This is an inplace operation.
def replace_patterns(graph_module):
"""
Original:
def f(x, y):
x = x + y
x = x * y
return x

After pass:
def f(x, y):
return x - y
"""
def pattern(x, y):
x = torch.ops.aten.add.Tensor(x, y)
x = torch.ops.aten.mul.Tensor(x, y)
return x

def replacement(x, y):
return torch.ops.aten.sub.Tensor(x, y)

replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
traced_module, pattern, replacement
)

pattern matcher

看起来pattern matchersubgraph_rewriter非常相似,我也没搞明白他们俩的区别和联系是什么,为什么要搞两套api

具体可以参考这篇教程,有一个完整的例子。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def pattern(x, y):
x = torch.ops.aten.add.Tensor(x, y)
x = torch.ops.aten.mul.Tensor(x, y)
return x

def replacement(x, y):
return torch.ops.aten.sub.Tensor(x, y)

example_inputs = [
torch.randn(1, 4),
torch.randn(1, 1),
]

from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._inductor import config

# Create a pattern matcher pass and register our pattern
patterns = PatternMatcherPass()

register_replacement(
pattern,
replacement,
example_inputs,
pm.fwd_only,
patterns,
)

# Create a custom pass function that applies our patterns
def fusion_pass(graph):
return patterns.apply(graph)

# Set our custom pass in the config
config.post_grad_custom_post_pass = fusion_pass

这个 api 我用的时候还是有些问题,比如说 pattern 的参数不能是常量,一定要是 tensor。 所以如果某个参数是常量,并且有多个可能的取值,那就只能针对每个取值都生成一个pass,取值有无限多个的情况我也不知道怎么办。

Pass Manager

当后端有了多个自定义pass的时候,他们会统一在pass manager处理。

示意代码如下:

1
2
3
4
5
6
7
8
9
10
11
class PassManager:
def __init__(self):
self.passes: list[InductorPass] = []
def __call__(self, graph: fx.Graph):
for pass_ in self.passes:
pass_(graph)
VllmInductorPass.dump_prefix += 1
self.post_cleanup(graph)
def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)

vLLM如何利用torch.compile

模型的定义文件中,支持compile的部分会使用@support_torch_compile装饰器

具体的实现在vllm/compilation/decorators.py

重点在下面的函数中会使用自定义的__init__方法,并且在其中进入TorchCompileWrapperWithCustomDispatcher

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def _support_torch_compile(
cls: _T,
dynamic_arg_dims: dict[str, Union[int, list[int]]],
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.
"""
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
# support decorating multiple times
return cls

# take care of method resolution order
# make sure super().__init__ is called on the base class
# other than TorchCompileWrapperWithCustomDispatcher
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )

old_init = cls.__init__

setattr(cls, IGNORE_COMPILE_KEY, False)

def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
# 省略若干
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)

cls.__init__ = __init__
vllm/compilation/wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class TorchCompileWrapperWithCustomDispatcher:
def __init__(self,
compiled_callable: Optional[Callable] = None,
compilation_level: int = 0):
if compiled_callable is None:
# default compilation settings
# compiling the forward method

backend = vllm_config.compilation_config.init_backend(vllm_config)
options = None
if isinstance(backend, str) and backend == "inductor":
options = get_current_vllm_config(
).compilation_config.inductor_compile_config

compiled_callable = torch.compile(self.forward,
fullgraph=True,
backend=backend,
options=options)

self.compiled_callable = compiled_callable

vLLM 实现了VllmBackend,在编译的时候会传入该自定义的backend。

VllmBackend

  • 首先有一些torch.compile缓存相关的操作。

  • 然后在VLLMBackend里设置了inductorpost_grad_custom_post_pass,加入自定义pass

  • 根据splitting_opssplit_graph,得到split_gm

  • 如果不开启cuda graph,直接返回拆分后的计算图,反之需要运行时把输入拷贝到固定地址后再运行(copy_and_call

1
2
3
4
5
6
7
8
9
10
11
12
13
def copy_and_call(*args):
list_args = list(args)
for i, index in enumerate(self.sym_tensor_indices):
runtime_tensor = list_args[index]
runtime_shape = runtime_tensor.shape[0]
static_tensor = self.input_buffers[i][:runtime_shape]

# copy the tensor to the static buffer
static_tensor.copy_(runtime_tensor)

# replace the tensor in the list_args to the static buffer
list_args[index] = static_tensor
return self.split_gm(*list_args)

自定义pass都做了什么

下面这些是目前vLLM中存在的自定义pass,出于某些原因,他们都是默认关闭的。(开发者说是因为V0和V1的过渡期所以他们关掉了:This is a relic of the V1 upgrade, we're working on enabling them by default) 相关的issue里有更多的信息以及社区的目前进展。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class PassConfig:
"""Configuration for custom Inductor passes.

This is separate from general `CompilationConfig` so that inductor passes
don't all have access to full configuration - that would create a cycle as
the `PassManager` is set as a property of config."""

enable_fusion: bool = False
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
enable_attn_fusion: bool = False
"""Whether to enable the custom attention+quant fusion pass."""
enable_noop: bool = False
"""Whether to enable the custom no-op elimination pass."""
enable_sequence_parallelism: bool = False
"""Whether to enable sequence parallelism."""
enable_async_tp: bool = False
"""Whether to enable async TP."""
enable_fi_allreduce_fusion: bool = False
"""Whether to enable flashinfer allreduce fusion."""
fi_allreduce_fusion_max_token_num: int = 16384
"""Max number of tokens to used in flashinfer allreduce fusion."""

所有的这些pass都在vllm/compilation,下面简单挑几个介绍一下。

FusionPass

  • attention + quant
  • norm + quant
  • compute + comm

这些都是使用上文提到的pattern matcher方法,将一些操作替换为自定义算子或者第三方库中的算子(如flashinfer)。

NoOpEliminationPass 用于消除多余的 reshape/slice 操作

比如说

1
2
3
4
mul_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
view_2: "f16[s0, 4096]" = torch.reshape(view_1, [-1, 4096])
view_3: "f16[s0, 128, 32]" = torch.reshape(view_2, [-1, 128, 32])
可以变为
1
2
mul_1: "f16[s0, 4096]" = ...
view_3: "f16[s0, 128, 32]" = ...
这个pass直接操作计算图节点,所以在这里放一下代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class NoOpEliminationPass(VllmInductorPass):
"""
This is an inductor pass that removes redundant reshape/slice operations.
It is required for RMSNorm-quant fusion to work properly.
That's because apply_fp8_linear adds a reshape, which is redundant
in the 2D-case. Additionally, torch internal no-op elimination pass does
not handle certain slice variants.

Cases handled:
1. A chain of reshapes is equivalent to the last reshape called on the
base tensor (input of the first reshape).
2. A reshape that produces the shape of the input is redundant
3. A slice that produces the shape of the input is redundant
"""

@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
count = 0
# Remove no-op reshapes/views:
for node in graph.nodes:
if is_func(node, torch.ops.aten.reshape.default):
# Case 1: rewrite reshape chains to reshapes on the base tensor
input = node.args[0]
# If the input is a reshape, rebind to that node
if is_func(input, torch.ops.aten.reshape.default):
# The new input is guaranteed not to be a reshape,
# because we process nodes in order
node.update_arg(0, input.args[0])
if len(input.users) == 0:
graph.erase_node(input)
count += 1

# Case 2: remove this reshape if it produces the original shape
input, shape = node.args[:2]
input_shape = input.meta["val"].shape
if len(shape) != len(input_shape):
# Reshape changing rank, skip
continue

if shape.count(-1) > 1:
# Invalid reshape args, skip
continue

if self.reshape_all_dims_equivalent(shape, input_shape):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1

elif is_func(node, torch.ops.aten.slice.Tensor):
# python slicing semantics are different from reshape
# Don't treat -1 as inferred dimension
input, dim_index, start, end = node.args[:4]
input_shape = input.meta["val"].shape
output_shape = node.meta["val"].shape

if output_shape == input_shape:
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1

elif is_func(node, torch.ops.aten.slice_scatter.default):
base, view, dim_index, start, end = node.args[:5]
base_shape = base.meta["val"].shape
view_shape = view.meta["val"].shape

if base_shape == view_shape:
node.replace_all_uses_with(view)
graph.erase_node(node)
count += 1

logger.debug("Removed %s no-op reshapes and slices", count)