diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 14e5579764a..b3c2dfc3167 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -212,7 +212,6 @@ def __init__( # set attribute from pretrained_config for key, value in pretrained_config.items(): setattr(self, key, value) - # we need set default value when not exist for key, value in PRETRAINED_INIT_CONFIGURATION.items(): if not hasattr(self, key): @@ -300,6 +299,9 @@ def override_name_from_config(self): if not hasattr(self, "mla_use_absorb"): self.mla_use_absorb = False + if hasattr(self, "num_experts") and getattr(self, "moe_num_experts") is None: + self.moe_num_experts = self.num_experts + def read_from_env(self): """ Read configuration information from environment variables and update the object's attributes. diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 1eb7af39490..65aa35df12e 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -117,6 +117,8 @@ "FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")), # Max pre-fetch requests number in PD "FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")), + # Enable or disable model caching. + # When enabled, the quantized model is stored as a cache for future inference to improve loading efficiency. "FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))), # Whether to clear cpu cache when clearing model weights. "FD_ENABLE_SWAP_SPACE_CLEARING": lambda: int(os.getenv("FD_ENABLE_SWAP_SPACE_CLEARING", "0")), diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index 41b06962da0..e4247e0a599 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -200,22 +200,16 @@ def apply( class UnquantizedFusedMoEMethod(MoEMethodBase): def create_weights(self, layer: nn.Layer, **extra_weight_attrs): - + num_experts = extra_weight_attrs.pop("num_experts") + hidden_size = extra_weight_attrs.pop("hidden_size") + moe_intermediate_size = extra_weight_attrs.pop("moe_intermediate_size") if current_platform.is_cuda(): - self.up_gate_proj_weight_shape = [ - layer.num_local_experts, - layer.hidden_size, - layer.moe_intermediate_size * 2, - ] - self.down_proj_weight_shape = [layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size] + self.up_gate_proj_weight_shape = [num_experts, hidden_size, moe_intermediate_size * 2] + self.down_proj_weight_shape = [num_experts, moe_intermediate_size, hidden_size] extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 1, "down": 0, "up": 1}} else: - self.up_gate_proj_weight_shape = [ - layer.num_local_experts, - layer.moe_intermediate_size * 2, - layer.hidden_size, - ] - self.down_proj_weight_shape = [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size] + self.up_gate_proj_weight_shape = [num_experts, moe_intermediate_size * 2, hidden_size] + self.down_proj_weight_shape = [num_experts, hidden_size, moe_intermediate_size] extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}} layer.up_gate_proj_weight = layer.create_parameter( diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index cd20d7eaf07..4f7a21923db 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -198,11 +198,7 @@ def process_weights_after_loading(self, layer): layer, weight_name, layer.create_parameter( - shape=[ - layer.num_local_experts, - ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]), - ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]), - ], + shape=weight.shape, dtype=weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), ), @@ -212,11 +208,7 @@ def process_weights_after_loading(self, layer): layer, scale_name, layer.create_parameter( - shape=[ - layer.num_local_experts, - ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]), - ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]), - ], + shape=scale.shape, dtype=scale_dtype, default_initializer=paddle.nn.initializer.Constant(0), ), diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 09330e549a7..adf12c0ac0b 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -199,7 +199,12 @@ def __init__( else: self.gate_correction_bias = None self.quant_method.create_weights( - self, weight_loader=self.weight_loader, model_format=fd_config.model_config.model_format + self, + weight_loader=self.weight_loader, + model_format=fd_config.model_config.model_format, + num_experts=self.num_local_experts if self.ep_size > 1 else self.num_experts, + hidden_size=self.hidden_size, + moe_intermediate_size=self.moe_intermediate_size, ) logger.info( @@ -214,45 +219,52 @@ def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str] # MoE experts has been fused in disk self._load_fused_experts_weight(param, loaded_weight) return - if hasattr(param, "SHARD_ID_TO_SHARDED_DIM"): - SHARD_ID_TO_SHARDED_DIM = param.SHARD_ID_TO_SHARDED_DIM - elif current_platform.is_cuda(): - SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1} - else: - SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0} - - if not param._is_initialized(): - param.initialize() - - if shard_id is None: - # 1.gate up fused in disk - weight_need_transpose = getattr(param, "weight_need_transpose", False) - output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]] - per_rank = output_size // 2 - start = self.tp_rank * per_rank - loaded_weight_shard_gate = slice_fn( - loaded_weight, weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank - ) - self._load_gate_up_weight( - param, expert_id, loaded_weight_shard_gate, "gate", SHARD_ID_TO_SHARDED_DIM["gate"], is_sharded=True - ) - start_up = output_size // 2 * self.tp_size + self.tp_rank * per_rank - loaded_weight_shard_up = slice_fn( - loaded_weight, weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank - ) - self._load_gate_up_weight( - param, expert_id, loaded_weight_shard_up, "up", SHARD_ID_TO_SHARDED_DIM["up"], is_sharded=True - ) - else: - # 2.gate up splited in disk - assert shard_id in ["gate", "down", "up"] - self._load_expert_weight( - param=param, - expert_id=expert_id, - loaded_weight=loaded_weight, - shard_id=shard_id, - shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id], - ) + + if expert_id - self.expert_id_offset >= 0 and expert_id - self.expert_id_offset < self.num_local_experts: + if hasattr(param, "SHARD_ID_TO_SHARDED_DIM"): + SHARD_ID_TO_SHARDED_DIM = param.SHARD_ID_TO_SHARDED_DIM + elif current_platform.is_cuda(): + SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1} + else: + SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0} + + if not param._is_initialized(): + param.initialize() + + if shard_id is None: + # 1.gate up fused in disk + weight_need_transpose = getattr(param, "weight_need_transpose", False) + output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]] + per_rank = output_size // 2 + start = self.tp_rank * per_rank + loaded_weight_shard_gate = slice_fn( + loaded_weight, weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank + ) + self._load_gate_up_weight( + param, + expert_id, + loaded_weight_shard_gate, + "gate", + SHARD_ID_TO_SHARDED_DIM["gate"], + is_sharded=True, + ) + start_up = output_size // 2 * self.tp_size + self.tp_rank * per_rank + loaded_weight_shard_up = slice_fn( + loaded_weight, weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank + ) + self._load_gate_up_weight( + param, expert_id, loaded_weight_shard_up, "up", SHARD_ID_TO_SHARDED_DIM["up"], is_sharded=True + ) + else: + # 2.gate up splited in disk + assert shard_id in ["gate", "down", "up"] + self._load_expert_weight( + param=param, + expert_id=expert_id, + loaded_weight=loaded_weight, + shard_id=shard_id, + shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id], + ) def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None, is_sharded=False): weight_need_transpose = getattr(param, "weight_need_transpose", False) diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index 8e47a919bc4..fd1b7e4f7c4 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -451,6 +451,20 @@ def compute_logits(self, hidden_states: paddle.Tensor): return logits + def empty_input_forward(self): + """ + empty_input_forward + """ + fake_hidden_states = paddle.empty( + shape=[1, self.fd_config.model_config.hidden_size], + dtype=paddle.get_default_dtype(), + ) + for i in range( + self.fd_config.model_config.moe_layer_start_index, + self.fd_config.model_config.num_hidden_layers, + ): + self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate) + def forward( self, ids_remove_padding: paddle.Tensor, diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 28278d5654f..69c6f9a4c9c 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -145,6 +145,8 @@ def fn(model_sublayer_name: str, param=None): quant_method = getattr(model_sublayer, "quant_method", None) if not hasattr(quant_method, "process_weights_after_loading"): return + if param is not None and hasattr(param, "tensor_track") and param.tensor_track is None: + return if param is not None and hasattr(param, "tensor_track") and not param.tensor_track.is_fully_copied(): return quant_method.process_weights_after_loading(model_sublayer) @@ -269,10 +271,6 @@ def _err_msg(msg: str) -> str: _err_msg("v1 loader currently does not support pre-sliced weights") return False - if fd_config.parallel_config.use_ep: - _err_msg("v1 loader currently does not support expert parallelism") - return False - if envs.FD_MOE_BACKEND.lower() == "marlin": _err_msg("v1 loader currently does not support marlin backend") return False diff --git a/fastdeploy/stop.sh b/fastdeploy/stop.sh index b12c068ecd2..db69c3420f4 100644 --- a/fastdeploy/stop.sh +++ b/fastdeploy/stop.sh @@ -1,10 +1,4 @@ -fastdeploy_inferernce_pids=$(ps auxww | grep "fastdeploy" | grep -v grep | awk '{print $2}') -echo $fastdeploy_inferernce_pids -for in_pid in ${fastdeploy_inferernce_pids[@]}; do - kill -9 ${in_pid} -done -echo 'end fastDeploy inference pids' api_server_pids=$(ps auxww | grep "api_server" | grep -v grep | awk '{print $2}') echo 'end api server pids:' @@ -18,3 +12,11 @@ for pid in $api_server_pids; do done echo 'end uvicorn multi workers' done + + +fastdeploy_inferernce_pids=$(ps auxww | grep "fastdeploy" | grep -v grep | awk '{print $2}') +echo $fastdeploy_inferernce_pids +for in_pid in ${fastdeploy_inferernce_pids[@]}; do + kill -9 ${in_pid} +done +echo 'end fastDeploy inference pids' diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 1d031bda428..6b7c66347db 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2387,7 +2387,11 @@ def profile_run(self) -> None: # 2. Dummy run self._dummy_run( - num_tokens=self.scheduler_config.max_num_batched_tokens, + num_tokens=( + self.scheduler_config.max_num_seqs + if self.scheduler_config.splitwise_role == "decode" + else self.scheduler_config.max_num_batched_tokens + ), batch_size=self.scheduler_config.max_num_seqs, ) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index f84ad66d239..2b9a466e307 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -853,7 +853,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: num_experts = model_config.moe_num_experts[0] else: num_experts = model_config.moe_num_experts - num_experts_per_rank = num_experts // parallel_config.expert_parallel_size num_experts_start_offset = expert_parallel_rank * num_experts_per_rank max_chips_per_node = 16 if current_platform.is_iluvatar() else 8