Skip to content

use dlblas #3469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion lmdeploy/pytorch/backends/cuda/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from ..moe import (FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl, FusedMoEBuilder, FusedMoEImpl, FusedMoEW8A8Builder,
FusedMoEW8A8Impl)

from typing import List, Tuple

logger = get_logger('lmdeploy')


Expand Down Expand Up @@ -588,6 +590,13 @@ def __init__(self,
self.use_deep_gemm = False
logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM')

try:
from dlblas.layers.moe.ep_moe import build_deepep_moe
self.use_dlblas = True
except ImportError:
self.use_dlblas = False
logger.warning('For higher performance, please install dlBLAS https://github.com/DeepLink-org/dlBLAS')

def forward(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
Expand All @@ -610,7 +619,10 @@ def do_renormalize(self, topk_weights):
return _renormalize(topk_weights, self.renormalize)

def fusedmoe_build(self, low_latency_mode: bool = False):
if low_latency_mode:
if self.use_dlblas:
return build_deepep_moe(low_latency_mode, self.ep_size, self.ep_group, self.num_experts, self.hidden_dim, self.block_size,
self.out_dtype)
elif low_latency_mode:
return FusedMoELowLatency(self.ep_size, self.ep_group, self.num_experts, self.hidden_dim, self.block_size,
self.out_dtype)
else:
Expand Down
7 changes: 6 additions & 1 deletion lmdeploy/pytorch/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
import os
enable_eplb = os.environ.get('EPLB_ENABLED', '0') == '1'


# microbatch
Expand Down Expand Up @@ -655,7 +657,10 @@ def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device
quantization_config = getattr(config, 'quantization_config', None)
self.hidden_dim = config.hidden_size
self.ffn_dim = config.moe_intermediate_size
self.num_experts = config.n_routed_experts
if enable_eplb:
self.num_experts = config.n_routed_experts + 32
else:
self.num_experts = config.n_routed_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.routed_scaling_factor = config.routed_scaling_factor
Expand Down
84 changes: 68 additions & 16 deletions lmdeploy/pytorch/nn/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from ..backends import OpType, get_backend
from .utils import div_up

import os
enable_eplb = os.environ.get('EPLB_ENABLED', '0') == '1'
from collections import defaultdict


class MoeType(Enum):
"""batch ecex type."""
Expand Down Expand Up @@ -73,10 +77,15 @@ def __init__(self,
self.half_out = out_features // 2

if self.ep:
self.expert_map = dict((eid, idx) for idx, eid in enumerate(expert_list))
self.weight.weight_loader = self.weight_loader_ep
if enable_eplb:
self.expert_map = defaultdict(list)
for idx, eid in enumerate(expert_list):
self.expert_map[eid].append(idx)
else:
self.expert_map = dict((eid, idx) for idx, eid in enumerate(expert_list))
self.scale.weight_loader = self.weight_loader_scale_ep
else:
self.weight.weight_loader = self.weight_loader_tp
self.scale.weight_loader = self.weight_loader_scale_tp

def update_weight(self, weight: torch.Tensor):
"""update weight."""
Expand Down Expand Up @@ -110,16 +119,49 @@ def weight_loader_ep(self, param: torch.nn.Parameter, loaded_weight: torch.Tenso
return

expert_map = self.expert_map
param_id = expert_map[expert_id]
if shard_id == 'gate':
param_data = param.data[param_id, :self.half_out]
elif shard_id == 'up':
param_data = param.data[param_id, self.half_out:]
elif shard_id == 'down':
param_data = param.data[param_id]
if not enable_eplb:
param_id = expert_map[expert_id]
if shard_id == 'gate':
param_data = param.data[param_id, :self.half_out]
elif shard_id == 'up':
param_data = param.data[param_id, self.half_out:]
elif shard_id == 'down':
param_data = param.data[param_id]
else:
raise RuntimeError(f'Unknown shard_id: {shard_id}')
param_data.copy_(loaded_weight)
else:
raise RuntimeError(f'Unknown shard_id: {shard_id}')
param_data.copy_(loaded_weight)
param_ids = expert_map[expert_id]
for param_id in param_ids:
if param.data.dtype == torch.float8_e4m3fn:
# 临时转为 float16 做索引
temp_param = param.data.to(torch.float16)

if shard_id == 'gate':
param_data = temp_param[param_id, :self.half_out]
elif shard_id == 'up':
param_data = temp_param[param_id, self.half_out:]
elif shard_id == 'down':
param_data = temp_param[param_id]
else:
raise RuntimeError(f'Unknown shard_id: {shard_id}')

# 将 loaded_weight 也转成 float16
weight_to_copy = loaded_weight.to(torch.float16)
param_data.copy_(weight_to_copy)

# 再写回原始 param.data(转换回 float8)
param.data.copy_(temp_param.to(torch.float8_e4m3fn))
else:
if shard_id == 'gate':
param_data = param.data[param_id, :self.half_out]
elif shard_id == 'up':
param_data = param.data[param_id, self.half_out:]
elif shard_id == 'down':
param_data = param.data[param_id]
else:
raise RuntimeError(f'Unknown shard_id: {shard_id}')
param_data.copy_(loaded_weight.to(param_data.dtype))


def _gather_input(x: torch.Tensor, tp_sizes: List[int]):
Expand Down Expand Up @@ -428,7 +470,12 @@ def __init__(self,
self.register_parameter('scale', scale)

if self.ep:
self.expert_map = dict((eid, idx) for idx, eid in enumerate(expert_list))
if enable_eplb:
self.expert_map = defaultdict(list)
for idx, eid in enumerate(expert_list):
self.expert_map[eid].append(idx)
else:
self.expert_map = dict((eid, idx) for idx, eid in enumerate(expert_list))
self.scale.weight_loader = self.weight_loader_scale_ep
else:
self.scale.weight_loader = self.weight_loader_scale_tp
Expand All @@ -445,9 +492,14 @@ def weight_loader_scale_ep(self, param: torch.nn.Parameter, loaded_weight: torch
shard_id: str):
expert_list = self.expert_list
if expert_id not in expert_list:
return
expert_id = self.expert_map[expert_id]
self.weight_loader_scale_tp(param, loaded_weight, expert_id, shard_id)
return
if not enable_eplb:
expert_id = self.expert_map[expert_id]
self.weight_loader_scale_tp(param, loaded_weight, expert_id, shard_id)
else:
expert_ids = self.expert_map[expert_id]
for expert_id in expert_ids:
self.weight_loader_scale_tp(param, loaded_weight, expert_id, shard_id)

def weight_loader_scale_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int,
shard_id: str):
Expand Down
Loading