Skip to content

Commit a2638a7

Browse files
committed
Fix: Broken model_dump()
1 parent 655a7e9 commit a2638a7

File tree

2 files changed

+216
-130
lines changed

2 files changed

+216
-130
lines changed

src/llmcompressor/recipe/recipe.py

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ class Recipe(RecipeBase):
3535
when serializing a recipe, yaml will be used by default.
3636
"""
3737

38+
version: Optional[str] = Field(default=None)
39+
args: RecipeArgs = Field(default_factory=RecipeArgs)
40+
stages: List[RecipeStage] = Field(default_factory=list)
41+
metadata: Optional[RecipeMetaData] = Field(default=None)
42+
args_evaluated: RecipeArgs = Field(default_factory=RecipeArgs)
43+
3844
@classmethod
3945
def from_modifiers(
4046
cls,
@@ -280,12 +286,6 @@ def simplify_combine_recipes(
280286

281287
return combined
282288

283-
version: Optional[str] = None
284-
args: RecipeArgs = Field(default_factory=RecipeArgs)
285-
stages: List[RecipeStage] = Field(default_factory=list)
286-
metadata: Optional[RecipeMetaData] = None
287-
args_evaluated: RecipeArgs = Field(default_factory=RecipeArgs)
288-
289289
def calculate_start(self) -> int:
290290
"""
291291
Calculate and return the start epoch of the recipe.
@@ -507,52 +507,60 @@ def combine_metadata(self, metadata: Optional[RecipeMetaData]):
507507

508508
def model_dump(self, *args, **kwargs) -> Dict[str, Any]:
509509
"""
510-
:return: A dictionary representation of the recipe
511-
"""
512-
dict_ = super().model_dump(*args, **kwargs)
513-
stages = {}
514-
515-
for stage in dict_["stages"]:
516-
name = f"{stage['group']}_stage"
517-
del stage["group"]
510+
Generate a serializable dictionary representation of this recipe.
518511
519-
if name not in stages:
520-
stages[name] = []
512+
This method transforms the internal recipe structure into a format
513+
suitable for YAML serialization while preserving all necessary
514+
information for round-trip deserialization.
521515
522-
stages[name].append(stage)
523-
524-
dict_["stages"] = stages
516+
:param args: Additional positional arguments for parent method
517+
:param kwargs: Additional keyword arguments for parent method
518+
:return: Dictionary ready for YAML serialization
519+
"""
520+
# Retrieve base representation from parent class
521+
raw_dict = super().model_dump(*args, **kwargs)
522+
523+
# Initialize clean output dictionary
524+
serializable_dict = {}
525+
526+
# Copy recipe metadata attributes
527+
metadata_keys = ["version", "args", "metadata"]
528+
for key in metadata_keys:
529+
if value := raw_dict.get(key):
530+
serializable_dict[key] = value
531+
532+
# Process and organize stages by group
533+
if "stages" in raw_dict:
534+
# Group stages by their type (e.g., "train", "eval")
535+
grouped_stages = {}
536+
for stage in raw_dict["stages"]:
537+
group_id = (
538+
f"{stage.pop('group')}_stage" # Remove group field and use as key
539+
)
525540

526-
yaml_recipe_dict = {}
541+
if group_id not in grouped_stages:
542+
grouped_stages[group_id] = []
527543

528-
# populate recipe level attributes
529-
recipe_level_attributes = ["version", "args", "metadata"]
544+
grouped_stages[group_id].append(stage)
530545

531-
for attribute in recipe_level_attributes:
532-
if attribute_value := dict_.get(attribute):
533-
yaml_recipe_dict[attribute] = attribute_value
546+
# Format each stage for YAML output
547+
for group_id, stages in grouped_stages.items():
548+
for idx, stage_data in enumerate(stages):
549+
# Create unique identifiers for multiple stages of same type
550+
final_id = f"{group_id}_{idx}" if len(stages) > 1 else group_id
534551

535-
# populate stages
536-
stages = dict_.pop("stages", {})
537-
for stage_name, stage_list in stages.items():
538-
for idx, stage in enumerate(stage_list):
539-
if len(stage_list) > 1:
540-
# resolve name clashes caused by combining recipes with
541-
# duplicate stage names
542-
final_stage_name = f"{stage_name}_{idx}"
543-
else:
544-
final_stage_name = stage_name
545-
stage_dict = get_yaml_serializable_stage_dict(
546-
modifiers=stage["modifiers"]
547-
)
552+
# Create clean stage representation
553+
stage_yaml = get_yaml_serializable_stage_dict(
554+
modifiers=stage_data["modifiers"]
555+
)
548556

549-
# infer run_type from stage
550-
if run_type := stage.get("run_type"):
551-
stage_dict["run_type"] = run_type
557+
# Preserve run type if specified
558+
if run_type := stage_data.get("run_type"):
559+
stage_yaml["run_type"] = run_type
552560

553-
yaml_recipe_dict[final_stage_name] = stage_dict
561+
serializable_dict[final_id] = stage_yaml
554562

555-
return yaml_recipe_dict
563+
return serializable_dict
556564

557565
def yaml(self, file_path: Optional[str] = None) -> str:
558566
"""
Lines changed: 165 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,185 @@
1+
import os
12
import tempfile
23

34
import pytest
45
import yaml
56

67
from llmcompressor.modifiers.obcq.base import SparseGPTModifier
8+
from llmcompressor.modifiers.pruning.constant import ConstantPruningModifier
79
from llmcompressor.recipe import Recipe
810
from tests.llmcompressor.helpers import valid_recipe_strings
911

1012

1113
@pytest.mark.parametrize("recipe_str", valid_recipe_strings())
12-
def test_recipe_create_instance_accepts_valid_recipe_string(recipe_str):
13-
recipe = Recipe.create_instance(recipe_str)
14-
assert recipe is not None, "Recipe could not be created from string"
14+
class TestRecipeWithStrings:
15+
"""Tests that use various recipe strings for validation."""
16+
17+
def test_create_from_string(self, recipe_str):
18+
"""Test creating a Recipe from a YAML string."""
19+
recipe = Recipe.create_instance(recipe_str)
20+
assert recipe is not None, "Recipe could not be created from string"
21+
assert isinstance(recipe, Recipe), "Created object is not a Recipe instance"
22+
23+
def test_create_from_file(self, recipe_str):
24+
"""Test creating a Recipe from a YAML file."""
25+
content = yaml.safe_load(recipe_str)
26+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f:
27+
yaml.dump(content, f)
28+
f.flush() # Ensure content is written
29+
recipe = Recipe.create_instance(f.name)
30+
assert recipe is not None, "Recipe could not be created from file"
31+
assert isinstance(recipe, Recipe), "Created object is not a Recipe instance"
32+
33+
def test_yaml_serialization_roundtrip(self, recipe_str):
34+
"""
35+
Test that a recipe can be serialized to YAML
36+
and deserialized back with all properties preserved.
37+
"""
38+
# Create original recipe
39+
original_recipe = Recipe.create_instance(recipe_str)
1540

41+
# Serialize to YAML
42+
yaml_str = original_recipe.yaml()
43+
assert yaml_str, "Serialized YAML string should not be empty"
1644

17-
@pytest.mark.parametrize("recipe_str", valid_recipe_strings())
18-
def test_recipe_create_instance_accepts_valid_recipe_file(recipe_str):
19-
content = yaml.safe_load(recipe_str)
20-
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f:
21-
yaml.dump(content, f)
22-
recipe = Recipe.create_instance(f.name)
23-
assert recipe is not None, "Recipe could not be created from file"
45+
# Deserialize from YAML
46+
deserialized_recipe = Recipe.create_instance(yaml_str)
2447

48+
# Compare serialized forms
49+
original_dict = original_recipe.model_dump()
50+
deserialized_dict = deserialized_recipe.model_dump()
2551

26-
@pytest.mark.parametrize("recipe_str", valid_recipe_strings())
27-
def test_yaml_serialization(recipe_str):
28-
recipe_instance = Recipe.create_instance(recipe_str)
29-
serialized_recipe = recipe_instance.yaml()
30-
recipe_from_serialized = Recipe.create_instance(serialized_recipe)
52+
assert original_dict == deserialized_dict, "Serialization roundtrip failed"
3153

32-
expected_dict = recipe_instance.model_dump()
33-
actual_dict = recipe_from_serialized.model_dump()
54+
def test_model_dump_and_validate(self, recipe_str):
55+
"""
56+
Test that model_dump produces a format compatible
57+
with model_validate.
58+
"""
59+
recipe = Recipe.create_instance(recipe_str)
60+
validated_recipe = Recipe.model_validate(recipe.model_dump())
61+
assert (
62+
recipe == validated_recipe
63+
), "Recipe instance and validated recipe do not match"
3464

35-
assert expected_dict == actual_dict
3665

66+
class TestRecipeSerialization:
67+
"""
68+
Tests for Recipe serialization and deserialization
69+
edge cases."""
3770

38-
@pytest.mark.parametrize("recipe_str", valid_recipe_strings())
39-
def test_model_dump_and_validate(recipe_str):
40-
recipe_instance = Recipe.create_instance(recipe_str)
41-
validated_recipe = Recipe.model_validate(recipe_instance.model_dump())
42-
assert (
43-
recipe_instance == validated_recipe
44-
), "Recipe instance and validated recipe do not match"
45-
46-
47-
def test_recipe_creates_correct_modifier():
48-
start = 1
49-
end = 10
50-
targets = "__ALL_PRUNABLE__"
51-
52-
yaml_str = f"""
53-
test_stage:
54-
pruning_modifiers:
55-
ConstantPruningModifier:
56-
start: {start}
57-
end: {end}
58-
targets: {targets}
59-
"""
71+
def test_empty_recipe_serialization(self):
72+
"""Test serialization of a minimal recipe with no stages."""
73+
recipe = Recipe()
74+
assert len(recipe.stages) == 0, "New recipe should have no stages"
75+
76+
# Test roundtrip serialization
77+
dumped = recipe.model_dump()
78+
loaded = Recipe.model_validate(dumped)
79+
assert recipe == loaded, "Empty recipe serialization failed"
6080

61-
recipe_instance = Recipe.create_instance(yaml_str)
62-
63-
stage_modifiers = recipe_instance.create_modifier()
64-
assert len(stage_modifiers) == 1
65-
assert len(modifiers := stage_modifiers[0].modifiers) == 1
66-
from llmcompressor.modifiers.pruning.constant import ConstantPruningModifier
67-
68-
assert isinstance(modifier := modifiers[0], ConstantPruningModifier)
69-
assert modifier.start == start
70-
assert modifier.end == end
71-
72-
73-
def test_recipe_can_be_created_from_modifier_instances():
74-
modifier = SparseGPTModifier(
75-
sparsity=0.5,
76-
)
77-
group_name = "dummy"
78-
79-
# for pep8 compliance
80-
recipe_str = (
81-
f"{group_name}_stage:\n"
82-
" pruning_modifiers:\n"
83-
" SparseGPTModifier:\n"
84-
" sparsity: 0.5\n"
85-
)
86-
87-
expected_recipe_instance = Recipe.create_instance(recipe_str)
88-
expected_modifiers = expected_recipe_instance.create_modifier()
89-
90-
actual_recipe_instance = Recipe.create_instance(
91-
[modifier], modifier_group_name=group_name
92-
)
93-
actual_modifiers = actual_recipe_instance.create_modifier()
94-
95-
# assert num stages is the same
96-
assert len(actual_modifiers) == len(expected_modifiers)
97-
98-
# assert num modifiers in each stage is the same
99-
assert len(actual_modifiers[0].modifiers) == len(expected_modifiers[0].modifiers)
100-
101-
# assert modifiers in each stage are the same type
102-
# and have the same parameters
103-
for actual_modifier, expected_modifier in zip(
104-
actual_modifiers[0].modifiers, expected_modifiers[0].modifiers
105-
):
106-
assert isinstance(actual_modifier, type(expected_modifier))
107-
assert actual_modifier.model_dump() == expected_modifier.model_dump()
81+
def test_file_serialization(self):
82+
"""Test serializing a recipe to a file and reading it back."""
83+
recipe = Recipe.create_instance(valid_recipe_strings()[0])
84+
85+
with tempfile.TemporaryDirectory() as temp_dir:
86+
file_path = os.path.join(temp_dir, "recipe.yaml")
87+
88+
# Write to file
89+
recipe.yaml(file_path=file_path)
90+
assert os.path.exists(file_path), "YAML file was not created"
91+
assert os.path.getsize(file_path) > 0, "YAML file is empty"
92+
93+
# Read back from file
94+
loaded_recipe = Recipe.create_instance(file_path)
95+
assert (
96+
recipe == loaded_recipe
97+
), "Recipe loaded from file doesn't match original"
98+
99+
100+
class TestRecipeModifiers:
101+
"""Tests for creating and working with modifiers in recipes."""
102+
103+
def test_creates_correct_modifier(self):
104+
"""
105+
Test that a recipe creates the expected modifier type
106+
with correct parameters.
107+
"""
108+
# Recipe parameters
109+
params = {"start": 1, "end": 10, "targets": "__ALL_PRUNABLE__"}
110+
111+
# Create recipe from YAML
112+
yaml_str = f"""
113+
test_stage:
114+
pruning_modifiers:
115+
ConstantPruningModifier:
116+
start: {params['start']}
117+
end: {params['end']}
118+
targets: {params['targets']}
119+
"""
120+
recipe = Recipe.create_instance(yaml_str)
121+
122+
# Get modifiers from recipe
123+
stage_modifiers = recipe.create_modifier()
124+
assert len(stage_modifiers) == 1, "Expected exactly one stage modifier"
125+
126+
modifiers = stage_modifiers[0].modifiers
127+
assert len(modifiers) == 1, "Expected exactly one modifier in the stage"
128+
129+
# Verify modifier type and parameters
130+
modifier = modifiers[0]
131+
assert isinstance(
132+
modifier, ConstantPruningModifier
133+
), "Wrong modifier type created"
134+
assert modifier.start == params["start"], "Modifier start value incorrect"
135+
assert modifier.end == params["end"], "Modifier end value incorrect"
136+
assert modifier.targets == params["targets"], "Modifier targets incorrect"
137+
138+
def test_create_from_modifier_instances(self):
139+
"""Test creating a recipe from modifier instances."""
140+
# Create a modifier instance
141+
sparsity_value = 0.5
142+
modifier = SparseGPTModifier(sparsity=sparsity_value)
143+
group_name = "dummy"
144+
145+
# Expected YAML representation
146+
recipe_str = (
147+
f"{group_name}_stage:\n"
148+
" pruning_modifiers:\n"
149+
" SparseGPTModifier:\n"
150+
f" sparsity: {sparsity_value}\n"
151+
)
152+
153+
# Create recipes for comparison
154+
expected_recipe = Recipe.create_instance(recipe_str)
155+
actual_recipe = Recipe.create_instance(
156+
[modifier], modifier_group_name=group_name
157+
)
158+
159+
# Compare recipes by creating and checking their modifiers
160+
self._compare_recipe_modifiers(actual_recipe, expected_recipe)
161+
162+
def _compare_recipe_modifiers(self, actual_recipe, expected_recipe):
163+
"""Helper method to compare modifiers created from two recipes."""
164+
actual_modifiers = actual_recipe.create_modifier()
165+
expected_modifiers = expected_recipe.create_modifier()
166+
167+
# Compare stage counts
168+
assert len(actual_modifiers) == len(expected_modifiers), "Stage counts differ"
169+
170+
if not actual_modifiers:
171+
return # No modifiers to compare
172+
173+
# Compare modifier counts in each stage
174+
assert len(actual_modifiers[0].modifiers) == len(
175+
expected_modifiers[0].modifiers
176+
), "Modifier counts differ"
177+
178+
# Compare modifier types and parameters
179+
for actual_mod, expected_mod in zip(
180+
actual_modifiers[0].modifiers, expected_modifiers[0].modifiers
181+
):
182+
assert isinstance(actual_mod, type(expected_mod)), "Modifier types differ"
183+
assert (
184+
actual_mod.model_dump() == expected_mod.model_dump()
185+
), "Modifier parameters differ"

0 commit comments

Comments
 (0)