Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] DEMO fix(time_comparison):Use Join queries when using time comparison #27853

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from superset import app
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.constants import InstantTimeComparison
from superset.db_engine_specs.base import builtin_time_grains
from superset.tags.models import TagType
from superset.utils import pandas_postprocessing, schema as utils
Expand Down Expand Up @@ -948,6 +949,14 @@ class ChartDataFilterSchema(Schema):
)


class InstantTimeComparisonInfoSchema(Schema):
range = fields.String(
metadata={"description": "Type of time comparison to be used"},
validate=validate.OneOf(choices=[ran.value for ran in InstantTimeComparison]),
)
filter = fields.Nested(ChartDataFilterSchema, allow_none=True)


class ChartDataExtrasSchema(Schema):
relative_start = fields.String(
metadata={
Expand Down Expand Up @@ -998,6 +1007,14 @@ class ChartDataExtrasSchema(Schema):
},
allow_none=True,
)
instant_time_comparison_info = fields.Nested(
InstantTimeComparisonInfoSchema,
metadata={
"description": "Extra parameters to use instant time comparison"
" with JOINs using a single query"
},
allow_none=True,
)


class AnnotationLayerSchema(Schema):
Expand Down
3 changes: 2 additions & 1 deletion superset/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@

# Used when calculating the time shift for time comparison
class InstantTimeComparison(StrEnum):
CUSTOM = "c"
INHERITED = "r"
YEAR = "y"
MONTH = "m"
WEEK = "w"
YEAR = "y"


class RouteMethod: # pylint: disable=too-few-public-methods
Expand Down
159 changes: 156 additions & 3 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=too-many-lines
"""a collection of model-related helper classes and functions"""
import builtins
import copy
import dataclasses
import json
import logging
Expand Down Expand Up @@ -57,7 +58,7 @@
from superset.advanced_data_type.types import AdvancedDataTypeResponse
from superset.common.db_query_status import QueryStatus
from superset.common.utils.time_range_utils import get_since_until_from_time_range
from superset.constants import EMPTY_STRING, NULL_STRING
from superset.constants import EMPTY_STRING, InstantTimeComparison, NULL_STRING
from superset.db_engine_specs.base import TimestampExpression
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
Expand Down Expand Up @@ -89,6 +90,7 @@
)
from superset.utils import core as utils
from superset.utils.core import (
FilterOperator,
GenericDataType,
get_column_name,
get_user_id,
Expand Down Expand Up @@ -905,6 +907,123 @@ def _apply_cte(sql: str, cte: Optional[str]) -> str:
sql = f"{cte}\n{sql}"
return sql

def extract_column_names(self, final_selected_columns: Any) -> list[str]:
column_names = []
for selected_col in final_selected_columns:
# The key attribute usually holds the name or alias of the column
column_name = selected_col.key if hasattr(selected_col, "key") else None
# If the column has a name attribute, use it as a fallback
if not column_name and hasattr(selected_col, "name"):
column_name = selected_col.name
# For labeled elements, the name is stored in the 'name' attribute
if hasattr(selected_col, "name"):
column_name = selected_col.name
# Append the extracted name to the list
if column_name:
column_names.append(column_name)
return column_names

def process_time_compare_join( # pylint: disable=too-many-locals
self,
query_obj: QueryObjectDict,
sqlaq: SqlaQuery,
mutate: bool,
instant_time_comparison_info: dict[str, Any],
) -> tuple[str, list[str]]:
"""
Main goal of this method is to create a JOIN between a given query object and
other that shifts the time filters. This is different from time_offsets because
we are not joining result sets but rather we're applying the JOIN at query level.
Use case: Compare paginated data in a Table Chart. But ideally can be leveraged by
anything that needs the experimental instant time comparison.
"""
# So we don't override the original QueryObject
query_obj_clone = copy.copy(query_obj)
final_query_sql = ""
# The inner query object doesn't need limits nor offset
query_obj_clone["row_limit"] = None
query_obj_clone["row_offset"] = None
# Let's get what range should we be using when building the time_comparison shift
# This is computing the time_shift based on some predefined options of deltas
instant_time_comparison_range = instant_time_comparison_info.get("range")
if instant_time_comparison_range == InstantTimeComparison.CUSTOM:
# If it's a custom filter, we take the 1st temporal filter and change it with
# whatever value we received in the request as the custom filter.
custom_filter = instant_time_comparison_info.get("filter", {})
temporal_filters = [
filter["col"]
for filter in query_obj_clone.get("filter", {})
if filter.get("op", None) == FilterOperator.TEMPORAL_RANGE
]
non_temporal_filters = [
filter["col"]
for filter in query_obj_clone.get("filter", {})
if filter.get("op", None) != FilterOperator.TEMPORAL_RANGE
]
if len(temporal_filters) > 0:
# Edit the firt temporal filter to include the custom filter
temporal_filters[0] = custom_filter

new_filters = temporal_filters + non_temporal_filters
query_obj_clone["filter"] = new_filters
if instant_time_comparison_range != InstantTimeComparison.CUSTOM:
# When not custom, we're supposed to use the predefined time ranges
# Year, Month, Week or Inherited
query_obj_clone["extras"] = {
**query_obj_clone.get("extras", {}),
"instant_time_comparison_range": instant_time_comparison_range,
}
shifted_sqlaq = self.get_sqla_query(**query_obj_clone)
# We JOIN only over columns, not metrics or anything else since those cannot be
# joined
join_columns = query_obj_clone.get("columns") or []
original_query_a = sqlaq.sqla_query
shifted_query_b = shifted_sqlaq.sqla_query
shifted_query_b_subquery = shifted_query_b.subquery()
query_a_cte = original_query_a.cte("query_a_results")
column_names_a = [column.key for column in original_query_a.c]
exclude_columns_b = set(query_obj_clone.get("columns") or [])
# Let's prepare the columns set to be used in query A and B
selected_columns_a = [query_a_cte.c[col].label(col) for col in column_names_a]
# Renamed columns from Query B (with "prev_" prefix)
selected_columns_b = [
shifted_query_b_subquery.c[col].label(f"prev_{col}")
for col in shifted_query_b_subquery.c.keys()
if col not in exclude_columns_b
]
# Combine selected columns from both queries
final_selected_columns = selected_columns_a + selected_columns_b
if join_columns and not query_obj_clone.get("is_rowcount"):
# Proceed with JOIN operation as before since join_columns is not empty
join_conditions = [
shifted_query_b_subquery.c[col] == query_a_cte.c[col]
for col in join_columns
if col in shifted_query_b_subquery.c and col in query_a_cte.c
]
final_query = sa.select(*final_selected_columns).select_from(
shifted_query_b_subquery.join(query_a_cte, sa.and_(*join_conditions))
)
else:
# When dealing with queries that have no columns or that are totals,
# rowcounts etc we join with the 1 = 1 to create a result set that have
# both sets (original and prev)
final_query = sa.select(*final_selected_columns).select_from(
shifted_query_b_subquery.join(
query_a_cte, sa.literal(True) == sa.literal(True)
)
)
# Transform the query as you would within get_query_str_extended
final_query_sql = self.database.compile_sqla_query(final_query)
final_query_sql = self._apply_cte(final_query_sql, sqlaq.cte)
final_query_sql = sqlparse.format(final_query_sql, reindent=True)
if mutate:
final_query_sql = self.mutate_query_from_config(final_query_sql)

# Prepare the labels for the columns to be used
labels_expected = self.extract_column_names(final_selected_columns)

return final_query_sql, labels_expected

def get_query_str_extended(
self,
query_obj: QueryObjectDict,
Expand All @@ -918,15 +1037,49 @@ def get_query_str_extended(
except SupersetParseError:
logger.warning("Unable to parse SQL to format it, passing it as-is")

# Need to tell apart the regular queries from the ones that need
# Time comparison
query_obj_clone = copy.copy(query_obj)
query_object_extras: dict[str, Any] = query_obj.get("extras", {})
instant_time_comparison_info = query_object_extras.get(
"instant_time_comparison_info", {}
)

if (
is_feature_enabled("CHART_PLUGINS_EXPERIMENTAL")
and instant_time_comparison_info
):
# Check that only DBs that support JOINs and Subqueries use this feature
if (
self.database is not None
and self.database.db_engine_spec is not None
and (
not self.database.db_engine_spec.allows_joins
or not self.database.db_engine_spec.allows_subqueries
)
):
raise QueryObjectValidationError(
_("Instant time comparison is not supported for this database")
)
(
final_query_sql,
labels_expected,
) = self.process_time_compare_join(
query_obj_clone, sqlaq, mutate, instant_time_comparison_info
)
else:
final_query_sql = sql
labels_expected = sqlaq.labels_expected

if mutate:
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
applied_template_filters=sqlaq.applied_template_filters,
applied_filter_columns=sqlaq.applied_filter_columns,
rejected_filter_columns=sqlaq.rejected_filter_columns,
labels_expected=sqlaq.labels_expected,
labels_expected=labels_expected,
prequeries=sqlaq.prequeries,
sql=sql,
sql=final_query_sql if final_query_sql else sql,
)

def _normalize_prequery_result_type(
Expand Down
Loading
Loading