diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py index 79f25d88c156a0..d3b94d3808240f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py @@ -49,6 +49,7 @@ from datahub.ingestion.source.state.profiling_state_handler import ProfilingHandler from datahub.ingestion.source.state.redundant_run_skip_handler import ( RedundantLineageRunSkipHandler, + RedundantQueriesRunSkipHandler, RedundantUsageRunSkipHandler, ) from datahub.ingestion.source.state.stale_entity_removal_handler import ( @@ -145,7 +146,10 @@ def __init__(self, ctx: PipelineContext, config: BigQueryV2Config): redundant_lineage_run_skip_handler: Optional[RedundantLineageRunSkipHandler] = ( None ) - if self.config.enable_stateful_lineage_ingestion: + if ( + self.config.enable_stateful_lineage_ingestion + and not self.config.use_queries_v2 + ): redundant_lineage_run_skip_handler = RedundantLineageRunSkipHandler( source=self, config=self.config, @@ -296,6 +300,17 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: ): return + redundant_queries_run_skip_handler: Optional[ + RedundantQueriesRunSkipHandler + ] = None + if self.config.enable_stateful_time_window: + redundant_queries_run_skip_handler = RedundantQueriesRunSkipHandler( + source=self, + config=self.config, + pipeline_name=self.ctx.pipeline_name, + run_id=self.ctx.run_id, + ) + with ( self.report.new_stage(f"*: {QUERIES_EXTRACTION}"), BigQueryQueriesExtractor( @@ -315,6 +330,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: structured_report=self.report, filters=self.filters, identifiers=self.identifiers, + redundant_run_skip_handler=redundant_queries_run_skip_handler, schema_resolver=self.sql_parser_schema_resolver, discovered_tables=self.bq_schema_extractor.table_refs, ) as queries_extractor, diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_config.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_config.py index 67710a6a5a2de1..314076b8ad6207 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_config.py @@ -25,6 +25,7 @@ from datahub.ingestion.source.state.stateful_ingestion_base import ( StatefulLineageConfigMixin, StatefulProfilingConfigMixin, + StatefulTimeWindowConfigMixin, StatefulUsageConfigMixin, ) from datahub.ingestion.source.usage.usage_common import BaseUsageConfig @@ -271,6 +272,7 @@ class BigQueryV2Config( SQLCommonConfig, StatefulUsageConfigMixin, StatefulLineageConfigMixin, + StatefulTimeWindowConfigMixin, StatefulProfilingConfigMixin, ClassificationSourceConfigMixin, ): @@ -527,6 +529,20 @@ def validate_upstream_lineage_in_report(cls, v: bool, values: Dict) -> bool: return v + @root_validator(pre=False, skip_on_failure=True) + def validate_queries_v2_stateful_ingestion(cls, values: Dict) -> Dict: + if values.get("use_queries_v2"): + if values.get("enable_stateful_lineage_ingestion") or values.get( + "enable_stateful_usage_ingestion" + ): + logger.warning( + "enable_stateful_lineage_ingestion and enable_stateful_usage_ingestion are deprecated " + "when using use_queries_v2=True. These configs only work with the legacy (non-queries v2) extraction path. " + "For queries v2, use enable_stateful_time_window instead to enable stateful ingestion " + "for the unified time window extraction (lineage + usage + operations + queries)." + ) + return values + def get_table_pattern(self, pattern: List[str]) -> str: return "|".join(pattern) if pattern else "" diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries_extractor.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries_extractor.py index aaa6d75ca96e0f..8097890d38afcf 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries_extractor.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries_extractor.py @@ -36,6 +36,9 @@ BigQueryFilter, BigQueryIdentifierBuilder, ) +from datahub.ingestion.source.state.redundant_run_skip_handler import ( + RedundantQueriesRunSkipHandler, +) from datahub.ingestion.source.usage.usage_common import BaseUsageConfig from datahub.metadata.urns import CorpUserUrn from datahub.sql_parsing.schema_resolver import SchemaResolver @@ -135,6 +138,7 @@ def __init__( structured_report: SourceReport, filters: BigQueryFilter, identifiers: BigQueryIdentifierBuilder, + redundant_run_skip_handler: Optional[RedundantQueriesRunSkipHandler] = None, graph: Optional[DataHubGraph] = None, schema_resolver: Optional[SchemaResolver] = None, discovered_tables: Optional[Collection[str]] = None, @@ -158,6 +162,9 @@ def __init__( ) self.structured_report = structured_report + self.redundant_run_skip_handler = redundant_run_skip_handler + + self.start_time, self.end_time = self._get_time_window() self.aggregator = SqlParsingAggregator( platform=self.identifiers.platform, @@ -172,8 +179,8 @@ def __init__( generate_query_usage_statistics=self.config.include_query_usage_statistics, usage_config=BaseUsageConfig( bucket_duration=self.config.window.bucket_duration, - start_time=self.config.window.start_time, - end_time=self.config.window.end_time, + start_time=self.start_time, + end_time=self.end_time, user_email_pattern=self.config.user_email_pattern, top_n_queries=self.config.top_n_queries, ), @@ -199,6 +206,34 @@ def local_temp_path(self) -> pathlib.Path: logger.info(f"Using local temp path: {path}") return path + def _get_time_window(self) -> tuple[datetime, datetime]: + if self.redundant_run_skip_handler: + start_time, end_time = ( + self.redundant_run_skip_handler.suggest_run_time_window( + self.config.window.start_time, + self.config.window.end_time, + ) + ) + else: + start_time = self.config.window.start_time + end_time = self.config.window.end_time + + # Usage statistics are aggregated per bucket (typically per day). + # To ensure accurate aggregated metrics, we need to align the start_time + # to the beginning of a bucket so that we include complete bucket periods. + if self.config.include_usage_statistics: + start_time = get_time_bucket(start_time, self.config.window.bucket_duration) + + return start_time, end_time + + def _update_state(self) -> None: + if self.redundant_run_skip_handler: + self.redundant_run_skip_handler.update_state( + self.config.window.start_time, + self.config.window.end_time, + self.config.window.bucket_duration, + ) + def is_temp_table(self, name: str) -> bool: try: table = BigqueryTableIdentifier.from_string_name(name) @@ -299,6 +334,8 @@ def get_workunits_internal( shared_connection.close() audit_log_file.unlink(missing_ok=True) + self._update_state() + def deduplicate_queries( self, queries: FileBackedList[ObservedQuery] ) -> FileBackedDict[Dict[int, ObservedQuery]]: @@ -355,8 +392,8 @@ def fetch_region_query_log( query_log_query = _build_enriched_query_log_query( project_id=project.id, region=region, - start_time=self.config.window.start_time, - end_time=self.config.window.end_time, + start_time=self.start_time, + end_time=self.end_time, ) logger.info(f"Fetching query log from BQ Project {project.id} for {region}") diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index 566e044057ced6..03ef7afbce713e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -31,6 +31,7 @@ from datahub.ingestion.source.state.stateful_ingestion_base import ( StatefulLineageConfigMixin, StatefulProfilingConfigMixin, + StatefulTimeWindowConfigMixin, StatefulUsageConfigMixin, ) from datahub.ingestion.source.usage.usage_common import BaseUsageConfig @@ -199,6 +200,7 @@ class SnowflakeV2Config( SnowflakeUsageConfig, StatefulLineageConfigMixin, StatefulUsageConfigMixin, + StatefulTimeWindowConfigMixin, StatefulProfilingConfigMixin, ClassificationSourceConfigMixin, IncrementalPropertiesConfigMixin, @@ -477,6 +479,20 @@ def validate_shares( return shares + @root_validator(pre=False, skip_on_failure=True) + def validate_queries_v2_stateful_ingestion(cls, values: Dict) -> Dict: + if values.get("use_queries_v2"): + if values.get("enable_stateful_lineage_ingestion") or values.get( + "enable_stateful_usage_ingestion" + ): + logger.warning( + "enable_stateful_lineage_ingestion and enable_stateful_usage_ingestion are deprecated " + "when using use_queries_v2=True. These configs only work with the legacy (non-queries v2) extraction path. " + "For queries v2, use enable_stateful_time_window instead to enable stateful ingestion " + "for the unified time window extraction (lineage + usage + operations + queries)." + ) + return values + def outbounds(self) -> Dict[str, Set[DatabaseId]]: """ Returns mapping of diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index 3ea5daa5ce9184..fceb76166bb7eb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -17,6 +17,7 @@ from datahub.configuration.time_window_config import ( BaseTimeWindowConfig, BucketDuration, + get_time_bucket, ) from datahub.ingestion.api.closeable import Closeable from datahub.ingestion.api.common import PipelineContext @@ -50,6 +51,9 @@ StoredProcLineageReport, StoredProcLineageTracker, ) +from datahub.ingestion.source.state.redundant_run_skip_handler import ( + RedundantQueriesRunSkipHandler, +) from datahub.ingestion.source.usage.usage_common import BaseUsageConfig from datahub.metadata.urns import CorpUserUrn from datahub.sql_parsing.schema_resolver import SchemaResolver @@ -180,6 +184,7 @@ def __init__( structured_report: SourceReport, filters: SnowflakeFilter, identifiers: SnowflakeIdentifierBuilder, + redundant_run_skip_handler: Optional[RedundantQueriesRunSkipHandler] = None, graph: Optional[DataHubGraph] = None, schema_resolver: Optional[SchemaResolver] = None, discovered_tables: Optional[List[str]] = None, @@ -191,9 +196,13 @@ def __init__( self.filters = filters self.identifiers = identifiers self.discovered_tables = set(discovered_tables) if discovered_tables else None + self.redundant_run_skip_handler = redundant_run_skip_handler self._structured_report = structured_report + # Adjust time window based on stateful ingestion state + self.start_time, self.end_time = self._get_time_window() + # The exit stack helps ensure that we close all the resources we open. self._exit_stack = contextlib.ExitStack() @@ -211,8 +220,8 @@ def __init__( generate_query_usage_statistics=self.config.include_query_usage_statistics, usage_config=BaseUsageConfig( bucket_duration=self.config.window.bucket_duration, - start_time=self.config.window.start_time, - end_time=self.config.window.end_time, + start_time=self.start_time, + end_time=self.end_time, user_email_pattern=self.config.user_email_pattern, # TODO make the rest of the fields configurable ), @@ -228,6 +237,34 @@ def __init__( def structured_reporter(self) -> SourceReport: return self._structured_report + def _get_time_window(self) -> tuple[datetime, datetime]: + if self.redundant_run_skip_handler: + start_time, end_time = ( + self.redundant_run_skip_handler.suggest_run_time_window( + self.config.window.start_time, + self.config.window.end_time, + ) + ) + else: + start_time = self.config.window.start_time + end_time = self.config.window.end_time + + # Usage statistics are aggregated per bucket (typically per day). + # To ensure accurate aggregated metrics, we need to align the start_time + # to the beginning of a bucket so that we include complete bucket periods. + if self.config.include_usage_statistics: + start_time = get_time_bucket(start_time, self.config.window.bucket_duration) + + return start_time, end_time + + def _update_state(self) -> None: + if self.redundant_run_skip_handler: + self.redundant_run_skip_handler.update_state( + self.config.window.start_time, + self.config.window.end_time, + self.config.window.bucket_duration, + ) + @functools.cached_property def local_temp_path(self) -> pathlib.Path: if self.config.local_temp_path: @@ -355,6 +392,9 @@ def get_workunits_internal( with self.report.aggregator_generate_timer: yield from auto_workunit(self.aggregator.gen_metadata()) + # Update the stateful ingestion state after successful extraction + self._update_state() + def fetch_users(self) -> UsersMapping: users: UsersMapping = dict() with self.structured_reporter.report_exc("Error fetching users from Snowflake"): @@ -378,8 +418,8 @@ def fetch_copy_history(self) -> Iterable[KnownLineageMapping]: # Derived from _populate_external_lineage_from_copy_history. query: str = SnowflakeQuery.copy_lineage_history( - start_time_millis=int(self.config.window.start_time.timestamp() * 1000), - end_time_millis=int(self.config.window.end_time.timestamp() * 1000), + start_time_millis=int(self.start_time.timestamp() * 1000), + end_time_millis=int(self.end_time.timestamp() * 1000), downstreams_deny_pattern=self.config.temporary_tables_pattern, ) @@ -414,8 +454,8 @@ def fetch_query_log( Union[PreparsedQuery, TableRename, TableSwap, ObservedQuery, StoredProcCall] ]: query_log_query = QueryLogQueryBuilder( - start_time=self.config.window.start_time, - end_time=self.config.window.end_time, + start_time=self.start_time, + end_time=self.end_time, bucket_duration=self.config.window.bucket_duration, deny_usernames=self.config.pushdown_deny_usernames, allow_usernames=self.config.pushdown_allow_usernames, diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index f2161618cc7e98..f5558d1d7cb107 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -73,6 +73,7 @@ from datahub.ingestion.source.state.profiling_state_handler import ProfilingHandler from datahub.ingestion.source.state.redundant_run_skip_handler import ( RedundantLineageRunSkipHandler, + RedundantQueriesRunSkipHandler, RedundantUsageRunSkipHandler, ) from datahub.ingestion.source.state.stale_entity_removal_handler import ( @@ -207,7 +208,7 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): ) self.report.sql_aggregator = self.aggregator.report - if self.config.include_table_lineage: + if self.config.include_table_lineage and not self.config.use_queries_v2: redundant_lineage_run_skip_handler: Optional[ RedundantLineageRunSkipHandler ] = None @@ -589,6 +590,17 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: with self.report.new_stage(f"*: {QUERIES_EXTRACTION}"): schema_resolver = self.aggregator._schema_resolver + redundant_queries_run_skip_handler: Optional[ + RedundantQueriesRunSkipHandler + ] = None + if self.config.enable_stateful_time_window: + redundant_queries_run_skip_handler = RedundantQueriesRunSkipHandler( + source=self, + config=self.config, + pipeline_name=self.ctx.pipeline_name, + run_id=self.ctx.run_id, + ) + queries_extractor = SnowflakeQueriesExtractor( connection=self.connection, # TODO: this should be its own section in main recipe @@ -614,6 +626,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: structured_report=self.report, filters=self.filters, identifiers=self.identifiers, + redundant_run_skip_handler=redundant_queries_run_skip_handler, schema_resolver=schema_resolver, discovered_tables=self.discovered_datasets, graph=self.ctx.graph, diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py b/metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py index e4a2646f6ccd3c..8171a3aa8f8ad0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py @@ -244,3 +244,24 @@ def update_state( cur_state.begin_timestamp_millis = datetime_to_ts_millis(start_time) cur_state.end_timestamp_millis = datetime_to_ts_millis(end_time) cur_state.bucket_duration = bucket_duration + + +class RedundantQueriesRunSkipHandler(RedundantRunSkipHandler): + """ + Handler for stateful ingestion of queries v2 extraction. + Manages the time window for audit log extraction that combines + lineage, usage, operations, and queries. + """ + + def get_job_name_suffix(self): + return "_audit_window" + + def update_state( + self, start_time: datetime, end_time: datetime, bucket_duration: BucketDuration + ) -> None: + cur_checkpoint = self.get_current_checkpoint() + if cur_checkpoint: + cur_state = cast(BaseTimeWindowCheckpointState, cur_checkpoint.state) + cur_state.begin_timestamp_millis = datetime_to_ts_millis(start_time) + cur_state.end_timestamp_millis = datetime_to_ts_millis(end_time) + cur_state.bucket_duration = bucket_duration diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py b/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py index 8d1743ee678fe8..2d31231eb53217 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py @@ -101,7 +101,9 @@ class StatefulLineageConfigMixin(ConfigModel): default=True, description="Enable stateful lineage ingestion." " This will store lineage window timestamps after successful lineage ingestion. " - "and will not run lineage ingestion for same timestamps in subsequent run. ", + "and will not run lineage ingestion for same timestamps in subsequent run. " + "NOTE: This only works with use_queries_v2=False (legacy extraction path). " + "For queries v2, use enable_stateful_time_window instead.", ) _store_last_lineage_extraction_timestamp = pydantic_renamed_field( @@ -150,7 +152,9 @@ class StatefulUsageConfigMixin(BaseTimeWindowConfig): default=True, description="Enable stateful lineage ingestion." " This will store usage window timestamps after successful usage ingestion. " - "and will not run usage ingestion for same timestamps in subsequent run. ", + "and will not run usage ingestion for same timestamps in subsequent run. " + "NOTE: This only works with use_queries_v2=False (legacy extraction path). " + "For queries v2, use enable_stateful_time_window instead.", ) _store_last_usage_extraction_timestamp = pydantic_renamed_field( @@ -169,6 +173,30 @@ def last_usage_extraction_stateful_option_validator(cls, values: Dict) -> Dict: return values +class StatefulTimeWindowConfigMixin(BaseTimeWindowConfig): + enable_stateful_time_window: bool = Field( + default=False, + description="Enable stateful time window tracking." + " This will store the time window after successful extraction " + "and adjust the time window in subsequent runs to avoid reprocessing. " + "NOTE: This is ONLY applicable when using queries v2 (use_queries_v2=True). " + "This replaces enable_stateful_lineage_ingestion and enable_stateful_usage_ingestion " + "for the queries v2 extraction path, since queries v2 extracts lineage, usage, operations, " + "and queries together from a single audit log and uses a unified time window.", + ) + + @root_validator(skip_on_failure=True) + def time_window_stateful_option_validator(cls, values: Dict) -> Dict: + sti = values.get("stateful_ingestion") + if not sti or not sti.enabled: + if values.get("enable_stateful_time_window"): + logger.warning( + "Stateful ingestion is disabled, disabling enable_stateful_time_window config option as well" + ) + values["enable_stateful_time_window"] = False + return values + + @dataclass class StatefulIngestionReport(SourceReport): pass diff --git a/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery_queries.py b/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery_queries_integration.py similarity index 100% rename from metadata-ingestion/tests/integration/bigquery_v2/test_bigquery_queries.py rename to metadata-ingestion/tests/integration/bigquery_v2/test_bigquery_queries_integration.py diff --git a/metadata-ingestion/tests/unit/snowflake/test_snowflake_queries.py b/metadata-ingestion/tests/unit/snowflake/test_snowflake_queries.py index 35e5de6ebbe973..010de18f1dacfa 100644 --- a/metadata-ingestion/tests/unit/snowflake/test_snowflake_queries.py +++ b/metadata-ingestion/tests/unit/snowflake/test_snowflake_queries.py @@ -1,4 +1,5 @@ import datetime +from typing import Optional from unittest.mock import Mock, patch import pytest @@ -20,6 +21,9 @@ SnowflakeQueriesExtractorConfig, ) from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery +from datahub.ingestion.source.state.redundant_run_skip_handler import ( + RedundantQueriesRunSkipHandler, +) class TestBuildAccessHistoryDatabaseFilterCondition: @@ -599,3 +603,219 @@ def test_report_counts_with_disabled_features(self): # Verify that num_preparsed_queries is 0 assert extractor.report.sql_aggregator is not None assert extractor.report.sql_aggregator.num_preparsed_queries == 0 + + +class TestSnowflakeQueriesExtractorStatefulTimeWindowIngestion: + """Tests for stateful time window ingestion support in queries v2.""" + + def _create_mock_extractor( + self, + include_usage_statistics: bool = False, + redundant_run_skip_handler: Optional[RedundantQueriesRunSkipHandler] = None, + bucket_duration: BucketDuration = BucketDuration.DAY, + ) -> SnowflakeQueriesExtractor: + """Helper to create a SnowflakeQueriesExtractor with mocked dependencies.""" + mock_connection = Mock() + mock_connection.query.return_value = [] + + config = SnowflakeQueriesExtractorConfig( + window=BaseTimeWindowConfig( + start_time=datetime.datetime(2021, 1, 1, tzinfo=datetime.timezone.utc), + end_time=datetime.datetime(2021, 1, 2, tzinfo=datetime.timezone.utc), + bucket_duration=bucket_duration, + ), + include_usage_statistics=include_usage_statistics, + ) + + mock_report = Mock() + mock_filters = Mock() + mock_identifiers = Mock() + mock_identifiers.platform = "snowflake" + mock_identifiers.identifier_config = SnowflakeIdentifierConfig() + + extractor = SnowflakeQueriesExtractor( + connection=mock_connection, + config=config, + structured_report=mock_report, + filters=mock_filters, + identifiers=mock_identifiers, + redundant_run_skip_handler=redundant_run_skip_handler, + ) + + return extractor + + def test_time_window_adjusted_with_handler(self): + """Test that time window is adjusted when handler is provided.""" + adjusted_start_time = datetime.datetime( + 2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc + ) + adjusted_end_time = datetime.datetime( + 2021, 1, 2, 12, 0, 0, tzinfo=datetime.timezone.utc + ) + + mock_handler = Mock(spec=RedundantQueriesRunSkipHandler) + mock_handler.suggest_run_time_window.return_value = ( + adjusted_start_time, + adjusted_end_time, + ) + + extractor = self._create_mock_extractor( + redundant_run_skip_handler=mock_handler, + ) + + mock_handler.suggest_run_time_window.assert_called_once() + assert extractor.start_time == adjusted_start_time + assert extractor.end_time == adjusted_end_time + + def test_time_window_not_adjusted_without_handler(self): + """Test that time window is not adjusted when no handler is provided.""" + original_start_time = datetime.datetime( + 2021, 1, 1, tzinfo=datetime.timezone.utc + ) + original_end_time = datetime.datetime(2021, 1, 2, tzinfo=datetime.timezone.utc) + + extractor = self._create_mock_extractor( + redundant_run_skip_handler=None, + ) + + assert extractor.start_time == original_start_time + assert extractor.end_time == original_end_time + + def test_bucket_alignment_with_usage_statistics(self): + """Test that start_time is aligned to bucket boundaries when usage statistics are enabled.""" + # Start time at 14:30 should be aligned to beginning of day (00:00) + start_time_with_offset = datetime.datetime( + 2021, 1, 1, 14, 30, 0, tzinfo=datetime.timezone.utc + ) + mock_handler = Mock(spec=RedundantQueriesRunSkipHandler) + mock_handler.suggest_run_time_window.return_value = ( + start_time_with_offset, + datetime.datetime(2021, 1, 2, tzinfo=datetime.timezone.utc), + ) + + extractor = self._create_mock_extractor( + include_usage_statistics=True, + redundant_run_skip_handler=mock_handler, + ) + + # Start time should be aligned to beginning of day + expected_aligned_start = datetime.datetime( + 2021, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc + ) + assert extractor.start_time == expected_aligned_start + # End time should remain unchanged + assert extractor.end_time == datetime.datetime( + 2021, 1, 2, tzinfo=datetime.timezone.utc + ) + + def test_no_bucket_alignment_without_usage_statistics(self): + """Test that start_time is NOT aligned when usage statistics are disabled.""" + start_time_with_offset = datetime.datetime( + 2021, 1, 1, 14, 30, 0, tzinfo=datetime.timezone.utc + ) + mock_handler = Mock(spec=RedundantQueriesRunSkipHandler) + mock_handler.suggest_run_time_window.return_value = ( + start_time_with_offset, + datetime.datetime(2021, 1, 2, tzinfo=datetime.timezone.utc), + ) + + extractor = self._create_mock_extractor( + include_usage_statistics=False, + redundant_run_skip_handler=mock_handler, + ) + + # Start time should NOT be aligned + assert extractor.start_time == start_time_with_offset + assert extractor.end_time == datetime.datetime( + 2021, 1, 2, tzinfo=datetime.timezone.utc + ) + + def test_bucket_alignment_hourly_with_usage_statistics(self): + """Test that start_time is aligned to hour boundaries when hourly buckets are configured.""" + start_time_with_offset = datetime.datetime( + 2021, 1, 1, 14, 30, 45, tzinfo=datetime.timezone.utc + ) + mock_handler = Mock(spec=RedundantQueriesRunSkipHandler) + mock_handler.suggest_run_time_window.return_value = ( + start_time_with_offset, + datetime.datetime(2021, 1, 2, tzinfo=datetime.timezone.utc), + ) + + extractor = self._create_mock_extractor( + include_usage_statistics=True, + redundant_run_skip_handler=mock_handler, + bucket_duration=BucketDuration.HOUR, + ) + + expected_aligned_start = datetime.datetime( + 2021, 1, 1, 14, 0, 0, tzinfo=datetime.timezone.utc + ) + assert extractor.start_time == expected_aligned_start + assert extractor.end_time == datetime.datetime( + 2021, 1, 2, tzinfo=datetime.timezone.utc + ) + + def test_state_updated_after_successful_extraction(self): + """Test that state is updated after successful extraction when handler is provided.""" + mock_handler = Mock(spec=RedundantQueriesRunSkipHandler) + mock_handler.suggest_run_time_window.return_value = ( + datetime.datetime(2021, 1, 1, tzinfo=datetime.timezone.utc), + datetime.datetime(2021, 1, 2, tzinfo=datetime.timezone.utc), + ) + + extractor = self._create_mock_extractor( + redundant_run_skip_handler=mock_handler, + ) + + with ( + patch.object(extractor, "fetch_users", return_value={}), + patch.object(extractor, "fetch_copy_history", return_value=[]), + patch.object(extractor, "fetch_query_log", return_value=[]), + ): + list(extractor.get_workunits_internal()) + + mock_handler.update_state.assert_called_once_with( + datetime.datetime(2021, 1, 1, tzinfo=datetime.timezone.utc), + datetime.datetime(2021, 1, 2, tzinfo=datetime.timezone.utc), + BucketDuration.DAY, + ) + + def test_state_not_updated_without_handler(self): + """Test that state is not updated when no handler is provided.""" + extractor = self._create_mock_extractor( + redundant_run_skip_handler=None, + ) + + with ( + patch.object(extractor, "fetch_users", return_value={}), + patch.object(extractor, "fetch_copy_history", return_value=[]), + patch.object(extractor, "fetch_query_log", return_value=[]), + ): + list(extractor.get_workunits_internal()) + + def test_queries_extraction_always_runs_with_handler(self): + """Test that queries extraction always runs even with a skip handler.""" + mock_handler = Mock(spec=RedundantQueriesRunSkipHandler) + mock_handler.suggest_run_time_window.return_value = ( + datetime.datetime(2021, 1, 1, tzinfo=datetime.timezone.utc), + datetime.datetime(2021, 1, 2, tzinfo=datetime.timezone.utc), + ) + + extractor = self._create_mock_extractor( + redundant_run_skip_handler=mock_handler, + ) + + with ( + patch.object(extractor, "fetch_users", return_value={}) as mock_fetch_users, + patch.object( + extractor, "fetch_copy_history", return_value=[] + ) as mock_fetch_copy_history, + patch.object( + extractor, "fetch_query_log", return_value=[] + ) as mock_fetch_query_log, + ): + list(extractor.get_workunits_internal()) + + mock_fetch_users.assert_called_once() + mock_fetch_copy_history.assert_called_once() + mock_fetch_query_log.assert_called_once() diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_redundant_run_skip_handler.py b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_redundant_run_skip_handler.py index bfd85a9c2eaed0..da42a460674a52 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_redundant_run_skip_handler.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_redundant_run_skip_handler.py @@ -31,6 +31,7 @@ def stateful_source(mock_datahub_graph: DataHubGraph) -> Iterable[SnowflakeV2Sou account_id="ABC12345.ap-south-1", username="TST_USR", password="TST_PWD", + use_queries_v2=False, # Use legacy path for testing redundant run skip handlers stateful_ingestion=StatefulStaleMetadataRemovalConfig( enabled=True, # Uses the graph from the pipeline context. @@ -159,11 +160,6 @@ def test_redundant_run_skip_handler( suggested_start_time: datetime, suggested_end_time: datetime, ) -> None: - # mock_datahub_graph - - # mocked_source = mock.MagicMock() - # mocked_config = mock.MagicMock() - with mock.patch( "datahub.ingestion.source.state.stateful_ingestion_base.StateProviderWrapper.get_last_checkpoint" ) as mocked_fn: