Skip to content

Commit

Permalink
Auto-register data classes & misc changes
Browse files Browse the repository at this point in the history
  • Loading branch information
dmpetrov committed Jul 14, 2024
1 parent 64a2e87 commit cc718ae
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 30 deletions.
4 changes: 3 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import sqlalchemy

from datachain.lib.feature import FeatureType
from datachain.lib.feature import FeatureType, registry_list
from datachain.lib.feature_utils import features_to_tuples
from datachain.lib.file import File, IndexedFile, get_file
from datachain.lib.meta_formats import read_meta, read_schema
Expand Down Expand Up @@ -534,6 +534,8 @@ def _udf_to_obj(
name = self.name or ""

sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
registry_list(list(sign.output_schema.values.values()))

params_schema = self.signals_schema.slice(sign.params, self._setup)

return UDFBase._create(target_class, sign, params_schema)
Expand Down
6 changes: 6 additions & 0 deletions src/datachain/lib/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@
feature_classes_lookup: dict[type, bool] = {}


def registry_list(output: Sequence[FeatureType]):
for val in output:
if is_feature(val):
Registry.add(val)


def is_standard_type(t: type) -> bool:
return any(t is ft or t is get_args(ft)[0] for ft in get_args(FeatureStandardType))

Expand Down
15 changes: 12 additions & 3 deletions src/datachain/lib/feature_registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import logging
from typing import Any, ClassVar, Optional

Expand All @@ -17,7 +18,9 @@ def get_version(cls, model: type[BaseModel]) -> int:

@classmethod
def get_name(cls, model) -> str:
return f"{model.__name__}@{cls.get_version(model)}"
if (version := cls.get_version(model)) > 0:
return f"{model.__name__}@v{version}"
return model.__name__

@classmethod
def add(cls, fr: type) -> None:
Expand All @@ -30,6 +33,12 @@ def add(cls, fr: type) -> None:
logger.warning("Feature %s is already registered", full_name)
cls.reg[name][version] = fr

if issubclass(fr, BaseModel):
for f_info in fr.model_fields.values():
anno = f_info.annotation
if inspect.isclass(anno) and issubclass(anno, BaseModel):
cls.add(anno)

@classmethod
def get(cls, name: str, version: Optional[int] = None) -> Optional[type]:
class_dict = cls.reg.get(name, None)
Expand All @@ -45,12 +54,12 @@ def get(cls, name: str, version: Optional[int] = None) -> Optional[type]:
@classmethod
def parse_name_version(cls, fullname: str) -> tuple[str, int]:
name = fullname
version = 1
version = 0

if "@" in fullname:
name, version_str = fullname.split("@")
if version_str.strip() != "":
version = int(version_str)
version = int(version_str[1:])

return name, version

Expand Down
23 changes: 16 additions & 7 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
from collections.abc import Iterator, Sequence
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, get_args, get_origin

Expand Down Expand Up @@ -57,7 +58,13 @@ def __init__(self, method: str, field):
)


@dataclass
class SignalSchema:
values: dict[str, FeatureType]
tree: dict[str, Any]
setup_func: dict[str, Callable]
setup_values: Optional[dict[str, Callable]]

def __init__(
self,
values: dict[str, FeatureType],
Expand Down Expand Up @@ -123,15 +130,17 @@ def deserialize(schema: dict[str, str]) -> "SignalSchema":
if not fr:
type_name, version = Registry.parse_name_version(type_name)
fr = Registry.get(type_name, version)

if not fr:
raise SignalSchemaError(
f"cannot deserialize '{signal}': "
f"unregistered type '{type_name}'."
f" Try to register it with `Registry.add({type_name})`."
)
except TypeError as err:
raise SignalSchemaError(
f"cannot deserialize '{signal}': {err}"
) from err

if not fr:
raise SignalSchemaError(
f"cannot deserialize '{signal}': unsupported type '{type_name}'"
)
signals[signal] = fr

return SignalSchema(signals)
Expand All @@ -152,11 +161,11 @@ def row_to_objs(self, row: Sequence[Any]) -> list[FeatureType]:
objs = []
pos = 0
for name, fr_type in self.values.items():
if val := self.setup_values.get(name, None): # type: ignore[attr-defined]
if self.setup_values and (val := self.setup_values.get(name, None)):
objs.append(val)
elif (fr := to_feature(fr_type)) is not None:
j, pos = ModelUtil.unflatten_to_json_pos(fr, row, pos)
objs.append(fr(**j))
objs.append(fr(**j)) # type: ignore[arg-type]
else:
objs.append(row[pos])
pos += 1
Expand Down
20 changes: 10 additions & 10 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import numpy as np
import pandas as pd
import pytest
from pydantic import BaseModel

from datachain.lib.dc import C, DataChain
from datachain.lib.feature import VersionedModel
from datachain.lib.file import File
from datachain.lib.signal_schema import (
SignalResolvingError,
Expand All @@ -29,12 +29,12 @@
}


class MyFr(VersionedModel):
class MyFr(BaseModel):
nnn: str
count: int


class MyNested(VersionedModel):
class MyNested(BaseModel):
label: str
fr: MyFr

Expand Down Expand Up @@ -194,7 +194,7 @@ def test_file_list(catalog):


def test_gen(catalog):
class _TestFr(VersionedModel):
class _TestFr(BaseModel):
file: File
sqrt: float
my_name: str
Expand All @@ -221,7 +221,7 @@ class _TestFr(VersionedModel):


def test_map(catalog):
class _TestFr(VersionedModel):
class _TestFr(BaseModel):
sqrt: float
my_name: str

Expand All @@ -241,7 +241,7 @@ class _TestFr(VersionedModel):


def test_agg(catalog):
class _TestFr(VersionedModel):
class _TestFr(BaseModel):
f: File
cnt: int
my_name: str
Expand Down Expand Up @@ -269,7 +269,7 @@ class _TestFr(VersionedModel):


def test_agg_two_params(catalog):
class _TestFr(VersionedModel):
class _TestFr(BaseModel):
f: File
cnt: int
my_name: str
Expand Down Expand Up @@ -325,7 +325,7 @@ def func(key) -> int:

with pytest.raises(UdfSignatureError):

class _MyCls(VersionedModel):
class _MyCls(BaseModel):
x: int

def func(key) -> _MyCls: # type: ignore[misc]
Expand All @@ -342,7 +342,7 @@ def func(key) -> tuple[File, str]: # type: ignore[misc]


def test_agg_tuple_result_iterator(catalog):
class _ImageGroup(VersionedModel):
class _ImageGroup(BaseModel):
name: str
size: int

Expand All @@ -364,7 +364,7 @@ def func(key, val) -> Iterator[tuple[File, _ImageGroup]]:


def test_agg_tuple_result_generator(catalog):
class _ImageGroup(VersionedModel):
class _ImageGroup(BaseModel):
name: str
size: int

Expand Down
22 changes: 13 additions & 9 deletions tests/unit/lib/test_signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from typing import Optional, Union

import pytest
from pydantic import BaseModel

from datachain.lib.feature import ModelUtil, VersionedModel
from datachain.lib.feature import ModelUtil
from datachain.lib.feature_registry import Registry
from datachain.lib.file import File
from datachain.lib.signal_schema import (
SetupError,
Expand All @@ -25,18 +27,18 @@ class _MyFile(File):
return SignalSchema(schema)


class MyType1(VersionedModel):
class MyType1(BaseModel):
aa: int
bb: str


class MyType2(VersionedModel):
class MyType2(BaseModel):
name: str
deep: MyType1


def test_deserialize_basic():
stored = {"name": "str", "count": "int", "file": "File@1"}
stored = {"name": "str", "count": "int", "file": "File@v1"}
signals = SignalSchema.deserialize(stored)

assert len(signals.values) == 3
Expand Down Expand Up @@ -68,7 +70,7 @@ def test_serialize_basic():
assert len(signals) == 3
assert signals["name"] == "str"
assert signals["age"] == "float"
assert signals["f"] == "File@1"
assert signals["f"] == "File@v1"


def test_feature_schema_serialize_optional():
Expand Down Expand Up @@ -101,7 +103,7 @@ def test_to_udf_spec():
{
"age": "float",
"address": "str",
"f": "File@1",
"f": "File@v1",
}
)

Expand All @@ -123,11 +125,12 @@ def test_to_udf_spec():


def test_select():
Registry.add(MyType2)
schema = SignalSchema.deserialize(
{
"age": "float",
"address": "str",
"f": "MyType1@1",
"f": "MyType1",
}
)

Expand All @@ -143,10 +146,11 @@ def test_select():


def test_select_nested_names():
Registry.add(MyType2)
schema = SignalSchema.deserialize(
{
"address": "str",
"fr": "MyType2@1",
"fr": "MyType2",
}
)

Expand All @@ -165,7 +169,7 @@ def test_select_nested_errors():
schema = SignalSchema.deserialize(
{
"address": "str",
"fr": "MyType2@1",
"fr": "MyType2",
}
)

Expand Down

0 comments on commit cc718ae

Please sign in to comment.