From ca11b7af239c90f6d7c00ed5116fdd08fda1a5ab Mon Sep 17 00:00:00 2001 From: "silviu.surcica" Date: Wed, 6 Mar 2024 18:59:31 +0200 Subject: [PATCH] Fix BigQueryTablePartitionExistenceTrigger partition query (#37655) * use table_id for partition query * tests * fix mock * fix mock; * fix mock --------- Co-authored-by: Silviu-Surcica --- .../providers/google/cloud/hooks/bigquery.py | 4 ++- .../google/cloud/triggers/bigquery.py | 4 ++- .../google/cloud/hooks/test_bigquery.py | 36 +++++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index 06e59c2b20f3f..20b1ac10c8021 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -3311,6 +3311,7 @@ async def get_job_output( async def create_job_for_partition_get( self, dataset_id: str | None, + table_id: str | None = None, project_id: str | None = None, ): """Create a new job and get the job_id using gcloud-aio.""" @@ -3320,7 +3321,8 @@ async def create_job_for_partition_get( query_request = { "query": "SELECT partition_id " - f"FROM `{project_id}.{dataset_id}.INFORMATION_SCHEMA.PARTITIONS`", + f"FROM `{project_id}.{dataset_id}.INFORMATION_SCHEMA.PARTITIONS`" + + (f" WHERE table_id={table_id}" if table_id else ""), "useLegacySql": False, } job_query_resp = await job_client.query(query_request, cast(Session, session)) diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py index bc9e812d1b28c..302316e4ae581 100644 --- a/airflow/providers/google/cloud/triggers/bigquery.py +++ b/airflow/providers/google/cloud/triggers/bigquery.py @@ -681,7 +681,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] await asyncio.sleep(self.poll_interval) else: - job_id = await hook.create_job_for_partition_get(self.dataset_id, project_id=self.project_id) + job_id = await hook.create_job_for_partition_get( + self.dataset_id, table_id=self.table_id, project_id=self.project_id + ) self.log.info("Sleeping for %s seconds.", self.poll_interval) await asyncio.sleep(self.poll_interval) diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py index 5ca34b276f21c..0214b382390ff 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery.py +++ b/tests/providers/google/cloud/hooks/test_bigquery.py @@ -2181,6 +2181,42 @@ async def test_get_job_output_assert_once_with(self, mock_job_instance): resp = await hook.get_job_output(job_id=JOB_ID, project_id=PROJECT_ID) assert resp == response + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.ClientSession") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance") + async def test_create_job_for_partition_get_with_table(self, mock_job_instance, mock_client_session): + hook = BigQueryAsyncHook() + mock_job_client = AsyncMock(Job) + mock_job_instance.return_value = mock_job_client + mock_session = AsyncMock() + mock_client_session.return_value.__aenter__.return_value = mock_session + expected_query_request = { + "query": "SELECT partition_id " + f"FROM `{PROJECT_ID}.{DATASET_ID}.INFORMATION_SCHEMA.PARTITIONS`" + f" WHERE table_id={TABLE_ID}", + "useLegacySql": False, + } + await hook.create_job_for_partition_get( + dataset_id=DATASET_ID, table_id=TABLE_ID, project_id=PROJECT_ID + ) + mock_job_client.query.assert_called_once_with(expected_query_request, mock_session) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.ClientSession") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance") + async def test_create_job_for_partition_get(self, mock_job_instance, mock_client_session): + hook = BigQueryAsyncHook() + mock_job_client = AsyncMock(Job) + mock_job_instance.return_value = mock_job_client + mock_session = AsyncMock() + mock_client_session.return_value.__aenter__.return_value = mock_session + expected_query_request = { + "query": f"SELECT partition_id FROM `{PROJECT_ID}.{DATASET_ID}.INFORMATION_SCHEMA.PARTITIONS`", + "useLegacySql": False, + } + await hook.create_job_for_partition_get(dataset_id=DATASET_ID, project_id=PROJECT_ID) + mock_job_client.query.assert_called_once_with(expected_query_request, mock_session) + def test_interval_check_for_airflow_exception(self): """ Assert that check return AirflowException