@@ -219,45 +219,52 @@ def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str]
219219 # MoE experts has been fused in disk
220220 self ._load_fused_experts_weight (param , loaded_weight )
221221 return
222- if hasattr (param , "SHARD_ID_TO_SHARDED_DIM" ):
223- SHARD_ID_TO_SHARDED_DIM = param .SHARD_ID_TO_SHARDED_DIM
224- elif current_platform .is_cuda ():
225- SHARD_ID_TO_SHARDED_DIM = {"gate" : 1 , "down" : 0 , "up" : 1 }
226- else :
227- SHARD_ID_TO_SHARDED_DIM = {"gate" : 0 , "down" : 1 , "up" : 0 }
228-
229- if not param ._is_initialized ():
230- param .initialize ()
231-
232- if shard_id is None :
233- # 1.gate up fused in disk
234- weight_need_transpose = getattr (param , "weight_need_transpose" , False )
235- output_size = param [expert_id - self .expert_id_offset ].shape [SHARD_ID_TO_SHARDED_DIM ["gate" ]]
236- per_rank = output_size // 2
237- start = self .tp_rank * per_rank
238- loaded_weight_shard_gate = slice_fn (
239- loaded_weight , weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM ["gate" ], start , start + per_rank
240- )
241- self ._load_gate_up_weight (
242- param , expert_id , loaded_weight_shard_gate , "gate" , SHARD_ID_TO_SHARDED_DIM ["gate" ], is_sharded = True
243- )
244- start_up = output_size // 2 * self .tp_size + self .tp_rank * per_rank
245- loaded_weight_shard_up = slice_fn (
246- loaded_weight , weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM ["up" ], start_up , start_up + per_rank
247- )
248- self ._load_gate_up_weight (
249- param , expert_id , loaded_weight_shard_up , "up" , SHARD_ID_TO_SHARDED_DIM ["up" ], is_sharded = True
250- )
251- else :
252- # 2.gate up splited in disk
253- assert shard_id in ["gate" , "down" , "up" ]
254- self ._load_expert_weight (
255- param = param ,
256- expert_id = expert_id ,
257- loaded_weight = loaded_weight ,
258- shard_id = shard_id ,
259- shard_dim = SHARD_ID_TO_SHARDED_DIM [shard_id ],
260- )
222+
223+ if expert_id - self .expert_id_offset >= 0 and expert_id - self .expert_id_offset < self .num_local_experts :
224+ if hasattr (param , "SHARD_ID_TO_SHARDED_DIM" ):
225+ SHARD_ID_TO_SHARDED_DIM = param .SHARD_ID_TO_SHARDED_DIM
226+ elif current_platform .is_cuda ():
227+ SHARD_ID_TO_SHARDED_DIM = {"gate" : 1 , "down" : 0 , "up" : 1 }
228+ else :
229+ SHARD_ID_TO_SHARDED_DIM = {"gate" : 0 , "down" : 1 , "up" : 0 }
230+
231+ if not param ._is_initialized ():
232+ param .initialize ()
233+
234+ if shard_id is None :
235+ # 1.gate up fused in disk
236+ weight_need_transpose = getattr (param , "weight_need_transpose" , False )
237+ output_size = param [expert_id - self .expert_id_offset ].shape [SHARD_ID_TO_SHARDED_DIM ["gate" ]]
238+ per_rank = output_size // 2
239+ start = self .tp_rank * per_rank
240+ loaded_weight_shard_gate = slice_fn (
241+ loaded_weight , weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM ["gate" ], start , start + per_rank
242+ )
243+ self ._load_gate_up_weight (
244+ param ,
245+ expert_id ,
246+ loaded_weight_shard_gate ,
247+ "gate" ,
248+ SHARD_ID_TO_SHARDED_DIM ["gate" ],
249+ is_sharded = True ,
250+ )
251+ start_up = output_size // 2 * self .tp_size + self .tp_rank * per_rank
252+ loaded_weight_shard_up = slice_fn (
253+ loaded_weight , weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM ["up" ], start_up , start_up + per_rank
254+ )
255+ self ._load_gate_up_weight (
256+ param , expert_id , loaded_weight_shard_up , "up" , SHARD_ID_TO_SHARDED_DIM ["up" ], is_sharded = True
257+ )
258+ else :
259+ # 2.gate up splited in disk
260+ assert shard_id in ["gate" , "down" , "up" ]
261+ self ._load_expert_weight (
262+ param = param ,
263+ expert_id = expert_id ,
264+ loaded_weight = loaded_weight ,
265+ shard_id = shard_id ,
266+ shard_dim = SHARD_ID_TO_SHARDED_DIM [shard_id ],
267+ )
261268
262269 def _load_gate_up_weight (self , param , expert_id , loaded_weight , shard_id , shard_dim = None , is_sharded = False ):
263270 weight_need_transpose = getattr (param , "weight_need_transpose" , False )
0 commit comments