From 7772b2455b41ef98af7c531f1c267363e15f6d23 Mon Sep 17 00:00:00 2001 From: "Anthony V." Date: Thu, 17 Oct 2024 17:58:52 +0200 Subject: [PATCH] fix: Shallow copy group_entities and person_entity when cloning TBS (#294) * fix: Shallow copy group_entities and person_entity when cloning TBS * chore: Lint and changelog * test: Re-add extension tests, which now pass --- changelog_entry.yaml | 6 ++++ .../taxbenefitsystems/tax_benefit_system.py | 13 +++++++++ tests/core/test_yaml.py | 7 +++++ .../yaml_tests/test_with_extension.yaml | 28 +++++++++++++++++++ 4 files changed, 54 insertions(+) create mode 100644 tests/fixtures/yaml_tests/test_with_extension.yaml diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..4d118c82 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,6 @@ +- bump: minor + changes: + changed: + - Shallow copy GroupEntities and PopulationEntity when cloning TaxBenefitSystem object + added: + - Two tests related to extensions that were previously removed \ No newline at end of file diff --git a/policyengine_core/taxbenefitsystems/tax_benefit_system.py b/policyengine_core/taxbenefitsystems/tax_benefit_system.py index 98f812f1..d65dd1bb 100644 --- a/policyengine_core/taxbenefitsystems/tax_benefit_system.py +++ b/policyengine_core/taxbenefitsystems/tax_benefit_system.py @@ -667,6 +667,8 @@ def clone(self) -> "TaxBenefitSystem": "_parameters_at_instant_cache", "variables", "entities", + "person_entity", + "group_entities", ): new_dict[key] = value @@ -676,10 +678,21 @@ def clone(self) -> "TaxBenefitSystem": variable_name: variable.clone() for variable_name, variable in self.variables.items() } + + # Apply shallow copies to all relevant entities new_dict["entities"] = [copy.copy(entity) for entity in self.entities] + new_dict["person_entity"] = copy.copy(self.person_entity) + new_dict["group_entities"] = [ + copy.copy(entity) for entity in self.group_entities + ] + # For all shallow-copied entities, set entity._tax_benefit_system to the new system for entity in new_dict["entities"]: entity.set_tax_benefit_system(new) + for entity in new_dict["group_entities"]: + entity.set_tax_benefit_system(new) + new_dict["person_entity"].set_tax_benefit_system(new) + return new def entities_plural(self) -> dict: diff --git a/tests/core/test_yaml.py b/tests/core/test_yaml.py index 24ad6f3c..acb740d1 100644 --- a/tests/core/test_yaml.py +++ b/tests/core/test_yaml.py @@ -76,6 +76,13 @@ def test_with_reform(tax_benefit_system): ) +def test_with_extension(tax_benefit_system): + assert ( + run_yaml_test(tax_benefit_system, "test_with_extension.yaml") + == EXIT_OK + ) + + def test_with_anchors(tax_benefit_system): assert ( run_yaml_test(tax_benefit_system, "test_with_anchors.yaml") == EXIT_OK diff --git a/tests/fixtures/yaml_tests/test_with_extension.yaml b/tests/fixtures/yaml_tests/test_with_extension.yaml new file mode 100644 index 00000000..682e512f --- /dev/null +++ b/tests/fixtures/yaml_tests/test_with_extension.yaml @@ -0,0 +1,28 @@ +- name: "Test using an extension" + period: 2017-01 + extensions: + - policyengine_core.extension_template + input: + persons: + parent: {} + child1: {} + household: + parents: [parent] + children: [child1] + output: + local_town_child_allowance: 100 + +- name: "Test using an extension" + period: 2017-01 + extensions: + - policyengine_core.extension_template + input: + persons: + parent: {} + child1: {} + child2: {} + household: + parents: [parent] + children: [child1, child2] + output: + local_town_child_allowance: 200 \ No newline at end of file