From 7e321810fae90e9712c09fcef6d06d6cc6e2a4f4 Mon Sep 17 00:00:00 2001 From: Demian Wassermann Date: Sun, 7 Feb 2021 18:24:19 +0100 Subject: [PATCH 01/33] Debugging improvements and added an implementation for logical minus --- dask_sql/context.py | 5 ++ dask_sql/physical/rel/logical/__init__.py | 2 + dask_sql/physical/rel/logical/minus.py | 71 +++++++++++++++++++++++ dask_sql/utils.py | 2 +- 4 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 dask_sql/physical/rel/logical/minus.py diff --git a/dask_sql/context.py b/dask_sql/context.py index 77ae10b6f..220790136 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -86,6 +86,7 @@ def __init__(self): RelConverter.add_plugin_class(logical.LogicalSortPlugin, replace=False) RelConverter.add_plugin_class(logical.LogicalTableScanPlugin, replace=False) RelConverter.add_plugin_class(logical.LogicalUnionPlugin, replace=False) + RelConverter.add_plugin_class(logical.LogicalMinusPlugin, replace=False) RelConverter.add_plugin_class(logical.LogicalValuesPlugin, replace=False) RelConverter.add_plugin_class(logical.SamplePlugin, replace=False) RelConverter.add_plugin_class(custom.AnalyzeTablePlugin, replace=False) @@ -515,6 +516,10 @@ def _get_ral(self, sql): nonOptimizedRelNode = generator.getRelationalAlgebra(validatedSqlNode) rel = generator.getOptimizedRelationalAlgebra(nonOptimizedRelNode) rel_string = str(generator.getRelationalAlgebraString(rel)) + logger.debug( + f"Non optimised query plan: \n " + f"{str(generator.getRelationalAlgebraString(nonOptimizedRelNode))}" + ) except (ValidationException, SqlParseException) as e: logger.debug(f"Original exception raised by Java:\n {e}") # We do not want to re-raise an exception here diff --git a/dask_sql/physical/rel/logical/__init__.py b/dask_sql/physical/rel/logical/__init__.py index 99698157c..9d429e6ac 100644 --- a/dask_sql/physical/rel/logical/__init__.py +++ b/dask_sql/physical/rel/logical/__init__.py @@ -6,6 +6,7 @@ from .sort import LogicalSortPlugin from .table_scan import LogicalTableScanPlugin from .union import LogicalUnionPlugin +from .minus import LogicalMinusPlugin from .values import LogicalValuesPlugin __all__ = [ @@ -16,6 +17,7 @@ LogicalSortPlugin, LogicalTableScanPlugin, LogicalUnionPlugin, + LogicalMinusPlugin, LogicalValuesPlugin, SamplePlugin, ] diff --git a/dask_sql/physical/rel/logical/minus.py b/dask_sql/physical/rel/logical/minus.py new file mode 100644 index 000000000..f26dc6e1b --- /dev/null +++ b/dask_sql/physical/rel/logical/minus.py @@ -0,0 +1,71 @@ +import dask.dataframe as dd + +from dask_sql.physical.rex import RexConverter +from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.datacontainer import DataContainer, ColumnContainer + + +class LogicalMinusPlugin(BaseRelPlugin): + """ + LogicalUnion is used on EXCEPT clauses. + It just concatonates the two data frames. + """ + + class_name = "org.apache.calcite.rel.logical.LogicalMinus" + + def convert( + self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" + ) -> DataContainer: + first_dc, second_dc = self.assert_inputs(rel, 2, context) + + first_df = first_dc.df + first_cc = first_dc.column_container + + second_df = second_dc.df + second_cc = second_dc.column_container + + # For concatenating, they should have exactly the same fields + output_field_names = [str(x) for x in rel.getRowType().getFieldNames()] + assert len(first_cc.columns) == len(output_field_names) + first_cc = first_cc.rename( + columns={ + col: output_col + for col, output_col in zip(first_cc.columns, output_field_names) + } + ) + first_dc = DataContainer(first_df, first_cc) + + assert len(second_cc.columns) == len(output_field_names) + second_cc = second_cc.rename( + columns={ + col: output_col + for col, output_col in zip(second_cc.columns, output_field_names) + } + ) + second_dc = DataContainer(second_df, second_cc) + + # To concat the to dataframes, we need to make sure the + # columns actually have the specified names in the + # column containers + # Otherwise the concat won't work + first_df = first_dc.assign() + second_df = second_dc.assign() + + self.check_columns_from_row_type(first_df, rel.getExpectedInputRowType(0)) + self.check_columns_from_row_type(second_df, rel.getExpectedInputRowType(1)) + + df = first_df.merge( + second_df, + how='left', + indicator=True, + ) + + df = df[ + df.iloc[:, -1] == "left_only" + ].iloc[:, :-1] + + cc = ColumnContainer(df.columns) + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + dc = DataContainer(df, cc) + dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) + return dc diff --git a/dask_sql/utils.py b/dask_sql/utils.py index 51ad06f51..a3b8350e3 100644 --- a/dask_sql/utils.py +++ b/dask_sql/utils.py @@ -212,7 +212,7 @@ def get_table_from_compound_identifier( try: return context.tables[tableName] except KeyError: - raise AttributeError(f"Table {tableName} is not defined.") + raise AttributeError(f"Table '{tableName}' does not exist.") def convert_sql_kwargs( From d6c7e6aa40dd727edf741f13070acffa5a3710a3 Mon Sep 17 00:00:00 2001 From: Demian Wassermann Date: Sun, 7 Feb 2021 18:24:56 +0100 Subject: [PATCH 02/33] Adding some new optimisations --- .../com/dask/sql/application/RelationalAlgebraGenerator.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java index 21bd319b6..04eefc60b 100644 --- a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java +++ b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java @@ -28,9 +28,11 @@ import org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule; import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule; import org.apache.calcite.rel.rules.FilterAggregateTransposeRule; +import org.apache.calcite.rel.rules.FilterSetOpTransposeRule; import org.apache.calcite.rel.rules.FilterJoinRule; import org.apache.calcite.rel.rules.FilterMergeRule; import org.apache.calcite.rel.rules.FilterRemoveIsNotDistinctFromRule; +import org.apache.calcite.rel.rules.LoptOptimizeJoinRule; import org.apache.calcite.rel.rules.ProjectJoinTransposeRule; import org.apache.calcite.rel.rules.ProjectMergeRule; import org.apache.calcite.rel.rules.ProjectRemoveRule; @@ -135,9 +137,11 @@ private HepPlanner getHepPlanner(final FrameworkConfig config) { // Taken from blazingSQL final HepProgram program = new HepProgramBuilder() .addRuleInstance(AggregateExpandDistinctAggregatesRule.Config.JOIN.toRule()) + .addRuleInstance(FilterSetOpTransposeRule.Config.DEFAULT.toRule()) .addRuleInstance(FilterAggregateTransposeRule.Config.DEFAULT.toRule()) .addRuleInstance(FilterJoinRule.JoinConditionPushRule.Config.DEFAULT.toRule()) .addRuleInstance(FilterJoinRule.FilterIntoJoinRule.Config.DEFAULT.toRule()) + .addRuleInstance(LoptOptimizeJoinRule.Config.DEFAULT.toRule()) .addRuleInstance(ProjectMergeRule.Config.DEFAULT.toRule()) .addRuleInstance(FilterMergeRule.Config.DEFAULT.toRule()) .addRuleInstance(ProjectJoinTransposeRule.Config.DEFAULT.toRule()) From 9ac75bbf6447628d77611186bef8695a9443ba89 Mon Sep 17 00:00:00 2001 From: Demian Wassermann Date: Mon, 8 Feb 2021 22:51:54 +0100 Subject: [PATCH 03/33] Squashed commit of the following: commit e5fac1ab0f0b873b8b14e5cec021f5edd00ddff4 Author: Nils Braun Date: Sun Feb 7 16:20:55 2021 +0100 Aggregate improvements and SQL compatibility (#134) * A lot of refactoring the the groupby. Mainly to include both distinct and null-grouping * Test for non-dask aggregations * All NaN data needs to go into the same partition (otherwise we can not sort) * Fix compatibility with SQL on null-joins * Distinct is not needed, as it is optimized away from Calcite * Implement is not distinct * Describe new limitations and remove old ones * Added compatibility test from fugue * Added a test for sorting with multiple partitions and NaNs * Stylefix commit 7273c2d4155814185f21c605402f29dd85999ca5 Author: Nils Braun Date: Sun Feb 7 15:34:55 2021 +0100 Docs improvements (#132) * Fixed a bug in function references in docs * More details on the dask-sql internals commit bdc518ed4e290c3af417f2075a6abafd426824be Author: Nils Braun Date: Sun Feb 7 14:19:50 2021 +0100 Fix the fugue dependency (#133) --- .github/workflows/test.yml | 2 +- dask_sql/physical/rel/logical/aggregate.py | 332 +++++--- dask_sql/physical/rel/logical/join.py | 17 + dask_sql/physical/rel/logical/sort.py | 2 +- dask_sql/physical/rex/core/call.py | 17 + docs/pages/cmd.rst | 2 +- docs/pages/custom.rst | 4 +- docs/pages/data_input.rst | 12 +- docs/pages/how_does_it_work.rst | 118 ++- docs/pages/machine_learning.rst | 4 +- docs/pages/quickstart.rst | 2 +- docs/pages/server.rst | 4 +- docs/pages/sql.rst | 9 +- docs/pages/sql/ml.rst | 4 +- tests/integration/test_compatibility.py | 884 +++++++++++++++++++++ tests/integration/test_groupby.py | 15 +- tests/integration/test_sort.py | 23 + 17 files changed, 1303 insertions(+), 148 deletions(-) create mode 100644 tests/integration/test_compatibility.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b655081e0..4aac6b32f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -84,7 +84,7 @@ jobs: run: | # explicitly install docker, fugue and sqlalchemy package conda install sqlalchemy psycopg2 -c conda-forge - pip install docker fugue + pip install docker "fugue<=0.5.0" if: matrix.os == 'ubuntu-latest' - name: Install Java (again) and test with pytest shell: bash -l {0} diff --git a/dask_sql/physical/rel/logical/aggregate.py b/dask_sql/physical/rel/logical/aggregate.py index 3d1394876..8fa832918 100644 --- a/dask_sql/physical/rel/logical/aggregate.py +++ b/dask_sql/physical/rel/logical/aggregate.py @@ -1,9 +1,10 @@ import operator from collections import defaultdict from functools import reduce -from typing import Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union import logging +import pandas as pd import dask.dataframe as dd from dask_sql.utils import new_temporary_column @@ -13,45 +14,50 @@ logger = logging.getLogger(__name__) -class GroupDatasetDescription: +class ReduceAggregation(dd.Aggregation): """ - Helper class to put dataframes which are filtered according to a specific column - into a dictionary. - Applying the same filter twice on the same dataframe does not give different - dataframes. Therefore we only hash these dataframes according to the column - they are filtered by. + A special form of an aggregation, that applies a given operation + on all elements in a group with "reduce". """ - def __init__(self, df: dd.DataFrame, filtered_column: str = ""): - self.df = df - self.filtered_column = filtered_column + def __init__(self, name: str, operation: Callable): + series_aggregate = lambda s: s.aggregate(lambda x: reduce(operation, x)) - def __eq__(self, rhs: "GroupDatasetDescription") -> bool: - """They are equal of they are filtered by the same column""" - return self.filtered_column == rhs.filtered_column + super().__init__(name, series_aggregate, series_aggregate) - def __hash__(self) -> str: - return hash(self.filtered_column) - def __repr__(self) -> str: - return f"GroupDatasetDescription({self.filtered_column})" +class AggregationOnPandas(dd.Aggregation): + """ + A special form of an aggregation, which does not apply the given function + (given as attribute name) directly to the dask groupby, but + via the groupby().apply() method. This is needed to call + functions directly on the pandas dataframes, but should be done + very carefully (as it is a performance bottleneck). + """ + def __init__(self, function_name: str): + def _f(s): + return s.apply(lambda s0: getattr(s0.dropna(), function_name)()) -# Description of an aggregation in the form of a mapping -# input column -> output column -> aggregation -AggregationDescription = Dict[str, Dict[str, Union[str, dd.Aggregation]]] + super().__init__(function_name, _f, _f) -class ReduceAggregation(dd.Aggregation): +class AggregationSpecification: """ - A special form of an aggregation, that applies a given operation - on all elements in a group with "reduce". + Most of the aggregations in SQL are already + implemented 1:1 in dask and can just be called via their name + (e.g. AVG is the mean). However sometimes those already + implemented functions only work well for numerical + functions. This small container class therefore + can have an additional aggregation function, which is + valid for non-numerical types. """ - def __init__(self, name: str, operation: Callable): - series_aggregate = lambda s: s.aggregate(lambda x: reduce(operation, x)) - - super().__init__(name, series_aggregate, series_aggregate) + def __init__(self, numerical_aggregation, non_numerical_aggregation=None): + self.numerical_aggregation = numerical_aggregation + self.non_numerical_aggregation = ( + non_numerical_aggregation or numerical_aggregation + ) class LogicalAggregatePlugin(BaseRelPlugin): @@ -63,31 +69,45 @@ class LogicalAggregatePlugin(BaseRelPlugin): group over, in the second case we "cheat" and add a 1-column to the dataframe, which allows us to reuse every aggregation function we already know of. + As NULLs are not groupable in dask, we handle them special + by adding a temporary column which is True for all NULL values + and False otherwise (and also group by it). The rest is just a lot of column-name-bookkeeping. Fortunately calcite will already make sure, that each aggregation function will only every be called with a single input column (by splitting the inner calculation to a step before). + + Open TODO: So far we are following the dask default + to only have a single partition after the group by (which is usual + a reasonable assumption). It would be nice to control + these things via HINTs. """ class_name = "org.apache.calcite.rel.logical.LogicalAggregate" AGGREGATION_MAPPING = { - "$sum0": "sum", - "any_value": dd.Aggregation( - "any_value", - lambda s: s.sample(n=1).values, - lambda s0: s0.sample(n=1).values, + "$sum0": AggregationSpecification("sum", AggregationOnPandas("sum")), + "any_value": AggregationSpecification( + dd.Aggregation( + "any_value", + lambda s: s.sample(n=1).values, + lambda s0: s0.sample(n=1).values, + ) + ), + "avg": AggregationSpecification("mean", AggregationOnPandas("mean")), + "bit_and": AggregationSpecification( + ReduceAggregation("bit_and", operator.and_) ), - "avg": "mean", - "bit_and": ReduceAggregation("bit_and", operator.and_), - "bit_or": ReduceAggregation("bit_or", operator.or_), - "bit_xor": ReduceAggregation("bit_xor", operator.xor), - "count": "count", - "every": dd.Aggregation("every", lambda s: s.all(), lambda s0: s0.all()), - "max": "max", - "min": "min", - "single_value": "first", + "bit_or": AggregationSpecification(ReduceAggregation("bit_or", operator.or_)), + "bit_xor": AggregationSpecification(ReduceAggregation("bit_xor", operator.xor)), + "count": AggregationSpecification("count"), + "every": AggregationSpecification( + dd.Aggregation("every", lambda s: s.all(), lambda s0: s0.all()) + ), + "max": AggregationSpecification("max", AggregationOnPandas("max")), + "min": AggregationSpecification("min", AggregationOnPandas("min")), + "single_value": AggregationSpecification("first"), } def convert( @@ -110,65 +130,22 @@ def convert( cc.get_backend_by_frontend_index(i) for i in group_column_indices ] - # Always keep an additional column around for empty groups and aggregates - additional_column_name = new_temporary_column(df) - - # NOTE: it might be the case that - # we do not need this additional - # column, but hopefully adding a single - # column of 1 is not so problematic... - df = df.assign(**{additional_column_name: 1}) - cc = cc.add(additional_column_name) dc = DataContainer(df, cc) - # Collect all aggregates - filtered_aggregations, output_column_order = self._collect_aggregations( - rel, dc, group_columns, additional_column_name, context - ) - if not group_columns: # There was actually no GROUP BY specified in the SQL # Still, this plan can also be used if we need to aggregate something over the full # data sample # To reuse the code, we just create a new column at the end with a single value - # It is important to do this after creating the aggregations, - # as we do not want this additional column to be used anywhere - group_columns = [additional_column_name] - logger.debug("Performing full-table aggregation") - # Now we can perform the aggregates - # We iterate through all pairs of (possible pre-filtered) - # dataframes and the aggregations to perform in this data... - df_agg = None - for filtered_df_desc, aggregation in filtered_aggregations.items(): - filtered_column = filtered_df_desc.filtered_column - if filtered_column: - logger.debug( - f"Aggregating {dict(aggregation)} on the data filtered by {filtered_column}" - ) - else: - logger.debug(f"Aggregating {dict(aggregation)} on the data") - - # ... we perform the aggregations ... - filtered_df = filtered_df_desc.df - # TODO: we could use the type information for - # pre-calculating the meta information - filtered_df_agg = filtered_df.groupby(by=group_columns).agg(aggregation) - - # ... fix the column names to a single level ... - filtered_df_agg.columns = filtered_df_agg.columns.get_level_values(-1) - - # ... and finally concat the new data with the already present columns - if df_agg is None: - df_agg = filtered_df_agg - else: - df_agg = df_agg.assign( - **{col: filtered_df_agg[col] for col in filtered_df_agg.columns} - ) + # Do all aggregates + df_result, output_column_order = self._do_aggregations( + rel, dc, group_columns, context, + ) # SQL does not care about the index, but we do not want to have any multiindices - df_agg = df_agg.reset_index(drop=True) + df_agg = df_result.reset_index(drop=True) # Fix the column names and the order of them, as this was messed with during the aggregations df_agg.columns = df_agg.columns.get_level_values(-1) @@ -179,48 +156,103 @@ def convert( dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) return dc - def _collect_aggregations( + def _do_aggregations( self, rel: "org.apache.calcite.rel.RelNode", dc: DataContainer, group_columns: List[str], - additional_column_name: str, context: "dask_sql.Context", - ) -> Tuple[ - Dict[GroupDatasetDescription, AggregationDescription], List[int], - ]: + ) -> Tuple[dd.DataFrame, List[str]]: """ - Create a mapping of dataframe -> aggregations (in the form input colum, output column, aggregation) - and the expected order of output columns. + Main functionality: return the result dataframe + and the output column order """ - aggregations = defaultdict(lambda: defaultdict(dict)) - output_column_order = [] df = dc.df cc = dc.column_container - # SQL needs to copy the old content also. As the values of the group columns + # We might need it later. + # If not, lets hope that adding a single column should not + # be a huge problem... + additional_column_name = new_temporary_column(df) + df = df.assign(**{additional_column_name: 1}) + + # Add an entry for every grouped column, as SQL wants them first + output_column_order = group_columns.copy() + + # Collect all aggregations we need to do + collected_aggregations, output_column_order = self._collect_aggregations( + rel, df, cc, context, additional_column_name, output_column_order + ) + + # SQL needs to have a column with the grouped values as the first + # output column. + # As the values of the group columns # are the same for a single group anyways, we just use the first row for col in group_columns: - aggregations[GroupDatasetDescription(df)][col][col] = "first" - output_column_order.append(col) + collected_aggregations[None].append((col, col, "first")) + + # Now we can go ahead and use these grouped aggregations + # to perform the actual aggregation + # It is very important to start with the non-filtered entry. + # Otherwise we might loose some entries in the grouped columns + key = None + aggregations = collected_aggregations.pop(key) + df_result = self._perform_aggregation( + df, None, aggregations, additional_column_name, group_columns, + ) + + # Now we can also the the rest + for filter_column, aggregations in collected_aggregations.items(): + agg_result = self._perform_aggregation( + df, filter_column, aggregations, additional_column_name, group_columns, + ) + + # ... and finally concat the new data with the already present columns + df_result = df_result.assign( + **{col: agg_result[col] for col in agg_result.columns} + ) + + return df_result, output_column_order + + def _collect_aggregations( + self, + rel: "org.apache.calcite.rel.RelNode", + df: dd.DataFrame, + cc: ColumnContainer, + context: "dask_sql.Context", + additional_column_name: str, + output_column_order: List[str], + ) -> Tuple[Dict[Tuple[str, str], List[Tuple[str, str, Any]]], List[str]]: + """ + Collect all aggregations together, which have the same filter column + so that the aggregations only need to be done once. + + Returns the aggregations as mapping filter_column -> List of Aggregations + where the aggregations are in the form (input_col, output_col, aggregation function (or string)) + """ + collected_aggregations = defaultdict(list) - # Now collect all aggregations for agg_call in rel.getNamedAggCalls(): - output_col = str(agg_call.getValue()) expr = agg_call.getKey() - if expr.hasFilter(): - filter_column = cc.get_backend_by_frontend_index(expr.filterArg) - filter_expression = df[filter_column] - filtered_df = df[filter_expression] - - grouped_df = GroupDatasetDescription(filtered_df, filter_column) + # Find out about the input column + inputs = expr.getArgList() + if len(inputs) == 1: + input_col = cc.get_backend_by_frontend_index(inputs[0]) + elif len(inputs) == 0: + input_col = additional_column_name else: - grouped_df = GroupDatasetDescription(df) + raise NotImplementedError("Can not cope with more than one input") - if expr.isDistinct(): - raise NotImplementedError("DISTINCT is not implemented (yet)") + # Extract flags (filtering/distinct) + if expr.isDistinct(): # pragma: no cover + raise ValueError("Apache Calcite should optimize them away!") + filter_column = None + if expr.hasFilter(): + filter_column = cc.get_backend_by_frontend_index(expr.filterArg) + + # Find out which aggregation function to use aggregation_name = str(expr.getAggregation().getName()) aggregation_name = aggregation_name.lower() try: @@ -232,16 +264,74 @@ def _collect_aggregations( raise NotImplementedError( f"Aggregation function {aggregation_name} not implemented (yet)." ) + if isinstance(aggregation_function, AggregationSpecification): + dtype = df[input_col].dtype + if pd.api.types.is_numeric_dtype(dtype): + aggregation_function = aggregation_function.numerical_aggregation + else: + aggregation_function = ( + aggregation_function.non_numerical_aggregation + ) - inputs = expr.getArgList() - if len(inputs) == 1: - input_col = cc.get_backend_by_frontend_index(inputs[0]) - elif len(inputs) == 0: - input_col = additional_column_name - else: - raise NotImplementedError("Can not cope with more than one input") + # Finally, extract the output column name + output_col = str(agg_call.getValue()) - aggregations[grouped_df][input_col][output_col] = aggregation_function + # Store the aggregation + key = filter_column + value = (input_col, output_col, aggregation_function) + collected_aggregations[key].append(value) output_column_order.append(output_col) - return aggregations, output_column_order + return collected_aggregations, output_column_order + + def _perform_aggregation( + self, + df: dd.DataFrame, + filter_column: str, + aggregations: List[Tuple[str, str, Any]], + additional_column_name: str, + group_columns: List[str], + ): + tmp_df = df + + if filter_column: + filter_expression = tmp_df[filter_column] + tmp_df = tmp_df[filter_expression] + + logger.debug(f"Filtered by {filter_column} before aggregation.") + + # SQL and dask are treating null columns a bit different: + # SQL will put them to the front, dask will just ignore them + # Therefore we use the same trick as fugue does: + # we will group by both the NaN and the real column value + group_columns_and_nulls = [] + for group_column in group_columns: + # the ~ makes NaN come first + is_null_column = ~(tmp_df[group_column].isnull()) + non_nan_group_column = tmp_df[group_column].fillna(0) + + group_columns_and_nulls += [is_null_column, non_nan_group_column] + + if not group_columns_and_nulls: + # This can happen in statements like + # SELECT SUM(x) FROM data + # without any groupby statement + group_columns_and_nulls = [additional_column_name] + + grouped_df = tmp_df.groupby(by=group_columns_and_nulls) + + # Convert into the correct format for dask + aggregations_dict = defaultdict(dict) + for aggregation in aggregations: + input_col, output_col, aggregation_f = aggregation + + aggregations_dict[input_col][output_col] = aggregation_f + + # Now apply the aggregation + logger.debug(f"Performing aggregation {dict(aggregations_dict)}") + agg_result = grouped_df.agg(aggregations_dict) + + # ... fix the column names to a single level ... + agg_result.columns = agg_result.columns.get_level_values(-1) + + return agg_result diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py index 8c2dd8e0f..093845304 100644 --- a/dask_sql/physical/rel/logical/join.py +++ b/dask_sql/physical/rel/logical/join.py @@ -100,6 +100,23 @@ def convert( f"common_{i}": df_rhs_renamed.iloc[:, index] for i, index in enumerate(rhs_on) } + + # SQL compatibility: when joining on columns that + # contain NULLs, pandas will actually happily + # keep those NULLs. That is however not compatible with + # SQL, so we get rid of them here + if join_type in ["inner", "right"]: + df_lhs_filter = reduce( + operator.and_, + [~df_lhs_renamed.iloc[:, index].isna() for index in lhs_on], + ) + df_lhs_renamed = df_lhs_renamed[df_lhs_filter] + if join_type in ["inner", "left"]: + df_rhs_filter = reduce( + operator.and_, + [~df_rhs_renamed.iloc[:, index].isna() for index in rhs_on], + ) + df_rhs_renamed = df_rhs_renamed[df_rhs_filter] else: # We are in the complex join case # where we have no column to merge on diff --git a/dask_sql/physical/rel/logical/sort.py b/dask_sql/physical/rel/logical/sort.py index 3038e445e..4c4a22ad3 100644 --- a/dask_sql/physical/rel/logical/sort.py +++ b/dask_sql/physical/rel/logical/sort.py @@ -135,7 +135,7 @@ def _sort_first_column( col = df[first_sort_column] is_na = col.isna().persist() if is_na.any().compute(): - df_is_na = df[is_na].reset_index(drop=True) + df_is_na = df[is_na].reset_index(drop=True).repartition(1) df_not_is_na = ( df[~is_na] .set_index(first_sort_column, drop=False) diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index ffd82bd57..8d5666a2f 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -240,6 +240,21 @@ def null(self, df: SeriesOrScalar,) -> SeriesOrScalar: return pd.isna(df) or df is None or np.isnan(df) +class IsNotDistinctOperation(Operation): + """The is not distinct operator""" + + def __init__(self): + super().__init__(self.not_distinct) + + def not_distinct(self, lhs: SeriesOrScalar, rhs: SeriesOrScalar) -> SeriesOrScalar: + """ + Returns true where `lhs` is not distinct from `rhs` (or both are null). + """ + is_null = IsNullOperation() + + return (is_null(lhs) & is_null(rhs)) | (lhs == rhs) + + class RegexOperation(Operation): """An abstract regex operation, which transforms the SQL regex into something python can understand""" @@ -627,6 +642,8 @@ class RexCallPlugin(BaseRexPlugin): "-": ReduceOperation(operation=operator.sub, unary_operation=lambda x: -x), "/": ReduceOperation(operation=SQLDivisionOperator()), "*": ReduceOperation(operation=operator.mul), + "is distinct from": NotOperation().of(IsNotDistinctOperation()), + "is not distinct from": IsNotDistinctOperation(), # special operations "cast": lambda x: x, "case": CaseOperation(), diff --git a/docs/pages/cmd.rst b/docs/pages/cmd.rst index 8d3926632..0dcac519f 100644 --- a/docs/pages/cmd.rst +++ b/docs/pages/cmd.rst @@ -21,7 +21,7 @@ or by running these lines of code cmd_loop() Some options can be set, e.g. to preload some testdata. -Have a look into :func:`dask_sql.cmd_loop` or call +Have a look into :func:`~dask_sql.cmd_loop` or call .. code-block:: bash diff --git a/docs/pages/custom.rst b/docs/pages/custom.rst index c0e3d0876..3f2b73527 100644 --- a/docs/pages/custom.rst +++ b/docs/pages/custom.rst @@ -11,7 +11,7 @@ Scalar Functions ---------------- A scalar function (such as :math:`x \to x^2`) turns a given column into another column of the same length. -It can be registered for usage in SQL with the :func:`dask_sql.Context.register_function` method. +It can be registered for usage in SQL with the :func:`~dask_sql.Context.register_function` method. Example: @@ -38,7 +38,7 @@ Aggregation Functions Aggregation functions run on a single column and turn them into a single value. This means they can only be used in ``GROUP BY`` aggregations. -They can be registered with the :func:`dask_sql.Context.register_aggregation` method. +They can be registered with the :func:`~dask_sql.Context.register_aggregation` method. This time however, an instance of a :class:`dask.dataframe.Aggregation` needs to be passed instead of a plain function. More information on dask aggregations can be found in the diff --git a/docs/pages/data_input.rst b/docs/pages/data_input.rst index 9a924f215..9b3413c24 100644 --- a/docs/pages/data_input.rst +++ b/docs/pages/data_input.rst @@ -3,14 +3,14 @@ Data Loading and Input ====================== -Before data can be queried with ``dask-sql``, it needs to be loaded into the dask cluster (or local instance) and registered with the :class:`dask_sql.Context`. +Before data can be queried with ``dask-sql``, it needs to be loaded into the dask cluster (or local instance) and registered with the :class:`~dask_sql.Context`. For this, ``dask-sql`` uses the wide field of possible `input formats `_ of ``dask``, plus some additional formats only suitable for `dask-sql`. You have multiple possibilities to load input data in ``dask-sql``: 1. Load it via python ------------------------------- -You can either use already created dask dataframes or create one by using the :func:`create_table` function. +You can either use already created dask dataframes or create one by using the :func:`~dask_sql.Context.create_table` function. Chances are high, there exists already a function to load your favorite format or location (e.g. s3 or hdfs). See below for all formats understood by ``dask-sql``. Make sure to install required libraries both on the driver and worker machines. @@ -58,7 +58,7 @@ In ``dask``, you can publish datasets with names into the cluster memory. This allows to reuse the same data from multiple clients/users in multiple sessions. For example, you can publish your data using the ``client.publish_dataset`` function of the ``distributed.Client``, -and then later register it in the :class:`dask_sql.Context` via SQL: +and then later register it in the :class:`~dask_sql.Context` via SQL: .. code-block:: python @@ -93,7 +93,7 @@ Input Formats * All formats and locations mentioned in `the Dask docu `_, including csv, parquet, json. Just pass in the location as string (and possibly the format, e.g. "csv" if it is not clear from the file extension). The data can be from local disc or many remote locations (S3, hdfs, Azure Filesystem, http, Google Filesystem, ...) - just prefix the path with the matching protocol. - Additional arguments passed to :func:`create_table` or ``CREATE TABLE`` are given to the ``read_`` calls. + Additional arguments passed to :func:`~dask_sql.Context.create_table` or ``CREATE TABLE`` are given to the ``read_`` calls. Example: @@ -113,7 +113,7 @@ Input Formats ) * If your data is already in Pandas (or Dask) DataFrames format, you can just use it as it is via the Python API - by giving it to :ref:`create_table` directly. + by giving it to :func:`~dask_sql.Context.create_table` directly. * You can connect ``dask-sql`` to an `intake `_ catalog and use the data registered there. Assuming you have an intake catalog stored in "catalog.yaml" (can also be the URL of an intake server), you can read in a stored table "data_table" either via Python @@ -161,7 +161,7 @@ Input Formats c.create_table("my_data", cursor, hive_table_name="the_name_in_hive") Again, ``hive_table_name`` is optional and defaults to the table name in ``dask-sql``. - You can also control the database used in Hive via the ``hive_schema_name```parameter. + You can also control the database used in Hive via the ``hive_schema_name`` parameter. Additional arguments are pushed to the internally called ``read_`` functions. .. note:: diff --git a/docs/pages/how_does_it_work.rst b/docs/pages/how_does_it_work.rst index 568605573..61f6442a7 100644 --- a/docs/pages/how_does_it_work.rst +++ b/docs/pages/how_does_it_work.rst @@ -7,8 +7,116 @@ At the core, ``dask-sql`` does two things: which is specified as a tree of java objects - similar to many other SQL engines (Hive, Flink, ...) - convert this description of the query from java objects into dask API calls (and execute them) - returning a dask dataframe. -For the first step, Apache Calcite needs to know about the columns and types of the dask dataframes, -therefore some java classes to store this information for dask dataframes are defined in ``planner``. -After the translation to a relational algebra is done (using ``RelationalAlgebraGenerator.getRelationalAlgebra``), -the python methods defined in ``dask_sql.physical`` turn this into a physical dask execution plan by converting -each piece of the relational algebra one-by-one. +Th following example explains this in quite some technical details. +For most of the users, this level of technical understanding is not needed. + +1. SQL enters the library +------------------------- + +No matter of via the Python API (:ref:`api`), the command line client (:ref:`cmd`) or the server (:ref:`server`), eventually the SQL statement by the user will end up as a string in the function :func:`~dask_sql.Context.sql`. + +2. SQL is parsed +---------------- + +This function will first give the SQL string to the implemented Java classes (especially :class:`RelationalAlgebraGenerator`) via the ``jpype`` library. +Inside this class, Apache Calcite is used to first parse the SQL string and then turn it into a relational algebra. +For this, Apache Calcite uses the SQL language description specified in the Calcite library itself and the additional definitions in the ``.ftl```files in the ``dask-sql`` repository. +They specify custom language features, such as the ``CREATE MODEL`` statement. + +.. note:: + + ``.ftl`` stands for FreeMarker Template Language and is one of the standard templating languages used in the Java ecosystem. + Each of the "functions" defined in the documents defines a part of the (extended) SQL language in ``javacc`` format. + FreeMarker is used to combine these parser definitions with the ones from Apache Calcite. Have a look into the ``config.fmpp`` file for more information. + + For example the following ``javacc`` code + + .. code-block:: + + SqlNode SqlShowTables() : + { + final Span s; + final SqlIdentifier schema; + } + { + { s = span(); } + schema = CompoundIdentifier() + { + return new SqlShowTables(s.end(this), schema); + } + } + + describes a parser line, which understands SQL statements such as + + .. code-block:: sql + + SHOW TABLES FROM "schema" + + While parsing the SQL, they are turned into an instance of the Java class :class:`SqlShowTables` (which is also defined in this project). + The :class:`Span` is used internally in Apache Calcite to store the position in the parsed SQL statement (e.g. for better error output). + The ``SqlShowTables`` javacc function (not the Java class SqlShowTables) is listed in ``config.fmpp`` as a ``statementParserMethods``, which makes it parsable as main SQL statement (similar to any normal ``SELECT ...`` statement). + All Java classes used as parser return values inherit from the Calcite class :class:`SqlNode` or any derived subclass (if it makes sense). Those classes are barely containers to store the information from the parsed SQL statements (such as the schema name in the example above) and do not have any business logic by themselves. + +3. SQL is (maybe) optimized +--------------------------- + +Once the SQL string is parsed into an instance of a :class:`SqlNode` (or a subclass of it), Apache Calcite can convert it into a relational algebra and optimize it. As this is only implemented for Calcite-own classes (and not for the custom classes such as :class:`SqlCreateModel`) this conversion and optimization is not triggered for all SQL statements (have a look into :func:`Context._get_ral`). + +After optimization, the resulting Java instance will be a class of any of the :class:`Logical*` classes in Apache Calcite (such as :class:`LogicalJoin`). Each of those can contain other instances as "inputs" creating a tree of different steps in the SQL statement (see below for an example). + +So after all, the result is either an optimized tree of steps in the relational algebra (represented by instances of the :class:`Logical*` classes) or an instance of a :class:`SqlNode` (sub)class. + +4. Translation to Dask API calls +-------------------------------- + +Depending on which type the resulting java class has, they are converted into calls to python functions using different python "converters". For each Java class, there exist a converter class in the ``dask_sql.physical.rel`` folder, which are registered at the :class:`dask_sql.physical.rel.convert.RelConverter` class. +Their job is to use the information stored in the java class instances and turn it into calls to python functions (see the example below for more information). + +As many SQL statements contain calculations using literals and/or columns, these are split into their own functionality (``dask_sql.physical.rex``) following a similar plugin-based converter system. +Have a look into the specific classes to understand how the conversion of a specific SQL language feature is implemented. + +5. Result +--------- + +The result of each of the conversions is a :class:`dask.DataFrame`, which is given to the user. In case of the command line tool or the SQL server, it is evaluated immediately - otherwise it can be used for further calculations by the user. + +Example +------- + +Let's walk through the steps above using the example SQL statement + +.. code-block:: sql + + SELECT x + y FROM timeseries WHERE x > 0 + +assuming the table "timeseries" is already registered. +If you want to follow along with the steps outlined in the following, start the command line tool in debug mode + +.. code-block:: bash + + dask-sql --load-test-data --startup --log-level DEBUG + +and enter the SQL statement above. + +First, the SQL is parsed by Apache Calcite and (as it is not a custom statement) transformed into a tree of relational algebra objects. + +.. code-block:: none + + LogicalProject(EXPR$0=[+($3, $4)]) + LogicalFilter(condition=[>($3, 0)]) + LogicalTableScan(table=[[schema, timeseries]]) + +The tree output above means, that the outer instance (:class:`LogicalProject`) needs as input the output of the previous instance (:class:`LogicalFilter`) etc. + +Therefore the conversion to python API calls is called recursively (depth-first). First, the :class:`LogicalTableScan` is converted using the :class:`rel.logical.table_scan.LogicalTableScanPlugin` plugin. It will just get the correct :class:`dask.DataFrame` from the dictionary of already registered tables of the context. +Next, the :class:`LogicalFilter` (having the dataframe as input), is converted via the :class:`rel.logical.filter.LogicalFilterPlugin`. +The filter expression ``>($3, 0)`` is converted into ``df["x"] > 0`` using a combination of REX plugins (have a look into the debug output to learn more) and applied to the dataframe. +The resulting dataframe is then passed to the converter :class:`rel.logical.project.LogicalProjectPlugin` for the :class:`LogicalProject`. +This will calculate the expression ``df["x"] + df["y"]`` (after having converted it via the class:`RexCallPlugin` plugin) and return the final result to the user. + +.. code-block:: python + + df_table_scan = context.tables["timeseries"] + df_filter = df_table_scan[df_table_scan["x"] > 0] + df_project = df_filter.assign(col=df_filter["x"] + df_filter["y"]) + return df_project[["col"]] \ No newline at end of file diff --git a/docs/pages/machine_learning.rst b/docs/pages/machine_learning.rst index a55412a0b..5abcb0f5f 100644 --- a/docs/pages/machine_learning.rst +++ b/docs/pages/machine_learning.rst @@ -19,7 +19,7 @@ Please also see :ref:`ml` for more information on the SQL statements used on thi ------------------------------------------------------------- If you are familiar with Python and the ML ecosystem in Python, this one is probably -the simplest possibility. You can use the :func:`Context.sql` call as described +the simplest possibility. You can use the :func:`~dask_sql.Context.sql` call as described before to extract the data for your training or ML prediction. The result will be a Dask dataframe, which you can either directly feed into your model or convert to a pandas dataframe with `.compute()` before. @@ -49,7 +49,7 @@ automatically. The syntax is similar to the `BigQuery Predict Syntax `_ or @@ -68,7 +68,7 @@ commands. Preregister your own data sources --------------------------------- -The python function :func:`dask_sql.run_server` accepts an already created :class:`dask_sql.Context`. +The python function :func:`~dask_sql.run_server` accepts an already created :class:`~dask_sql.Context`. This means you can preload your data sources and register them with a context before starting your server. By this, your server will already have data to query: diff --git a/docs/pages/sql.rst b/docs/pages/sql.rst index ace1297ab..c371084af 100644 --- a/docs/pages/sql.rst +++ b/docs/pages/sql.rst @@ -199,14 +199,16 @@ Limitatons ``dask-sql`` is still in early development, therefore exist some limitations: -* Not all operations and aggregations are implemented already, most prominently: ``WINDOW`` is not implemented so far. -* ``GROUP BY`` aggregations can not use ``DISTINCT`` +Not all operations and aggregations are implemented already, most prominently: ``WINDOW`` is not implemented so far. .. note:: Whenever you find a not already implemented operation, keyword or functionality, please raise an issue at our `issue tracker `_ with your use-case. +Dask/pandas and SQL treat null-values (or nan) differently on sorting, grouping and joining. +``dask-sql`` tries to follow the SQL standard as much as possible, so results might be different to what you expect from Dask/pandas. + Apart from those functional limitations, there is a operation which need special care: ``ORDER BY```. Normally, ``dask-sql`` calls create a ``dask`` data frame, which gets only computed when you call the ``.compute()`` member. Due to internal constraints, this is currently not the case for ``ORDER BY``. @@ -218,4 +220,5 @@ Including this operation will trigger a calculation of the full data frame alrea The data inside ``dask`` is partitioned, to distribute it over the cluster. ``head`` will only return the first N elements from the first partition - even if N is larger than the partition size. As a benefit, calling ``.head(N)`` is typically faster than calculating the full data sample with ``.compute()``. - ``LIMIT`` on the other hand will always return the first N elements - no matter on how many partitions they are scattered - but will also need to precalculate the first partition to find out, if it needs to have a look into all data or not. + ``LIMIT`` on the other hand will always return the first N elements - no matter on how many partitions they are scattered - + but will also need to precalculate the first partition to find out, if it needs to have a look into all data or not. diff --git a/docs/pages/sql/ml.rst b/docs/pages/sql/ml.rst index bab3837f5..9100145d6 100644 --- a/docs/pages/sql/ml.rst +++ b/docs/pages/sql/ml.rst @@ -13,7 +13,7 @@ As all SQL statements in ``dask-sql`` are eventually converted to Python calls, any custom Python function and library, e.g. Machine Learning libraries. Although it would be possible to register custom functions (see :ref:`custom`) for this and use them, it is much more convenient if this functionality is already included in the core SQL language. -These three statements help in training and using models. Every :class:`Context` has a registry for models, which +These three statements help in training and using models. Every :class:`~dask_sql.Context` has a registry for models, which can be used for training or prediction. For a full example, see :ref:`machine_learning`. @@ -128,7 +128,7 @@ Predict the target using the given model and dataframe from the ``SELECT`` query The return value is the input dataframe with an additional column named "target", which contains the predicted values. The model needs to be registered at the context before using it in this function, -either by calling :func:`Context.register_model` explicitly or by training +either by calling :func:`~dask_sql.Context.register_model` explicitly or by training a model using the ``CREATE MODEL`` SQL statement above. A model can be anything which has a ``predict`` function. diff --git a/tests/integration/test_compatibility.py b/tests/integration/test_compatibility.py new file mode 100644 index 000000000..b3f1dab1b --- /dev/null +++ b/tests/integration/test_compatibility.py @@ -0,0 +1,884 @@ +""" +The tests in this module are taken from +the fugue-sql module to test the compatibility +with their "understanding" of SQL +They run randomized tests and compare with sqlite. + +There are some changes compared to the fugueSQL +tests, especially when it comes to sort order: +dask-sql does not enforce a specific order after groupby +""" + +import sqlite3 +from datetime import datetime, timedelta + +import pandas as pd +import numpy as np +from pandas.testing import assert_frame_equal +from dask_sql import Context + + +def eq_sqlite(sql, **dfs): + c = Context() + engine = sqlite3.connect(":memory:") + + for name, df in dfs.items(): + c.create_table(name, df) + df.to_sql(name, engine, index=False) + + dask_result = c.sql(sql).compute().reset_index(drop=True) + sqlite_result = pd.read_sql(sql, engine).reset_index(drop=True) + + assert_frame_equal(dask_result, sqlite_result, check_dtype=False) + + +def make_rand_df(size: int, **kwargs): + np.random.seed(0) + data = {} + for k, v in kwargs.items(): + if not isinstance(v, tuple): + v = (v, 0.0) + dt, null_ct = v[0], v[1] + if dt is int: + s = np.random.randint(10, size=size) + elif dt is bool: + s = np.where(np.random.randint(2, size=size), True, False) + elif dt is float: + s = np.random.rand(size) + elif dt is str: + r = [f"ssssss{x}" for x in range(10)] + c = np.random.randint(10, size=size) + s = np.array([r[x] for x in c]) + elif dt is datetime: + rt = [datetime(2020, 1, 1) + timedelta(days=x) for x in range(10)] + c = np.random.randint(10, size=size) + s = np.array([rt[x] for x in c]) + else: + raise NotImplementedError + ps = pd.Series(s) + if null_ct > 0: + idx = np.random.choice(size, null_ct, replace=False).tolist() + ps[idx] = None + data[k] = ps + return pd.DataFrame(data) + + +def test_basic_select_from(): + df = make_rand_df(5, a=(int, 2), b=(str, 3), c=(float, 4)) + eq_sqlite("SELECT 1 AS a, 1.5 AS b, 'x' AS c") + eq_sqlite("SELECT 1+2 AS a, 1.5*3 AS b, 'x' AS c") + eq_sqlite("SELECT * FROM a", a=df) + eq_sqlite("SELECT * FROM a AS x", a=df) + eq_sqlite("SELECT b AS bb, a+1-2*3.0/4 AS cc, x.* FROM a AS x", a=df) + eq_sqlite("SELECT *, 1 AS x, 2.5 AS y, 'z' AS z FROM a AS x", a=df) + eq_sqlite("SELECT *, -(1.0+a)/3 AS x, +(2.5) AS y FROM a AS x", a=df) + + +def test_case_when(): + a = make_rand_df(100, a=(int, 20), b=(str, 30), c=(float, 40)) + eq_sqlite( + """ + SELECT a,b,c, + CASE + WHEN a<10 THEN a+3 + WHEN c<0.5 THEN a+5 + ELSE (1+2)*3 + a + END AS d + FROM a + """, + a=a, + ) + + +def test_drop_duplicates(): + # simplest + a = make_rand_df(100, a=int, b=int) + eq_sqlite( + """ + SELECT DISTINCT b, a FROM a + ORDER BY a NULLS LAST, b NULLS FIRST + """, + a=a, + ) + # mix of number and nan + a = make_rand_df(100, a=(int, 50), b=(int, 50)) + eq_sqlite( + """ + SELECT DISTINCT b, a FROM a + ORDER BY a NULLS LAST, b NULLS FIRST + """, + a=a, + ) + # mix of number and string and nulls + a = make_rand_df(100, a=(int, 50), b=(str, 50), c=float) + eq_sqlite( + """ + SELECT DISTINCT b, a FROM a + ORDER BY a NULLS LAST, b NULLS FIRST + """, + a=a, + ) + + +def test_order_by_no_limit(): + a = make_rand_df(100, a=(int, 50), b=(str, 50), c=float) + eq_sqlite( + """ + SELECT DISTINCT b, a FROM a + ORDER BY a NULLS LAST, b NULLS FIRST + """, + a=a, + ) + + +def test_order_by_limit(): + a = make_rand_df(100, a=(int, 50), b=(str, 50), c=float) + eq_sqlite( + """ + SELECT DISTINCT b, a FROM a LIMIT 0 + """, + a=a, + ) + eq_sqlite( + """ + SELECT DISTINCT b, a FROM a ORDER BY a NULLS FIRST, b NULLS FIRST LIMIT 2 + """, + a=a, + ) + eq_sqlite( + """ + SELECT b, a FROM a + ORDER BY a NULLS LAST, b NULLS FIRST LIMIT 10 + """, + a=a, + ) + + +def test_where(): + df = make_rand_df(100, a=(int, 30), b=(str, 30), c=(float, 30)) + eq_sqlite("SELECT * FROM a WHERE TRUE OR TRUE", a=df) + eq_sqlite("SELECT * FROM a WHERE TRUE AND TRUE", a=df) + eq_sqlite("SELECT * FROM a WHERE FALSE OR FALSE", a=df) + eq_sqlite("SELECT * FROM a WHERE FALSE AND FALSE", a=df) + + eq_sqlite("SELECT * FROM a WHERE TRUE OR b<='ssssss8'", a=df) + eq_sqlite("SELECT * FROM a WHERE TRUE AND b<='ssssss8'", a=df) + eq_sqlite("SELECT * FROM a WHERE FALSE OR b<='ssssss8'", a=df) + eq_sqlite("SELECT * FROM a WHERE FALSE AND b<='ssssss8'", a=df) + eq_sqlite("SELECT * FROM a WHERE a=10 OR b<='ssssss8'", a=df) + eq_sqlite("SELECT * FROM a WHERE c IS NOT NULL OR (a<5 AND b IS NOT NULL)", a=df) + + df = make_rand_df(100, a=(float, 30), b=(float, 30), c=(float, 30)) + eq_sqlite("SELECT * FROM a WHERE a<0.5 AND b<0.5 AND c<0.5", a=df) + eq_sqlite("SELECT * FROM a WHERE a<0.5 OR b<0.5 AND c<0.5", a=df) + eq_sqlite("SELECT * FROM a WHERE a IS NULL OR (b<0.5 AND c<0.5)", a=df) + eq_sqlite("SELECT * FROM a WHERE a*b IS NULL OR (b*c<0.5 AND c*a<0.5)", a=df) + + +def test_in_between(): + df = make_rand_df(10, a=(int, 3), b=(str, 3)) + eq_sqlite("SELECT * FROM a WHERE a IN (2,4,6)", a=df) + eq_sqlite("SELECT * FROM a WHERE a BETWEEN 2 AND 4+1", a=df) + eq_sqlite("SELECT * FROM a WHERE a NOT IN (2,4,6) AND a IS NOT NULL", a=df) + eq_sqlite("SELECT * FROM a WHERE a NOT BETWEEN 2 AND 4+1 AND a IS NOT NULL", a=df) + + +def test_join_inner(): + a = make_rand_df(100, a=(int, 40), b=(str, 40), c=(float, 40)) + b = make_rand_df(80, d=(float, 10), a=(int, 10), b=(str, 10)) + eq_sqlite( + """ + SELECT + a.*, d, d*c AS x + FROM a + INNER JOIN b ON a.a=b.a AND a.b=b.b + ORDER BY a.a NULLS FIRST, a.b NULLS FIRST, a.c NULLS FIRST, d NULLS FIRST + """, + a=a, + b=b, + ) + + +def test_join_left(): + a = make_rand_df(100, a=(int, 40), b=(str, 40), c=(float, 40)) + b = make_rand_df(80, d=(float, 10), a=(int, 10), b=(str, 10)) + eq_sqlite( + """ + SELECT + a.*, d, d*c AS x + FROM a LEFT JOIN b ON a.a=b.a AND a.b=b.b + ORDER BY a.a NULLS FIRST, a.b NULLS FIRST, a.c NULLS FIRST, d NULLS FIRST + """, + a=a, + b=b, + ) + + +def test_join_cross(): + a = make_rand_df(10, a=(int, 4), b=(str, 4), c=(float, 4)) + b = make_rand_df(20, dd=(float, 1), aa=(int, 1), bb=(str, 1)) + eq_sqlite("SELECT * FROM a CROSS JOIN b", a=a, b=b) + + +def test_join_multi(): + a = make_rand_df(100, a=(int, 40), b=(str, 40), c=(float, 40)) + b = make_rand_df(80, d=(float, 10), a=(int, 10), b=(str, 10)) + c = make_rand_df(80, dd=(float, 10), a=(int, 10), b=(str, 10)) + eq_sqlite( + """ + SELECT a.*,d,dd FROM a + INNER JOIN b ON a.a=b.a AND a.b=b.b + INNER JOIN c ON a.a=c.a AND c.b=b.b + ORDER BY a.a NULLS FIRST, a.b NULLS FIRST, a.c NULLS FIRST, dd NULLS FIRST, d NULLS FIRST + """, + a=a, + b=b, + c=c, + ) + + +def test_agg_count_no_group_by(): + a = make_rand_df( + 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) + ) + eq_sqlite( + """ + SELECT + COUNT(a) AS c_a, + COUNT(DISTINCT a) AS cd_a, + COUNT(b) AS c_b, + COUNT(DISTINCT b) AS cd_b, + COUNT(c) AS c_c, + COUNT(DISTINCT c) AS cd_c, + COUNT(d) AS c_d, + COUNT(DISTINCT d) AS cd_d, + COUNT(e) AS c_e, + COUNT(DISTINCT a) AS cd_e + FROM a + """, + a=a, + ) + + +def test_agg_count(): + a = make_rand_df( + 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) + ) + eq_sqlite( + """ + SELECT + a, b, a+1 AS c, + COUNT(c) AS c_c, + COUNT(DISTINCT c) AS cd_c, + COUNT(d) AS c_d, + COUNT(DISTINCT d) AS cd_d, + COUNT(e) AS c_e, + COUNT(DISTINCT a) AS cd_e + FROM a GROUP BY a, b + """, + a=a, + ) + + +def test_agg_sum_avg_no_group_by(): + eq_sqlite( + """ + SELECT + SUM(a) AS sum_a, + AVG(a) AS avg_a + FROM a + """, + a=pd.DataFrame({"a": [float("nan")]}), + ) + a = make_rand_df( + 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) + ) + eq_sqlite( + """ + SELECT + SUM(a) AS sum_a, + AVG(a) AS avg_a, + SUM(c) AS sum_c, + AVG(c) AS avg_c, + SUM(e) AS sum_e, + AVG(e) AS avg_e, + SUM(a)+AVG(e) AS mix_1, + SUM(a+e) AS mix_2 + FROM a + """, + a=a, + ) + + +def test_agg_sum_avg(): + a = make_rand_df( + 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) + ) + eq_sqlite( + """ + SELECT + a,b, a+1 AS c, + SUM(c) AS sum_c, + AVG(c) AS avg_c, + SUM(e) AS sum_e, + AVG(e) AS avg_e, + SUM(a)+AVG(e) AS mix_1, + SUM(a+e) AS mix_2 + FROM a GROUP BY a,b + """, + a=a, + ) + + +def test_agg_min_max_no_group_by(): + a = make_rand_df( + 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) + ) + eq_sqlite( + """ + SELECT + MIN(a) AS min_a, + MAX(a) AS max_a, + MIN(b) AS min_b, + MAX(b) AS max_b, + MIN(c) AS min_c, + MAX(c) AS max_c, + MIN(d) AS min_d, + MAX(d) AS max_d, + MIN(e) AS min_e, + MAX(e) AS max_e, + MIN(a+e) AS mix_1, + MIN(a)+MIN(e) AS mix_2 + FROM a + """, + a=a, + ) + + +def test_agg_min_max(): + a = make_rand_df( + 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) + ) + eq_sqlite( + """ + SELECT + a, b, a+1 AS c, + MIN(c) AS min_c, + MAX(c) AS max_c, + MIN(d) AS min_d, + MAX(d) AS max_d, + MIN(e) AS min_e, + MAX(e) AS max_e, + MIN(a+e) AS mix_1, + MIN(a)+MIN(e) AS mix_2 + FROM a GROUP BY a, b + """, + a=a, + ) + + +# TODO: Except not implemented so far +# def test_window_row_number(): +# a = make_rand_df(100, a=int, b=(float, 50)) +# eq_sqlite( +# """ +# SELECT *, +# ROW_NUMBER() OVER (ORDER BY a ASC, b DESC NULLS FIRST) AS a1, +# ROW_NUMBER() OVER (ORDER BY a ASC, b DESC NULLS LAST) AS a2, +# ROW_NUMBER() OVER (ORDER BY a ASC, b ASC NULLS FIRST) AS a3, +# ROW_NUMBER() OVER (ORDER BY a ASC, b ASC NULLS LAST) AS a4, +# ROW_NUMBER() OVER (PARTITION BY a ORDER BY a,b DESC) AS a5 +# FROM a +# """, +# a=a, +# ) + +# a = make_rand_df( +# 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=float +# ) +# eq_sqlite( +# """ +# SELECT *, +# ROW_NUMBER() OVER (ORDER BY a ASC, b DESC NULLS FIRST, e) AS a1, +# ROW_NUMBER() OVER (ORDER BY a ASC, b DESC NULLS LAST, e) AS a2, +# ROW_NUMBER() OVER (PARTITION BY a ORDER BY a,b DESC, e) AS a3, +# ROW_NUMBER() OVER (PARTITION BY a,c ORDER BY a,b DESC, e) AS a4 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_row_number_partition_by(): +# a = make_rand_df(100, a=int, b=(float, 50)) +# eq_sqlite( +# """ +# SELECT *, +# ROW_NUMBER() OVER (PARTITION BY a ORDER BY a,b DESC) AS a5 +# FROM a +# """, +# a=a, +# ) + +# a = make_rand_df( +# 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=float +# ) +# eq_sqlite( +# """ +# SELECT *, +# ROW_NUMBER() OVER (PARTITION BY a ORDER BY a,b DESC, e) AS a3, +# ROW_NUMBER() OVER (PARTITION BY a,c ORDER BY a,b DESC, e) AS a4 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_ranks(): +# a = make_rand_df(100, a=int, b=(float, 50), c=(str, 50)) +# eq_sqlite( +# """ +# SELECT *, +# RANK() OVER (PARTITION BY a ORDER BY b DESC NULLS FIRST, c) AS a1, +# DENSE_RANK() OVER (ORDER BY a ASC, b DESC NULLS LAST, c DESC) AS a2, +# PERCENT_RANK() OVER (ORDER BY a ASC, b ASC NULLS LAST, c) AS a4 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_ranks_partition_by(): +# a = make_rand_df(100, a=int, b=(float, 50), c=(str, 50)) +# eq_sqlite( +# """ +# SELECT *, +# RANK() OVER (PARTITION BY a ORDER BY b DESC NULLS FIRST, c) AS a1, +# DENSE_RANK() OVER +# (PARTITION BY a ORDER BY a ASC, b DESC NULLS LAST, c DESC) +# AS a2, +# PERCENT_RANK() OVER +# (PARTITION BY a ORDER BY a ASC, b ASC NULLS LAST, c) AS a4 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_lead_lag(): +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# """ +# SELECT +# LEAD(b,1) OVER (ORDER BY a) AS a1, +# LEAD(b,2,10) OVER (ORDER BY a) AS a2, +# LEAD(b,1) OVER (PARTITION BY c ORDER BY a) AS a3, +# LEAD(b,1) OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS a5, + +# LAG(b,1) OVER (ORDER BY a) AS b1, +# LAG(b,2,10) OVER (ORDER BY a) AS b2, +# LAG(b,1) OVER (PARTITION BY c ORDER BY a) AS b3, +# LAG(b,1) OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS b5 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_lead_lag_partition_by(): +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# """ +# SELECT +# LEAD(b,1,10) OVER (PARTITION BY c ORDER BY a) AS a3, +# LEAD(b,1) OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS a5, + +# LAG(b,1) OVER (PARTITION BY c ORDER BY a) AS b3, +# LAG(b,1) OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS b5 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_sum_avg(): +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# for func in ["SUM", "AVG"]: +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER () AS a1, +# {func}(b) OVER (PARTITION BY c) AS a2, +# {func}(b+a) OVER (PARTITION BY c,b) AS a3, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# AS a6 +# FROM a +# """, +# a=a, +# ) +# # >= 1.1.0 has bug on these agg function with groupby+rolloing +# # https://github.com/pandas-dev/pandas/issues/35557 +# if pd.__version__ < "1.1": +# # irregular windows +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6, +# {func}(b) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING) AS a7, +# {func}(b) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS a8 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_sum_avg_partition_by(): +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# for func in ["SUM", "AVG"]: +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b+a) OVER (PARTITION BY c,b) AS a3, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# AS a6 +# FROM a +# """, +# a=a, +# ) +# # 1.1.0 has bug on these agg function with groupby+rolloing +# # https://github.com/pandas-dev/pandas/issues/35557 +# if pd.__version__ < "1.1": +# # irregular windows +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6, +# {func}(b) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING) AS a7, +# {func}(b) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS a8 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_min_max(): +# for func in ["MIN", "MAX"]: +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER () AS a1, +# {func}(b) OVER (PARTITION BY c) AS a2, +# {func}(b+a) OVER (PARTITION BY c,b) AS a3, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# AS a6 +# FROM a +# """, +# a=a, +# ) +# # < 1.1.0 has bugs on these agg function with rolloing (no group by) +# if pd.__version__ >= "1.1": +# # irregular windows +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6, +# {func}(b) OVER (ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING) AS a7, +# {func}(b) OVER (ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS a8 +# FROM a +# """, +# a=a, +# ) +# # == 1.1.0 has bugs on these agg function with rolloing (with group by) +# # https://github.com/pandas-dev/pandas/issues/35557 +# # < 1.1.0 has bugs on nulls when rolling with forward looking +# if pd.__version__ < "1.1": +# b = make_rand_df(10, a=float, b=(int, 0), c=(str, 0)) +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6 +# FROM a +# """, +# a=b, +# ) + +# TODO: Except not implemented so far +# def test_window_min_max_partition_by(): +# for func in ["MIN", "MAX"]: +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (PARTITION BY c) AS a2, +# {func}(b+a) OVER (PARTITION BY c,b) AS a3, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# AS a6 +# FROM a +# """, +# a=a, +# ) +# # >= 1.1.0 has bugs on these agg function with rolloing (with group by) +# # https://github.com/pandas-dev/pandas/issues/35557 +# # < 1.1.0 has bugs on nulls when rolling with forward looking +# if pd.__version__ < "1.1": +# b = make_rand_df(10, a=float, b=(int, 0), c=(str, 0)) +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6 +# FROM a +# """, +# a=b, +# ) + +# TODO: Except not implemented so far +# def test_window_count(): +# for func in ["COUNT"]: +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER () AS a1, +# {func}(b) OVER (PARTITION BY c) AS a2, +# {func}(b+a) OVER (PARTITION BY c,b) AS a3, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# AS a6, + +# {func}(c) OVER () AS b1, +# {func}(c) OVER (PARTITION BY c) AS b2, +# {func}(c) OVER (PARTITION BY c,b) AS b3, +# {func}(c) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS b4, +# {func}(c) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS b5, +# {func}(c) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# AS b6 +# FROM a +# """, +# a=a, +# ) +# # < 1.1.0 has bugs on these agg function with rolloing (no group by) +# # == 1.1.0 has this bug +# # https://github.com/pandas-dev/pandas/issues/35579 +# if pd.__version__ >= "1.1": +# # irregular windows +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS a6, +# {func}(b) OVER (PARTITION BY c ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS a9, + +# {func}(c) OVER (ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS b6, +# {func}(c) OVER (PARTITION BY c ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS b9 +# FROM a +# """, +# a=a, +# ) + +# TODO: Except not implemented so far +# def test_window_count_partition_by(): +# for func in ["COUNT"]: +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (PARTITION BY c) AS a2, +# {func}(b+a) OVER (PARTITION BY c,b) AS a3, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, +# {func}(b+a) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# AS a6, + +# {func}(c) OVER (PARTITION BY c) AS b2, +# {func}(c) OVER (PARTITION BY c,b) AS b3, +# {func}(c) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS b4, +# {func}(c) OVER (PARTITION BY b ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS b5, +# {func}(c) OVER (PARTITION BY b ORDER BY a +# ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# AS b6 +# FROM a +# """, +# a=a, +# ) +# # < 1.1.0 has bugs on these agg function with rolloing (no group by) +# # == 1.1.0 has this bug +# # https://github.com/pandas-dev/pandas/issues/35579 +# if pd.__version__ >= "1.1": +# # irregular windows +# eq_sqlite( +# f""" +# SELECT a,b, +# {func}(b) OVER (PARTITION BY c ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS a9, + +# {func}(c) OVER (PARTITION BY c ORDER BY a DESC +# ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS b9 +# FROM a +# """, +# a=a, +# ) + +# TODO: Windowing not implemented so far +# def test_nested_query(): +# a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# """ +# SELECT * FROM ( +# SELECT *, +# ROW_NUMBER() OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS r +# FROM a) +# WHERE r=1 +# """, +# a=a, +# ) + + +def test_union(): + a = make_rand_df(30, b=(int, 10), c=(str, 10)) + b = make_rand_df(80, b=(int, 50), c=(str, 50)) + c = make_rand_df(100, b=(int, 50), c=(str, 50)) + eq_sqlite( + """ + SELECT * FROM a + UNION SELECT * FROM b + UNION SELECT * FROM c + ORDER BY b NULLS FIRST, c NULLS FIRST + """, + a=a, + b=b, + c=c, + ) + eq_sqlite( + """ + SELECT * FROM a + UNION ALL SELECT * FROM b + UNION ALL SELECT * FROM c + ORDER BY b NULLS FIRST, c NULLS FIRST + """, + a=a, + b=b, + c=c, + ) + + +# TODO: Except not implemented so far +# def test_except(): +# a = make_rand_df(30, b=(int, 10), c=(str, 10)) +# b = make_rand_df(80, b=(int, 50), c=(str, 50)) +# c = make_rand_df(100, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# """ +# SELECT * FROM c +# EXCEPT SELECT * FROM b +# EXCEPT SELECT * FROM c +# """, +# a=a, +# b=b, +# c=c, +# ) + +# TODO: Intersect not implemented so far +# def test_intersect(): +# a = make_rand_df(30, b=(int, 10), c=(str, 10)) +# b = make_rand_df(80, b=(int, 50), c=(str, 50)) +# c = make_rand_df(100, b=(int, 50), c=(str, 50)) +# eq_sqlite( +# """ +# SELECT * FROM c +# INTERSECT SELECT * FROM b +# INTERSECT SELECT * FROM c +# """, +# a=a, +# b=b, +# c=c, +# ) + + +def test_with(): + a = make_rand_df(30, a=(int, 10), b=(str, 10)) + b = make_rand_df(80, ax=(int, 10), bx=(str, 10)) + eq_sqlite( + """ + WITH + aa AS ( + SELECT a AS aa, b AS bb FROM a + ), + c AS ( + SELECT aa-1 AS aa, bb FROM aa + ) + SELECT * FROM c UNION SELECT * FROM b + ORDER BY aa NULLS FIRST, bb NULLS FIRST + """, + a=a, + b=b, + ) + + +def test_integration_1(): + a = make_rand_df(100, a=int, b=str, c=float, d=int, e=bool, f=str, g=str, h=float) + eq_sqlite( + """ + WITH + a1 AS ( + SELECT a+1 AS a, b, c FROM a + ), + a2 AS ( + SELECT a,MAX(b) AS b_max, AVG(c) AS c_avg FROM a GROUP BY a + ), + a3 AS ( + SELECT d+2 AS d, f, g, h FROM a WHERE e + ) + SELECT a1.a,b,c,b_max,c_avg,f,g,h FROM a1 + INNER JOIN a2 ON a1.a=a2.a + LEFT JOIN a3 ON a1.a=a3.d + ORDER BY a1.a NULLS FIRST, b NULLS FIRST, c NULLS FIRST, f NULLS FIRST, g NULLS FIRST, h NULLS FIRST + """, + a=a, + ) diff --git a/tests/integration/test_groupby.py b/tests/integration/test_groupby.py index 76d35939d..d305d6b06 100644 --- a/tests/integration/test_groupby.py +++ b/tests/integration/test_groupby.py @@ -127,7 +127,7 @@ def test_group_by_nan(c): ) df = df.compute() - expected_df = pd.DataFrame({"c": [3, 1]}) + expected_df = pd.DataFrame({"c": [3, float("nan"), 1]}) # The dtype in pandas 1.0.5 and pandas 1.1.0 are different, so # we can not check here assert_frame_equal(df, expected_df, check_dtype=False) @@ -206,3 +206,16 @@ def test_aggregations(c): } ) assert_frame_equal(df.sort_values("user_id").reset_index(drop=True), expected_df) + + df = c.sql( + """ + SELECT + MAX(a) AS "max", + MIN(a) AS "min" + FROM string_table + """ + ) + df = df.compute() + + expected_df = pd.DataFrame({"max": ["a normal string"], "min": ["%_%"]}) + assert_frame_equal(df.reset_index(drop=True), expected_df) diff --git a/tests/integration/test_sort.py b/tests/integration/test_sort.py index 5d8c83807..ffec77a84 100644 --- a/tests/integration/test_sort.py +++ b/tests/integration/test_sort.py @@ -228,6 +228,29 @@ def test_sort_with_nan_more_columns(): ) +def test_sort_with_nan_many_partitions(): + c = Context() + df = pd.DataFrame({"a": [float("nan"), 1] * 30, "b": [1, 2, 3] * 20,}) + c.create_table("df", dd.from_pandas(df, npartitions=10)) + + df_result = ( + c.sql("SELECT * FROM df ORDER BY a NULLS FIRST, b ASC NULLS FIRST") + .compute() + .reset_index(drop=True) + ) + + assert_frame_equal( + df_result, + pd.DataFrame( + { + "a": [float("nan")] * 30 + [1] * 30, + "b": [1] * 10 + [2] * 10 + [3] * 10 + [1] * 10 + [2] * 10 + [3] * 10, + } + ), + check_names=False, + ) + + def test_sort_strings(c): string_table = pd.DataFrame({"a": ["zzhsd", "öfjdf", "baba"]}) c.create_table("string_table", string_table) From 8c5ec9ed9efe1948dd29f716015ea86caa94c28e Mon Sep 17 00:00:00 2001 From: Demian Wassermann Date: Mon, 8 Feb 2021 23:02:09 +0100 Subject: [PATCH 04/33] Test for except --- tests/integration/test_except.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 tests/integration/test_except.py diff --git a/tests/integration/test_except.py b/tests/integration/test_except.py new file mode 100644 index 000000000..9fe98131a --- /dev/null +++ b/tests/integration/test_except.py @@ -0,0 +1,29 @@ +def test_except_empty(c, df): + result_df = c.sql( + """ + SELECT * FROM df + EXCEPT + SELECT * FROM df + """ + ) + result_df = result_df.compute() + assert len(result_df) == 0 + + +def test_except_non_empty(c, df): + result_df = c.sql( + """ + ( + SELECT 1 as "a" + UNION + SELECT 2 as "a" + UNION + SELECT 3 as "a" + ) + EXCEPT + SELECT 2 as "a" + """ + ) + result_df = result_df.compute() + assert result_df.columns == "a" + assert set(result_df["a"]) == set([1, 3]) From 3a93a1160c86a0c0df0a055afc4cec037349b954 Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Tue, 9 Feb 2021 09:53:41 +0100 Subject: [PATCH 05/33] Added Intersect support --- dask_sql/context.py | 1 + dask_sql/physical/rel/logical/__init__.py | 1 + dask_sql/physical/rel/logical/intersect.py | 73 ++++++++++++++++++++++ 3 files changed, 75 insertions(+) create mode 100644 dask_sql/physical/rel/logical/intersect.py diff --git a/dask_sql/context.py b/dask_sql/context.py index 220790136..108595db0 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -86,6 +86,7 @@ def __init__(self): RelConverter.add_plugin_class(logical.LogicalSortPlugin, replace=False) RelConverter.add_plugin_class(logical.LogicalTableScanPlugin, replace=False) RelConverter.add_plugin_class(logical.LogicalUnionPlugin, replace=False) + RelConverter.add_plugin_class(logical.LogicalIntersectPlugin, replace=False) RelConverter.add_plugin_class(logical.LogicalMinusPlugin, replace=False) RelConverter.add_plugin_class(logical.LogicalValuesPlugin, replace=False) RelConverter.add_plugin_class(logical.SamplePlugin, replace=False) diff --git a/dask_sql/physical/rel/logical/__init__.py b/dask_sql/physical/rel/logical/__init__.py index 9d429e6ac..cffc2b17f 100644 --- a/dask_sql/physical/rel/logical/__init__.py +++ b/dask_sql/physical/rel/logical/__init__.py @@ -6,6 +6,7 @@ from .sort import LogicalSortPlugin from .table_scan import LogicalTableScanPlugin from .union import LogicalUnionPlugin +from .intersect import LogicalIntersectPlugin from .minus import LogicalMinusPlugin from .values import LogicalValuesPlugin diff --git a/dask_sql/physical/rel/logical/intersect.py b/dask_sql/physical/rel/logical/intersect.py new file mode 100644 index 000000000..8c5b6210b --- /dev/null +++ b/dask_sql/physical/rel/logical/intersect.py @@ -0,0 +1,73 @@ +import dask.dataframe as dd + +from dask_sql.physical.rex import RexConverter +from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.datacontainer import DataContainer, ColumnContainer + + +class LogicalIntersectPlugin(BaseRelPlugin): + """ + LogicalIntersect is used on INTERSECT clauses. + It just concatonates the two data frames. + """ + + class_name = "org.apache.calcite.rel.logical.LogicalIntersect" + + def convert( + self, + rel: "org.apache.calcite.rel.RelNode", + context: "dask_sql.Context", + ) -> DataContainer: + first_dc, second_dc = self.assert_inputs(rel, 2, context) + + first_df = first_dc.df + first_cc = first_dc.column_container + + second_df = second_dc.df + second_cc = second_dc.column_container + + # For concatenating, they should have exactly the same fields + output_field_names = [str(x) for x in rel.getRowType().getFieldNames()] + assert len(first_cc.columns) == len(output_field_names) + first_cc = first_cc.rename( + columns={ + col: output_col + for col, output_col in zip( + first_cc.columns, output_field_names + ) + } + ) + first_dc = DataContainer(first_df, first_cc) + + assert len(second_cc.columns) == len(output_field_names) + second_cc = second_cc.rename( + columns={ + col: output_col + for col, output_col in zip( + second_cc.columns, output_field_names + ) + } + ) + second_dc = DataContainer(second_df, second_cc) + + # To concat the to dataframes, we need to make sure the + # columns actually have the specified names in the + # column containers + # Otherwise the concat won't work + first_df = first_dc.assign() + second_df = second_dc.assign() + + self.check_columns_from_row_type( + first_df, rel.getExpectedInputRowType(0) + ) + self.check_columns_from_row_type( + second_df, rel.getExpectedInputRowType(1) + ) + + df = first_df.merge(second_df, how="inner") + + cc = ColumnContainer(df.columns) + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + dc = DataContainer(df, cc) + dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) + return dc From 83ddf0527e8b1f6765bd5187eebaf2a941dffde3 Mon Sep 17 00:00:00 2001 From: Demian Wassermann Date: Tue, 9 Feb 2021 16:37:50 +0100 Subject: [PATCH 06/33] JoinToMultiJoinRule rule insterted to reorganiser joins --- .../com/dask/sql/application/RelationalAlgebraGenerator.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java index 04eefc60b..9cd486a57 100644 --- a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java +++ b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java @@ -32,6 +32,7 @@ import org.apache.calcite.rel.rules.FilterJoinRule; import org.apache.calcite.rel.rules.FilterMergeRule; import org.apache.calcite.rel.rules.FilterRemoveIsNotDistinctFromRule; +import org.apache.calcite.rel.rules.JoinToMultiJoinRule; import org.apache.calcite.rel.rules.LoptOptimizeJoinRule; import org.apache.calcite.rel.rules.ProjectJoinTransposeRule; import org.apache.calcite.rel.rules.ProjectMergeRule; @@ -141,6 +142,7 @@ private HepPlanner getHepPlanner(final FrameworkConfig config) { .addRuleInstance(FilterAggregateTransposeRule.Config.DEFAULT.toRule()) .addRuleInstance(FilterJoinRule.JoinConditionPushRule.Config.DEFAULT.toRule()) .addRuleInstance(FilterJoinRule.FilterIntoJoinRule.Config.DEFAULT.toRule()) + .addRuleInstance(JoinToMultiJoinRule.Config.DEFAULT.toRule()) .addRuleInstance(LoptOptimizeJoinRule.Config.DEFAULT.toRule()) .addRuleInstance(ProjectMergeRule.Config.DEFAULT.toRule()) .addRuleInstance(FilterMergeRule.Config.DEFAULT.toRule()) From 164099067a0d6bed8cc8aa6303048e39c754f632 Mon Sep 17 00:00:00 2001 From: Demian Wassermann Date: Tue, 9 Feb 2021 19:58:40 +0100 Subject: [PATCH 07/33] WIP --- dask_sql/physical/rel/convert.py | 5 ++++- .../dask/sql/application/RelationalAlgebraGenerator.java | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/dask_sql/physical/rel/convert.py b/dask_sql/physical/rel/convert.py index 77f8a3ef7..d7ef0a263 100644 --- a/dask_sql/physical/rel/convert.py +++ b/dask_sql/physical/rel/convert.py @@ -1,4 +1,5 @@ import logging +import time import dask.dataframe as dd @@ -53,6 +54,8 @@ def convert( logger.debug( f"Processing REL {rel} using {plugin_instance.__class__.__name__}..." ) + start_time = time.perf_counter() df = plugin_instance.convert(rel, context=context) - logger.debug(f"Processed REL {rel} into {LoggableDataFrame(df)}") + elapsed_time = time.perf_counter() - start_time + logger.debug(f"Processed REL {rel} into {LoggableDataFrame(df)} ({elapsed_time}s)") return df diff --git a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java index 9cd486a57..5ddbd4654 100644 --- a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java +++ b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java @@ -31,9 +31,11 @@ import org.apache.calcite.rel.rules.FilterSetOpTransposeRule; import org.apache.calcite.rel.rules.FilterJoinRule; import org.apache.calcite.rel.rules.FilterMergeRule; +import org.apache.calcite.rel.rules.FilterMultiJoinMergeRule; import org.apache.calcite.rel.rules.FilterRemoveIsNotDistinctFromRule; import org.apache.calcite.rel.rules.JoinToMultiJoinRule; import org.apache.calcite.rel.rules.LoptOptimizeJoinRule; +import org.apache.calcite.rel.rules.MultiJoinOptimizeBushyRule; import org.apache.calcite.rel.rules.ProjectJoinTransposeRule; import org.apache.calcite.rel.rules.ProjectMergeRule; import org.apache.calcite.rel.rules.ProjectRemoveRule; @@ -142,11 +144,13 @@ private HepPlanner getHepPlanner(final FrameworkConfig config) { .addRuleInstance(FilterAggregateTransposeRule.Config.DEFAULT.toRule()) .addRuleInstance(FilterJoinRule.JoinConditionPushRule.Config.DEFAULT.toRule()) .addRuleInstance(FilterJoinRule.FilterIntoJoinRule.Config.DEFAULT.toRule()) + .addRuleInstance(ProjectJoinTransposeRule.Config.DEFAULT.toRule()) .addRuleInstance(JoinToMultiJoinRule.Config.DEFAULT.toRule()) + .addRuleInstance(FilterMultiJoinMergeRule.Config.DEFAULT.toRule()) .addRuleInstance(LoptOptimizeJoinRule.Config.DEFAULT.toRule()) + .addRuleInstance(MultiJoinOptimizeBushyRule.Config.DEFAULT.toRule()) .addRuleInstance(ProjectMergeRule.Config.DEFAULT.toRule()) .addRuleInstance(FilterMergeRule.Config.DEFAULT.toRule()) - .addRuleInstance(ProjectJoinTransposeRule.Config.DEFAULT.toRule()) // In principle, not a bad idea. But we need to keep the most // outer project - because otherwise the column name information is lost // in cases such as SELECT x AS a, y AS B FROM df From ddf1551786c69c3bcf379cb9777dad17a4b589a4 Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Thu, 11 Feb 2021 15:01:25 +0100 Subject: [PATCH 08/33] Adding missing import --- dask_sql/physical/rel/logical/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dask_sql/physical/rel/logical/__init__.py b/dask_sql/physical/rel/logical/__init__.py index cffc2b17f..eac53d05a 100644 --- a/dask_sql/physical/rel/logical/__init__.py +++ b/dask_sql/physical/rel/logical/__init__.py @@ -18,6 +18,7 @@ LogicalSortPlugin, LogicalTableScanPlugin, LogicalUnionPlugin, + LogicalIntersectPlugin, LogicalMinusPlugin, LogicalValuesPlugin, SamplePlugin, From 5c13175365395a7f4e24f1bf970744203d6658f2 Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Fri, 12 Feb 2021 10:16:16 +0100 Subject: [PATCH 09/33] Rewriting assign method to use rename as much as possible --- dask_sql/datacontainer.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/dask_sql/datacontainer.py b/dask_sql/datacontainer.py index bddfbcad4..c471f9173 100644 --- a/dask_sql/datacontainer.py +++ b/dask_sql/datacontainer.py @@ -42,7 +42,8 @@ def _copy(self) -> ColumnContainer: Internal function to copy this container """ return ColumnContainer( - self._frontend_columns.copy(), self._frontend_backend_mapping.copy() + self._frontend_columns.copy(), + self._frontend_backend_mapping.copy(), ) def limit_to(self, fields: List[str]) -> ColumnContainer: @@ -137,7 +138,9 @@ def make_unique(self, prefix="col"): where is the column index. """ return self.rename( - columns={str(col): f"{prefix}_{i}" for i, col in enumerate(self.columns)} + columns={ + str(col): f"{prefix}_{i}" for i, col in enumerate(self.columns) + } ) @@ -166,11 +169,22 @@ def assign(self) -> dd.DataFrame: a dataframe which has the the columns specified in the stored ColumnContainer. """ - df = self.df.assign( - **{ - col_from: self.df[col_to] - for col_from, col_to in self.column_container.mapping() - if col_from in self.column_container.columns - } - ) + # We rename as many cols as possible because renaming is much more + # efficient than assigning. + + renames = {} + assigns = {} + for col_from, col_to in self.column_container.mapping(): + if col_from in self.column_container.columns: + if ( + len(renames) < len(self.df.columns) + and col_to not in renames + and (col_from not in self.df.columns or col_from == col_to) + ): + renames[col_to] = col_from + else: + assigns[col_from] = self.df[col_to] + df = self.df.rename(columns=renames) + if len(assigns) > 0: + df = df.assign(**assigns) return df[self.column_container.columns] From 391acbf7c433ca8e38021c3a1b00550c0ad4067f Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Fri, 12 Feb 2021 14:27:50 +0100 Subject: [PATCH 10/33] Rewrite inner-join to merge on already existing columns instead of adding temporary columns to merge on --- dask_sql/physical/rel/logical/join.py | 88 +++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py index 093845304..f19a96a28 100644 --- a/dask_sql/physical/rel/logical/join.py +++ b/dask_sql/physical/rel/logical/join.py @@ -92,6 +92,11 @@ def convert( # We therefore create new columns on purpose, which have a distinct name. assert len(lhs_on) == len(rhs_on) if lhs_on: + if join_type == "inner": + return self._do_inner_join_inplace( + df_lhs_renamed, df_rhs_renamed, lhs_on, rhs_on, + join_type, rel, filter_condition, context + ) lhs_columns_to_add = { f"common_{i}": df_lhs_renamed.iloc[:, index] for i, index in enumerate(lhs_on) @@ -176,6 +181,89 @@ def convert( dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) return dc + def _do_inner_join_inplace( + self, df_lhs_renamed, df_rhs_renamed, lhs_on, rhs_on, + join_type, rel, filter_condition, context + ): + """ + Same method as above, but instead of adding temporary join columns to merge on, + we merge on columns already in the dataframe by renaming them. + """ + df_lhs_to_merge = df_lhs_renamed.rename(columns={ + df_lhs_renamed.columns[index]: f"common_{i}" + for i, index in enumerate(lhs_on) + }) + df_rhs_to_merge = df_rhs_renamed.rename(columns={ + df_rhs_renamed.columns[index]: f"common_{i}" + for i, index in enumerate(rhs_on) + }) + on_columns = [f"common_{i}" for i in range(len(lhs_on))] + + # SQL compatibility: when joining on columns that + # contain NULLs, pandas will actually happily + # keep those NULLs. That is however not compatible with + # SQL, so we get rid of them here + df_lhs_filter = reduce( + operator.and_, + [~df_lhs_to_merge.iloc[:, index].isna() for index in lhs_on], + ) + df_lhs_to_merge = df_lhs_to_merge[df_lhs_filter] + df_rhs_filter = reduce( + operator.and_, + [~df_rhs_to_merge.iloc[:, index].isna() for index in rhs_on], + ) + df_rhs_to_merge = df_rhs_to_merge[df_rhs_filter] + + df = dd.merge(df_lhs_to_merge, df_rhs_to_merge, on=on_columns, how=join_type) + + # 6. So the next step is to make sure + # we have the correct column order. + correct_column_order = list(df_lhs_renamed.columns) + list( + df_rhs_renamed.columns + ) + # We update the columns for the rhs to point to the resulting join columns + if lhs_on: + for i, on_column in enumerate(on_columns): + correct_column_order[lhs_on[i]] = on_column + correct_column_order[len(df_lhs_renamed.columns) + rhs_on[i]] = on_column + cc = ColumnContainer(df.columns).limit_to(correct_column_order) + + # and to rename them like the rel specifies + row_type = rel.getRowType() + field_specifications = [str(f) for f in row_type.getFieldNames()] + l_lhs = len(df_lhs_renamed.columns) + cc = cc.rename( + { + from_col: to_col + for from_col, to_col in zip(cc.columns[:l_lhs], field_specifications[:l_lhs]) + } + ) + cc = cc.rename( + { + from_col: to_col + for from_col, to_col in zip(cc.columns[l_lhs:], field_specifications[l_lhs:]) + } + ) + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + dc = DataContainer(df, cc) + + # 7. Last but not least we apply any filters by and-chaining together the filters + if filter_condition: + # This line is a bit of code duplication with RexCallPlugin - but I guess it is worth to keep it separate + filter_condition = reduce( + operator.and_, + [ + RexConverter.convert(rex, dc, context=context) + for rex in filter_condition + ], + ) + logger.debug(f"Additionally applying filter {filter_condition}") + df = filter_or_scalar(df, filter_condition) + dc = DataContainer(df, cc) + + dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) + return dc + def _split_join_condition( self, join_condition: "org.apache.calcite.rex.RexCall" ) -> Tuple[List[str], List[str], List["org.apache.calcite.rex.RexCall"]]: From 3feaa0128415fb290874ee725517afb24c9fabb3 Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Tue, 2 Mar 2021 10:13:30 +0100 Subject: [PATCH 11/33] Commenting the inplace join method which id added as it seems to not work in all cases. Needs to be reworked. --- dask_sql/physical/rel/logical/join.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py index f19a96a28..dd3a2aa47 100644 --- a/dask_sql/physical/rel/logical/join.py +++ b/dask_sql/physical/rel/logical/join.py @@ -92,11 +92,14 @@ def convert( # We therefore create new columns on purpose, which have a distinct name. assert len(lhs_on) == len(rhs_on) if lhs_on: - if join_type == "inner": - return self._do_inner_join_inplace( - df_lhs_renamed, df_rhs_renamed, lhs_on, rhs_on, - join_type, rel, filter_condition, context - ) + # Doing inplace join seems to have unexpected side effects due to + # the join columns being shared in the resulting df for the lhs and rhs. + # It needs to be reworked. + # if join_type == "inner": + # return self._do_inner_join_inplace( + # df_lhs_renamed, df_rhs_renamed, lhs_on, rhs_on, + # join_type, rel, filter_condition, context + # ) lhs_columns_to_add = { f"common_{i}": df_lhs_renamed.iloc[:, index] for i, index in enumerate(lhs_on) From 4cfd39c9f16ff6af0c5a276770bb5b80e37e5dca Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Tue, 9 Mar 2021 09:46:04 +0100 Subject: [PATCH 12/33] Trying to make aggregations work on multi col params. Untested --- dask_sql/physical/rel/logical/aggregate.py | 38 ++++++++++++++++------ 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/dask_sql/physical/rel/logical/aggregate.py b/dask_sql/physical/rel/logical/aggregate.py index 8fa832918..17ddbf322 100644 --- a/dask_sql/physical/rel/logical/aggregate.py +++ b/dask_sql/physical/rel/logical/aggregate.py @@ -88,6 +88,7 @@ class LogicalAggregatePlugin(BaseRelPlugin): AGGREGATION_MAPPING = { "$sum0": AggregationSpecification("sum", AggregationOnPandas("sum")), + "sum": AggregationSpecification("sum", AggregationOnPandas("sum")), "any_value": AggregationSpecification( dd.Aggregation( "any_value", @@ -242,7 +243,7 @@ def _collect_aggregations( elif len(inputs) == 0: input_col = additional_column_name else: - raise NotImplementedError("Can not cope with more than one input") + input_col = tuple(cc.get_backend_by_frontend_index(inputs[i]) for i in inputs) # Extract flags (filtering/distinct) if expr.isDistinct(): # pragma: no cover @@ -265,7 +266,10 @@ def _collect_aggregations( f"Aggregation function {aggregation_name} not implemented (yet)." ) if isinstance(aggregation_function, AggregationSpecification): - dtype = df[input_col].dtype + if isinstance(input_col, tuple): + dtype = df[input_col[0]].dtype + else: + dtype = df[input_col].dtype if pd.api.types.is_numeric_dtype(dtype): aggregation_function = aggregation_function.numerical_aggregation else: @@ -322,16 +326,30 @@ def _perform_aggregation( # Convert into the correct format for dask aggregations_dict = defaultdict(dict) + multi_col_aggregations = dict() for aggregation in aggregations: input_col, output_col, aggregation_f = aggregation - - aggregations_dict[input_col][output_col] = aggregation_f + if isinstance(input_col, tuple): + multi_col_aggregations[output_col] = (input_col, aggregation_f) + else: + aggregations_dict[input_col][output_col] = aggregation_f # Now apply the aggregation - logger.debug(f"Performing aggregation {dict(aggregations_dict)}") - agg_result = grouped_df.agg(aggregations_dict) - - # ... fix the column names to a single level ... - agg_result.columns = agg_result.columns.get_level_values(-1) - + agg_result = None + if len(aggregations_dict) > 0: + logger.debug(f"Performing aggregation {dict(aggregations_dict)}") + agg_result = grouped_df.agg(aggregations_dict) + + # ... fix the column names to a single level ... + agg_result.columns = agg_result.columns.get_level_values(-1) + + + # apply multi-column aggregations + for output_col, (input_col, aggregation_f) in multi_col_aggregations.items(): + new_col = grouped_df.apply(lambda x: aggregation_f(*[getattr(x, col) for col in input_col])) + if agg_result is None: + agg_result = new_col.rename(output_col).to_frame() + else: + agg_result = agg_result.assign(**{output_col: new_col}) + return agg_result From be11cd6d1b652afa38be1f1a2b2e6d9cc3cf6785 Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Tue, 9 Mar 2021 17:12:54 +0100 Subject: [PATCH 13/33] Better and cleaner rule management in Calcite app --- .../dask/sql/application/DaskRuleSets.java | 153 ++++++++++++++++++ .../RelationalAlgebraGenerator.java | 81 +++++++--- 2 files changed, 208 insertions(+), 26 deletions(-) create mode 100644 planner/src/main/java/com/dask/sql/application/DaskRuleSets.java diff --git a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java new file mode 100644 index 000000000..6951abba8 --- /dev/null +++ b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java @@ -0,0 +1,153 @@ +package com.dask.sql.application; + +import org.apache.calcite.rel.rules.CoreRules; +import org.apache.calcite.tools.RuleSet; +import org.apache.calcite.tools.RuleSets; + +public class DaskRuleSets { + /** + * Convert sub-queries before query decorrelation. + */ + static final RuleSet TABLE_SUBQUERY_RULES = RuleSets.ofList(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE, + CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE, CoreRules.JOIN_SUB_QUERY_TO_CORRELATE); + + /** + * RuleSet to reduce expressions + */ + static final RuleSet REDUCE_EXPRESSION_RULES = RuleSets.ofList(CoreRules.FILTER_REDUCE_EXPRESSIONS, + CoreRules.PROJECT_REDUCE_EXPRESSIONS, CoreRules.CALC_REDUCE_EXPRESSIONS, CoreRules.JOIN_REDUCE_EXPRESSIONS, + CoreRules.AGGREGATE_REDUCE_FUNCTIONS); + + /** + * RuleSet about filter + */ + static final RuleSet FILTER_RULES = RuleSets.ofList( + // push a filter into a join + CoreRules.FILTER_INTO_JOIN, + // push filter into the children of a join + CoreRules.JOIN_CONDITION_PUSH, + // push filter through an aggregation + CoreRules.FILTER_AGGREGATE_TRANSPOSE, + // push a filter past a project + CoreRules.FILTER_PROJECT_TRANSPOSE, + // push a filter past a setop + CoreRules.FILTER_SET_OP_TRANSPOSE, CoreRules.FILTER_MERGE); + + /** + * RuleSet about project + */ + static final RuleSet PROJECT_RULES = RuleSets.ofList( + // push a projection past a filter + CoreRules.PROJECT_FILTER_TRANSPOSE, + // merge projections + CoreRules.PROJECT_MERGE, + // Don't add PROJECT_REMOVE + // CoreRules.PROJECT_REMOVE, + // removes constant keys from an Agg + CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS, + // push project through a Union + CoreRules.PROJECT_SET_OP_TRANSPOSE, CoreRules.PROJECT_JOIN_TRANSPOSE); + + /** + * RuleSet about aggregate + */ + static final RuleSet AGGREGATE_RULES = RuleSets.ofList(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN, + CoreRules.AGGREGATE_JOIN_TRANSPOSE, CoreRules.AGGREGATE_PROJECT_MERGE, CoreRules.PROJECT_AGGREGATE_MERGE, + CoreRules.AGGREGATE_MERGE); + + /** + * RuleSet for merging joins + */ + static final RuleSet JOIN_REORDER_PREPARE_RULES = RuleSets.ofList( + // merge project to MultiJoin + CoreRules.PROJECT_MULTI_JOIN_MERGE, + // merge filter to MultiJoin + CoreRules.FILTER_MULTI_JOIN_MERGE, + // merge join to MultiJoin + CoreRules.JOIN_TO_MULTI_JOIN); + + /** + * Rules to reorder joins + */ + static final RuleSet JOIN_REORDER_RULES = RuleSets.ofList(CoreRules.MULTI_JOIN_OPTIMIZE, + CoreRules.MULTI_JOIN_OPTIMIZE_BUSHY); + + /** + * RuleSet to do logical optimize. + */ + static final RuleSet LOGICAL_RULES = RuleSets.ofList( + // scan optimization + // PushProjectIntoTableSourceScanRule.INSTANCE, + // PushProjectIntoLegacyTableSourceScanRule.INSTANCE, + // PushFilterIntoTableSourceScanRule.INSTANCE, + // PushFilterIntoLegacyTableSourceScanRule.INSTANCE, + // PushLimitIntoTableSourceScanRule.INSTANCE, + + // reorder sort and projection + // CoreRules.SORT_PROJECT_TRANSPOSE, + // remove unnecessary sort rule + // CoreRules.SORT_REMOVE, + + // join rules + // FlinkJoinPushExpressionsRule.INSTANCE, + // SimplifyJoinConditionRule.INSTANCE, + + // remove union with only a single child + CoreRules.UNION_REMOVE, + // convert non-all union into all-union + distinct + CoreRules.UNION_TO_DISTINCT, + + // aggregation and projection rules + CoreRules.AGGREGATE_PROJECT_MERGE, CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS, + + // remove aggregation if it does not aggregate and input is already distinct + // FlinkAggregateRemoveRule.INSTANCE, + // push aggregate through join + // FlinkAggregateJoinTransposeRule.EXTENDED, + // using variants of aggregate union rule + CoreRules.AGGREGATE_UNION_AGGREGATE_FIRST, CoreRules.AGGREGATE_UNION_AGGREGATE_SECOND + + // reduce aggregate functions like AVG, STDDEV_POP etc. + // CoreRules.AGGREGATE_REDUCE_FUNCTIONS, + + // reduce useless aggCall + // PruneAggregateCallRule.PROJECT_ON_AGGREGATE, + // PruneAggregateCallRule.CALC_ON_AGGREGATE, + + // expand grouping sets + // DecomposeGroupingSetsRule.INSTANCE, + + // calc rules + // CoreRules.FILTER_CALC_MERGE, CoreRules.PROJECT_CALC_MERGE, + // CoreRules.FILTER_TO_CALC, + // CoreRules.PROJECT_TO_CALC + // FlinkCalcMergeRule.INSTANCE, + + // semi/anti join transpose rule + // FlinkSemiAntiJoinJoinTransposeRule.INSTANCE, + // FlinkSemiAntiJoinProjectTransposeRule.INSTANCE, + // FlinkSemiAntiJoinFilterTransposeRule.INSTANCE, + + // set operators + // ReplaceIntersectWithSemiJoinRule.INSTANCE, + // RewriteIntersectAllRule.INSTANCE, + // ReplaceMinusWithAntiJoinRule.INSTANCE, + // RewriteMinusAllRule.INSTANCE + ); + + /** + * Initial rule set from dask_sql with a couple rules added by Demian. + */ + static final RuleSet DASK_DEFAULT_CORE_RULES = RuleSets.ofList( + CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN, CoreRules.FILTER_SET_OP_TRANSPOSE, + CoreRules.FILTER_AGGREGATE_TRANSPOSE, CoreRules.FILTER_INTO_JOIN, CoreRules.JOIN_CONDITION_PUSH, + CoreRules.PROJECT_JOIN_TRANSPOSE, CoreRules.PROJECT_MULTI_JOIN_MERGE, CoreRules.JOIN_TO_MULTI_JOIN, + CoreRules.MULTI_JOIN_OPTIMIZE, CoreRules.MULTI_JOIN_OPTIMIZE_BUSHY, CoreRules.AGGREGATE_JOIN_TRANSPOSE, + CoreRules.AGGREGATE_PROJECT_MERGE, CoreRules.PROJECT_AGGREGATE_MERGE, CoreRules.AGGREGATE_MERGE, + CoreRules.PROJECT_MERGE, CoreRules.FILTER_MERGE, + // Don't add this rule as it removes projections which are used to rename colums + // CoreRules.PROJECT_REMOVE, + CoreRules.PROJECT_REDUCE_EXPRESSIONS, CoreRules.FILTER_REDUCE_EXPRESSIONS, + CoreRules.FILTER_EXPAND_IS_NOT_DISTINCT_FROM, CoreRules.AGGREGATE_REDUCE_FUNCTIONS); + +} \ No newline at end of file diff --git a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java index 9cd486a57..d7faa129f 100644 --- a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java +++ b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java @@ -4,6 +4,7 @@ import java.sql.DriverManager; import java.sql.SQLException; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Properties; @@ -19,7 +20,9 @@ import org.apache.calcite.jdbc.JavaTypeFactoryImpl; import org.apache.calcite.plan.Context; import org.apache.calcite.plan.Contexts; +import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.hep.HepMatchOrder; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepProgramBuilder; @@ -54,6 +57,7 @@ import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.Planner; import org.apache.calcite.tools.RelConversionException; +import org.apache.calcite.tools.RuleSet; import org.apache.calcite.tools.ValidationException; /** @@ -64,6 +68,11 @@ * This class is taken (in parts) from the blazingSQL project. */ public class RelationalAlgebraGenerator { + public enum HepExecutionType { + SEQUENCE, + COLLECTION + } + /// The created planner private Planner planner; /// The planner for optimized queries @@ -134,32 +143,51 @@ private CalciteConnection getCalciteConnection() throws SQLException { /// get an optimizer hep planner private HepPlanner getHepPlanner(final FrameworkConfig config) { - // TODO: check if these rules are sensible - // Taken from blazingSQL - final HepProgram program = new HepProgramBuilder() - .addRuleInstance(AggregateExpandDistinctAggregatesRule.Config.JOIN.toRule()) - .addRuleInstance(FilterSetOpTransposeRule.Config.DEFAULT.toRule()) - .addRuleInstance(FilterAggregateTransposeRule.Config.DEFAULT.toRule()) - .addRuleInstance(FilterJoinRule.JoinConditionPushRule.Config.DEFAULT.toRule()) - .addRuleInstance(FilterJoinRule.FilterIntoJoinRule.Config.DEFAULT.toRule()) - .addRuleInstance(JoinToMultiJoinRule.Config.DEFAULT.toRule()) - .addRuleInstance(LoptOptimizeJoinRule.Config.DEFAULT.toRule()) - .addRuleInstance(ProjectMergeRule.Config.DEFAULT.toRule()) - .addRuleInstance(FilterMergeRule.Config.DEFAULT.toRule()) - .addRuleInstance(ProjectJoinTransposeRule.Config.DEFAULT.toRule()) - // In principle, not a bad idea. But we need to keep the most - // outer project - because otherwise the column name information is lost - // in cases such as SELECT x AS a, y AS B FROM df - // .addRuleInstance(ProjectRemoveRule.Config.DEFAULT.toRule()) - .addRuleInstance(ReduceExpressionsRule.ProjectReduceExpressionsRule.Config.DEFAULT.toRule()) - // this rule might make sense, but turns a < 1 into a SEARCH expression - // which is currently not supported by dask-sql - // .addRuleInstance(ReduceExpressionsRule.FilterReduceExpressionsRule.Config.DEFAULT.toRule()) - .addRuleInstance(FilterRemoveIsNotDistinctFromRule.Config.DEFAULT.toRule()) - // TODO: remove AVG - .addRuleInstance(AggregateReduceFunctionsRule.Config.DEFAULT.toRule()).build(); - - return new HepPlanner(program, config.getContext()); + final HepProgramBuilder builder = new HepProgramBuilder(); + builder.addMatchOrder(HepMatchOrder.ARBITRARY).addMatchLimit(Integer.MAX_VALUE); + // for (RelOptRule rule : DaskRuleSets.DASK_DEFAULT_CORE_RULES){ + // builder.addRuleInstance(rule); + // } + + // join reorder + builder.addSubprogram(getHepProgram(DaskRuleSets.JOIN_REORDER_PREPARE_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); + builder.addSubprogram(getHepProgram(DaskRuleSets.JOIN_REORDER_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.SEQUENCE)); + + // project rules + builder.addSubprogram(getHepProgram(DaskRuleSets.AGGREGATE_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); + builder.addSubprogram(getHepProgram(DaskRuleSets.PROJECT_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); + builder.addSubprogram(getHepProgram(DaskRuleSets.FILTER_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); + builder.addSubprogram(getHepProgram(DaskRuleSets.REDUCE_EXPRESSION_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); + + // optimize logical plan + builder.addSubprogram(getHepProgram(DaskRuleSets.LOGICAL_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.SEQUENCE)); + + return new HepPlanner(builder.build(), config.getContext()); + } + + /** + * Builds a HepProgram for the given set of rules and with the given order. If type is COLLECTION, + * rules are added as collection. Otherwise, rules are added sequentially. + * @param rules + * @param order + * @param type + * @return + */ + private HepProgram getHepProgram(final RuleSet rules, final HepMatchOrder order, final HepExecutionType type) { + final HepProgramBuilder builder = new HepProgramBuilder().addMatchOrder(order); + switch (type) { + case SEQUENCE: + for (RelOptRule rule : rules) { + builder.addRuleInstance(rule); + } + break; + case COLLECTION: + List rulesCollection = new ArrayList(); + rules.iterator().forEachRemaining(rulesCollection::add); + builder.addRuleCollection(rulesCollection); + break; + } + return builder.build(); } /// Parse a sql string into a sql tree @@ -205,3 +233,4 @@ public String getRelationalAlgebraString(final RelNode relNode) { return RelOptUtil.toString(relNode); } } + From 48be68057b531bff157a555ae696538d7e77175f Mon Sep 17 00:00:00 2001 From: Demian Wassermann Date: Wed, 10 Mar 2021 10:06:30 +0100 Subject: [PATCH 14/33] Optimizations and extension of join implementation --- dask_sql/physical/rel/logical/join.py | 39 ++++++++++--------- .../dask/sql/application/DaskRuleSets.java | 10 ++++- .../RelationalAlgebraGenerator.java | 11 +++--- 3 files changed, 35 insertions(+), 25 deletions(-) diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py index dd3a2aa47..f775fd3c8 100644 --- a/dask_sql/physical/rel/logical/join.py +++ b/dask_sql/physical/rel/logical/join.py @@ -318,21 +318,24 @@ def _extract_lhs_rhs(self, rex): operand_lhs = operands[0] operand_rhs = operands[1] - if isinstance(operand_lhs, org.apache.calcite.rex.RexInputRef) and isinstance( - operand_rhs, org.apache.calcite.rex.RexInputRef - ): - lhs_index = operand_lhs.getIndex() - rhs_index = operand_rhs.getIndex() - - # The rhs table always comes after the lhs - # table. Therefore we have a very simple - # way of checking, which index comes from which - # input - if lhs_index > rhs_index: - lhs_index, rhs_index = rhs_index, lhs_index - - return lhs_index, rhs_index - - raise TypeError( - "Invalid join condition" - ) # pragma: no cover. Do not how how it could be triggered. + indices = [] + for operand in operands: + if isinstance(operand, org.apache.calcite.rex.RexInputRef): + indices.append(operand.getIndex()) + elif ( + isinstance(operand, org.apache.calcite.rex.RexCall) and + isinstance(operand.getOperator(), org.apache.calcite.sql.fun.SqlCastFunction) + ): + indices.append(operand.operands[0].getIndex()) + else: + raise TypeError( + "Invalid join condition" + ) # pragma: no cover. Do not how how it could be triggered. + lhs_index, rhs_index = indices + + if lhs_index > rhs_index: + lhs_index, rhs_index = rhs_index, lhs_index + + return lhs_index, rhs_index + + diff --git a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java index 6951abba8..f65588606 100644 --- a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java +++ b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java @@ -37,6 +37,7 @@ public class DaskRuleSets { * RuleSet about project */ static final RuleSet PROJECT_RULES = RuleSets.ofList( + CoreRules.AGGREGATE_PROJECT_MERGE, // push a projection past a filter CoreRules.PROJECT_FILTER_TRANSPOSE, // merge projections @@ -46,7 +47,11 @@ public class DaskRuleSets { // removes constant keys from an Agg CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS, // push project through a Union - CoreRules.PROJECT_SET_OP_TRANSPOSE, CoreRules.PROJECT_JOIN_TRANSPOSE); + CoreRules.PROJECT_SET_OP_TRANSPOSE, + CoreRules.JOIN_PROJECT_LEFT_TRANSPOSE, + CoreRules.JOIN_PROJECT_RIGHT_TRANSPOSE, + CoreRules.JOIN_PROJECT_BOTH_TRANSPOSE + ); /** * RuleSet about aggregate @@ -64,7 +69,8 @@ public class DaskRuleSets { // merge filter to MultiJoin CoreRules.FILTER_MULTI_JOIN_MERGE, // merge join to MultiJoin - CoreRules.JOIN_TO_MULTI_JOIN); + CoreRules.JOIN_TO_MULTI_JOIN + ); /** * Rules to reorder joins diff --git a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java index 508edc28a..1c3ce8409 100644 --- a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java +++ b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java @@ -151,16 +151,17 @@ private HepPlanner getHepPlanner(final FrameworkConfig config) { // builder.addRuleInstance(rule); // } - // join reorder - builder.addSubprogram(getHepProgram(DaskRuleSets.JOIN_REORDER_PREPARE_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); - builder.addSubprogram(getHepProgram(DaskRuleSets.JOIN_REORDER_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.SEQUENCE)); - // project rules builder.addSubprogram(getHepProgram(DaskRuleSets.AGGREGATE_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); builder.addSubprogram(getHepProgram(DaskRuleSets.PROJECT_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); builder.addSubprogram(getHepProgram(DaskRuleSets.FILTER_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); builder.addSubprogram(getHepProgram(DaskRuleSets.REDUCE_EXPRESSION_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); - + // join reorder + builder.addSubprogram(getHepProgram(DaskRuleSets.JOIN_REORDER_PREPARE_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); + builder.addSubprogram(getHepProgram(DaskRuleSets.JOIN_REORDER_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.SEQUENCE)); + + // project rules + builder.addSubprogram(getHepProgram(DaskRuleSets.PROJECT_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); // optimize logical plan builder.addSubprogram(getHepProgram(DaskRuleSets.LOGICAL_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.SEQUENCE)); From 6b7b78ebe809a696ff6a8387bc85ffcb47989648 Mon Sep 17 00:00:00 2001 From: Demian Wassermann Date: Wed, 10 Mar 2021 11:10:21 +0100 Subject: [PATCH 15/33] Improved log and added one more project rule --- dask_sql/context.py | 10 ++++++---- .../java/com/dask/sql/application/DaskRuleSets.java | 1 + 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/dask_sql/context.py b/dask_sql/context.py index 108595db0..6cb319978 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -515,11 +515,13 @@ def _get_ral(self, sql): else: validatedSqlNode = generator.getValidatedNode(sqlNode) nonOptimizedRelNode = generator.getRelationalAlgebra(validatedSqlNode) + rel_string_non_op = str(generator.getRelationalAlgebraString(nonOptimizedRelNode)) + rel_non_op_count = rel_string_non_op.count('\n') rel = generator.getOptimizedRelationalAlgebra(nonOptimizedRelNode) rel_string = str(generator.getRelationalAlgebraString(rel)) logger.debug( - f"Non optimised query plan: \n " - f"{str(generator.getRelationalAlgebraString(nonOptimizedRelNode))}" + f"Non optimised query plan: {rel_non_op_count} ops\n " + f"{rel_string_non_op}" ) except (ValidationException, SqlParseException) as e: logger.debug(f"Original exception raised by Java:\n {e}") @@ -550,8 +552,8 @@ def _get_ral(self, sql): "Not extracting output column names as the SQL is not a SELECT call" ) select_names = None - - logger.debug(f"Extracted relational algebra:\n {rel_string}") + br = '\n' + logger.debug(f"Extracted relational algebra {rel_string.count(br)} ops:\n {rel_string}") return rel, select_names, rel_string def _to_sql_string(self, s: "org.apache.calcite.sql.SqlNode", default_dialect=None): diff --git a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java index f65588606..e881f0eb5 100644 --- a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java +++ b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java @@ -37,6 +37,7 @@ public class DaskRuleSets { * RuleSet about project */ static final RuleSet PROJECT_RULES = RuleSets.ofList( + CoreRules.PROJECT_MERGE, CoreRules.AGGREGATE_PROJECT_MERGE, // push a projection past a filter CoreRules.PROJECT_FILTER_TRANSPOSE, From 98a3a9d492a096af0d8e41f09db2c802d11e1d6e Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Thu, 11 Mar 2021 17:43:51 +0100 Subject: [PATCH 16/33] Cleaning up rules for optimizer --- .../dask/sql/application/DaskRuleSets.java | 302 +++++++++--------- .../RelationalAlgebraGenerator.java | 53 +-- 2 files changed, 163 insertions(+), 192 deletions(-) diff --git a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java index e881f0eb5..091c97423 100644 --- a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java +++ b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java @@ -1,160 +1,160 @@ package com.dask.sql.application; +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.hep.HepMatchOrder; +import org.apache.calcite.plan.hep.HepProgram; +import org.apache.calcite.plan.hep.HepProgramBuilder; import org.apache.calcite.rel.rules.CoreRules; import org.apache.calcite.tools.RuleSet; import org.apache.calcite.tools.RuleSets; +/** + * RuleSets and utilities for creating Programs to use with Calcite's query + * planners. This is inspired both from Apache Calcite's default optimization + * programs + * (https://github.com/apache/calcite/blob/master/core/src/main/java/org/apache/calcite/tools/Programs.java) + * and Apache Flink's multi-phase query optimization + * (https://github.com/apache/flink/blob/master/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkStreamProgram.scala) + */ public class DaskRuleSets { - /** - * Convert sub-queries before query decorrelation. - */ - static final RuleSet TABLE_SUBQUERY_RULES = RuleSets.ofList(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE, - CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE, CoreRules.JOIN_SUB_QUERY_TO_CORRELATE); - - /** - * RuleSet to reduce expressions - */ - static final RuleSet REDUCE_EXPRESSION_RULES = RuleSets.ofList(CoreRules.FILTER_REDUCE_EXPRESSIONS, - CoreRules.PROJECT_REDUCE_EXPRESSIONS, CoreRules.CALC_REDUCE_EXPRESSIONS, CoreRules.JOIN_REDUCE_EXPRESSIONS, - CoreRules.AGGREGATE_REDUCE_FUNCTIONS); - - /** - * RuleSet about filter - */ - static final RuleSet FILTER_RULES = RuleSets.ofList( - // push a filter into a join - CoreRules.FILTER_INTO_JOIN, - // push filter into the children of a join - CoreRules.JOIN_CONDITION_PUSH, - // push filter through an aggregation - CoreRules.FILTER_AGGREGATE_TRANSPOSE, - // push a filter past a project - CoreRules.FILTER_PROJECT_TRANSPOSE, - // push a filter past a setop - CoreRules.FILTER_SET_OP_TRANSPOSE, CoreRules.FILTER_MERGE); - - /** - * RuleSet about project - */ - static final RuleSet PROJECT_RULES = RuleSets.ofList( - CoreRules.PROJECT_MERGE, - CoreRules.AGGREGATE_PROJECT_MERGE, - // push a projection past a filter - CoreRules.PROJECT_FILTER_TRANSPOSE, - // merge projections - CoreRules.PROJECT_MERGE, - // Don't add PROJECT_REMOVE - // CoreRules.PROJECT_REMOVE, - // removes constant keys from an Agg - CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS, - // push project through a Union - CoreRules.PROJECT_SET_OP_TRANSPOSE, - CoreRules.JOIN_PROJECT_LEFT_TRANSPOSE, - CoreRules.JOIN_PROJECT_RIGHT_TRANSPOSE, - CoreRules.JOIN_PROJECT_BOTH_TRANSPOSE - ); - - /** - * RuleSet about aggregate - */ - static final RuleSet AGGREGATE_RULES = RuleSets.ofList(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN, - CoreRules.AGGREGATE_JOIN_TRANSPOSE, CoreRules.AGGREGATE_PROJECT_MERGE, CoreRules.PROJECT_AGGREGATE_MERGE, - CoreRules.AGGREGATE_MERGE); - - /** - * RuleSet for merging joins - */ - static final RuleSet JOIN_REORDER_PREPARE_RULES = RuleSets.ofList( - // merge project to MultiJoin - CoreRules.PROJECT_MULTI_JOIN_MERGE, - // merge filter to MultiJoin - CoreRules.FILTER_MULTI_JOIN_MERGE, - // merge join to MultiJoin - CoreRules.JOIN_TO_MULTI_JOIN - ); - - /** - * Rules to reorder joins - */ - static final RuleSet JOIN_REORDER_RULES = RuleSets.ofList(CoreRules.MULTI_JOIN_OPTIMIZE, - CoreRules.MULTI_JOIN_OPTIMIZE_BUSHY); - - /** - * RuleSet to do logical optimize. - */ - static final RuleSet LOGICAL_RULES = RuleSets.ofList( - // scan optimization - // PushProjectIntoTableSourceScanRule.INSTANCE, - // PushProjectIntoLegacyTableSourceScanRule.INSTANCE, - // PushFilterIntoTableSourceScanRule.INSTANCE, - // PushFilterIntoLegacyTableSourceScanRule.INSTANCE, - // PushLimitIntoTableSourceScanRule.INSTANCE, - - // reorder sort and projection - // CoreRules.SORT_PROJECT_TRANSPOSE, - // remove unnecessary sort rule - // CoreRules.SORT_REMOVE, - - // join rules - // FlinkJoinPushExpressionsRule.INSTANCE, - // SimplifyJoinConditionRule.INSTANCE, - - // remove union with only a single child - CoreRules.UNION_REMOVE, - // convert non-all union into all-union + distinct - CoreRules.UNION_TO_DISTINCT, - - // aggregation and projection rules - CoreRules.AGGREGATE_PROJECT_MERGE, CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS, - - // remove aggregation if it does not aggregate and input is already distinct - // FlinkAggregateRemoveRule.INSTANCE, - // push aggregate through join - // FlinkAggregateJoinTransposeRule.EXTENDED, - // using variants of aggregate union rule - CoreRules.AGGREGATE_UNION_AGGREGATE_FIRST, CoreRules.AGGREGATE_UNION_AGGREGATE_SECOND - - // reduce aggregate functions like AVG, STDDEV_POP etc. - // CoreRules.AGGREGATE_REDUCE_FUNCTIONS, - - // reduce useless aggCall - // PruneAggregateCallRule.PROJECT_ON_AGGREGATE, - // PruneAggregateCallRule.CALC_ON_AGGREGATE, - - // expand grouping sets - // DecomposeGroupingSetsRule.INSTANCE, - - // calc rules - // CoreRules.FILTER_CALC_MERGE, CoreRules.PROJECT_CALC_MERGE, - // CoreRules.FILTER_TO_CALC, - // CoreRules.PROJECT_TO_CALC - // FlinkCalcMergeRule.INSTANCE, - - // semi/anti join transpose rule - // FlinkSemiAntiJoinJoinTransposeRule.INSTANCE, - // FlinkSemiAntiJoinProjectTransposeRule.INSTANCE, - // FlinkSemiAntiJoinFilterTransposeRule.INSTANCE, - - // set operators - // ReplaceIntersectWithSemiJoinRule.INSTANCE, - // RewriteIntersectAllRule.INSTANCE, - // ReplaceMinusWithAntiJoinRule.INSTANCE, - // RewriteMinusAllRule.INSTANCE - ); - - /** - * Initial rule set from dask_sql with a couple rules added by Demian. - */ - static final RuleSet DASK_DEFAULT_CORE_RULES = RuleSets.ofList( - CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN, CoreRules.FILTER_SET_OP_TRANSPOSE, - CoreRules.FILTER_AGGREGATE_TRANSPOSE, CoreRules.FILTER_INTO_JOIN, CoreRules.JOIN_CONDITION_PUSH, - CoreRules.PROJECT_JOIN_TRANSPOSE, CoreRules.PROJECT_MULTI_JOIN_MERGE, CoreRules.JOIN_TO_MULTI_JOIN, - CoreRules.MULTI_JOIN_OPTIMIZE, CoreRules.MULTI_JOIN_OPTIMIZE_BUSHY, CoreRules.AGGREGATE_JOIN_TRANSPOSE, - CoreRules.AGGREGATE_PROJECT_MERGE, CoreRules.PROJECT_AGGREGATE_MERGE, CoreRules.AGGREGATE_MERGE, - CoreRules.PROJECT_MERGE, CoreRules.FILTER_MERGE, - // Don't add this rule as it removes projections which are used to rename colums - // CoreRules.PROJECT_REMOVE, - CoreRules.PROJECT_REDUCE_EXPRESSIONS, CoreRules.FILTER_REDUCE_EXPRESSIONS, - CoreRules.FILTER_EXPAND_IS_NOT_DISTINCT_FROM, CoreRules.AGGREGATE_REDUCE_FUNCTIONS); + + // private constructor + private DaskRuleSets() { + } + + /** + * RuleSet to reduce expressions + */ + static final RuleSet REDUCE_EXPRESSION_RULES = RuleSets.ofList(CoreRules.FILTER_REDUCE_EXPRESSIONS, + CoreRules.PROJECT_REDUCE_EXPRESSIONS, CoreRules.CALC_REDUCE_EXPRESSIONS, + CoreRules.JOIN_REDUCE_EXPRESSIONS, CoreRules.AGGREGATE_REDUCE_FUNCTIONS); + + /** + * RuleSet about filter + */ + static final RuleSet FILTER_RULES = RuleSets.ofList( + // push a filter into a join + CoreRules.FILTER_INTO_JOIN, + // push filter into the children of a join + CoreRules.JOIN_CONDITION_PUSH, + // push filter through an aggregation + CoreRules.FILTER_AGGREGATE_TRANSPOSE, + // push a filter past a project + CoreRules.FILTER_PROJECT_TRANSPOSE, + // push a filter past a setop + CoreRules.FILTER_SET_OP_TRANSPOSE, CoreRules.FILTER_MERGE); + + /** + * RuleSet about project Dont' add CoreRules.PROJECT_REMOVE + */ + static final RuleSet PROJECT_RULES = RuleSets.ofList(CoreRules.PROJECT_MERGE, CoreRules.AGGREGATE_PROJECT_MERGE, + // push a projection past a filter + CoreRules.PROJECT_FILTER_TRANSPOSE, + // merge projections + CoreRules.PROJECT_MERGE, + // removes constant keys from an Agg + CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS, + // push project through a Union + CoreRules.PROJECT_SET_OP_TRANSPOSE, CoreRules.JOIN_PROJECT_LEFT_TRANSPOSE, + CoreRules.JOIN_PROJECT_RIGHT_TRANSPOSE, CoreRules.JOIN_PROJECT_BOTH_TRANSPOSE); + + /** + * RuleSet about aggregate + */ + static final RuleSet AGGREGATE_RULES = RuleSets.ofList(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN, + CoreRules.AGGREGATE_JOIN_TRANSPOSE, CoreRules.AGGREGATE_PROJECT_MERGE, + CoreRules.PROJECT_AGGREGATE_MERGE, CoreRules.AGGREGATE_MERGE, + // Important. Removes unecessary distinct calls + CoreRules.AGGREGATE_REMOVE, CoreRules.AGGREGATE_JOIN_REMOVE); + + /** + * RuleSet for merging joins + */ + static final RuleSet JOIN_REORDER_PREPARE_RULES = RuleSets.ofList( + // merge project to MultiJoin + CoreRules.PROJECT_MULTI_JOIN_MERGE, + // merge filter to MultiJoin + CoreRules.FILTER_MULTI_JOIN_MERGE, + // merge join to MultiJoin + CoreRules.JOIN_TO_MULTI_JOIN); + + /** + * Rules to reorder joins + */ + static final RuleSet JOIN_REORDER_RULES = RuleSets.ofList( + // optimize multi joins + CoreRules.MULTI_JOIN_OPTIMIZE, + // optmize bushy multi joins + CoreRules.MULTI_JOIN_OPTIMIZE_BUSHY); + + /** + * Rules to reorder joins using associate and commute rules. See + * https://www.querifylabs.com/blog/rule-based-query-optimization for an + * explanation. JoinCommuteRule causes exhaustive search and should probably not + * be used. + */ + static final RuleSet JOIN_COMMUTE_ASSOCIATE_RULES = RuleSets.ofList( + // changes a join based on associativity rule. + CoreRules.JOIN_ASSOCIATE, CoreRules.JOIN_COMMUTE); + + /** + * RuleSet to do logical optimize. + */ + static final RuleSet LOGICAL_RULES = RuleSets.ofList( + // remove union with only a single child + CoreRules.UNION_REMOVE, + // convert non-all union into all-union + distinct + CoreRules.UNION_TO_DISTINCT, CoreRules.MINUS_MERGE, + // aggregation and projection rules + // CoreRules.AGGREGATE_PROJECT_MERGE, CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS, + // CoreRules.AGGREGATE_REMOVE, CoreRules.AGGREGATE_JOIN_TRANSPOSE_EXTENDED, + CoreRules.AGGREGATE_UNION_AGGREGATE_FIRST, CoreRules.AGGREGATE_UNION_AGGREGATE_SECOND); + + /** + * Initial rule set from dask_sql with a couple rules added by Demian. Not used + * but kept for reference. + */ + static final RuleSet DASK_DEFAULT_CORE_RULES = RuleSets.ofList( + CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN, CoreRules.FILTER_SET_OP_TRANSPOSE, + CoreRules.FILTER_AGGREGATE_TRANSPOSE, CoreRules.FILTER_INTO_JOIN, CoreRules.JOIN_CONDITION_PUSH, + CoreRules.PROJECT_JOIN_TRANSPOSE, CoreRules.PROJECT_MULTI_JOIN_MERGE, + CoreRules.JOIN_TO_MULTI_JOIN, CoreRules.MULTI_JOIN_OPTIMIZE, + CoreRules.MULTI_JOIN_OPTIMIZE_BUSHY, CoreRules.AGGREGATE_JOIN_TRANSPOSE, + CoreRules.AGGREGATE_PROJECT_MERGE, CoreRules.PROJECT_AGGREGATE_MERGE, CoreRules.AGGREGATE_MERGE, + CoreRules.PROJECT_MERGE, CoreRules.FILTER_MERGE, + // Don't add this rule as it removes projections which are used to rename colums + // CoreRules.PROJECT_REMOVE, + CoreRules.PROJECT_REDUCE_EXPRESSIONS, CoreRules.FILTER_REDUCE_EXPRESSIONS, + CoreRules.FILTER_EXPAND_IS_NOT_DISTINCT_FROM, CoreRules.AGGREGATE_REDUCE_FUNCTIONS); + + /** + * Builds a HepProgram for the given set of rules and with the given order. If + * type is COLLECTION, rules are added as collection. Otherwise, rules are added + * sequentially. + */ + public static HepProgram hepProgram(final RuleSet rules, final HepMatchOrder order, + final HepExecutionType type) { + final HepProgramBuilder builder = new HepProgramBuilder().addMatchOrder(order); + switch (type) { + case SEQUENCE: + for (RelOptRule rule : rules) { + builder.addRuleInstance(rule); + } + break; + case COLLECTION: + List rulesCollection = new ArrayList(); + rules.iterator().forEachRemaining(rulesCollection::add); + builder.addRuleCollection(rulesCollection); + break; + } + return builder.build(); + } + + public enum HepExecutionType { + SEQUENCE, COLLECTION + } } \ No newline at end of file diff --git a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java index 1c3ce8409..6a55fbd08 100644 --- a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java +++ b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Properties; +import com.dask.sql.application.DaskRuleSets.HepExecutionType; import com.dask.sql.schema.DaskSchema; import org.apache.calcite.sql.SqlDialect; import org.apache.calcite.sql.dialect.PostgresqlSqlDialect; @@ -70,10 +71,6 @@ * This class is taken (in parts) from the blazingSQL project. */ public class RelationalAlgebraGenerator { - public enum HepExecutionType { - SEQUENCE, - COLLECTION - } /// The created planner private Planner planner; @@ -147,52 +144,26 @@ private CalciteConnection getCalciteConnection() throws SQLException { private HepPlanner getHepPlanner(final FrameworkConfig config) { final HepProgramBuilder builder = new HepProgramBuilder(); builder.addMatchOrder(HepMatchOrder.ARBITRARY).addMatchLimit(Integer.MAX_VALUE); + // Legacy rule set // for (RelOptRule rule : DaskRuleSets.DASK_DEFAULT_CORE_RULES){ // builder.addRuleInstance(rule); // } - // project rules - builder.addSubprogram(getHepProgram(DaskRuleSets.AGGREGATE_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); - builder.addSubprogram(getHepProgram(DaskRuleSets.PROJECT_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); - builder.addSubprogram(getHepProgram(DaskRuleSets.FILTER_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); - builder.addSubprogram(getHepProgram(DaskRuleSets.REDUCE_EXPRESSION_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); - // join reorder - builder.addSubprogram(getHepProgram(DaskRuleSets.JOIN_REORDER_PREPARE_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); - builder.addSubprogram(getHepProgram(DaskRuleSets.JOIN_REORDER_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.SEQUENCE)); + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.AGGREGATE_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.PROJECT_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.FILTER_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.REDUCE_EXPRESSION_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); + // join reorder. The first set of rules transforms joins into a large multijoin. + // the second set of rules splits the multijoins by applying a heuristic to determine the best join order. + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.JOIN_REORDER_PREPARE_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.JOIN_REORDER_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.SEQUENCE)); - // project rules - builder.addSubprogram(getHepProgram(DaskRuleSets.PROJECT_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); - // optimize logical plan - builder.addSubprogram(getHepProgram(DaskRuleSets.LOGICAL_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.SEQUENCE)); + // optimize logical plan. Be careful not to introduce rules in this set which mess up the join order from the step before. + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.LOGICAL_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.SEQUENCE)); return new HepPlanner(builder.build(), config.getContext()); } - /** - * Builds a HepProgram for the given set of rules and with the given order. If type is COLLECTION, - * rules are added as collection. Otherwise, rules are added sequentially. - * @param rules - * @param order - * @param type - * @return - */ - private HepProgram getHepProgram(final RuleSet rules, final HepMatchOrder order, final HepExecutionType type) { - final HepProgramBuilder builder = new HepProgramBuilder().addMatchOrder(order); - switch (type) { - case SEQUENCE: - for (RelOptRule rule : rules) { - builder.addRuleInstance(rule); - } - break; - case COLLECTION: - List rulesCollection = new ArrayList(); - rules.iterator().forEachRemaining(rulesCollection::add); - builder.addRuleCollection(rulesCollection); - break; - } - return builder.build(); - } - /// Parse a sql string into a sql tree public SqlNode getSqlNode(final String sql) throws SqlParseException { try { From 6e35be5c4e78588075ccbce602afca03a59e20dd Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Tue, 16 Mar 2021 10:20:39 +0100 Subject: [PATCH 17/33] adding a Materializer class to add RelOptMaterializations to the planner --- planner/pom.xml | 5 + .../application/DaskCalciteMaterializer.java | 286 ++++++++++++++++++ .../RelationalAlgebraGenerator.java | 60 ++-- .../java/com/dask/sql/schema/DaskTable.java | 21 +- 4 files changed, 336 insertions(+), 36 deletions(-) create mode 100644 planner/src/main/java/com/dask/sql/application/DaskCalciteMaterializer.java diff --git a/planner/pom.xml b/planner/pom.xml index a24f5a1e0..a3ce2c1f5 100755 --- a/planner/pom.xml +++ b/planner/pom.xml @@ -47,6 +47,11 @@ javacc 4.0 + + com.google.guava + guava + 30.1-jre + diff --git a/planner/src/main/java/com/dask/sql/application/DaskCalciteMaterializer.java b/planner/src/main/java/com/dask/sql/application/DaskCalciteMaterializer.java new file mode 100644 index 000000000..2d4a77ddb --- /dev/null +++ b/planner/src/main/java/com/dask/sql/application/DaskCalciteMaterializer.java @@ -0,0 +1,286 @@ +package com.dask.sql.application; + +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; + +import javax.annotation.Nullable; + +import com.dask.sql.schema.DaskTable; +import com.google.common.collect.ImmutableList; + +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.config.CalciteConnectionConfig; +import org.apache.calcite.config.CalciteConnectionConfigImpl; +import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.jdbc.JavaTypeFactoryImpl; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptMaterialization; +import org.apache.calcite.plan.RelOptPlanner; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.volcano.VolcanoPlanner; +import org.apache.calcite.prepare.CalciteCatalogReader; +import org.apache.calcite.prepare.PlannerImpl; +import org.apache.calcite.prepare.RelOptTableImpl; +import org.apache.calcite.prepare.Prepare.CatalogReader; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.schema.Schemas; +import org.apache.calcite.schema.Table; +import org.apache.calcite.schema.impl.StarTable; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.fun.SqlLibrary; +import org.apache.calcite.sql.fun.SqlLibraryOperatorTableFactory; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.util.SqlOperatorTables; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorUtil; +import org.apache.calcite.sql2rel.SqlToRelConverter; +import org.apache.calcite.sql2rel.StandardConvertletTable; +import org.apache.calcite.tools.FrameworkConfig; + +/** + * Utility class for preparing a list of RelOptMaterializations which can then + * be added to a RelOptPlanner (hep or volcano) to be used during optimization. + * + * Create the class using the constructor with the rootSchema, defaultSchemaName, + * and FrameworkConfig used to create a Planner. Then call getMaterializations() + * to get the list of RelOptMaterializations, containing a RelOptMaterialization + * for each view in the schema. A view is any DaskTable which has an sql query + * associated. + * + * A lot of the code for this class is taken / adapted from : + * https://github.com/apache/calcite/blob/master/core/src/main/java/org/apache/calcite/prepare/Prepare.java + * https://github.com/apache/calcite/blob/master/core/src/main/java/org/apache/calcite/prepare/CalciteMaterializer.java + * https://github.com/apache/calcite/blob/master/core/src/main/java/org/apache/calcite/prepare/CalcitePrepareImpl.java + * + */ +public class DaskCalciteMaterializer { + + private final CatalogReader catalogReader; + private final CalciteSchema schema; + private final SqlValidator sqlValidator; + private final JavaTypeFactory typeFactory; + private final FrameworkConfig config; + private final PlannerImpl planner; + + DaskCalciteMaterializer(final SchemaPlus rootSchema, final String schemaName, final FrameworkConfig config) { + final SchemaPlus schemaPlus = rootSchema.getSubSchema(schemaName); + schema = CalciteSchema.from(schemaPlus); + this.config = config; + planner = new PlannerImpl(config); + + final List schemaPath = new ArrayList(); + schemaPath.add(schema.getName()); + final Properties props = new Properties(); + props.setProperty("defaultSchema", schema.getName()); + catalogReader = new CalciteCatalogReader(schema.root(), schemaPath, + new JavaTypeFactoryImpl(DaskSqlDialect.DASKSQL_TYPE_SYSTEM), new CalciteConnectionConfigImpl(props)); + + final List sqlOperatorTables = new ArrayList<>(); + sqlOperatorTables.add(SqlStdOperatorTable.instance()); + sqlOperatorTables.add(SqlLibraryOperatorTableFactory.INSTANCE.getOperatorTable(SqlLibrary.POSTGRESQL)); + sqlOperatorTables.add(catalogReader); + SqlOperatorTable operatorTable = SqlOperatorTables.chain(sqlOperatorTables); + + typeFactory = new JavaTypeFactoryImpl(DaskSqlDialect.DASKSQL_TYPE_SYSTEM); + final CalciteConnectionConfig connectionConfig = new CalciteConnectionConfigImpl(props); + final SqlValidator.Config validatorConfig = SqlValidator.Config.DEFAULT + .withLenientOperatorLookup(connectionConfig.lenientOperatorLookup()) + .withSqlConformance(connectionConfig.conformance()) + .withDefaultNullCollation(connectionConfig.defaultNullCollation()).withIdentifierExpansion(true); + + sqlValidator = SqlValidatorUtil.newValidator(operatorTable, catalogReader, typeFactory, validatorConfig); + } + + /** + * Prepare a list of RelOptMaterialization to be added to the planner before optimizing + */ + public List getMaterializations() { + List materializations = new ArrayList(); + for (String tableName : schema.getTableNames()) { + CalciteSchema.TableEntry tableEntry = schema.getTable(tableName, true); + DaskTable table = (DaskTable) tableEntry.getTable(); + if (table.isMaterializedView()) { + List qualifiedTableName = tableEntry.path(); + // Create a materialization with the Table and SQL query + final Materialization materialization = new Materialization(tableEntry, table.getSql(), qualifiedTableName); + // Populate this materialization's tableRel and queryRel with the corresponding + // RelNode representation for the query and table + populate(materialization); + // Create a RelOptMaterialization to add to the list of materializations + materializations.add( + new RelOptMaterialization( + materialization.tableRel, + materialization.queryRel, + materialization.starRelOptTable, + qualifiedTableName)); + } + } + return materializations; + } + + protected SqlToRelConverter getSqlToRelConverter(SqlValidator validator, CatalogReader catalogReader, + SqlToRelConverter.Config relConfig) { + final RexBuilder rexBuilder = new RexBuilder(typeFactory); + final RelOptPlanner optPlanner = new VolcanoPlanner(this.config.getCostFactory(), this.config.getContext()); + final RelOptCluster cluster = RelOptCluster.create(optPlanner, rexBuilder); + return new SqlToRelConverter(planner, validator, catalogReader, cluster, StandardConvertletTable.INSTANCE, + relConfig); + } + + /** + * Populates a materialization record, converting an sql query string and + * table path (essentially a list of strings, like ["hr", "sales"]) into + * RelNodes which can be used during the relational algebra planning process. + */ + protected void populate(final Materialization materialization) { + SqlParser parser = SqlParser.create(materialization.sql); + SqlNode node; + try { + node = parser.parseStmt(); + } catch (SqlParseException e) { + throw new RuntimeException("parse failed", e); + } + final SqlToRelConverter.Config relConfig = SqlToRelConverter.config().withTrimUnusedFields(true); + SqlToRelConverter sqlToRelConverter2 = getSqlToRelConverter(sqlValidator, catalogReader, relConfig); + + RelRoot root = sqlToRelConverter2.convertQuery(node, true, true); + materialization.queryRel = trimUnusedFields(root).rel; + + // Identify and substitute a StarTable in queryRel. + // + // It is possible that no StarTables match. That is OK, but the + // materialization patterns that are recognized will not be as rich. + // + // It is possible that more than one StarTable matches. TBD: should we + // take the best (whatever that means), or all of them? + useStar(schema, materialization); + + List tableName = materialization.materializedTable.path(); + RelOptTable table = this.catalogReader.getTable(tableName); + materialization.tableRel = sqlToRelConverter2.toRel(table, ImmutableList.of()); + } + + /** + * Walks over a tree of relational expressions, replacing each + * {@link org.apache.calcite.rel.RelNode} with a 'slimmed down' relational + * expression that projects only the columns required by its consumer. + * + * @param root Root of relational expression tree + * @return Trimmed relational expression + */ + protected RelRoot trimUnusedFields(RelRoot root) { + final SqlToRelConverter.Config config = SqlToRelConverter.config().withTrimUnusedFields(shouldTrim(root.rel)) + .withExpand(false); + final SqlToRelConverter converter = getSqlToRelConverter(sqlValidator, catalogReader, config); + final boolean ordered = !root.collation.getFieldCollations().isEmpty(); + final boolean dml = SqlKind.DML.contains(root.kind); + return root.withRel(converter.trimUnusedFields(dml || ordered, root.rel)); + } + + private static boolean shouldTrim(RelNode rootRel) { + // For now, don't trim if there are more than 3 joins. The projects + // near the leaves created by trim migrate past joins and seem to + // prevent join-reordering. + return RelOptUtil.countJoins(rootRel) < 2; + } + + /** + * Converts a relational expression to use a {@link StarTable} defined in + * {@code schema}. Uses the first star table that fits. + */ + private void useStar(CalciteSchema schema, Materialization materialization) { + RelNode queryRel = materialization.queryRel; + for (Callback x : useStar(schema, queryRel)) { + // Success -- we found a star table that matches. + materialization.materialize(x.rel, x.starRelOptTable); + System.out.println("Materialization " + materialization.materializedTable + " matched star table " + + x.starTable + "; query after re-write: " + RelOptUtil.toString(queryRel)); + } + } + + /** + * Converts a relational expression to use a + * {@link org.apache.calcite.schema.impl.StarTable} defined in {@code schema}. + * Uses the first star table that fits. + */ + private Iterable useStar(CalciteSchema schema, RelNode queryRel) { + List starTables = Schemas.getStarTables(schema.root()); + if (starTables.isEmpty()) { + // Don't waste effort converting to leaf-join form. + return ImmutableList.of(); + } + final List list = new ArrayList<>(); + final RelNode rel2 = RelOptMaterialization.toLeafJoinForm(queryRel); + for (CalciteSchema.TableEntry starTable : starTables) { + final Table table = starTable.getTable(); + assert table instanceof StarTable; + RelOptTableImpl starRelOptTable = RelOptTableImpl.create(catalogReader, table.getRowType(typeFactory), + starTable, null); + final RelNode rel3 = RelOptMaterialization.tryUseStar(rel2, starRelOptTable); + if (rel3 != null) { + list.add(new Callback(rel3, starTable, starRelOptTable)); + } + } + return list; + } + + /** Called when we discover a star table that matches. */ + static class Callback { + public final RelNode rel; + public final CalciteSchema.TableEntry starTable; + public final RelOptTableImpl starRelOptTable; + + Callback(RelNode rel, CalciteSchema.TableEntry starTable, RelOptTableImpl starRelOptTable) { + this.rel = rel; + this.starTable = starTable; + this.starRelOptTable = starRelOptTable; + } + } + + /** + * Describes that a given SQL query is materialized by a given table. The + * materialization is currently valid, and can be used in the planning process. + */ + public static class Materialization { + /** The table that holds the materialized data. */ + final CalciteSchema.TableEntry materializedTable; + /** The query that derives the data. */ + final String sql; + /** The schema path for the query. */ + final List viewSchemaPath; + /** + * Relational expression for the table. Usually a + * {@link org.apache.calcite.rel.logical.LogicalTableScan}. + */ + @Nullable + RelNode tableRel; + /** Relational expression for the query to populate the table. */ + @Nullable + RelNode queryRel; + /** Star table identified. */ + private @Nullable RelOptTable starRelOptTable; + + public Materialization(CalciteSchema.TableEntry materializedTable, String sql, List viewSchemaPath) { + assert materializedTable != null; + assert sql != null; + this.materializedTable = materializedTable; + this.sql = sql; + this.viewSchemaPath = viewSchemaPath; + } + + public void materialize(RelNode queryRel, RelOptTable starRelOptTable) { + this.queryRel = queryRel; + this.starRelOptTable = starRelOptTable; + // assert starRelOptTable.maybeUnwrap(StarTable.class).isPresent(); + } + } +} diff --git a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java index 6a55fbd08..ff7c5a4c8 100644 --- a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java +++ b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java @@ -4,15 +4,12 @@ import java.sql.DriverManager; import java.sql.SQLException; import java.util.ArrayList; -import java.util.Collection; import java.util.List; import java.util.Properties; import com.dask.sql.application.DaskRuleSets.HepExecutionType; import com.dask.sql.schema.DaskSchema; import org.apache.calcite.sql.SqlDialect; -import org.apache.calcite.sql.dialect.PostgresqlSqlDialect; -import org.apache.calcite.avatica.util.Casing; import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.config.CalciteConnectionConfigImpl; import org.apache.calcite.config.CalciteConnectionProperty; @@ -21,30 +18,13 @@ import org.apache.calcite.jdbc.JavaTypeFactoryImpl; import org.apache.calcite.plan.Context; import org.apache.calcite.plan.Contexts; -import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptMaterialization; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.plan.hep.HepMatchOrder; import org.apache.calcite.plan.hep.HepPlanner; -import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepProgramBuilder; import org.apache.calcite.prepare.CalciteCatalogReader; import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule; -import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule; -import org.apache.calcite.rel.rules.FilterAggregateTransposeRule; -import org.apache.calcite.rel.rules.FilterSetOpTransposeRule; -import org.apache.calcite.rel.rules.FilterJoinRule; -import org.apache.calcite.rel.rules.FilterMergeRule; -import org.apache.calcite.rel.rules.FilterMultiJoinMergeRule; -import org.apache.calcite.rel.rules.FilterRemoveIsNotDistinctFromRule; -import org.apache.calcite.rel.rules.JoinToMultiJoinRule; -import org.apache.calcite.rel.rules.LoptOptimizeJoinRule; -import org.apache.calcite.rel.rules.MultiJoinOptimizeBushyRule; -import org.apache.calcite.rel.rules.ProjectJoinTransposeRule; -import org.apache.calcite.rel.rules.ProjectMergeRule; -import org.apache.calcite.rel.rules.ProjectRemoveRule; -import org.apache.calcite.rel.rules.ReduceExpressionsRule; -import org.apache.calcite.rel.type.RelDataTypeSystem; import org.apache.calcite.rex.RexExecutorImpl; import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.sql.SqlNode; @@ -60,7 +40,6 @@ import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.Planner; import org.apache.calcite.tools.RelConversionException; -import org.apache.calcite.tools.RuleSet; import org.apache.calcite.tools.ValidationException; /** @@ -91,6 +70,11 @@ public RelationalAlgebraGenerator(final DaskSchema schema) throws ClassNotFoundE planner = Frameworks.getPlanner(config); hepPlanner = getHepPlanner(config); + + final DaskCalciteMaterializer materializer = new DaskCalciteMaterializer(rootSchema, schema.getName(), config); + for (RelOptMaterialization materialization : materializer.getMaterializations()) { + hepPlanner.addMaterialization(materialization); + } } /// Create the framework config, e.g. containing with SQL dialect we speak @@ -146,20 +130,29 @@ private HepPlanner getHepPlanner(final FrameworkConfig config) { builder.addMatchOrder(HepMatchOrder.ARBITRARY).addMatchLimit(Integer.MAX_VALUE); // Legacy rule set // for (RelOptRule rule : DaskRuleSets.DASK_DEFAULT_CORE_RULES){ - // builder.addRuleInstance(rule); + // builder.addRuleInstance(rule); // } - builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.AGGREGATE_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); - builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.PROJECT_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); - builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.FILTER_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); - builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.REDUCE_EXPRESSION_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.AGGREGATE_RULES, HepMatchOrder.BOTTOM_UP, + HepExecutionType.COLLECTION)); + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.PROJECT_RULES, HepMatchOrder.BOTTOM_UP, + HepExecutionType.COLLECTION)); + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.FILTER_RULES, HepMatchOrder.BOTTOM_UP, + HepExecutionType.COLLECTION)); + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.REDUCE_EXPRESSION_RULES, HepMatchOrder.BOTTOM_UP, + HepExecutionType.COLLECTION)); // join reorder. The first set of rules transforms joins into a large multijoin. - // the second set of rules splits the multijoins by applying a heuristic to determine the best join order. - builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.JOIN_REORDER_PREPARE_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.COLLECTION)); - builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.JOIN_REORDER_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.SEQUENCE)); - - // optimize logical plan. Be careful not to introduce rules in this set which mess up the join order from the step before. - builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.LOGICAL_RULES, HepMatchOrder.BOTTOM_UP, HepExecutionType.SEQUENCE)); + // the second set of rules splits the multijoins by applying a heuristic to + // determine the best join order. + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.JOIN_REORDER_PREPARE_RULES, HepMatchOrder.BOTTOM_UP, + HepExecutionType.COLLECTION)); + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.JOIN_REORDER_RULES, HepMatchOrder.BOTTOM_UP, + HepExecutionType.SEQUENCE)); + + // optimize logical plan. Be careful not to introduce rules in this set which + // mess up the join order from the step before. + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.LOGICAL_RULES, HepMatchOrder.BOTTOM_UP, + HepExecutionType.SEQUENCE)); return new HepPlanner(builder.build(), config.getContext()); } @@ -207,4 +200,3 @@ public String getRelationalAlgebraString(final RelNode relNode) { return RelOptUtil.toString(relNode); } } - diff --git a/planner/src/main/java/com/dask/sql/schema/DaskTable.java b/planner/src/main/java/com/dask/sql/schema/DaskTable.java index 31681f9fe..c5c730728 100644 --- a/planner/src/main/java/com/dask/sql/schema/DaskTable.java +++ b/planner/src/main/java/com/dask/sql/schema/DaskTable.java @@ -28,13 +28,22 @@ public class DaskTable implements ProjectableFilterableTable { private final ArrayList> tableColumns; // Name of this table private final String name; + // Optional sql query. If given, the table is considered a materialized view + // and added to the planner for view-based optimization + private final String sql; - /// Construct a new table with the given name - public DaskTable(final String name) { + /// Construct a new table with the given name and sql + public DaskTable(final String name, final String sql) { this.name = name; + this.sql = sql; this.tableColumns = new ArrayList>(); } + /// Construct a new table with the given name + public DaskTable(final String name) { + this(name, null); + } + /// Add a column with the given type public void addColumn(final String columnName, final SqlTypeName columnType) { this.tableColumns.add(new Pair<>(columnName, columnType)); @@ -45,6 +54,14 @@ public String getTableName() { return this.name; } + public String getSql() { + return this.sql; + } + + public boolean isMaterializedView() { + return this.sql != null; + } + /// calcite method: Get the type of a row of this table (using the type factory) @Override public RelDataType getRowType(final RelDataTypeFactory relDataTypeFactory) { From 70b922f72cb8c9f70dc5fa9e11b074cd9deadc02 Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Tue, 16 Mar 2021 11:01:24 +0100 Subject: [PATCH 18/33] Managing materialized views --- dask_sql/context.py | 19 +++++++++++++++---- .../application/DaskCalciteMaterializer.java | 2 +- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/dask_sql/context.py b/dask_sql/context.py index 6cb319978..cced80548 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -68,6 +68,8 @@ def __init__(self): """ # Storage for the registered tables self.tables = {} + # Storage for the registered views + self.views = {} # Storage for the registered functions self.functions: Dict[str, Callable] = {} self.function_list: List[FunctionDescription] = [] @@ -118,6 +120,7 @@ def create_table( input_table: InputType, format: str = None, persist: bool = True, + sql: str = None, **kwargs, ): """ @@ -193,6 +196,8 @@ def create_table( **kwargs, ) self.tables[table_name.lower()] = dc + if sql is not None: + self.views[table_name.lower()] = sql def register_dask_table(self, df: dd.DataFrame, name: str): """ @@ -450,11 +455,17 @@ def _prepare_schema(self): logger.warning("No tables are registered.") for name, dc in self.tables.items(): - table = DaskTable(name) df = dc.df - logger.debug( - f"Adding table '{name}' to schema with columns: {list(df.columns)}" - ) + if name in self.views: + table = DaskTable(name, self.views[name]) + logger.debug( + f"Adding materialied table '{name}' to schema with columns: {list(df.columns)}" + ) + else: + table = DaskTable(name) + logger.debug( + f"Adding table '{name}' to schema with columns: {list(df.columns)}" + ) for column in df.columns: data_type = df[column].dtype sql_data_type = python_to_sql_type(data_type) diff --git a/planner/src/main/java/com/dask/sql/application/DaskCalciteMaterializer.java b/planner/src/main/java/com/dask/sql/application/DaskCalciteMaterializer.java index 2d4a77ddb..b1f210df0 100644 --- a/planner/src/main/java/com/dask/sql/application/DaskCalciteMaterializer.java +++ b/planner/src/main/java/com/dask/sql/application/DaskCalciteMaterializer.java @@ -142,7 +142,7 @@ protected SqlToRelConverter getSqlToRelConverter(SqlValidator validator, Catalog * RelNodes which can be used during the relational algebra planning process. */ protected void populate(final Materialization materialization) { - SqlParser parser = SqlParser.create(materialization.sql); + SqlParser parser = SqlParser.create(materialization.sql, config.getParserConfig()); SqlNode node; try { node = parser.parseStmt(); From 6269fbc0cf5bb42fccbabc44693492eac96f38d3 Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Tue, 16 Mar 2021 15:27:32 +0100 Subject: [PATCH 19/33] Added a hepplanner for materializedview optimisation --- .../dask/sql/application/DaskRuleSets.java | 12 ++++++++++- .../RelationalAlgebraGenerator.java | 21 +++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java index 091c97423..7b0f46686 100644 --- a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java +++ b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java @@ -8,6 +8,7 @@ import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepProgramBuilder; import org.apache.calcite.rel.rules.CoreRules; +import org.apache.calcite.rel.rules.materialize.MaterializedViewRules; import org.apache.calcite.tools.RuleSet; import org.apache.calcite.tools.RuleSets; @@ -109,10 +110,19 @@ private DaskRuleSets() { // convert non-all union into all-union + distinct CoreRules.UNION_TO_DISTINCT, CoreRules.MINUS_MERGE, // aggregation and projection rules - // CoreRules.AGGREGATE_PROJECT_MERGE, CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS, + // CoreRules.AGGREGATE_PROJECT_MERGE, + // CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS, // CoreRules.AGGREGATE_REMOVE, CoreRules.AGGREGATE_JOIN_TRANSPOSE_EXTENDED, CoreRules.AGGREGATE_UNION_AGGREGATE_FIRST, CoreRules.AGGREGATE_UNION_AGGREGATE_SECOND); + /** + * RuleSet for MaterializedViews. + */ + static final RuleSet MATERIALIZATION_RULES = RuleSets.ofList(MaterializedViewRules.FILTER_SCAN, + MaterializedViewRules.PROJECT_FILTER, MaterializedViewRules.FILTER, + MaterializedViewRules.PROJECT_JOIN, MaterializedViewRules.JOIN, + MaterializedViewRules.PROJECT_AGGREGATE, MaterializedViewRules.AGGREGATE); + /** * Initial rule set from dask_sql with a couple rules added by Demian. Not used * but kept for reference. diff --git a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java index ff7c5a4c8..798e4001f 100644 --- a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java +++ b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java @@ -55,6 +55,8 @@ public class RelationalAlgebraGenerator { private Planner planner; /// The planner for optimized queries private HepPlanner hepPlanner; + /// The planner to optimise queries with materialized views + private HepPlanner viewBasedPlanner; /// Create a new relational algebra generator from a schema public RelationalAlgebraGenerator(final DaskSchema schema) throws ClassNotFoundException, SQLException { @@ -70,10 +72,12 @@ public RelationalAlgebraGenerator(final DaskSchema schema) throws ClassNotFoundE planner = Frameworks.getPlanner(config); hepPlanner = getHepPlanner(config); - + viewBasedPlanner = getViewBasedPlanner(config); final DaskCalciteMaterializer materializer = new DaskCalciteMaterializer(rootSchema, schema.getName(), config); for (RelOptMaterialization materialization : materializer.getMaterializations()) { - hepPlanner.addMaterialization(materialization); + // System.out.println("Adding materialized view for \n" + getRelationalAlgebraString(materialization.tableRel) + // + "\nwith sql query plan " + getRelationalAlgebraString(materialization.queryRel)); + viewBasedPlanner.addMaterialization(materialization); } } @@ -199,4 +203,17 @@ public RelNode getOptimizedRelationalAlgebra(final RelNode nonOptimizedPlan) { public String getRelationalAlgebraString(final RelNode relNode) { return RelOptUtil.toString(relNode); } + + public RelNode getMaterializedViewsOptimizedRelationalAlgebra(final RelNode relPlan) { + viewBasedPlanner.setRoot(relPlan); + return viewBasedPlanner.findBestExp(); + } + + private HepPlanner getViewBasedPlanner(final FrameworkConfig config) { + final HepProgramBuilder builder = new HepProgramBuilder(); + builder.addMatchOrder(HepMatchOrder.ARBITRARY).addMatchLimit(Integer.MAX_VALUE); + builder.addSubprogram(DaskRuleSets.hepProgram(DaskRuleSets.MATERIALIZATION_RULES, HepMatchOrder.BOTTOM_UP, + HepExecutionType.COLLECTION)); + return new HepPlanner(builder.build(), config.getContext()); + } } From 465fa03a81ca32f63d9710c89cdcce06bdd2456d Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Wed, 17 Mar 2021 11:50:16 +0100 Subject: [PATCH 20/33] Adding a rule to manage filter on join expressions --- dask_sql/physical/rel/logical/join.py | 3 --- .../java/com/dask/sql/application/DaskRuleSets.java | 13 +++++++++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py index f775fd3c8..052121cf3 100644 --- a/dask_sql/physical/rel/logical/join.py +++ b/dask_sql/physical/rel/logical/join.py @@ -315,9 +315,6 @@ def _extract_lhs_rhs(self, rex): operands = rex.getOperands() assert len(operands) == 2 - operand_lhs = operands[0] - operand_rhs = operands[1] - indices = [] for operand in operands: if isinstance(operand, org.apache.calcite.rex.RexInputRef): diff --git a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java index 7b0f46686..7bd35beca 100644 --- a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java +++ b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java @@ -39,6 +39,12 @@ private DaskRuleSets() { static final RuleSet FILTER_RULES = RuleSets.ofList( // push a filter into a join CoreRules.FILTER_INTO_JOIN, + // Jonas : We need JOIN_PUSH_TRANSITIVE_PREDICATES rule to work + // with the FILTER_INTO_JOIN rule, + // otherwise we end up with filter expressions on join conditions + // (LogicalJoin(condition=[=($0, 1000)], joinType=[inner])) which + // LogicalJoinPlugin can't handle. + CoreRules.JOIN_PUSH_TRANSITIVE_PREDICATES, // push filter into the children of a join CoreRules.JOIN_CONDITION_PUSH, // push filter through an aggregation @@ -59,8 +65,11 @@ private DaskRuleSets() { // removes constant keys from an Agg CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS, // push project through a Union - CoreRules.PROJECT_SET_OP_TRANSPOSE, CoreRules.JOIN_PROJECT_LEFT_TRANSPOSE, - CoreRules.JOIN_PROJECT_RIGHT_TRANSPOSE, CoreRules.JOIN_PROJECT_BOTH_TRANSPOSE); + CoreRules.PROJECT_SET_OP_TRANSPOSE, + CoreRules.JOIN_PROJECT_LEFT_TRANSPOSE, + CoreRules.JOIN_PROJECT_RIGHT_TRANSPOSE, + CoreRules.JOIN_PROJECT_BOTH_TRANSPOSE + ); /** * RuleSet about aggregate From 2e7f79e067dd9d96c79d3c893a939863ab5953d3 Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Wed, 17 Mar 2021 17:30:10 +0100 Subject: [PATCH 21/33] Changing behaviour of aggregate group_by to get it to work with multi column aggregations --- dask_sql/physical/rel/logical/aggregate.py | 31 +++++++++++++--------- dask_sql/physical/rel/logical/minus.py | 1 - 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/dask_sql/physical/rel/logical/aggregate.py b/dask_sql/physical/rel/logical/aggregate.py index 17ddbf322..0a2fcf8aa 100644 --- a/dask_sql/physical/rel/logical/aggregate.py +++ b/dask_sql/physical/rel/logical/aggregate.py @@ -243,7 +243,7 @@ def _collect_aggregations( elif len(inputs) == 0: input_col = additional_column_name else: - input_col = tuple(cc.get_backend_by_frontend_index(inputs[i]) for i in inputs) + input_col = tuple(cc.get_backend_by_frontend_index(i) for i in inputs) # Extract flags (filtering/distinct) if expr.isDistinct(): # pragma: no cover @@ -304,17 +304,23 @@ def _perform_aggregation( logger.debug(f"Filtered by {filter_column} before aggregation.") - # SQL and dask are treating null columns a bit different: - # SQL will put them to the front, dask will just ignore them - # Therefore we use the same trick as fugue does: - # we will group by both the NaN and the real column value - group_columns_and_nulls = [] - for group_column in group_columns: - # the ~ makes NaN come first - is_null_column = ~(tmp_df[group_column].isnull()) - non_nan_group_column = tmp_df[group_column].fillna(0) + # Jonas : we don't really care to have the exact same behaviour as SQL + # and grouping by series instead of column names is messing up the + # multi column aggregations so i'm just assuming this will work + # instead of the commented part below. + group_columns_and_nulls = group_columns - group_columns_and_nulls += [is_null_column, non_nan_group_column] + # # SQL and dask are treating null columns a bit different: + # # SQL will put them to the front, dask will just ignore them + # # Therefore we use the same trick as fugue does: + # # we will group by both the NaN and the real column value + # group_columns_and_nulls = [] + # for group_column in group_columns: + # # the ~ makes NaN come first + # is_null_column = ~(tmp_df[group_column].isnull()) + # non_nan_group_column = tmp_df[group_column].fillna(0) + + # group_columns_and_nulls += [is_null_column, non_nan_group_column] if not group_columns_and_nulls: # This can happen in statements like @@ -343,10 +349,9 @@ def _perform_aggregation( # ... fix the column names to a single level ... agg_result.columns = agg_result.columns.get_level_values(-1) - # apply multi-column aggregations for output_col, (input_col, aggregation_f) in multi_col_aggregations.items(): - new_col = grouped_df.apply(lambda x: aggregation_f(*[getattr(x, col) for col in input_col])) + new_col = grouped_df.apply(aggregation_f) if agg_result is None: agg_result = new_col.rename(output_col).to_frame() else: diff --git a/dask_sql/physical/rel/logical/minus.py b/dask_sql/physical/rel/logical/minus.py index f26dc6e1b..57352c1d5 100644 --- a/dask_sql/physical/rel/logical/minus.py +++ b/dask_sql/physical/rel/logical/minus.py @@ -1,6 +1,5 @@ import dask.dataframe as dd -from dask_sql.physical.rex import RexConverter from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.datacontainer import DataContainer, ColumnContainer From 9b735098047b02cfa802fd8d0b9fba6308dbcf19 Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Fri, 19 Mar 2021 10:28:30 +0100 Subject: [PATCH 22/33] Adding return_type info to aggregate apply calls --- dask_sql/physical/rel/logical/aggregate.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/dask_sql/physical/rel/logical/aggregate.py b/dask_sql/physical/rel/logical/aggregate.py index 0a2fcf8aa..5c3955d27 100644 --- a/dask_sql/physical/rel/logical/aggregate.py +++ b/dask_sql/physical/rel/logical/aggregate.py @@ -1,7 +1,7 @@ import operator from collections import defaultdict from functools import reduce -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Type import logging import pandas as pd @@ -190,7 +190,7 @@ def _do_aggregations( # As the values of the group columns # are the same for a single group anyways, we just use the first row for col in group_columns: - collected_aggregations[None].append((col, col, "first")) + collected_aggregations[None].append((col, col, "first", None)) # Now we can go ahead and use these grouped aggregations # to perform the actual aggregation @@ -223,7 +223,7 @@ def _collect_aggregations( context: "dask_sql.Context", additional_column_name: str, output_column_order: List[str], - ) -> Tuple[Dict[Tuple[str, str], List[Tuple[str, str, Any]]], List[str]]: + ) -> Tuple[Dict[Tuple[str, str], List[Tuple[str, str, Any, Type]]], List[str]]: """ Collect all aggregations together, which have the same filter column so that the aggregations only need to be done once. @@ -256,6 +256,10 @@ def _collect_aggregations( # Find out which aggregation function to use aggregation_name = str(expr.getAggregation().getName()) aggregation_name = aggregation_name.lower() + return_type = None + for function_description in context.function_list: + if function_description.name == aggregation_name: + return_type = function_description.return_type try: aggregation_function = self.AGGREGATION_MAPPING[aggregation_name] except KeyError: @@ -282,7 +286,7 @@ def _collect_aggregations( # Store the aggregation key = filter_column - value = (input_col, output_col, aggregation_function) + value = (input_col, output_col, aggregation_function, return_type) collected_aggregations[key].append(value) output_column_order.append(output_col) @@ -292,7 +296,7 @@ def _perform_aggregation( self, df: dd.DataFrame, filter_column: str, - aggregations: List[Tuple[str, str, Any]], + aggregations: List[Tuple[str, str, Any, Type]], additional_column_name: str, group_columns: List[str], ): @@ -334,9 +338,9 @@ def _perform_aggregation( aggregations_dict = defaultdict(dict) multi_col_aggregations = dict() for aggregation in aggregations: - input_col, output_col, aggregation_f = aggregation + input_col, output_col, aggregation_f, return_type = aggregation if isinstance(input_col, tuple): - multi_col_aggregations[output_col] = (input_col, aggregation_f) + multi_col_aggregations[output_col] = (input_col, aggregation_f, return_type) else: aggregations_dict[input_col][output_col] = aggregation_f @@ -350,8 +354,8 @@ def _perform_aggregation( agg_result.columns = agg_result.columns.get_level_values(-1) # apply multi-column aggregations - for output_col, (input_col, aggregation_f) in multi_col_aggregations.items(): - new_col = grouped_df.apply(aggregation_f) + for output_col, (input_col, aggregation_f, return_type) in multi_col_aggregations.items(): + new_col = grouped_df.apply(lambda x: aggregation_f(*[getattr(x, col) for col in input_col])) if agg_result is None: agg_result = new_col.rename(output_col).to_frame() else: From e54e40686f4f335049eb5d71c208e4b0016ed442 Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Fri, 19 Mar 2021 15:15:19 +0100 Subject: [PATCH 23/33] Rewriting of aggregate to handle both GroupBy-aggregations and GroupBy-apply functions --- dask_sql/physical/rel/logical/aggregate.py | 173 +++++++++++++++------ 1 file changed, 129 insertions(+), 44 deletions(-) diff --git a/dask_sql/physical/rel/logical/aggregate.py b/dask_sql/physical/rel/logical/aggregate.py index 5c3955d27..22de5f079 100644 --- a/dask_sql/physical/rel/logical/aggregate.py +++ b/dask_sql/physical/rel/logical/aggregate.py @@ -21,7 +21,9 @@ class ReduceAggregation(dd.Aggregation): """ def __init__(self, name: str, operation: Callable): - series_aggregate = lambda s: s.aggregate(lambda x: reduce(operation, x)) + series_aggregate = lambda s: s.aggregate( + lambda x: reduce(operation, x) + ) super().__init__(name, series_aggregate, series_aggregate) @@ -100,8 +102,12 @@ class LogicalAggregatePlugin(BaseRelPlugin): "bit_and": AggregationSpecification( ReduceAggregation("bit_and", operator.and_) ), - "bit_or": AggregationSpecification(ReduceAggregation("bit_or", operator.or_)), - "bit_xor": AggregationSpecification(ReduceAggregation("bit_xor", operator.xor)), + "bit_or": AggregationSpecification( + ReduceAggregation("bit_or", operator.or_) + ), + "bit_xor": AggregationSpecification( + ReduceAggregation("bit_xor", operator.xor) + ), "count": AggregationSpecification("count"), "every": AggregationSpecification( dd.Aggregation("every", lambda s: s.all(), lambda s0: s0.all()) @@ -112,7 +118,9 @@ class LogicalAggregatePlugin(BaseRelPlugin): } def convert( - self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" + self, + rel: "org.apache.calcite.rel.RelNode", + context: "dask_sql.Context", ) -> DataContainer: (dc,) = self.assert_inputs(rel, 1, context) @@ -123,7 +131,9 @@ def convert( cc = cc.make_unique() # I have no idea what that is, but so far it was always of length 1 - assert len(rel.getGroupSets()) == 1, "Do not know how to handle this case!" + assert ( + len(rel.getGroupSets()) == 1 + ), "Do not know how to handle this case!" # Extract the information, which columns we need to group for group_column_indices = [int(i) for i in rel.getGroupSet()] @@ -142,7 +152,10 @@ def convert( # Do all aggregates df_result, output_column_order = self._do_aggregations( - rel, dc, group_columns, context, + rel, + dc, + group_columns, + context, ) # SQL does not care about the index, but we do not want to have any multiindices @@ -181,39 +194,82 @@ def _do_aggregations( output_column_order = group_columns.copy() # Collect all aggregations we need to do - collected_aggregations, output_column_order = self._collect_aggregations( + ( + collected_aggregations, + output_column_order, + ) = self._collect_aggregations( rel, df, cc, context, additional_column_name, output_column_order ) # SQL needs to have a column with the grouped values as the first - # output column. - # As the values of the group columns - # are the same for a single group anyways, we just use the first row - for col in group_columns: - collected_aggregations[None].append((col, col, "first", None)) + # output column. As the values of the group columns + # are the same for a single group anyways, we just use the first row. + df_result = None + if group_columns: + default_column_aggregations = [] + for col in group_columns: + default_column_aggregations.append((col, col, "first", None)) + df_result = self._apply_and_assign_aggregations( + df, + df_result, + None, + default_column_aggregations, + additional_column_name, + group_columns, + ) # Now we can go ahead and use these grouped aggregations # to perform the actual aggregation # It is very important to start with the non-filtered entry. # Otherwise we might loose some entries in the grouped columns - key = None - aggregations = collected_aggregations.pop(key) - df_result = self._perform_aggregation( - df, None, aggregations, additional_column_name, group_columns, - ) + if None in collected_aggregations: + key = None + aggregations = collected_aggregations.pop(key) + df_result = self._apply_and_assign_aggregations( + df, + df_result, + None, + aggregations, + additional_column_name, + group_columns, + ) - # Now we can also the the rest + # Now we can also add the rest for filter_column, aggregations in collected_aggregations.items(): - agg_result = self._perform_aggregation( - df, filter_column, aggregations, additional_column_name, group_columns, + df_result = self._apply_and_assign_aggregations( + df, + df_result, + filter_column, + aggregations, + additional_column_name, + group_columns, ) - # ... and finally concat the new data with the already present columns + return df_result, output_column_order + + def _apply_and_assign_aggregations( + self, + df: dd.DataFrame, + df_result: dd.DataFrame, + filter_column: str, + aggregations: List[Tuple[str, str, Any, Type]], + additional_column_name: str, + group_columns: List[str], + ): + agg_result = self._perform_aggregation( + df, + filter_column, + aggregations, + additional_column_name, + group_columns, + ) + if df_result is None: + df_result = agg_result + else: df_result = df_result.assign( **{col: agg_result[col] for col in agg_result.columns} ) - - return df_result, output_column_order + return df_result def _collect_aggregations( self, @@ -223,13 +279,15 @@ def _collect_aggregations( context: "dask_sql.Context", additional_column_name: str, output_column_order: List[str], - ) -> Tuple[Dict[Tuple[str, str], List[Tuple[str, str, Any, Type]]], List[str]]: + ) -> Tuple[ + Dict[Tuple[str, str], List[Tuple[str, str, Any, Type]]], List[str] + ]: """ Collect all aggregations together, which have the same filter column so that the aggregations only need to be done once. Returns the aggregations as mapping filter_column -> List of Aggregations - where the aggregations are in the form (input_col, output_col, aggregation function (or string)) + where the aggregations are in the form (input_col, output_col, aggregation function (or string), return_type) """ collected_aggregations = defaultdict(list) @@ -243,7 +301,9 @@ def _collect_aggregations( elif len(inputs) == 0: input_col = additional_column_name else: - input_col = tuple(cc.get_backend_by_frontend_index(i) for i in inputs) + input_col = tuple( + cc.get_backend_by_frontend_index(i) for i in inputs + ) # Extract flags (filtering/distinct) if expr.isDistinct(): # pragma: no cover @@ -251,7 +311,9 @@ def _collect_aggregations( filter_column = None if expr.hasFilter(): - filter_column = cc.get_backend_by_frontend_index(expr.filterArg) + filter_column = cc.get_backend_by_frontend_index( + expr.filterArg + ) # Find out which aggregation function to use aggregation_name = str(expr.getAggregation().getName()) @@ -261,7 +323,9 @@ def _collect_aggregations( if function_description.name == aggregation_name: return_type = function_description.return_type try: - aggregation_function = self.AGGREGATION_MAPPING[aggregation_name] + aggregation_function = self.AGGREGATION_MAPPING[ + aggregation_name + ] except KeyError: try: aggregation_function = context.functions[aggregation_name] @@ -275,7 +339,9 @@ def _collect_aggregations( else: dtype = df[input_col].dtype if pd.api.types.is_numeric_dtype(dtype): - aggregation_function = aggregation_function.numerical_aggregation + aggregation_function = ( + aggregation_function.numerical_aggregation + ) else: aggregation_function = ( aggregation_function.non_numerical_aggregation @@ -334,31 +400,50 @@ def _perform_aggregation( grouped_df = tmp_df.groupby(by=group_columns_and_nulls) - # Convert into the correct format for dask - aggregations_dict = defaultdict(dict) - multi_col_aggregations = dict() + # Dask supports two types of group-by aggregations: by calling .agg + # or .apply on a GroupBy dataframe. We want to call .agg if possible, + # as it's supposed to be faster en cleaner. But it doesn't work for + # all cases, in which case we use .apply instead. We start by + # preparing the aggregate calls in a format dask understands. + aggregate_aggregations = defaultdict(dict) + apply_aggregations = dict() for aggregation in aggregations: input_col, output_col, aggregation_f, return_type = aggregation - if isinstance(input_col, tuple): - multi_col_aggregations[output_col] = (input_col, aggregation_f, return_type) + if isinstance( + aggregation_f, (AggregationSpecification, dd.Aggregation, str) + ): + aggregate_aggregations[input_col][output_col] = aggregation_f else: - aggregations_dict[input_col][output_col] = aggregation_f + apply_aggregations[output_col] = ( + input_col, + aggregation_f, + return_type, + ) - # Now apply the aggregation + # Now we apply the aggregations agg_result = None - if len(aggregations_dict) > 0: - logger.debug(f"Performing aggregation {dict(aggregations_dict)}") - agg_result = grouped_df.agg(aggregations_dict) + if len(aggregate_aggregations) > 0: + logger.debug(f"Performing aggregation {dict(aggregate_aggregations)}") + agg_result = grouped_df.agg(aggregate_aggregations) # ... fix the column names to a single level ... agg_result.columns = agg_result.columns.get_level_values(-1) - - # apply multi-column aggregations - for output_col, (input_col, aggregation_f, return_type) in multi_col_aggregations.items(): - new_col = grouped_df.apply(lambda x: aggregation_f(*[getattr(x, col) for col in input_col])) + + # apply aggregations with .apply + for output_col, ( + input_col, + aggregation_f, + return_type, + ) in apply_aggregations.items(): + new_col = grouped_df.apply( + lambda x: aggregation_f( + *[getattr(x, col) for col in input_col] + ), + meta=(output_col, return_type) + ) if agg_result is None: agg_result = new_col.rename(output_col).to_frame() else: agg_result = agg_result.assign(**{output_col: new_col}) - + return agg_result From 5865355f2dbf1cd22379865399b1d6e4848816ab Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Wed, 24 Mar 2021 15:20:55 +0100 Subject: [PATCH 24/33] Added a rule to push filter expr out of join conditions --- .../main/java/com/dask/sql/application/DaskRuleSets.java | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java index 7bd35beca..f4551fd2b 100644 --- a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java +++ b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java @@ -39,12 +39,14 @@ private DaskRuleSets() { static final RuleSet FILTER_RULES = RuleSets.ofList( // push a filter into a join CoreRules.FILTER_INTO_JOIN, - // Jonas : We need JOIN_PUSH_TRANSITIVE_PREDICATES rule to work + // Jonas : We need both JOIN_PUSH_TRANSITIVE_PREDICATES + // and JOIN_PUSH_EXPRESSIONS rules to work // with the FILTER_INTO_JOIN rule, // otherwise we end up with filter expressions on join conditions - // (LogicalJoin(condition=[=($0, 1000)], joinType=[inner])) which + // i.e. (emp join dept on emp.deptno * 2 = dept.deptno) which // LogicalJoinPlugin can't handle. CoreRules.JOIN_PUSH_TRANSITIVE_PREDICATES, + CoreRules.JOIN_PUSH_EXPRESSIONS, // push filter into the children of a join CoreRules.JOIN_CONDITION_PUSH, // push filter through an aggregation @@ -81,7 +83,8 @@ private DaskRuleSets() { CoreRules.AGGREGATE_REMOVE, CoreRules.AGGREGATE_JOIN_REMOVE); /** - * RuleSet for merging joins + * RuleSet for merging joins. All joins are merged into a large multi-join, + * which is then optimised by one of the JOIN_REORDER_RULES. */ static final RuleSet JOIN_REORDER_PREPARE_RULES = RuleSets.ofList( // merge project to MultiJoin From 628599254a41bb30522c68268138e3c299765ccc Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Thu, 25 Mar 2021 17:47:54 +0100 Subject: [PATCH 25/33] Remove JOIN_PUSH_TRANSITIVE_PREDICATES rule --- dask_sql/physical/rel/logical/aggregate.py | 2 ++ dask_sql/physical/rel/logical/join.py | 8 -------- .../main/java/com/dask/sql/application/DaskRuleSets.java | 4 +--- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/dask_sql/physical/rel/logical/aggregate.py b/dask_sql/physical/rel/logical/aggregate.py index 22de5f079..af17bab13 100644 --- a/dask_sql/physical/rel/logical/aggregate.py +++ b/dask_sql/physical/rel/logical/aggregate.py @@ -435,6 +435,8 @@ def _perform_aggregation( aggregation_f, return_type, ) in apply_aggregations.items(): + if not isinstance(input_col, tuple): + input_col = (input_col,) new_col = grouped_df.apply( lambda x: aggregation_f( *[getattr(x, col) for col in input_col] diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py index 052121cf3..6833eadf9 100644 --- a/dask_sql/physical/rel/logical/join.py +++ b/dask_sql/physical/rel/logical/join.py @@ -92,14 +92,6 @@ def convert( # We therefore create new columns on purpose, which have a distinct name. assert len(lhs_on) == len(rhs_on) if lhs_on: - # Doing inplace join seems to have unexpected side effects due to - # the join columns being shared in the resulting df for the lhs and rhs. - # It needs to be reworked. - # if join_type == "inner": - # return self._do_inner_join_inplace( - # df_lhs_renamed, df_rhs_renamed, lhs_on, rhs_on, - # join_type, rel, filter_condition, context - # ) lhs_columns_to_add = { f"common_{i}": df_lhs_renamed.iloc[:, index] for i, index in enumerate(lhs_on) diff --git a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java index f4551fd2b..50ebf7e59 100644 --- a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java +++ b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java @@ -39,13 +39,11 @@ private DaskRuleSets() { static final RuleSet FILTER_RULES = RuleSets.ofList( // push a filter into a join CoreRules.FILTER_INTO_JOIN, - // Jonas : We need both JOIN_PUSH_TRANSITIVE_PREDICATES - // and JOIN_PUSH_EXPRESSIONS rules to work + // Jonas : We need JOIN_PUSH_EXPRESSIONS rule to work // with the FILTER_INTO_JOIN rule, // otherwise we end up with filter expressions on join conditions // i.e. (emp join dept on emp.deptno * 2 = dept.deptno) which // LogicalJoinPlugin can't handle. - CoreRules.JOIN_PUSH_TRANSITIVE_PREDICATES, CoreRules.JOIN_PUSH_EXPRESSIONS, // push filter into the children of a join CoreRules.JOIN_CONDITION_PUSH, From 8798026858646b617591c05eb43e2d7a224a398d Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Tue, 6 Apr 2021 09:53:31 +0200 Subject: [PATCH 26/33] Raise an assertionerror when join condition is a constant to treat it as a filter expression --- dask_sql/physical/rel/logical/aggregate.py | 2 +- dask_sql/physical/rel/logical/join.py | 5 +++++ .../src/main/java/com/dask/sql/application/DaskRuleSets.java | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/dask_sql/physical/rel/logical/aggregate.py b/dask_sql/physical/rel/logical/aggregate.py index af17bab13..57a9912d2 100644 --- a/dask_sql/physical/rel/logical/aggregate.py +++ b/dask_sql/physical/rel/logical/aggregate.py @@ -402,7 +402,7 @@ def _perform_aggregation( # Dask supports two types of group-by aggregations: by calling .agg # or .apply on a GroupBy dataframe. We want to call .agg if possible, - # as it's supposed to be faster en cleaner. But it doesn't work for + # as it's supposed to be faster and cleaner. But it doesn't work for # all cases, in which case we use .apply instead. We start by # preparing the aggregate calls in a format dask understands. aggregate_aggregations = defaultdict(dict) diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py index 6833eadf9..13a8215cc 100644 --- a/dask_sql/physical/rel/logical/join.py +++ b/dask_sql/physical/rel/logical/join.py @@ -316,6 +316,11 @@ def _extract_lhs_rhs(self, rex): isinstance(operand.getOperator(), org.apache.calcite.sql.fun.SqlCastFunction) ): indices.append(operand.operands[0].getIndex()) + elif isinstance(operand, org.apache.calcite.rex.RexLiteral): + # i.e. join condition is col.id == constant + # raising an AssertionError means that the RexExpression will be added + # as a filter condition to be applied after the join. + raise AssertionError("This is actually a filter condition") else: raise TypeError( "Invalid join condition" diff --git a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java index 50ebf7e59..48ebd559e 100644 --- a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java +++ b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java @@ -44,6 +44,7 @@ private DaskRuleSets() { // otherwise we end up with filter expressions on join conditions // i.e. (emp join dept on emp.deptno * 2 = dept.deptno) which // LogicalJoinPlugin can't handle. + // CoreRules.JOIN_PUSH_TRANSITIVE_PREDICATES, CoreRules.JOIN_PUSH_EXPRESSIONS, // push filter into the children of a join CoreRules.JOIN_CONDITION_PUSH, From 01427b66c25fb96f3d8afe496bde377fdaa1ed50 Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Tue, 13 Apr 2021 17:29:37 +0200 Subject: [PATCH 27/33] Remove Project_filter_transpose rule that was causing some errors in neurolang tests --- .../main/java/com/dask/sql/application/DaskRuleSets.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java index 48ebd559e..13f9a3542 100644 --- a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java +++ b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java @@ -50,8 +50,11 @@ private DaskRuleSets() { CoreRules.JOIN_CONDITION_PUSH, // push filter through an aggregation CoreRules.FILTER_AGGREGATE_TRANSPOSE, - // push a filter past a project - CoreRules.FILTER_PROJECT_TRANSPOSE, + // Jonas : the FILTER_PROJECT_TRANSPOSE rule causes Calcite to push filter conditions (like x = y) + // into a projection (i.e. project(x, y, z)) which (sometimes) causes errors as the project will + // select column y instead of x for instance and loose the reference to y. + // It's the reason for several test failures in neurolang. + // CoreRules.FILTER_PROJECT_TRANSPOSE, // push a filter past a setop CoreRules.FILTER_SET_OP_TRANSPOSE, CoreRules.FILTER_MERGE); From f8e05d87774fc5f04d5bfeca83263a2002dfbf6f Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Wed, 14 Apr 2021 14:17:56 +0200 Subject: [PATCH 28/33] Refactor aggregate to call drop_duplicate when we're doing a groupby on all the columns without an aggregation function (i.e calling distinct()) --- dask_sql/physical/rel/logical/aggregate.py | 77 ++++++++++++---------- 1 file changed, 42 insertions(+), 35 deletions(-) diff --git a/dask_sql/physical/rel/logical/aggregate.py b/dask_sql/physical/rel/logical/aggregate.py index 57a9912d2..9129f45f8 100644 --- a/dask_sql/physical/rel/logical/aggregate.py +++ b/dask_sql/physical/rel/logical/aggregate.py @@ -141,8 +141,6 @@ def convert( cc.get_backend_by_frontend_index(i) for i in group_column_indices ] - dc = DataContainer(df, cc) - if not group_columns: # There was actually no GROUP BY specified in the SQL # Still, this plan can also be used if we need to aggregate something over the full @@ -150,20 +148,40 @@ def convert( # To reuse the code, we just create a new column at the end with a single value logger.debug("Performing full-table aggregation") - # Do all aggregates - df_result, output_column_order = self._do_aggregations( - rel, - dc, - group_columns, - context, + # Add an entry for every grouped column, as SQL wants them first + output_column_order = group_columns.copy() + additional_column_name = new_temporary_column(df) + + # Collect all aggregations we need to do + ( + collected_aggregations, + output_column_order, + ) = self._collect_aggregations( + rel, df, cc, context, additional_column_name, output_column_order ) - # SQL does not care about the index, but we do not want to have any multiindices - df_agg = df_result.reset_index(drop=True) + # Check if we're doing a real aggregation or just droping duplicates + if ( + len(group_columns) == len(cc.columns) + and len(collected_aggregations) == 0 + ): + # Just drop duplicates + df_agg = df.drop_duplicates() + else: + # Do the aggregations + df_result = self._do_aggregations( + df, + group_columns, + collected_aggregations, + additional_column_name, + ) + + # SQL does not care about the index, but we do not want to have any multiindices + df_agg = df_result.reset_index(drop=True) - # Fix the column names and the order of them, as this was messed with during the aggregations - df_agg.columns = df_agg.columns.get_level_values(-1) - cc = ColumnContainer(df_agg.columns).limit_to(output_column_order) + # Fix the column names and the order of them, as this was messed with during the aggregations + df_agg.columns = df_agg.columns.get_level_values(-1) + cc = ColumnContainer(df_agg.columns).limit_to(output_column_order) cc = self.fix_column_to_row_type(cc, rel.getRowType()) dc = DataContainer(df_agg, cc) @@ -172,35 +190,22 @@ def convert( def _do_aggregations( self, - rel: "org.apache.calcite.rel.RelNode", - dc: DataContainer, + df: dd.DataFrame, group_columns: List[str], - context: "dask_sql.Context", + collected_aggregations: Dict[ + Tuple[str, str], List[Tuple[str, str, Any, Type]] + ], + additional_column_name: str, ) -> Tuple[dd.DataFrame, List[str]]: """ Main functionality: return the result dataframe and the output column order """ - df = dc.df - cc = dc.column_container - # We might need it later. # If not, lets hope that adding a single column should not # be a huge problem... - additional_column_name = new_temporary_column(df) df = df.assign(**{additional_column_name: 1}) - # Add an entry for every grouped column, as SQL wants them first - output_column_order = group_columns.copy() - - # Collect all aggregations we need to do - ( - collected_aggregations, - output_column_order, - ) = self._collect_aggregations( - rel, df, cc, context, additional_column_name, output_column_order - ) - # SQL needs to have a column with the grouped values as the first # output column. As the values of the group columns # are the same for a single group anyways, we just use the first row. @@ -245,7 +250,7 @@ def _do_aggregations( group_columns, ) - return df_result, output_column_order + return df_result def _apply_and_assign_aggregations( self, @@ -398,7 +403,7 @@ def _perform_aggregation( # without any groupby statement group_columns_and_nulls = [additional_column_name] - grouped_df = tmp_df.groupby(by=group_columns_and_nulls) + grouped_df = tmp_df.groupby(by=group_columns_and_nulls, sort=False) # Dask supports two types of group-by aggregations: by calling .agg # or .apply on a GroupBy dataframe. We want to call .agg if possible, @@ -423,7 +428,9 @@ def _perform_aggregation( # Now we apply the aggregations agg_result = None if len(aggregate_aggregations) > 0: - logger.debug(f"Performing aggregation {dict(aggregate_aggregations)}") + logger.debug( + f"Performing aggregation {dict(aggregate_aggregations)}" + ) agg_result = grouped_df.agg(aggregate_aggregations) # ... fix the column names to a single level ... @@ -441,7 +448,7 @@ def _perform_aggregation( lambda x: aggregation_f( *[getattr(x, col) for col in input_col] ), - meta=(output_col, return_type) + meta=(output_col, return_type), ) if agg_result is None: agg_result = new_col.rename(output_col).to_frame() From 4dc898e16aab43cdd310e635818c41ba38737644 Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Thu, 15 Apr 2021 14:54:16 +0200 Subject: [PATCH 29/33] Remove Project_join_transpose rules that were causing issues in one of the tests in neurolang --- dask_sql/physical/rel/logical/join.py | 119 ++++-------------- .../dask/sql/application/DaskRuleSets.java | 8 +- 2 files changed, 28 insertions(+), 99 deletions(-) diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py index 13a8215cc..e85d07349 100644 --- a/dask_sql/physical/rel/logical/join.py +++ b/dask_sql/physical/rel/logical/join.py @@ -42,7 +42,9 @@ class LogicalJoinPlugin(BaseRelPlugin): } def convert( - self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" + self, + rel: "org.apache.calcite.rel.RelNode", + context: "dask_sql.Context", ) -> DataContainer: # Joining is a bit more complicated, so lets do it in steps: @@ -78,9 +80,13 @@ def convert( # As this is probably non-sense for large tables, but there is no other # known solution so far. join_condition = rel.getCondition() - lhs_on, rhs_on, filter_condition = self._split_join_condition(join_condition) + lhs_on, rhs_on, filter_condition = self._split_join_condition( + join_condition + ) - logger.debug(f"Joining with type {join_type} on columns {lhs_on}, {rhs_on}.") + logger.debug( + f"Joining with type {join_type} on columns {lhs_on}, {rhs_on}." + ) # lhs_on and rhs_on are the indices of the columns to merge on. # The given column indices are for the full, merged table which consists @@ -108,13 +114,19 @@ def convert( if join_type in ["inner", "right"]: df_lhs_filter = reduce( operator.and_, - [~df_lhs_renamed.iloc[:, index].isna() for index in lhs_on], + [ + ~df_lhs_renamed.iloc[:, index].isna() + for index in lhs_on + ], ) df_lhs_renamed = df_lhs_renamed[df_lhs_filter] if join_type in ["inner", "left"]: df_rhs_filter = reduce( operator.and_, - [~df_rhs_renamed.iloc[:, index].isna() for index in rhs_on], + [ + ~df_rhs_renamed.iloc[:, index].isna() + for index in rhs_on + ], ) df_rhs_renamed = df_rhs_renamed[df_rhs_filter] else: @@ -138,7 +150,9 @@ def convert( # 5. Now we can finally merge on these columns # The resulting dataframe will contain all (renamed) columns from the lhs and rhs # plus the added columns - df = dd.merge(df_lhs_with_tmp, df_rhs_with_tmp, on=added_columns, how=join_type) + df = dd.merge( + df_lhs_with_tmp, df_rhs_with_tmp, on=added_columns, how=join_type + ) # 6. So the next step is to make sure # we have the correct column order (and to remove the temporary join columns) @@ -176,89 +190,6 @@ def convert( dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) return dc - def _do_inner_join_inplace( - self, df_lhs_renamed, df_rhs_renamed, lhs_on, rhs_on, - join_type, rel, filter_condition, context - ): - """ - Same method as above, but instead of adding temporary join columns to merge on, - we merge on columns already in the dataframe by renaming them. - """ - df_lhs_to_merge = df_lhs_renamed.rename(columns={ - df_lhs_renamed.columns[index]: f"common_{i}" - for i, index in enumerate(lhs_on) - }) - df_rhs_to_merge = df_rhs_renamed.rename(columns={ - df_rhs_renamed.columns[index]: f"common_{i}" - for i, index in enumerate(rhs_on) - }) - on_columns = [f"common_{i}" for i in range(len(lhs_on))] - - # SQL compatibility: when joining on columns that - # contain NULLs, pandas will actually happily - # keep those NULLs. That is however not compatible with - # SQL, so we get rid of them here - df_lhs_filter = reduce( - operator.and_, - [~df_lhs_to_merge.iloc[:, index].isna() for index in lhs_on], - ) - df_lhs_to_merge = df_lhs_to_merge[df_lhs_filter] - df_rhs_filter = reduce( - operator.and_, - [~df_rhs_to_merge.iloc[:, index].isna() for index in rhs_on], - ) - df_rhs_to_merge = df_rhs_to_merge[df_rhs_filter] - - df = dd.merge(df_lhs_to_merge, df_rhs_to_merge, on=on_columns, how=join_type) - - # 6. So the next step is to make sure - # we have the correct column order. - correct_column_order = list(df_lhs_renamed.columns) + list( - df_rhs_renamed.columns - ) - # We update the columns for the rhs to point to the resulting join columns - if lhs_on: - for i, on_column in enumerate(on_columns): - correct_column_order[lhs_on[i]] = on_column - correct_column_order[len(df_lhs_renamed.columns) + rhs_on[i]] = on_column - cc = ColumnContainer(df.columns).limit_to(correct_column_order) - - # and to rename them like the rel specifies - row_type = rel.getRowType() - field_specifications = [str(f) for f in row_type.getFieldNames()] - l_lhs = len(df_lhs_renamed.columns) - cc = cc.rename( - { - from_col: to_col - for from_col, to_col in zip(cc.columns[:l_lhs], field_specifications[:l_lhs]) - } - ) - cc = cc.rename( - { - from_col: to_col - for from_col, to_col in zip(cc.columns[l_lhs:], field_specifications[l_lhs:]) - } - ) - cc = self.fix_column_to_row_type(cc, rel.getRowType()) - dc = DataContainer(df, cc) - - # 7. Last but not least we apply any filters by and-chaining together the filters - if filter_condition: - # This line is a bit of code duplication with RexCallPlugin - but I guess it is worth to keep it separate - filter_condition = reduce( - operator.and_, - [ - RexConverter.convert(rex, dc, context=context) - for rex in filter_condition - ], - ) - logger.debug(f"Additionally applying filter {filter_condition}") - df = filter_or_scalar(df, filter_condition) - dc = DataContainer(df, cc) - - dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) - return dc - def _split_join_condition( self, join_condition: "org.apache.calcite.rex.RexCall" ) -> Tuple[List[str], List[str], List["org.apache.calcite.rex.RexCall"]]: @@ -311,9 +242,11 @@ def _extract_lhs_rhs(self, rex): for operand in operands: if isinstance(operand, org.apache.calcite.rex.RexInputRef): indices.append(operand.getIndex()) - elif ( - isinstance(operand, org.apache.calcite.rex.RexCall) and - isinstance(operand.getOperator(), org.apache.calcite.sql.fun.SqlCastFunction) + elif isinstance( + operand, org.apache.calcite.rex.RexCall + ) and isinstance( + operand.getOperator(), + org.apache.calcite.sql.fun.SqlCastFunction, ): indices.append(operand.operands[0].getIndex()) elif isinstance(operand, org.apache.calcite.rex.RexLiteral): @@ -331,5 +264,3 @@ def _extract_lhs_rhs(self, rex): lhs_index, rhs_index = rhs_index, lhs_index return lhs_index, rhs_index - - diff --git a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java index 13f9a3542..d15236606 100644 --- a/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java +++ b/planner/src/main/java/com/dask/sql/application/DaskRuleSets.java @@ -61,7 +61,8 @@ private DaskRuleSets() { /** * RuleSet about project Dont' add CoreRules.PROJECT_REMOVE */ - static final RuleSet PROJECT_RULES = RuleSets.ofList(CoreRules.PROJECT_MERGE, CoreRules.AGGREGATE_PROJECT_MERGE, + static final RuleSet PROJECT_RULES = RuleSets.ofList( + CoreRules.AGGREGATE_PROJECT_MERGE, // push a projection past a filter CoreRules.PROJECT_FILTER_TRANSPOSE, // merge projections @@ -69,10 +70,7 @@ private DaskRuleSets() { // removes constant keys from an Agg CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS, // push project through a Union - CoreRules.PROJECT_SET_OP_TRANSPOSE, - CoreRules.JOIN_PROJECT_LEFT_TRANSPOSE, - CoreRules.JOIN_PROJECT_RIGHT_TRANSPOSE, - CoreRules.JOIN_PROJECT_BOTH_TRANSPOSE + CoreRules.PROJECT_SET_OP_TRANSPOSE ); /** From d01dbd5092ed75c6ba9bc266719e5643256a7ddf Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Mon, 21 Jun 2021 14:45:58 +0200 Subject: [PATCH 30/33] Add missing dependency for distributed --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 86e0e424f..d5a0cf84c 100755 --- a/setup.py +++ b/setup.py @@ -75,6 +75,7 @@ def run(self): setup_requires=["setuptools_scm"] + sphinx_requirements, install_requires=[ "dask[dataframe]>=2.19.0", + "distributed", "pandas<1.2.0,>=1.0.0", # pandas 1.2.0 introduced float NaN dtype, # which is currently not working with dask, # so the test is failing, see https://github.com/dask/dask/issues/7156 From 31e475d10c80381cb3528650b189532aca8c2046 Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Tue, 7 Sep 2021 10:21:51 +0200 Subject: [PATCH 31/33] Update Calcite version to 1.27 --- planner/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/planner/pom.xml b/planner/pom.xml index a3ce2c1f5..c10fdca06 100755 --- a/planner/pom.xml +++ b/planner/pom.xml @@ -17,7 +17,7 @@ 1.7.29 ${java.version} ${java.version} - 1.26.0 + 1.27.0 From 64b8a8c947b84775bda0988f4974ddd8eeae6969 Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Mon, 4 Oct 2021 10:51:09 +0200 Subject: [PATCH 32/33] Update pandas version --- setup.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/setup.py b/setup.py index d5a0cf84c..e816cc3a6 100755 --- a/setup.py +++ b/setup.py @@ -76,10 +76,7 @@ def run(self): install_requires=[ "dask[dataframe]>=2.19.0", "distributed", - "pandas<1.2.0,>=1.0.0", # pandas 1.2.0 introduced float NaN dtype, - # which is currently not working with dask, - # so the test is failing, see https://github.com/dask/dask/issues/7156 - # below 1.0, there were no nullable ext. types + "pandas>=1.0.0", # below 1.0, there were no nullable ext. types "jpype1>=1.0.2", "fastapi>=0.61.1", "uvicorn>=0.11.3", From e78363c7ed3daac055741970d13a2a57a6ea9352 Mon Sep 17 00:00:00 2001 From: Jonas Renault Date: Wed, 27 Oct 2021 11:02:31 +0200 Subject: [PATCH 33/33] Add a persist call when applying multiple aggregates so that they are evaluated in the for loop, and not at the end with only the last function --- dask_sql/physical/rel/logical/aggregate.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dask_sql/physical/rel/logical/aggregate.py b/dask_sql/physical/rel/logical/aggregate.py index 9129f45f8..55256d5dd 100644 --- a/dask_sql/physical/rel/logical/aggregate.py +++ b/dask_sql/physical/rel/logical/aggregate.py @@ -436,7 +436,9 @@ def _perform_aggregation( # ... fix the column names to a single level ... agg_result.columns = agg_result.columns.get_level_values(-1) - # apply aggregations with .apply + # apply aggregations with .apply. The .persist() calls on agg_result are + # important otherwise when the dataframe gets computed, the lambda function + # to apply will be the last one of the list. for output_col, ( input_col, aggregation_f, @@ -451,8 +453,8 @@ def _perform_aggregation( meta=(output_col, return_type), ) if agg_result is None: - agg_result = new_col.rename(output_col).to_frame() + agg_result = new_col.rename(output_col).to_frame().persist() else: - agg_result = agg_result.assign(**{output_col: new_col}) + agg_result = agg_result.assign(**{output_col: new_col}).persist() return agg_result