Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/compile/distributed/test_async_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def async_tp_pass_on_test_model(
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(
enable_async_tp=True,
fuse_gemm_comms=True,
),
)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
Expand Down Expand Up @@ -413,7 +413,7 @@ def test_async_tp_pass_correctness(
"mode": CompilationMode.VLLM_COMPILE,
"compile_sizes": [2, 4, 8],
"splitting_ops": [],
"pass_config": {"enable_async_tp": async_tp_enabled},
"pass_config": {"fuse_gemm_comms": async_tp_enabled},
}

async_tp_args = [
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/distributed/test_fusion_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def all_reduce_fusion_pass_on_test_model(
)
)
vllm_config.compilation_config.pass_config = PassConfig(
enable_fi_allreduce_fusion=True, enable_noop=True
fuse_allreduce_rms=True, eliminate_noops=True
)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
vllm_config.parallel_config.rank = local_rank # Setup rank for debug path
Expand Down
16 changes: 8 additions & 8 deletions tests/compile/distributed/test_fusions_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def test_attn_quant(
splitting_ops=splitting_ops,
# Common
mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)
Expand Down Expand Up @@ -282,9 +282,9 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
# Common
mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(
enable_attn_fusion=True,
enable_noop=True,
enable_fi_allreduce_fusion=True,
fuse_attn_quant=True,
eliminate_noops=True,
fuse_allreduce_rms=True,
),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
Expand Down Expand Up @@ -384,10 +384,10 @@ def test_tp2_attn_quant_async_tp(
# Common
level=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(
enable_attn_fusion=True,
enable_noop=True,
enable_sequence_parallelism=True,
enable_async_tp=True,
fuse_attn_quant=True,
eliminate_noops=True,
enable_sp=True,
fuse_gemm_comms=True,
),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
Expand Down
18 changes: 9 additions & 9 deletions tests/compile/distributed/test_sequence_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def ops_in_model_before(self):
]

def ops_in_model(self):
if self.vllm_config.compilation_config.pass_config.enable_fusion:
if self.vllm_config.compilation_config.pass_config.fuse_norm_quant:
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
elif RMSNorm.enabled():
return [
Expand Down Expand Up @@ -183,7 +183,7 @@ def ops_in_model(self):
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("enable_fusion", [True, False])
@pytest.mark.parametrize("fuse_norm_quant", [True, False])
@pytest.mark.parametrize("dynamic", [False, True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_sequence_parallelism_pass(
Expand All @@ -193,7 +193,7 @@ def test_sequence_parallelism_pass(
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_fusion: bool,
fuse_norm_quant: bool,
dynamic: bool,
):
num_processes = 2
Expand All @@ -211,7 +211,7 @@ def run_torch_spawn(fn, nprocs):
seq_len,
hidden_size,
dtype,
enable_fusion,
fuse_norm_quant,
dynamic,
),
nprocs=nprocs,
Expand All @@ -229,7 +229,7 @@ def sequence_parallelism_pass_on_test_model(
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_fusion: bool,
fuse_norm_quant: bool,
dynamic: bool,
):
current_platform.seed_everything(0)
Expand Down Expand Up @@ -260,9 +260,9 @@ def sequence_parallelism_pass_on_test_model(
cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
custom_ops=custom_ops_list,
pass_config=PassConfig(
enable_sequence_parallelism=True,
enable_fusion=enable_fusion,
enable_noop=True,
enable_sp=True,
fuse_norm_quant=fuse_norm_quant,
eliminate_noops=True,
),
) # NoOp needed for fusion
device_config = DeviceConfig(device=torch.device("cuda"))
Expand Down Expand Up @@ -297,7 +297,7 @@ def sequence_parallelism_pass_on_test_model(
sequence_parallelism_pass,
]

if enable_fusion:
if fuse_norm_quant:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
passes_for_backend.append(fusion_pass)

Expand Down
4 changes: 3 additions & 1 deletion tests/compile/fullgraph/test_full_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ def test_full_graph(
CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+rms_norm"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
pass_config=PassConfig(
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
),
),
*model_info,
)
Expand Down
77 changes: 65 additions & 12 deletions tests/compile/test_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import logging
from contextlib import nullcontext
from unittest.mock import patch

Expand All @@ -10,8 +11,9 @@
from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode
from vllm.config.compilation import CompilationMode, PassConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.logger import _print_warning_once
from vllm.platforms import current_platform
from vllm.utils.torch_utils import _is_torch_equal_or_newer

Expand Down Expand Up @@ -191,7 +193,7 @@ def test_splitting_ops_dynamic():
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
Expand All @@ -206,7 +208,7 @@ def test_splitting_ops_dynamic():
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
# work around for accessing all attntion ops
Expand All @@ -219,15 +221,15 @@ def test_splitting_ops_dynamic():
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
# With inductor graph partition, attn_fusion and splitting_ops
# work together. Default splitting_ops include attention ops.
assert config.compilation_config.splitting_ops_contain_attention()
# enable_attn_fusion is directly supported under
# fuse_attn_quant is directly supported under
# use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
Expand Down Expand Up @@ -301,7 +303,7 @@ def test_should_split():
"cudagraph_capture_sizes",
"max_cudagraph_capture_size",
"tp_size",
"enable_sequence_parallelism",
"enable_sp",
"max_num_batched_tokens",
"cudagraph_mode",
"expected_max_size",
Expand Down Expand Up @@ -339,7 +341,7 @@ def test_cudagraph_sizes_post_init(
cudagraph_capture_sizes,
max_cudagraph_capture_size,
tp_size,
enable_sequence_parallelism,
enable_sp,
max_num_batched_tokens,
cudagraph_mode,
expected_max_size,
Expand All @@ -355,11 +357,12 @@ def test_cudagraph_sizes_post_init(
compilation_config = CompilationConfig(
cudagraph_capture_sizes=cudagraph_capture_sizes,
max_cudagraph_capture_size=max_cudagraph_capture_size,
pass_config={
"enable_sequence_parallelism": enable_sequence_parallelism,
"enable_fusion": True,
"enable_noop": True,
},
pass_config=PassConfig(
enable_sp=enable_sp,
fuse_norm_quant=True,
fuse_act_quant=True,
eliminate_noops=True,
),
cudagraph_mode=cudagraph_mode,
)
engine_args = EngineArgs(
Expand All @@ -375,3 +378,53 @@ def test_cudagraph_sizes_post_init(
vllm_config.compilation_config.max_cudagraph_capture_size
== expected_max_size
)


def test_pass_config_deprecation(caplog_vllm):
caplog_vllm.set_level(logging.WARNING)

# Clear cache to ensure warnings are re-issued
_print_warning_once.cache_clear()

# Test enable_fusion -> fuse_norm_quant, fuse_act_quant
caplog_vllm.clear()
config = PassConfig(enable_fusion=True)
assert "enable_fusion is deprecated" in caplog_vllm.text
assert config.fuse_norm_quant is True
assert config.fuse_act_quant is True
assert config.enable_fusion is None

# Test enable_attn_fusion -> fuse_attn_quant
caplog_vllm.clear()
config = PassConfig(enable_attn_fusion=True)
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
assert config.fuse_attn_quant is True
assert config.enable_attn_fusion is None

# Test enable_noop -> eliminate_noops
caplog_vllm.clear()
config = PassConfig(enable_noop=True)
assert "enable_noop is deprecated" in caplog_vllm.text
assert config.eliminate_noops is True
assert config.enable_noop is None

# Test enable_sequence_parallelism -> enable_sp
caplog_vllm.clear()
config = PassConfig(enable_sequence_parallelism=True)
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
assert config.enable_sp is True
assert config.enable_sequence_parallelism is None

# Test enable_async_tp -> fuse_gemm_comms
caplog_vllm.clear()
config = PassConfig(enable_async_tp=True)
assert "enable_async_tp is deprecated" in caplog_vllm.text
assert config.fuse_gemm_comms is True
assert config.enable_async_tp is None

# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
caplog_vllm.clear()
config = PassConfig(enable_fi_allreduce_fusion=True)
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
assert config.fuse_allreduce_rms is True
assert config.enable_fi_allreduce_fusion is None
6 changes: 5 additions & 1 deletion tests/compile/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,11 @@ def test_fix_functionalization(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
custom_ops=["all"],
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),
pass_config=PassConfig(
fuse_norm_quant=do_fusion,
fuse_act_quant=do_fusion,
eliminate_noops=True,
),
),
)

Expand Down
4 changes: 3 additions & 1 deletion tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def test_fusion_rmsnorm_quant(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops,
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
pass_config=PassConfig(
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
),
),
)
with vllm.config.set_current_vllm_config(vllm_config):
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_fusion_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def test_attention_quant_pattern(

# Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig(
enable_attn_fusion=True, enable_noop=True
fuse_attn_quant=True, eliminate_noops=True
)
with (
set_current_vllm_config(vllm_config),
Expand Down
4 changes: 2 additions & 2 deletions tests/compile/test_noop_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def forward(self, x):
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(enable_noop=True),
pass_config=PassConfig(eliminate_noops=True),
)
)
with vllm.config.set_current_vllm_config(vllm_config):
Expand Down Expand Up @@ -99,7 +99,7 @@ def forward(self, x):
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(enable_noop=True),
pass_config=PassConfig(eliminate_noops=True),
)
)
with vllm.config.set_current_vllm_config(vllm_config):
Expand Down
7 changes: 5 additions & 2 deletions tests/compile/test_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@ def test_pass_manager_uuid(callable):

# UUID should be different due to config change
config2 = copy.deepcopy(config)
config2.compilation_config.pass_config.enable_fusion = (
not config2.compilation_config.pass_config.enable_fusion
config2.compilation_config.pass_config.fuse_norm_quant = (
not config2.compilation_config.pass_config.fuse_norm_quant
)
config2.compilation_config.pass_config.fuse_act_quant = (
not config2.compilation_config.pass_config.fuse_act_quant
)
pass_manager3 = PostGradPassManager()
pass_manager3.configure(config2)
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_qk_norm_rope_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_qk_norm_rope_fusion(
custom_ops=custom_ops,
pass_config=PassConfig(
enable_qk_norm_rope_fusion=True,
enable_noop=True,
eliminate_noops=True,
),
),
)
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_silu_mul_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_fusion_silu_and_mul_quant(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops,
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
pass_config=PassConfig(fuse_act_quant=True, eliminate_noops=True),
),
)

Expand Down
Loading