From 5d7f8b7f57c4be044f046a36c28dec243734cb9b Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Thu, 24 Aug 2023 18:11:15 -0700 Subject: [PATCH] Fix wrong mocks (#694) * Fix wrong mocks * Fix tests * Remove print patch * Fix lint --- rdt/transformers/numerical.py | 3 +-- tests/unit/test_hyper_transformer.py | 4 +--- tests/unit/transformers/pii/test_anonymizer.py | 10 +++++----- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/rdt/transformers/numerical.py b/rdt/transformers/numerical.py index 96316ab81..d0edb2984 100644 --- a/rdt/transformers/numerical.py +++ b/rdt/transformers/numerical.py @@ -524,8 +524,7 @@ def _reverse_transform_helper(self, data): normalized = np.clip(data[:, 0], -1, 1) means = self._bgm_transformer.means_.reshape([-1]) stds = np.sqrt(self._bgm_transformer.covariances_).reshape([-1]) - selected_component = data[:, 1].astype(int) - + selected_component = data[:, 1].astype(int) # maybe round instead? std_t = stds[self.valid_component_indicator][selected_component] mean_t = means[self.valid_component_indicator][selected_component] reversed_data = normalized * self.STD_MULTIPLIER * std_t + mean_t diff --git a/tests/unit/test_hyper_transformer.py b/tests/unit/test_hyper_transformer.py index 6aecbaba4..dc31bc8a9 100644 --- a/tests/unit/test_hyper_transformer.py +++ b/tests/unit/test_hyper_transformer.py @@ -2260,8 +2260,7 @@ def test_update_transformers_no_field_transformers(self): assert instance.get_config() == expected_config - @patch('rdt.hyper_transformer.print') - def test_update_transformers_missmatch_sdtypes(self, mock_warnings): + def test_update_transformers_mismatch_sdtypes(self): """Test update transformers. Ensure that the function updates properly the ``self.field_transformers`` and prints the @@ -2303,7 +2302,6 @@ def test_update_transformers_missmatch_sdtypes(self, mock_warnings): with pytest.raises(InvalidConfigError, match=err_msg): instance.update_transformers(column_name_to_transformer) - assert mock_warnings.called_once_with(err_msg) instance._validate_transformers.assert_called_once_with(column_name_to_transformer) def test_update_transformers_transformer_is_none(self): diff --git a/tests/unit/transformers/pii/test_anonymizer.py b/tests/unit/transformers/pii/test_anonymizer.py index 61cd77857..ba6a88b4f 100644 --- a/tests/unit/transformers/pii/test_anonymizer.py +++ b/tests/unit/transformers/pii/test_anonymizer.py @@ -218,7 +218,7 @@ def test___init__default(self, mock_check_provider_function, mock_faker): assert instance.function_name == 'lexify' assert instance.function_kwargs == {} assert instance.locales is None - assert mock_faker.Faker.called_once_with(None) + mock_faker.Faker.assert_called_once_with(None) assert instance.enforce_uniqueness is False assert instance.missing_value_generation == 'random' @@ -279,7 +279,7 @@ def test___init__custom(self, mock_check_provider_function, mock_faker): assert instance.function_name == 'credit_card_full' assert instance.function_kwargs == {'type': 'visa'} assert instance.locales == ['en_US', 'fr_FR'] - assert mock_faker.Faker.called_once_with(['en_US', 'fr_FR']) + mock_faker.Faker.assert_called_once_with(['en_US', 'fr_FR']) assert instance.enforce_uniqueness def test___init__no_function_name(self): @@ -346,7 +346,7 @@ def test_reset_randomization(self, mock_faker, mock_base_reset): AnonymizedFaker.reset_randomization(instance) # Assert - assert mock_faker.Faker.called_once_with(['en_US']) + mock_faker.Faker.assert_has_calls([call(None), call(['en_US'])]) mock_base_reset.assert_called_once() def test__fit(self): @@ -597,7 +597,7 @@ def test___init__super_attrs(self, mock_check_provider_function, mock_faker): assert instance.function_name == 'lexify' assert instance.function_kwargs == {} assert instance.locales is None - assert mock_faker.Faker.called_once_with(None) + mock_faker.Faker.assert_called_once_with(None) @patch('rdt.transformers.pii.anonymizer.faker') @patch('rdt.transformers.pii.anonymizer.AnonymizedFaker.check_provider_function') @@ -641,7 +641,7 @@ def test___init__custom(self, mock_check_provider_function, mock_faker): assert instance.function_name == 'credit_card_full' assert instance.function_kwargs == {'type': 'visa'} assert instance.locales == ['en_US', 'fr_FR'] - assert mock_faker.Faker.called_once_with(['en_US', 'fr_FR']) + mock_faker.Faker.assert_called_once_with(['en_US', 'fr_FR']) def test_get_mapping(self): """Test the ``get_mapping`` method.