Skip to content

Commit

Permalink
Merge pull request #4 from DhruvaBansal00/gptq-marlin-refactor
Browse files Browse the repository at this point in the history
Refactoring for maintainability
  • Loading branch information
ElizaWszola authored Aug 22, 2024
2 parents 8f4648c + 315e3b6 commit 34bb5b0
Show file tree
Hide file tree
Showing 13 changed files with 814 additions and 606 deletions.
76 changes: 41 additions & 35 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import (fused_marlin_moe, fused_moe,
single_marlin_moe)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe_marlin import (
fused_moe_marlin, single_marlin_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE
Expand Down Expand Up @@ -63,11 +64,11 @@ def test_fused_moe(
topk: int,
dtype: torch.dtype,
):
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10

score = torch.randn((m, e), device='cuda', dtype=dtype)
score = torch.randn((m, e), device="cuda", dtype=dtype)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
Expand Down Expand Up @@ -166,11 +167,11 @@ def test_fused_marlin_moe(

quant_type = scalar_types.uint4b8
dtype = torch.float16
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
for i in range(w2.shape[0]):
w2[0] = torch.eye(k, n, device='cuda', dtype=dtype)
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)

w_ref1_l = []
qweight1_l = []
Expand Down Expand Up @@ -218,27 +219,32 @@ def test_fused_marlin_moe(
g_idx2 = stack_and_dev(g_idx2_l)
sort_indices2 = stack_and_dev(sort_indices2_l)

score = torch.randn((m, e), device='cuda', dtype=dtype)
triton_output = fused_moe(a,
w_ref1.transpose(1, 2).contiguous(),
w_ref2.transpose(1, 2).contiguous(),
score,
topk,
renormalize=False)
marlin_output = fused_marlin_moe(a,
qweight1,
qweight2,
score,
g_idx1,
g_idx2,
sort_indices1,
sort_indices2,
topk,
renormalize=False,
w1_scale=scales1,
w2_scale=scales2)

assert (compute_max_diff(marlin_output, triton_output) < 4e-2)
score = torch.randn((m, e), device="cuda", dtype=dtype)
triton_output = fused_moe(
a,
w_ref1.transpose(1, 2).contiguous(),
w_ref2.transpose(1, 2).contiguous(),
score,
topk,
renormalize=False,
)
marlin_output = fused_moe_marlin(
a,
qweight1,
qweight2,
score,
g_idx1,
g_idx2,
sort_indices1,
sort_indices2,
topk,
renormalize=False,
w1_scale=scales1,
w2_scale=scales2,
num_bits=4,
)

assert compute_max_diff(marlin_output, triton_output) < 4e-2


# TODO: make sure this test works
Expand Down Expand Up @@ -275,8 +281,8 @@ def test_single_marlin_moe(

quant_type = scalar_types.uint4b8
dtype = torch.float16
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w = torch.randn((e, n, k), device='cuda', dtype=dtype) / 10
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10

w_ref_l = []
qweights_l = []
Expand All @@ -300,7 +306,7 @@ def test_single_marlin_moe(
g_idx = stack_and_dev(g_idx_l)
sort_indices = stack_and_dev(sort_indices_l)

score = torch.randn((m, e), device='cuda', dtype=dtype)
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
Expand All @@ -311,4 +317,4 @@ def test_single_marlin_moe(
renormalize=False)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

assert (compute_max_diff(marlin_output, torch_output) < 1e-2)
assert compute_max_diff(marlin_output, torch_output) < 1e-2
2 changes: 1 addition & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
output = torch.empty((num_experts, size_k // 16, size_n * 2),
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype)
for e in range(num_experts):
Expand Down
7 changes: 3 additions & 4 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_marlin_moe,
single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe_marlin import (
fused_moe_marlin, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.triton_utils import HAS_TRITON

__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"fused_marlin_moe",
"fused_moe_marlin",
"single_marlin_moe",
]

if HAS_TRITON:

from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)
Expand Down
178 changes: 0 additions & 178 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,181 +666,3 @@ def fused_moe(
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)


def single_marlin_moe(
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
gating_output: torch.Tensor,
g_idx: torch.Tensor,
rand_perm: torch.Tensor,
topk: int,
renormalize: bool,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
) -> torch.Tensor:
"""
This function computes a Marlin MoE MMM using weights w
and top-k gating mechanism. It is meant for testing and debugging.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w (torch.Tensor): The first set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w and w2. Defaults to False.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch"
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w.is_contiguous(), "Expert weights must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
M, K = hidden_states.shape
E = w.shape[0]
N = w.shape[2] // 2

topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)

# This might not be an optimal config for a single MMM
get_config_func = functools.partial(try_get_optimal_moe_config,
w.shape,
w.shape,
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,
is_marlin=True)
config = get_config_func(M)

block_size_m = config['BLOCK_SIZE_M']

sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)

max_workspace_size = (N // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)

intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
g_idx, rand_perm, workspace, M, N, K, True, E, topk, block_size_m,
True, False)

return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)


def fused_marlin_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
g_idx1: torch.Tensor,
g_idx2: torch.Tensor,
rand_perm1: torch.Tensor,
rand_perm2: torch.Tensor,
topk: int,
renormalize: bool,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16

topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)

get_config_func = functools.partial(try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,
is_marlin=True)
config = get_config_func(M)

block_size_m = config['BLOCK_SIZE_M']

sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)

max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)

intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)

intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale,
g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk,
block_size_m, True, False)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))

intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids,
w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk,
block_size_m, False, True)

return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
Loading

0 comments on commit 34bb5b0

Please sign in to comment.