From 1cfa15a601c950c19cceab4965b8aa04e583de1f Mon Sep 17 00:00:00 2001 From: Grzegorz Rusin Date: Wed, 19 Nov 2025 08:03:45 +0100 Subject: [PATCH 1/2] row_count function --- src/databricks/labs/dqx/check_funcs.py | 81 ++++++++++++++++++++++++++ src/databricks/labs/dqx/utils.py | 9 +++ 2 files changed, 90 insertions(+) diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index 5ab24df1f..e1679a383 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -1,8 +1,12 @@ import datetime import re +import logging import warnings import ipaddress import uuid +from typing import Any +from functools import partial +from databricks.labs.blueprint.parallel import Threads, ManyError from collections.abc import Callable from enum import Enum from itertools import zip_longest @@ -19,9 +23,12 @@ normalize_col_str, get_columns_as_strings, to_lowercase, + strip_jvm_stacktrace, ) from databricks.labs.dqx.errors import MissingParameterError, InvalidParameterError, UnsafeSqlQueryError +logger = logging.getLogger(__name__) + _IPV4_OCTET = r"(25[0-5]|2[0-4]\d|1\d{2}|[1-9]?\d)" _IPV4_CIDR_SUFFIX = r"(3[0-2]|[12]?\d)" IPV4_MAX_OCTET_COUNT = 4 @@ -2705,3 +2712,77 @@ def _validate_sql_query_params(query: str, merge_columns: list[str]) -> None: raise UnsafeSqlQueryError( "Provided SQL query is not safe for execution. Please ensure it does not contain any unsafe operations." ) + +@register_rule("dataset") +def row_count(table_expr: str, + row_count_column: str = "row_count", + row_count_error_column: str = "row_count_error", + worker_count: int = 8): + """ + Computes row counts for tables listed in the dataframe's records. + + Each row of a dataframe is expected to contain a table name. The computation is done in parallel on the specified number of workers. + + Args: + table_expr: The sql expression that will be used to get the table fqn name. For example: + - `table_name` when there is only one column in the dataframe with table names. + - `catalog_name || '.' || schema_name || '.' || table_name` when there are 3 columns in the dataframe: catalog_name, schema_name, table_name. + row_count_column: Name of the column that will contain the row count. Defaults to "row_count". + worker_count: Number of workers to use for the computation. Defaults to 8. + + Note: + - The operation collects all tables names from the input dataframe and computes row counts for each table in parallel. + - Ensure that the input dataframe listing table names is not too large to avoid driver memory issues. + - In case of major driver failure operation will lose all computed data and needs to be restarted. + + Returns: + Input dataframe enriched with row counts. + + In case error occurs while computing row count, the _errors column will contain the error message. + """ + + def compute_row_counts(spark: SparkSession, table_name: str) -> dict[str, Any]: + try: + logger.debug(f"Computing row count for table: {table_name!r}") + cnt = spark.table(table_name).count() + logger.info(f"Row count for table {table_name!r}: {cnt}") + return {'__table_name': table_name, row_count_column: cnt} + except Exception as e: + logger.warning(f"Error computing row count for table {table_name!r}: {strip_jvm_stacktrace(e)}") + return {'__table_name': table_name, row_count_error_column: e} + + def apply(df: DataFrame, spark: SparkSession, ref_dfs: dict[str, DataFrame]) -> DataFrame: + df = df.withColumn('__table_name', F.expr(table_expr)) + + table_names = [row['__table_name'] for row in df.select('__table_name').distinct().collect()] + + tasks = [partial(compute_row_counts, spark, table_name) for table_name in table_names] + results, bad = Threads.gather("row_count", tasks, worker_count) + + # bad should normally be empty, as underyling function is not raising any execeptions + # spark session problems should lead to results being exception, not being added to bad + # errors can only happen if there is a problem with python REPL + if bad: + raise ManyError(bad) + + df_counts = spark.createDataFrame( + results, + types.StructType([ + types.StructField("__table_name", types.StringType(), False), + types.StructField(row_count_column, types.LongType(), True), + types.StructField(row_count_error_column, types.StringType(), True) + ])) + + final_df = df.alias("df").join(df_counts.alias("counts"), on="__table_name", how="left") + + final_df = final_df.select("df.*", f"counts.{row_count_column}", f"counts.{row_count_error_column}") + + return final_df + + condition = make_condition( + F.col(row_count_error_column).isNotNull(), + F.concat_ws("", F.lit("Error computing row count for table: "), F.col(row_count_error_column)), + f"{row_count_error_column}_error", + ) + + return condition, apply diff --git a/src/databricks/labs/dqx/utils.py b/src/databricks/labs/dqx/utils.py index 3ec8d0b58..a384d1017 100644 --- a/src/databricks/labs/dqx/utils.py +++ b/src/databricks/labs/dqx/utils.py @@ -448,3 +448,12 @@ def to_lowercase(col_expr: Column, is_array: bool = False) -> Column: if is_array: return F.transform(col_expr, F.lower) return F.lower(col_expr) + + +def strip_jvm_stacktrace(exception: Exception) -> str: + """Returns exception message with 'JVM stacktrace:' part of message stripped""" + s = str(exception) + if stack_idx := s.find("JVM stacktrace:"): + return s[:stack_idx].rstrip() + else: + return s \ No newline at end of file From 462eafe5704a275d962de3e23be50e2171740be7 Mon Sep 17 00:00:00 2001 From: Grzegorz Rusin Date: Thu, 20 Nov 2025 13:41:16 +0100 Subject: [PATCH 2/2] docs update --- docs/dqx/docs/reference/quality_checks.mdx | 39 +++++ src/databricks/labs/dqx/check_funcs.py | 1 + tests/integration/test_dataset_checks.py | 193 +++++++++++++++++++++ 3 files changed, 233 insertions(+) diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx index 4587eab77..c1a9e478b 100644 --- a/docs/dqx/docs/reference/quality_checks.mdx +++ b/docs/dqx/docs/reference/quality_checks.mdx @@ -1396,6 +1396,7 @@ You can also define your own custom dataset-level checks (see [Creating custom c | `compare_datasets` | Compares two DataFrames at both row and column levels, providing detailed information about differences, including new or missing rows and column-level changes. Only columns present in both the source and reference DataFrames are compared. Use with caution if `check_missing_records` is enabled, as this may increase the number of rows in the output beyond the original input DataFrame. The comparison does not support Map types (any column comparison on map type is skipped automatically). Comparing datasets is valuable for validating data during migrations, detecting drift, performing regression testing, or verifying synchronization between source and target systems. | `columns`: columns to use for row matching with the reference DataFrame (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'), if not having primary keys or wanting to match against all columns you can pass 'df.columns'; `ref_columns`: list of columns in the reference DataFrame or Table to row match against the source DataFrame (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'), if not having primary keys or wanting to match against all columns you can pass 'ref_df.columns'; note that `columns` are matched with `ref_columns` by position, so the order of the provided columns in both lists must be exactly aligned; `exclude_columns`: (optional) list of columns to exclude from the value comparison but not from row matching (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'); the `exclude_columns` field does not alter the list of columns used to determine row matches (columns), it only controls which columns are skipped during the value comparison; `ref_df_name`: (optional) name of the reference DataFrame (dictionary of DataFrames can be passed when applying checks); `ref_table`: (optional) fully qualified reference table name; either `ref_df_name` or `ref_table` must be provided but never both; the number of passed `columns` and `ref_columns` must match and keys are checks in the given order; `check_missing_records`: perform a FULL OUTER JOIN to identify records that are missing from source or reference DataFrames, default is False; use with caution as this may increase the number of rows in the output, as unmatched rows from both sides are included; `null_safe_row_matching`: (optional) treat NULLs as equal when matching rows using `columns` and `ref_columns` (default: True); `null_safe_column_value_matching`: (optional) treat NULLs as equal when comparing column values (default: True) | | `is_data_fresh_per_time_window` | Freshness check that validates whether at least X records arrive within every Y-minute time window. | `column`: timestamp column (can be a string column name or a column expression); `window_minutes`: time window in minutes to check for data arrival; `min_records_per_window`: minimum number of records expected per time window; `lookback_windows`: (optional) number of time windows to look back from `curr_timestamp`, it filters records to include only those within the specified number of time windows from `curr_timestamp` (if no lookback is provided, the check is applied to the entire dataset); `curr_timestamp`: (optional) current timestamp column (if not provided, current_timestamp() function is used) | | `has_valid_schema` | Schema check that validates whether the DataFrame schema matches an expected schema. In non-strict mode, validates that all expected columns exist with compatible types (allows extra columns). In strict mode, validates exact schema match (same columns, same order, same types) for all columns by default or for all columns specified in `columns`. This check is applied at the dataset level and reports schema violations for all rows in the DataFrame when incompatibilities are detected. | `expected_schema`: expected schema as a DDL string (e.g., "id INT, name STRING") or StructType object; `columns`: (optional) list of columns to validate (if not provided, all columns are considered); `strict`: (optional) whether to perform strict schema validation (default: False) - False: validates that all expected columns exist with compatible types, True: validates exact schema match | +| `row_count` | Computes row counts for tables listed in the DataFrame's records. Each row is expected to contain a table name, and the computation is done in parallel using the specified number of workers. Results are added as new columns to enrich the input DataFrame. In case of errors during computation, the error message is stored in the error column instead of the count. | `table_expr`: SQL expression to get the table fully qualified name (e.g., `table_name` for single column, or `catalog_name || '.' || schema_name || '.' || table_name` for three columns); `row_count_column`: (optional) name of the column that will contain the row count (default: "row_count"); `row_count_error_column`: (optional) name of the column that will contain error messages if row count computation fails (default: "row_count_error"); `worker_count`: (optional) number of workers to use for parallel computation (default: 8) | **Compare datasets check** @@ -1729,6 +1730,23 @@ Complex data types are supported as well. - id - name +# row_count check with single column containing table names +- criticality: error + check: + function: row_count + arguments: + table_expr: table_name # this column will contain the name of a table to compute row counts for + +# row_count check with multiple columns forming fully qualified table name +- criticality: error + check: + function: row_count + arguments: + table_expr: "catalog_name || '.' || schema_name || '.' || table_name" # expressions on how to compute table name + row_count_column: table_row_count + row_count_error_column: table_row_count_error + worker_count: 16 + # apply check to multiple columns - criticality: error check: @@ -2124,6 +2142,27 @@ checks = [ }, ), + # row_count check with single column containing table names + DQDatasetRule( + criticality="error", + check_func=check_funcs.row_count, + check_func_kwargs={ + "table_expr": "table_name", + }, + ), + + # row_count check with multiple columns forming fully qualified table name + DQDatasetRule( + criticality="error", + check_func=check_funcs.row_count, + check_func_kwargs={ + "table_expr": "catalog_name || '.' || schema_name || '.' || table_name", + "row_count_column": "table_row_count", + "row_count_error_column": "table_row_count_error", + "worker_count": 16, + }, + ), + # apply check to multiple columns *DQForEachColRule( check_func=check_funcs.is_unique, # 'columns' as first argument diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index e1679a383..2033b6846 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -2728,6 +2728,7 @@ def row_count(table_expr: str, - `table_name` when there is only one column in the dataframe with table names. - `catalog_name || '.' || schema_name || '.' || table_name` when there are 3 columns in the dataframe: catalog_name, schema_name, table_name. row_count_column: Name of the column that will contain the row count. Defaults to "row_count". + row_count_error_column: Name of the column that will contain error messages if row count computation fails. Defaults to "row_count_error". worker_count: Number of workers to use for the computation. Defaults to 8. Note: diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py index 99eafa264..fbf46a3a3 100644 --- a/tests/integration/test_dataset_checks.py +++ b/tests/integration/test_dataset_checks.py @@ -20,6 +20,7 @@ compare_datasets, is_data_fresh_per_time_window, has_valid_schema, + row_count, ) from databricks.labs.dqx.utils import get_column_name_or_alias from databricks.labs.dqx.errors import InvalidParameterError @@ -1960,3 +1961,195 @@ def test_has_valid_schema_with_specific_columns_mismatch(spark: SparkSession): "a string, b string, c double, has_invalid_schema string", ) assert_df_equality(actual_condition_df, expected_condition_df, ignore_nullable=True) + + +def test_blabla(): + pass + +def test_row_count_basic(spark, make_schema, make_random): + catalog_name = TEST_CATALOG + test_schema = make_schema(catalog_name=catalog_name) + + # Create test tables with known row counts + table1_name = f"{catalog_name}.{test_schema.name}.{make_random(10).lower()}" + table2_name = f"{catalog_name}.{test_schema.name}.{make_random(10).lower()}" + table3_name = f"{catalog_name}.{test_schema.name}.{make_random(10).lower()}" + + # Create tables with different row counts + spark.createDataFrame([(i,) for i in range(5)], "value int").write.saveAsTable(table1_name) + spark.createDataFrame([(i,) for i in range(10)], "value int").write.saveAsTable(table2_name) + spark.createDataFrame([(i,) for i in range(15)], "value int").write.saveAsTable(table3_name) + + # Create input dataframe with table names + input_df = spark.createDataFrame( + [ + [table1_name], + [table2_name], + [table3_name], + ], + "table_name string", + ) + + # Apply row_count check + condition, apply_method = row_count(table_expr="table_name") + actual_df = apply_method(input_df, spark, {}) + + # Select relevant columns + actual = actual_df.select("table_name", "row_count", "row_count_error", condition) + + expected_schema = "table_name string, row_count long, row_count_error string, row_count_error_error string" + expected = spark.createDataFrame( + [ + [table1_name, 5, None, None], + [table2_name, 10, None, None], + [table3_name, 15, None, None], + ], + expected_schema, + ) + + assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True) + + +def test_row_count_with_composite_table_expression(spark: SparkSession, make_schema, make_random): + catalog_name = TEST_CATALOG + test_schema = make_schema(catalog_name=catalog_name) + + # Create test tables with known row counts + table1 = make_random(10).lower() + table2 = make_random(10).lower() + + table1_fqn = f"{catalog_name}.{test_schema.name}.{table1}" + table2_fqn = f"{catalog_name}.{test_schema.name}.{table2}" + + # Create tables with different row counts + spark.createDataFrame([(i,) for i in range(7)], "value int").write.saveAsTable(table1_fqn) + spark.createDataFrame([(i,) for i in range(12)], "value int").write.saveAsTable(table2_fqn) + + # Create input dataframe with separate columns for catalog, schema, and table + input_df = spark.createDataFrame( + [ + [catalog_name, test_schema.name, table1], + [catalog_name, test_schema.name, table2], + ], + "catalog_name string, schema_name string, table_name string", + ) + + # Apply row_count check with composite table expression + condition, apply_method = row_count( + table_expr="catalog_name || '.' || schema_name || '.' || table_name" + ) + actual_df = apply_method(input_df, spark, {}) + + # Select relevant columns + actual = actual_df.select( + "catalog_name", "schema_name", "table_name", "row_count", "row_count_error", condition + ) + + expected_schema = ( + "catalog_name string, schema_name string, table_name string, " + "row_count long, row_count_error string, row_count_error_error string" + ) + expected = spark.createDataFrame( + [ + [catalog_name, test_schema.name, table1, 7, None, None], + [catalog_name, test_schema.name, table2, 12, None, None], + ], + expected_schema, + ) + + assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True) + + +def test_row_count_with_nonexistent_table(spark: SparkSession, make_schema, make_random): + catalog_name = TEST_CATALOG + test_schema = make_schema(catalog_name=catalog_name) + + # Create one valid table and one non-existent table + valid_table = make_random(10).lower() + valid_table_fqn = f"{catalog_name}.{test_schema.name}.{valid_table}" + nonexistent_table = f"{catalog_name}.{test_schema.name}.nonexistent_{make_random(10).lower()}" + + # Create only the valid table + spark.createDataFrame([(i,) for i in range(3)], "value int").write.saveAsTable(valid_table_fqn) + + # Create input dataframe with both valid and non-existent table + input_df = spark.createDataFrame( + [ + [valid_table_fqn], + [nonexistent_table], + ], + "table_name string", + ) + + # Apply row_count check + condition, apply_method = row_count(table_expr="table_name") + actual_df = apply_method(input_df, spark, {}) + + # Select relevant columns + actual = actual_df.select("table_name", "row_count", "row_count_error", condition) + + # Collect results to check the error case + results = actual.collect() + + # Verify valid table has correct count + valid_row = [r for r in results if r["table_name"] == valid_table_fqn][0] + assert valid_row["row_count"] == 3 + assert valid_row["row_count_error"] is None + assert valid_row["row_count_error_error"] is None + + # Verify non-existent table has error + error_row = [r for r in results if r["table_name"] == nonexistent_table][0] + assert error_row["row_count"] is None + assert error_row["row_count_error"] is not None + assert error_row["row_count_error_error"] is not None + assert "Error computing row count for table" in error_row["row_count_error_error"] + + +def test_row_count_with_custom_columns(spark: SparkSession, make_schema, make_random): + catalog_name = TEST_CATALOG + test_schema = make_schema(catalog_name=catalog_name) + + # Create test tables + table1_name = f"{catalog_name}.{test_schema.name}.{make_random(10).lower()}" + table2_name = f"{catalog_name}.{test_schema.name}.{make_random(10).lower()}" + + spark.createDataFrame([(i,) for i in range(8)], "value int").write.saveAsTable(table1_name) + spark.createDataFrame([(i,) for i in range(4)], "value int").write.saveAsTable(table2_name) + + # Create input dataframe + input_df = spark.createDataFrame( + [ + [table1_name], + [table2_name], + ], + "table_name string", + ) + + # Apply row_count check with custom column names + custom_count_col = "custom_count" + custom_error_col = "custom_error" + condition, apply_method = row_count( + table_expr="table_name", + row_count_column=custom_count_col, + row_count_error_column=custom_error_col, + ) + actual_df = apply_method(input_df, spark, {}) + + # Verify custom columns exist + assert custom_count_col in actual_df.columns + assert custom_error_col in actual_df.columns + + # Select relevant columns + actual = actual_df.select("table_name", custom_count_col, custom_error_col, condition) + + expected_schema = f"table_name string, {custom_count_col} long, {custom_error_col} string, {custom_error_col}_error string" + expected = spark.createDataFrame( + [ + [table1_name, 8, None, None], + [table2_name, 4, None, None], + ], + expected_schema, + ) + + assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True) +