Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
e82164f
Add anymodel directories to feature/puzzletron
danielkorzekwa Mar 4, 2026
2099df3
Make any_model conversion working.
danielkorzekwa Mar 5, 2026
eb5cf8a
Update child_init.py with anymodel version
danielkorzekwa Mar 5, 2026
c9de41c
fix attention pruning
danielkorzekwa Mar 5, 2026
3c1bc1f
Add trust_remote_code to load_model_config (default to false)
danielkorzekwa Mar 5, 2026
8357136
Make activation scoring working
danielkorzekwa Mar 5, 2026
6cc2194
Comment all tested models aside of llama_3_1_8b_instruct
danielkorzekwa Mar 5, 2026
ee4e1e3
Delete not needed decilm test
danielkorzekwa Mar 5, 2026
449b523
Fix broken tests
danielkorzekwa Mar 5, 2026
fb27bba
Update puzzletron_nas_pluging to any_model version
danielkorzekwa Mar 5, 2026
b350f82
Correct test resources used by tests.
danielkorzekwa Mar 5, 2026
fafe5a3
Disable puzzletron tests (will be enabled after all any_model logic i…
danielkorzekwa Mar 5, 2026
e988248
Merge branch 'dkorzekwa/anymodel_core' into dkorzekwa/anymodel_activa…
danielkorzekwa Mar 6, 2026
c717852
Comment out not implemented models.
danielkorzekwa Mar 6, 2026
030f126
format python docs
danielkorzekwa Mar 6, 2026
8dcdfbf
Merge branch 'dkorzekwa/anymodel_core' into dkorzekwa/anymodel_activa…
danielkorzekwa Mar 6, 2026
70df0df
Use trust_remote_code in force_cache_dynamic_modules()
danielkorzekwa Mar 6, 2026
bb56662
Merge branch 'dkorzekwa/anymodel_core' into dkorzekwa/anymodel_activa…
danielkorzekwa Mar 6, 2026
ecd953e
Fix anymodel pruning
danielkorzekwa Mar 6, 2026
ee8f538
Fix buid docs issue.
danielkorzekwa Mar 6, 2026
c9b76a1
Merge branch 'dkorzekwa/anymodel_core' into dkorzekwa/anymodel_activa…
danielkorzekwa Mar 6, 2026
6e3af61
Merge branch 'dkorzekwa/anymodel_activation_scoring' into dkorzekwa/a…
danielkorzekwa Mar 6, 2026
0ad6d92
Merging build_library_and_stats
danielkorzekwa Mar 6, 2026
995eb1a
Merging anymodel: calc_one_block_scores
danielkorzekwa Mar 6, 2026
34081c9
Mering any_model: calc_one_block_scores
danielkorzekwa Mar 6, 2026
47414d5
Clarify readme and avoid reusing the same reference in llama_converter.
danielkorzekwa Mar 9, 2026
a8305d8
Fix tied-embedding handling before writing the safetensors index.
danielkorzekwa Mar 9, 2026
68421a5
Fix NaN ranking currently selects NaNs as “best” experts by default.
danielkorzekwa Mar 9, 2026
d6b8028
Code clean up.
danielkorzekwa Mar 9, 2026
ecd2341
Code clean up.
danielkorzekwa Mar 10, 2026
f9d845d
code clean up
danielkorzekwa Mar 10, 2026
d171b01
Merge branch 'dkorzekwa/anymodel_core' into dkorzekwa/anymodel_activa…
danielkorzekwa Mar 10, 2026
722da90
Merge branch 'dkorzekwa/anymodel_activation_scoring' into dkorzekwa/a…
danielkorzekwa Mar 10, 2026
934ab2f
code clean up
danielkorzekwa Mar 10, 2026
0f14ec3
Merge branch 'dkorzekwa/anymodel_pruning' into dkorzekwa/anymodel_bui…
danielkorzekwa Mar 10, 2026
dcb9e02
remove not needed comment
danielkorzekwa Mar 10, 2026
0c9ea5d
Merge branch 'dkorzekwa/anymodel_build_library_and_stats' into dkorze…
danielkorzekwa Mar 10, 2026
176a435
Fix a broken test_puzzletron test on 2 gpus.
danielkorzekwa Mar 10, 2026
02e2c9b
Merge branch 'dkorzekwa/anymodel_activation_scoring' into dkorzekwa/a…
danielkorzekwa Mar 10, 2026
92c4419
Merge branch 'dkorzekwa/anymodel_pruning' into dkorzekwa/anymodel_bui…
danielkorzekwa Mar 10, 2026
aa1eb3e
Merge branch 'dkorzekwa/anymodel_build_library_and_stats' into dkorze…
danielkorzekwa Mar 10, 2026
335750c
Merge branch 'feature/puzzletron' into dkorzekwa/any_model_calc_one_b…
danielkorzekwa Mar 12, 2026
58a42ca
Fix tox -e build-docs issues
danielkorzekwa Mar 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions modelopt/torch/puzzletron/puzzletron.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def puzzletron(
build_library_and_stats.launch_build_library_and_stats(hydra_cfg)
dist.barrier()

# # Step 5: calc_one_block_scores (distributed processing)
# scoring.launch_scoring(hydra_cfg)
# Step 5: calc_one_block_scores (distributed processing)
scoring.launch_scoring(hydra_cfg)

# # Step 6: mip_and_realize_models (distributed processing)
# mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)
Expand Down
103 changes: 75 additions & 28 deletions modelopt/torch/puzzletron/replacement_library/replacement_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Replacement library for efficiently loading and managing layer-replaced DeciLM models.
"""
Replacement library for efficiently loading and managing layer-replaced DeciLM models.
- Uses replacement_utils for parsing, sorting, and analyzing layer replacement configurations
"""
# mypy: ignore-errors

import copy
import json
import re
import tempfile
from pathlib import Path
from typing import List, Optional

import numpy as np
import torch
from immutabledict import immutabledict
from lru import LRU
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file
from torch import nn
from transformers import PretrainedConfig, PreTrainedModel

import modelopt.torch.utils.distributed as dist
from modelopt.torch.puzzletron.anymodel.converter.converter import Converter
from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig
from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import (
DeciLMDecoderLayer,
Expand All @@ -51,9 +57,11 @@
init_module_with_state_dict,
load_model_config,
)
from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_model_config
from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import (
create_dummy_model,
is_in_safetensors_format,
load_and_shard_model,
load_sharded_state_dict,
)

Expand All @@ -62,8 +70,10 @@ class ReplacementLibrary:
def __init__(
self,
replacement_library_path: str | Path,
model_config_overrides: dict | None = None,
descriptor,
model_config_overrides: Optional[dict] = None,
):
self.descriptor = descriptor
self.replacement_library = self._load_replacement_library(replacement_library_path)
self._ensure_all_checkpoints_are_split()
self.model_config_overrides = (
Expand Down Expand Up @@ -114,42 +124,77 @@ def n_layer(self) -> int:
def model_config(self) -> DeciLMConfig:
if self._model_config is None:
self._model_config = load_model_config(
self.get_arbitrary_checkpoint_dir(), self.model_config_overrides
self.get_arbitrary_checkpoint_dir(),
self.model_config_overrides,
ignore_unexpected_config_keys=True,
)
return self._model_config

def create_model_config(self, layer_replacements: list[dict]):
block_configs, _ = extract_block_configs_and_locations(layer_replacements)
model_config = self.model_config.set_block_configs(block_configs)
model_config = copy.deepcopy(self.model_config)
model_config.block_configs = block_configs
model_config.num_hidden_layers = len(block_configs)
return model_config

def load_model(self, layer_replacements: list[dict]) -> DeciLMForCausalLM:
block_configs, block_locations = extract_block_configs_and_locations(layer_replacements)
model_config = self.model_config.set_block_configs(block_configs)
def _get_arbitrary_block_checkpoint_paths(self):
checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir())
subblocks_dir = checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME
non_block_paths = [p for p in subblocks_dir.glob("*.safetensors") if "block_" not in p.name]
return non_block_paths

def create_index_file_from_weights(self, weight_paths: List[str]):
weight_map = {}
for weight_path in weight_paths:
weight_path = Path(weight_path)
with safe_open(str(weight_path), framework="pt", device="cpu") as f:
for tensor_name in f.keys():
weight_map[tensor_name] = f"{SAFETENSORS_SUBBLOCKS_DIR_NAME}/{weight_path.name}"
index = {"metadata": {"format": "pt"}, "weight_map": weight_map}
return index

def prepare_tmp_checkpoint_dir(
self,
tmpdir: Path,
model_config: PretrainedConfig,
layer_replacements: List[dict],
):
arbitrary_checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir())

owned_block_indexes = _get_owned_block_indexes(model_config.get_num_hidden_layers())
model = create_dummy_model(model_config, self.dtype)
weight_paths = self._get_arbitrary_block_checkpoint_paths()
for layer_replacement in layer_replacements:
weight_paths += layer_replacement["weight_paths"]

is_first_shard = 0 in owned_block_indexes
if is_first_shard and not isinstance(model.model.get_input_embeddings(), nn.Embedding):
model.set_input_embeddings(self.get_embedding())
weights_index = self.create_index_file_from_weights(weight_paths)
index_path = tmpdir / "model.safetensors.index.json"
with index_path.open("w", encoding="utf-8") as out:
json.dump(weights_index, out, indent=2, sort_keys=True)

is_last_shard = model_config.get_num_hidden_layers() - 1 in owned_block_indexes
if is_last_shard and not isinstance(model.model.get_output_embeddings(), nn.Linear):
model.model.set_final_layer_norm(self.get_ln_f())
model.set_output_embeddings(self.get_lm_head())
Converter.copy_checkpoint_files(arbitrary_checkpoint_dir, tmpdir)
save_model_config(model_config, tmpdir)

active_blocks = []
for block_idx in owned_block_indexes:
layer_replacement, block_idx_in_replacement = block_locations[block_idx]
block = self.get_block(layer_replacement, block_idx_in_replacement)
model.model.layers[block_idx] = block
active_blocks.append(block)
# create symlinks inside tmpdir
subblocks_dir = tmpdir / SAFETENSORS_SUBBLOCKS_DIR_NAME
subblocks_dir.mkdir(exist_ok=True)
for weight_path in weight_paths:
link_path = subblocks_dir / weight_path.name
link_path.symlink_to(weight_path)

self._move_inactive_blocks_to_cpu(active_blocks)
def load_model(
self,
layer_replacements: list[dict],
) -> PreTrainedModel:
"""Load model using AnyModel approach with temporary checkpoint directory."""
model_config = self.create_model_config(layer_replacements)
with tempfile.TemporaryDirectory(prefix="replacement_solution_") as tmpdir:
tmpdir = Path(tmpdir)
self.prepare_tmp_checkpoint_dir(
tmpdir, model_config=model_config, layer_replacements=layer_replacements
)
model = load_and_shard_model(descriptor=self.descriptor, checkpoint_path=tmpdir)
return model

def load_checkpoint(self, checkpoint_dir: str | Path) -> DeciLMForCausalLM:
def load_checkpoint(self, checkpoint_dir: str | Path) -> PreTrainedModel:
checkpoint_dir = Path(checkpoint_dir).resolve()
layer_replacements = self._locate_replacements_of_entire_checkpoint(checkpoint_dir)
model = self.load_model(layer_replacements)
Expand Down Expand Up @@ -221,7 +266,7 @@ def _load_layer_replacement(self, layer_replacement: dict) -> nn.ModuleList:
if len(state_dict) > 0:
block_indices = [
int(re.findall(r"^model\.layers\.(\d+)\.", param_name)[0])
for param_name in state_dict
for param_name in state_dict.keys()
]
assert sorted(set(block_indices)) == list(
range(min(block_indices), max(block_indices) + 1)
Expand All @@ -239,7 +284,9 @@ def _load_layer_replacement(self, layer_replacement: dict) -> nn.ModuleList:
}

dtype = infer_weights_dtype(state_dict)
model_config = self.model_config.set_block_configs(layer_replacement["child_block_configs"])
model_config = copy.deepcopy(self.model_config)
model_config.block_configs = layer_replacement["child_block_configs"]
model_config.num_hidden_layers = len(layer_replacement["child_block_configs"])

module_list = nn.ModuleList(
[
Expand Down Expand Up @@ -316,7 +363,7 @@ def _get_arbitrary_non_block_param(self, param_name: str) -> torch.Tensor:
partial_state_dict = load_sharded_state_dict(checkpoint_dir, [param_name])
return partial_state_dict[param_name]

non_block_pth_path = checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME / "non_block.pth"
non_block_pth_path = checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME / f"non_block.pth"
assert non_block_pth_path.exists(), _error_message_ensure_split(checkpoint_dir)
non_block_state_dict = torch.load(non_block_pth_path)
return non_block_state_dict[param_name]
Expand Down
Loading
Loading