Skip to content

Commit

Permalink
Implement chain group_by (#482)
Browse files Browse the repository at this point in the history
* Implement chain group_by
* Use 'sql_to_python' for GenericFunction type conversion
  • Loading branch information
dreadatour authored Oct 16, 2024
1 parent 6da688f commit c6ca542
Show file tree
Hide file tree
Showing 19 changed files with 750 additions and 142 deletions.
2 changes: 2 additions & 0 deletions src/datachain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datachain.lib import func
from datachain.lib.data_model import DataModel, DataType, is_chain_type
from datachain.lib.dc import C, Column, DataChain, Sys
from datachain.lib.file import (
Expand Down Expand Up @@ -34,6 +35,7 @@
"Sys",
"TarVFile",
"TextFile",
"func",
"is_chain_type",
"metrics",
"param",
Expand Down
8 changes: 8 additions & 0 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,14 @@ def copy_table(
query: Select,
progress_cb: Optional[Callable[[int], None]] = None,
) -> None:
if len(query._group_by_clause) > 0:
select_q = query.with_only_columns(
*[c for c in query.selected_columns if c.name != "sys__id"]
)
q = table.insert().from_select(list(select_q.selected_columns), select_q)
self.db.execute(q)
return

if "sys__id" in query.selected_columns:
col_id = query.selected_columns.sys__id
else:
Expand Down
20 changes: 8 additions & 12 deletions src/datachain/lib/convert/sql_to_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@
from sqlalchemy import ColumnElement


def sql_to_python(args_map: dict[str, ColumnElement]) -> dict[str, Any]:
res = {}
for name, sql_exp in args_map.items():
try:
type_ = sql_exp.type.python_type
if type_ == Decimal:
type_ = float
except NotImplementedError:
type_ = str
res[name] = type_

return res
def sql_to_python(sql_exp: ColumnElement) -> Any:
try:
type_ = sql_exp.type.python_type
if type_ == Decimal:
type_ = float
except NotImplementedError:
type_ = str
return type_
84 changes: 62 additions & 22 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from datachain.lib.dataset_info import DatasetInfo
from datachain.lib.file import ArrowRow, File, get_file_type
from datachain.lib.file import ExportPlacement as FileExportPlacement
from datachain.lib.func import Func
from datachain.lib.listing import (
is_listing_dataset,
is_listing_expired,
Expand All @@ -42,21 +43,12 @@
from datachain.lib.model_store import ModelStore
from datachain.lib.settings import Settings
from datachain.lib.signal_schema import SignalSchema
from datachain.lib.udf import (
Aggregator,
BatchMapper,
Generator,
Mapper,
UDFBase,
)
from datachain.lib.udf import Aggregator, BatchMapper, Generator, Mapper, UDFBase
from datachain.lib.udf_signature import UdfSignature
from datachain.lib.utils import DataChainParamsError
from datachain.lib.utils import DataChainColumnError, DataChainParamsError
from datachain.query import Session
from datachain.query.dataset import (
DatasetQuery,
PartitionByType,
)
from datachain.query.schema import DEFAULT_DELIMITER, Column
from datachain.query.dataset import DatasetQuery, PartitionByType
from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta
from datachain.sql.functions import path as pathfunc
from datachain.telemetry import telemetry
from datachain.utils import batched_it, inside_notebook
Expand Down Expand Up @@ -149,11 +141,6 @@ def _get_str(on: Sequence[Union[str, sqlalchemy.ColumnElement]]) -> str:
super().__init__(f"Merge error on='{on_str}'{right_on_str}: {msg}")


class DataChainColumnError(DataChainParamsError): # noqa: D101
def __init__(self, col_name, msg): # noqa: D107
super().__init__(f"Error for column {col_name}: {msg}")


OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]


Expand Down Expand Up @@ -982,10 +969,9 @@ def distinct(self, arg: str, *args: str) -> "Self": # type: ignore[override]
row is left in the result set.
Example:
```py
dc.distinct("file.parent", "file.name")
)
```
```py
dc.distinct("file.parent", "file.name")
```
"""
return self._evolve(
query=self._query.distinct(
Expand All @@ -1011,6 +997,60 @@ def select_except(self, *args: str) -> "Self":
query=self._query.select(*columns), signal_schema=new_schema
)

def group_by(
self,
*,
partition_by: Union[str, Sequence[str]],
**kwargs: Func,
) -> "Self":
"""Group rows by specified set of signals and return new signals
with aggregated values.
Example:
```py
chain = chain.group_by(
cnt=func.count(),
partition_by=("file_source", "file_ext"),
)
```
"""
if isinstance(partition_by, str):
partition_by = [partition_by]
if not partition_by:
raise ValueError("At least one column should be provided for partition_by")

if not kwargs:
raise ValueError("At least one column should be provided for group_by")
for col_name, func in kwargs.items():
if not isinstance(func, Func):
raise DataChainColumnError(
col_name,
f"Column {col_name} has type {type(func)} but expected Func object",
)

partition_by_columns: list[Column] = []
signal_columns: list[Column] = []
schema_fields: dict[str, DataType] = {}

# validate partition_by columns and add them to the schema
for col_name in partition_by:
col_db_name = ColumnMeta.to_db_name(col_name)
col_type = self.signals_schema.get_column_type(col_db_name)
col = Column(col_db_name, python_to_sql(col_type))
partition_by_columns.append(col)
schema_fields[col_db_name] = col_type

# validate signal columns and add them to the schema
for col_name, func in kwargs.items():
col = func.get_column(self.signals_schema, label=col_name)
signal_columns.append(col)
schema_fields[col_name] = func.get_result_type(self.signals_schema)

return self._evolve(
query=self._query.group_by(signal_columns, partition_by_columns),
signal_schema=SignalSchema(schema_fields),
)

def mutate(self, **kwargs) -> "Self":
"""Create new signals based on existing signals.
Expand Down
14 changes: 14 additions & 0 deletions src/datachain/lib/func/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .aggregate import any_value, avg, collect, concat, count, max, min, sum
from .func import Func

__all__ = [
"Func",
"any_value",
"avg",
"collect",
"concat",
"count",
"max",
"min",
"sum",
]
42 changes: 42 additions & 0 deletions src/datachain/lib/func/aggregate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Optional

from sqlalchemy import func as sa_func

from datachain.sql import functions as dc_func

from .func import Func


def count(col: Optional[str] = None) -> Func:
return Func(inner=sa_func.count, col=col, result_type=int)


def sum(col: str) -> Func:
return Func(inner=sa_func.sum, col=col)


def avg(col: str) -> Func:
return Func(inner=dc_func.aggregate.avg, col=col)


def min(col: str) -> Func:
return Func(inner=sa_func.min, col=col)


def max(col: str) -> Func:
return Func(inner=sa_func.max, col=col)


def any_value(col: str) -> Func:
return Func(inner=dc_func.aggregate.any_value, col=col)


def collect(col: str) -> Func:
return Func(inner=dc_func.aggregate.collect, col=col, is_array=True)


def concat(col: str, separator="") -> Func:
def inner(arg):
return dc_func.aggregate.group_concat(arg, separator)

return Func(inner=inner, col=col, result_type=str)
64 changes: 64 additions & 0 deletions src/datachain/lib/func/func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import TYPE_CHECKING, Callable, Optional

from datachain.lib.convert.python_to_sql import python_to_sql
from datachain.lib.utils import DataChainColumnError
from datachain.query.schema import Column, ColumnMeta

if TYPE_CHECKING:
from datachain import DataType
from datachain.lib.signal_schema import SignalSchema


class Func:
def __init__(
self,
inner: Callable,
col: Optional[str] = None,
result_type: Optional["DataType"] = None,
is_array: bool = False,
) -> None:
self.inner = inner
self.col = col
self.result_type = result_type
self.is_array = is_array

@property
def db_col(self) -> Optional[str]:
return ColumnMeta.to_db_name(self.col) if self.col else None

def db_col_type(self, signals_schema: "SignalSchema") -> Optional["DataType"]:
if not self.db_col:
return None
col_type: type = signals_schema.get_column_type(self.db_col)
return list[col_type] if self.is_array else col_type # type: ignore[valid-type]

def get_result_type(self, signals_schema: "SignalSchema") -> "DataType":
col_type = self.db_col_type(signals_schema)

if self.result_type:
return self.result_type

if col_type:
return col_type

raise DataChainColumnError(
str(self.inner),
"Column name is required to infer result type",
)

def get_column(
self, signals_schema: "SignalSchema", label: Optional[str] = None
) -> Column:
if self.col:
if label == "collect":
print(label)
col_type = self.get_result_type(signals_schema)
col = Column(self.db_col, python_to_sql(col_type))
func_col = self.inner(col)
else:
func_col = self.inner()

if label:
func_col = func_col.label(label)

return func_col
12 changes: 9 additions & 3 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,12 @@ def _set_file_stream(
if ModelStore.is_pydantic(finfo.annotation):
SignalSchema._set_file_stream(getattr(obj, field), catalog, cache)

def get_column_type(self, col_name: str) -> DataType:
for path, _type, has_subtree, _ in self.get_flat_tree():
if not has_subtree and DEFAULT_DELIMITER.join(path) == col_name:
return _type
raise SignalResolvingError([col_name], "is not found")

def db_signals(
self, name: Optional[str] = None, as_columns=False
) -> Union[list[str], list[Column]]:
Expand Down Expand Up @@ -490,7 +496,7 @@ def mutate(self, args_map: dict) -> "SignalSchema":
new_values[name] = args_map[name]
else:
# adding new signal
new_values.update(sql_to_python({name: value}))
new_values[name] = sql_to_python(value)

return SignalSchema(new_values)

Expand Down Expand Up @@ -534,12 +540,12 @@ def _build_tree(
for name, val in values.items()
}

def get_flat_tree(self) -> Iterator[tuple[list[str], type, bool, int]]:
def get_flat_tree(self) -> Iterator[tuple[list[str], DataType, bool, int]]:
yield from self._get_flat_tree(self.tree, [], 0)

def _get_flat_tree(
self, tree: dict, prefix: list[str], depth: int
) -> Iterator[tuple[list[str], type, bool, int]]:
) -> Iterator[tuple[list[str], DataType, bool, int]]:
for name, (type_, substree) in tree.items():
suffix = name.split(".")
new_prefix = prefix + suffix
Expand Down
5 changes: 5 additions & 0 deletions src/datachain/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,8 @@ def __init__(self, message):
class DataChainParamsError(DataChainError):
def __init__(self, message):
super().__init__(message)


class DataChainColumnError(DataChainParamsError):
def __init__(self, col_name, msg):
super().__init__(f"Error for column {col_name}: {msg}")
Loading

0 comments on commit c6ca542

Please sign in to comment.