Skip to content
Draft
Empty file.
120 changes: 120 additions & 0 deletions src/databricks/labs/lakebridge/reconcile/design/expressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import dataclasses
import typing as t
from abc import ABC, abstractmethod
from functools import reduce

import sqlglot.expressions as e
from duckdb.duckdb import alias
from sqlglot.dialects import Dialect as SqlglotDialect

DialectType = t.Union[str, SqlglotDialect, t.Type[SqlglotDialect], None]

@dataclasses.dataclass(frozen=True)
class ExpressionTransformation:
func: t.Callable # isnt this Func
args: dict


class AnyExpression(ABC):
@abstractmethod
def build(self) -> str:
pass


class ExpressionBuilder(AnyExpression):
_expression: e.Expression

def __init__(self, column_name: str, dialect: str, table_name: str | None = None):
self._column_name = column_name
self._alias = None
self._table_name = table_name
self._dialect = dialect
self._transformations: list[ExpressionTransformation] = []

def build(self) -> str:
if self._table_name:
column = e.Column(this=self._column_name, table=self._table_name)
else:
column = e.Column(this=self._column_name, quoted=False)
aliased = e.Alias(this=column, alias=self._alias) if self._alias else column
transformed = self._apply_transformations(aliased)
select_stmt = e.select(transformed).sql(dialect=self._dialect)
return select_stmt.removeprefix("SELECT ") # return only column with the transformations

def _apply_transformations(self, column: e.Expression) -> e.Expression:
exp = column
for transformation in self._transformations:
exp = transformation.func(this=exp.copy(), **transformation.args) # add error handling
return exp

def column_name(self, name: str):
self._column_name = name
return self

def alias(self, alias: str | None):
self._alias = alias
return self

def table_name(self, name: str):
self._column_name = name
return self

def transform(self, func: t.Callable, **kwargs):
transform = ExpressionTransformation(func, kwargs)
self._transformations.append(transform)
return self

def concat(self, other: "ExpressionBuilder"):
pass

class HashExpressionsBuilder(AnyExpression):

def __init__(self, dialect: str, columns: list[ExpressionBuilder]):
self._dialect = dialect
self._alias = None
self._expressions: list[ExpressionBuilder] = columns

def build(self) -> str:
columns_to_hash = [col.alias(None).build() for col in self._expressions]
columns_to_hash_expr = [e.Column(this=col) for col in columns_to_hash]
concat_expr = e.Concat(expressions=columns_to_hash_expr)
if self._dialect == "oracle":
concat_expr = reduce(lambda x, y: e.DPipe(this=x, expression=y), concat_expr.expressions)
match self._dialect: # Implement for the rest
case "tsql": return (
"CONVERT(VARCHAR(256), HASHBYTES("
"'SHA2_256', CONVERT(VARCHAR(256),{})), 2)"
f" AS {self._alias}" if self._alias else ""
.format(concat_expr.sql(dialect=self._dialect))
)
case _:
sha = e.SHA2(this=concat_expr, length=e.Literal(this=256, is_string=False))
if self._alias: sha = e.Alias(this=sha, alias=self._alias)
return sha.sql(dialect=self._dialect)

def alias(self, alias: str | None):
self._alias = alias
return self


class QueryBuilder:

def __init__(self, dialect: str, columns: list[AnyExpression]):
self._dialect = dialect
self._expressions: list[AnyExpression] = columns

def build(self) -> str:
select = [ex.build() for ex in self._expressions]
return e.select(*select).from_(":table").sql(dialect=self._dialect)


def coalesce(column: ExpressionBuilder, default=0, is_string=False) -> ExpressionBuilder:
expressions = [e.Literal(this=default, is_string=is_string)]
return column.transform(e.Coalesce, expressions=expressions)

def trim(column: ExpressionBuilder) -> ExpressionBuilder:
return column.transform(e.Trim)

def unix_time(column: ExpressionBuilder):
return column.transform(e.TimeStrToUnix) #placeholder

238 changes: 238 additions & 0 deletions src/databricks/labs/lakebridge/reconcile/design/normalizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import dataclasses
from abc import ABC, abstractmethod

import expressions as e
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
from utypes import ExternalType, UType, ColumnTypeName


@dataclasses.dataclass(frozen=True)
class ExternalColumnDefinition:
column_name: str
data_type: ExternalType
encoding: str = "utf-8"

@dataclasses.dataclass(frozen=True)
class DatetimeColumnDefinition(ExternalColumnDefinition):
timezone: str = "UTC"


class AbstractNormalizer(ABC):
@classmethod
@abstractmethod
def registry_key_family(cls) -> str:
pass

@classmethod
@abstractmethod
def registry_key(cls) -> str:
pass

@abstractmethod
def normalize(self, column: e.ExpressionBuilder, dialect: e.DialectType, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
pass

class UniversalNormalizer(AbstractNormalizer, ABC):
@classmethod
def registry_key_family(cls) -> str:
return "Universal"

class HandleNullsAndTrimNormalizer(UniversalNormalizer):
@classmethod
def registry_key(cls) -> str:
return cls.__name__

def normalize(self, column: e.ExpressionBuilder, dialect: e.DialectType, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
return e.coalesce(e.trim(column), "__null_recon__", is_string=True)

class QuoteIdentifierNormalizer(UniversalNormalizer):
@classmethod
def registry_key(cls) -> str:
return cls.__name__

def normalize(self, column: e.ExpressionBuilder, dialect: e.DialectType, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
match dialect:
case "oracle": return self._quote_oracle(column, column_def)
case "databricks": return self._quote_databricks(column, column_def)
case "snowflake": return self._quote_snowflake(column, column_def)
case _: return column # instead of error, return as is

def _quote_oracle(self, column: e.ExpressionBuilder, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
normalized = DialectUtils.normalize_identifier(
column_def.column_name,
source_start_delimiter='"',
source_end_delimiter='"',
).source_normalized
return column.column_name(normalized)

def _quote_databricks(self, column: e.ExpressionBuilder, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
normalized = DialectUtils.ansi_quote_identifier(column_def.column_name)
return column.column_name(normalized)

def _quote_snowflake(self, column: e.ExpressionBuilder, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
normalized = DialectUtils.normalize_identifier(
column_def.column_name,
source_start_delimiter='"',
source_end_delimiter='"',
).source_normalized
return column.column_name(normalized)


class AbstractTypeNormalizer(AbstractNormalizer):
@classmethod
def registry_key_family(cls) -> str:
return "ForType"

@classmethod
@abstractmethod
def utype(cls) -> UType:
pass

def normalize(self, column: e.ExpressionBuilder, dialect: str, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
match dialect:
case "oracle": return self._normalize_oracle(column, column_def)
case "databricks": return self._normalize_databricks(column, column_def)
case "snowflake": return self._normalize_snowflake(column, column_def)
case _: return column # instead of error, return as is

@abstractmethod
def _normalize_oracle(self, column: e.ExpressionBuilder, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
pass

@abstractmethod
def _normalize_databricks(self, column: e.ExpressionBuilder, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
pass

@abstractmethod
def _normalize_snowflake(self, column: e.ExpressionBuilder, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
pass

class UDatetimeTypeNormalizer(AbstractTypeNormalizer):
"""
transform all dialects to unix time
"""

@classmethod
def registry_key(cls) -> str:
return cls.utype().name.name

@classmethod
def utype(cls) -> UType:
return UType(ColumnTypeName.DATETIME)

def _normalize_oracle(self, column: e.ExpressionBuilder, source_col: ExternalColumnDefinition) -> e.ExpressionBuilder:
return column

def _normalize_databricks(self, column: e.ExpressionBuilder, source_col: ExternalColumnDefinition) -> e.ExpressionBuilder:
return e.unix_time(column)

def _normalize_snowflake(self, column: e.ExpressionBuilder, source_col: ExternalColumnDefinition) -> e.ExpressionBuilder:
return column

class UStringTypeNormalizer(AbstractTypeNormalizer):

_delegate = HandleNullsAndTrimNormalizer()

@classmethod
def registry_key(cls) -> str:
return cls.utype().name.name

@classmethod
def utype(cls) -> UType:
return UType(ColumnTypeName.VARCHAR)

def _normalize_oracle(self, column: e.ExpressionBuilder,
column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
return self._delegate.normalize(column, "", column_def)

def _normalize_databricks(self, column: e.ExpressionBuilder,
column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
return self._delegate.normalize(column, "", column_def)

def _normalize_snowflake(self, column: e.ExpressionBuilder,
column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
return self._delegate.normalize(column, "", column_def)


class NormalizersRegistry:
_registry: dict[str,dict[str, AbstractNormalizer]] = {} # can we type this to subclass of AbstractTypeNormalizer

def register_normalizer(self, normalizer: AbstractNormalizer): # also subclasses
family = self._registry.get(normalizer.registry_key_family(), {})
if family.get(normalizer.registry_key()):
raise ValueError(f"Normalizer already registered for utype: {normalizer.registry_key_family()},{normalizer.registry_key()}")
if not family:
self._registry[normalizer.registry_key_family()] = {}
self._registry[normalizer.registry_key_family()][normalizer.registry_key()] = normalizer

def get_type_normalizer(self, name: ColumnTypeName) -> AbstractTypeNormalizer | None:
return self._registry.get(AbstractTypeNormalizer.registry_key_family(), {}).get(name.name)

def get_universal_normalizers(self):
return self._registry.get(UniversalNormalizer.registry_key_family(), {}).values()

class DialectNormalizer(ABC):
DbTypeNormalizerType = dict[ColumnTypeName, ColumnTypeName]
# or ExternalType to UType. what about extra type information e.g scale, precision?

dialect: e.DialectType

def __init__(self, registry: NormalizersRegistry):
self._registry = registry

@classmethod
def type_normalizers(cls) -> DbTypeNormalizerType:
return {
ColumnTypeName("DATE"): UDatetimeTypeNormalizer.utype().name,
ColumnTypeName("NCHAR"): UStringTypeNormalizer.utype().name,
ColumnTypeName("CHAR"): UStringTypeNormalizer.utype().name,
ColumnTypeName("VARCHAR"): UStringTypeNormalizer.utype().name,
ColumnTypeName("NVARCHAR"): UStringTypeNormalizer.utype().name,
ColumnTypeName("VARCHAR2"): UStringTypeNormalizer.utype().name,
}

def normalize(self, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
start = e.ExpressionBuilder(column_def.column_name, self.dialect)
for normalizer in self._registry.get_universal_normalizers():
start = normalizer.normalize(start, self.dialect, column_def)
utype = self.type_normalizers().get(column_def.data_type.name)
if utype:
normalizer = self._registry.get_type_normalizer(utype)
if normalizer:
return normalizer.normalize(start, self.dialect, column_def)
return start


class OracleNormalizer(DialectNormalizer):
dialect = "oracle"


class SnowflakeNormalizer(DialectNormalizer):
dialect = "snowflake"

if __name__ == "__main__":
registry = NormalizersRegistry()
registry.register_normalizer(UDatetimeTypeNormalizer())
registry.register_normalizer(UStringTypeNormalizer())
# registry.register_normalizer(HandleNullsAndTrimNormalizer())
registry.register_normalizer(QuoteIdentifierNormalizer())
oracle = OracleNormalizer(registry)
snow = SnowflakeNormalizer(registry)

column = ExternalColumnDefinition("student_id", ExternalType(ColumnTypeName["NCHAR"]))

oracle_column = oracle.normalize(column, registry).build()
assert oracle_column == "COALESCE(TRIM(\"student_id\"), '__null_recon__')"
snow_column = snow.normalize(column, registry).build()
assert snow_column == "COALESCE(TRIM(\"student_id\"), '__null_recon__')"


"""
1. source system
2. target system
3. datatype
4. encoding
5. query
"""


Loading
Loading