forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_flashinfer.py
More file actions
266 lines (230 loc) · 8.47 KB
/
test_flashinfer.py
File metadata and controls
266 lines (230 loc) · 8.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import pytest
import torch
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8,
flashinfer_cutlass_moe_fp8,
register_moe_scaling_factors,
rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
90
):
pytest.skip(
"Supported for sm >= 90",
allow_module_level=True,
)
NUM_EXPERTS = [16]
TOP_KS = [1]
MNK_FACTORS = [
(256, 8192, 5120),
(127, 4096, 5120),
(10, 8192, 5120),
(10, 4096, 5120),
(1, 8192, 5120),
(1, 4096, 5120),
]
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def quant_fp8_per_tensor_batches(a):
num_batches = a.size(0)
a_quant = []
a_scales = []
for i in range(num_batches):
a_fp8, a_global_sf = input_to_float8(a[i])
a_global_sf = 1.0 / a_global_sf
a_quant.append(a_fp8)
a_scales.append(a_global_sf)
result_a_quant = torch.stack(a_quant)
result_a_scales = torch.stack(a_scales)
return result_a_quant, result_a_scales
@dataclass
class TestData:
hidden_states: torch.Tensor
w13_quantized: torch.Tensor
w2_quantized: torch.Tensor
a1_scale: torch.Tensor
a2_scale: torch.Tensor
w13_weight_scale: torch.Tensor
w2_weight_scale: torch.Tensor
layer: torch.nn.Module
@staticmethod
def make_moe_tensors_8bit(
m: int, k: int, n: int, e: int, reorder: bool
) -> "TestData":
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
# Scale to fp8
_, a1_scale = input_to_float8(hidden_states)
a1_scale = 1.0 / a1_scale
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
layer = torch.nn.Module()
layer.w13_weight = w13_quantized.clone()
layer.w2_weight = w2_quantized.clone()
layer.w13_input_scale = a1_scale
layer.w2_input_scale = a2_scale
layer.w13_weight_scale = w13_weight_scale
layer.w2_weight_scale = w2_weight_scale
register_moe_scaling_factors(layer)
# flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if reorder:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
layer.custom_routing_function = Llama4MoE.custom_routing_function
layer.intermediate_size_per_partition = n
layer.ep_rank = 0
layer.local_num_experts = e
return TestData(
hidden_states=hidden_states,
w13_quantized=w13_quantized,
w2_quantized=w2_quantized,
a1_scale=a1_scale,
a2_scale=a2_scale,
w13_weight_scale=w13_weight_scale,
w2_weight_scale=w2_weight_scale,
layer=layer,
)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
def test_flashinfer_per_tensor_moe_fp8_no_graph(
m: int,
n: int,
k: int,
e: int,
topk: int,
monkeypatch,
):
if not current_platform.has_device_capability(100):
pytest.skip("Test is only supported for sm >= 100")
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=td.hidden_states,
router_logits=score,
use_grouped_topk=False,
top_k=topk,
renormalize=False,
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax",
)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
per_act_token_quant=False,
)
output = fused_experts(
td.hidden_states,
td.w13_quantized,
td.w2_quantized,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
activation="silu",
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=True,
quant_config=quant_config,
)
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
layer=td.layer,
hidden_states=td.hidden_states,
router_logits=score,
routing_bias=None,
global_num_experts=e,
top_k=topk,
num_expert_group=None,
topk_group=None,
apply_router_weight_on_input=True,
)
torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
def test_flashinfer_cutlass_moe_fp8_no_graph(
m: int,
n: int,
k: int,
e: int,
topk: int,
monkeypatch,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=td.hidden_states,
router_logits=score,
use_grouped_topk=False,
top_k=topk,
renormalize=False,
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax",
)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=td.w13_weight_scale,
g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
w2_scale=td.w2_weight_scale,
g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
a1_scale=td.a1_scale,
a1_gscale=td.a1_scale,
a2_scale=td.a2_scale,
a2_gscale=1.0 / td.a2_scale,
per_act_token_quant=False,
)
output = fused_experts(
td.hidden_states,
td.w13_quantized,
td.w2_quantized,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
activation="silu",
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=True,
quant_config=quant_config,
)
td.layer.dp_size = 1
def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
return quant_config
td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
td.layer.quant_method = td.layer
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
td.hidden_states,
td.layer,
topk_weights,
topk_ids,
activation="silu",
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=True,
)
torch.testing.assert_close(
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
)