From 971000d851a66aedc3fc31a756b131cb19f3e002 Mon Sep 17 00:00:00 2001 From: myersCody Date: Fri, 15 Dec 2023 07:46:43 -0500 Subject: [PATCH] Address code smells. --- .../upgrade_trino/test/test_verify_parquet_files.py | 8 ++++---- koku/masu/api/upgrade_trino/util/task_handler.py | 4 ++-- .../api/upgrade_trino/util/verify_parquet_files.py | 13 +++++++------ koku/masu/celery/tasks.py | 2 +- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py index 64884514c4..e79497dce2 100644 --- a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py @@ -61,7 +61,7 @@ def create_default_verify_handler(self): @patch("masu.api.upgrade_trino.util.verify_parquet_files.StateTracker._clean_local_files") @patch("masu.api.upgrade_trino.util.verify_parquet_files.get_s3_resource") - def test_retrieve_verify_reload_S3_parquet(self, mock_s3_resource, _): + def test_retrieve_verify_reload_s3_parquet(self, mock_s3_resource, _): """Test fixes for reindexes on all required columns.""" # build a parquet file where reindex is used for all required columns test_metadata = [ @@ -95,7 +95,7 @@ def test_retrieve_verify_reload_S3_parquet(self, mock_s3_resource, _): mock_bucket.objects.filter.side_effect = filter_side_effect mock_bucket.download_file.return_value = temp_file VerifyParquetFiles.local_path = self.temp_dir - verify_handler.retrieve_verify_reload_S3_parquet() + verify_handler.retrieve_verify_reload_s3_parquet() mock_bucket.upload_fileobj.assert_called() table = pq.read_table(temp_file) schema = table.schema @@ -249,7 +249,7 @@ def test_other_providers_s3_paths(self): @patch("masu.api.upgrade_trino.util.verify_parquet_files.StateTracker._clean_local_files") @patch("masu.api.upgrade_trino.util.verify_parquet_files.get_s3_resource") - def test_retrieve_verify_reload_S3_parquet_failure(self, mock_s3_resource, _): + def test_retrieve_verify_reload_s3_parquet_failure(self, mock_s3_resource, _): """Test fixes for reindexes on all required columns.""" # build a parquet file where reindex is used for all required columns file_data = { @@ -277,7 +277,7 @@ def test_retrieve_verify_reload_S3_parquet_failure(self, mock_s3_resource, _): mock_bucket.objects.filter.side_effect = filter_side_effect mock_bucket.download_file.return_value = temp_file VerifyParquetFiles.local_path = self.temp_dir - verify_handler.retrieve_verify_reload_S3_parquet() + verify_handler.retrieve_verify_reload_s3_parquet() mock_bucket.upload_fileobj.assert_not_called() os.remove(temp_file) diff --git a/koku/masu/api/upgrade_trino/util/task_handler.py b/koku/masu/api/upgrade_trino/util/task_handler.py index c9d8ad14bf..15aac5180a 100644 --- a/koku/masu/api/upgrade_trino/util/task_handler.py +++ b/koku/masu/api/upgrade_trino/util/task_handler.py @@ -69,10 +69,10 @@ def from_query_params(cls, query_params: QueryDict) -> "FixParquetTaskHandler": return reprocess_kwargs @classmethod - def clean_column_names(self, provider_type): + def clean_column_names(cls, provider_type): """Creates a mapping of columns to expected pyarrow values.""" clean_column_names = {} - provider_mapping = self.REQUIRED_COLUMNS_MAPPING.get(provider_type.replace("-local", "")) + provider_mapping = cls.REQUIRED_COLUMNS_MAPPING.get(provider_type.replace("-local", "")) # Our required mapping stores the raw column name; however, # the parquet files will contain the cleaned column name. for raw_col, default_val in provider_mapping.items(): diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py index 13d7ce74c9..7ad50198e2 100644 --- a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -53,9 +53,9 @@ def _set_pyarrow_types(self, cleaned_column_mapping): # TODO: AWS saves datetime as timestamp[ms, tz=UTC] # Should we be storing in a standard type here? mapping[key] = pa.timestamp("ms") - elif default_val == "": + elif isinstance(default_val, str): mapping[key] = pa.string() - elif default_val == 0.0: + elif isinstance(default_val, float): mapping[key] = pa.float64() return mapping @@ -130,7 +130,7 @@ def local_path(self): local_path.mkdir(parents=True, exist_ok=True) return local_path - def retrieve_verify_reload_S3_parquet(self): + def retrieve_verify_reload_s3_parquet(self): """Retrieves the s3 files from s3""" s3_resource = get_s3_resource(settings.S3_ACCESS_KEY, settings.S3_SECRET, settings.S3_REGION) s3_bucket = s3_resource.Bucket(settings.S3_BUCKET_NAME) @@ -201,6 +201,8 @@ def retrieve_verify_reload_S3_parquet(self): def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_names): """Performs a transformation to change a double to a timestamp.""" + if not field_names: + return table = pq.read_table(parquet_file_path) schema = table.schema fields = [] @@ -229,7 +231,7 @@ def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_n pq.write_table(new_table, parquet_file_path) # Same logic as last time, but combined into one method & added state tracking - def _coerce_parquet_data_type(self, parquet_file_path, transformation_enabled=True): + def _coerce_parquet_data_type(self, parquet_file_path): """If a parquet file has an incorrect dtype we can attempt to coerce it to the correct type it. @@ -296,8 +298,7 @@ def _coerce_parquet_data_type(self, parquet_file_path, transformation_enabled=Tr table = table.cast(new_schema) # Write the table back to the Parquet file pa.parquet.write_table(table, parquet_file_path) - if double_to_timestamp_fields: - self._perform_transformation_double_to_timestamp(parquet_file_path, double_to_timestamp_fields) + self._perform_transformation_double_to_timestamp(parquet_file_path, double_to_timestamp_fields) # Signal that we need to send this update to S3. return self.file_tracker.COERCE_REQUIRED diff --git a/koku/masu/celery/tasks.py b/koku/masu/celery/tasks.py index bc61d1c39a..4f5a2668d5 100644 --- a/koku/masu/celery/tasks.py +++ b/koku/masu/celery/tasks.py @@ -62,7 +62,7 @@ @celery_app.task(name="masu.celery.tasks.fix_parquet_data_types", queue=DEFAULT) def fix_parquet_data_types(*args, **kwargs): verify_parquet = VerifyParquetFiles(*args, **kwargs) - verify_parquet.retrieve_verify_reload_S3_parquet() + verify_parquet.retrieve_verify_reload_s3_parquet() @celery_app.task(name="masu.celery.tasks.check_report_updates", queue=DEFAULT)