diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu
index 88275dbdd83a..55e659679701 100644
--- a/csrc/activation_kernels.cu
+++ b/csrc/activation_kernels.cu
@@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
   int64_t num_tokens = input.numel() / input.size(-1);                   \
   dim3 grid(num_tokens);                                                 \
   dim3 block(std::min(d, 1024));                                         \
+  if (num_tokens == 0) {                                                 \
+    return;                                                              \
+  }                                                                      \
   const at::cuda::OptionalCUDAGuard device_guard(device_of(input));      \
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();          \
   VLLM_DISPATCH_FLOATING_TYPES(                                          \
diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h
index dc6e0769b878..f7b75c48373f 100644
--- a/csrc/dispatch_utils.h
+++ b/csrc/dispatch_utils.h
@@ -65,5 +65,19 @@
   AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)   \
   AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
 
+#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
+  AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)       \
+  AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)       \
+  AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)      \
+  AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)        \
+  AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)       \
+  AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__)     \
+  AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__)     \
+  AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
+
 #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
   AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
+
+#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                              \
+      TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu
index d7be769458e3..6b6a9d04a60f 100644
--- a/csrc/moe/moe_align_sum_kernels.cu
+++ b/csrc/moe/moe_align_sum_kernels.cu
@@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
   }
 
   if (use_global_memory) {
-    VLLM_DISPATCH_INTEGRAL_TYPES(
+    VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
         topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
           // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
           // tensors
@@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
               cumsum_buffer.data_ptr<int32_t>());
         });
   } else if (use_i16) {
-    VLLM_DISPATCH_INTEGRAL_TYPES(
+    VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
         topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
           // set dynamic shared mem
           auto kernel =
@@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
               topk_ids.numel());
         });
   } else {
-    VLLM_DISPATCH_INTEGRAL_TYPES(
+    VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
         topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
           auto kernel =
               vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
@@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
   TORCH_CHECK(num_experts == 256,
               "sgl_moe_align_block_size kernel only supports deepseek v3.");
 
-  VLLM_DISPATCH_INTEGRAL_TYPES(
+  VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
       topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
         // calc needed amount of shared mem for `cumsum` tensors
         auto options_int =
diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu
index de9747b60252..a9379032245d 100644
--- a/csrc/moe/topk_softmax_kernels.cu
+++ b/csrc/moe/topk_softmax_kernels.cu
@@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
     }
 }
 
-template <int TPB>
-__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
-    int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
+template <int TPB, typename IndType>
+__launch_bounds__(TPB) __global__ void moeTopK(
+    const float* inputs_after_softmax,
+    const bool* finished,
+    float* output,
+    IndType* indices,
+    int* source_rows,
+    const int num_experts,
+    const int k,
+    const int start_expert,
+    const int end_expert)
 {
 
     using cub_kvp = cub::KeyValuePair<int, float>;
@@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
   2) This implementation assumes k is small, but will work for any k.
 */
 
-template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
+template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
 __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
-    void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
+    void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
         int* source_rows, const int k, const int start_expert, const int end_expert)
 {
     // We begin by enforcing compile time assertions and setting up compile time constants.
@@ -397,8 +405,8 @@ struct TopkConstants
 };
 } // namespace detail
 
-template <int EXPERTS, int WARPS_PER_TB>
-void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
+template <int EXPERTS, int WARPS_PER_TB, typename IndType>
+void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
     int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
 {
     static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
@@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
         token_expert_indices, num_tokens, topk, 0, num_experts,         \
         stream);
 
+template <typename IndType>
 void topkGatingSoftmaxKernelLauncher(
     const float* gating_output,
     float* topk_weights,
-    int* topk_indicies,
+    IndType* topk_indicies,
     int* token_expert_indices,
     float* softmax_workspace,
     const int num_tokens,
@@ -493,14 +502,32 @@ void topk_softmax(
     const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
     const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
     torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
-    vllm::moe::topkGatingSoftmaxKernelLauncher(
-        gating_output.data_ptr<float>(),
-        topk_weights.data_ptr<float>(),
-        topk_indices.data_ptr<int>(),
-        token_expert_indices.data_ptr<int>(),
-        softmax_workspace.data_ptr<float>(),
-        num_tokens,
-        num_experts,
-        topk,
-        stream);
+
+    if(topk_indices.scalar_type() == at::ScalarType::Int)
+    {
+        vllm::moe::topkGatingSoftmaxKernelLauncher(
+            gating_output.data_ptr<float>(),
+            topk_weights.data_ptr<float>(),
+            topk_indices.data_ptr<int>(),
+            token_expert_indices.data_ptr<int>(),
+            softmax_workspace.data_ptr<float>(),
+            num_tokens,
+            num_experts,
+            topk,
+            stream);
+    }
+    else
+    {
+        assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
+        vllm::moe::topkGatingSoftmaxKernelLauncher(
+            gating_output.data_ptr<float>(),
+            topk_weights.data_ptr<float>(),
+            topk_indices.data_ptr<uint32_t>(),
+            token_expert_indices.data_ptr<int>(),
+            softmax_workspace.data_ptr<float>(),
+            num_tokens,
+            num_experts,
+            topk,
+            stream);
+    }
 }
diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py
index 965915beaf58..f636a08c0b09 100644
--- a/examples/offline_inference/data_parallel.py
+++ b/examples/offline_inference/data_parallel.py
@@ -65,11 +65,17 @@ def parse_args():
                         type=int,
                         default=0,
                         help="Master node port")
+    parser.add_argument("--enforce-eager",
+                        action='store_true',
+                        help="Enforce eager mode execution.")
+    parser.add_argument("--trust-remote-code",
+                        action='store_true',
+                        help="Trust remote code.")
     return parser.parse_args()
 
 
 def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
-         dp_master_port, GPUs_per_dp_rank):
+         dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code):
     os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
     os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
     os.environ["VLLM_DP_SIZE"] = str(dp_size)
@@ -109,10 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
                                      max_tokens=[16, 20][global_dp_rank % 2])
 
     # Create an LLM.
-    llm = LLM(model=model,
-              tensor_parallel_size=GPUs_per_dp_rank,
-              enforce_eager=True,
-              enable_expert_parallel=True)
+    llm = LLM(
+        model=model,
+        tensor_parallel_size=GPUs_per_dp_rank,
+        enforce_eager=enforce_eager,
+        enable_expert_parallel=True,
+        trust_remote_code=trust_remote_code,
+    )
     outputs = llm.generate(prompts, sampling_params)
     # Print the outputs.
     for i, output in enumerate(outputs):
@@ -155,7 +164,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
         proc = Process(target=main,
                        args=(args.model, dp_size, local_dp_rank,
                              global_dp_rank, dp_master_ip, dp_master_port,
-                             tp_size))
+                             tp_size, args.enforce_eager,
+                             args.trust_remote_code))
         proc.start()
         procs.append(proc)
     exit_code = 0
diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py
new file mode 100644
index 000000000000..5de87fa581e9
--- /dev/null
+++ b/tests/kernels/moe/test_batched_moe.py
@@ -0,0 +1,107 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from dataclasses import dataclass
+
+import pytest
+import torch
+import triton.language as tl
+
+from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
+    invoke_moe_batched_triton_kernel)
+
+
+@dataclass
+class BatchedMMConfig:
+    dtype: torch.dtype
+    num_experts: int
+    max_tokens_per_expert: int
+    K: int
+    N: int
+
+
+@dataclass
+class BatchedMMTensors:
+    A: torch.Tensor  # [E, max_tokens, K]
+    B: torch.Tensor  # [E, K, N] - column major
+    C: torch.Tensor  # [E, max_tokens, N]
+    num_expert_tokens: torch.Tensor  # [E]
+
+    @staticmethod
+    def make_tensors(config: BatchedMMConfig):
+        A = torch.randn(
+            (config.num_experts, config.max_tokens_per_expert, config.K),
+            device="cuda",
+            dtype=config.dtype) / 50.0
+        B = torch.randn((config.num_experts, config.N, config.K),
+                        device="cuda",
+                        dtype=config.dtype) / 50.0
+        C = torch.zeros(
+            (config.num_experts, config.max_tokens_per_expert, config.N),
+            device="cuda",
+            dtype=config.dtype)
+        num_expert_tokens = torch.randint(low=0,
+                                          high=config.max_tokens_per_expert,
+                                          size=(config.num_experts, ),
+                                          device="cuda",
+                                          dtype=torch.int32)
+        return BatchedMMTensors(A, B, C, num_expert_tokens)
+
+
+def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
+             num_expert_tokens: torch.Tensor) -> torch.Tensor:
+
+    num_expert_tokens_cpu = num_expert_tokens.clone()
+    num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
+    num_experts = num_expert_tokens.size(0)
+
+    for e in range(num_experts):
+        num_tokens = num_expert_tokens_cpu[e]
+        C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
+
+    return C
+
+
+@pytest.mark.parametrize("num_experts", [16, 32])
+@pytest.mark.parametrize("max_tokens_per_expert",
+                         [32, 64, 128, 192, 224, 256, 512])
+@pytest.mark.parametrize("K", [128, 256, 1024])
+@pytest.mark.parametrize("N", [128, 256, 512, 1024])
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
+                    N: int, dtype: torch.dtype):
+
+    config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
+    tensors = BatchedMMTensors.make_tensors(config)
+
+    test_output = tensors.C
+    ref_output = test_output.clone()
+
+    compute_tl_dtype = {
+        torch.float16: tl.float16,
+        torch.bfloat16: tl.bfloat16,
+        torch.float32: tl.float32
+    }[test_output.dtype]
+    invoke_moe_batched_triton_kernel(
+        tensors.A,
+        tensors.B,
+        test_output,
+        tensors.num_expert_tokens,
+        compute_tl_dtype,
+        # Quantization data
+        None,
+        None,
+        None,
+        # Quantization schemes
+        False,
+        False,
+        False,
+        config={
+            "BLOCK_SIZE_M": 16,
+            "BLOCK_SIZE_N": 16,
+            "BLOCK_SIZE_K": 16
+        })
+
+    ref_output = ref_impl(tensors.A, tensors.B, ref_output,
+                          tensors.num_expert_tokens)
+
+    torch.testing.assert_close(test_output, ref_output, atol=1e-3, rtol=1e-3)
\ No newline at end of file
diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py
index 975cd418a171..7db4fe0f46e3 100644
--- a/tests/kernels/moe/test_cutlass_moe.py
+++ b/tests/kernels/moe/test_cutlass_moe.py
@@ -30,6 +30,11 @@
     (224, 3072, 1536),
 ]
 
+vllm_config = VllmConfig(parallel_config=ParallelConfig(
+    pipeline_parallel_size=1))
+vllm_config.scheduler_config.max_num_seqs = 128
+vllm_config.scheduler_config.max_model_len = 8192
+
 
 @dataclasses.dataclass
 class MOETensors:
@@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
         'w1_q': moe_tensors.w1_q.transpose(1, 2),  # type: ignore[union-attr]
         'w2_q': moe_tensors.w2_q.transpose(1, 2),  # type: ignore[union-attr]
         'topk_weights': topk_weights,
-        'topk_ids_': topk_ids,
+        'topk_ids': topk_ids,
         'ab_strides1': moe_tensors.ab_strides1,
         'c_strides1': moe_tensors.c_strides1,
         'ab_strides2': moe_tensors.ab_strides2,
@@ -231,18 +236,15 @@ def test_cutlass_moe_8_bit_no_graph(
     per_out_ch: bool,
 ):
     current_platform.seed_everything(7)
-    with set_current_vllm_config(
-            VllmConfig(parallel_config=ParallelConfig(
-                pipeline_parallel_size=1))):
-
+    with set_current_vllm_config(vllm_config):
         mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
                                                   per_out_ch)
 
         score = torch.randn((m, e), device="cuda", dtype=torch.half)
-        topk_weights, topk_ids = fused_topk(mt.a,
-                                            score,
-                                            topk,
-                                            renormalize=False)
+        topk_weights, topk_ids, _ = fused_topk(mt.a,
+                                               score,
+                                               topk,
+                                               renormalize=False)
 
         # Note that we are using the dequantized versions of the tensors.
         # Using a, w1 and w2 directly results in minor output differences.
@@ -276,20 +278,17 @@ def test_cutlass_moe_8_bit_cuda_graph(
     per_out_ch: bool,
 ):
     current_platform.seed_everything(7)
-    with set_current_vllm_config(
-            VllmConfig(parallel_config=ParallelConfig(
-                pipeline_parallel_size=1))):
-
+    with set_current_vllm_config(vllm_config):
         dtype = torch.half
 
         mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
                                                   per_out_ch)
 
         score = torch.randn((m, e), device="cuda", dtype=dtype)
-        topk_weights, topk_ids = fused_topk(mt.a,
-                                            score,
-                                            topk,
-                                            renormalize=False)
+        topk_weights, topk_ids, _ = fused_topk(mt.a,
+                                               score,
+                                               topk,
+                                               renormalize=False)
 
         # Note that we are using the dequantized versions of the tensors.
         # Using a, w1 and w2 directly results in minor output differences.
@@ -334,18 +333,15 @@ def test_cutlass_moe_8_bit_EP(
     ep_size: int,
 ):
     current_platform.seed_everything(7)
-    with set_current_vllm_config(
-            VllmConfig(parallel_config=ParallelConfig(
-                pipeline_parallel_size=1))):
-
+    with set_current_vllm_config(vllm_config):
         mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
                                                   per_out_channel)
 
         score = torch.randn((m, e), device="cuda", dtype=torch.half)
-        topk_weights, topk_ids = fused_topk(mt.a,
-                                            score,
-                                            topk,
-                                            renormalize=False)
+        topk_weights, topk_ids, _ = fused_topk(mt.a,
+                                               score,
+                                               topk,
+                                               renormalize=False)
 
         # Note that we are using the dequantized versions of the tensors.
         # Using a, w1 and w2 directly results in minor output differences.
diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py
index abf3e3667a75..befd07075c93 100644
--- a/tests/kernels/moe/test_moe.py
+++ b/tests/kernels/moe/test_moe.py
@@ -12,6 +12,7 @@
 
 import vllm.model_executor.layers.fused_moe  # noqa
 from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
+from vllm.config import VllmConfig, set_current_vllm_config
 from vllm.model_executor.layers.fused_moe import fused_moe
 from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
 from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
@@ -30,6 +31,10 @@
 EP_SIZE = [1, 4]
 TOP_KS = [2, 6]
 
+vllm_config = VllmConfig()
+vllm_config.scheduler_config.max_num_seqs = 128
+vllm_config.scheduler_config.max_model_len = 8192
+
 
 @pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
 @pytest.mark.parametrize("n", [128, 1024, 2048])
@@ -68,31 +73,33 @@ def test_fused_moe(
     else:
         e_map = None
 
-    torch_output = torch_moe(a, w1, w2, score, topk, e_map)
-    iterative_output = iterative_moe(a,
-                                     w1,
-                                     w2,
-                                     score,
-                                     topk,
-                                     global_num_experts=e,
-                                     expert_map=e_map,
-                                     renormalize=False)
+    with set_current_vllm_config(vllm_config):
+        torch_output = torch_moe(a, w1, w2, score, topk, e_map)
+        iterative_output = iterative_moe(a,
+                                         w1,
+                                         w2,
+                                         score,
+                                         topk,
+                                         global_num_experts=e,
+                                         expert_map=e_map,
+                                         renormalize=False)
+
+        # Pad the weight if moe padding is enabled
+        if padding:
+            w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
+            torch.cuda.empty_cache()
+            w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
+            torch.cuda.empty_cache()
+
+        triton_output = fused_moe(a,
+                                  w1,
+                                  w2,
+                                  score,
+                                  topk,
+                                  global_num_experts=e,
+                                  expert_map=e_map,
+                                  renormalize=False)
 
-    # Pad the weight if moe padding is enabled
-    if padding:
-        w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
-        torch.cuda.empty_cache()
-        w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
-        torch.cuda.empty_cache()
-
-    triton_output = fused_moe(a,
-                              w1,
-                              w2,
-                              score,
-                              topk,
-                              global_num_experts=e,
-                              expert_map=e_map,
-                              renormalize=False)
     torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
     torch.testing.assert_close(iterative_output,
                                torch_output,
@@ -113,7 +120,7 @@ def test_fused_moe(
 def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
                         ep_size: int, dtype: torch.dtype, group_size: int,
                         has_zp: bool, weight_bits: int):
-    print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
+    #print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
     a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
     w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
     w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
@@ -192,22 +199,24 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
     else:
         e_map = None
 
-    triton_output = fused_moe(a,
-                              w1_qweight,
-                              w2_qweight,
-                              score,
-                              topk,
-                              renormalize=False,
-                              use_int4_w4a16=weight_bits == 4,
-                              use_int8_w8a16=weight_bits == 8,
-                              global_num_experts=e,
-                              expert_map=e_map,
-                              w1_scale=w1_scales,
-                              w2_scale=w2_scales,
-                              w1_zp=w1_qzeros if has_zp else None,
-                              w2_zp=w2_qzeros if has_zp else None,
-                              block_shape=[0, group_size])
-    torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
+    with set_current_vllm_config(vllm_config):
+        triton_output = fused_moe(a,
+                                  w1_qweight,
+                                  w2_qweight,
+                                  score,
+                                  topk,
+                                  renormalize=False,
+                                  use_int4_w4a16=weight_bits == 4,
+                                  use_int8_w8a16=weight_bits == 8,
+                                  global_num_experts=e,
+                                  expert_map=e_map,
+                                  w1_scale=w1_scales,
+                                  w2_scale=w2_scales,
+                                  w1_zp=w1_qzeros if has_zp else None,
+                                  w2_zp=w2_qzeros if has_zp else None,
+                                  block_shape=[0, group_size])
+        torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
+
     torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
 
 
@@ -438,7 +447,8 @@ def test_fused_marlin_moe(
     topk_weights, topk_ids, token_expert_indices = fused_topk(
         a, score, topk, False)
 
-    torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
+    with set_current_vllm_config(vllm_config):
+        torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
 
     marlin_output = torch.ops.vllm.fused_marlin_moe(
         a,
diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py
new file mode 100644
index 000000000000..542f03a01a1e
--- /dev/null
+++ b/tests/kernels/moe/test_pplx_moe.py
@@ -0,0 +1,620 @@
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the MOE layers.
+
+Run `pytest tests/kernels/test_pplx_moe.py`.
+"""
+import dataclasses
+import os
+import traceback
+from typing import Callable, Optional
+
+import pytest
+import torch
+
+try:
+    from pplx_kernels import AllToAll
+    from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
+                                      nvshmem_finalize, nvshmem_get_unique_id,
+                                      nvshmem_init)
+    has_pplx = True
+except ImportError:
+    has_pplx = False
+
+from torch.multiprocessing import (
+    spawn)  # pyright: ignore[reportPrivateImportUsage]
+from typing_extensions import Concatenate, ParamSpec
+
+import vllm.model_executor.layers.fused_moe  # noqa
+from vllm.config import VllmConfig, set_current_vllm_config
+from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
+    BatchedDispatchCombine, BatchedExperts)
+from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
+from vllm.model_executor.layers.fused_moe.modular_kernel import (
+    FusedMoEModularKernel)
+from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import (
+    PplxDispatchCombine)
+from vllm.platforms import current_platform
+
+NUM_EXPERTS = [8, 64]
+EP_SIZE = [1, 4]
+TOP_KS = [2, 6]
+
+vllm_config = VllmConfig()
+vllm_config.scheduler_config.max_num_seqs = 128
+vllm_config.scheduler_config.max_model_len = 8192
+
+P = ParamSpec("P")
+
+requires_pplx = pytest.mark.skipif(
+    not has_pplx,
+    reason="Requires PPLX kernels",
+)
+
+
+@dataclasses.dataclass
+class ProcessGroupInfo:
+    world_size: int
+    world_local_size: int
+    rank: int
+    node_rank: int
+    local_rank: int
+    device: torch.device
+
+
+def _worker_parallel_launch(
+    local_rank: int,
+    world_size: int,
+    world_local_size: int,
+    node_rank: int,
+    init_method: str,
+    worker: Callable[Concatenate[ProcessGroupInfo, P], None],
+    *args: P.args,
+    **kwargs: P.kwargs,
+) -> None:
+    rank = node_rank * world_local_size + local_rank
+    torch.cuda.set_device(local_rank)
+    device = torch.device("cuda", local_rank)
+    torch.distributed.init_process_group(
+        backend="cpu:gloo,cuda:nccl",
+        init_method=init_method,
+        rank=rank,
+        world_size=world_size,
+        device_id=device,
+    )
+    barrier = torch.tensor([rank], device=device)
+    torch.distributed.all_reduce(barrier)
+
+    try:
+        worker(
+            ProcessGroupInfo(
+                world_size=world_size,
+                world_local_size=world_local_size,
+                rank=rank,
+                node_rank=node_rank,
+                local_rank=local_rank,
+                device=device,
+            ),
+            *args,
+            **kwargs,
+        )
+    except Exception as ex:
+        print(ex)
+        traceback.print_exc()
+        raise
+    finally:
+        torch.distributed.destroy_process_group()
+
+
+def parallel_launch(
+    world_size: int,
+    worker: Callable[Concatenate[ProcessGroupInfo, P], None],
+    *args: P.args,
+    **kwargs: P.kwargs,
+) -> None:
+    assert not kwargs
+    spawn(
+        _worker_parallel_launch,
+        args=(
+            world_size,
+            world_size,
+            0,
+            "tcp://localhost:29500",
+            worker,
+        ) + args,
+        nprocs=world_size,
+        join=True,
+    )
+
+
+def parallel_launch_from_env(
+    worker: Callable[Concatenate[ProcessGroupInfo, P], None],
+    *args: P.args,
+    **kwargs: P.kwargs,
+) -> None:
+    """
+    Launches a worker function in parallel across all processes in the current
+    environment. The environment must have the following variables set:
+    - WORLD_SIZE: The total number of processes.
+    - WORLD_LOCAL_SIZE: The number of processes on the current node.
+    - NODE_RANK: The rank of the current
+    - MASTER_ADDR: The address of the master process.
+    - MASTER_PORT: The port of the master process.
+    """
+    assert not kwargs
+    world_size = int(os.environ["WORLD_SIZE"])
+    world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
+    node_rank = int(os.environ["NODE_RANK"])
+    assert "MASTER_ADDR" in os.environ
+    assert "MASTER_PORT" in os.environ
+    spawn(
+        _worker_parallel_launch,
+        args=(
+            world_size,
+            world_local_size,
+            node_rank,
+            "env://",
+            worker,
+        ) + args,
+        nprocs=world_local_size,
+        join=True,
+    )
+
+
+def torch_dispatch(
+    a: torch.Tensor,
+    topk_ids: torch.Tensor,
+    num_experts: int,
+    max_num_tokens: Optional[int] = None,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    assert topk_ids.dim() == 2
+    assert topk_ids.shape[0] == a.shape[0]
+
+    num_tokens, hidden_dim = a.shape
+    topk = topk_ids.shape[1]
+
+    tokens_per_expert = torch.bincount(topk_ids.view(-1),
+                                       minlength=num_experts)
+
+    assert tokens_per_expert.numel() == num_experts
+
+    if max_num_tokens is None:
+        max_num_tokens = int(tokens_per_expert.max().item())
+
+    b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim),
+                      dtype=a.dtype,
+                      device=a.device)
+
+    token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)
+
+    for token in range(num_tokens):
+        for j in range(topk):
+            expert_id = topk_ids[token, j]
+            idx = token_counts[expert_id]
+            b_a[expert_id, idx:idx + 1, :] = a[token, :]
+            token_counts[expert_id] = token_counts[expert_id] + 1
+
+    return b_a, tokens_per_expert
+
+
+def torch_combine(b_out, topk_weight, topk_ids):
+    num_tokens = topk_ids.shape[0]
+    num_experts = b_out.shape[0]
+    K = b_out.shape[-1]
+    out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
+    expert_counts = torch.zeros(num_experts,
+                                dtype=torch.int,
+                                device=b_out.device)
+    for token in range(num_tokens):
+        expert_ids = topk_ids[token]
+        for i in range(expert_ids.numel()):
+            expert_id = expert_ids[i]
+            idx = expert_counts[expert_id]
+            out[token, :] = out[token, :] + b_out[expert_id, idx:idx +
+                                                  1, :] * topk_weight[token, i]
+            expert_counts[expert_id] = expert_counts[expert_id] + 1
+
+    return out
+
+
+def torch_batched_moe(a, w1, w2, topk_weight, topk_ids):
+    num_experts = w1.shape[0]
+    b_a, tokens_per_expert = torch_dispatch(a, topk_ids, num_experts)
+    assert b_a.dim() == 3
+    num_tokens, topk = topk_ids.shape
+    _, max_num_tokens, K = b_a.shape
+    assert num_experts == b_a.shape[0] and w2.shape[1] == K
+    out = torch.zeros((num_experts, max_num_tokens, K),
+                      dtype=b_a.dtype,
+                      device=b_a.device)
+    tmp = torch.empty((max_num_tokens, w1.shape[1] // 2),
+                      dtype=b_a.dtype,
+                      device=b_a.device)
+    for expert in range(num_experts):
+        num = tokens_per_expert[expert]
+        if num > 0:
+            torch.ops._C.silu_and_mul(
+                tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1))
+            out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)
+
+    return torch_combine(out, topk_weight, topk_ids)
+
+
+def batched_moe(a, w1, w2, topk_weight, topk_ids):
+    num_experts = w1.shape[0]
+
+    fused_experts = FusedMoEModularKernel(
+        BatchedDispatchCombine(a.shape[0], world_size=1, dp_size=1, rank=0),
+        BatchedExperts(a.shape[0]))
+
+    return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
+
+
+# TODO: same as torch_moe but with fused_topk factored out.
+def torch_moe2(a, w1, w2, topk_weight, topk_ids):
+    M, K = a.shape
+    topk = topk_ids.shape[1]
+    a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
+    out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
+    num_experts = w1.shape[0]
+    for i in range(num_experts):
+        mask = (topk_ids == i).view(-1)
+        if mask.sum():
+            out[mask] = SiluAndMul()(
+                a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
+
+    return (out.view(M, -1, w2.shape[1]) *
+            topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
+
+
+@pytest.mark.parametrize("m", [1, 33, 64, 222])
+@pytest.mark.parametrize("n", [128, 1024, 2048])
+@pytest.mark.parametrize("k", [128, 512, 1024])
+@pytest.mark.parametrize("e", NUM_EXPERTS)
+@pytest.mark.parametrize("topk", TOP_KS)
+@pytest.mark.parametrize("dtype", [torch.bfloat16])
+def test_fused_moe_batched_experts(
+    m: int,
+    n: int,
+    k: int,
+    e: int,
+    topk: int,
+    dtype: torch.dtype,
+):
+    current_platform.seed_everything(7)
+
+    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
+    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
+    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
+    score = torch.randn((m, e), device="cuda", dtype=dtype)
+
+    with set_current_vllm_config(vllm_config):
+        topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
+        baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
+        torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
+        batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
+
+    torch.testing.assert_close(baseline_output,
+                               torch_output,
+                               atol=2e-2,
+                               rtol=0)
+    torch.set_printoptions(profile="full")
+    torch.testing.assert_close(baseline_output,
+                               batched_output,
+                               atol=2e-2,
+                               rtol=0)
+
+
+def rank_chunk(num, r, w):
+    rem = num % w
+    return (num // w) + (1 if r < rem else 0)
+
+
+def chunk_by_rank(t, r, w):
+    chunk = rank_chunk(t.shape[0], r, w)
+    return t[(r * chunk):(r + 1) * chunk]
+
+
+def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
+    assert torch.cuda.current_device() == pgi.local_rank
+
+    topk = topk_ids.shape[1]
+    num_tokens, hidden_dim = a.shape
+    block_size = 128
+    device = pgi.device
+    rank = pgi.rank
+    world_size = pgi.world_size
+    max_num_tokens = rank_chunk(num_tokens, 0, world_size)
+
+    ata = AllToAll.internode(
+        max_num_tokens=max_num_tokens,
+        num_experts=num_experts,
+        experts_per_token=topk,
+        rank=rank,
+        world_size=world_size,
+        dp_size=dp_size,
+        hidden_dim=hidden_dim,
+        hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
+        hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
+                                ((hidden_dim + block_size - 1) // block_size *
+                                 torch.float32.itemsize)),
+    )
+
+    topk_ids = topk_ids.to(dtype=torch.uint32)
+
+    dispatch_combine = PplxDispatchCombine(
+        ata,
+        max_num_tokens,
+        world_size,
+        rank,
+        dp_size,
+        a.dtype,
+    )
+
+    a_chunk = chunk_by_rank(a, rank, world_size).to(device)
+    chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
+    chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
+
+    b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch(
+        a_chunk,
+        None,
+        None,
+        chunk_topk_weight,
+        chunk_topk_ids,
+        num_experts,
+        None,
+        False,
+    )
+
+    b_a = b_a * 1.5
+
+    out = torch.full(
+        (max_num_tokens, hidden_dim),
+        torch.nan,
+        dtype=a.dtype,
+        device=device,
+    )
+
+    dispatch_combine.combine(
+        out,
+        b_a,
+        chunk_topk_weight,
+        chunk_topk_ids,
+        False,
+    )
+
+    torch.cuda.synchronize()
+
+    ata.destroy()
+
+    num_tokens = a_chunk.shape[0]
+
+    return out[:num_tokens]
+
+
+def _pplx_dispatch_combine(
+    pgi: ProcessGroupInfo,
+    dp_size: int,
+    a,
+    score,
+    topk,
+    num_experts,
+):
+    uid = nvshmem_get_unique_id(
+    ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
+    torch.distributed.broadcast(uid, src=0)
+    nvshmem_init(uid, pgi.rank, pgi.world_size)
+    device = pgi.device
+
+    topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
+    k = a.shape[1]
+
+    a_rep = torch.repeat_interleave(a, topk, dim=0).to(device)
+
+    torch_output = (a_rep.view(-1, topk, k) * 1.5 *
+                    topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to(
+                        a.dtype)
+
+    pplx_output = pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids,
+                                        num_experts)
+
+    torch_output = chunk_by_rank(torch_output, pgi.rank,
+                                 pgi.world_size).to(pplx_output.device)
+
+    torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
+
+    nvshmem_finalize()
+
+
+# TODO: this test point does not work for M == 1
+@pytest.mark.parametrize("m", [4, 32, 64, 222])
+@pytest.mark.parametrize("n", [128, 1024, 2048])
+@pytest.mark.parametrize("k", [128, 512, 1024])
+@pytest.mark.parametrize("e", NUM_EXPERTS)
+@pytest.mark.parametrize("topk", TOP_KS)
+@pytest.mark.parametrize("dtype", [torch.bfloat16])
+@pytest.mark.parametrize("world_dp_size", [[2, 1]])
+@requires_pplx
+def test_pplx_dispatch_combine(
+    m: int,
+    n: int,
+    k: int,
+    e: int,
+    topk: int,
+    dtype: torch.dtype,
+    world_dp_size: tuple[int, int],
+):
+    current_platform.seed_everything(7)
+    world_size, dp_size = world_dp_size
+    device = "cuda"
+    a = torch.randn((m, k), device=device, dtype=dtype) / 10
+    score = torch.randn((m, e), device=device, dtype=dtype)
+
+    parallel_launch(world_size, _pplx_dispatch_combine, dp_size, a, score,
+                    topk, e)
+
+
+def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
+    assert torch.cuda.current_device() == pgi.local_rank
+
+    hidden_dim = a.shape[1]
+    num_experts = w1.shape[0]
+    block_size = 128
+    device = pgi.device
+    rank = pgi.rank
+    world_size = pgi.world_size
+    topk = topk_ids.shape[1]
+    max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
+
+    ata = AllToAll.internode(
+        max_num_tokens=max_num_tokens,
+        num_experts=num_experts,
+        experts_per_token=topk,
+        rank=rank,
+        world_size=world_size,
+        dp_size=dp_size,
+        hidden_dim=hidden_dim,
+        hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
+        hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
+                                ((hidden_dim + block_size - 1) // block_size *
+                                 torch.float32.itemsize)),
+    )
+
+    topk_ids = topk_ids.to(dtype=torch.uint32)
+
+    dispatch_combine = PplxDispatchCombine(
+        ata,
+        max_num_tokens,
+        world_size,
+        rank,
+        dp_size,
+    )
+
+    experts = BatchedExperts(a.shape[0])
+
+    fused_experts = FusedMoEModularKernel(
+        dispatch_combine,
+        experts,
+    )
+
+    # TODO: workers with the same dp_rank must use the exact same inputs.
+
+    a_chunk = chunk_by_rank(a, rank, world_size).to(device)
+    chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
+    chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
+
+    out = fused_experts(
+        a_chunk,
+        # Chunking weights like this only works for batched format
+        chunk_by_rank(w1, rank, world_size).to(device),
+        chunk_by_rank(w2, rank, world_size).to(device),
+        chunk_topk_weight,
+        chunk_topk_ids,
+        global_num_experts=num_experts)
+
+    torch.cuda.synchronize()
+
+    ata.destroy()
+
+    return out
+
+
+def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
+    assert torch.cuda.current_device() == pgi.local_rank
+
+    num_experts = w1.shape[0]
+    device = pgi.device
+    rank = pgi.rank
+    world_size = pgi.world_size
+    max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
+
+    dispatch_combine = BatchedDispatchCombine(
+        max_num_tokens=max_num_tokens,
+        world_size=world_size,
+        dp_size=dp_size,
+        rank=rank,
+    )
+
+    experts = BatchedExperts(a.shape[0])
+
+    fused_experts = FusedMoEModularKernel(
+        dispatch_combine,
+        experts,
+    )
+
+    # TODO: workers with the same dp_rank must use the exact same inputs.
+
+    a_chunk = chunk_by_rank(a, rank, world_size).to(device)
+    chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
+    chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
+
+    out = fused_experts(
+        a_chunk,
+        # Chunking weights like this only works for batched format
+        chunk_by_rank(w1, rank, world_size).to(device),
+        chunk_by_rank(w2, rank, world_size).to(device),
+        chunk_topk_weight,
+        chunk_topk_ids,
+        global_num_experts=num_experts)
+
+    return out
+
+
+def _pplx_moe(
+    pgi: ProcessGroupInfo,
+    dp_size: int,
+    a: torch.Tensor,
+    w1: torch.Tensor,
+    w2: torch.Tensor,
+    score: torch.Tensor,
+    topk: int,
+):
+    uid = nvshmem_get_unique_id(
+    ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
+    torch.distributed.broadcast(uid, src=0)
+    nvshmem_init(uid, pgi.rank, pgi.world_size)
+
+    m, k = a.shape
+    e, _, n = w2.shape
+
+    with set_current_vllm_config(vllm_config):
+        topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
+        torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
+        pplx_output = pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids)
+        batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
+                                      topk_ids)
+
+    torch_output = chunk_by_rank(torch_output, pgi.rank,
+                                 pgi.world_size).to(pplx_output.device)
+
+    torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
+    torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
+
+    nvshmem_finalize()
+
+
+@pytest.mark.parametrize("m", [1, 2, 3, 32, 45, 64, 222])
+@pytest.mark.parametrize("n", [128, 1024, 2048])
+@pytest.mark.parametrize("k", [128, 512, 1024])
+@pytest.mark.parametrize("e", NUM_EXPERTS)
+@pytest.mark.parametrize("topk", TOP_KS)
+@pytest.mark.parametrize("dtype", [torch.bfloat16])
+@pytest.mark.parametrize("world_dp_size", [[2, 1]])
+@requires_pplx
+def test_pplx_moe(
+    m: int,
+    n: int,
+    k: int,
+    e: int,
+    topk: int,
+    dtype: torch.dtype,
+    world_dp_size: tuple[int, int],
+):
+    current_platform.seed_everything(7)
+    world_size, dp_size = world_dp_size
+    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
+    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
+    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
+    score = torch.randn((m, e), device="cuda", dtype=dtype)
+
+    parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk)
diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py
index 44734e9340aa..3b5838a99fa1 100644
--- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py
+++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py
@@ -7,6 +7,7 @@
 import torch
 
 from vllm import _custom_ops as ops
+from vllm.config import VllmConfig, set_current_vllm_config
 from vllm.model_executor.layers.activation import SiluAndMul
 from vllm.model_executor.layers.fused_moe import fused_moe
 from vllm.platforms import current_platform
@@ -15,6 +16,10 @@
     pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
                 allow_module_level=True)
 
+vllm_config = VllmConfig()
+vllm_config.scheduler_config.max_num_seqs = 128
+vllm_config.scheduler_config.max_model_len = 8192
+
 
 def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
     """Matrix multiplication function that supports per-token input
@@ -137,20 +142,21 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
     w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
     score = torch.randn((M, E), dtype=dtype)
 
-    ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
-    out = fused_moe(
-        a,
-        w1,
-        w2,
-        score,
-        topk,
-        renormalize=False,
-        use_fp8_w8a8=True,  # using fp8
-        per_channel_quant=True,
-        w1_scale=w1_s,
-        w2_scale=w2_s,
-        block_shape=None,  # Not using block quantization
-    )
+    with set_current_vllm_config(vllm_config):
+        ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
+        out = fused_moe(
+            a,
+            w1,
+            w2,
+            score,
+            topk,
+            renormalize=False,
+            use_fp8_w8a8=True,  # using fp8
+            per_channel_quant=True,
+            w1_scale=w1_s,
+            w2_scale=w2_s,
+            block_shape=None,  # Not using block quantization
+        )
 
     # Check results
     rel_diff = (torch.mean(
diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py
index 38c7e461bb9c..ef1d7e47ef81 100644
--- a/tests/kernels/quantization/test_block_fp8.py
+++ b/tests/kernels/quantization/test_block_fp8.py
@@ -11,7 +11,7 @@
 from vllm.model_executor.layers.activation import SiluAndMul
 from vllm.model_executor.layers.fused_moe import fused_moe
 from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
-    deep_gemm_moe_fp8)
+    _valid_deep_gemm_shape, deep_gemm_moe_fp8)
 from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
 from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
     moe_align_block_size)
@@ -30,6 +30,10 @@
     pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
                 allow_module_level=True)
 
+vllm_config = VllmConfig()
+vllm_config.scheduler_config.max_num_seqs = 128
+vllm_config.scheduler_config.max_model_len = 8192
+
 # Test configurations
 DTYPES = [torch.bfloat16]  # [torch.half, torch.bfloat16, torch.float32]
 NUM_TOKENS = [7, 83, 2048]
@@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
     score = torch.randn((M, E), dtype=dtype)
 
     # Set the context to avoid lots of warning spam.
-    vllm_config = VllmConfig()
     with set_current_vllm_config(vllm_config):
         out = fused_moe(
             a,
@@ -258,6 +261,7 @@ def per_block_cast_to_fp8(
 @pytest.mark.parametrize(
     "M,N,K,block_size,out_dtype,seed",
     itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
+@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
 @torch.inference_mode()
 def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
     # only aligned sizes
@@ -381,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
     block_size = [block_m, block_m]
     dtype = torch.bfloat16
 
-    # only aligned sizes
-    if (N % block_m != 0 or K % block_m != 0 or topk > E):
-        pytest.skip(
-            f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
-
-    if N <= 512:
-        pytest.skip("Skipping N <= 512 until performance issues solved.")
+    if topk > E:
+        pytest.skip(f"Skipping test: topk={topk} > E={E}")
 
-    vllm_config = VllmConfig()
+    if not _valid_deep_gemm_shape(M, N, K):
+        pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
 
     torch.manual_seed(seed)
     fp8_info = torch.finfo(torch.float8_e4m3fn)
diff --git a/tests/kernels/quantization/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py
index 104f23fd7cd2..a4e9f83f0eaf 100644
--- a/tests/kernels/quantization/test_block_int8.py
+++ b/tests/kernels/quantization/test_block_int8.py
@@ -18,6 +18,10 @@
     pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
                 allow_module_level=True)
 
+vllm_config = VllmConfig()
+vllm_config.scheduler_config.max_num_seqs = 128
+vllm_config.scheduler_config.max_model_len = 8192
+
 
 # For test
 def native_per_token_group_quant_int8(x,
@@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
     score = torch.randn((M, E), dtype=dtype)
 
     # Set the context to avoid lots of warning spam.
-    vllm_config = VllmConfig()
     with set_current_vllm_config(vllm_config):
         out = fused_moe(
             a,
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index cb9658ce1004..10db10236f6f 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -23,6 +23,7 @@
 """
 import contextlib
 import gc
+import importlib.util
 import pickle
 import weakref
 from collections import namedtuple
@@ -42,7 +43,7 @@
 from vllm.distributed.utils import StatelessProcessGroup
 from vllm.logger import init_logger
 from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
-                        supports_custom_op)
+                        run_once, supports_custom_op)
 
 
 @dataclass
@@ -912,6 +913,45 @@ def init_distributed_environment(
             "world group already initialized with a different world size")
 
 
+PPLX_DID_INIT: bool = False
+
+
+@run_once
+def pplx_init(rank, world_size):
+    has_pplx = importlib.util.find_spec("pplx_kernels") is not None
+
+    if has_pplx and world_size > 1:
+        from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
+                                          nvshmem_get_unique_id, nvshmem_init)
+        try:
+            global PPLX_DID_INIT
+            logger.info("PPLX_INIT rank=%d world=%d", rank, world_size)
+            uid = nvshmem_get_unique_id(
+            ) if rank == 0 else nvshmem_alloc_empty_unique_id()
+            uid_gpu = uid.cuda()
+            get_world_group().broadcast(uid_gpu, src=0)
+            logger.debug("PPLX_INIT UID = %s", uid_gpu)
+            uid = uid_gpu.to(device='cpu')
+            nvshmem_init(uid, rank, world_size)
+            PPLX_DID_INIT = True
+        except Exception as ex:
+            logger.error("Failed to initialize nvshmem for pplx: %s", ex)
+
+
+@run_once
+def pplx_finalize():
+    global PPLX_DID_INIT
+    if PPLX_DID_INIT:
+        from pplx_kernels.nvshmem import nvshmem_finalize
+
+        from vllm.model_executor.layers.fused_moe.layer import (
+            _all_to_all_cache)
+        _all_to_all_cache.destroy()
+
+        logger.info("PPLX finalize")
+        nvshmem_finalize()
+
+
 def initialize_model_parallel(
     tensor_model_parallel_size: int = 1,
     pipeline_model_parallel_size: int = 1,
@@ -1006,6 +1046,8 @@ def initialize_model_parallel(
         "DP rank %s, PP rank %s, TP rank %s", rank, world_size,
         _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group)
 
+    pplx_init(rank, world_size)
+
 
 def ensure_model_parallel_initialized(
     tensor_model_parallel_size: int,
@@ -1081,6 +1123,9 @@ def get_tensor_model_parallel_rank():
 def destroy_model_parallel():
     """Set the groups to none and destroy them."""
     global _TP
+
+    pplx_finalize()
+
     if _TP:
         _TP.destroy()
     _TP = None
diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py
index e4d4008cd0a6..18459f982c0a 100644
--- a/vllm/distributed/utils.py
+++ b/vllm/distributed/utils.py
@@ -360,7 +360,12 @@ def stateless_destroy_torch_distributed_process_group(
     Destroy ProcessGroup returned by
         stateless_init_torch_distributed_process_group().
     """
-    # Lazy import for non-CUDA backends.
-    from torch.distributed.distributed_c10d import _shutdown_backend
-    _shutdown_backend(pg)
+    try:
+        # pytorch < 2.7
+        # Lazy import for non-CUDA backends.
+        from torch.distributed.distributed_c10d import _shutdown_backend
+        _shutdown_backend(pg)
+    except:
+        pg.shutdown()
+
     _unregister_process_group(pg.group_name)
diff --git a/vllm/forward_context.py b/vllm/forward_context.py
index 9ddc3d1f2c51..f3732f83af5c 100644
--- a/vllm/forward_context.py
+++ b/vllm/forward_context.py
@@ -31,6 +31,7 @@
 
 @dataclass
 class DPMetadata:
+    max_tokens_across_dp_cpu: torch.Tensor
     cu_tokens_across_dp_cpu: torch.Tensor
 
 
@@ -94,8 +95,10 @@ def set_forward_context(attn_metadata: Any,
                                          dtype=torch.int32)
         from vllm.distributed.parallel_state import get_dp_group
         dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
+        max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
         cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
-        dp_metadata = DPMetadata(cu_tokens_across_dp_cpu)
+        dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
+                                 cu_tokens_across_dp_cpu)
 
     global _forward_context
     prev_context = _forward_context
diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py
index 9829ccdb384f..b55a3ede12d3 100644
--- a/vllm/model_executor/layers/fused_moe/__init__.py
+++ b/vllm/model_executor/layers/fused_moe/__init__.py
@@ -38,8 +38,8 @@ def get_config() -> Optional[Dict[str, Any]]:
     from vllm.model_executor.layers.fused_moe.cutlass_moe import (
         cutlass_moe_fp8)
     from vllm.model_executor.layers.fused_moe.fused_moe import (
-        fused_experts, fused_moe, fused_topk, get_config_file_name,
-        grouped_topk)
+        TritonExperts, fused_experts, fused_moe, fused_topk,
+        get_config_file_name, grouped_topk)
 
     __all__ += [
         "fused_moe",
@@ -48,4 +48,5 @@ def get_config() -> Optional[Dict[str, Any]]:
         "get_config_file_name",
         "grouped_topk",
         "cutlass_moe_fp8",
+        "TritonExperts",
     ]
diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
index 960c7f834857..e52751eddf2c 100644
--- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
@@ -1,10 +1,199 @@
 # SPDX-License-Identifier: Apache-2.0
 """Fused MoE kernel."""
-from typing import Optional
+from typing import Optional, Tuple
 
 import torch
 
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
 from vllm import _custom_ops as ops
+from vllm.model_executor.layers.fused_moe.dispatch_combine import (
+    StandardDispatchCombine)
+from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
+
+
+class CutlassExperts(mk.FusedMoEPermuteExpertsUnpermute):
+
+    def __init__(
+        self,
+        ab_strides1: torch.Tensor,
+        c_strides1: torch.Tensor,
+        ab_strides2: torch.Tensor,
+        c_strides2: torch.Tensor,
+        out_dtype: torch.dtype,
+    ):
+        super().__init__()
+        self.ab_strides1 = ab_strides1
+        self.c_strides1 = c_strides1
+        self.ab_strides2 = ab_strides2
+        self.c_strides2 = c_strides2
+        self.out_dtype = out_dtype
+
+    def workspace_shapes(
+        self,
+        a: torch.Tensor,
+        M: int,
+        N: int,
+        K: int,
+        topk: int,
+        num_experts: int,
+    ) -> Tuple[int, int, torch.dtype]:
+        # Note that K, N are transposed
+        N, K = K, N
+        workspace1 = M * topk * max(2 * N, K)
+        workspace2 = M * topk * N
+        return (workspace1, workspace2, self.out_dtype)
+
+    def apply(
+        self,
+        hidden_states: torch.Tensor,
+        w1: torch.Tensor,
+        w2: torch.Tensor,
+        topk_ids: torch.Tensor,
+        activation: str,
+        global_num_experts: int,
+        expert_map: Optional[torch.Tensor],
+        w1_scale: Optional[torch.Tensor],
+        w2_scale: Optional[torch.Tensor],
+        w1_zp: Optional[torch.Tensor],
+        w2_zp: Optional[torch.Tensor],
+        a1q_scale: Optional[torch.Tensor],
+        a2_scale: Optional[torch.Tensor],
+        workspace13: torch.Tensor,
+        workspace2: torch.Tensor,
+        expert_num_tokens: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        a1q = hidden_states
+
+        assert w1_scale is not None
+        assert w2_scale is not None
+        assert w1.dtype == torch.float8_e4m3fn
+        assert w2.dtype == torch.float8_e4m3fn
+        assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1"
+        assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2"
+        assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
+        assert a1q_scale is None or a1q_scale.dim(
+        ) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
+            0], "Input scale shape mismatch"
+        assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
+            1] == w1.shape[2], "W1 scale shape mismatch"
+        assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
+            1] == w2.shape[2], "W2 scale shape mismatch"
+        assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
+        assert w1.shape[0] == w1_scale.shape[
+            0], "w1 scales expert number mismatch"
+        assert w1.shape[0] == w2_scale.shape[
+            0], "w2 scales expert number mismatch"
+        assert a2_scale is None or a1q_scale is None or a2_scale.shape == a1q_scale.shape, "Intermediate scale shape mismatch"  # noqa: E501
+        assert self.ab_strides1.shape[0] == w1.shape[
+            0], "AB Strides 1 expert number mismatch"
+        assert self.c_strides1.shape[0] == w1.shape[
+            0], "C Strides 1 expert number mismatch"
+        assert self.ab_strides2.shape[0] == w2.shape[
+            0], "AB Strides 2 expert number  mismatch"
+        assert self.c_strides2.shape[0] == w2.shape[
+            0], "C Strides 2 expert number mismatch"
+        assert self.out_dtype in [torch.half,
+                                  torch.bfloat16], "Invalid output dtype"
+
+        M = a1q.shape[0]
+        _, N, K = w2.shape  # because w1 + w2 are transposed
+        device = a1q.device
+
+        assert w1.shape[1] == K
+        assert global_num_experts != -1
+        assert a1q_scale is not None
+
+        if expert_map is not None:
+            "Translate info from expert_map to topk_ids"
+            local_topk_ids = torch.where(expert_map[topk_ids] != -1,
+                                         expert_map[topk_ids], -1)
+        else:
+            local_topk_ids = topk_ids
+
+        topk = local_topk_ids.shape[1]
+
+        per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else (
+            a2_scale.numel() != 1 if a2_scale is not None else False)
+
+        expert_offsets = torch.empty((global_num_experts + 1),
+                                     dtype=torch.int32,
+                                     device=device)
+        problem_sizes1 = torch.empty((global_num_experts, 3),
+                                     dtype=torch.int32,
+                                     device=device)
+        problem_sizes2 = torch.empty((global_num_experts, 3),
+                                     dtype=torch.int32,
+                                     device=device)
+
+        # With expert_map each Rank processes only a subset of experts. As
+        # a result not all of a_map and c2 tensors are filled. We fill it
+        # zeros for correctness.
+        if expert_map is not None:
+            a_map = torch.zeros((local_topk_ids.numel()),
+                                dtype=torch.int32,
+                                device=device)
+        else:
+            a_map = torch.empty((local_topk_ids.numel()),
+                                dtype=torch.int32,
+                                device=device)
+
+        c_map = torch.empty((local_topk_ids.numel()),
+                            dtype=torch.int32,
+                            device=device)
+
+        ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets,
+                                    problem_sizes1, problem_sizes2, a_map,
+                                    c_map, global_num_experts, N, K)
+
+        a1q = _fp8_perm(a1q, a_map)
+        a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
+
+        c1 = _resize_cache(workspace13, (M * topk, N * 2))
+        c2 = _resize_cache(workspace2, (M * topk, N))
+        c3 = _resize_cache(workspace13, (M * topk, K))
+
+        ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale,
+                           expert_offsets[:-1], problem_sizes1,
+                           self.ab_strides1, self.ab_strides1, self.c_strides1)
+
+        self.activation(activation, c2, c1)
+
+        a2q, a2q_scale = ops.scaled_fp8_quant(
+            c2, a2_scale, use_per_token_if_dynamic=per_act_token)
+
+        if expert_map is not None:
+            c3.fill_(0)
+
+        ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale,
+                           expert_offsets[:-1], problem_sizes2,
+                           self.ab_strides2, self.ab_strides2, self.c_strides2)
+
+        c3 = c3[c_map, ...]
+
+        return c3
+
+
+def modular_cutlass_moe_fp8(
+    per_act_token: bool,
+    ab_strides1: torch.Tensor,
+    c_strides1: torch.Tensor,
+    ab_strides2: torch.Tensor,
+    c_strides2: torch.Tensor,
+    out_dtype: torch.dtype = torch.half,
+) -> mk.FusedMoEModularKernel:
+    return mk.FusedMoEModularKernel(
+        StandardDispatchCombine(
+            per_channel_quant=per_act_token,
+            quant_dtype=torch.float8_e4m3fn,
+        ),
+        CutlassExperts(
+            ab_strides1,
+            c_strides1,
+            ab_strides2,
+            c_strides2,
+            out_dtype,
+        ),
+    )
 
 
 #TODO make the grouped gemm kernel consistent with scaled gemm kernel
@@ -15,7 +204,7 @@ def cutlass_moe_fp8(
     w1_scale: torch.Tensor,
     w2_scale: torch.Tensor,
     topk_weights: torch.Tensor,
-    topk_ids_: torch.Tensor,
+    topk_ids: torch.Tensor,
     ab_strides1: torch.Tensor,
     c_strides1: torch.Tensor,
     ab_strides2: torch.Tensor,
@@ -57,7 +246,7 @@ def cutlass_moe_fp8(
     - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
         quantize the intermediate result between the gemms.
         Shape: scalar or [M]
-    - out_dtype (torch.Tensor): The output tensor type.
+    - out_dtype (torch.dtype): The output tensor type.
     - expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
         every Rank is responsible for a subset of experts. expert_map is a
         mapping from global expert-id to local expert-id. When expert_map[i]
@@ -69,112 +258,28 @@ def cutlass_moe_fp8(
     Returns:
     - torch.Tensor: The fp16 output tensor after applying the MoE layer.
     """
-
-    assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
-    assert w1_q.dtype == torch.float8_e4m3fn
-    assert w2_q.dtype == torch.float8_e4m3fn
-    assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
-    assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
-    assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
-    assert a1_scale is None or a1_scale.dim(
-    ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[
-        0], "Input scale shape mismatch"
-    assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
-        1] == w1_q.shape[2], "W1 scale shape mismatch"
-    assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
-        1] == w2_q.shape[2], "W2 scale shape mismatch"
-    assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
-    assert w1_q.shape[0] == w1_scale.shape[
-        0], "w1 scales expert number mismatch"
-    assert w1_q.shape[0] == w2_scale.shape[
-        0], "w2 scales expert number mismatch"
-    assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch"  # noqa: E501
-    assert ab_strides1.shape[0] == w1_q.shape[
-        0], "AB Strides 1 expert number mismatch"
-    assert c_strides1.shape[0] == w1_q.shape[
-        0], "C Strides 1 expert number mismatch"
-    assert ab_strides2.shape[0] == w2_q.shape[
-        0], "AB Strides 2 expert number  mismatch"
-    assert c_strides2.shape[0] == w2_q.shape[
-        0], "C Strides 2 expert number mismatch"
-    assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
-
-    num_experts = w1_q.size(0)
-    m = a.size(0)
-    k = w1_q.size(1)
-    n = w2_q.size(1)
-
-    local_topk_ids = topk_ids_
-    if expert_map is not None:
-        "Translate info from expert_map to topk_ids"
-        local_topk_ids = torch.where(expert_map[topk_ids_] != -1,
-                                     expert_map[topk_ids_], -1)
-
-    topk = local_topk_ids.size(1)
-
     per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
         a2_scale.numel() != 1 if a2_scale is not None else False)
-    if apply_router_weight_on_input:
-        assert topk == 1, \
-            "apply_router_weight_on_input is only implemented for topk=1"
-        # TODO: this only works for topK=1, will need to update for topK>1
-        a = a * topk_weights.to(out_dtype)
-
-    a_q, a1_scale = ops.scaled_fp8_quant(
-        a, a1_scale, use_per_token_if_dynamic=per_act_token)
-    device = a_q.device
-
-    expert_offsets = torch.empty((num_experts + 1),
-                                 dtype=torch.int32,
-                                 device=device)
-    problem_sizes1 = torch.empty((num_experts, 3),
-                                 dtype=torch.int32,
-                                 device=device)
-    problem_sizes2 = torch.empty((num_experts, 3),
-                                 dtype=torch.int32,
-                                 device=device)
-
-    a_map_initializer = torch.empty
-    c2_initializer = torch.empty
-    if expert_map is not None:
-        # With expert_map each Rank processes only a subset of experts. As
-        # a result not all of a_map and c2 tensors are filled. We fill it
-        # zeros for correctness.
-        a_map_initializer = torch.zeros
-        c2_initializer = torch.zeros
-
-    a_map = a_map_initializer((local_topk_ids.numel()),
-                              dtype=torch.int32,
-                              device=device)
-    c_map = torch.empty((local_topk_ids.numel()),
-                        dtype=torch.int32,
-                        device=device)
-
-    ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1,
-                                problem_sizes2, a_map, c_map, num_experts, n,
-                                k)
-
-    rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
-    rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
-
-    c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
-    c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype)
-
-    ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
-                       expert_offsets[:-1], problem_sizes1, ab_strides1,
-                       ab_strides1, c_strides1)
-
-    intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
-    torch.ops._C.silu_and_mul(intermediate, c1)
-
-    intemediate_q, a2_scale = ops.scaled_fp8_quant(
-        intermediate, a2_scale, use_per_token_if_dynamic=per_act_token)
-
-    ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
-                       expert_offsets[:-1], problem_sizes2, ab_strides2,
-                       ab_strides2, c_strides2)
-    # Gather tokens
-    c2 = c2[c_map].view(m, topk, k)
-    if not apply_router_weight_on_input:
-        c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype)
-    return c2.sum(dim=1)
+
+    fn = modular_cutlass_moe_fp8(
+        per_act_token,
+        ab_strides1,
+        c_strides1,
+        ab_strides2,
+        c_strides2,
+        out_dtype,
+    )
+
+    return fn(
+        a,
+        w1_q,
+        w2_q,
+        topk_weights,
+        topk_ids,
+        expert_map=expert_map,
+        w1_scale=w1_scale,
+        w2_scale=w2_scale,
+        a1_scale=a1_scale,
+        a2_scale=a2_scale,
+        apply_router_weight_on_input=apply_router_weight_on_input,
+    )
diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
index 353c8cc9d59f..4a0fb374bd41 100644
--- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
@@ -1,16 +1,17 @@
 # SPDX-License-Identifier: Apache-2.0
+import functools
 import importlib.util
 from typing import Optional, Tuple
 
 import torch
 
-import vllm.envs as envs
-from vllm import _custom_ops as ops
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
 from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
-    moe_align_block_size)
-from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
-                                                        _fp8_quantize,
+from vllm.model_executor.layers.fused_moe.dispatch_combine import (
+    StandardDispatchCombine)
+from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
+    _moe_permute)
+from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
                                                         _resize_cache)
 from vllm.utils import round_up
 
@@ -19,6 +20,19 @@
 has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
 
 
+@functools.cache
+def deep_gemm_block_shape() -> list[int]:
+    # Lazy import to avoid CUDA initialization problems.
+    import deep_gemm as dg
+    block = dg.get_m_alignment_for_contiguous_layout()
+    return [block, block]
+
+
+def _valid_deep_gemm_shape(M: int, N: int, K: int):
+    align = deep_gemm_block_shape()[0]
+    return align <= M and N % align == 0 and K % align == 0
+
+
 def _valid_deep_gemm(hidden_states: torch.Tensor,
                      w1: torch.Tensor,
                      w2: torch.Tensor,
@@ -29,89 +43,120 @@ def _valid_deep_gemm(hidden_states: torch.Tensor,
     aligned by `dg.get_m_alignment_for_contiguous_layout()`.
     """
     if not has_deep_gemm:
+        logger.debug("DeepGemm disabled: deep_gemm not available.")
         return False
 
-    # Lazy import to avoid CUDA initialization problems.
-    import deep_gemm as dg
-
-    # Expert maps not supported yet.
     if expert_map is not None:
+        logger.debug("DeepGemm disabled: expert map NYI.")
         return False
 
-    align = dg.get_m_alignment_for_contiguous_layout()
     M = hidden_states.shape[0]
     _, K, N = w2.shape
-
-    # For now, disable DeepGemm for small N until better permute/unpermute
-    # ops are available.
-    if N <= 512:
+    if not _valid_deep_gemm_shape(M, N, K):
+        logger.debug("DeepGemm disabled: unalinged problem size.")
         return False
 
-    if align > M or N % align != 0 or K % align != 0:
+    if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn):
+        logger.debug("DeepGemm disabled: invalid weight dtype(s).")
         return False
 
-    return (hidden_states.is_contiguous() and w1.is_contiguous()
-            and w2.is_contiguous())
-
+    if (not hidden_states.is_contiguous() or not w1.is_contiguous()
+            or not w2.is_contiguous()):
+        logger.debug(
+            "DeepGemm disabled: weights or activations not contiguous.")
+        return False
 
-def _moe_permute(
-    curr_hidden_states: torch.Tensor,
-    a1q_scale: Optional[torch.Tensor],
-    curr_topk_ids: torch.Tensor,
-    global_num_experts: int,
-    expert_map: Optional[torch.Tensor],
-    block_m: int,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
-           Optional[torch.Tensor]]:
-    """
-    Determine the sorted_token_ids, expert_ids for the given problem size.
-    Permute the hidden states and scales according to `sorted_token_ids`.
-    """
-    top_k_num = curr_topk_ids.shape[1]
+    return True
+
+
+class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
+
+    def __init__(self):
+        super().__init__()
+        self.block_shape = deep_gemm_block_shape()
+
+    def workspace_shapes(
+        self,
+        a: torch.Tensor,
+        M: int,
+        N: int,
+        K: int,
+        topk: int,
+        num_experts: int,
+    ) -> Tuple[int, int, torch.dtype]:
+        block_m = self.block_shape[0]
+        M_sum = (M * topk) + num_experts * (block_m - 1)
+        M_sum = round_up(M_sum, block_m)
+        workspace1 = M_sum * max(N * 2, K)
+        workspace2 = M_sum * N
+        return (workspace1, workspace2, a.dtype)
+
+    def apply(
+        self,
+        hidden_states: torch.Tensor,
+        w1: torch.Tensor,
+        w2: torch.Tensor,
+        topk_ids: torch.Tensor,
+        activation: str,
+        global_num_experts: int,
+        expert_map: Optional[torch.Tensor],
+        w1_scale: Optional[torch.Tensor],
+        w2_scale: Optional[torch.Tensor],
+        w1_zp: Optional[torch.Tensor],
+        w2_zp: Optional[torch.Tensor],
+        a1q_scale: Optional[torch.Tensor],
+        a2_scale: Optional[torch.Tensor],
+        workspace13: torch.Tensor,
+        workspace2: torch.Tensor,
+        expert_num_tokens: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        import deep_gemm as dg
+
+        a1q = hidden_states
+        _, N, K = w1.shape
+
+        assert global_num_experts != -1
+        assert w2.shape[1] == K
+
+        a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute(
+            a1q,
+            a1q_scale,
+            topk_ids,
+            global_num_experts,
+            expert_map,
+            self.block_shape[0],
+        )
+
+        # Note: M_sum is different than the pre-permuted shape of a1q.
+        M_sum = a1q.shape[0]
+        workspace1 = _resize_cache(workspace13, (M_sum, N))
+        workspace2 = _resize_cache(workspace2, (M_sum, N // 2))
+        workspace3 = _resize_cache(workspace13, (M_sum, K))
 
-    tokens_in_chunk, _ = curr_hidden_states.shape
+        dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
+            (a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids)
 
-    sorted_token_ids, expert_ids, num_tokens_post_padded = (
-        moe_align_block_size(curr_topk_ids,
-                             block_m,
-                             global_num_experts,
-                             expert_map,
-                             pad_sorted_ids=True))
+        self.activation(activation, workspace2, workspace1.view(-1, N))
 
-    inv_perm: Optional[torch.Tensor] = None
+        a2q_scale: Optional[torch.Tensor] = None
 
-    num_tokens = top_k_num * tokens_in_chunk
-    sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
-    expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
-    inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
+        a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, False,
+                                       self.block_shape)
 
-    # Permute according to sorted token ids.
-    curr_hidden_states = _fp8_perm(curr_hidden_states,
-                                   sorted_token_ids // top_k_num)
+        dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
+            (a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
 
-    if a1q_scale is not None:
-        a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
+        workspace3 = workspace3[inv_perm, ...]
 
-    return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
-            inv_perm)
+        return workspace3
 
 
-def _moe_unpermute_and_reduce(
-    out: torch.Tensor,
-    curr_hidden: torch.Tensor,
-    inv_perm: Optional[torch.Tensor],
-    topk_weight: torch.Tensor,
-) -> None:
-    """
-    Unpermute the final result and apply topk_weights, then perform the final
-    reduction on the hidden states.
-    """
-    M, topk = topk_weight.shape
-    K = curr_hidden.shape[1]
-    curr_hidden = curr_hidden[inv_perm, ...]
-    curr_hidden = curr_hidden.view(-1, topk, K)
-    curr_hidden.mul_(topk_weight.view(M, -1, 1))
-    ops.moe_sum(curr_hidden, out)
+def modular_deep_gemm_fused_moe_fp8() -> mk.FusedMoEModularKernel:
+    return mk.FusedMoEModularKernel(
+        StandardDispatchCombine(quant_dtype=torch.float8_e4m3fn,
+                                block_shape=deep_gemm_block_shape()),
+        DeepGemmExperts(),
+    )
 
 
 def deep_gemm_moe_fp8(
@@ -128,6 +173,7 @@ def deep_gemm_moe_fp8(
     expert_map: Optional[torch.Tensor] = None,
     a1_scale: Optional[torch.Tensor] = None,
     a2_scale: Optional[torch.Tensor] = None,
+    apply_router_weight_on_input=False,
 ) -> torch.Tensor:
     """
     This function computes a a8w8-quantized Mixture of Experts (MoE) layer
@@ -166,129 +212,20 @@ def deep_gemm_moe_fp8(
     Returns:
     - torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
     """
-    # Lazy import to avoid CUDA initialization problems.
-    import deep_gemm as dg
-
-    assert expert_map is None, "Expert maps not supported yet"
-
-    assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
-
-    assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
-    assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
-    assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
-    assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
-    assert hidden_states.dtype in [
-        torch.float32, torch.float16, torch.bfloat16
-    ]
-    assert w1.dtype == torch.float8_e4m3fn
-    assert w2.dtype == torch.float8_e4m3fn
-    assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
-    assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
-    assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
-    assert a1_scale is None or a1_scale.dim(
-    ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[
-        0] == hidden_states.shape[0], "Input scale shape mismatch"
-    assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch"  # noqa: E501
-
-    num_tokens, _ = hidden_states.shape
-    E, N, _ = w1.shape
-    K = w2.shape[1]
-    if global_num_experts == -1:
-        global_num_experts = E
-
-    # We execute the fused_moe kernel in chunks to circumvent this issue:
-    # https://github.com/vllm-project/vllm/issues/5938
-    CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
-
-    assert _valid_deep_gemm(hidden_states, w1, w2, expert_map)
-
-    if inplace:
-        out_hidden_states = hidden_states
-    else:
-        out_hidden_states = torch.empty_like(hidden_states)
-
-    block_m = dg.get_m_alignment_for_contiguous_layout()
-    block_shape = [block_m, block_m]
-
-    assert w1_scale is not None
-    assert w2_scale is not None
-
-    # We attempt to transpose and align offline in Fp8MoEMethod, in which
-    # case these calls will be nops.  Otherwise, they'll be performed every
-    # time the layer is executed.
-    w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous()
-    w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous()
-
-    M_sum = topk_ids.numel() + global_num_experts * (block_m - 1)
-    M_sum = round_up(M_sum, block_m)
-
-    num_chunks = (num_tokens // CHUNK_SIZE) + 1
-
-    # We can reuse the memory between cache1 and cache3 because by the time
-    # we need cache3, we're done with cache1
-    workspace13 = torch.empty(M_sum * max(N, K),
-                              device=hidden_states.device,
-                              dtype=hidden_states.dtype)
-
-    workspace1 = workspace13[:M_sum * N].view(M_sum, N)
-    workspace2 = torch.empty((M_sum, N // 2),
-                             device=hidden_states.device,
-                             dtype=hidden_states.dtype)
-    workspace3 = workspace13[:M_sum * K].view(M_sum, K)
-
-    for chunk in range(num_chunks):
-        begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
-                                          min((chunk + 1) * CHUNK_SIZE,
-                                              num_tokens))
-        curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
-        tokens_in_chunk, _ = curr_hidden_states.shape
-
-        if tokens_in_chunk == 0:
-            break
-
-        curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
-        curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
-
-        a1q_scale: Optional[torch.Tensor] = None
-
-        qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states,
-                                                       a1_scale, block_shape)
-
-        (qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
-         inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale,
-                                  curr_topk_ids, global_num_experts,
-                                  expert_map, block_m)
-
-        # Adjust the intermediate cache size and config for the last chunk.
-        # Note that in most cases we only have one chunk so the cache size
-        # and config are already set correctly and do not need to be adjusted.
-        if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
-            curr_M = sorted_token_ids.numel()
-            workspace1 = _resize_cache(workspace1, (curr_M, N))
-            workspace2 = _resize_cache(workspace2, (curr_M, N // 2))
-            workspace3 = _resize_cache(workspace3, (curr_M, K))
-
-        dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
-            (qcurr_hidden_states, a1q_scale), (w1, w1_scale), workspace1,
-            expert_ids)
-
-        if activation == "silu":
-            torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N))
-        elif activation == "gelu":
-            torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N))
-        else:
-            raise ValueError(f"Unsupported FusedMoe activation: {activation}")
-
-        a2q_scale: Optional[torch.Tensor] = None
-
-        qworkspace2, a2q_scale = _fp8_quantize(workspace2, a2_scale,
-                                               block_shape)
-
-        dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
-            (qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
-
-        _moe_unpermute_and_reduce(
-            out_hidden_states[begin_chunk_idx:end_chunk_idx],
-            workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights)
-
-    return out_hidden_states
+    fn = modular_deep_gemm_fused_moe_fp8()
+    return fn(
+        hidden_states,
+        w1,
+        w2,
+        topk_weights,
+        topk_ids,
+        inplace,
+        activation,
+        global_num_experts,
+        expert_map,
+        w1_scale=w1_scale,
+        w2_scale=w2_scale,
+        a1_scale=a1_scale,
+        a2_scale=a2_scale,
+        apply_router_weight_on_input=apply_router_weight_on_input,
+    )
diff --git a/vllm/model_executor/layers/fused_moe/dispatch_combine.py b/vllm/model_executor/layers/fused_moe/dispatch_combine.py
new file mode 100644
index 000000000000..9b647a70d5e0
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/dispatch_combine.py
@@ -0,0 +1,60 @@
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import torch
+
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
+    _moe_unpermute_and_reduce)
+from vllm.model_executor.layers.fused_moe.utils import (
+    moe_kernel_quantize_input)
+
+
+class StandardDispatchCombine(mk.FusedMoEQuantizeDispatchCombine):
+
+    def __init__(
+        self,
+        quant_dtype: Optional[torch.dtype] = None,
+        per_channel_quant: bool = False,
+        block_shape: Optional[list[int]] = None,
+    ):
+        super().__init__()
+        self.per_channel_quant = per_channel_quant
+        self.block_shape = block_shape
+        self.quant_dtype = quant_dtype
+
+    def dispatch(
+        self,
+        a1: torch.Tensor,
+        a1_scale: Optional[torch.Tensor],
+        a2_scale: Optional[torch.Tensor],
+        topk_weights: torch.Tensor,
+        topk_ids: torch.Tensor,
+        num_experts: int,
+        expert_map: Optional[torch.Tensor],
+        apply_router_weight_on_input: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+        if apply_router_weight_on_input:
+            topk = topk_ids.shape[1]
+            # TODO: this only works for topK=1, will need to update for topK>1
+            assert topk == 1, \
+                "apply_router_weight_on_input is only implemented for topk=1"
+            a1.mul_(topk_weights.to(a1.dtype))
+
+        a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
+                                                   self.quant_dtype,
+                                                   self.per_channel_quant,
+                                                   self.block_shape)
+
+        return a1q, a1q_scale, None
+
+    def combine(
+        self,
+        output: torch.Tensor,
+        fused_expert_output: torch.Tensor,
+        topk_weights: torch.Tensor,
+        topk_ids: torch.Tensor,
+        apply_router_weight_on_input: bool,
+    ) -> None:
+        _moe_unpermute_and_reduce(output, fused_expert_output, None,
+                                  topk_weights, apply_router_weight_on_input)
diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
new file mode 100644
index 000000000000..47739375b141
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
@@ -0,0 +1,740 @@
+# SPDX-License-Identifier: Apache-2.0
+"""Fused batched MoE kernel."""
+from typing import List, Optional, Tuple
+
+import torch
+import triton
+import triton.language as tl
+
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.model_executor.layers.fused_moe.fused_moe import (
+    get_config_dtype_str, try_get_optimal_moe_config)
+from vllm.model_executor.layers.fused_moe.utils import _resize_cache
+
+
+@triton.jit
+def moe_mmk(
+        a_ptrs,
+        b_ptrs,
+        K,
+        expert_id,
+        a_scale_ptr,
+        b_scale_ptr,
+        # The stride variables represent how much to increase the ptr by when
+        # moving by 1 element in a particular dimension. E.g. `stride_am` is
+        # how much to increase `a_ptr` by to get the element one row down
+        # (A has M rows).
+        stride_ak,
+        stride_bk,
+        stride_asm,
+        stride_ask,
+        stride_bse,
+        stride_bsk,
+        stride_bsn,
+        # Offsets and masks
+        offs_m,
+        offs_n,
+        mask_m,
+        # Block size for block-wise quantization
+        group_n: tl.constexpr,
+        group_k: tl.constexpr,
+        # Meta-parameters
+        BLOCK_M: tl.constexpr,
+        BLOCK_N: tl.constexpr,
+        BLOCK_K: tl.constexpr,
+        compute_type: tl.constexpr,
+        use_w8a8: tl.constexpr,
+        use_w8a16: tl.constexpr):
+
+    offs_k = tl.arange(0, BLOCK_K)
+
+    if use_w8a16:
+        b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[
+            None, :] * stride_bsn
+        b_scale = tl.load(b_scale_ptrs)
+
+    if use_w8a8:
+        # block-wise
+        if group_k > 0 and group_n > 0:
+            a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
+            offs_bsn = offs_n // group_n
+            b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse +
+                            offs_bsn * stride_bsn)
+        # tensor-wise
+        else:
+            a_scale = tl.load(a_scale_ptr)
+            b_scale = tl.load(b_scale_ptr + expert_id)
+
+    # -----------------------------------------------------------
+    # Iterate to compute a block of the C matrix.
+    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
+    # of fp32 values for higher accuracy.
+    # `accumulator` will be converted back to fp16 after the loop.
+    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+    for k in range(0, tl.cdiv(K, BLOCK_K)):
+        # Load the next block of A and B, generate a mask by checking the
+        # K dimension.
+        a = tl.load(a_ptrs,
+                    mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K),
+                    other=0.0)
+        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
+        # We accumulate along the K dimension.
+        if use_w8a16:
+            accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
+        elif use_w8a8:
+            if group_k > 0 and group_n > 0:
+                k_start = k * BLOCK_K
+                offs_ks = k_start // group_k
+                a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask,
+                                  mask=mask_m,
+                                  other=0.0)
+                b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
+
+                accumulator += tl.dot(a, b) * a_scale[:,
+                                                      None] * b_scale[None, :]
+            else:
+                if use_w8a8:
+                    # acc used to enable fp8_fast_accum
+                    accumulator = tl.dot(a, b, acc=accumulator)
+                else:
+                    accumulator += tl.dot(a, b)
+        else:
+            accumulator += tl.dot(a, b)
+        # Advance the ptrs to the next K block.
+        a_ptrs += BLOCK_K * stride_ak
+        b_ptrs += BLOCK_K * stride_bk
+
+    if use_w8a16:
+        accumulator = (accumulator * b_scale).to(compute_type)
+    elif use_w8a8:
+        if group_k > 0 and group_n > 0:
+            accumulator = accumulator.to(compute_type)
+        else:
+            accumulator = (accumulator * a_scale * b_scale).to(compute_type)
+    else:
+        accumulator = accumulator.to(compute_type)
+
+    return accumulator
+
+
+@triton.jit
+def expert_triton_kernel(
+        a_ptr,  #[max_tokens, K]
+        b_ptr,  #[K, N]
+        c_ptr,  #[max_tokens, N]
+        expert_id,
+        compute_type: tl.constexpr,
+        # Dimensions
+        M,
+        N,
+        K,
+        # Quantization data
+        a_scale_ptr,
+        b_scale_ptr,
+        b_zp_ptr,
+        # strides
+        stride_am,
+        stride_ak,
+        stride_bk,
+        stride_bn,
+        stride_cm,
+        stride_cn,
+        stride_asm,
+        stride_ask,
+        stride_bse,
+        stride_bsk,
+        stride_bsn,
+        # Blockwise quantization data
+        group_n,
+        group_k,
+        # Quantization schemes
+        use_fp8_w8a8: tl.constexpr,
+        use_int8_w8a16: tl.constexpr,
+        # Kernel config
+        BLOCK_M: tl.constexpr,
+        BLOCK_N: tl.constexpr,
+        BLOCK_K: tl.constexpr):
+
+    offs_m = tl.arange(0, BLOCK_M)
+    offs_n = tl.arange(0, BLOCK_N) % N
+    offs_k = tl.arange(0, BLOCK_K)
+    mask_m = offs_m < M
+
+    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
+    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
+
+    accumulator = moe_mmk(
+        a_ptrs,
+        b_ptrs,
+        K,
+        expert_id,
+        a_scale_ptr,
+        b_scale_ptr,
+        # The stride variables represent how much to increase the ptr by when
+        # moving by 1 element in a particular dimension. E.g. `stride_am` is
+        # how much to increase `a_ptr` by to get the element one row down
+        # (A has M rows).
+        stride_ak,
+        stride_bk,
+        stride_asm,
+        stride_ask,
+        stride_bse,
+        stride_bsk,
+        stride_bsn,
+        # Offsets and masks
+        offs_m,
+        offs_n,
+        mask_m,
+        # Block size for block-wise quantization
+        group_n,
+        group_k,
+        # Meta-parameters
+        BLOCK_M,
+        BLOCK_N,
+        BLOCK_K,
+        compute_type,
+        use_fp8_w8a8,
+        use_int8_w8a16)
+
+    # store in C
+    offs_cn = tl.arange(0, BLOCK_N)
+    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn
+    c_mask = mask_m[:, None] & (offs_cn[None, :] < N)
+    tl.store(c_ptrs, accumulator, mask=c_mask)
+
+
+@triton.jit
+def batched_triton_kernel(
+        a_ptr,  # [E, max_num_tokens, K]
+        b_ptr,  # [E, K, N]
+        c_ptr,  # [E, max_num_tokens, N]
+        expert_num_tokens,  # [E]
+        compute_type: tl.constexpr,
+        # Dimensions
+        max_num_tokens,
+        K,
+        N,
+        # Quantization data
+        a_scale_ptr,
+        b_scale_ptr,
+        b_zp_ptr,
+        # The stride variables represent how much to increase the ptr by when
+        # moving by 1 element in a particular dimension. E.g. `stride_am` is
+        # how much to increase `a_ptr` by to get the element one row down
+        # (A has M rows).
+        stride_ae,
+        stride_am,
+        stride_ak,
+        stride_be,
+        stride_bk,
+        stride_bn,
+        stride_ce,
+        stride_cm,
+        stride_cn,
+        stride_asm,
+        stride_ask,
+        stride_bse,
+        stride_bsk,
+        stride_bsn,
+        # Blockwise quantization data
+        group_n: tl.constexpr,
+        group_k: tl.constexpr,
+        # Quantization schemes
+        use_fp8_w8a8: tl.constexpr,
+        use_int8_w8a16: tl.constexpr,
+        # Kernel config
+        BLOCK_M: tl.constexpr,
+        BLOCK_N: tl.constexpr,
+        BLOCK_K: tl.constexpr):
+    expert_id = tl.program_id(axis=0)
+    e_num_tokens = tl.load(expert_num_tokens + expert_id)
+    if e_num_tokens == 0:
+        # Early exit
+        return
+
+    pid_mn = tl.program_id(axis=1)
+    #num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M)
+    num_pid_n = tl.cdiv(N, BLOCK_N)
+    pid_m = pid_mn // num_pid_n
+    pid_n = pid_mn % num_pid_n
+
+    cta_m_start = pid_m * BLOCK_M
+    cta_n_start = pid_n * BLOCK_N
+    if cta_m_start >= e_num_tokens:
+        # Early exit
+        return
+
+    cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start)
+    cta_n_size = min(BLOCK_N, N - cta_n_start)
+
+    a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am
+    b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn
+    c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm +
+             cta_n_start * stride_cn)
+
+    expert_triton_kernel(
+        a_ptr,
+        b_ptr,
+        c_ptr,
+        expert_id,
+        compute_type,
+        cta_m_size,  # M
+        cta_n_size,  # N
+        K,  # K
+        a_scale_ptr,
+        b_scale_ptr,
+        b_zp_ptr,
+        # Strides
+        stride_am,
+        stride_ak,
+        stride_bk,
+        stride_bn,
+        stride_cm,
+        stride_cn,
+        stride_asm,
+        stride_ask,
+        stride_bse,
+        stride_bsk,
+        stride_bsn,
+        # Blockwise quantization data
+        group_n,
+        group_k,
+        # Quantization schemes
+        use_fp8_w8a8,
+        use_int8_w8a16,
+        # Kernel config
+        BLOCK_M,
+        BLOCK_N,
+        BLOCK_K)
+
+
+def invoke_moe_batched_triton_kernel(
+        A: torch.Tensor,  # [E, max_tokens, K]
+        B: torch.Tensor,  # [E, K, N]
+        C: torch.Tensor,  # [E, max_tokens, N]
+        expert_num_tokens: torch.Tensor,  # [E]
+        compute_type: tl.dtype,
+        # Quantization data
+        A_scale: torch.Tensor,
+        B_scale: torch.Tensor,
+        B_zp: torch.Tensor,
+        # Quantization schemes
+        use_fp8_w8a8: bool,
+        use_int8_w8a16: bool,
+        use_int4_w4a16: bool,
+        config: dict[str, int],
+        block_shape: Optional[list[int]] = None):
+
+    assert not use_int4_w4a16
+    max_num_tokens = A.size(1)
+    K = A.size(2)
+    N = C.size(2)
+
+    BLOCK_M = config['BLOCK_SIZE_M']
+    BLOCK_N = config['BLOCK_SIZE_N']
+    BLOCK_K = config['BLOCK_SIZE_K']
+    assert max_num_tokens % BLOCK_M == 0
+
+    grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
+            triton.cdiv(B.shape[1], BLOCK_N))
+
+    batched_triton_kernel[grid](
+        A,
+        B,
+        C,
+        expert_num_tokens,
+        compute_type,
+        # Dimensions
+        max_num_tokens,
+        K,
+        N,
+        # Quantization data
+        A_scale,
+        B_scale,
+        B_zp,
+        # Strides
+        A.stride(0),
+        A.stride(1),
+        A.stride(2),
+        B.stride(0),
+        B.stride(2),
+        B.stride(1),
+        C.stride(0),
+        C.stride(1),
+        C.stride(2),
+        A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
+        A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
+        B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
+        B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
+        B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
+        # Blockwise quantization data
+        0 if block_shape is None else block_shape[0],
+        0 if block_shape is None else block_shape[1],
+        # Quantization schemes
+        use_fp8_w8a8,
+        use_int8_w8a16,
+        # Kernel config
+        BLOCK_M=BLOCK_M,
+        BLOCK_N=BLOCK_N,
+        BLOCK_K=BLOCK_K)
+
+
+def rank_chunk(num, r, w):
+    rem = num % w
+    return (num // w) + (1 if r < rem else 0)
+
+
+class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine):
+
+    def __init__(self, max_num_tokens: Optional[int], world_size: int,
+                 dp_size: int, rank: int):
+        super().__init__()
+        self.world_size = world_size
+        self.dp_size = dp_size
+        self.rank = rank
+        self.max_num_tokens = max_num_tokens
+
+    def dispatch(
+        self,
+        a1: torch.Tensor,
+        a1_scale: Optional[torch.Tensor],
+        a2_scale: Optional[torch.Tensor],
+        topk_weights: torch.Tensor,
+        topk_ids: torch.Tensor,
+        num_experts: int,
+        expert_map: Optional[torch.Tensor],
+        apply_router_weight_on_input: bool,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+        assert a1.dim() == 2
+        assert topk_ids.dim() == 2
+        assert topk_ids.shape[0] == a1.shape[0]
+
+        if apply_router_weight_on_input:
+            topk = topk_ids.shape[1]
+            # TODO: this only works for topK=1, will need to update for topK>1
+            assert topk == 1, \
+                "apply_router_weight_on_input is only implemented for topk=1"
+            a1.mul_(topk_weights.to(a1.dtype))
+
+        num_tokens, hidden_dim = a1.shape
+        topk = topk_ids.shape[1]
+
+        if self.max_num_tokens is None:
+            tokens_per_expert = torch.bincount(topk_ids.view(-1),
+                                               minlength=num_experts)
+            self.max_num_tokens = int(tokens_per_expert.max().item())
+        else:
+            tokens_per_expert = torch.zeros(num_experts,
+                                            dtype=torch.int,
+                                            device=a1.device)
+
+        assert num_experts % self.world_size == 0
+
+        num_local_experts = num_experts // self.world_size
+
+        b_a1 = torch.zeros(
+            (num_local_experts, self.max_num_tokens, hidden_dim),
+            dtype=a1.dtype,
+            device=a1.device)
+
+        first_expert = num_local_experts * self.rank
+        last_expert = first_expert + num_local_experts
+
+        for expert_id in range(first_expert, last_expert):
+            topks = torch.any(topk_ids == expert_id, dim=1).flatten()
+            rows = torch.count_nonzero(topks.flatten())
+            b_a1[expert_id -
+                 first_expert, :rows, :] = a1[:topks.numel()][topks]
+            tokens_per_expert[expert_id - first_expert] = rows
+
+        return b_a1, a1_scale, tokens_per_expert
+
+    def combine(
+        self,
+        output: torch.Tensor,
+        fused_expert_output: torch.Tensor,
+        topk_weights: torch.Tensor,
+        topk_ids: torch.Tensor,
+        apply_router_weight_on_input: bool,
+    ) -> None:
+        num_tokens = topk_ids.shape[0]
+        num_local_experts = fused_expert_output.shape[0]
+        K = fused_expert_output.shape[-1]
+        assert output.shape[0] == num_tokens and output.shape[1] == K
+
+        output.fill_(0)
+
+        first_expert = num_local_experts * self.rank
+        last_expert = first_expert + num_local_experts
+
+        for expert_id in range(first_expert, last_expert):
+            topkws = topk_ids == expert_id
+            topks = torch.any(topkws, dim=1).flatten()
+            rows = torch.count_nonzero(topks)
+            rhs = fused_expert_output[expert_id - first_expert, :rows, :]
+            if not apply_router_weight_on_input:
+                rhs.mul_(topk_weights[topkws].view(rhs.shape[0], 1))
+            output[topks] = output[topks] + rhs
+
+
+class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
+
+    def __init__(
+        self,
+        world_size: int,
+        dp_size: int,
+        max_num_tokens: Optional[int] = None,
+        use_fp8_w8a8: bool = False,
+        use_int8_w8a8: bool = False,
+        use_int8_w8a16: bool = False,
+        use_int4_w4a16: bool = False,
+        block_shape: Optional[List[int]] = None,
+        block_m: Optional[int] = None,
+    ):
+        super().__init__()
+        assert block_shape is None
+        assert block_m is None
+        assert not use_fp8_w8a8, "NYI"
+        assert not use_int8_w8a8, "NYI"
+        assert not use_int8_w8a16, "NYI"
+        assert not use_int4_w4a16, "NYI"
+        self.max_num_tokens = max_num_tokens
+        self.world_size = world_size
+        self.dp_size = dp_size
+
+    def workspace_shapes(
+        self,
+        a: torch.Tensor,
+        M: int,
+        N: int,
+        K: int,
+        topk: int,
+        num_experts: int,
+    ) -> Tuple[int, int, torch.dtype]:
+        assert a.dim() == 2
+        num_dp = self.world_size // self.dp_size
+        max_num_tokens = a.shape[
+            0] if self.max_num_tokens is None else self.max_num_tokens
+        #print(f"WORKSPACE {max_num_tokens} {num_dp}")
+        workspace13 = num_experts * max_num_tokens * num_dp * K
+        workspace2 = max_num_tokens * num_dp * N
+        return (workspace13, workspace2, a.dtype)
+
+    def apply(
+        self,
+        hidden_states: torch.Tensor,
+        w1: torch.Tensor,
+        w2: torch.Tensor,
+        topk_ids: torch.Tensor,
+        activation: str,
+        global_num_experts: int,
+        expert_map: Optional[torch.Tensor],
+        w1_scale: Optional[torch.Tensor],
+        w2_scale: Optional[torch.Tensor],
+        w1_zp: Optional[torch.Tensor],
+        w2_zp: Optional[torch.Tensor],
+        a1q_scale: Optional[torch.Tensor],
+        a2_scale: Optional[torch.Tensor],
+        workspace13: torch.Tensor,
+        workspace2: torch.Tensor,
+        expert_num_tokens: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        assert hidden_states.dim() == 3
+        assert expert_num_tokens is not None
+        hidden_dim = hidden_states.shape[-1]
+
+        if self.max_num_tokens is None:
+            max_num_tokens = hidden_states.shape[1]
+        else:
+            max_num_tokens = self.max_num_tokens
+
+        num_dp = self.world_size // self.dp_size
+        num_experts = global_num_experts
+        out = _resize_cache(workspace13,
+                            (num_experts, max_num_tokens * num_dp, hidden_dim))
+        num_local_experts = w1.shape[0]
+        assert num_local_experts == w1.shape[
+            0], f"{num_local_experts} == {w1.shape[0]}"
+
+        N = w1.shape[1] // 2
+
+        # Not cudagraph friendly
+        assert (torch.cuda.is_current_stream_capturing()
+                or torch.all(expert_num_tokens <= max_num_tokens)), (
+                    f"{expert_num_tokens} <= {max_num_tokens}")
+
+        for expert in range(num_local_experts):
+            # Indexing expert_num_tokens doesn't work w/cudagraphs
+            if torch.cuda.is_current_stream_capturing():
+                num = max_num_tokens * num_dp
+            else:
+                num = int(expert_num_tokens[expert].item())
+            tmp = _resize_cache(workspace2, (num, N))
+            input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
+            self.activation(activation, tmp, input)
+            out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
+
+        return out
+
+
+class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
+
+    def __init__(
+        self,
+        max_num_tokens: Optional[int] = None,
+        use_fp8_w8a8: bool = False,
+        use_int8_w8a8: bool = False,
+        use_int8_w8a16: bool = False,
+        use_int4_w4a16: bool = False,
+        block_shape: Optional[List[int]] = None,
+        world_size: int = 1,
+        dp_size: int = 1,
+    ):
+        super().__init__()
+        self.use_fp8_w8a8 = use_fp8_w8a8
+        self.use_int8_w8a8 = use_int8_w8a8
+        self.use_int4_w4a16 = use_int4_w4a16
+        self.use_int8_w8a16 = use_int8_w8a16
+        self.block_shape = block_shape
+        self.max_num_tokens = max_num_tokens
+        assert not use_int8_w8a8, "NYI"
+        assert not use_int4_w4a16, "NYI"
+        self.world_size = world_size
+        self.dp_size = dp_size
+
+    def workspace_shapes(
+        self,
+        a: torch.Tensor,
+        M: int,
+        N: int,
+        K: int,
+        topk: int,
+        num_experts: int,
+    ) -> Tuple[int, int, torch.dtype]:
+        assert a.dim() == 2
+        num_dp = self.world_size // self.dp_size
+        max_num_tokens = a.shape[
+            0] if self.max_num_tokens is None else self.max_num_tokens
+        workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
+        workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
+        return (workspace13, workspace2, a.dtype)
+
+    def apply(
+        self,
+        hidden_states: torch.Tensor,
+        w1: torch.Tensor,
+        w2: torch.Tensor,
+        topk_ids: torch.Tensor,
+        activation: str,
+        global_num_experts: int,
+        expert_map: Optional[torch.Tensor],
+        w1_scale: Optional[torch.Tensor],
+        w2_scale: Optional[torch.Tensor],
+        w1_zp: Optional[torch.Tensor],
+        w2_zp: Optional[torch.Tensor],
+        a1q_scale: Optional[torch.Tensor],
+        a2_scale: Optional[torch.Tensor],
+        workspace13: torch.Tensor,
+        workspace2: torch.Tensor,
+        expert_num_tokens: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        # Check constraints.
+        if self.use_int4_w4a16:
+            assert hidden_states.shape[-1] // 2 == w1.shape[
+                2], "Hidden size mismatch"
+        else:
+            assert hidden_states.shape[-1] == w1.shape[2], \
+                (f"Hidden size mismatch {hidden_states.shape[-1]} "
+                 f"!= {w1.shape[2]}")
+
+        assert hidden_states.is_contiguous(
+        ), "Hidden_states must be contiguous"
+        assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
+        assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
+        assert hidden_states.dtype in [
+            torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
+        ]
+
+        # TODO: num_tokens -> max_num_tokens?
+        E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
+            hidden_states, w1, w2, topk_ids)
+
+        assert w1.shape[0] == E
+        assert w2.shape[0] == E
+
+        config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
+                                            use_int8_w8a16=self.use_int8_w8a16,
+                                            use_int4_w4a16=self.use_int4_w4a16,
+                                            dtype=hidden_states.dtype)
+
+        config = try_get_optimal_moe_config(
+            w1.shape,
+            w2.shape,
+            top_k_num,
+            config_dtype,
+            num_tokens,
+            block_shape=self.block_shape,
+        )
+
+        if hidden_states.dtype == torch.bfloat16:
+            compute_type = tl.bfloat16
+        elif hidden_states.dtype == torch.float16:
+            compute_type = tl.float16
+        elif hidden_states.dtype == torch.float32:
+            compute_type = tl.float32
+        elif hidden_states.dtype == torch.float8_e4m3fn:
+            compute_type = tl.bfloat16
+        else:
+            raise ValueError(
+                f"Unsupported compute_type: {hidden_states.dtype}")
+
+        #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
+        # We can reuse the memory between these because by the time we need
+        # cache3, we're done with cache1
+        intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N))
+        intermediate_cache2 = _resize_cache(workspace2,
+                                            (E, num_tokens, N // 2))
+        intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K))
+
+        # MM1
+        invoke_moe_batched_triton_kernel(A=hidden_states,
+                                         B=w1,
+                                         C=intermediate_cache1,
+                                         expert_num_tokens=expert_num_tokens,
+                                         compute_type=compute_type,
+                                         A_scale=a1q_scale,
+                                         B_scale=w1_scale,
+                                         B_zp=w1_zp,
+                                         use_fp8_w8a8=self.use_fp8_w8a8,
+                                         use_int8_w8a16=self.use_int8_w8a16,
+                                         use_int4_w4a16=self.use_int4_w4a16,
+                                         config=config,
+                                         block_shape=self.block_shape)
+
+        # Fix activations
+        self.activation(activation, intermediate_cache2.view(-1, N // 2),
+                        intermediate_cache1.view(-1, N))
+
+        #qintermediate_cache2 = intermediate_cache2
+        a2q_scale = a2_scale
+        # TODO (varun) : support w8a8
+        assert not self.use_fp8_w8a8
+        #if self.use_fp8_w8a8:
+        #    qintermediate_cache2, a2q_scale = _fp8_quantize(
+        #        intermediate_cache2, a2_scale, self.block_shape)
+
+        invoke_moe_batched_triton_kernel(A=intermediate_cache2,
+                                         B=w2,
+                                         C=intermediate_cache3,
+                                         expert_num_tokens=expert_num_tokens,
+                                         compute_type=compute_type,
+                                         A_scale=a2q_scale,
+                                         B_scale=w2_scale,
+                                         B_zp=w2_zp,
+                                         use_fp8_w8a8=self.use_fp8_w8a8,
+                                         use_int8_w8a16=self.use_int8_w8a16,
+                                         use_int4_w4a16=self.use_int4_w4a16,
+                                         config=config,
+                                         block_shape=self.block_shape)
+
+        return intermediate_cache3
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index f6305822c2dc..93b651ecde8b 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -8,16 +8,17 @@
 import torch
 
 import vllm.envs as envs
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
 from vllm import _custom_ops as ops
 from vllm.logger import init_logger
 from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
     _valid_deep_gemm, deep_gemm_moe_fp8)
+from vllm.model_executor.layers.fused_moe.dispatch_combine import (
+    StandardDispatchCombine)
 from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
     moe_align_block_size)
-from vllm.model_executor.layers.quantization.utils.fp8_utils import (
-    per_token_group_quant_fp8)
-from vllm.model_executor.layers.quantization.utils.int8_utils import (
-    per_token_group_quant_int8, per_token_quant_int8)
+from vllm.model_executor.layers.fused_moe.utils import (
+    _resize_cache, moe_kernel_quantize_input)
 from vllm.platforms import current_platform
 from vllm.triton_utils import tl, triton
 from vllm.utils import direct_register_custom_op
@@ -484,6 +485,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
     assert topk_weights is None or topk_weights.stride(1) == 1
     assert sorted_token_ids.stride(0) == 1
 
+    if use_fp8_w8a8 or use_int8_w8a8:
+        assert B_scale is not None
+        assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
+                == B_scale.shape[-2])
+        assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
+                == B_scale.shape[-1])
+
+    elif use_int8_w8a16 or use_int4_w4a16:
+        assert B_scale is not None
+        assert block_shape is None or block_shape[0] == 0
+    else:
+        assert A_scale is None
+        assert B_scale is None
+
     M = A.shape[0]
     num_tokens = M * top_k
 
@@ -855,6 +870,7 @@ def fused_topk(
     gating_output: torch.Tensor,
     topk: int,
     renormalize: bool,
+    indices_type: Optional[torch.dtype] = None,
 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     assert hidden_states.shape[0] == gating_output.shape[0], (
         "Number of tokens mismatch")
@@ -865,10 +881,11 @@ def fused_topk(
                                topk,
                                dtype=torch.float32,
                                device=hidden_states.device)
-    topk_ids = torch.empty(M,
-                           topk,
-                           dtype=torch.int32,
-                           device=hidden_states.device)
+    topk_ids = torch.empty(
+        M,
+        topk,
+        dtype=torch.int32 if indices_type is None else indices_type,
+        device=hidden_states.device)
     token_expert_indices = torch.empty(M,
                                        topk,
                                        dtype=torch.int32,
@@ -962,6 +979,20 @@ def get_config_dtype_str(
     return None
 
 
+# TODO: use scalar_type instead of bools?
+def get_config_qtype(
+    use_fp8_w8a8: bool,
+    use_int8_w8a8: bool,
+    use_int8_w8a16: bool,
+    use_int4_w4a16: bool,
+) -> Optional[torch.dtype]:
+    if use_fp8_w8a8:
+        return torch.float8_e4m3fn
+    elif use_int8_w8a8:
+        return torch.int8
+    return None
+
+
 def inplace_fused_experts(hidden_states: torch.Tensor,
                           w1: torch.Tensor,
                           w2: torch.Tensor,
@@ -1131,7 +1162,10 @@ def fused_experts(hidden_states: torch.Tensor,
                   a2_scale: Optional[torch.Tensor] = None,
                   block_shape: Optional[List[int]] = None,
                   allow_deep_gemm: bool = False) -> torch.Tensor:
-    if (allow_deep_gemm and use_fp8_w8a8
+    # For now, disable DeepGemm for small N (<= 512) until better
+    # permute/unpermute ops are available.
+    N = w1.shape[1]
+    if (allow_deep_gemm and use_fp8_w8a8 and N > 512
             and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
         assert apply_router_weight_on_input is False
         return deep_gemm_moe_fp8(
@@ -1148,6 +1182,7 @@ def fused_experts(hidden_states: torch.Tensor,
             w2_scale=w2_scale,
             a1_scale=a1_scale,
             a2_scale=a2_scale,
+            apply_router_weight_on_input=apply_router_weight_on_input,
         )
     else:
         return dispatch_fused_experts_func(inplace)(
@@ -1174,87 +1209,37 @@ def fused_experts(hidden_states: torch.Tensor,
             block_shape=block_shape)
 
 
-def moe_kernel_prepare_input(
-    A: torch.Tensor,
-    B: torch.Tensor,
-    A_scale: Optional[torch.Tensor],
-    B_scale: Optional[torch.Tensor],
-    use_fp8_w8a8: bool,
-    use_int8_w8a8: bool,
-    use_int8_w8a16: bool,
-    use_int4_w4a16: bool,
-    per_channel_quant: bool,
+def fused_experts_impl(
+    hidden_states: torch.Tensor,
+    w1: torch.Tensor,
+    w2: torch.Tensor,
+    topk_weights: torch.Tensor,
+    topk_ids: torch.Tensor,
+    inplace: bool = False,
+    activation: str = "silu",
+    apply_router_weight_on_input: bool = False,
+    use_fp8_w8a8: bool = False,
+    use_int8_w8a8: bool = False,
+    use_int8_w8a16: bool = False,
+    use_int4_w4a16: bool = False,
+    per_channel_quant: bool = False,
+    global_num_experts: int = -1,
+    expert_map: Optional[torch.Tensor] = None,
+    w1_scale: Optional[torch.Tensor] = None,
+    w2_scale: Optional[torch.Tensor] = None,
+    w1_zp: Optional[torch.Tensor] = None,
+    w2_zp: Optional[torch.Tensor] = None,
+    a1_scale: Optional[torch.Tensor] = None,
+    a2_scale: Optional[torch.Tensor] = None,
     block_shape: Optional[List[int]] = None,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
-    if use_fp8_w8a8:
-        assert B_scale is not None
-        if block_shape is None:
-            # If weights are per-channel (per_channel_quant=True), then
-            # activations apply per-token quantization. Otherwise, assume
-            # activation tensor-wise fp8 quantization, dynamic or static
-            A, A_scale = ops.scaled_fp8_quant(
-                A, A_scale, use_per_token_if_dynamic=per_channel_quant)
-        else:
-            # activation block-wise fp8 quantization
-            assert len(block_shape) == 2
-            _, block_k = block_shape[0], block_shape[1]
-            A, A_scale = per_token_group_quant_fp8(A, block_k)
-            assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
-            # assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
-            # assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
-    elif use_int8_w8a8:
-        assert B_scale is not None
-        if block_shape is None:
-            # activation channel-wise int8 quantization
-            assert (per_channel_quant
-                    ), "int8 quantization only supports block or channel-wise"
-            A, A_scale = per_token_quant_int8(A)
-        else:
-            # activation block-wise int8 quantization
-            assert len(block_shape) == 2
-            _, block_k = block_shape[0], block_shape[1]
-            A, A_scale = per_token_group_quant_int8(A, block_k)
-            assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
-            # assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
-            # assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
-    elif use_int8_w8a16 or use_int4_w4a16:
-        assert B_scale is not None
-        assert block_shape is None or block_shape[0] == 0
-    else:
-        assert A_scale is None
-        assert B_scale is None
-
-    return A, A_scale
-
-
-def fused_experts_impl(hidden_states: torch.Tensor,
-                       w1: torch.Tensor,
-                       w2: torch.Tensor,
-                       topk_weights: torch.Tensor,
-                       topk_ids: torch.Tensor,
-                       inplace: bool = False,
-                       activation: str = "silu",
-                       apply_router_weight_on_input: bool = False,
-                       use_fp8_w8a8: bool = False,
-                       use_int8_w8a8: bool = False,
-                       use_int8_w8a16: bool = False,
-                       use_int4_w4a16: bool = False,
-                       per_channel_quant: bool = False,
-                       global_num_experts: int = -1,
-                       expert_map: Optional[torch.Tensor] = None,
-                       w1_scale: Optional[torch.Tensor] = None,
-                       w2_scale: Optional[torch.Tensor] = None,
-                       w1_zp: Optional[torch.Tensor] = None,
-                       w2_zp: Optional[torch.Tensor] = None,
-                       a1_scale: Optional[torch.Tensor] = None,
-                       a2_scale: Optional[torch.Tensor] = None,
-                       block_shape: Optional[List[int]] = None):
+) -> torch.Tensor:
     # Check constraints.
     if use_int4_w4a16:
         assert hidden_states.shape[1] // 2 == w1.shape[
             2], "Hidden size mismatch"
     else:
-        assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
+        assert hidden_states.shape[1] == w1.shape[2], (
+            f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
 
     assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
     assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
@@ -1264,7 +1249,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
         torch.float32, torch.float16, torch.bfloat16
     ]
 
-    num_tokens, _ = hidden_states.shape
+    num_tokens = hidden_states.shape[0]
     E, N, _ = w1.shape
     K = w2.shape[1]
     if global_num_experts == -1:
@@ -1279,6 +1264,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
                                         use_int4_w4a16=use_int4_w4a16,
                                         dtype=hidden_states.dtype)
 
+    qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
+                             use_int8_w8a8=use_int8_w8a8,
+                             use_int8_w8a16=use_int8_w8a16,
+                             use_int4_w4a16=use_int4_w4a16)
+
     get_config_func = functools.partial(
         try_get_optimal_moe_config,
         w1.shape,
@@ -1341,15 +1331,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
         curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
         curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
 
-        qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input(
+        qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
             A=curr_hidden_states,
-            B=w1,
             A_scale=a1_scale,
-            B_scale=w1_scale,
-            use_fp8_w8a8=use_fp8_w8a8,
-            use_int8_w8a8=use_int8_w8a8,
-            use_int8_w8a16=use_int8_w8a16,
-            use_int4_w4a16=use_int4_w4a16,
+            qtype=qtype,
             per_channel_quant=per_channel_quant,
             block_shape=block_shape)
 
@@ -1360,7 +1345,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
         invoke_fused_moe_kernel(qcurr_hidden_states,
                                 w1,
                                 intermediate_cache1,
-                                qa1_scale,
+                                a1q_scale,
                                 w1_scale,
                                 w1_zp,
                                 curr_topk_weights,
@@ -1387,22 +1372,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
         else:
             raise ValueError(f"Unsupported FusedMoe activation: {activation}")
 
-        qintermediate_cache2, qa2_scale = moe_kernel_prepare_input(
+        qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
             A=intermediate_cache2,
-            B=w2,
             A_scale=a2_scale,
-            B_scale=w2_scale,
-            use_fp8_w8a8=use_fp8_w8a8,
-            use_int8_w8a8=use_int8_w8a8,
-            use_int8_w8a16=use_int8_w8a16,
-            use_int4_w4a16=use_int4_w4a16,
+            qtype=qtype,
             per_channel_quant=per_channel_quant,
             block_shape=block_shape)
 
         invoke_fused_moe_kernel(qintermediate_cache2,
                                 w2,
                                 intermediate_cache3,
-                                qa2_scale,
+                                a2q_scale,
                                 w2_scale,
                                 w2_zp,
                                 curr_topk_weights,
@@ -1537,3 +1517,229 @@ def fused_moe(
                          a1_scale=a1_scale,
                          a2_scale=a2_scale,
                          block_shape=block_shape)
+
+
+class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
+
+    def __init__(
+        self,
+        use_fp8_w8a8: bool,
+        use_int8_w8a8: bool,
+        use_int8_w8a16: bool,
+        use_int4_w4a16: bool,
+        per_channel_quant: bool,
+        block_shape: Optional[List[int]] = None,
+        block_m: Optional[int] = None,
+    ):
+        super().__init__()
+        self.use_fp8_w8a8 = use_fp8_w8a8
+        self.use_int4_w4a16 = use_int4_w4a16
+        self.use_int8_w8a8 = use_int8_w8a8
+        self.use_int8_w8a16 = use_int8_w8a16
+        self.block_shape = block_shape
+        self.block_m = block_m
+        self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
+                                      use_int8_w8a8=use_int8_w8a8,
+                                      use_int8_w8a16=use_int8_w8a16,
+                                      use_int4_w4a16=use_int4_w4a16)
+        self.per_channel_quant = per_channel_quant
+
+    def workspace_shapes(
+        self,
+        a: torch.Tensor,
+        M: int,
+        N: int,
+        K: int,
+        topk: int,
+        num_experts: int,
+    ) -> Tuple[int, int, torch.dtype]:
+        factor = num_experts if a.dim() == 3 else 1
+        workspace1 = M * topk * max(N * 2, K) * factor
+        workspace2 = M * topk * N * factor
+        return (workspace1, workspace2, a.dtype)
+
+    def apply(
+        self,
+        hidden_states: torch.Tensor,
+        w1: torch.Tensor,
+        w2: torch.Tensor,
+        topk_ids: torch.Tensor,
+        activation: str,
+        global_num_experts: int,
+        expert_map: Optional[torch.Tensor],
+        w1_scale: Optional[torch.Tensor],
+        w2_scale: Optional[torch.Tensor],
+        w1_zp: Optional[torch.Tensor],
+        w2_zp: Optional[torch.Tensor],
+        a1q_scale: Optional[torch.Tensor],
+        a2_scale: Optional[torch.Tensor],
+        workspace13: torch.Tensor,
+        workspace2: torch.Tensor,
+        expert_num_tokens: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        # Check constraints.
+        if self.use_int4_w4a16:
+            assert hidden_states.shape[-1] // 2 == w1.shape[
+                2], "Hidden size mismatch"
+        else:
+            assert hidden_states.shape[-1] == w1.shape[2], \
+                (f"Hidden size mismatch {hidden_states.shape[-1]} "
+                 f"!= {w1.shape[2]}")
+
+        assert hidden_states.is_contiguous(
+        ), "Hidden_states must be contiguous"
+        assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
+        assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
+        assert hidden_states.dtype in [
+            torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
+        ]
+
+        E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
+            hidden_states, w1, w2, topk_ids)
+
+        if global_num_experts == -1:
+            global_num_experts = E
+
+        config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
+                                            use_int8_w8a16=self.use_int8_w8a16,
+                                            use_int4_w4a16=self.use_int4_w4a16,
+                                            dtype=hidden_states.dtype)
+
+        config = try_get_optimal_moe_config(
+            w1.shape,
+            w2.shape,
+            top_k_num,
+            config_dtype,
+            num_tokens,
+            block_shape=self.block_shape,
+        )
+
+        if hidden_states.dtype == torch.bfloat16:
+            compute_type = tl.bfloat16
+        elif hidden_states.dtype == torch.float16:
+            compute_type = tl.float16
+        elif hidden_states.dtype == torch.float32:
+            compute_type = tl.float32
+        elif hidden_states.dtype == torch.float8_e4m3fn:
+            compute_type = tl.bfloat16
+        else:
+            raise ValueError(
+                f"Unsupported compute_type: {hidden_states.dtype}")
+
+        # We can reuse the memory between these because by the time we need
+        # cache3, we're done with cache1
+        intermediate_cache1 = _resize_cache(workspace13,
+                                            (num_tokens, top_k_num, N))
+        intermediate_cache2 = _resize_cache(workspace2,
+                                            (num_tokens * top_k_num, N // 2))
+        intermediate_cache3 = _resize_cache(workspace13,
+                                            (num_tokens, top_k_num, K))
+
+        if hidden_states.dim() == 2:  #block_m is None:
+            sorted_token_ids, expert_ids, num_tokens_post_padded = (
+                moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
+                                     global_num_experts, expert_map))
+        else:
+            max_num_tokens = hidden_states.shape[1]
+            sorted_token_ids = torch.arange(0,
+                                            hidden_states.shape[0] *
+                                            max_num_tokens,
+                                            device=hidden_states.device,
+                                            dtype=torch.int)
+            sorted_token_ids = sorted_token_ids.flatten()
+            expert_ids = torch.arange(0,
+                                      global_num_experts,
+                                      device=hidden_states.device,
+                                      dtype=torch.int)
+            expert_ids = torch.repeat_interleave(expert_ids,
+                                                 max_num_tokens,
+                                                 dim=0)
+            num_tokens_post_padded = torch.zeros(1,
+                                                 device=hidden_states.device,
+                                                 dtype=torch.int32)
+            num_tokens_post_padded.fill_(max_num_tokens)
+            hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+
+        invoke_fused_moe_kernel(hidden_states,
+                                w1,
+                                intermediate_cache1,
+                                a1q_scale,
+                                w1_scale,
+                                w1_zp,
+                                None,
+                                sorted_token_ids,
+                                expert_ids,
+                                num_tokens_post_padded,
+                                False,
+                                top_k_num,
+                                config,
+                                compute_type=compute_type,
+                                use_fp8_w8a8=self.use_fp8_w8a8,
+                                use_int8_w8a8=self.use_int8_w8a8,
+                                use_int8_w8a16=self.use_int8_w8a16,
+                                use_int4_w4a16=self.use_int4_w4a16,
+                                per_channel_quant=self.per_channel_quant,
+                                block_shape=self.block_shape)
+
+        self.activation(activation, intermediate_cache2,
+                        intermediate_cache1.view(-1, N))
+
+        a2q_scale: Optional[torch.Tensor] = None
+
+        qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
+            intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant,
+            self.block_shape)
+
+        invoke_fused_moe_kernel(qintermediate_cache2,
+                                w2,
+                                intermediate_cache3,
+                                a2q_scale,
+                                w2_scale,
+                                w2_zp,
+                                None,
+                                sorted_token_ids,
+                                expert_ids,
+                                num_tokens_post_padded,
+                                False,
+                                1,
+                                config,
+                                compute_type=compute_type,
+                                use_fp8_w8a8=self.use_fp8_w8a8,
+                                use_int8_w8a8=self.use_int8_w8a8,
+                                use_int8_w8a16=self.use_int8_w8a16,
+                                use_int4_w4a16=self.use_int4_w4a16,
+                                per_channel_quant=self.per_channel_quant,
+                                block_shape=self.block_shape)
+
+        return intermediate_cache3
+
+
+def modular_triton_fused_moe(
+    use_fp8_w8a8: bool,
+    use_int8_w8a8: bool,
+    use_int8_w8a16: bool,
+    use_int4_w4a16: bool,
+    per_channel_quant: bool,
+    block_shape: Optional[List[int]] = None,
+) -> mk.FusedMoEModularKernel:
+    qtype = get_config_qtype(
+        use_fp8_w8a8=use_fp8_w8a8,
+        use_int8_w8a8=use_int8_w8a8,
+        use_int8_w8a16=use_int8_w8a16,
+        use_int4_w4a16=use_int4_w4a16,
+    )
+    return mk.FusedMoEModularKernel(
+        StandardDispatchCombine(
+            quant_dtype=qtype,
+            per_channel_quant=per_channel_quant,
+            block_shape=block_shape,
+        ),
+        TritonExperts(
+            use_fp8_w8a8=use_fp8_w8a8,
+            use_int8_w8a8=use_int8_w8a8,
+            use_int8_w8a16=use_int8_w8a16,
+            use_int4_w4a16=use_int4_w4a16,
+            per_channel_quant=per_channel_quant,
+            block_shape=block_shape,
+        ),
+    )
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index 35994c8ac6af..0feee87b4cbd 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -1,15 +1,19 @@
 # SPDX-License-Identifier: Apache-2.0
 
+import importlib
+import threading
 from abc import abstractmethod
+from dataclasses import dataclass
 from enum import Enum
 from typing import Callable, List, Optional, Tuple
+from weakref import WeakValueDictionary
 
 import torch
 import torch.nn.functional as F
 from torch.nn.parameter import UninitializedParameter
 
 import vllm.envs as envs
-from vllm.config import get_current_vllm_config
+from vllm.config import ParallelConfig, get_current_vllm_config
 from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
                               get_tensor_model_parallel_world_size,
                               tensor_model_parallel_all_reduce)
@@ -23,8 +27,16 @@
 from vllm.platforms.interface import CpuArchEnum
 from vllm.utils import direct_register_custom_op
 
+has_pplx = importlib.util.find_spec("pplx_kernels") is not None
+
 if current_platform.is_cuda_alike():
-    from .fused_moe import fused_experts
+    from .fused_batched_moe import BatchedDispatchCombine, BatchedTritonExperts
+    from .fused_moe import TritonExperts, fused_experts
+    from .modular_kernel import (FusedMoEModularKernel,
+                                 FusedMoEPermuteExpertsUnpermute,
+                                 FusedMoEQuantizeDispatchCombine)
+    if has_pplx:
+        from .pplx_dispatch_combine import PplxDispatchCombine
 else:
     fused_experts = None  # type: ignore
 if current_platform.is_tpu():
@@ -34,6 +46,177 @@
     fused_moe_pallas = None  # type: ignore
 logger = init_logger(__name__)
 
+MOE_DP_CHUNK_SIZE = 256
+
+
+@dataclass
+class FusedMoEParallelConfig:
+    tp_size: int
+    dp_size: int
+    ep_size: int
+    tp_rank: int
+    dp_rank: int
+    ep_rank: int
+
+    use_ep: bool  # whether to use EP or not
+
+    @property
+    def use_pplx_kernels(self):
+        return self.dp_size > 1 and self.use_ep and has_pplx
+
+    @staticmethod
+    def make(tp_size_: int, dp_size_: int,
+             vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
+        """
+        Determine MoE parallel configuration. Based on the input tp_size_,
+        dp_size_, ep_size_ and vllm's parallel config, determine what
+        level's of parallelism to use in the fused moe layer.
+
+        Args:
+            tp_size_ (int): tp_size passed into the FusedMoE constructor.
+            dp_size_ (int): dp_size passed into the FusedMoE constructor.
+            ep_size_ (int): ep_size passed into the FusedMoE constructor.
+            vllm_parallel_config (ParallelConfig): vllm's parallel config
+            object.
+
+        Examples:
+        When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1,
+        we simply return the sizes unaltered and the ranks set to 0.
+
+        Expert Parallelism is considered only when either dp_size_ or tp_size_
+        is non trivial.
+
+        When TP = 2, DP = 1 and EP = False, the configuration on different
+        devices,
+            - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
+                         legend : {size, rank}
+            - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
+            - Comment : Tensors are sharded across 2 devices.
+
+        When TP = 1, DP = 2 and EP = False, the configuration on different
+        devices,
+            - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
+            - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
+            - Comment: There are 2 engine instances and the tensors are sharded
+              across 2 decvices.
+
+        When TP = 2, DP = 2 and EP = False, the configuration on different
+        devices,
+            - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
+            - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
+            - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
+            - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
+            - Comment: There are 2 engine instances and the tensors are sharded
+              across 4 devices.
+
+        When, TP = 2, DP = 1 and EP = True, the configuration on different
+        devices,
+            - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
+            - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
+            - Comment: The experts are split between the 2 devices.
+
+        When, TP = 1, DP = 2 and EP = True, the configuration on different
+        devices,
+            - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
+            - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
+            - Comment: There are 2 engine instances and the experts are split
+              between the 2 devices.
+
+        When TP = 2, DP = 2 and EP = True, the configuration on different
+        devices,
+            - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
+            - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
+            - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
+            - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
+            - Comment: There are 2 engine instances and the experts are split
+              between the 4 devices.
+        """
+
+        def flatten_tp_across_dp(dp_rank: int):
+            tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
+            # There are actually dp_size_ * tp_size_ devices. Update tp_size
+            # and tp_rank so we shard across all devices.
+            tp_size = dp_size_ * tp_size_
+            tp_rank = dp_rank * tp_size_ + tp_rank
+            return tp_size, tp_rank
+
+        use_ep = (dp_size_ * tp_size_ > 1
+                  and vllm_parallel_config.enable_expert_parallel)
+
+        dp_size = dp_size_
+        dp_rank = get_dp_group().rank_in_group
+        tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
+
+        if not use_ep:
+            return FusedMoEParallelConfig(tp_size=tp_size,
+                                          tp_rank=tp_rank,
+                                          dp_size=dp_size,
+                                          dp_rank=dp_rank,
+                                          ep_size=1,
+                                          ep_rank=0,
+                                          use_ep=False)
+        # DP + EP / TP + EP / DP + TP + EP
+        assert use_ep
+        # In EP, each device owns a set of experts fully. There is no tensor
+        # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
+        ep_size = tp_size
+        ep_rank = tp_rank
+        return FusedMoEParallelConfig(tp_size=1,
+                                      tp_rank=0,
+                                      dp_size=dp_size,
+                                      dp_rank=dp_rank,
+                                      ep_size=ep_size,
+                                      ep_rank=ep_rank,
+                                      use_ep=True)
+
+
+# Adapted from pplx-kernels tests/all_to_all_utils.py
+@dataclass
+class MoEConfig:
+    num_experts: int
+    experts_per_token: int
+    hidden_dim: int
+
+    num_local_experts: int
+    moe_parallel_config: FusedMoEParallelConfig
+
+    in_dtype: torch.dtype  # The activation type.
+
+    # TODO: add more quantization params, blocked, per-token, etc.
+    block_size: int = 128
+
+    @property
+    def tp_size(self):
+        return self.moe_parallel_config.tp_size
+
+    @property
+    def dp_size(self):
+        return self.moe_parallel_config.dp_size
+
+    @property
+    def ep_size(self):
+        return self.moe_parallel_config.ep_size
+
+    @property
+    def tp_rank(self):
+        return self.moe_parallel_config.tp_rank
+
+    @property
+    def dp_rank(self):
+        return self.moe_parallel_config.dp_rank
+
+    @property
+    def ep_rank(self):
+        return self.moe_parallel_config.ep_rank
+
+    @property
+    def use_ep(self):
+        return self.moe_parallel_config.use_ep
+
+    @property
+    def use_pplx_kernels(self):
+        return self.moe_parallel_config.use_pplx_kernels
+
 
 class FusedMoeWeightScaleSupported(Enum):
     TENSOR = "tensor"
@@ -50,6 +233,14 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
                        params_dtype: torch.dtype, **extra_weight_attrs):
         raise NotImplementedError
 
+    def set_dispatch_combine(
+        self,
+        dp_size: int,
+        world_size: int,
+        dispatch_combine: FusedMoEQuantizeDispatchCombine,
+    ) -> bool:
+        return False
+
     @abstractmethod
     def apply(
         self,
@@ -72,10 +263,54 @@ def apply(
         raise NotImplementedError
 
 
+class AllToAllCache:
+
+    def __init__(self):
+        self._cache: WeakValueDictionary = WeakValueDictionary()
+        self._lock = threading.RLock()  # Reentrant lock for thread safety
+
+    def destroy(self):
+        with self._lock:
+            for _, a2a in self._cache.items():
+                a2a.destroy()
+
+    def get_or_create(self, **kwargs):
+        assert has_pplx
+        import pplx_kernels as pplx
+
+        # Create a hashable key from the kwargs
+        key = tuple(sorted((k, v) for k, v in kwargs.items()))
+
+        with self._lock:
+            instance = self._cache.get(key)
+            if instance is None:
+                # TODO (varun): Add support to switch to intranode
+                # when all communications are within the same
+                # node.
+                instance = pplx.AllToAll.internode(**kwargs)
+                self._cache[key] = instance
+            return instance
+
+
+# Global singleton
+_all_to_all_cache = AllToAllCache()
+
+
+# Factory function as a cleaner interface
+def get_all_to_all(**kwargs):
+    return _all_to_all_cache.get_or_create(**kwargs)
+
+
 @CustomOp.register("unquantized_fused_moe")
 class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
     """MoE method without quantization."""
 
+    def __init__(self, moe: MoEConfig):
+        super().__init__()
+        self.fused_experts = fused_experts
+        self.moe = moe
+        self.using_pplx = False
+
     def create_weights(self, layer: torch.nn.Module, num_experts: int,
                        hidden_size: int, intermediate_size_per_partition: int,
                        params_dtype: torch.dtype, **extra_weight_attrs):
@@ -173,6 +408,50 @@ def apply(
             activation=activation,
             apply_router_weight_on_input=apply_router_weight_on_input)
 
+    def set_dispatch_combine(
+        self,
+        dp_size: int,
+        world_size: int,
+        dispatch_combine: FusedMoEQuantizeDispatchCombine,
+    ) -> bool:
+        assert self.fused_experts == fused_experts
+
+        experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
+
+        self.using_pplx = False
+
+        if isinstance(dispatch_combine,
+                      (BatchedDispatchCombine, PplxDispatchCombine)):
+            logger.debug("BatchedTritonExperts %s", self.moe)
+            experts = BatchedTritonExperts(
+                max_num_tokens=MOE_DP_CHUNK_SIZE,
+                world_size=world_size,
+                dp_size=dp_size,
+                use_fp8_w8a8=False,
+                use_int8_w8a8=False,
+                use_int8_w8a16=False,
+                use_int4_w4a16=False,
+                block_shape=None,
+            )
+            self.using_pplx = isinstance(dispatch_combine, PplxDispatchCombine)
+        else:
+            logger.debug("TritonExperts %s", self.moe)
+            experts = TritonExperts(
+                use_fp8_w8a8=False,
+                use_int8_w8a8=False,
+                use_int8_w8a16=False,
+                use_int4_w4a16=False,
+                block_shape=None,
+                per_channel_quant=False,
+            )
+
+        self.fused_experts = FusedMoEModularKernel(
+            dispatch_combine,
+            experts,
+        )
+
+        return True
+
     def forward_cuda(
         self,
         layer: torch.nn.Module,
@@ -201,9 +480,10 @@ def forward_cuda(
             num_expert_group=num_expert_group,
             custom_routing_function=custom_routing_function,
             scoring_func=scoring_func,
-            e_score_correction_bias=e_score_correction_bias)
+            e_score_correction_bias=e_score_correction_bias,
+            indices_type=torch.uint32 if self.using_pplx else None)
 
-        return fused_experts(
+        return self.fused_experts(
             hidden_states=x,
             w1=layer.w13_weight,
             w2=layer.w2_weight,
@@ -213,7 +493,8 @@ def forward_cuda(
             activation=activation,
             apply_router_weight_on_input=apply_router_weight_on_input,
             global_num_experts=global_num_experts,
-            expert_map=expert_map)
+            expert_map=expert_map,
+        )
 
     def forward_cpu(
         self,
@@ -369,6 +650,45 @@ def determine_expert_map(
     return (local_num_experts, expert_map)
 
 
+def _construct_dispatch_combine(
+    moe: MoEConfig, quant_config: Optional[QuantizationConfig]
+) -> Optional[FusedMoEQuantizeDispatchCombine]:
+    max_num_tokens = MOE_DP_CHUNK_SIZE
+    world_size = moe.ep_size
+    dp_size = moe.ep_size // moe.dp_size  # dp_size actually means TP.
+    rank = moe.ep_rank
+
+    if moe.use_pplx_kernels:
+        logger.debug("using pplx dispatch")
+
+        all_to_all = get_all_to_all(
+            max_num_tokens=max_num_tokens,
+            num_experts=moe.num_experts,
+            experts_per_token=moe.experts_per_token,  # topk
+            rank=rank,
+            world_size=world_size,
+            dp_size=dp_size,
+            hidden_dim=moe.hidden_dim,
+            hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
+            # For blocked per token: set to
+            #   ceil_div(hidden_dim, block_size) * sizeof(float32)
+            # For per-token: set to sizeof(float32)
+            hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else
+                                    ((moe.hidden_dim + moe.block_size - 1) //
+                                     moe.block_size * torch.float32.itemsize)))
+
+        return PplxDispatchCombine(
+            all_to_all,
+            max_num_tokens=max_num_tokens,
+            world_size=world_size,
+            rank=rank,
+            dp_size=dp_size,
+            quant_dtype=moe.in_dtype,
+        )
+
+    return None
+
+
 class FusedMoE(torch.nn.Module):
     """FusedMoE layer for MoE models.
 
@@ -419,21 +739,16 @@ def __init__(
             params_dtype = torch.get_default_dtype()
         self.params_dtype = params_dtype
 
-        # Note: here we guard against accessing the TP and DP groups when
-        # uninitialized (this happens when testing)
-        self.tp_size = (tp_size if tp_size is not None else
-                        get_tensor_model_parallel_world_size())
-        tp_rank = 0 if self.tp_size == 1 else get_tensor_model_parallel_rank()
-        self.dp_size = (dp_size
-                        if dp_size is not None else get_dp_group().world_size)
-        self.dp_rank = (0
-                        if self.dp_size == 1 else get_dp_group().rank_in_group)
-        self.global_num_experts = num_experts
-
-        # Use expert parallelism instead of tensor parallelism?
         vllm_config = get_current_vllm_config()
-        use_ep = (vllm_config.parallel_config.enable_expert_parallel
-                  and self.tp_size * self.dp_size > 1)
+        self.moe_parallel_config: FusedMoEParallelConfig = (
+            FusedMoEParallelConfig.make(
+                tp_size_=(tp_size if tp_size is not None else
+                          get_tensor_model_parallel_world_size()),
+                dp_size_=(dp_size if dp_size is not None else
+                          get_dp_group().world_size),
+                vllm_parallel_config=vllm_config.parallel_config))
+
+        self.global_num_experts = num_experts
 
         # For smuggling this layer into the fused moe custom op
         self.use_direct_call = self.dp_size == 1
@@ -444,28 +759,17 @@ def __init__(
             compilation_config.static_forward_context[prefix] = self
             self.layer_name = prefix
 
-        if use_ep:
-            # Set TP size to 1 to adjust for EP and adjust EP size and rank
-            # for DP attention.
-            self.ep_rank = tp_rank + self.tp_size * self.dp_rank
-            self.tp_rank = 0
-            self.ep_size = self.tp_size * self.dp_size
-            self.tp_size = 1
-
+        # Determine expert maps
+        if self.use_ep:
             self.local_num_experts, self.expert_map = determine_expert_map(
                 ep_size=self.ep_size,
                 ep_rank=self.ep_rank,
                 global_num_experts=self.global_num_experts)
         else:
-            # Adjust TP size for DP attention
-            self.tp_rank = tp_rank + self.tp_size * self.dp_rank
-            self.ep_rank = 0
-            self.tp_size = self.tp_size * self.dp_size
-            self.ep_size = 1
-            self.local_num_experts = self.global_num_experts
-            self.expert_map = None
+            self.local_num_experts, self.expert_map = (self.global_num_experts,
+                                                       None)
+
         self.top_k = top_k
-        self.global_num_experts = num_experts
 
         assert intermediate_size % self.tp_size == 0
         self.hidden_size = hidden_size
@@ -489,14 +793,39 @@ def __init__(
             from vllm_hpu_extension.ops import DynamicFusedMOE
             self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
 
+        moe = MoEConfig(
+            num_experts=self.global_num_experts,
+            experts_per_token=top_k,
+            hidden_dim=hidden_size,
+            num_local_experts=self.local_num_experts,
+            moe_parallel_config=self.moe_parallel_config,
+            in_dtype=params_dtype,  # TODO: is this right?
+        )
+
         # Note: get_quant_method will look at the layer's local_num_experts
         # for heuristic purposes, so it must be initialized first.
+        quant_method: Optional[FusedMoEMethodBase] = None
+
         if quant_config is None:
-            self.quant_method: Optional[QuantizeMethodBase] = (
-                UnquantizedFusedMoEMethod())
+            quant_method = UnquantizedFusedMoEMethod(moe)
         else:
-            self.quant_method = quant_config.get_quant_method(self, prefix)
-        assert self.quant_method is not None
+            quant_method = quant_config.get_quant_method(
+                self, prefix)  # type: ignore
+            assert isinstance(quant_method, FusedMoEMethodBase)
+
+        assert quant_method is not None
+        self.quant_method = quant_method
+
+        dispatch_combine = _construct_dispatch_combine(moe, quant_config)
+
+        if dispatch_combine is not None:
+            world_size = moe.ep_size
+            dp_size = int(moe.ep_size // moe.dp_size)
+            success = self.quant_method.set_dispatch_combine(
+                dp_size, world_size, dispatch_combine)
+            if not success:
+                logger.warning("DP+EP not supported for %s.",
+                               type(self.quant_method))
 
         self.apply_router_weight_on_input = apply_router_weight_on_input
         moe_quant_params = {
@@ -516,6 +845,38 @@ def __init__(
 
         self.quant_method.create_weights(layer=self, **moe_quant_params)
 
+    @property
+    def tp_size(self):
+        return self.moe_parallel_config.tp_size
+
+    @property
+    def dp_size(self):
+        return self.moe_parallel_config.dp_size
+
+    @property
+    def ep_size(self):
+        return self.moe_parallel_config.ep_size
+
+    @property
+    def tp_rank(self):
+        return self.moe_parallel_config.tp_rank
+
+    @property
+    def dp_rank(self):
+        return self.moe_parallel_config.dp_rank
+
+    @property
+    def ep_rank(self):
+        return self.moe_parallel_config.ep_rank
+
+    @property
+    def use_ep(self):
+        return self.moe_parallel_config.use_ep
+
+    @property
+    def use_pplx_kernels(self):
+        return self.moe_parallel_config.use_pplx_kernels
+
     def _load_per_tensor_weight_scale(self, shard_id: str,
                                       param: torch.nn.Parameter,
                                       loaded_weight: torch.Tensor,
@@ -783,7 +1144,8 @@ def select_experts(hidden_states: torch.Tensor,
                        num_expert_group: Optional[int] = None,
                        custom_routing_function: Optional[Callable] = None,
                        scoring_func: str = "softmax",
-                       e_score_correction_bias: Optional[torch.Tensor] = None):
+                       e_score_correction_bias: Optional[torch.Tensor] = None,
+                       indices_type: Optional[torch.dtype] = None):
         from vllm.model_executor.layers.fused_moe.fused_moe import (
             fused_topk, grouped_topk)
 
@@ -800,18 +1162,24 @@ def select_experts(hidden_states: torch.Tensor,
                 topk_group=topk_group,
                 scoring_func=scoring_func,
                 e_score_correction_bias=e_score_correction_bias)
+            if indices_type is not None:
+                topk_ids = topk_ids.to(dtype=indices_type)
         elif custom_routing_function is None:
             topk_weights, topk_ids, token_expert_indices = fused_topk(
                 hidden_states=hidden_states,
                 gating_output=router_logits,
                 topk=top_k,
-                renormalize=renormalize)
+                renormalize=renormalize,
+                indices_type=indices_type,
+            )
         else:
             topk_weights, topk_ids = custom_routing_function(
                 hidden_states=hidden_states,
                 gating_output=router_logits,
                 topk=top_k,
                 renormalize=renormalize)
+            if indices_type is not None:
+                topk_ids = topk_ids.to(dtype=indices_type)
 
         return topk_weights, topk_ids
 
@@ -833,6 +1201,31 @@ def naive_multicast(self, x: torch.Tensor,
 
         return buffer
 
+    def must_reduce_shared_expert_outputs(self) -> bool:
+        """
+        The shared_experts are typically computed using the RowParallelLinear
+        layer. The result of this function is typically used as
+        the reduce_results argument to the module.
+        When just tensor-parallel is used, it is not required to reduce
+        the shared_experts results immediately. Instead we reduce at the
+        once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
+        With EP and the pplx kernels - this is no longer viable as all
+        GPU ranks in DP, produce the complete set of hidden_states.
+        Therefore it is required that we reduce the shared_experts output
+        early.
+        """
+        return self.use_pplx_kernels
+
+    def maybe_all_reduce_tensor_model_parallel(
+            self, final_hidden_states: torch.Tensor):
+        """
+        The pplx combine kernel reduces across GPU ranks by default. 
+        """
+        if self.use_pplx_kernels:
+            return final_hidden_states
+        else:
+            return tensor_model_parallel_all_reduce(final_hidden_states)
+
     def forward(self, hidden_states: torch.Tensor,
                 router_logits: torch.Tensor):
         if self.use_direct_call:
@@ -841,13 +1234,66 @@ def forward(self, hidden_states: torch.Tensor,
             return torch.ops.vllm.moe_forward(hidden_states, router_logits,
                                               self.layer_name)
 
+    def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
+                             full_router_logits: torch.Tensor):
+
+        full_final_hidden_states = torch.empty_like(full_hidden_states)
+
+        def process_chunk(chunk_start, chunk_end, skip_result_store=False):
+            hidden_states = full_hidden_states[chunk_start:chunk_end, :]
+            router_logits = full_router_logits[chunk_start:chunk_end, :]
+
+            # Matrix multiply.
+            final_hidden_states = self.quant_method.apply(
+                layer=self,
+                x=hidden_states,
+                router_logits=router_logits,
+                top_k=self.top_k,
+                renormalize=self.renormalize,
+                use_grouped_topk=self.use_grouped_topk,
+                global_num_experts=self.global_num_experts,
+                expert_map=self.expert_map,
+                topk_group=self.topk_group,
+                num_expert_group=self.num_expert_group,
+                custom_routing_function=self.custom_routing_function,
+                scoring_func=self.scoring_func,
+                e_score_correction_bias=self.e_score_correction_bias,
+                activation=self.activation,
+            )
+
+            if not skip_result_store:
+                full_final_hidden_states[chunk_start:chunk_end, :].copy_(
+                    final_hidden_states)
+
+        ctx = get_forward_context()
+        max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
+        moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE
+
+        num_tokens = full_hidden_states.size(0)
+        for chunk_start_ in range(0, max_tokens_across_dp,
+                                  moe_dp_chunk_size_per_rank):
+            chunk_start = chunk_start_
+            chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
+                            max_tokens_across_dp)
+            # clamp start and end
+            chunk_start = min(chunk_start, num_tokens - 1)
+            chunk_end = min(chunk_end, num_tokens)
+
+            process_chunk(chunk_start,
+                          chunk_end,
+                          skip_result_store=chunk_start_ >= num_tokens)
+
+        return full_final_hidden_states
+
     def forward_impl(self, hidden_states: torch.Tensor,
                      router_logits: torch.Tensor):
         assert self.quant_method is not None
+        if self.moe_parallel_config.use_pplx_kernels:
+            return self.forward_impl_chunked(hidden_states, router_logits)
 
         if self.dp_size > 1:
-            cu_tokens_across_dp_cpu = get_forward_context(
-            ).dp_metadata.cu_tokens_across_dp_cpu
+            ctx = get_forward_context()
+            cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu
 
             hidden_states = self.naive_multicast(hidden_states,
                                                  cu_tokens_across_dp_cpu)
diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py
new file mode 100644
index 000000000000..95b0397f9529
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py
@@ -0,0 +1,364 @@
+# SPDX-License-Identifier: Apache-2.0
+from abc import ABC, abstractmethod
+from typing import Optional, Tuple
+
+import torch
+
+#
+# This file defines a set of base classes used to make MoE kernels more modular.
+# The goal is to be able to utilize different communication mechanisms with
+# any fused MoE kernel without needing to have combinatoric implementations.
+#
+# The fused moe kernels are broken down into the following components:
+#
+# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine]
+#
+# Each component will be independent of the others except for
+# [Quantize-Dispatch] and `[Combine] (see below). The components can then be
+# mixed and matched with so that DP+EP can be supported easily for multiple
+# MoE kernel implementations.
+#
+# The following main classes are defined:
+# * FusedMoEQuantizeDispatchCombine - an abstract base class for quantization,
+#   dispatching and combing. The dispatch method takes care of any needed
+#   quantization and the combine method applies weights and does the final
+#   reduction of the output.
+# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
+#   MoE operation. One important feature to note is that this class does not
+#   apply topk weights or reduce the final output.
+# * FusedMoEModularKernel - an interface class that combines a
+#   FusedMoEQuantizeDispatchCombine and a FusedMoEPermuteExpertsUnpermute to
+#   provide the standard fused MoE kernel interface.
+#
+# [Quantize-Dispatch] and [Combine] functionality are bundled into a single
+# class `FusedMoEQuantizeDispatchCombine` since they could use collective
+# communication mechanisms that need to be consistent.
+#
+
+
+def _moe_problem_size(
+    a1: torch.Tensor,
+    w1: torch.Tensor,
+    w2: torch.Tensor,
+    topk_ids: torch.Tensor,
+) -> Tuple[int, int, int, int, int]:
+    """
+    Extract the MoE problem size from the given tensor arguments:
+    - a: The hidden states, input to the MoE layer.
+    - w1: The first set of expert weights.
+    - w2: The second set of expert weights.
+    - topk_ids: The topk ids.
+
+    Note: extracting the problem shape from the weight and activation tensors is
+    not obvious.  It needs to be done this way specifically due to subtle issues
+    with particular kernels, e.g. the int4 kernels divide the trailing dimension
+    by two, so it's not "correct" to extract N or K from the trailing dimension
+    of w1 or w2.  Similarly, some kernels transpose the weights, so this needs
+    to be kept in mind.
+    """
+    assert w1.dim() == 3 and w2.dim() == 3
+    E, N, _ = w1.shape
+    K = w2.shape[1]
+
+    if a1.dim() == 2:
+        # Make sure we are using the correct a1 (pre-permute).
+        assert topk_ids.shape[0] == a1.shape[0], \
+            f"{topk_ids.shape[0]} != {a1.shape[0]}"
+        M = a1.shape[0]
+    else:
+        assert a1.dim() == 3
+        assert a1.shape[0] == E, f"{a1.shape[0]} == {E}"
+        M = a1.shape[1]  # This is max_num_tokens
+
+    assert topk_ids.dim() == 2
+    topk = topk_ids.shape[1]
+
+    return E, M, N, K, topk
+
+
+class FusedMoEQuantizeDispatchCombine(ABC):
+    """
+    An abstract base class for the [Quantize-Dispatch] and [Combine] steps
+    described above.
+    """
+
+    @abstractmethod
+    def dispatch(
+        self,
+        a1: torch.Tensor,
+        a1_scale: Optional[torch.Tensor],
+        a2_scale: Optional[torch.Tensor],
+        topk_weights: torch.Tensor,
+        topk_ids: torch.Tensor,
+        num_experts: int,
+        expert_map: Optional[torch.Tensor],
+        apply_router_weight_on_input: bool,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+        """
+        Perform any quantization (and/or) dispatching needed
+        for this kernel.
+        - a1: The (unquantized) input to the MoE layer.
+        - a1_scale: Optional scales for a1
+        - a2_scale: Optional scales for the second MoE gemm.  Required to make
+          sure the quantization is consistent for both gemms.
+        - topk_ids: The topk ids.
+        - topk_weights: The topk weights.
+        - num_experts: The total number of experts in the global expert space.
+        - expert_map: A tensor mapping expert indices from the global expert
+          space to the local expert space of the expert parallel shard.
+        - apply_router_weight_on_input: When True, apply the weights to the
+          activations, before quantization + dispatching.
+
+        Returns a tuple of:
+        - quantized + dispatched a.
+        - quantized + dispatched a1_scales.
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def combine(
+        self,
+        output: torch.Tensor,
+        fused_expert_output: torch.Tensor,
+        topk_weights: torch.Tensor,
+        topk_ids: torch.Tensor,
+        apply_router_weight_on_input: bool,
+    ) -> None:
+        """
+        Perform any combine plus apply weights and perform a reduction on the
+        fused experts output.
+        - output: The output tensor, written in place.  Must be (M, K) shape.
+        - fused_expert_output: The unweighted, unreduced output of the fused
+          experts, it will have (M, topk, K) shape.
+        - topk_weights: The weights to be applied to the fused_experts_output.
+        - topk_ids: The topk_ids.
+        - apply_router_weight_on_input: When False, apply the weights to
+          fused_expert_output.
+        """
+        raise NotImplementedError
+
+
+class FusedMoEPermuteExpertsUnpermute(ABC):
+    """
+    An abstract base class for the [Permute-Experts-Unpermute] step described
+    above.
+    """
+
+    @abstractmethod
+    def workspace_shapes(
+        self,
+        a: torch.Tensor,
+        M: int,
+        N: int,
+        K: int,
+        topk: int,
+        num_experts: int,
+    ) -> Tuple[int, int, torch.dtype]:
+        """
+        Compute the number of elements for the temporary outputs of the two
+        gemms and activation in the fused expert function.  Since the
+        gemms are independent, the workspace for the first gemm can be shared
+        with the workspace for the last gemm.
+
+        Returns a tuple of:
+        - Number of workspace13 elements: must be large enough to hold the
+          result of either expert gemm.
+        - Number of workspace2 elements: must be large enough to hold the
+          result of the activation function.
+        - Workspace type: The dtype to use for the workspace tensors.
+        """
+        raise NotImplementedError
+
+    def activation(self, activation: str, output: torch.Tensor,
+                   input: torch.Tensor) -> None:
+        assert output.shape[-1] * 2 == input.shape[-1]
+        if activation == "silu":
+            torch.ops._C.silu_and_mul(output, input)
+        elif activation == "gelu":
+            torch.ops._C.gelu_and_mul(output, input)
+        else:
+            raise ValueError(f"Unsupported FusedMoe activation: {activation}")
+
+    @abstractmethod
+    def apply(
+        self,
+        hidden_states: torch.Tensor,
+        w1: torch.Tensor,
+        w2: torch.Tensor,
+        topk_ids: torch.Tensor,
+        activation: str,
+        global_num_experts: int,
+        expert_map: Optional[torch.Tensor],
+        w1_scale: Optional[torch.Tensor],
+        w2_scale: Optional[torch.Tensor],
+        w1_zp: Optional[torch.Tensor],
+        w2_zp: Optional[torch.Tensor],
+        a1q_scale: Optional[torch.Tensor],
+        a2_scale: Optional[torch.Tensor],
+        workspace13: torch.Tensor,
+        workspace2: torch.Tensor,
+        expert_num_tokens: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        """
+        This function computes the intermediate result of a Mixture of Experts
+        (MoE) layer using two sets of weights, w1 and w2.
+
+        Parameters:
+        - hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
+          layer.
+        - w1 (torch.Tensor): The first set of expert weights.
+        - w2 (torch.Tensor): The second set of expert weights.
+        - topk_ids (torch.Tensor): A map of row to expert id.
+        - activation (str): The activation function to apply after the first
+          MoE layer.
+        - global_num_experts (int): The total number of experts in the global
+          expert space.
+        - expert_map (Optional[torch.Tensor]):  A tensor mapping expert indices
+          from the global expert space to the local expert space of the expert
+          parallel shard.
+        - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
+        - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
+        - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
+          w1.
+        - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
+          w2.
+        - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
+          used for a1.
+        - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
+        - workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
+          must be large enough to hold output of either MoE gemm.
+        - workspace2 (torch.Tensor): A scratch tensor used for the activation
+          function.
+        - expert_num_tokens: An optional tensor containing the number of tokens
+          assigned to each expert when using batched experts format input.
+
+        Returns:
+        - torch.Tensor: The unweighted, unreduced output tensor
+        """
+        raise NotImplementedError
+
+
+class FusedMoEModularKernel(torch.nn.Module):
+    """
+    This class combines a FusedMoEQuantizeDispatchCombine instance and
+    a FusedMoEPermuteExpertsUnpermute to provide an interface that
+    is compatible with the `fused_experts` function in fused_moe.py.
+
+    It takes care of managing any required scratch space.
+
+    Note: Instances of this class should only be used for a single model
+    layer due to any layer specific state that may be used by the component
+    objects.
+    """
+
+    def __init__(
+        self,
+        dispatch_combine: FusedMoEQuantizeDispatchCombine,
+        fused_experts: FusedMoEPermuteExpertsUnpermute,
+    ):
+        super().__init__()
+        self.dispatch_combine = dispatch_combine
+        self.fused_experts = fused_experts
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        w1: torch.Tensor,
+        w2: torch.Tensor,
+        topk_weights: torch.Tensor,
+        topk_ids: torch.Tensor,
+        inplace: bool = False,
+        activation: str = "silu",
+        global_num_experts: int = -1,
+        expert_map: Optional[torch.Tensor] = None,
+        w1_scale: Optional[torch.Tensor] = None,
+        w2_scale: Optional[torch.Tensor] = None,
+        w1_zp: Optional[torch.Tensor] = None,
+        w2_zp: Optional[torch.Tensor] = None,
+        a1_scale: Optional[torch.Tensor] = None,
+        a2_scale: Optional[torch.Tensor] = None,
+        apply_router_weight_on_input: bool = False,
+    ) -> torch.Tensor:
+        """
+        This function computes a Mixture of Experts (MoE) layer using two sets
+        of weights, w1 and w2, and top-k gating mechanism.
+
+        Parameters:
+        - hidden_states: (torch.Tensor): The input tensor to the MoE layer.
+        - w1 (torch.Tensor): The first set of expert weights.
+        - w2 (torch.Tensor): The second set of expert weights.
+        - topk_weights (torch.Tensor): The topk weights applied at the end of
+          the layer.
+        - topk_ids (torch.Tensor): A map of row to expert id.
+        - inplace (bool): If True, perform the operation in-place.
+          Defaults to False.
+        - activation (str): The activation function to apply after the first
+          MoE layer.
+        - global_num_experts (int): The total number of experts in the global
+          expert space.
+        - expert_map (Optional[torch.Tensor]):  A tensor mapping expert indices
+          from the global expert space to the local expert space of the expert
+          parallel shard.
+        - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
+        - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
+        - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
+          w1.
+        - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
+          w2.
+        - a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1.
+        - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
+        - apply_router_weight_on_input (bool): When true, the topk weights are
+          applied directly on the inputs. This is only applicable when topk is
+          1.
+
+        Returns:
+        - torch.Tensor: The output tensor after applying the MoE layer.
+        """
+        a1 = hidden_states
+        E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids)
+
+        if global_num_experts == -1:
+            global_num_experts = E
+
+        output = a1 if inplace else torch.empty_like(a1)
+
+        workspace13_shape, workspace2_shape, workspace_dtype = (
+            self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
+                                                global_num_experts))
+
+        # We can reuse the memory between cache1 and cache3 because by the time
+        # we need cache3, we're done with cache1
+        workspace13 = torch.empty(workspace13_shape,
+                                  device=a1.device,
+                                  dtype=workspace_dtype)
+        workspace2 = torch.empty(workspace2_shape,
+                                 device=a1.device,
+                                 dtype=workspace_dtype)
+
+        a1q, a1q_scale, expert_num_tokens = self.dispatch_combine.dispatch(
+            a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts,
+            expert_map, apply_router_weight_on_input)
+
+        fused_out = self.fused_experts.apply(
+            a1q,
+            w1,
+            w2,
+            topk_ids,
+            activation=activation,
+            global_num_experts=global_num_experts,
+            expert_map=expert_map,
+            w1_scale=w1_scale,
+            w2_scale=w2_scale,
+            w1_zp=w1_zp,
+            w2_zp=w2_zp,
+            a1q_scale=a1q_scale,
+            a2_scale=a2_scale,
+            workspace13=workspace13,
+            workspace2=workspace2,
+            expert_num_tokens=expert_num_tokens,
+        )
+
+        self.dispatch_combine.combine(output, fused_out, topk_weights,
+                                      topk_ids, apply_router_weight_on_input)
+
+        return output
diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
index cdf7e31c1436..a72d0917e0a5 100644
--- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
+++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
@@ -3,6 +3,74 @@
 
 import torch
 
+from vllm import _custom_ops as ops
+from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
+    moe_align_block_size)
+from vllm.model_executor.layers.fused_moe.utils import _fp8_perm
+
+
+def _moe_permute(
+    curr_hidden_states: torch.Tensor,
+    a1q_scale: Optional[torch.Tensor],
+    curr_topk_ids: torch.Tensor,
+    global_num_experts: int,
+    expert_map: Optional[torch.Tensor],
+    block_m: int,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
+           Optional[torch.Tensor]]:
+    """
+    Determine the sorted_token_ids, expert_ids for the given problem size.
+    Permute the hidden states and scales according to `sorted_token_ids`.
+    """
+    top_k_num = curr_topk_ids.shape[1]
+
+    tokens_in_chunk, _ = curr_hidden_states.shape
+
+    sorted_token_ids, expert_ids, num_tokens_post_padded = (
+        moe_align_block_size(curr_topk_ids,
+                             block_m,
+                             global_num_experts,
+                             expert_map,
+                             pad_sorted_ids=True))
+
+    inv_perm: Optional[torch.Tensor] = None
+
+    num_tokens = top_k_num * tokens_in_chunk
+    sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
+    expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
+    inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
+
+    # Permute according to sorted token ids.
+    curr_hidden_states = _fp8_perm(curr_hidden_states,
+                                   sorted_token_ids // top_k_num)
+
+    if a1q_scale is not None:
+        a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
+
+    return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
+            inv_perm)
+
+
+def _moe_unpermute_and_reduce(
+    out: torch.Tensor,
+    curr_hidden: torch.Tensor,
+    inv_perm: Optional[torch.Tensor],
+    topk_weight: torch.Tensor,
+    apply_router_weight_on_input: bool,
+) -> None:
+    """
+    Unpermute the final result and apply topk_weights, then perform the final
+    reduction on the hidden states.
+    """
+    M, topk = topk_weight.shape
+    K = curr_hidden.shape[-1]
+    if inv_perm is not None:
+        curr_hidden = curr_hidden[inv_perm, ...]
+    curr_hidden = curr_hidden.view(-1, topk, K)
+    if not apply_router_weight_on_input:
+        curr_hidden.mul_(topk_weight.view(M, -1, 1))
+    ops.moe_sum(curr_hidden, out)
+
 
 def moe_permute(
     hidden_states: torch.Tensor,
@@ -17,21 +85,21 @@ def moe_permute(
     fill_invalid_expert: int = -1
 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
     """
-    This function expands and permutes activation to gather uncontinuous tokens 
+    This function expands and permutes activation to gather uncontinuous tokens
       for each expert.
     Parameters:
-    - hidden_states (torch.Tensor): The input tensor to the MoE layer.    
+    - hidden_states (torch.Tensor): The input tensor to the MoE layer.
     - topk_weights (torch.Tensor): topk expert route weight for each token.
     - topk_ids (torch.Tensor): topk expert route id for each token.
     - token_expert_indices (torch.Tensor): indice for expanded hidden.
     - topk (int): The number of top-k experts to select.
     - n_expert (int): The number of expert.
     - n_local_expert (int): The number of expert in current EP rank.
-    - expert_map (Optional[torch.Tensor]):  A tensor mapping expert indices 
-        from the global expert space to the local expert space of the expert 
+    - expert_map (Optional[torch.Tensor]):  A tensor mapping expert indices
+        from the global expert space to the local expert space of the expert
         parallel shard.
     - align_block_size (Optional[int]): align group gemm block size for deepgemm
-    - fill_invalid_expert(int): fill expert id in m_indices for invalid expert 
+    - fill_invalid_expert(int): fill expert id in m_indices for invalid expert
       to workaround DeepGemm unsupported -1 in m_indices
     Returns:
     - permuted_hidden_states (torch.Tensor): permuted activation.
@@ -39,7 +107,7 @@ def moe_permute(
        of each expert for standard grouped gemm. if enable 'align_block_size'
        expert_first_token_offset will align up to 'align_block_size'.
     - src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute.
-    - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records 
+    - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
     the group which the j-th row of the LHS belong to.`
     """
     n_token, n_hidden = hidden_states.shape
@@ -87,7 +155,7 @@ def moe_unpermute(
     n_local_expert: int,
 ) -> torch.Tensor:
     """
-    This function expands and permutes activation to gathering uncontinuous 
+    This function expands and permutes activation to gathering uncontinuous
       tokens for each expert.
     Parameters:
     - permuted_hidden_states (torch.Tensor): permuted activation.
@@ -99,8 +167,8 @@ def moe_unpermute(
     - n_expert (int): The number of expert.
     - n_local_expert (int): The number of expert in current EP rank.
     Returns:
-    - hidden_states (torch.Tensor): The reduced and unpermuted activation 
-      tensor.  
+    - hidden_states (torch.Tensor): The reduced and unpermuted activation
+      tensor.
     """
     n_token, n_hidden = topk_weights.shape[0], permuted_hidden_states.shape[-1]
     assert (n_hidden * permuted_hidden_states.element_size()
diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py
new file mode 100644
index 000000000000..7392fe418a45
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py
@@ -0,0 +1,147 @@
+# SPDX-License-Identifier: Apache-2.0
+from typing import List, Optional, Tuple
+
+import pplx_kernels as pplx
+import torch
+
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.model_executor.layers.fused_moe.utils import (
+    moe_kernel_quantize_input)
+
+
+# Note use: layer.get_all_to_all() to get an AllToAll instance
+# The max_num_tokens, world_size and dp_size must be the same
+# as the ones used to create the AllToAll.
+class PplxDispatchCombine(mk.FusedMoEQuantizeDispatchCombine):
+
+    def __init__(self,
+                 a2a: pplx.AllToAll,
+                 max_num_tokens: int,
+                 world_size: int,
+                 rank: int,
+                 dp_size: int,
+                 quant_dtype: Optional[torch.dtype] = None,
+                 block_shape: Optional[List[int]] = None):
+        super().__init__()
+        assert max_num_tokens > 0
+        self.a2a = a2a
+        self.block_shape = block_shape
+        self.max_num_tokens = max_num_tokens
+        self.world_size = world_size
+        self.rank = rank
+        self.dp_size = dp_size
+        self.quant_dtype = quant_dtype
+
+    def dispatch(
+        self,
+        a1: torch.Tensor,
+        a1_scale: Optional[torch.Tensor],
+        a2_scale: Optional[torch.Tensor],
+        rank_topk_weights: torch.Tensor,
+        rank_topk_ids: torch.Tensor,
+        num_experts: int,
+        expert_map: Optional[torch.Tensor],
+        apply_router_weight_on_input: bool,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+        num_tokens = a1.shape[0]  # M
+        hidden_dim = a1.shape[-1]  # K
+
+        assert rank_topk_ids.shape[0] == num_tokens
+        # assert expert_map is None, "NYI"
+
+        # Is this always going to be a1.device?
+        device = a1.device
+
+        if apply_router_weight_on_input:
+            topk = rank_topk_ids.shape[1]
+            # TODO: this only works for topK=1, will need to update for topK>1
+            assert topk == 1, \
+                "apply_router_weight_on_input is only implemented for topk=1"
+            a1 = a1 * rank_topk_weights.to(a1.dtype)
+
+        per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
+            a2_scale.numel() != 1 if a2_scale is not None else False)
+
+        a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
+                                                   self.quant_dtype,
+                                                   per_act_token,
+                                                   self.block_shape)
+
+        # rem_experts need to be 0 for pplx to work properly.
+        rem_experts = num_experts % self.world_size
+        assert rem_experts == 0
+        num_local_experts = ((num_experts // self.world_size) +
+                             (1 if self.rank < rem_experts else 0))
+
+        expert_num_tokens = torch.empty(
+            num_local_experts,
+            dtype=torch.int32,
+            device=device,
+        )
+
+        num_dp = self.world_size // self.dp_size
+        expert_x = torch.empty(
+            (num_local_experts, self.max_num_tokens * num_dp, hidden_dim),
+            dtype=a1q.dtype,
+            device=device,
+        )
+
+        expert_x_scale: Optional[torch.Tensor] = None
+        if a1q.dtype.itemsize == 1:
+            float32_size = torch.float32.itemsize
+            block_size = (self.block_shape[0] if self.block_shape is not None
+                          else 1) * float32_size
+            expert_x_scale = torch.empty(
+                (
+                    num_experts,
+                    expert_x.size(1),
+                    (expert_x.size(2) + block_size - 1) // block_size,
+                ),
+                dtype=torch.float32,
+                device=device,
+            )
+
+        # This argument is optional, defaults to indices.shape[0]
+        # There's not much point setting this unless it is != indices.shape[0]
+        bound_m = None
+
+        self.a2a.dispatch(
+            out_expert_num_tokens=expert_num_tokens,
+            out_expert_x=expert_x,
+            out_expert_x_scale=expert_x_scale,
+            dp_x=a1q,
+            dp_x_scale=a1q_scale,
+            indices=rank_topk_ids,
+            bound_m=bound_m,
+        )
+
+        return expert_x, expert_x_scale, expert_num_tokens
+
+    def combine(
+        self,
+        output: torch.Tensor,
+        fused_expert_output: torch.Tensor,
+        topk_weights: torch.Tensor,
+        topk_ids: torch.Tensor,
+        apply_router_weight_on_input: bool,
+    ) -> None:
+        num_tokens = output.shape[0]  # M
+        # This argument is optional
+        # There's not much point setting this unless it is != topk_ids.shape[0]
+        bound_m = None
+
+        assert topk_ids.shape[0] == num_tokens
+        assert output.shape[0] <= self.max_num_tokens, \
+            f"{output.shape[0]} <= {self.max_num_tokens}"
+        assert output.shape[1] == fused_expert_output.shape[-1]
+
+        # Set weights to 1 if we did them in dispatch. This is hacky.
+        if apply_router_weight_on_input:
+            topk_weights = torch.ones_like(topk_weights)
+
+        self.a2a.combine(
+            out_tokens=output,
+            indices=topk_ids,  #.to(torch.uint32),
+            weights=topk_weights,
+            expert_y=fused_expert_output,
+            bound_m=bound_m)
diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
new file mode 100644
index 000000000000..e512c11933dc
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
@@ -0,0 +1,112 @@
+# SPDX-License-Identifier: Apache-2.0
+from typing import List, Optional, Tuple
+
+import torch
+
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
+    DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
+from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
+
+
+class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
+
+    def __init__(self,
+                 use_fp8_w8a8: bool = False,
+                 use_int8_w8a8: bool = False,
+                 use_int8_w8a16: bool = False,
+                 use_int4_w4a16: bool = False,
+                 per_channel_quant: bool = False,
+                 block_shape: Optional[List[int]] = None,
+                 block_m: Optional[int] = None,
+                 allow_deep_gemm: bool = False):
+        super().__init__()
+        self.triton_expert = TritonExperts(use_fp8_w8a8=use_fp8_w8a8,
+                                           use_int8_w8a8=use_int8_w8a8,
+                                           use_int4_w4a16=use_int4_w4a16,
+                                           use_int8_w8a16=use_int8_w8a16,
+                                           per_channel_quant=per_channel_quant,
+                                           block_shape=block_shape,
+                                           block_m=block_m)
+        self.deep_gemm_expert = DeepGemmExperts()
+        self.allow_deep_gemm = allow_deep_gemm
+        self.use_fp8_w8a8 = use_fp8_w8a8
+
+    def workspace_shapes(
+        self,
+        a: torch.Tensor,
+        M: int,
+        N: int,
+        K: int,
+        topk: int,
+        num_experts: int,
+    ) -> Tuple[int, int, torch.dtype]:
+        # Note: the deep gemm workspaces are strictly larger than the triton
+        # workspaces so we can be pessimistic here and allocate for DeepGemm
+        # even if we fall back to triton later, e.g. if expert maps are set.
+        if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
+            return self.deep_gemm_expert.workspace_shapes(
+                a, M, N, K, topk, num_experts)
+        else:
+            return self.triton_expert.workspace_shapes(a, M, N, K, topk,
+                                                       num_experts)
+
+    def apply(
+        self,
+        hidden_states: torch.Tensor,
+        w1: torch.Tensor,
+        w2: torch.Tensor,
+        topk_ids: torch.Tensor,
+        activation: str,
+        global_num_experts: int,
+        expert_map: Optional[torch.Tensor],
+        w1_scale: Optional[torch.Tensor],
+        w2_scale: Optional[torch.Tensor],
+        w1_zp: Optional[torch.Tensor],
+        w2_zp: Optional[torch.Tensor],
+        a1q_scale: Optional[torch.Tensor],
+        a2_scale: Optional[torch.Tensor],
+        workspace13: torch.Tensor,
+        workspace2: torch.Tensor,
+        expert_num_tokens: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        N = w1.shape[1]
+        if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
+                and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
+            return self.deep_gemm_expert.apply(
+                hidden_states,
+                w1,
+                w2,
+                topk_ids,
+                activation,
+                global_num_experts,
+                expert_map,
+                w1_scale,
+                w2_scale,
+                w1_zp,
+                w2_zp,
+                a1q_scale,
+                a2_scale,
+                workspace13,
+                workspace2,
+                expert_num_tokens,
+            )
+        else:
+            return self.triton_expert.apply(
+                hidden_states,
+                w1,
+                w2,
+                topk_ids,
+                activation,
+                global_num_experts,
+                expert_map,
+                w1_scale,
+                w2_scale,
+                w1_zp,
+                w2_zp,
+                a1q_scale,
+                a2_scale,
+                workspace13,
+                workspace2,
+                expert_num_tokens,
+            )
diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py
index db31422f7275..c16389c25aa9 100644
--- a/vllm/model_executor/layers/fused_moe/utils.py
+++ b/vllm/model_executor/layers/fused_moe/utils.py
@@ -7,6 +7,8 @@
 from vllm import _custom_ops as ops
 from vllm.model_executor.layers.quantization.utils.fp8_utils import (
     per_token_group_quant_fp8)
+from vllm.model_executor.layers.quantization.utils.int8_utils import (
+    per_token_group_quant_int8, per_token_quant_int8)
 from vllm.utils import cdiv
 
 
@@ -15,34 +17,81 @@ def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
     Shrink the given tensor and apply the given view to it.  This is
     used to resize the intermediate fused_moe caches.
     """
-    assert prod(v) <= x.numel()
+    assert prod(
+        v) <= x.numel(), f"{prod(v)} <= {x.numel()}"  # CUDAGRAPH unfriendly?
     return x.flatten()[:prod(v)].view(*v)
 
 
 def _fp8_quantize(
     A: torch.Tensor,
     A_scale: Optional[torch.Tensor],
-    block_shape: Optional[List[int]],
+    per_act_token: bool,
+    block_shape: Optional[List[int]] = None,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     """
     Perform fp8 quantization on the inputs.  If a block_shape
     is provided, the output will be blocked.
     """
     if block_shape is None:
-        A, A_scale = ops.scaled_fp8_quant(A, A_scale)
+        A, A_scale = ops.scaled_fp8_quant(
+            A, A_scale, use_per_token_if_dynamic=per_act_token)
     else:
         assert len(block_shape) == 2
         _, block_k = block_shape[0], block_shape[1]
         A, A_scale = per_token_group_quant_fp8(A, block_k)
         assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
+
+    return A, A_scale
+
+
+def _int8_quantize(
+    A: torch.Tensor,
+    A_scale: Optional[torch.Tensor],
+    per_act_token: bool,
+    block_shape: Optional[List[int]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Perform int8 quantization on the inputs.  If a block_shape
+    is provided, the output will be blocked.
+    """
+
+    # If weights are per-channel (per_channel_quant=True), then
+    # activations apply per-token quantization. Otherwise, assume
+    # activation tensor-wise fp8/int8 quantization, dynamic or static
+    if block_shape is None:
+        assert per_act_token, \
+            "int8 quantization only supports block or channel-wise"
+        A, A_scale = per_token_quant_int8(A)
+    else:
+        assert len(block_shape) == 2
+        _, block_k = block_shape[0], block_shape[1]
+        A, A_scale = per_token_group_quant_int8(A, block_k)
+        assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
+
     return A, A_scale
 
 
+def moe_kernel_quantize_input(
+    A: torch.Tensor,
+    A_scale: Optional[torch.Tensor],
+    qtype: Optional[torch.dtype],
+    per_channel_quant: bool,
+    block_shape: Optional[List[int]] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+    if qtype == torch.float8_e4m3fn:
+        return _fp8_quantize(A, A_scale, per_channel_quant, block_shape)
+    elif qtype == torch.int8:
+        return _int8_quantize(A, A_scale, per_channel_quant, block_shape)
+    else:
+        assert A_scale is None
+        return A, A_scale
+
+
 def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
     """
     A permutation routine that works on fp8 types.
     """
-    if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
+    if torch.is_floating_point(m) and m.dtype.itemsize == 1:
         return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
     else:
         return m[idx, ...]
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index f7056016fe8c..86f48296c3a9 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -1,5 +1,6 @@
 # SPDX-License-Identifier: Apache-2.0
 
+import functools
 import importlib.util
 from typing import Any, Callable, Dict, List, Optional
 
@@ -9,6 +10,7 @@
 from torch.nn.parameter import Parameter
 
 import vllm.envs as envs
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
 from vllm import _custom_ops as ops
 from vllm.distributed import get_tensor_model_parallel_world_size
 from vllm.logger import init_logger
@@ -426,6 +428,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
     """
 
     def __init__(self, quant_config: Fp8Config):
+        from vllm.model_executor.layers.fused_moe import fused_experts
         self.quant_config = quant_config
         self.block_quant = self.quant_config.weight_block_size is not None
 
@@ -450,6 +453,11 @@ def __init__(self, quant_config: Fp8Config):
                 logger.warning_once(
                     "DeepGemm not supported on the current platform.")
 
+        self.fused_experts = functools.partial(
+            fused_experts,
+            block_shape=self.quant_config.weight_block_size,
+            allow_deep_gemm=self.allow_deep_gemm)
+
     def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
                        intermediate_size_per_partition: int,
                        params_dtype: torch.dtype, **extra_weight_attrs):
@@ -769,6 +777,31 @@ def process_weights_after_loading(self, layer: Module) -> None:
             del layer.w13_input_scale
             del layer.w2_input_scale
 
+    def set_dispatch_combine(
+        self,
+        dp_size: int,
+        world_size: int,
+        dispatch_combine: mk.FusedMoEQuantizeDispatchCombine,
+    ) -> bool:
+        from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
+            TritonOrDeepGemmExperts)
+
+        if self.use_marlin:
+            return False
+
+        experts = TritonOrDeepGemmExperts(
+            use_fp8_w8a8=True,
+            block_shape=self.quant_config.weight_block_size,
+            allow_deep_gemm=self.allow_deep_gemm,
+        )
+
+        self.fused_experts = mk.FusedMoEModularKernel(
+            dispatch_combine,
+            experts,
+        )
+
+        return True
+
     def apply(
         self,
         layer: torch.nn.Module,
@@ -787,8 +820,6 @@ def apply(
         apply_router_weight_on_input: bool = False,
         activation: str = "silu",
     ) -> torch.Tensor:
-        from vllm.model_executor.layers.fused_moe import fused_experts
-
         topk_weights, topk_ids = FusedMoE.select_experts(
             hidden_states=x,
             router_logits=router_logits,
@@ -815,28 +846,26 @@ def apply(
                 quant_type_id=scalar_types.float8_e4m3fn.id,
                 global_num_experts=global_num_experts,
                 expert_map=expert_map)
-
-        return fused_experts(
-            x,
-            layer.w13_weight,
-            layer.w2_weight,
-            topk_weights=topk_weights,
-            topk_ids=topk_ids,
-            inplace=True,
-            activation=activation,
-            use_fp8_w8a8=True,
-            global_num_experts=global_num_experts,
-            apply_router_weight_on_input=apply_router_weight_on_input,
-            expert_map=expert_map,
-            w1_scale=(layer.w13_weight_scale_inv
-                      if self.block_quant else layer.w13_weight_scale),
-            w2_scale=(layer.w2_weight_scale_inv
-                      if self.block_quant else layer.w2_weight_scale),
-            a1_scale=layer.w13_input_scale,
-            a2_scale=layer.w2_input_scale,
-            block_shape=self.quant_config.weight_block_size,
-            allow_deep_gemm=self.allow_deep_gemm,
-        )
+        else:
+            return self.fused_experts(
+                hidden_states=x,
+                w1=layer.w13_weight,
+                w2=layer.w2_weight,
+                topk_weights=topk_weights,
+                topk_ids=topk_ids,
+                inplace=True,
+                activation=activation,
+                use_fp8_w8a8=True,
+                global_num_experts=global_num_experts,
+                apply_router_weight_on_input=apply_router_weight_on_input,
+                expert_map=expert_map,
+                w1_scale=(layer.w13_weight_scale_inv
+                          if self.block_quant else layer.w13_weight_scale),
+                w2_scale=(layer.w2_weight_scale_inv
+                          if self.block_quant else layer.w2_weight_scale),
+                a1_scale=layer.w13_input_scale,
+                a2_scale=layer.w2_input_scale,
+            )
 
 
 class Fp8KVCacheMethod(BaseKVCacheMethod):
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index ce86b9b2c4f0..a75c7c116b98 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -24,6 +24,7 @@
 """Inference-only DeepseekV2/DeepseekV3 model."""
 from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
 
+
 import torch
 from torch import nn
 from transformers import PretrainedConfig
@@ -31,9 +32,7 @@
 from vllm.attention import Attention
 from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, ModelConfig, VllmConfig
-from vllm.distributed import (get_pp_group,
-                              get_tensor_model_parallel_world_size,
-                              tensor_model_parallel_all_reduce)
+from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 from vllm.model_executor.layers.activation import SiluAndMul
 from vllm.model_executor.layers.fused_moe import FusedMoE
 from vllm.model_executor.layers.layernorm import RMSNorm
@@ -143,7 +142,8 @@ def __init__(
                 intermediate_size=intermediate_size,
                 hidden_act=config.hidden_act,
                 quant_config=quant_config,
-                reduce_results=False,
+                reduce_results=self.experts.must_reduce_shared_expert_outputs(
+                ),
                 prefix=f"{prefix}.shared_experts",
             )
 
@@ -154,6 +154,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
             shared_output = self.shared_experts(hidden_states)
         # router_logits: (num_tokens, n_experts)
         router_logits, _ = self.gate(hidden_states)
+
         if hidden_states.dtype != torch.float16:
             final_hidden_states = self.experts(
                 hidden_states=hidden_states,
@@ -171,8 +172,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
                 # See DeepseekV2DecoderLayer for more details.
                 final_hidden_states = final_hidden_states + shared_output \
                     * (1. / self.routed_scaling_factor)
+
         if self.tp_size > 1:
-            final_hidden_states = tensor_model_parallel_all_reduce(
+            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
                 final_hidden_states)
 
         return final_hidden_states.view(num_tokens, hidden_dim)
diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py
index 7fff14cb9f12..b0c525849a2e 100644
--- a/vllm/model_executor/models/granitemoe.py
+++ b/vllm/model_executor/models/granitemoe.py
@@ -70,6 +70,7 @@ def __init__(self,
                  prefix: str = ""):
         super().__init__()
         self.hidden_size = hidden_size
+        self.tp_size = get_tensor_model_parallel_world_size()
 
         # Gate always runs at half / full precision for now.
         self.gate = ReplicatedLinear(hidden_size,
@@ -97,6 +98,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         # router_logits: (num_tokens, n_experts)
         router_logits, _ = self.gate(hidden_states)
         final_hidden_states = self.experts(hidden_states, router_logits)
+
+        if self.tp_size > 1:
+            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
+                final_hidden_states)
+
         return final_hidden_states.view(orig_shape)
 
 
diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py
index 0fdc30f36f9b..dfd0804f21cf 100644
--- a/vllm/model_executor/models/llama4.py
+++ b/vllm/model_executor/models/llama4.py
@@ -25,8 +25,7 @@
 from vllm.attention import Attention
 from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, VllmConfig
-from vllm.distributed import (get_tensor_model_parallel_world_size,
-                              tensor_model_parallel_all_reduce)
+from vllm.distributed import get_tensor_model_parallel_world_size
 from vllm.model_executor.layers.fused_moe import FusedMoE
 from vllm.model_executor.layers.layernorm import RMSNorm
 from vllm.model_executor.layers.linear import (QKVParallelLinear,
@@ -89,7 +88,7 @@ def __init__(self,
             quant_config=quant_config,
             bias=False,
             prefix=f"{prefix}.shared_expert",
-            reduce_results=False,  # We need to do scatter before reduce
+            reduce_results=self.experts.must_reduce_shared_expert_outputs(),
         )
 
     def forward(self, hidden_states):
@@ -102,7 +101,8 @@ def forward(self, hidden_states):
         experts_out = routed_out + shared_out
 
         if self.tp_size > 1:
-            experts_out = tensor_model_parallel_all_reduce(experts_out)
+            experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
+                experts_out)
 
         return experts_out
 
diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py
index 47d90919ed8f..e45b81cb0158 100644
--- a/vllm/model_executor/models/qwen2_moe.py
+++ b/vllm/model_executor/models/qwen2_moe.py
@@ -33,9 +33,7 @@
 from vllm.attention import Attention
 from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, VllmConfig
-from vllm.distributed import (get_pp_group,
-                              get_tensor_model_parallel_world_size,
-                              tensor_model_parallel_all_reduce)
+from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 from vllm.logger import init_logger
 from vllm.model_executor.layers.activation import SiluAndMul
 from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -129,7 +127,8 @@ def __init__(
                 intermediate_size=config.shared_expert_intermediate_size,
                 hidden_act=config.hidden_act,
                 quant_config=quant_config,
-                reduce_results=False,
+                reduce_results=self.experts.must_reduce_shared_expert_outputs(
+                ),
             )
         else:
             self.shared_expert = None
@@ -156,7 +155,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         if shared_output is not None:
             final_hidden_states = final_hidden_states + shared_output
         if self.tp_size > 1:
-            final_hidden_states = tensor_model_parallel_all_reduce(
+            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
                 final_hidden_states)
 
         return final_hidden_states.view(orig_shape)
diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py
index fe6b303ba0b5..f507a072734a 100644
--- a/vllm/model_executor/models/qwen3_moe.py
+++ b/vllm/model_executor/models/qwen3_moe.py
@@ -30,9 +30,7 @@
 from vllm.attention import Attention
 from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, VllmConfig
-from vllm.distributed import (get_pp_group,
-                              get_tensor_model_parallel_world_size,
-                              tensor_model_parallel_all_reduce)
+from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 from vllm.logger import init_logger
 from vllm.model_executor.layers.activation import SiluAndMul
 from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -137,7 +135,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
                                            router_logits=router_logits)
         final_hidden_states = final_hidden_states
         if self.tp_size > 1:
-            final_hidden_states = tensor_model_parallel_all_reduce(
+            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
                 final_hidden_states)
 
         return final_hidden_states.view(orig_shape)
diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py
index 0d18a5639c2a..f8746ca708f7 100644
--- a/vllm/v1/attention/backends/mla/common.py
+++ b/vllm/v1/attention/backends/mla/common.py
@@ -882,7 +882,10 @@ def forward(
 
         if attn_metadata is None:
             # Profiling run.
-            return output
+            # The zero fill is required when used with DP + EP
+            # to ensure all ranks within a DP group compute the
+            # same expert outputs.
+            return output.fill_(0)
 
         num_actual_toks = attn_metadata.num_actual_tokens