diff --git a/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py b/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py index 41992d98ff24a..98c370e37555d 100644 --- a/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py +++ b/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py @@ -78,6 +78,7 @@ LABEL_DAG_ID = "dag_id" LABEL_LOGICAL_DATE = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date" LABEL_TRY_NUMBER = "try_number" +LABEL_RUN_ID = "run_id" @attrs.define(kw_only=True) @@ -197,8 +198,20 @@ def upload(self, path: os.PathLike | str, ti: RuntimeTI | None = None) -> None: shutil.rmtree(parent, ignore_errors=True) def read(self, relative_path: str, ti: RuntimeTI) -> LogResponse: - """Read logs from Stackdriver Logging using task instance labels.""" - ti_labels = _task_instance_to_labels(ti) + """Read logs from Stackdriver Logging using task instance labels. + + Filters on ``run_id`` instead of ``logical_date`` because the supervisor + process that hosts ``REMOTE_TASK_LOG`` has no DB connection to convert + ``run_id`` → ``logical_date``. The write path (Bug 1 / #68246) already + writes ``run_id`` as a label, so the read filter matches what was actually + written. + """ + ti_labels = { + LABEL_DAG_ID: ti.dag_id, + LABEL_TASK_ID: ti.task_id, + LABEL_RUN_ID: ti.run_id, + LABEL_TRY_NUMBER: str(ti.try_number), + } log_filter = self.prepare_log_filter(ti_labels) messages, end_of_log, _ = self.read_logs(log_filter, next_page_token=None, all_pages=True) return [f"Reading remote log from Stackdriver for {relative_path}"], [messages] if messages else [] diff --git a/providers/google/tests/unit/google/cloud/log/test_stackdriver_task_handler.py b/providers/google/tests/unit/google/cloud/log/test_stackdriver_task_handler.py index 0eb6c209f8fdb..c805bd866e427 100644 --- a/providers/google/tests/unit/google/cloud/log/test_stackdriver_task_handler.py +++ b/providers/google/tests/unit/google/cloud/log/test_stackdriver_task_handler.py @@ -78,10 +78,7 @@ def test_read_logs(self, mock_client, mock_get_creds_and_project_id): ti.task_id = "test_task" ti.dag_id = "test_dag" ti.try_number = 1 - if AIRFLOW_V_3_0_PLUS: - ti.logical_date = timezone.datetime(2016, 1, 1) - else: - ti.execution_date = timezone.datetime(2016, 1, 1) + ti.run_id = "run1" messages, logs = self.io.read("dag_id=test_dag/run_id=run1/task_id=test_task/attempt=1.log", ti) @@ -101,16 +98,45 @@ def test_read_logs_empty(self, mock_client, mock_get_creds_and_project_id): ti.task_id = "test_task" ti.dag_id = "test_dag" ti.try_number = 1 - if AIRFLOW_V_3_0_PLUS: - ti.logical_date = timezone.datetime(2016, 1, 1) - else: - ti.execution_date = timezone.datetime(2016, 1, 1) + ti.run_id = "run1" messages, logs = self.io.read("test/path", ti) assert len(messages) == 1 assert logs == [] + @mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id") + @mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.LoggingServiceV2Client") + def test_read_logs_uses_run_id_filter(self, mock_client, mock_get_creds_and_project_id): + """``read()`` must filter on ``run_id``, not ``logical_date``. + + In AF3's supervisor model the REMOTE_TASK_LOG handler runs in the + supervisor process which has no DB connection to derive + ``logical_date`` from ``run_id``. The read path must use ``run_id`` + directly so it matches the label the write path already emits (Bug 1). + """ + mock_client.return_value.list_log_entries.return_value.pages = iter( + [_create_list_log_entries_response_mock(["MSG1"], None)] + ) + mock_get_creds_and_project_id.return_value = ("creds", "project_id") + + ti = mock.MagicMock() + ti.task_id = "t" + ti.dag_id = "d" + ti.run_id = "run123" + ti.try_number = 2 + + messages, logs = self.io.read("dag_id=d/run_id=run123/task_id=t/attempt=2.log", ti) + + request = mock_client.return_value.list_log_entries.call_args.kwargs["request"] + assert 'labels.run_id="run123"' in request.filter + assert 'labels.try_number="2"' in request.filter + assert 'labels.task_id="t"' in request.filter + assert 'labels.dag_id="d"' in request.filter + assert "logical_date" not in request.filter + assert "execution_date" not in request.filter + assert logs == ["MSG1"] + @mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id") @mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client") def test_credentials(self, mock_client, mock_get_creds_and_project_id):