Skip to content

[Feature]support trainer_degree in name_mapping #2922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions fastdeploy/rl/rollout_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def _init_model(self) -> nn.Layer:
model.eval()
return model

def get_name_mappings_to_training(self) -> Dict[str, str]:
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Get parameter name mappings between rollout and training models."""
return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})()

return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})(trainer_degree)
def get_quantization_infer_keys(self) -> Dict[str, str]:
"""Get parameter name mappings between rollout and training models."""
return getattr(self.rollout_model, "get_quantization_infer_keys", lambda: {})()
Expand Down Expand Up @@ -108,9 +108,6 @@ def _complete_missing_mappings(self) -> None:
# Skip weight scale parameters in mapping. Train and infer have same key.
self.infer_to_train_mapping[key] = key

if getattr(self.fd_config.model_config, "tie_word_embeddings", False):
self.infer_to_train_mapping.pop("lm_head.linear.weight")

def get_quantization_infer_keys(self) -> list[str]:
"""Get quantization infer keys"""
quant_weight_key = []
Expand Down Expand Up @@ -143,7 +140,7 @@ def name(self) -> str:
"""name"""
return "Ernie4_5_MoeForCausalLMRL"

def get_name_mappings_to_training(self) -> Dict[str, str]:
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
# Prepare placeholders
place_holders = ["weight"]
Expand Down Expand Up @@ -216,7 +213,7 @@ def name(self) -> str:
"""name"""
return "Ernie4_5_VLMoeForConditionalGenerationRL"

def get_name_mappings_to_training(self) -> Dict[str, str]:
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
# Prepare placeholders
place_holders = ["weight"]
Expand Down Expand Up @@ -284,9 +281,9 @@ def _generate_ranges(start, end, step=16, take=8):

assert isinstance(self.fd_config.model_config.moe_num_experts, list)
total_moe_num = sum(self.fd_config.model_config.moe_num_experts)
rollout_model_degree = self.fd_config.parallel_config.tensor_parallel_size
expert_num_per_rank = self.fd_config.model_config.moe_num_experts[0] // rollout_model_degree

if not trainer_degree:
trainer_degree = self.fd_config.parallel_config.tensor_parallel_size
expert_num_per_rank = self.fd_config.model_config.moe_num_experts[0] // trainer_degree
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None 的情况下,trainer_degree = rollout_model_degree = self.fd_config.parallel_config.tensor_parallel_size 。 做一手兼容

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# Process MoE layers
for layer_idx in range(text_moe_layer_start_index, text_moe_layer_end_index):
_add_expert_mappings(layer_idx, "text", expert_start=0)
Expand Down Expand Up @@ -317,7 +314,7 @@ def name(self) -> str:
"""name"""
return "Qwen2ForCausalLMRL"

def get_name_mappings_to_training(self) -> Dict[str, str]:
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
# Prepare placeholders
place_holders = ["weight"]
Expand Down Expand Up @@ -361,7 +358,7 @@ def name(self) -> str:
"""name"""
return "Qwen3MoeForCausalLMRL"

def get_name_mappings_to_training(self) -> Dict[str, str]:
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
# Prepare placeholders
place_holders = ["weight"]
Expand Down Expand Up @@ -430,6 +427,6 @@ def __init__(self, fd_config: FDConfig):
def name(self) -> str:
"""name"""
return "Qwen3ForCausalLMRL"

def get_name_mappings_to_training(self) -> Dict[str, str]:
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
pass
Loading