diff --git a/src/preset_cli/cli/superset/sync/dbt/metrics.py b/src/preset_cli/cli/superset/sync/dbt/metrics.py index 5b5ad269..d82a0cd9 100644 --- a/src/preset_cli/cli/superset/sync/dbt/metrics.py +++ b/src/preset_cli/cli/superset/sync/dbt/metrics.py @@ -89,32 +89,27 @@ def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) -> return f"COUNT(DISTINCT {sql})" if type_ in {"expression", "derived"}: + if metric.get("skip_parsing"): + return sql.strip() + try: expression = sqlglot.parse_one(sql, dialect=metric["dialect"]) + tokens = expression.find_all(exp.Column) + + for token in tokens: + if token.sql() in metrics: + parent_sql = get_metric_expression(token.sql(), metrics) + parent_expression = sqlglot.parse_one( + parent_sql, + dialect=metric["dialect"], + ) + token.replace(parent_expression) + + return expression.sql(dialect=metric["dialect"]) except ParseError: - for parent_metric in metric["depends_on"]: - parent_metric_name = parent_metric.split(".")[-1] - pattern = r"\b" + re.escape(parent_metric_name) + r"\b" - parent_metric_syntax = get_metric_expression( - parent_metric_name, - metrics, - ) - sql = re.sub(pattern, parent_metric_syntax, sql) + sql = replace_metric_syntax(sql, metric["depends_on"], metrics) return sql - tokens = expression.find_all(exp.Column) - - for token in tokens: - if token.sql() in metrics: - parent_sql = get_metric_expression(token.sql(), metrics) - parent_expression = sqlglot.parse_one( - parent_sql, - dialect=metric["dialect"], - ) - token.replace(parent_expression) - - return expression.sql(dialect=metric["dialect"]) - sorted_metric = dict(sorted(metric.items())) raise Exception(f"Unable to generate metric expression from: {sorted_metric}") @@ -206,11 +201,7 @@ def get_metric_definition( kwargs = meta.pop("superset", {}) return { - "expression": ( - get_metric_expression(metric_name, metric_map) - if not metric.get("skip_parsing") - else metric.get("expression") or metric.get("sql") - ), + "expression": get_metric_expression(metric_name, metric_map), "metric_name": metric_name, "metric_type": (metric.get("type") or metric.get("calculation_method")), "verbose_name": metric.get("label", metric_name), @@ -229,14 +220,17 @@ def get_superset_metrics_per_model( """ superset_metrics = defaultdict(list) for metric in og_metrics: - metric_models = get_metric_models(metric["unique_id"], og_metrics) - - # dbt supports creating derived metrics with raw syntax - if len(metric_models) == 0: - try: - metric_models.add(metric["meta"]["superset"].pop("model")) + # dbt supports creating derived metrics with raw syntax. In case the metric doesn't + # rely on other metrics (or rely on other metrics that aren't associated with any + # model), it's required to specify the dataset the metric should be associated with + # under the ``meta.superset.model`` key. If the derived metric is just an expression + # with no dependency, it's not required to parse the metric SQL. + if model := metric.get("meta", {}).get("superset", {}).pop("model", None): + if len(metric["depends_on"]) == 0: metric["skip_parsing"] = True - except KeyError: + else: + metric_models = get_metric_models(metric["unique_id"], og_metrics) + if len(metric_models) == 0: _logger.warning( "Metric %s cannot be calculated because it's not associated with any model." " Please specify the model under metric.meta.superset.model.", @@ -244,19 +238,19 @@ def get_superset_metrics_per_model( ) continue - if len(metric_models) != 1: - _logger.warning( - "Metric %s cannot be calculated because it depends on multiple models: %s", - metric["name"], - ", ".join(sorted(metric_models)), - ) - continue + if len(metric_models) != 1: + _logger.warning( + "Metric %s cannot be calculated because it depends on multiple models: %s", + metric["name"], + ", ".join(sorted(metric_models)), + ) + continue + model = metric_models.pop() metric_definition = get_metric_definition( metric["name"], og_metrics, ) - model = metric_models.pop() superset_metrics[model].append(metric_definition) for sl_metric in sl_metrics or []: @@ -399,3 +393,24 @@ def get_models_from_sql( raise ValueError(f"Unable to find model for SQL source {table}") return [model_map[ModelKey(table.db, table.name)] for table in sources] + + +def replace_metric_syntax( + sql: str, + dependencies: List[str], + metrics: Dict[str, MetricSchema], +) -> str: + """ + Replace metric keys with their SQL syntax. + This method is a fallback in case ``sqlglot`` raises a ``ParseError``. + """ + for parent_metric in dependencies: + parent_metric_name = parent_metric.split(".")[-1] + pattern = r"\b" + re.escape(parent_metric_name) + r"\b" + parent_metric_syntax = get_metric_expression( + parent_metric_name, + metrics, + ) + sql = re.sub(pattern, parent_metric_syntax, sql) + + return sql.strip() diff --git a/tests/cli/superset/sync/dbt/metrics_test.py b/tests/cli/superset/sync/dbt/metrics_test.py index 9cc602a0..3ea099db 100644 --- a/tests/cli/superset/sync/dbt/metrics_test.py +++ b/tests/cli/superset/sync/dbt/metrics_test.py @@ -19,6 +19,7 @@ get_metrics_for_model, get_models_from_sql, get_superset_metrics_per_model, + replace_metric_syntax, ) @@ -929,6 +930,15 @@ def test_get_superset_metrics_per_model_og_derived( "expression": "1", }, ), + og_metric_schema.load( + { + "name": "revenue", + "unique_id": "revenue", + "depends_on": ["orders"], + "calculation_method": "sum", + "expression": "price_each", + }, + ), og_metric_schema.load( { "name": "derived_metric_missing_model_info", @@ -980,6 +990,38 @@ def test_get_superset_metrics_per_model_og_derived( """, }, ), + og_metric_schema.load( + { + "name": "derived_combining_other_derived_including_jinja", + "unique_id": "derived_combining_other_derived_including_jinja", + "depends_on": ["derived_metric_with_jinja_and_other_metric", "revenue"], + "dialect": "postgres", + "calculation_method": "derived", + "expression": "derived_metric_with_jinja_and_other_metric / revenue", + }, + ), + og_metric_schema.load( + { + "name": "simple_derived", + "unique_id": "simple_derived", + "depends_on": [], + "dialect": "postgres", + "calculation_method": "derived", + "expression": "max(order_date)", + "meta": {"superset": {"model": "customers"}}, + }, + ), + og_metric_schema.load( + { + "name": "last_derived_example", + "unique_id": "last_derived_example", + "depends_on": ["simple_derived"], + "dialect": "postgres", + "calculation_method": "derived", + "expression": "simple_derived - 1", + "meta": {"superset": {"model": "customers"}}, + }, + ), ] result = get_superset_metrics_per_model(og_metrics, []) @@ -1000,19 +1042,33 @@ def test_get_superset_metrics_per_model_og_derived( "extra": "{}", }, { - "expression": """ -SUM( + "expression": """SUM( {% for x in filter_values('x_values') %} {{ + x_values }} {% endfor %} -) -""", +)""", "metric_name": "derived_metric_with_jinja", "metric_type": "derived", "verbose_name": "derived_metric_with_jinja", "description": "", "extra": "{}", }, + { + "expression": "max(order_date)", + "metric_name": "simple_derived", + "metric_type": "derived", + "verbose_name": "simple_derived", + "description": "", + "extra": "{}", + }, + { + "expression": "MAX(order_date) - 1", + "metric_name": "last_derived_example", + "metric_type": "derived", + "verbose_name": "last_derived_example", + "description": "", + "extra": "{}", + }, ], "orders": [ { @@ -1024,18 +1080,71 @@ def test_get_superset_metrics_per_model_og_derived( "verbose_name": "sales", }, { - "expression": """ -SUM( + "description": "", + "expression": "SUM(price_each)", + "extra": "{}", + "metric_name": "revenue", + "metric_type": "sum", + "verbose_name": "revenue", + }, + { + "expression": """SUM( {% for x in filter_values('x_values') %} {{ my_sales + SUM(1) }} {% endfor %} -) -""", +)""", "metric_name": "derived_metric_with_jinja_and_other_metric", "metric_type": "derived", "verbose_name": "derived_metric_with_jinja_and_other_metric", "description": "", "extra": "{}", }, + { + "expression": """SUM( + {% for x in filter_values('x_values') %} + {{ my_sales + SUM(1) }} + {% endfor %} +) / SUM(price_each)""", + "metric_name": "derived_combining_other_derived_including_jinja", + "metric_type": "derived", + "verbose_name": "derived_combining_other_derived_including_jinja", + "description": "", + "extra": "{}", + }, ], } + + +def test_replace_metric_syntax() -> None: + """ + Test the ``replace_metric_syntax`` method. + """ + og_metric_schema = MetricSchema() + sql = "revenue - cost" + metrics = { + "revenue": og_metric_schema.load( + { + "name": "revenue", + "unique_id": "revenue", + "depends_on": [], + "calculation_method": "derived", + "expression": "SUM({{ url_param['aggreagtor'] }})", + "dialect": "postgres", + }, + ), + "cost": og_metric_schema.load( + { + "name": "cost", + "unique_id": "cost", + "depends_on": [], + "calculation_method": "derived", + "expression": "SUM({{ filter_values['test'] }})", + "dialect": "postgres", + }, + ), + } + result = replace_metric_syntax(sql, ["revenue", "cost"], metrics) + assert ( + result + == "SUM({{ url_param['aggreagtor'] }}) - SUM({{ filter_values['test'] }})" + )