|
| 1 | +import os |
1 | 2 | import tempfile
|
2 | 3 |
|
3 | 4 | import pytest
|
4 | 5 | import yaml
|
5 | 6 |
|
6 | 7 | from llmcompressor.modifiers.obcq.base import SparseGPTModifier
|
| 8 | +from llmcompressor.modifiers.pruning.constant import ConstantPruningModifier |
7 | 9 | from llmcompressor.recipe import Recipe
|
8 | 10 | from tests.llmcompressor.helpers import valid_recipe_strings
|
9 | 11 |
|
10 | 12 |
|
11 | 13 | @pytest.mark.parametrize("recipe_str", valid_recipe_strings())
|
12 |
| -def test_recipe_create_instance_accepts_valid_recipe_string(recipe_str): |
13 |
| - recipe = Recipe.create_instance(recipe_str) |
14 |
| - assert recipe is not None, "Recipe could not be created from string" |
| 14 | +class TestRecipeWithStrings: |
| 15 | + """Tests that use various recipe strings for validation.""" |
| 16 | + |
| 17 | + def test_create_from_string(self, recipe_str): |
| 18 | + """Test creating a Recipe from a YAML string.""" |
| 19 | + recipe = Recipe.create_instance(recipe_str) |
| 20 | + assert recipe is not None, "Recipe could not be created from string" |
| 21 | + assert isinstance(recipe, Recipe), "Created object is not a Recipe instance" |
| 22 | + |
| 23 | + def test_create_from_file(self, recipe_str): |
| 24 | + """Test creating a Recipe from a YAML file.""" |
| 25 | + content = yaml.safe_load(recipe_str) |
| 26 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f: |
| 27 | + yaml.dump(content, f) |
| 28 | + f.flush() # Ensure content is written |
| 29 | + recipe = Recipe.create_instance(f.name) |
| 30 | + assert recipe is not None, "Recipe could not be created from file" |
| 31 | + assert isinstance(recipe, Recipe), "Created object is not a Recipe instance" |
| 32 | + |
| 33 | + def test_yaml_serialization_roundtrip(self, recipe_str): |
| 34 | + """ |
| 35 | + Test that a recipe can be serialized to YAML |
| 36 | + and deserialized back with all properties preserved. |
| 37 | + """ |
| 38 | + # Create original recipe |
| 39 | + original_recipe = Recipe.create_instance(recipe_str) |
15 | 40 |
|
| 41 | + # Serialize to YAML |
| 42 | + yaml_str = original_recipe.yaml() |
| 43 | + assert yaml_str, "Serialized YAML string should not be empty" |
16 | 44 |
|
17 |
| -@pytest.mark.parametrize("recipe_str", valid_recipe_strings()) |
18 |
| -def test_recipe_create_instance_accepts_valid_recipe_file(recipe_str): |
19 |
| - content = yaml.safe_load(recipe_str) |
20 |
| - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f: |
21 |
| - yaml.dump(content, f) |
22 |
| - recipe = Recipe.create_instance(f.name) |
23 |
| - assert recipe is not None, "Recipe could not be created from file" |
| 45 | + # Deserialize from YAML |
| 46 | + deserialized_recipe = Recipe.create_instance(yaml_str) |
24 | 47 |
|
| 48 | + # Compare serialized forms |
| 49 | + original_dict = original_recipe.model_dump() |
| 50 | + deserialized_dict = deserialized_recipe.model_dump() |
25 | 51 |
|
26 |
| -@pytest.mark.parametrize("recipe_str", valid_recipe_strings()) |
27 |
| -def test_yaml_serialization(recipe_str): |
28 |
| - recipe_instance = Recipe.create_instance(recipe_str) |
29 |
| - serialized_recipe = recipe_instance.yaml() |
30 |
| - recipe_from_serialized = Recipe.create_instance(serialized_recipe) |
| 52 | + assert original_dict == deserialized_dict, "Serialization roundtrip failed" |
31 | 53 |
|
32 |
| - expected_dict = recipe_instance.model_dump() |
33 |
| - actual_dict = recipe_from_serialized.model_dump() |
| 54 | + def test_model_dump_and_validate(self, recipe_str): |
| 55 | + """ |
| 56 | + Test that model_dump produces a format compatible |
| 57 | + with model_validate. |
| 58 | + """ |
| 59 | + recipe = Recipe.create_instance(recipe_str) |
| 60 | + validated_recipe = Recipe.model_validate(recipe.model_dump()) |
| 61 | + assert ( |
| 62 | + recipe == validated_recipe |
| 63 | + ), "Recipe instance and validated recipe do not match" |
34 | 64 |
|
35 |
| - assert expected_dict == actual_dict |
36 | 65 |
|
| 66 | +class TestRecipeSerialization: |
| 67 | + """ |
| 68 | + Tests for Recipe serialization and deserialization |
| 69 | + edge cases.""" |
37 | 70 |
|
38 |
| -@pytest.mark.parametrize("recipe_str", valid_recipe_strings()) |
39 |
| -def test_model_dump_and_validate(recipe_str): |
40 |
| - recipe_instance = Recipe.create_instance(recipe_str) |
41 |
| - validated_recipe = Recipe.model_validate(recipe_instance.model_dump()) |
42 |
| - assert ( |
43 |
| - recipe_instance == validated_recipe |
44 |
| - ), "Recipe instance and validated recipe do not match" |
45 |
| - |
46 |
| - |
47 |
| -def test_recipe_creates_correct_modifier(): |
48 |
| - start = 1 |
49 |
| - end = 10 |
50 |
| - targets = "__ALL_PRUNABLE__" |
51 |
| - |
52 |
| - yaml_str = f""" |
53 |
| - test_stage: |
54 |
| - pruning_modifiers: |
55 |
| - ConstantPruningModifier: |
56 |
| - start: {start} |
57 |
| - end: {end} |
58 |
| - targets: {targets} |
59 |
| - """ |
| 71 | + def test_empty_recipe_serialization(self): |
| 72 | + """Test serialization of a minimal recipe with no stages.""" |
| 73 | + recipe = Recipe() |
| 74 | + assert len(recipe.stages) == 0, "New recipe should have no stages" |
| 75 | + |
| 76 | + # Test roundtrip serialization |
| 77 | + dumped = recipe.model_dump() |
| 78 | + loaded = Recipe.model_validate(dumped) |
| 79 | + assert recipe == loaded, "Empty recipe serialization failed" |
60 | 80 |
|
61 |
| - recipe_instance = Recipe.create_instance(yaml_str) |
62 |
| - |
63 |
| - stage_modifiers = recipe_instance.create_modifier() |
64 |
| - assert len(stage_modifiers) == 1 |
65 |
| - assert len(modifiers := stage_modifiers[0].modifiers) == 1 |
66 |
| - from llmcompressor.modifiers.pruning.constant import ConstantPruningModifier |
67 |
| - |
68 |
| - assert isinstance(modifier := modifiers[0], ConstantPruningModifier) |
69 |
| - assert modifier.start == start |
70 |
| - assert modifier.end == end |
71 |
| - |
72 |
| - |
73 |
| -def test_recipe_can_be_created_from_modifier_instances(): |
74 |
| - modifier = SparseGPTModifier( |
75 |
| - sparsity=0.5, |
76 |
| - ) |
77 |
| - group_name = "dummy" |
78 |
| - |
79 |
| - # for pep8 compliance |
80 |
| - recipe_str = ( |
81 |
| - f"{group_name}_stage:\n" |
82 |
| - " pruning_modifiers:\n" |
83 |
| - " SparseGPTModifier:\n" |
84 |
| - " sparsity: 0.5\n" |
85 |
| - ) |
86 |
| - |
87 |
| - expected_recipe_instance = Recipe.create_instance(recipe_str) |
88 |
| - expected_modifiers = expected_recipe_instance.create_modifier() |
89 |
| - |
90 |
| - actual_recipe_instance = Recipe.create_instance( |
91 |
| - [modifier], modifier_group_name=group_name |
92 |
| - ) |
93 |
| - actual_modifiers = actual_recipe_instance.create_modifier() |
94 |
| - |
95 |
| - # assert num stages is the same |
96 |
| - assert len(actual_modifiers) == len(expected_modifiers) |
97 |
| - |
98 |
| - # assert num modifiers in each stage is the same |
99 |
| - assert len(actual_modifiers[0].modifiers) == len(expected_modifiers[0].modifiers) |
100 |
| - |
101 |
| - # assert modifiers in each stage are the same type |
102 |
| - # and have the same parameters |
103 |
| - for actual_modifier, expected_modifier in zip( |
104 |
| - actual_modifiers[0].modifiers, expected_modifiers[0].modifiers |
105 |
| - ): |
106 |
| - assert isinstance(actual_modifier, type(expected_modifier)) |
107 |
| - assert actual_modifier.model_dump() == expected_modifier.model_dump() |
| 81 | + def test_file_serialization(self): |
| 82 | + """Test serializing a recipe to a file and reading it back.""" |
| 83 | + recipe = Recipe.create_instance(valid_recipe_strings()[0]) |
| 84 | + |
| 85 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 86 | + file_path = os.path.join(temp_dir, "recipe.yaml") |
| 87 | + |
| 88 | + # Write to file |
| 89 | + recipe.yaml(file_path=file_path) |
| 90 | + assert os.path.exists(file_path), "YAML file was not created" |
| 91 | + assert os.path.getsize(file_path) > 0, "YAML file is empty" |
| 92 | + |
| 93 | + # Read back from file |
| 94 | + loaded_recipe = Recipe.create_instance(file_path) |
| 95 | + assert ( |
| 96 | + recipe == loaded_recipe |
| 97 | + ), "Recipe loaded from file doesn't match original" |
| 98 | + |
| 99 | + |
| 100 | +class TestRecipeModifiers: |
| 101 | + """Tests for creating and working with modifiers in recipes.""" |
| 102 | + |
| 103 | + def test_creates_correct_modifier(self): |
| 104 | + """ |
| 105 | + Test that a recipe creates the expected modifier type |
| 106 | + with correct parameters. |
| 107 | + """ |
| 108 | + # Recipe parameters |
| 109 | + params = {"start": 1, "end": 10, "targets": "__ALL_PRUNABLE__"} |
| 110 | + |
| 111 | + # Create recipe from YAML |
| 112 | + yaml_str = f""" |
| 113 | + test_stage: |
| 114 | + pruning_modifiers: |
| 115 | + ConstantPruningModifier: |
| 116 | + start: {params['start']} |
| 117 | + end: {params['end']} |
| 118 | + targets: {params['targets']} |
| 119 | + """ |
| 120 | + recipe = Recipe.create_instance(yaml_str) |
| 121 | + |
| 122 | + # Get modifiers from recipe |
| 123 | + stage_modifiers = recipe.create_modifier() |
| 124 | + assert len(stage_modifiers) == 1, "Expected exactly one stage modifier" |
| 125 | + |
| 126 | + modifiers = stage_modifiers[0].modifiers |
| 127 | + assert len(modifiers) == 1, "Expected exactly one modifier in the stage" |
| 128 | + |
| 129 | + # Verify modifier type and parameters |
| 130 | + modifier = modifiers[0] |
| 131 | + assert isinstance( |
| 132 | + modifier, ConstantPruningModifier |
| 133 | + ), "Wrong modifier type created" |
| 134 | + assert modifier.start == params["start"], "Modifier start value incorrect" |
| 135 | + assert modifier.end == params["end"], "Modifier end value incorrect" |
| 136 | + assert modifier.targets == params["targets"], "Modifier targets incorrect" |
| 137 | + |
| 138 | + def test_create_from_modifier_instances(self): |
| 139 | + """Test creating a recipe from modifier instances.""" |
| 140 | + # Create a modifier instance |
| 141 | + sparsity_value = 0.5 |
| 142 | + modifier = SparseGPTModifier(sparsity=sparsity_value) |
| 143 | + group_name = "dummy" |
| 144 | + |
| 145 | + # Expected YAML representation |
| 146 | + recipe_str = ( |
| 147 | + f"{group_name}_stage:\n" |
| 148 | + " pruning_modifiers:\n" |
| 149 | + " SparseGPTModifier:\n" |
| 150 | + f" sparsity: {sparsity_value}\n" |
| 151 | + ) |
| 152 | + |
| 153 | + # Create recipes for comparison |
| 154 | + expected_recipe = Recipe.create_instance(recipe_str) |
| 155 | + actual_recipe = Recipe.create_instance( |
| 156 | + [modifier], modifier_group_name=group_name |
| 157 | + ) |
| 158 | + |
| 159 | + # Compare recipes by creating and checking their modifiers |
| 160 | + self._compare_recipe_modifiers(actual_recipe, expected_recipe) |
| 161 | + |
| 162 | + def _compare_recipe_modifiers(self, actual_recipe, expected_recipe): |
| 163 | + """Helper method to compare modifiers created from two recipes.""" |
| 164 | + actual_modifiers = actual_recipe.create_modifier() |
| 165 | + expected_modifiers = expected_recipe.create_modifier() |
| 166 | + |
| 167 | + # Compare stage counts |
| 168 | + assert len(actual_modifiers) == len(expected_modifiers), "Stage counts differ" |
| 169 | + |
| 170 | + if not actual_modifiers: |
| 171 | + return # No modifiers to compare |
| 172 | + |
| 173 | + # Compare modifier counts in each stage |
| 174 | + assert len(actual_modifiers[0].modifiers) == len( |
| 175 | + expected_modifiers[0].modifiers |
| 176 | + ), "Modifier counts differ" |
| 177 | + |
| 178 | + # Compare modifier types and parameters |
| 179 | + for actual_mod, expected_mod in zip( |
| 180 | + actual_modifiers[0].modifiers, expected_modifiers[0].modifiers |
| 181 | + ): |
| 182 | + assert isinstance(actual_mod, type(expected_mod)), "Modifier types differ" |
| 183 | + assert ( |
| 184 | + actual_mod.model_dump() == expected_mod.model_dump() |
| 185 | + ), "Modifier parameters differ" |
0 commit comments