Skip to content

Commit

Permalink
Address code smells.
Browse files Browse the repository at this point in the history
  • Loading branch information
myersCody committed Dec 15, 2023
1 parent b91827e commit 971000d
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 13 deletions.
8 changes: 4 additions & 4 deletions koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def create_default_verify_handler(self):

@patch("masu.api.upgrade_trino.util.verify_parquet_files.StateTracker._clean_local_files")
@patch("masu.api.upgrade_trino.util.verify_parquet_files.get_s3_resource")
def test_retrieve_verify_reload_S3_parquet(self, mock_s3_resource, _):
def test_retrieve_verify_reload_s3_parquet(self, mock_s3_resource, _):
"""Test fixes for reindexes on all required columns."""
# build a parquet file where reindex is used for all required columns
test_metadata = [
Expand Down Expand Up @@ -95,7 +95,7 @@ def test_retrieve_verify_reload_S3_parquet(self, mock_s3_resource, _):
mock_bucket.objects.filter.side_effect = filter_side_effect
mock_bucket.download_file.return_value = temp_file
VerifyParquetFiles.local_path = self.temp_dir
verify_handler.retrieve_verify_reload_S3_parquet()
verify_handler.retrieve_verify_reload_s3_parquet()
mock_bucket.upload_fileobj.assert_called()
table = pq.read_table(temp_file)
schema = table.schema
Expand Down Expand Up @@ -249,7 +249,7 @@ def test_other_providers_s3_paths(self):

@patch("masu.api.upgrade_trino.util.verify_parquet_files.StateTracker._clean_local_files")
@patch("masu.api.upgrade_trino.util.verify_parquet_files.get_s3_resource")
def test_retrieve_verify_reload_S3_parquet_failure(self, mock_s3_resource, _):
def test_retrieve_verify_reload_s3_parquet_failure(self, mock_s3_resource, _):
"""Test fixes for reindexes on all required columns."""
# build a parquet file where reindex is used for all required columns
file_data = {
Expand Down Expand Up @@ -277,7 +277,7 @@ def test_retrieve_verify_reload_S3_parquet_failure(self, mock_s3_resource, _):
mock_bucket.objects.filter.side_effect = filter_side_effect
mock_bucket.download_file.return_value = temp_file
VerifyParquetFiles.local_path = self.temp_dir
verify_handler.retrieve_verify_reload_S3_parquet()
verify_handler.retrieve_verify_reload_s3_parquet()
mock_bucket.upload_fileobj.assert_not_called()
os.remove(temp_file)

Expand Down
4 changes: 2 additions & 2 deletions koku/masu/api/upgrade_trino/util/task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def from_query_params(cls, query_params: QueryDict) -> "FixParquetTaskHandler":
return reprocess_kwargs

@classmethod
def clean_column_names(self, provider_type):
def clean_column_names(cls, provider_type):
"""Creates a mapping of columns to expected pyarrow values."""
clean_column_names = {}
provider_mapping = self.REQUIRED_COLUMNS_MAPPING.get(provider_type.replace("-local", ""))
provider_mapping = cls.REQUIRED_COLUMNS_MAPPING.get(provider_type.replace("-local", ""))
# Our required mapping stores the raw column name; however,
# the parquet files will contain the cleaned column name.
for raw_col, default_val in provider_mapping.items():
Expand Down
13 changes: 7 additions & 6 deletions koku/masu/api/upgrade_trino/util/verify_parquet_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def _set_pyarrow_types(self, cleaned_column_mapping):
# TODO: AWS saves datetime as timestamp[ms, tz=UTC]
# Should we be storing in a standard type here?
mapping[key] = pa.timestamp("ms")
elif default_val == "":
elif isinstance(default_val, str):
mapping[key] = pa.string()
elif default_val == 0.0:
elif isinstance(default_val, float):
mapping[key] = pa.float64()
return mapping

Expand Down Expand Up @@ -130,7 +130,7 @@ def local_path(self):
local_path.mkdir(parents=True, exist_ok=True)
return local_path

def retrieve_verify_reload_S3_parquet(self):
def retrieve_verify_reload_s3_parquet(self):
"""Retrieves the s3 files from s3"""
s3_resource = get_s3_resource(settings.S3_ACCESS_KEY, settings.S3_SECRET, settings.S3_REGION)
s3_bucket = s3_resource.Bucket(settings.S3_BUCKET_NAME)
Expand Down Expand Up @@ -201,6 +201,8 @@ def retrieve_verify_reload_S3_parquet(self):

def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_names):
"""Performs a transformation to change a double to a timestamp."""
if not field_names:
return
table = pq.read_table(parquet_file_path)
schema = table.schema
fields = []
Expand Down Expand Up @@ -229,7 +231,7 @@ def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_n
pq.write_table(new_table, parquet_file_path)

# Same logic as last time, but combined into one method & added state tracking
def _coerce_parquet_data_type(self, parquet_file_path, transformation_enabled=True):
def _coerce_parquet_data_type(self, parquet_file_path):
"""If a parquet file has an incorrect dtype we can attempt to coerce
it to the correct type it.
Expand Down Expand Up @@ -296,8 +298,7 @@ def _coerce_parquet_data_type(self, parquet_file_path, transformation_enabled=Tr
table = table.cast(new_schema)
# Write the table back to the Parquet file
pa.parquet.write_table(table, parquet_file_path)
if double_to_timestamp_fields:
self._perform_transformation_double_to_timestamp(parquet_file_path, double_to_timestamp_fields)
self._perform_transformation_double_to_timestamp(parquet_file_path, double_to_timestamp_fields)
# Signal that we need to send this update to S3.
return self.file_tracker.COERCE_REQUIRED

Expand Down
2 changes: 1 addition & 1 deletion koku/masu/celery/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
@celery_app.task(name="masu.celery.tasks.fix_parquet_data_types", queue=DEFAULT)
def fix_parquet_data_types(*args, **kwargs):
verify_parquet = VerifyParquetFiles(*args, **kwargs)
verify_parquet.retrieve_verify_reload_S3_parquet()
verify_parquet.retrieve_verify_reload_s3_parquet()

Check warning on line 65 in koku/masu/celery/tasks.py

View check run for this annotation

Codecov / codecov/patch

koku/masu/celery/tasks.py#L64-L65

Added lines #L64 - L65 were not covered by tests


@celery_app.task(name="masu.celery.tasks.check_report_updates", queue=DEFAULT)
Expand Down

0 comments on commit 971000d

Please sign in to comment.