vllm-torch-compile-integration
本文最后更新于 2025年10月5日
vLLM的torch.compile
集成在网上资料非常少,官方写的有一篇博客 和文档
但是很多相关的介绍都是语焉不详,只有宏观上的介绍,没有具体的细节。
笔者正在写这篇的时候,BBuf 发了一篇类似的文章[https://zhuanlan.zhihu.com/p/1955402895890560120]
所以笔者决定自己写一篇。
torch.compile
介绍
这个网上的介绍已经非常多了,在此就不再赘述。
需要明确的是,torch.compile
包括两个重要的组成部分:Dynamo
和Inductor
整个编译过程大概是下面这样,具体可以参考这篇博客
- 使用 TorchDynamo 来 trace 函数
- xxxx
- Inductor 运行 post_grad_passes,优化计算图。
- xxxx
自定义backend
torch.compile
提供了使用自定义backend
的的接口,具体参考这篇文档
backend
函数需要符合这样的接口 1
(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable
Dynamo
调用,接收一个fx graph
,后端需要返回优化后的、等价的fx graph
。
1 |
|
自定义pass
本节主要参考了下面两篇文档
- 官方文档:Writing Graph Transformations on ATen IR
- 官方教程: Building a Convolution/Batch Norm fuser with torch.compile
当torch.compile
官方提供的pass
不满足要求的时候可以选择添加自定义pass
。
官方总结了这么几种添加pass
的目的
- Axis A: 1. Creating one-to-X mapping (eg. decomposition) 2. Creating many-to-one mapping (eg. fusion)
- Axis B: 1. Doing forwards iteration (eg. shape propagation) 2. Doing backwards iteration (eg. dead code elimination)
- Axis C: 1. Dependent on local node information (eg. out-variant conversion) 2. Dependent on global graph information (eg. memory planning)
修改计算图最直接的方式就是直接操作生成的fx graph
1
2
3
4def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
node.target = torch.ops.aten.mul.Tensorreplace_add_with_mul
会把所有的add
操作都替换为mul
。
当然直接操作计算图是比较繁琐的,所以pytorch
也提供了一些比较方便的
helper utilities。
我目前接触过的,大概有三种pytorch
提供的自定义pass
的方式
torch.fx.Transformer
subgraph_rewriter
pattern matcher
torch.fx.Transformer
1 |
|
1 |
|
这种方式只能比较方便地对一个 node 做变换。
subgraph_rewriter
如果想要实现 fusion,就需要 X->one 或者 X->Y 的能力
1 |
|
pattern matcher
看起来pattern matcher
和subgraph_rewriter
非常相似,我也没搞明白他们俩的区别和联系是什么,为什么要搞两套api
具体可以参考这篇教程,有一个完整的例子。
1 |
|
这个 api 我用的时候还是有些问题,比如说 pattern
的参数不能是常量,一定要是 tensor。
所以如果某个参数是常量,并且有多个可能的取值,那就只能针对每个取值都生成一个pass
,取值有无限多个的情况我也不知道怎么办。
Pass Manager
当后端有了多个自定义pass
的时候,他们会统一在pass manager
处理。
示意代码如下: 1
2
3
4
5
6
7
8
9
10
11class PassManager:
def __init__(self):
self.passes: list[InductorPass] = []
def __call__(self, graph: fx.Graph):
for pass_ in self.passes:
pass_(graph)
VllmInductorPass.dump_prefix += 1
self.post_cleanup(graph)
def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)
vLLM如何利用torch.compile
模型的定义文件中,支持compile的部分会使用@support_torch_compile
装饰器
具体的实现在vllm/compilation/decorators.py
重点在下面的函数中会使用自定义的__init__
方法,并且在其中进入TorchCompileWrapperWithCustomDispatcher
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28def _support_torch_compile(
cls: _T,
dynamic_arg_dims: dict[str, Union[int, list[int]]],
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.
"""
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
# support decorating multiple times
return cls
# take care of method resolution order
# make sure super().__init__ is called on the base class
# other than TorchCompileWrapperWithCustomDispatcher
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
old_init = cls.__init__
setattr(cls, IGNORE_COMPILE_KEY, False)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
# 省略若干
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)
cls.__init__ = __init__1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20class TorchCompileWrapperWithCustomDispatcher:
def __init__(self,
compiled_callable: Optional[Callable] = None,
compilation_level: int = 0):
if compiled_callable is None:
# default compilation settings
# compiling the forward method
backend = vllm_config.compilation_config.init_backend(vllm_config)
options = None
if isinstance(backend, str) and backend == "inductor":
options = get_current_vllm_config(
).compilation_config.inductor_compile_config
compiled_callable = torch.compile(self.forward,
fullgraph=True,
backend=backend,
options=options)
self.compiled_callable = compiled_callable
vLLM
实现了VllmBackend
,在编译的时候会传入该自定义的backend。
VllmBackend
首先有一些
torch.compile
缓存相关的操作。然后在
VLLMBackend
里设置了inductor
的post_grad_custom_post_pass
,加入自定义pass
根据
splitting_ops
来split_graph
,得到split_gm
如果不开启cuda graph,直接返回拆分后的计算图,反之需要运行时把输入拷贝到固定地址后再运行(
copy_and_call
)
1 |
|
自定义pass都做了什么
下面这些是目前vLLM
中存在的自定义pass,出于某些原因,他们都是默认关闭的。(开发者说是因为V0和V1的过渡期所以他们关掉了:This
is a relic of the V1 upgrade, we're working on enabling them by
default) 相关的issue里有更多的信息以及社区的目前进展。
1 |
|
所有的这些pass都在vllm/compilation,下面简单挑几个介绍一下。
FusionPass
- attention + quant
- norm + quant
- compute + comm
这些都是使用上文提到的pattern matcher
方法,将一些操作替换为自定义算子或者第三方库中的算子(如flashinfer)。
NoOpEliminationPass
用于消除多余的 reshape/slice 操作
比如说 1
2
3
4mul_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
view_2: "f16[s0, 4096]" = torch.reshape(view_1, [-1, 4096])
view_3: "f16[s0, 128, 32]" = torch.reshape(view_2, [-1, 128, 32])1
2mul_1: "f16[s0, 4096]" = ...
view_3: "f16[s0, 128, 32]" = ...1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71class NoOpEliminationPass(VllmInductorPass):
"""
This is an inductor pass that removes redundant reshape/slice operations.
It is required for RMSNorm-quant fusion to work properly.
That's because apply_fp8_linear adds a reshape, which is redundant
in the 2D-case. Additionally, torch internal no-op elimination pass does
not handle certain slice variants.
Cases handled:
1. A chain of reshapes is equivalent to the last reshape called on the
base tensor (input of the first reshape).
2. A reshape that produces the shape of the input is redundant
3. A slice that produces the shape of the input is redundant
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
count = 0
# Remove no-op reshapes/views:
for node in graph.nodes:
if is_func(node, torch.ops.aten.reshape.default):
# Case 1: rewrite reshape chains to reshapes on the base tensor
input = node.args[0]
# If the input is a reshape, rebind to that node
if is_func(input, torch.ops.aten.reshape.default):
# The new input is guaranteed not to be a reshape,
# because we process nodes in order
node.update_arg(0, input.args[0])
if len(input.users) == 0:
graph.erase_node(input)
count += 1
# Case 2: remove this reshape if it produces the original shape
input, shape = node.args[:2]
input_shape = input.meta["val"].shape
if len(shape) != len(input_shape):
# Reshape changing rank, skip
continue
if shape.count(-1) > 1:
# Invalid reshape args, skip
continue
if self.reshape_all_dims_equivalent(shape, input_shape):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
elif is_func(node, torch.ops.aten.slice.Tensor):
# python slicing semantics are different from reshape
# Don't treat -1 as inferred dimension
input, dim_index, start, end = node.args[:4]
input_shape = input.meta["val"].shape
output_shape = node.meta["val"].shape
if output_shape == input_shape:
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
elif is_func(node, torch.ops.aten.slice_scatter.default):
base, view, dim_index, start, end = node.args[:5]
base_shape = base.meta["val"].shape
view_shape = view.meta["val"].shape
if base_shape == view_shape:
node.replace_all_uses_with(view)
graph.erase_node(node)
count += 1
logger.debug("Removed %s no-op reshapes and slices", count)