diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index cf2336bf4a..562c9c6f43 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1043,6 +1043,28 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): ), ) + sequential_checkpoint_dir: str | None = ModeloptField( + default=None, + title="Directory for sequential calibration checkpoints.", + description=( + "If set (together with sequential_checkpoint_interval), sequential calibration " + "will save intermediate checkpoints to this directory. On resume, if a checkpoint " + "with seq_calib_progress metadata is found, calibration resumes from the last " + "completed layer. Uses a rolling checkpoint (overwrites on each save)." + ), + ) + + sequential_checkpoint_interval: int | None = ModeloptField( + default=None, + gt=0, + title="Checkpoint interval for sequential calibration (in layers).", + description=( + "Save a checkpoint every N layers during sequential calibration. " + "Requires sequential_checkpoint_dir to also be set. " + "If None, no checkpoints are saved." + ), + ) + class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 472252e1c7..1fa049af04 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -43,6 +43,7 @@ TensorQuantizer, ) from .utils import is_quantized, is_quantized_linear +from .utils.checkpoint import SEQ_CALIB_PROGRESS_ATTR __all__ = [ "register", @@ -108,6 +109,12 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata: details regarding how MCore sharded checkpoint is restored, see modelopt.torch.opt.plugins.mcore_dist_checkpointing.restore_sharded_modelopt_state. """ + # Propagate sequential calibration progress to the model for resume. + # This is global metadata (not per-module), so it must run before the + # MCore early return — it applies to both HF and MCore checkpoint paths. + if "seq_calib_progress" in metadata: + setattr(model, SEQ_CALIB_PROGRESS_ATTR, metadata["seq_calib_progress"]) + if "quantizer_state" not in metadata: # MCore sharded checkpoint (`torch-dist`) has its quantizer_state stored as the # extra_state of `QuantModule`. The quantizer_state is resumed with @@ -170,6 +177,11 @@ def update_quantize_metadata( """Update the quantizer state in the metadata dict.""" metadata["quantizer_state"] = quantizer_state(model) + # Propagate sequential calibration progress if present (for checkpoint save) + progress = getattr(model, SEQ_CALIB_PROGRESS_ATTR, None) + if progress is not None: + metadata["seq_calib_progress"] = progress + def quantizer_state(model: nn.Module) -> dict[str, Any]: """Returns the quantizer state dict describing the quantizer states in the model.""" diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index e08efece9a..5a6de1e411 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -15,6 +15,7 @@ """This module contains the mode descriptor for the quantization mode.""" +import warnings from abc import abstractmethod from collections.abc import Callable @@ -228,6 +229,15 @@ def wrapped_calib_func( kwargs["algorithm"] = method moe_calib_experts_ratio = kwargs.pop("moe_calib_experts_ratio", None) + checkpoint_dir = kwargs.pop("sequential_checkpoint_dir", None) + checkpoint_interval = kwargs.pop("sequential_checkpoint_interval", None) + + if not sequential and (checkpoint_dir is not None or checkpoint_interval is not None): + warnings.warn( + "sequential_checkpoint_dir/sequential_checkpoint_interval are set but " + "use_sequential is False. Checkpoint settings will be ignored." + ) + if moe_calib_experts_ratio is not None: assert ( isinstance(moe_calib_experts_ratio, (int, float)) and 0 < moe_calib_experts_ratio <= 1 @@ -248,6 +258,8 @@ def wrapped_calib_func( model, forward_loop=forward_loop, calib_func=func, + checkpoint_dir=checkpoint_dir, + checkpoint_interval=checkpoint_interval, **kwargs, ) else: diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 89097fd32c..b5dc3efccf 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -49,6 +49,12 @@ reduce_amax, weight_attr_names, ) +from .utils.checkpoint import ( + SEQ_CALIB_PROGRESS_ATTR, + detect_sequential_resume_layer, + save_sequential_checkpoint, + should_save_seq_calib_checkpoint, +) __all__ = [ "awq", @@ -1870,6 +1876,8 @@ def sequential_calibrate( model: nn.Module, forward_loop: ForwardLoop, calib_func: Callable, + checkpoint_dir: str | None = None, + checkpoint_interval: int | None = None, **calib_kwargs, ): """Sequential calibration - a sequential layer-by-layer calibration algorithm. @@ -1877,6 +1885,18 @@ def sequential_calibrate( Runs the full model forward per layer but patches decoder layers with a skip / run / capture strategy so that inter-layer logic in parent modules (e.g. mask construction) executes naturally without model-specific hooks. + + Args: + model: The model to calibrate. + forward_loop: Callable that runs calibration data through the model. + calib_func: Per-layer calibration function (e.g. ``max_calibrate``). + checkpoint_dir: If set (with *checkpoint_interval*), save a rolling + checkpoint every *checkpoint_interval* layers. On re-run with a + model restored from such a checkpoint, calibration resumes + automatically from the last completed layer. + checkpoint_interval: Save a checkpoint every N layers. Requires + *checkpoint_dir* to also be set. + **calib_kwargs: Extra arguments forwarded to *calib_func*. """ if forward_loop is None: raise ValueError( @@ -1891,14 +1911,23 @@ def sequential_calibrate( "Sequential calibration requires a model with identifiable transformer layers." ) - print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers") + num_layers = len(transformer_layers) + print_rank_0(f"Sequential calibration: Found {num_layers} transformer layers") + + resume_from_layer, layer_output_metas = detect_sequential_resume_layer(model, num_layers) input_getter = LayerActivationCollector(model) - input_getter._patch_all_layers(decoder_layers=transformer_layers) + input_getter._patch_all_layers( + decoder_layers=transformer_layers, layer_output_metas=layer_output_metas + ) try: - for layer_idx, layer in enumerate(transformer_layers): - print_rank_0(f"Calibrating layer {layer_idx + 1}/{len(transformer_layers)}") + if resume_from_layer > 0: + input_getter.prepare_for_resume(resume_from_layer, forward_loop) + + for layer_idx in range(resume_from_layer, num_layers): + layer = transformer_layers[layer_idx] + print_rank_0(f"Calibrating layer {layer_idx + 1}/{num_layers}") layer_inputs = input_getter.get_input_activations(layer, forward_loop) def _layer_forward_loop(m, _inputs=layer_inputs): @@ -1909,5 +1938,19 @@ def _layer_forward_loop(m, _inputs=layer_inputs): del layer_inputs torch.cuda.empty_cache() + + if should_save_seq_calib_checkpoint( + layer_idx, num_layers, checkpoint_dir, checkpoint_interval + ): + assert checkpoint_dir is not None # narrowed by should_save_seq_calib_checkpoint + layer_output_metas = input_getter.get_layer_output_metas(layer_idx) + save_sequential_checkpoint( + model, layer_idx, num_layers, checkpoint_dir, layer_output_metas + ) finally: + # Sole owner of _seq_calib_progress cleanup. The attribute may be set + # by save_sequential_checkpoint (save path) or restore_quantizer_state + # (resume path); neither deletes it — this is the single cleanup point. + if hasattr(model, SEQ_CALIB_PROGRESS_ATTR): + delattr(model, SEQ_CALIB_PROGRESS_ATTR) input_getter._unpatch_all_layers() diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 0d02716a6e..7f7a75fd5c 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -40,6 +40,7 @@ from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE from ..utils import replace_function, sync_moe_expert_amax from ..utils.activation_collector import LayerActivationCollector +from ..utils.checkpoint import register_seq_calib_checkpoint_saver from .attention import register_attention_for_kv_quant from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin @@ -1472,6 +1473,14 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): ) +def _save_hf_checkpoint(model: nn.Module, checkpoint_dir: str) -> None: + """Save a HuggingFace model checkpoint using ``save_pretrained``.""" + model.save_pretrained(checkpoint_dir) + + +register_seq_calib_checkpoint_saver(_is_supported_hf_model, _save_hf_checkpoint) + + class _QuantMoELinear(QuantModule): """Quantization wrapper for Step3p5 MoELinear modules (fused expert weights). diff --git a/modelopt/torch/quantization/utils/activation_collector.py b/modelopt/torch/quantization/utils/activation_collector.py index 5f187fdcb2..18a171ae39 100644 --- a/modelopt/torch/quantization/utils/activation_collector.py +++ b/modelopt/torch/quantization/utils/activation_collector.py @@ -22,7 +22,7 @@ from collections import deque from dataclasses import dataclass, field -from typing import Any +from typing import Any, Literal import torch import torch.nn as nn @@ -44,7 +44,7 @@ class _LayerCalibState: patched forward to decide skip / run / capture / original behaviour. """ - mode: str = "original" + mode: Literal["original", "skip", "run", "capture"] = "original" name: str = "" cached_inputs: deque = field(default_factory=deque) collected_inputs: list = field(default_factory=list) @@ -150,12 +150,39 @@ def _zeros_from_meta(meta): # downstream run-mode layer, which replays from its own cached inputs instead. return meta[1] - def _patch_all_layers(self, decoder_layers: nn.ModuleList | None = None): + @staticmethod + def _remap_meta_device(meta, device: torch.device): + """Return a copy of *meta* with all tensor devices replaced by *device*.""" + tag = meta[0] + if tag == "tensor": + _, shape, dtype, _old_device = meta + return ("tensor", shape, dtype, device) + if tag == "tuple": + return ( + "tuple", + tuple(LayerActivationCollector._remap_meta_device(m, device) for m in meta[1]), + ) + if tag == "list": + return ( + "list", + [LayerActivationCollector._remap_meta_device(m, device) for m in meta[1]], + ) + return meta + + def _patch_all_layers( + self, + decoder_layers: nn.ModuleList | None = None, + layer_output_metas: dict | None = None, + ): """Bind the unified forward to every decoder layer and the model. Called once. Args: decoder_layers: Pre-resolved decoder layers. If *None*, layers are discovered via :meth:`get_decoder_layers`. + layer_output_metas: ``{layer_idx: output_meta}`` mapping from a + checkpoint, used to pre-populate skip-mode metadata on resume. + Tensor devices in the metas are remapped to each layer's current + device so that checkpoints are portable across placements. """ def _patched_forward(self, *args, **kwargs): @@ -200,10 +227,16 @@ def _patched_forward(self, *args, **kwargs): module_to_name = {m: name for name, m in self.model.named_modules()} try: - for layer in self._decoder_layers: + for i, layer in enumerate(self._decoder_layers): layer._seq_calib = _LayerCalibState( name=module_to_name.get(layer, type(layer).__name__), ) + if layer_output_metas and i in layer_output_metas: + p = next(layer.parameters(), None) + device = p.device if p is not None else torch.device("cpu") + layer._seq_calib.output_meta = self._remap_meta_device( + layer_output_metas[i], device + ) bind_forward_method(layer, _patched_forward, "_original_forward") def _early_stop_forward(module_self, *args, **kwargs): @@ -238,6 +271,26 @@ def _unpatch_all_layers(self): self._cleanup_layers() self._patched = False + def _set_layer_mode( + self, layer_idx: int, mode: Literal["original", "skip", "run", "capture"] + ) -> None: + """Set the mode for a single decoder layer with appropriate side effects.""" + assert self._decoder_layers is not None + state = self._decoder_layers[layer_idx]._seq_calib + state.mode = mode + + if mode == "skip": + state.cached_inputs.clear() + elif mode == "run": + if not state.collected_inputs: + raise RuntimeError( + f"Layer {layer_idx} ({state.name!r}) has no collected inputs to replay." + ) + state.cached_inputs = deque(state.collected_inputs) + state.collected_inputs = [] + elif mode == "capture": + state.collected_inputs = [] + def _set_layer_states(self, layer_idx: int): """Transition layer modes for the next calibration step. @@ -247,30 +300,11 @@ def _set_layer_states(self, layer_idx: int): * Layer ``i - 1`` → **run** (replay captured inputs with calibrated weights). * Layer ``i`` → **capture** (record inputs, then early-stop). """ - assert self._decoder_layers is not None - if layer_idx > 1: - done = self._decoder_layers[layer_idx - 2]._seq_calib - # output_meta is intentionally kept: skip mode needs it to produce - # correctly shaped zero-filled outputs for the parent forward. - done.mode = "skip" - done.cached_inputs.clear() - + self._set_layer_mode(layer_idx - 2, "skip") if layer_idx > 0: - prev = self._decoder_layers[layer_idx - 1]._seq_calib - if not prev.collected_inputs: - raise RuntimeError( - f"Layer {layer_idx - 1} ({prev.name!r}) has no collected inputs to replay. " - "Layers must be calibrated sequentially — ensure get_input_activations() " - "was called for every preceding layer in order." - ) - prev.mode = "run" - prev.cached_inputs = deque(prev.collected_inputs) - prev.collected_inputs = [] - - cur = self._decoder_layers[layer_idx]._seq_calib - cur.mode = "capture" - cur.collected_inputs = [] + self._set_layer_mode(layer_idx - 1, "run") + self._set_layer_mode(layer_idx, "capture") def _log_layer_summary(self, layer_idx: int): """Log a one-line summary of layer modes for the current calibration step.""" @@ -284,10 +318,57 @@ def _log_layer_summary(self, layer_idx: int): parts = [f"{mode}: {groups[mode]}" for mode in ("skip", "run", "capture") if mode in groups] print_rank_0(f"Calibrating layer {layer_idx + 1}/{n} | {' | '.join(parts)}") + def _validate_skip_metas(self, indices: range) -> None: + """Raise if any layer in *indices* is missing ``output_meta`` for skip mode.""" + assert self._decoder_layers is not None + for i in indices: + if self._decoder_layers[i]._seq_calib.output_meta is None: + raise RuntimeError( + f"Layer {i} has no output_meta but must be in skip mode for resume. " + "The checkpoint may be corrupted or missing layer_output_metas." + ) + + def _run_warmup_capture(self, capture_layer_idx: int, forward_loop: ForwardLoop) -> None: + """Run a forward pass with *capture_layer_idx* in capture mode. + + Raises RuntimeError if no inputs are collected. + """ + assert self._decoder_layers is not None + state = self._decoder_layers[capture_layer_idx]._seq_calib + state.mode = "capture" + state.collected_inputs = [] + + try: + forward_loop(self.model) + except Exception: + state.mode = "original" + state.collected_inputs = [] + raise + + if not state.collected_inputs: + state.mode = "original" + raise RuntimeError( + f"Warm-up forward collected no inputs for layer {capture_layer_idx}. " + "Cannot resume sequential calibration." + ) + # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ + def get_layer_output_metas(self, up_to_layer_idx: int) -> dict[int, tuple]: + """Return ``{layer_idx: output_meta}`` for layers ``0 .. up_to_layer_idx`` (inclusive). + + Only layers that have a non-*None* ``output_meta`` are included. + """ + assert self._decoder_layers is not None + metas: dict[int, tuple] = {} + for i in range(min(up_to_layer_idx + 1, len(self._decoder_layers))): + state = getattr(self._decoder_layers[i], self._LAYER_ATTR, None) + if state is not None and state.output_meta is not None: + metas[i] = state.output_meta + return metas + @torch.no_grad() def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: """Collect input activations for *layer* by running a full model forward. @@ -333,3 +414,45 @@ def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoo # in subsequent iterations via _set_layer_states. info.mode = "original" return inputs + + @torch.no_grad() + def prepare_for_resume( + self, + resume_layer_idx: int, + forward_loop: ForwardLoop, + ): + """Set up layer states for resuming sequential calibration from a checkpoint. + + Runs a single warm-up forward pass so that the next call to + :meth:`get_input_activations` for ``resume_layer_idx`` produces the + correct inputs. Layers ``0 .. K-2`` run in *original* mode (real + computation), layer ``K-1`` in *capture* mode. After the pass, + ``0 .. K-2`` switch to *skip* and ``K-1`` retains its + ``collected_inputs`` for the subsequent *run* transition. + + Output metas are restored from the checkpoint during ``_patch_all_layers``. + """ + if not self._patched: + raise RuntimeError( + "prepare_for_resume() requires _patch_all_layers() to be called first." + ) + if resume_layer_idx == 0: + return + + k = resume_layer_idx + preceding = range(k - 1) + + for i in preceding: + self._set_layer_mode(i, "original") + + print_rank_0( + f"Running warm-up forward pass for resume " + f"(layers 0..{k - 2} original, layer {k - 1} capture)" + ) + self._run_warmup_capture(k - 1, forward_loop) + + for i in preceding: + self._set_layer_mode(i, "skip") + self._validate_skip_metas(preceding) + + print_rank_0(f"Warm-up complete. Ready to resume from layer {k}.") diff --git a/modelopt/torch/quantization/utils/checkpoint.py b/modelopt/torch/quantization/utils/checkpoint.py new file mode 100644 index 0000000000..4c62907b51 --- /dev/null +++ b/modelopt/torch/quantization/utils/checkpoint.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint save/resume utilities for sequential calibration. + +Provides: + +* A pluggable **save registry** — plugins (e.g. huggingface.py) register a + ``(predicate, save_fn)`` pair at import time so that + :func:`get_checkpoint_saver` can find the right saver for any model. + +* **Resume detection** — :func:`detect_sequential_resume_layer` reads progress + metadata previously attached to the model and returns the layer index to + resume from. + +* **Checkpoint saving** — :func:`save_sequential_checkpoint` collects layer + output metadata, attaches progress to the model, and delegates to the + registered saver. +""" + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +from modelopt.torch.utils import print_rank_0 + +if TYPE_CHECKING: + from collections.abc import Callable + + import torch.nn as nn + +#: Model attribute name used to store sequential calibration progress. +SEQ_CALIB_PROGRESS_ATTR = "_seq_calib_progress" + +# --------------------------------------------------------------------------- +# Save registry +# --------------------------------------------------------------------------- +_CHECKPOINT_SAVE_SUPPORT: list[ + tuple[Callable[[nn.Module], bool], Callable[[nn.Module, str], None]] +] = [] + + +def register_seq_calib_checkpoint_saver( + is_supported: Callable[[nn.Module], bool], + save_fn: Callable[[nn.Module, str], None], +) -> None: + """Register a ``(predicate, saver)`` pair for sequential calibration checkpointing.""" + entry = (is_supported, save_fn) + if entry not in _CHECKPOINT_SAVE_SUPPORT: + _CHECKPOINT_SAVE_SUPPORT.append(entry) + + +def get_checkpoint_saver( + model: nn.Module, +) -> Callable[[nn.Module, str], None] | None: + """Return the registered save function for *model*, or *None*.""" + for is_supported, save_fn in _CHECKPOINT_SAVE_SUPPORT: + if is_supported(model): + return save_fn + return None + + +def detect_sequential_resume_layer(model: nn.Module, num_layers: int) -> tuple[int, dict | None]: + """Read checkpoint progress from the model and return ``(resume_layer_idx, layer_output_metas)``. + + Returns ``(0, None)`` for a fresh run with no checkpoint present. + The attribute is **not** deleted here — cleanup is owned by + :func:`sequential_calibrate`'s ``finally`` block. + """ + progress = getattr(model, SEQ_CALIB_PROGRESS_ATTR, None) + if progress is None: + return 0, None + + if not isinstance(progress, dict): + raise ValueError( + f"Expected seq_calib_progress to be a dict, got {type(progress).__name__}." + ) + for key in ("completed_layer_idx", "total_layers"): + if key not in progress: + raise ValueError(f"Checkpoint progress is missing required key {key!r}.") + + completed_layer = progress["completed_layer_idx"] + saved_total = progress["total_layers"] + + if not isinstance(completed_layer, int) or not isinstance(saved_total, int): + raise ValueError( + f"Checkpoint progress values must be ints, got " + f"completed_layer_idx={completed_layer!r}, total_layers={saved_total!r}." + ) + + if saved_total != num_layers: + raise ValueError( + f"Checkpoint was saved with {saved_total} layers but model has " + f"{num_layers} layers. Cannot resume." + ) + + if not (0 <= completed_layer < num_layers): + raise ValueError( + f"completed_layer_idx={completed_layer} is out of range for " + f"{num_layers} layers (expected 0..{num_layers - 1})." + ) + + resume_from = completed_layer + 1 + print_rank_0( + f"Resuming sequential calibration from layer {resume_from} " + f"(layers 0..{completed_layer} already calibrated)" + ) + return resume_from, progress.get("layer_output_metas", {}) + + +def should_save_seq_calib_checkpoint( + layer_idx: int, num_layers: int, checkpoint_dir: str | None, checkpoint_interval: int | None +) -> bool: + """Return *True* when a checkpoint should be saved after calibrating *layer_idx*.""" + if checkpoint_interval is not None and checkpoint_interval <= 0: + raise ValueError( + f"checkpoint_interval must be a positive integer, got {checkpoint_interval}." + ) + return ( + checkpoint_dir is not None + and checkpoint_interval is not None + and (layer_idx + 1) % checkpoint_interval == 0 + and layer_idx < num_layers - 1 # never save after the final layer + ) + + +def save_sequential_checkpoint( + model: nn.Module, + completed_layer_idx: int, + total_layers: int, + checkpoint_dir: str, + layer_output_metas: dict, +) -> None: + """Save a rolling checkpoint during sequential calibration. + + Temporarily attaches progress to the model so that ``update_quantize_metadata`` + can serialize it during ``save_pretrained``. The attribute is **not** deleted + here — cleanup is owned by :func:`sequential_calibrate`'s ``finally`` block. + """ + saver = get_checkpoint_saver(model) + if saver is None: + print_rank_0( + "Warning: checkpoint_dir is set but no checkpoint saver is registered " + "for this model type. Skipping checkpoint save." + ) + return + + model._seq_calib_progress = { + "completed_layer_idx": completed_layer_idx, + "total_layers": total_layers, + "layer_output_metas": layer_output_metas, + } + + os.makedirs(checkpoint_dir, exist_ok=True) + saver(model, checkpoint_dir) + print_rank_0( + f"Saved sequential calibration checkpoint at layer " + f"{completed_layer_idx + 1}/{total_layers} to {checkpoint_dir}" + ) diff --git a/tests/unit/torch/quantization/test_sequential_calibrate.py b/tests/unit/torch/quantization/test_sequential_calibrate.py index 14c1903de2..126a8e9f3a 100644 --- a/tests/unit/torch/quantization/test_sequential_calibrate.py +++ b/tests/unit/torch/quantization/test_sequential_calibrate.py @@ -15,14 +15,25 @@ """Unit tests for sequential_calibrate and LayerActivationCollector.""" +import io +import warnings from collections import deque +from unittest.mock import MagicMock import pytest import torch import torch.nn as nn +from modelopt.torch.quantization.config import MaxCalibConfig +from modelopt.torch.quantization.mode import wrapped_calib_func from modelopt.torch.quantization.model_calib import sequential_calibrate from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +from modelopt.torch.quantization.utils.checkpoint import ( + _CHECKPOINT_SAVE_SUPPORT, + SEQ_CALIB_PROGRESS_ATTR, + get_checkpoint_saver, + register_seq_calib_checkpoint_saver, +) class _DecoderBlock(nn.Module): @@ -533,3 +544,537 @@ def forward_loop(m): assert len(meta_1[1]) == 1 finally: collector._unpatch_all_layers() + + +# --------------------------------------------------------------------------- +# Checkpoint save / resume tests +# --------------------------------------------------------------------------- + + +@pytest.fixture +def _clean_checkpoint_registry(): + """Isolate the checkpoint saver registry for a single test.""" + old = _CHECKPOINT_SAVE_SUPPORT.copy() + _CHECKPOINT_SAVE_SUPPORT.clear() + yield + _CHECKPOINT_SAVE_SUPPORT.clear() + _CHECKPOINT_SAVE_SUPPORT.extend(old) + + +@pytest.fixture +def _register_discoverer(monkeypatch, _clean_checkpoint_registry): + """Register a simple discoverer and clear checkpoint saver registry.""" + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) + + +def _make_forward_loop(tokens): + def forward_loop(m): + for t in tokens: + m(t) + + return forward_loop + + +def _noop_calib(layer, forward_loop, **kwargs): + """No-op calibration: just run the forward loop.""" + forward_loop(layer) + + +@pytest.mark.usefixtures("_clean_checkpoint_registry") +class TestCheckpointSaveRegistry: + def test_register_and_get_saver(self): + saver = MagicMock() + register_seq_calib_checkpoint_saver(lambda m: True, saver) + model = nn.Linear(4, 4) + assert get_checkpoint_saver(model) is saver + + def test_get_saver_returns_none_when_empty(self): + assert get_checkpoint_saver(nn.Linear(4, 4)) is None + + def test_dedup_registration(self): + pred = lambda m: True # noqa: E731 + saver = lambda m, d: None # noqa: E731 + register_seq_calib_checkpoint_saver(pred, saver) + register_seq_calib_checkpoint_saver(pred, saver) + assert len(_CHECKPOINT_SAVE_SUPPORT) == 1 + + +@pytest.mark.usefixtures("_register_discoverer") +class TestCheckpointSave: + def test_no_save_when_dir_is_none(self): + """Default behavior: no checkpoint saving when checkpoint_dir is None.""" + model, tokens = _make_model_and_data(n_layers=4) + calibrated_layers = [] + + def tracking_calib(layer, fwd, **kwargs): + fwd(layer) + calibrated_layers.append(layer) + + sequential_calibrate( + model, + forward_loop=_make_forward_loop(tokens), + calib_func=tracking_calib, + checkpoint_dir=None, + checkpoint_interval=2, + ) + assert len(calibrated_layers) == 4 + assert not hasattr(model, SEQ_CALIB_PROGRESS_ATTR) + + def test_checkpoint_triggers_saver_at_interval(self): + """Saver should be called at the correct layer intervals.""" + model, tokens = _make_model_and_data(n_layers=5) + save_calls = [] + + def mock_saver(m, d): + progress = getattr(m, SEQ_CALIB_PROGRESS_ATTR) + save_calls.append(progress["completed_layer_idx"]) + + register_seq_calib_checkpoint_saver(lambda m: True, mock_saver) + + sequential_calibrate( + model, + forward_loop=_make_forward_loop(tokens), + calib_func=_noop_calib, + checkpoint_dir="/tmp/test_ckpt", + checkpoint_interval=2, + ) + # interval=2: save after layers 1 (idx 1), 3 (idx 3), skip last (idx 4) + assert save_calls == [1, 3] + + def test_checkpoint_skips_final_layer(self): + """No checkpoint is saved after the final layer.""" + model, tokens = _make_model_and_data(n_layers=3) + save_calls = [] + + def mock_saver(m, d): + progress = getattr(m, SEQ_CALIB_PROGRESS_ATTR) + save_calls.append(progress["completed_layer_idx"]) + + register_seq_calib_checkpoint_saver(lambda m: True, mock_saver) + + # interval=1 would save every layer, but should still skip the last + sequential_calibrate( + model, + forward_loop=_make_forward_loop(tokens), + calib_func=_noop_calib, + checkpoint_dir="/tmp/test_ckpt", + checkpoint_interval=1, + ) + assert 2 not in save_calls # layer index 2 (last) should not trigger save + + def test_checkpoint_warns_when_no_saver_registered(self, capsys): + """Should print a warning when checkpoint_dir is set but no saver is registered.""" + model, tokens = _make_model_and_data(n_layers=3) + + sequential_calibrate( + model, + forward_loop=_make_forward_loop(tokens), + calib_func=_noop_calib, + checkpoint_dir="/tmp/test_ckpt", + checkpoint_interval=1, + ) + # No error raised — just a warning printed + assert not hasattr(model, SEQ_CALIB_PROGRESS_ATTR) + + def test_progress_attr_cleaned_up_after_save(self): + """The _seq_calib_progress attribute should be cleaned up after save.""" + model, tokens = _make_model_and_data(n_layers=4) + + def mock_saver(m, d): + # During save, the attribute should be set + assert hasattr(m, SEQ_CALIB_PROGRESS_ATTR) + + register_seq_calib_checkpoint_saver(lambda m: True, mock_saver) + + sequential_calibrate( + model, + forward_loop=_make_forward_loop(tokens), + calib_func=_noop_calib, + checkpoint_dir="/tmp/test_ckpt", + checkpoint_interval=2, + ) + # After completion, attribute should be gone + assert not hasattr(model, SEQ_CALIB_PROGRESS_ATTR) + + def test_checkpoint_progress_contains_output_metas(self): + """Saved progress should include layer_output_metas.""" + model, tokens = _make_model_and_data(n_layers=4) + saved_progress = {} + + def mock_saver(m, d): + progress = getattr(m, SEQ_CALIB_PROGRESS_ATTR) + saved_progress.update(progress) + + register_seq_calib_checkpoint_saver(lambda m: True, mock_saver) + + sequential_calibrate( + model, + forward_loop=_make_forward_loop(tokens), + calib_func=_noop_calib, + checkpoint_dir="/tmp/test_ckpt", + checkpoint_interval=2, + ) + assert "layer_output_metas" in saved_progress + assert "completed_layer_idx" in saved_progress + assert "total_layers" in saved_progress + + +@pytest.mark.usefixtures("_register_discoverer") +class TestPrepareForResume: + def test_basic_resume_from_layer_2(self): + """After prepare_for_resume(2), layers 0 should be skip, layer 1 should + have collected_inputs, and get_input_activations(layer_2) should work.""" + model, tokens = _make_model_and_data(n_layers=4) + fwd = _make_forward_loop(tokens) + + # First, do a full run to get output_metas + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + for layer in model.layers: + collector.get_input_activations(layer, fwd) + output_metas = { + i: model.layers[i]._seq_calib.output_meta + for i in range(len(model.layers)) + if model.layers[i]._seq_calib.output_meta is not None + } + finally: + collector._unpatch_all_layers() + + # Now simulate resume from layer 2 + collector2 = LayerActivationCollector(model) + collector2._patch_all_layers(layer_output_metas=output_metas) + try: + collector2.prepare_for_resume(2, fwd) + + # Layer 0 should be in skip mode + assert model.layers[0]._seq_calib.mode == "skip" + assert model.layers[0]._seq_calib.output_meta is not None + + # Layer 1 should have collected_inputs from warm-up + assert len(model.layers[1]._seq_calib.collected_inputs) > 0 + + # get_input_activations for layer 2 should work + inputs = collector2.get_input_activations(model.layers[2], fwd) + assert len(inputs) == len(tokens) + finally: + collector2._unpatch_all_layers() + + def test_resume_from_layer_1(self): + """Edge case: resume from layer 1 (only layer 0 was calibrated).""" + model, tokens = _make_model_and_data(n_layers=3) + fwd = _make_forward_loop(tokens) + + # Get output_metas from a full run + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + for layer in model.layers: + collector.get_input_activations(layer, fwd) + output_metas = { + i: model.layers[i]._seq_calib.output_meta + for i in range(len(model.layers)) + if model.layers[i]._seq_calib.output_meta is not None + } + finally: + collector._unpatch_all_layers() + + # Resume from layer 1 + collector2 = LayerActivationCollector(model) + collector2._patch_all_layers(layer_output_metas=output_metas) + try: + collector2.prepare_for_resume(1, fwd) + + # Layer 0 should have collected_inputs from warm-up capture + assert len(model.layers[0]._seq_calib.collected_inputs) > 0 + + # get_input_activations for layer 1 should work + inputs = collector2.get_input_activations(model.layers[1], fwd) + assert len(inputs) == len(tokens) + finally: + collector2._unpatch_all_layers() + + def test_resume_from_layer_0_is_noop(self): + """prepare_for_resume(0) should be a no-op.""" + model, tokens = _make_model_and_data(n_layers=3) + fwd = _make_forward_loop(tokens) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + collector.prepare_for_resume(0, fwd) + # All layers should still be in original mode + for layer in model.layers: + assert layer._seq_calib.mode == "original" + finally: + collector._unpatch_all_layers() + + def test_resume_requires_patched_state(self): + """prepare_for_resume should raise if layers aren't patched.""" + model, tokens = _make_model_and_data(n_layers=3) + fwd = _make_forward_loop(tokens) + collector = LayerActivationCollector(model) + with pytest.raises(RuntimeError, match="requires _patch_all_layers"): + collector.prepare_for_resume(1, fwd) + + def test_resume_missing_output_meta_raises(self): + """If a layer needs skip mode but has no output_meta, should raise.""" + model, tokens = _make_model_and_data(n_layers=4) + fwd = _make_forward_loop(tokens) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + # Try to resume from layer 3 with no saved output_metas + with pytest.raises(RuntimeError, match="no output_meta"): + collector.prepare_for_resume(3, fwd) + finally: + collector._unpatch_all_layers() + + +@pytest.mark.usefixtures("_register_discoverer") +class TestResumeDetection: + def test_resume_starts_from_correct_layer(self): + """sequential_calibrate should skip already-calibrated layers on resume.""" + model, tokens = _make_model_and_data(n_layers=4) + calibrated_layers = [] + + def tracking_calib(layer, fwd, **kwargs): + fwd(layer) + calibrated_layers.append(id(layer)) + + # First run to get output_metas + full_run_collector = LayerActivationCollector(model) + full_run_collector._patch_all_layers() + fwd = _make_forward_loop(tokens) + try: + for layer in model.layers: + full_run_collector.get_input_activations(layer, fwd) + output_metas = { + i: model.layers[i]._seq_calib.output_meta + for i in range(len(model.layers)) + if model.layers[i]._seq_calib.output_meta is not None + } + finally: + full_run_collector._unpatch_all_layers() + + # Set up resume from layer 2 + setattr( + model, + SEQ_CALIB_PROGRESS_ATTR, + { + "completed_layer_idx": 1, + "total_layers": 4, + "layer_output_metas": output_metas, + }, + ) + + sequential_calibrate( + model, + forward_loop=_make_forward_loop(tokens), + calib_func=tracking_calib, + ) + + # Should only calibrate layers 2 and 3 (not 0 and 1) + assert len(calibrated_layers) == 2 + assert calibrated_layers[0] == id(model.layers[2]) + assert calibrated_layers[1] == id(model.layers[3]) + + def test_resume_mismatched_layer_count_raises(self): + """Should raise ValueError when checkpoint layer count doesn't match.""" + model, tokens = _make_model_and_data(n_layers=3) + + setattr( + model, + SEQ_CALIB_PROGRESS_ATTR, + { + "completed_layer_idx": 1, + "total_layers": 10, # Mismatch! + "layer_output_metas": {}, + }, + ) + + with pytest.raises(ValueError, match="10 layers but model has 3"): + sequential_calibrate( + model, + forward_loop=_make_forward_loop(tokens), + calib_func=_noop_calib, + ) + + def test_progress_attr_cleaned_up_after_resume(self): + """_seq_calib_progress should be deleted after calibration completes.""" + model, tokens = _make_model_and_data(n_layers=3) + + # Get output_metas + collector = LayerActivationCollector(model) + collector._patch_all_layers() + fwd = _make_forward_loop(tokens) + try: + for layer in model.layers: + collector.get_input_activations(layer, fwd) + output_metas = { + i: model.layers[i]._seq_calib.output_meta + for i in range(len(model.layers)) + if model.layers[i]._seq_calib.output_meta is not None + } + finally: + collector._unpatch_all_layers() + + setattr( + model, + SEQ_CALIB_PROGRESS_ATTR, + { + "completed_layer_idx": 0, + "total_layers": 3, + "layer_output_metas": output_metas, + }, + ) + + sequential_calibrate( + model, + forward_loop=_make_forward_loop(tokens), + calib_func=_noop_calib, + ) + assert not hasattr(model, SEQ_CALIB_PROGRESS_ATTR) + + +@pytest.mark.usefixtures("_register_discoverer") +class TestMetadataIntegration: + def test_update_quantize_metadata_includes_progress(self): + """update_quantize_metadata should pick up _seq_calib_progress.""" + from modelopt.torch.quantization.config import QuantizeConfig + from modelopt.torch.quantization.conversion import update_quantize_metadata + + model = nn.Linear(4, 4) + progress = {"completed_layer_idx": 5, "total_layers": 10} + setattr(model, SEQ_CALIB_PROGRESS_ATTR, progress) + + metadata = {} + update_quantize_metadata(model, QuantizeConfig(), metadata) + + assert metadata["seq_calib_progress"] == progress + delattr(model, SEQ_CALIB_PROGRESS_ATTR) + + def test_output_meta_serialization_roundtrip(self): + """layer_output_metas should survive torch.save/load roundtrip.""" + output_metas = { + 0: ("tensor", torch.Size([2, 16]), torch.float32, torch.device("cpu")), + 1: ( + "tuple", + ( + ("tensor", torch.Size([2, 16]), torch.float32, torch.device("cpu")), + ("other", None), + ), + ), + } + progress = { + "completed_layer_idx": 1, + "total_layers": 4, + "layer_output_metas": output_metas, + } + + buf = io.BytesIO() + torch.save(progress, buf) + buf.seek(0) + loaded = torch.load(buf, weights_only=False) + + assert loaded["layer_output_metas"][0] == output_metas[0] + assert loaded["layer_output_metas"][1] == output_metas[1] + assert loaded["completed_layer_idx"] == 1 + + +@pytest.mark.usefixtures("_register_discoverer") +class TestFullCheckpointResumeRoundtrip: + def test_resume_produces_same_calibration_order(self): + """Full roundtrip: run with checkpoint, simulate resume, verify all layers calibrated.""" + torch.manual_seed(42) + n_layers = 5 + model, tokens = _make_model_and_data(n_layers=n_layers) + + # Track calibration order in a full uninterrupted run + full_run_layers = [] + + def tracking_calib_full(layer, fwd, **kwargs): + fwd(layer) + full_run_layers.append(id(layer)) + + sequential_calibrate( + model, forward_loop=_make_forward_loop(tokens), calib_func=tracking_calib_full + ) + assert len(full_run_layers) == n_layers + + # Now simulate a checkpointed run that "resumes" from layer 3 + torch.manual_seed(42) + model2, tokens2 = _make_model_and_data(n_layers=n_layers) + + # First, partially calibrate layers 0-2 to get output_metas + partial_collector = LayerActivationCollector(model2) + partial_collector._patch_all_layers() + fwd = _make_forward_loop(tokens2) + try: + for i in range(3): + inputs = partial_collector.get_input_activations(model2.layers[i], fwd) + + def _fwd(m, _inputs=inputs): + for args, kw in _inputs: + m(*args, **kw) + + _noop_calib(model2.layers[i], _fwd) + del inputs + + output_metas = { + i: model2.layers[i]._seq_calib.output_meta + for i in range(len(model2.layers)) + if model2.layers[i]._seq_calib.output_meta is not None + } + finally: + partial_collector._unpatch_all_layers() + + # Set progress for resume + setattr( + model2, + SEQ_CALIB_PROGRESS_ATTR, + { + "completed_layer_idx": 2, + "total_layers": n_layers, + "layer_output_metas": output_metas, + }, + ) + + resumed_layers = [] + + def tracking_calib_resume(layer, fwd, **kwargs): + fwd(layer) + resumed_layers.append(id(layer)) + + sequential_calibrate( + model2, + forward_loop=_make_forward_loop(tokens2), + calib_func=tracking_calib_resume, + ) + + # Should only calibrate layers 3 and 4 + assert len(resumed_layers) == 2 + assert resumed_layers[0] == id(model2.layers[3]) + assert resumed_layers[1] == id(model2.layers[4]) + + +class TestWrappedCalibFuncWarning: + def test_checkpoint_config_without_sequential_warns(self): + """Setting checkpoint config without use_sequential=True should warn.""" + config = MaxCalibConfig( + sequential_checkpoint_dir="/tmp/test", + sequential_checkpoint_interval=2, + use_sequential=False, + ) + model = nn.Linear(4, 4) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + wrapped_calib_func(model, config, forward_loop=None, func=None) + assert len(w) == 1 + assert "use_sequential is False" in str(w[0].message)