Skip to content

Commit

Permalink
Move _tp_plan setting to post_init
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Nov 7, 2024
1 parent 073c521 commit db6e5ee
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 25 deletions.
49 changes: 28 additions & 21 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Has support for a `QuantoQuantizedCache` instance as `past_key_values`
_supports_quantized_cache = False

# A tensor parallel plan to be applied to the model when TP is enabled. For
# top-level models, this attribute is currently defined in respective model
# code. For base models, this attribute comes from
# `config.base_model_tp_plan` during `post_init`.
_tp_plan = None

@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
Expand Down Expand Up @@ -1443,6 +1449,9 @@ def post_init(self):
"""
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
# If current model is a base model, attach `base_model_tp_plan` from config
if self.base_model is self:
self._tp_plan = self.config.base_model_tp_plan

def dequantize(self):
"""
Expand Down Expand Up @@ -3475,9 +3484,8 @@ def from_pretrained(

tp_plan = kwargs.pop("tp_plan", None)
if tp_plan is not None and tp_plan != "auto":
raise ValueError(
f"tp_plan supports 'auto' only for now but got {tp_plan}."
)
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")

if is_fsdp_enabled():
low_cpu_mem_usage = True
Expand Down Expand Up @@ -4095,9 +4103,7 @@ def from_pretrained(
init_contexts.append(init_empty_weights())
elif tp_plan is not None:
if not torch.distributed.is_initialized():
raise ValueError(
"Tensor Parallel requires torch.distributed to be initialized first."
)
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")

# Get device type (e.g. "cuda")
device_type = torch.distributed.distributed_c10d._device_capability()[0]
Expand Down Expand Up @@ -5063,21 +5069,22 @@ def tensor_parallel(self, device_mesh):
# parallelize a model.
def tplize(mod: torch.nn.Module) -> None:
tp_plan = getattr(mod, "_tp_plan", None)
if tp_plan:
logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}")
# In model configs, we use a neutral type (string) to specify
# parallel styles, here we translate them into torch TP types.
# Using tree_map because `tp_plan` is a dict.
tp_plan = torch.utils._pytree.tree_map(
translate_to_torch_parallel_style,
tp_plan,
)
# Apply TP to current module.
torch.distributed.tensor.parallel.parallelize_module(
mod,
device_mesh=device_mesh,
parallelize_plan=tp_plan,
)
if tp_plan is None:
return
logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}")
# In model configs, we use a neutral type (string) to specify
# parallel styles, here we translate them into torch TP types.
# Using tree_map because `tp_plan` is a dict.
tp_plan = torch.utils._pytree.tree_map(
translate_to_torch_parallel_style,
tp_plan,
)
# Apply TP to current module.
torch.distributed.tensor.parallel.parallelize_module(
mod,
device_mesh=device_mesh,
parallelize_plan=tp_plan,
)

# `apply` is a native method of `nn.Module` that recursively applies a
# function to every submodule.
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,6 @@ def __init__(self, config: GemmaConfig):
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.gradient_checkpointing = False
self._tp_plan = config.base_model_tp_plan
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,6 @@ def __init__(self, config: Gemma2Config):
self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.gradient_checkpointing = False
self._tp_plan = config.base_model_tp_plan
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,6 @@ def __init__(self, config: GlmConfig):
dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta
)
self.gradient_checkpointing = False
self._tp_plan = config.base_model_tp_plan
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,6 @@ def __init__(self, config: LlamaConfig):
self.rotary_emb = LlamaRotaryEmbedding(config=config)

self.gradient_checkpointing = False
self._tp_plan = config.base_model_tp_plan
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

Expand Down

0 comments on commit db6e5ee

Please sign in to comment.