Skip to content

Commit

Permalink
Splitting implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
aarati-K committed Mar 15, 2024
1 parent aef01e0 commit 3bbb931
Showing 1 changed file with 226 additions and 25 deletions.
251 changes: 226 additions & 25 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,19 @@
import toolz
from packaging.version import parse as vparse

from sqlglot import parse_one, exp
from sqlglot.optimizer import optimize
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.simplify import simplify
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries

import ibis.common.exceptions as exc
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis import util
Expand All @@ -40,8 +51,6 @@
import pandas as pd
import torch

import ibis.expr.operations as ops


def normalize_filenames(source_list):
# Promote to list
Expand Down Expand Up @@ -77,6 +86,201 @@ class Backend(BaseAlchemyBackend, CanCreateSchema):
compiler = DuckDBSQLCompiler
supports_create_or_replace = True

override_schemas = {}
# Sample schema object
# override_schemas = {"accidents": {
# "fact": "accidents_fact",
# # {table_name: joining_key}
# "dimension_tables": {
# "accidents_dim0": "p0",
# "accidents_dim1": "p1",
# "accidents_dim2": "p2",
# "accidents_dim3": "p3",
# "accidents_dim4": "p4",
# "accidents_dim5": "p5",
# "accidents_dim6": "p6",
# "accidents_dim7": "p7",
# "accidents_dim8": "p8",
# "accidents_dim9": "p9",
# "accidents_dim10": "p10",
# "accidents_dim11": "p11",
# "accidents_dim12": "p12"},
# "col_to_table_map": {'ID': 'accidents_fact',
# 'Severity': 'accidents_dim1',
# 'Start_Time': 'accidents_fact',
# 'End_Time': 'accidents_fact',
# 'Start_Lat': 'accidents_fact',
# 'Start_Lng': 'accidents_fact',
# 'End_Lat': 'accidents_fact',
# 'End_Lng': 'accidents_fact',
# 'Distance(mi)': 'accidents_dim8',
# 'Description': 'accidents_dim12',
# 'Number': 'accidents_fact',
# 'Street': 'accidents_dim9',
# 'Side': 'accidents_dim0',
# 'City': 'accidents_dim7',
# 'County': 'accidents_dim6',
# 'State': 'accidents_dim1',
# 'Zipcode': 'accidents_dim10',
# 'Country': 'accidents_dim0',
# 'Timezone': 'accidents_dim1',
# 'Airport_Code': 'accidents_fact',
# 'Weather_Timestamp': 'accidents_dim11',
# 'Temperature(F)': 'accidents_dim4',
# 'Wind_Chill(F)': 'accidents_dim4',
# 'Humidity(%)': 'accidents_dim2',
# 'Pressure(in)': 'accidents_dim5',
# 'Visibility(mi)': 'accidents_dim2',
# 'Wind_Direction': 'accidents_dim1',
# 'Wind_Speed(mph)': 'accidents_dim3',
# 'Precipitation(in)': 'accidents_dim3',
# 'Weather_Condition': 'accidents_dim2',
# 'Amenity': 'accidents_dim0',
# 'Bump': 'accidents_dim0',
# 'Crossing': 'accidents_dim0',
# 'Give_Way': 'accidents_dim0',
# 'Junction': 'accidents_dim0',
# 'No_Exit': 'accidents_dim0',
# 'Railway': 'accidents_dim0',
# 'Roundabout': 'accidents_dim0',
# 'Station': 'accidents_dim0',
# 'Stop': 'accidents_dim0',
# 'Traffic_Calming': 'accidents_dim0',
# 'Traffic_Signal': 'accidents_dim0',
# 'Turning_Loop': 'accidents_dim0',
# 'Sunrise_Sunset': 'accidents_dim0',
# 'Civil_Twilight': 'accidents_dim0',
# 'Nautical_Twilight': 'accidents_dim0',
# 'Astronomical_Twilight': 'accidents_dim0'}
# }}

def register_schema(self, schema):
for key, value in schema.items():
self.override_schemas[key] = value

def rewrite_sql(self, sql : str) -> str:
expression_tree = optimize(sql)
table_names = set()
column_names = set()

# Transformer function on the expression tree
# to obtain the table and column names in the query
# Query might not have columns!
# For example: SELECT * FROM accidents LIMIT 5
def get_table_and_column_names(node):
if isinstance(node, exp.From) and node.name:
table_names.add(node.name)
if isinstance(node, exp.Column):
column_names.add(node.name)
return node

expression_tree = expression_tree.transform(get_table_and_column_names)

# I am not sure if this is correct logic
if not len(column_names):
return sql

# Check the override_schemas to see if any of the tables is a view
# Let's not consider multi-table queries for now, I haven't encountered them
for table_name in table_names:
try:
schema = self.override_schemas[table_name]
except KeyError:
continue
if len(schema['dimension_tables']) == 0:
continue

# Collect the dimension tables to be joined
dim_to_join = set()
for col in column_names:
try:
dim_name = schema['col_to_table_map'][col]
except:
continue
if dim_name == schema['fact']:
continue
dim_to_join.add(dim_name)

# Rewrite the from string
if not len(dim_to_join):
# All columns are in the fact table
rewrite_string = "FROM " + schema['fact']
else:
# There are dim tables to join
join_clauses = []
for dim in dim_to_join:
joining_col = schema['dimension_tables'][dim]
join_string = schema['fact'] + '.' + joining_col + "=" + dim + "." + joining_col
join_clauses.append(join_string)
join_clause = ' AND '.join(join_clauses)
dim_to_join.add(schema['fact'])
table_clause = ','.join(dim_to_join)
rewrite_string = "FROM (SELECT * FROM " + table_clause + " WHERE " + join_clause + ")"

def rewrite_from(node):
if isinstance(node, exp.From) and node.name == table_name:
updated = rewrite_string
if node.alias_or_name:
updated += " AS " + node.alias_or_name
return parse_one(updated, into=exp.From)
return node

expression_tree = expression_tree.transform(rewrite_from)

return expression_tree.sql()

def execute(
self,
expr: ir.Expr,
params: Mapping[ir.Scalar, Any] | None = None,
limit: str = "default",
**kwargs: Any,
):
"""Compile and execute an Ibis expression.
Compile and execute Ibis expression using this backend client
interface, returning results in-memory in the appropriate object type
Parameters
----------
expr
Ibis expression
limit
For expressions yielding result sets; retrieve at most this number
of values/rows. Overrides any limit already set on the expression.
params
Named unbound parameters
kwargs
Backend specific arguments. For example, the clickhouse backend
uses this to receive `external_tables` as a dictionary of pandas
DataFrames.
Returns
-------
DataFrame | Series | Scalar
* `Table`: pandas.DataFrame
* `Column`: pandas.Series
* `Scalar`: Python scalar value
"""
# TODO Reconsider having `kwargs` here. It's needed to support
# `external_tables` in clickhouse, but better to deprecate that
# feature than all this magic.
# we don't want to pass `timecontext` to `raw_sql`
self._run_pre_execute_hooks(expr)

kwargs.pop("timecontext", None)
# query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
# sql = query_ast.compile()
sql = self._to_sql(expr, limit=limit, params=params)
sql = self.rewrite_sql(sql)
self._log(sql)

schema = expr.as_table().schema()
with self._safe_raw_sql(sql, **kwargs) as cursor:
result = self.fetch_from_cursor(cursor, schema)

return expr.__pandas_result__(result)

@property
def current_database(self) -> str:
return self._scalar_query(sa.select(sa.func.current_database()))
Expand All @@ -88,31 +292,30 @@ def list_databases(self, like: str | None = None) -> list[str]:
schema="information_schema",
)

query = sa.select(sa.distinct(s.c.catalog_name)).order_by(s.c.catalog_name)
query = sa.select(sa.distinct(s.c.catalog_name))
with self.begin() as con:
results = list(con.execute(query).scalars())
return self._filter_with_like(results, like=like)

@property
def current_schema(self) -> str:
return self._scalar_query(sa.select(sa.func.current_schema()))

def list_schemas(self, like: str | None = None) -> list[str]:
s = sa.table(
"schemata",
sa.column("catalog_name", sa.TEXT()),
sa.column("schema_name", sa.TEXT()),
schema="information_schema",
def list_schemas(
self, like: str | None = None, database: str | None = None
) -> list[str]:
# override duckdb because all databases are always visible
text = """\
SELECT schema_name
FROM information_schema.schemata
WHERE catalog_name = :database"""
query = sa.text(text).bindparams(
database=database if database is not None else self.current_database
)

query = (
sa.select(s.c.schema_name)
.where(s.c.catalog_name == sa.func.current_database())
.order_by(s.c.schema_name)
)
with self.begin() as con:
results = list(con.execute(query).scalars())
return self._filter_with_like(results, like=like)
schemas = list(con.execute(query).scalars())
return self._filter_with_like(schemas, like=like)

@property
def current_schema(self) -> str:
return self._scalar_query(sa.select(sa.func.current_schema()))

@staticmethod
def _convert_kwargs(kwargs: MutableMapping) -> None:
Expand Down Expand Up @@ -711,11 +914,9 @@ def attach_sqlite(
con.execute(sa.text(f"CALL sqlite_attach('{path}', overwrite={overwrite})"))

def _run_pre_execute_hooks(self, expr: ir.Expr) -> None:
from ibis.expr.analysis import find_physical_tables

# Warn for any tables depending on RecordBatchReaders that have already
# started being consumed.
for t in find_physical_tables(expr.op()):
for t in expr.op().find(ops.PhysicalTable):
started = self._record_batch_readers_consumed.get(t.name)
if started is True:
warnings.warn(
Expand Down Expand Up @@ -1029,8 +1230,8 @@ def _register_udfs(self, expr: ir.Expr) -> None:
def _compile_udf(self, udf_node: ops.ScalarUDF) -> None:
func = udf_node.__func__
name = func.__name__
input_types = [DuckDBType.to_string(arg.output_dtype) for arg in udf_node.args]
output_type = DuckDBType.to_string(udf_node.output_dtype)
input_types = [DuckDBType.to_string(arg.dtype) for arg in udf_node.args]
output_type = DuckDBType.to_string(udf_node.dtype)

def register_udf(con):
return con.connection.create_function(
Expand Down

0 comments on commit 3bbb931

Please sign in to comment.