diff --git a/koku/masu/test/util/aws/test_common.py b/koku/masu/test/util/aws/test_common.py index 393b844ad9..20849322ee 100644 --- a/koku/masu/test/util/aws/test_common.py +++ b/koku/masu/test/util/aws/test_common.py @@ -507,34 +507,45 @@ def test_remove_s3_objects_not_matching_metadata(self): "account", Provider.PROVIDER_AWS, "provider_uuid", start_date, Config.CSV_DATA_TYPE ) expected_key = "not_matching_key" - mock_object = Mock(metadata={metadata_key: "this will be deleted"}, key=expected_key) not_matching_summary = Mock() - not_matching_summary.Object.return_value = mock_object + not_matching_summary.key = expected_key + not_expected_key = "matching_key" - mock_object = Mock(metadata={metadata_key: metadata_value}, key=not_expected_key) matching_summary = Mock() - matching_summary.Object.return_value = mock_object + matching_summary.key = not_expected_key + + def mock_head_object(Bucket, Key): + if Key == expected_key: + return {"Metadata": {metadata_key: "this will be deleted"}} + elif Key == not_expected_key: + return {"Metadata": {metadata_key: metadata_value}} + raise ClientError({}, "Error") + with patch("masu.util.aws.common.get_s3_resource") as mock_s3: mock_s3.return_value.Bucket.return_value.objects.filter.return_value = [ not_matching_summary, matching_summary, ] - removed = utils.delete_s3_objects_not_matching_metadata( - "request_id", s3_csv_path, metadata_key=metadata_key, metadata_value_check=metadata_value - ) - self.assertListEqual(removed, [{"Key": expected_key}]) + with patch("boto3.client") as mock_s3_client: + mock_s3_client.return_value.head_object.side_effect = mock_head_object + removed = utils.delete_s3_objects_not_matching_metadata( + "request_id", s3_csv_path, metadata_key=metadata_key, metadata_value_check=metadata_value + ) + self.assertListEqual(removed, [{"Key": expected_key}]) with patch("masu.util.aws.common.get_s3_resource") as mock_s3: - client_error_object = Mock() - client_error_object.Object.side_effect = ClientError({}, "Error") - mock_s3.return_value.Bucket.return_value.objects.filter.return_value = [ - not_matching_summary, - client_error_object, - ] - removed = utils.delete_s3_objects_not_matching_metadata( - "request_id", s3_csv_path, metadata_key=metadata_key, metadata_value_check=metadata_value - ) - self.assertListEqual(removed, []) + client_error_summary = Mock() + client_error_summary.key = expected_key + with patch("boto3.client") as mock_s3_client: + mock_s3_client.return_value.head_object.side_effect = ClientError({}, "Error") + mock_s3.return_value.Bucket.return_value.objects.filter.return_value = [ + not_matching_summary, + client_error_summary, + ] + removed = utils.delete_s3_objects_not_matching_metadata( + "request_id", s3_csv_path, metadata_key=metadata_key, metadata_value_check=metadata_value + ) + self.assertListEqual(removed, []) with patch("masu.util.aws.common.get_s3_objects_not_matching_metadata") as mock_get_objects, patch( "masu.util.aws.common.get_s3_resource" @@ -614,35 +625,45 @@ def test_remove_s3_objects_matching_metadata(self): "account", Provider.PROVIDER_AWS, "provider_uuid", start_date, Config.CSV_DATA_TYPE ) not_expected_key = "not_matching_key" - mock_object = Mock(metadata={metadata_key: "this will not be deleted"}, key=not_expected_key) not_matching_summary = Mock() - not_matching_summary.Object.return_value = mock_object + not_matching_summary.key = not_expected_key expected_key = "matching_key" - mock_object = Mock(metadata={metadata_key: metadata_value}, key=expected_key) matching_summary = Mock() - matching_summary.Object.return_value = mock_object + matching_summary.key = expected_key + + def mock_head_object(Bucket, Key): + if Key == not_expected_key: + return {"Metadata": {metadata_key: "this will not be deleted"}} + elif Key == expected_key: + return {"Metadata": {metadata_key: metadata_value}} + raise ClientError({}, "Error") + with patch("masu.util.aws.common.get_s3_resource") as mock_s3: mock_s3.return_value.Bucket.return_value.objects.filter.return_value = [ not_matching_summary, matching_summary, ] - removed = utils.delete_s3_objects_matching_metadata( - "request_id", s3_csv_path, metadata_key=metadata_key, metadata_value_check=metadata_value - ) - self.assertListEqual(removed, [{"Key": expected_key}]) + with patch("boto3.client") as mock_s3_client: + mock_s3_client.return_value.head_object.side_effect = mock_head_object + removed = utils.delete_s3_objects_matching_metadata( + "request_id", s3_csv_path, metadata_key=metadata_key, metadata_value_check=metadata_value + ) + self.assertListEqual(removed, [{"Key": expected_key}]) with patch("masu.util.aws.common.get_s3_resource") as mock_s3: - client_error_object = Mock() - client_error_object.Object.side_effect = ClientError({}, "Error") - mock_s3.return_value.Bucket.return_value.objects.filter.return_value = [ - not_matching_summary, - client_error_object, - ] - removed = utils.delete_s3_objects_matching_metadata( - "request_id", s3_csv_path, metadata_key=metadata_key, metadata_value_check=metadata_value - ) - self.assertListEqual(removed, []) + client_error_summary = Mock() + client_error_summary.key = not_expected_key + with patch("boto3.client") as mock_s3_client: + mock_s3_client.return_value.head_object.side_effect = ClientError({}, "Error") + mock_s3.return_value.Bucket.return_value.objects.filter.return_value = [ + not_matching_summary, + client_error_summary, + ] + removed = utils.delete_s3_objects_matching_metadata( + "request_id", s3_csv_path, metadata_key=metadata_key, metadata_value_check=metadata_value + ) + self.assertListEqual(removed, []) with patch("masu.util.aws.common.get_s3_objects_matching_metadata") as mock_get_objects, patch( "masu.util.aws.common.get_s3_resource" diff --git a/koku/masu/util/aws/common.py b/koku/masu/util/aws/common.py index 1f4df87815..aaf9791e08 100644 --- a/koku/masu/util/aws/common.py +++ b/koku/masu/util/aws/common.py @@ -715,12 +715,18 @@ def get_s3_objects_matching_metadata( if context is None: context = {} try: + s3_client = boto3.client( + "s3", + aws_access_key_id=settings.S3_ACCESS_KEY, + aws_secret_access_key=settings.S3_SECRET, + region_name=settings.S3_REGION, + ) keys = [] for obj_summary in _get_s3_objects(s3_path): - existing_object = obj_summary.Object() - metadata_value = existing_object.metadata.get(metadata_key) + response = s3_client.head_object(Bucket=obj_summary.bucket_name, Key=obj_summary.key) + metadata_value = response["Metadata"].get(metadata_key) if metadata_value == metadata_value_check: - keys.append(existing_object.key) + keys.append(obj_summary.key) return keys except (EndpointConnectionError, ClientError) as err: LOG.warning( @@ -743,12 +749,18 @@ def get_s3_objects_not_matching_metadata( if context is None: context = {} try: + s3_client = boto3.client( + "s3", + aws_access_key_id=settings.S3_ACCESS_KEY, + aws_secret_access_key=settings.S3_SECRET, + region_name=settings.S3_REGION, + ) keys = [] for obj_summary in _get_s3_objects(s3_path): - existing_object = obj_summary.Object() - metadata_value = existing_object.metadata.get(metadata_key) + response = s3_client.head_object(Bucket=obj_summary.bucket_name, Key=obj_summary.key) + metadata_value = response["Metadata"].get(metadata_key) if metadata_value != metadata_value_check: - keys.append(existing_object.key) + keys.append(obj_summary.key) return keys except (EndpointConnectionError, ClientError) as err: LOG.warning(