Skip to content

Commit

Permalink
fix(derived metrics): Support reusing metrics with Jinja in other met…
Browse files Browse the repository at this point in the history
…rics (#285)

* fix(derived metrics): Support reusing metrics with Jinja in other metrics
* Moving off of regex to .strip()
  • Loading branch information
Vitor-Avila committed Apr 22, 2024
1 parent 258b69b commit 2c445c5
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 49 deletions.
97 changes: 56 additions & 41 deletions src/preset_cli/cli/superset/sync/dbt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,32 +89,27 @@ def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) ->
return f"COUNT(DISTINCT {sql})"

if type_ in {"expression", "derived"}:
if metric.get("skip_parsing"):
return sql.strip()

try:
expression = sqlglot.parse_one(sql, dialect=metric["dialect"])
tokens = expression.find_all(exp.Column)

for token in tokens:
if token.sql() in metrics:
parent_sql = get_metric_expression(token.sql(), metrics)
parent_expression = sqlglot.parse_one(
parent_sql,
dialect=metric["dialect"],
)
token.replace(parent_expression)

return expression.sql(dialect=metric["dialect"])
except ParseError:
for parent_metric in metric["depends_on"]:
parent_metric_name = parent_metric.split(".")[-1]
pattern = r"\b" + re.escape(parent_metric_name) + r"\b"
parent_metric_syntax = get_metric_expression(
parent_metric_name,
metrics,
)
sql = re.sub(pattern, parent_metric_syntax, sql)
sql = replace_metric_syntax(sql, metric["depends_on"], metrics)
return sql

tokens = expression.find_all(exp.Column)

for token in tokens:
if token.sql() in metrics:
parent_sql = get_metric_expression(token.sql(), metrics)
parent_expression = sqlglot.parse_one(
parent_sql,
dialect=metric["dialect"],
)
token.replace(parent_expression)

return expression.sql(dialect=metric["dialect"])

sorted_metric = dict(sorted(metric.items()))
raise Exception(f"Unable to generate metric expression from: {sorted_metric}")

Expand Down Expand Up @@ -206,11 +201,7 @@ def get_metric_definition(
kwargs = meta.pop("superset", {})

return {
"expression": (
get_metric_expression(metric_name, metric_map)
if not metric.get("skip_parsing")
else metric.get("expression") or metric.get("sql")
),
"expression": get_metric_expression(metric_name, metric_map),
"metric_name": metric_name,
"metric_type": (metric.get("type") or metric.get("calculation_method")),
"verbose_name": metric.get("label", metric_name),
Expand All @@ -229,34 +220,37 @@ def get_superset_metrics_per_model(
"""
superset_metrics = defaultdict(list)
for metric in og_metrics:
metric_models = get_metric_models(metric["unique_id"], og_metrics)

# dbt supports creating derived metrics with raw syntax
if len(metric_models) == 0:
try:
metric_models.add(metric["meta"]["superset"].pop("model"))
# dbt supports creating derived metrics with raw syntax. In case the metric doesn't
# rely on other metrics (or rely on other metrics that aren't associated with any
# model), it's required to specify the dataset the metric should be associated with
# under the ``meta.superset.model`` key. If the derived metric is just an expression
# with no dependency, it's not required to parse the metric SQL.
if model := metric.get("meta", {}).get("superset", {}).pop("model", None):
if len(metric["depends_on"]) == 0:
metric["skip_parsing"] = True
except KeyError:
else:
metric_models = get_metric_models(metric["unique_id"], og_metrics)
if len(metric_models) == 0:
_logger.warning(
"Metric %s cannot be calculated because it's not associated with any model."
" Please specify the model under metric.meta.superset.model.",
metric["name"],
)
continue

if len(metric_models) != 1:
_logger.warning(
"Metric %s cannot be calculated because it depends on multiple models: %s",
metric["name"],
", ".join(sorted(metric_models)),
)
continue
if len(metric_models) != 1:
_logger.warning(
"Metric %s cannot be calculated because it depends on multiple models: %s",
metric["name"],
", ".join(sorted(metric_models)),
)
continue
model = metric_models.pop()

metric_definition = get_metric_definition(
metric["name"],
og_metrics,
)
model = metric_models.pop()
superset_metrics[model].append(metric_definition)

for sl_metric in sl_metrics or []:
Expand Down Expand Up @@ -399,3 +393,24 @@ def get_models_from_sql(
raise ValueError(f"Unable to find model for SQL source {table}")

return [model_map[ModelKey(table.db, table.name)] for table in sources]


def replace_metric_syntax(
sql: str,
dependencies: List[str],
metrics: Dict[str, MetricSchema],
) -> str:
"""
Replace metric keys with their SQL syntax.
This method is a fallback in case ``sqlglot`` raises a ``ParseError``.
"""
for parent_metric in dependencies:
parent_metric_name = parent_metric.split(".")[-1]
pattern = r"\b" + re.escape(parent_metric_name) + r"\b"
parent_metric_syntax = get_metric_expression(
parent_metric_name,
metrics,
)
sql = re.sub(pattern, parent_metric_syntax, sql)

return sql.strip()
125 changes: 117 additions & 8 deletions tests/cli/superset/sync/dbt/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
get_metrics_for_model,
get_models_from_sql,
get_superset_metrics_per_model,
replace_metric_syntax,
)


Expand Down Expand Up @@ -929,6 +930,15 @@ def test_get_superset_metrics_per_model_og_derived(
"expression": "1",
},
),
og_metric_schema.load(
{
"name": "revenue",
"unique_id": "revenue",
"depends_on": ["orders"],
"calculation_method": "sum",
"expression": "price_each",
},
),
og_metric_schema.load(
{
"name": "derived_metric_missing_model_info",
Expand Down Expand Up @@ -980,6 +990,38 @@ def test_get_superset_metrics_per_model_og_derived(
""",
},
),
og_metric_schema.load(
{
"name": "derived_combining_other_derived_including_jinja",
"unique_id": "derived_combining_other_derived_including_jinja",
"depends_on": ["derived_metric_with_jinja_and_other_metric", "revenue"],
"dialect": "postgres",
"calculation_method": "derived",
"expression": "derived_metric_with_jinja_and_other_metric / revenue",
},
),
og_metric_schema.load(
{
"name": "simple_derived",
"unique_id": "simple_derived",
"depends_on": [],
"dialect": "postgres",
"calculation_method": "derived",
"expression": "max(order_date)",
"meta": {"superset": {"model": "customers"}},
},
),
og_metric_schema.load(
{
"name": "last_derived_example",
"unique_id": "last_derived_example",
"depends_on": ["simple_derived"],
"dialect": "postgres",
"calculation_method": "derived",
"expression": "simple_derived - 1",
"meta": {"superset": {"model": "customers"}},
},
),
]

result = get_superset_metrics_per_model(og_metrics, [])
Expand All @@ -1000,19 +1042,33 @@ def test_get_superset_metrics_per_model_og_derived(
"extra": "{}",
},
{
"expression": """
SUM(
"expression": """SUM(
{% for x in filter_values('x_values') %}
{{ + x_values }}
{% endfor %}
)
""",
)""",
"metric_name": "derived_metric_with_jinja",
"metric_type": "derived",
"verbose_name": "derived_metric_with_jinja",
"description": "",
"extra": "{}",
},
{
"expression": "max(order_date)",
"metric_name": "simple_derived",
"metric_type": "derived",
"verbose_name": "simple_derived",
"description": "",
"extra": "{}",
},
{
"expression": "MAX(order_date) - 1",
"metric_name": "last_derived_example",
"metric_type": "derived",
"verbose_name": "last_derived_example",
"description": "",
"extra": "{}",
},
],
"orders": [
{
Expand All @@ -1024,18 +1080,71 @@ def test_get_superset_metrics_per_model_og_derived(
"verbose_name": "sales",
},
{
"expression": """
SUM(
"description": "",
"expression": "SUM(price_each)",
"extra": "{}",
"metric_name": "revenue",
"metric_type": "sum",
"verbose_name": "revenue",
},
{
"expression": """SUM(
{% for x in filter_values('x_values') %}
{{ my_sales + SUM(1) }}
{% endfor %}
)
""",
)""",
"metric_name": "derived_metric_with_jinja_and_other_metric",
"metric_type": "derived",
"verbose_name": "derived_metric_with_jinja_and_other_metric",
"description": "",
"extra": "{}",
},
{
"expression": """SUM(
{% for x in filter_values('x_values') %}
{{ my_sales + SUM(1) }}
{% endfor %}
) / SUM(price_each)""",
"metric_name": "derived_combining_other_derived_including_jinja",
"metric_type": "derived",
"verbose_name": "derived_combining_other_derived_including_jinja",
"description": "",
"extra": "{}",
},
],
}


def test_replace_metric_syntax() -> None:
"""
Test the ``replace_metric_syntax`` method.
"""
og_metric_schema = MetricSchema()
sql = "revenue - cost"
metrics = {
"revenue": og_metric_schema.load(
{
"name": "revenue",
"unique_id": "revenue",
"depends_on": [],
"calculation_method": "derived",
"expression": "SUM({{ url_param['aggreagtor'] }})",
"dialect": "postgres",
},
),
"cost": og_metric_schema.load(
{
"name": "cost",
"unique_id": "cost",
"depends_on": [],
"calculation_method": "derived",
"expression": "SUM({{ filter_values['test'] }})",
"dialect": "postgres",
},
),
}
result = replace_metric_syntax(sql, ["revenue", "cost"], metrics)
assert (
result
== "SUM({{ url_param['aggreagtor'] }}) - SUM({{ filter_values['test'] }})"
)

0 comments on commit 2c445c5

Please sign in to comment.