Skip to content

Commit

Permalink
Fix wrong mocks (#694)
Browse files Browse the repository at this point in the history
* Fix wrong mocks

* Fix tests

* Remove print patch

* Fix lint
  • Loading branch information
fealho committed Aug 25, 2023
1 parent 2bfa2e2 commit 5d7f8b7
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 10 deletions.
3 changes: 1 addition & 2 deletions rdt/transformers/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,7 @@ def _reverse_transform_helper(self, data):
normalized = np.clip(data[:, 0], -1, 1)
means = self._bgm_transformer.means_.reshape([-1])
stds = np.sqrt(self._bgm_transformer.covariances_).reshape([-1])
selected_component = data[:, 1].astype(int)

selected_component = data[:, 1].astype(int) # maybe round instead?
std_t = stds[self.valid_component_indicator][selected_component]
mean_t = means[self.valid_component_indicator][selected_component]
reversed_data = normalized * self.STD_MULTIPLIER * std_t + mean_t
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2260,8 +2260,7 @@ def test_update_transformers_no_field_transformers(self):

assert instance.get_config() == expected_config

@patch('rdt.hyper_transformer.print')
def test_update_transformers_missmatch_sdtypes(self, mock_warnings):
def test_update_transformers_mismatch_sdtypes(self):
"""Test update transformers.
Ensure that the function updates properly the ``self.field_transformers`` and prints the
Expand Down Expand Up @@ -2303,7 +2302,6 @@ def test_update_transformers_missmatch_sdtypes(self, mock_warnings):
with pytest.raises(InvalidConfigError, match=err_msg):
instance.update_transformers(column_name_to_transformer)

assert mock_warnings.called_once_with(err_msg)
instance._validate_transformers.assert_called_once_with(column_name_to_transformer)

def test_update_transformers_transformer_is_none(self):
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/transformers/pii/test_anonymizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def test___init__default(self, mock_check_provider_function, mock_faker):
assert instance.function_name == 'lexify'
assert instance.function_kwargs == {}
assert instance.locales is None
assert mock_faker.Faker.called_once_with(None)
mock_faker.Faker.assert_called_once_with(None)
assert instance.enforce_uniqueness is False
assert instance.missing_value_generation == 'random'

Expand Down Expand Up @@ -279,7 +279,7 @@ def test___init__custom(self, mock_check_provider_function, mock_faker):
assert instance.function_name == 'credit_card_full'
assert instance.function_kwargs == {'type': 'visa'}
assert instance.locales == ['en_US', 'fr_FR']
assert mock_faker.Faker.called_once_with(['en_US', 'fr_FR'])
mock_faker.Faker.assert_called_once_with(['en_US', 'fr_FR'])
assert instance.enforce_uniqueness

def test___init__no_function_name(self):
Expand Down Expand Up @@ -346,7 +346,7 @@ def test_reset_randomization(self, mock_faker, mock_base_reset):
AnonymizedFaker.reset_randomization(instance)

# Assert
assert mock_faker.Faker.called_once_with(['en_US'])
mock_faker.Faker.assert_has_calls([call(None), call(['en_US'])])
mock_base_reset.assert_called_once()

def test__fit(self):
Expand Down Expand Up @@ -597,7 +597,7 @@ def test___init__super_attrs(self, mock_check_provider_function, mock_faker):
assert instance.function_name == 'lexify'
assert instance.function_kwargs == {}
assert instance.locales is None
assert mock_faker.Faker.called_once_with(None)
mock_faker.Faker.assert_called_once_with(None)

@patch('rdt.transformers.pii.anonymizer.faker')
@patch('rdt.transformers.pii.anonymizer.AnonymizedFaker.check_provider_function')
Expand Down Expand Up @@ -641,7 +641,7 @@ def test___init__custom(self, mock_check_provider_function, mock_faker):
assert instance.function_name == 'credit_card_full'
assert instance.function_kwargs == {'type': 'visa'}
assert instance.locales == ['en_US', 'fr_FR']
assert mock_faker.Faker.called_once_with(['en_US', 'fr_FR'])
mock_faker.Faker.assert_called_once_with(['en_US', 'fr_FR'])

def test_get_mapping(self):
"""Test the ``get_mapping`` method.
Expand Down

0 comments on commit 5d7f8b7

Please sign in to comment.