diff --git a/modelopt/torch/puzzletron/puzzletron.py b/modelopt/torch/puzzletron/puzzletron.py index 87d90fdd9..262df7648 100644 --- a/modelopt/torch/puzzletron/puzzletron.py +++ b/modelopt/torch/puzzletron/puzzletron.py @@ -67,8 +67,8 @@ def puzzletron( build_library_and_stats.launch_build_library_and_stats(hydra_cfg) dist.barrier() - # # Step 5: calc_one_block_scores (distributed processing) - # scoring.launch_scoring(hydra_cfg) + # Step 5: calc_one_block_scores (distributed processing) + scoring.launch_scoring(hydra_cfg) # # Step 6: mip_and_realize_models (distributed processing) # mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) diff --git a/modelopt/torch/puzzletron/replacement_library/replacement_library.py b/modelopt/torch/puzzletron/replacement_library/replacement_library.py index bf6cc6636..7935fea4a 100644 --- a/modelopt/torch/puzzletron/replacement_library/replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/replacement_library.py @@ -12,23 +12,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Replacement library for efficiently loading and managing layer-replaced DeciLM models. +""" +Replacement library for efficiently loading and managing layer-replaced DeciLM models. - Uses replacement_utils for parsing, sorting, and analyzing layer replacement configurations """ # mypy: ignore-errors +import copy import json import re +import tempfile from pathlib import Path +from typing import List, Optional -import numpy as np import torch from immutabledict import immutabledict from lru import LRU +from safetensors import safe_open from safetensors.torch import load_file as safe_load_file from torch import nn +from transformers import PretrainedConfig, PreTrainedModel import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.converter.converter import Converter from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMDecoderLayer, @@ -51,9 +57,11 @@ init_module_with_state_dict, load_model_config, ) +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_model_config from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import ( create_dummy_model, is_in_safetensors_format, + load_and_shard_model, load_sharded_state_dict, ) @@ -62,8 +70,10 @@ class ReplacementLibrary: def __init__( self, replacement_library_path: str | Path, - model_config_overrides: dict | None = None, + descriptor, + model_config_overrides: Optional[dict] = None, ): + self.descriptor = descriptor self.replacement_library = self._load_replacement_library(replacement_library_path) self._ensure_all_checkpoints_are_split() self.model_config_overrides = ( @@ -114,42 +124,77 @@ def n_layer(self) -> int: def model_config(self) -> DeciLMConfig: if self._model_config is None: self._model_config = load_model_config( - self.get_arbitrary_checkpoint_dir(), self.model_config_overrides + self.get_arbitrary_checkpoint_dir(), + self.model_config_overrides, + ignore_unexpected_config_keys=True, ) return self._model_config def create_model_config(self, layer_replacements: list[dict]): block_configs, _ = extract_block_configs_and_locations(layer_replacements) - model_config = self.model_config.set_block_configs(block_configs) + model_config = copy.deepcopy(self.model_config) + model_config.block_configs = block_configs + model_config.num_hidden_layers = len(block_configs) return model_config - def load_model(self, layer_replacements: list[dict]) -> DeciLMForCausalLM: - block_configs, block_locations = extract_block_configs_and_locations(layer_replacements) - model_config = self.model_config.set_block_configs(block_configs) + def _get_arbitrary_block_checkpoint_paths(self): + checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir()) + subblocks_dir = checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME + non_block_paths = [p for p in subblocks_dir.glob("*.safetensors") if "block_" not in p.name] + return non_block_paths + + def create_index_file_from_weights(self, weight_paths: List[str]): + weight_map = {} + for weight_path in weight_paths: + weight_path = Path(weight_path) + with safe_open(str(weight_path), framework="pt", device="cpu") as f: + for tensor_name in f.keys(): + weight_map[tensor_name] = f"{SAFETENSORS_SUBBLOCKS_DIR_NAME}/{weight_path.name}" + index = {"metadata": {"format": "pt"}, "weight_map": weight_map} + return index + + def prepare_tmp_checkpoint_dir( + self, + tmpdir: Path, + model_config: PretrainedConfig, + layer_replacements: List[dict], + ): + arbitrary_checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir()) - owned_block_indexes = _get_owned_block_indexes(model_config.get_num_hidden_layers()) - model = create_dummy_model(model_config, self.dtype) + weight_paths = self._get_arbitrary_block_checkpoint_paths() + for layer_replacement in layer_replacements: + weight_paths += layer_replacement["weight_paths"] - is_first_shard = 0 in owned_block_indexes - if is_first_shard and not isinstance(model.model.get_input_embeddings(), nn.Embedding): - model.set_input_embeddings(self.get_embedding()) + weights_index = self.create_index_file_from_weights(weight_paths) + index_path = tmpdir / "model.safetensors.index.json" + with index_path.open("w", encoding="utf-8") as out: + json.dump(weights_index, out, indent=2, sort_keys=True) - is_last_shard = model_config.get_num_hidden_layers() - 1 in owned_block_indexes - if is_last_shard and not isinstance(model.model.get_output_embeddings(), nn.Linear): - model.model.set_final_layer_norm(self.get_ln_f()) - model.set_output_embeddings(self.get_lm_head()) + Converter.copy_checkpoint_files(arbitrary_checkpoint_dir, tmpdir) + save_model_config(model_config, tmpdir) - active_blocks = [] - for block_idx in owned_block_indexes: - layer_replacement, block_idx_in_replacement = block_locations[block_idx] - block = self.get_block(layer_replacement, block_idx_in_replacement) - model.model.layers[block_idx] = block - active_blocks.append(block) + # create symlinks inside tmpdir + subblocks_dir = tmpdir / SAFETENSORS_SUBBLOCKS_DIR_NAME + subblocks_dir.mkdir(exist_ok=True) + for weight_path in weight_paths: + link_path = subblocks_dir / weight_path.name + link_path.symlink_to(weight_path) - self._move_inactive_blocks_to_cpu(active_blocks) + def load_model( + self, + layer_replacements: list[dict], + ) -> PreTrainedModel: + """Load model using AnyModel approach with temporary checkpoint directory.""" + model_config = self.create_model_config(layer_replacements) + with tempfile.TemporaryDirectory(prefix="replacement_solution_") as tmpdir: + tmpdir = Path(tmpdir) + self.prepare_tmp_checkpoint_dir( + tmpdir, model_config=model_config, layer_replacements=layer_replacements + ) + model = load_and_shard_model(descriptor=self.descriptor, checkpoint_path=tmpdir) return model - def load_checkpoint(self, checkpoint_dir: str | Path) -> DeciLMForCausalLM: + def load_checkpoint(self, checkpoint_dir: str | Path) -> PreTrainedModel: checkpoint_dir = Path(checkpoint_dir).resolve() layer_replacements = self._locate_replacements_of_entire_checkpoint(checkpoint_dir) model = self.load_model(layer_replacements) @@ -221,7 +266,7 @@ def _load_layer_replacement(self, layer_replacement: dict) -> nn.ModuleList: if len(state_dict) > 0: block_indices = [ int(re.findall(r"^model\.layers\.(\d+)\.", param_name)[0]) - for param_name in state_dict + for param_name in state_dict.keys() ] assert sorted(set(block_indices)) == list( range(min(block_indices), max(block_indices) + 1) @@ -239,7 +284,9 @@ def _load_layer_replacement(self, layer_replacement: dict) -> nn.ModuleList: } dtype = infer_weights_dtype(state_dict) - model_config = self.model_config.set_block_configs(layer_replacement["child_block_configs"]) + model_config = copy.deepcopy(self.model_config) + model_config.block_configs = layer_replacement["child_block_configs"] + model_config.num_hidden_layers = len(layer_replacement["child_block_configs"]) module_list = nn.ModuleList( [ @@ -316,7 +363,7 @@ def _get_arbitrary_non_block_param(self, param_name: str) -> torch.Tensor: partial_state_dict = load_sharded_state_dict(checkpoint_dir, [param_name]) return partial_state_dict[param_name] - non_block_pth_path = checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME / "non_block.pth" + non_block_pth_path = checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME / f"non_block.pth" assert non_block_pth_path.exists(), _error_message_ensure_split(checkpoint_dir) non_block_state_dict = torch.load(non_block_pth_path) return non_block_state_dict[param_name] diff --git a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py index 4e3266df4..d253c9445 100644 --- a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py +++ b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py @@ -21,9 +21,11 @@ # mypy: ignore-errors import json +import shutil import warnings from functools import partial from pathlib import Path +from typing import Optional import torch from omegaconf import DictConfig @@ -31,6 +33,8 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.converter import Converter +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch.puzzletron.replacement_library.replacement_library import ReplacementLibrary from modelopt.torch.puzzletron.replacement_library.replacement_utils import parse_layer_replacement @@ -40,15 +44,15 @@ copy_tokenizer, ) from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import ( - copy_deci_lm_hf_code, save_checkpoint, save_safetensors_index, ) +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model from modelopt.torch.puzzletron.tools.validation_utils import ( validate_model_and_extract_hidden_states, validate_model_with_teacher_similarity_metrics, ) -from modelopt.torch.puzzletron.utils.parsing import get_nested_key +from modelopt.torch.puzzletron.utils.parsing import get_nested_key, parse_path from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import perform_pipeline_stitches """ @@ -68,62 +72,55 @@ def validate_puzzle_solutions(args: DictConfig) -> None: Args: args: Configuration object containing the following attributes: - Puzzle Configuration (Required) attributes: - - - ``replacement_library_path`` (Path): Path to the replacement library JSON file. - - ``solutions_path`` (Path): Path to puzzle solutions JSON file or directory containing solution files. - - ``solutions_to_validate`` (list[int], optional): Indices of specific solutions to validate. - Validates all solutions if None. - - ``sort_solutions_by`` (str, optional): JSON field path to sort solutions by before validation. - - ``bigger_is_better`` (bool): If True, sort solutions in descending order. Used with sort_solutions_by. - - ``skip_validation`` (bool): If True, skip model validation and only save models if requested. - - ``save_models`` (bool): If True, save realized model checkpoints for each solution. - - Teacher/Tokenizer Configuration attributes: - - - ``teacher_dir`` (Path, optional): Path to teacher model directory. Auto-inferred if not provided. - - ``tokenizer_name`` (str, optional): Tokenizer name/path. Uses teacher_dir if not specified. - - Model Configuration (Required if skip_validation=False) attributes: - - - ``model_dtype`` (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). - - ``autocast_dtype`` (str or torch.dtype): Autocast data type for mixed precision. - - Dataset Configuration (Required if skip_validation=False) attributes: - - - ``dataset_path`` (str): Path to the validation dataset. - - ``data_column`` (str): Column name in dataset containing text data. - - ``block_size`` (int): Maximum sequence length for tokenization. - - ``eval_samples`` (int, optional): Number of samples to evaluate. - - ``val_dataset_name`` (str): Name of validation dataset split. - - ``source_datasets_to_discard`` (list[str], optional): List of source datasets to exclude. - - ``load_dataset_fn`` (callable, optional): Custom function to load the dataset. - - Data Processing (Required if skip_validation=False) attributes: - - - ``micro_batch_size`` (int): Batch size for evaluation. - - ``seed`` (int): Random seed for reproducibility. - - ``shuffle_seed`` (int, optional): Seed for shuffling data. - - ``varlen`` (bool): Enable variable-length sequences. - - ``bos_rate`` (float): Rate of adding BOS token. - - ``fim_rate`` (float): Fill-in-the-middle rate for code completion tasks. - - ``fim_spm_rate`` (float): SPM-based fill-in-the-middle rate. - - Output Configuration attributes: - - - ``output_dir`` (Path, optional): Directory to save validation results. - Auto-generated from solutions_path if not provided. - - Execution Options (Optional if skip_validation=False) attributes: - - - ``calc_losses_on_cpu`` (bool): Calculate losses on CPU to avoid OOM. - - ``write_results`` (bool): Write validation results to file. - - ``activations_log_dir`` (str, optional): Directory to log activation scores. - - ``activation_hooks_kwargs`` (str or dict, optional): Arguments for activation hooks. + Puzzle Configuration (Required): + - replacement_library_path (Path): Path to the replacement library JSON file. + - solutions_path (Path): Path to puzzle solutions JSON file or directory containing solution files. + - solutions_to_validate (list[int], optional): Indices of specific solutions to validate. Validates all solutions if None. + - sort_solutions_by (str, optional): JSON field path to sort solutions by before validation. + - bigger_is_better (bool): If True, sort solutions in descending order. Used with sort_solutions_by. + - skip_validation (bool): If True, skip model validation and only save models if requested. + - save_models (bool): If True, save realized model checkpoints for each solution. + + Teacher/Tokenizer Configuration: + - teacher_dir (Path, optional): Path to teacher model directory. Auto-inferred if not provided. + - tokenizer_name (str, optional): Tokenizer name/path. Uses teacher_dir if not specified. + + Model Configuration (Required if skip_validation=False): + - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). + - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. + + Dataset Configuration (Required if skip_validation=False): + - dataset_path (str): Path to the validation dataset. + - data_column (str): Column name in dataset containing text data. + - block_size (int): Maximum sequence length for tokenization. + - eval_samples (int, optional): Number of samples to evaluate. + - val_dataset_name (str): Name of validation dataset split. + - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. + - load_dataset_fn (callable, optional): Custom function to load the dataset. + + Data Processing (Required if skip_validation=False): + - micro_batch_size (int): Batch size for evaluation. + - seed (int): Random seed for reproducibility. + - shuffle_seed (int, optional): Seed for shuffling data. + - varlen (bool): Enable variable-length sequences. + - bos_rate (float): Rate of adding BOS token. + - fim_rate (float): Fill-in-the-middle rate for code completion tasks. + - fim_spm_rate (float): SPM-based fill-in-the-middle rate. + + Output Configuration: + - output_dir (Path, optional): Directory to save validation results. Auto-generated from solutions_path if not provided. + + Execution Options (Optional if skip_validation=False): + - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. + - write_results (bool): Write validation results to file. + - activations_log_dir (str, optional): Directory to log activation scores. + - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. Returns: None. Saves validation results and optionally model checkpoints to disk. """ + descriptor = ModelDescriptorFactory.get(args.descriptor) + puzzle_solutions = load_puzzle_solutions( args.solutions_path, args.sort_solutions_by, args.bigger_is_better ) @@ -143,29 +140,41 @@ def validate_puzzle_solutions(args: DictConfig) -> None: else args.solutions_path.with_name(f"{args.solutions_path.stem}--validation") ) - replacement_library = ReplacementLibrary(args.replacement_library_path) + replacement_library = ReplacementLibrary( + args.replacement_library_path, + descriptor=descriptor, + model_config_overrides={"use_cache": False}, + ) teacher_hidden_states = None if (args.teacher_dir is not None) and (not args.skip_validation): - teacher_model = replacement_library.load_checkpoint(args.teacher_dir) + teacher_model = load_and_shard_model( + checkpoint_path=args.teacher_dir, descriptor=descriptor + ) teacher_model.cuda(dist.local_rank()) - stitched_model = perform_pipeline_stitches(teacher_model) + stitched_model = perform_pipeline_stitches(teacher_model, descriptor=descriptor) teacher_hidden_states = validate_model_and_extract_hidden_states( args, stitched_model, tokenizer, output_dir, model_name="teacher", - pipeline_parallel=True, val_dataloader=val_dataloader, ) + # Properly release CUDA memory after teacher validation + teacher_model.cpu() + stitched_model.cpu() + torch.cuda.empty_cache() + torch.cuda.synchronize() + dist.barrier() + for i_solution, puzzle_solution in tqdm( list(zip(args.solutions_to_validate, puzzle_solutions)), desc="Validating solutions" ): layer_replacements = _extract_layer_replacements_from_puzzle_solution(puzzle_solution) - # realizable_as_symlinks = can_realize_as_symlinks(layer_replacements) - realizable_as_symlinks = False + realizable_as_symlinks = can_realize_as_symlinks(layer_replacements) + # realizable_as_symlinks = False model_config = replacement_library.create_model_config(layer_replacements) if (args.save_models and not realizable_as_symlinks) or (not args.skip_validation): model = replacement_library.load_model(layer_replacements) @@ -177,24 +186,21 @@ def validate_puzzle_solutions(args: DictConfig) -> None: / f"solution_{i_solution}" ) - model_config.dtype = args.model_dtype - model_config.architectures = ["DeciLMForCausalLM"] + model_config.dtype = getattr(args, "model_dtype", "torch.bfloat16") + Converter.copy_checkpoint_files(args.teacher_dir, checkpoint_dir) if realizable_as_symlinks: if dist.is_master(): - save_checkpoint_as_symlinks( - layer_replacements, model_config, checkpoint_dir, replacement_library - ) - else: - save_checkpoint(model, checkpoint_dir) + # save_checkpoint_as_symlinks is currently not supported + pass + save_checkpoint(model, checkpoint_dir, descriptor) copy_tokenizer(args.tokenizer_name, checkpoint_dir) - copy_deci_lm_hf_code(checkpoint_dir) dist.barrier() if not args.skip_validation: model.cuda(dist.local_rank()) - stitched_model = perform_pipeline_stitches(model) + stitched_model = perform_pipeline_stitches(model, descriptor=descriptor) validate_model_with_teacher_similarity_metrics( args, stitched_model, @@ -203,10 +209,15 @@ def validate_puzzle_solutions(args: DictConfig) -> None: output_dir, model_name=f"solution_{i_solution}", extra_payload={"i_solution": i_solution, "puzzle_solution": puzzle_solution}, - pipeline_parallel=True, val_dataloader=val_dataloader, ) + # Properly release CUDA memory after solution validation + model.cpu() + stitched_model.cpu() + torch.cuda.empty_cache() + torch.cuda.synchronize() + dist.barrier() @@ -278,7 +289,7 @@ def _extract_layer_replacements_from_puzzle_solution( def load_puzzle_solutions( solutions_path: Path, - sort_solutions_by: str | None, + sort_solutions_by: Optional[str], bigger_is_better: bool, ) -> list[dict]: assert solutions_path.exists(), f"{solutions_path=} does not exist" diff --git a/modelopt/torch/puzzletron/tools/validation_utils.py b/modelopt/torch/puzzletron/tools/validation_utils.py index 697977cda..d7197e8ab 100644 --- a/modelopt/torch/puzzletron/tools/validation_utils.py +++ b/modelopt/torch/puzzletron/tools/validation_utils.py @@ -21,7 +21,7 @@ # mypy: ignore-errors from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional, Union import torch from omegaconf import DictConfig, OmegaConf @@ -44,8 +44,7 @@ def validate_model_and_extract_hidden_states( tokenizer: PreTrainedTokenizerBase, output_dir: str | Path, model_name: str, - extra_payload: dict[str, Any] | None = None, - pipeline_parallel: bool = False, + extra_payload: Optional[dict[str, Any]] = None, val_dataloader=None, ) -> list[torch.Tensor | LowMemorySparseTensor]: mprint(f""" @@ -60,7 +59,6 @@ def validate_model_and_extract_hidden_states( model, tokenizer, return_hidden_states=True, - pipeline_parallel=pipeline_parallel, val_dataloader=val_dataloader, ) if dist.is_last_process(): @@ -77,8 +75,7 @@ def validate_model_with_teacher_similarity_metrics( target_hidden_states_per_batch: list[torch.Tensor], output_dir: str | Path, model_name: str, - extra_payload: dict[str, Any] | None = None, - pipeline_parallel: bool = False, + extra_payload: Optional[dict[str, Any]] = None, calculate_full_score_ablations: bool = False, val_dataloader=None, ) -> None: @@ -95,7 +92,6 @@ def validate_model_with_teacher_similarity_metrics( model, tokenizer, target_hidden_states_per_batch=target_hidden_states_per_batch, - pipeline_parallel=pipeline_parallel, calculate_full_score_ablations=calculate_full_score_ablations, val_dataloader=val_dataloader, )