From e4e67bdb4d4881d364abfb2a714e89527c9e6b12 Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul <mahdi.benjelloul@gmail.com> Date: Wed, 28 Nov 2018 22:41:26 +0100 Subject: [PATCH] Add csv import test --- .../tests/test_scenario_csv.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 openfisca_survey_manager/tests/test_scenario_csv.py diff --git a/openfisca_survey_manager/tests/test_scenario_csv.py b/openfisca_survey_manager/tests/test_scenario_csv.py new file mode 100644 index 00000000..17370f02 --- /dev/null +++ b/openfisca_survey_manager/tests/test_scenario_csv.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- + + +import logging +import os +import pandas as pd +import pkg_resources + +from openfisca_core.model_api import * # noqa analysis:ignore +from openfisca_core import periods +from openfisca_survey_manager.input_dataframe_generator import ( + make_input_dataframe_by_entity, + random_data_generator, + randomly_init_variable, + ) +from openfisca_country_template import CountryTaxBenefitSystem +from openfisca_survey_manager.tests.test_scenario import generate_input_input_dataframe_by_entity +from openfisca_survey_manager.scenarios import AbstractSurveyScenario + + +log = logging.getLogger(__name__) +tax_benefit_system = CountryTaxBenefitSystem() +directory = os.path.join( + pkg_resources.get_distribution('openfisca-survey-manager').location, + 'openfisca_survey_manager', + 'tests', + 'data_files', + 'dump', + ) + + +def create_entity_csv_files(): + input_dataframe_by_entity = generate_input_input_dataframe_by_entity(nb_persons = 10, nb_groups = 5, salary_max_value = 50000, + rent_max_value = 1000) + for entity, dataframe in input_dataframe_by_entity.items(): + dataframe.to_csv(os.path.join(directory, "{}.csv".format(entity)), index = False) + + +def test_survey_scenario_csv_import(): + survey_scenario = AbstractSurveyScenario() + survey_scenario.set_tax_benefit_systems(tax_benefit_system = tax_benefit_system) + survey_scenario.year = 2017 + survey_scenario.used_as_input_variables = ['salary', 'rent'] + period = periods.period('2017-01') + survey_scenario.tax_benefit_system.entities + input_data_frame_by_entity = dict() + for entity in survey_scenario.tax_benefit_system.entities: + entity_key = entity.key + dataframe = pd.read_csv(os.path.join(directory, "{}.csv".format(entity_key))) + input_data_frame_by_entity[entity_key] = dataframe + + data = { + 'input_data_frame_by_entity_by_period': { + period: input_data_frame_by_entity + } + } + survey_scenario.init_from_data(data = data) + simulation = survey_scenario.simulation + error = 1e-03 + assert ( + (simulation.calculate('salary', period) - input_data_frame_by_entity['person']['salary']).abs() + < error).all() + assert ( + (simulation.calculate('rent', period) - input_data_frame_by_entity['household']['rent']).abs() + < error).all() + + +if __name__ == "__main__": + import sys + log = logging.getLogger(__name__) + logging.basicConfig(level = logging.DEBUG, stream = sys.stdout) + create_entity_csv_files() + test_survey_scenario_csv_import() \ No newline at end of file