Skip to content

Commit

Permalink
pydantic_to_feature(): support enum type & nested lists (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmpetrov authored Jul 11, 2024
1 parent f500052 commit e8e7ce7
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 20 deletions.
9 changes: 7 additions & 2 deletions src/datachain/lib/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings
from collections.abc import Iterable, Sequence
from datetime import datetime
from enum import Enum
from functools import lru_cache
from types import GenericAlias
from typing import (
Expand Down Expand Up @@ -63,6 +64,7 @@
str: String,
Literal: String,
LiteralEx: String,
Enum: String,
float: Float,
bool: Boolean,
datetime: DateTime, # Note, list of datetime is not supported yet
Expand Down Expand Up @@ -364,8 +366,11 @@ def _resolve(cls, name, field_info, prefix: list[str]):


def convert_type_to_datachain(typ): # noqa: PLR0911
if inspect.isclass(typ) and issubclass(typ, SQLType):
return typ
if inspect.isclass(typ):
if issubclass(typ, SQLType):
return typ
if issubclass(typ, Enum):
return str

res = TYPE_TO_DATACHAIN.get(typ)
if res:
Expand Down
52 changes: 35 additions & 17 deletions src/datachain/lib/feature_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import inspect
import string
from collections.abc import Sequence
from enum import Enum
from typing import Any, Union, get_args, get_origin

from pydantic import BaseModel, create_model
Expand Down Expand Up @@ -35,23 +37,7 @@ def pydantic_to_feature(data_cls: type[BaseModel]) -> type[Feature]:
for name, field_info in data_cls.model_fields.items():
anno = field_info.annotation
if anno not in TYPE_TO_DATACHAIN:
orig = get_origin(anno)
if orig is list:
anno = get_args(anno) # type: ignore[assignment]
if isinstance(anno, Sequence):
anno = anno[0] # type: ignore[unreachable]
is_list = True
else:
is_list = False

try:
convert_type_to_datachain(anno)
except TypeError:
if not Feature.is_feature(anno): # type: ignore[arg-type]
anno = pydantic_to_feature(anno) # type: ignore[arg-type]

if is_list:
anno = list[anno] # type: ignore[valid-type]
anno = _to_feature_type(anno)
fields[name] = (anno, field_info.default)

cls = create_model(
Expand All @@ -63,6 +49,38 @@ def pydantic_to_feature(data_cls: type[BaseModel]) -> type[Feature]:
return cls


def _to_feature_type(anno):
if inspect.isclass(anno) and issubclass(anno, Enum):
return str

orig = get_origin(anno)
if orig is list:
anno = get_args(anno) # type: ignore[assignment]
if isinstance(anno, Sequence):
anno = anno[0] # type: ignore[unreachable]
is_list = True
else:
is_list = False

try:
convert_type_to_datachain(anno)
except TypeError:
if not Feature.is_feature(anno): # type: ignore[arg-type]
orig = get_origin(anno)
if orig in TYPE_TO_DATACHAIN:
anno = _to_feature_type(anno)
else:
if orig == Union:
args = get_args(anno)
if len(args) == 2 and (type(None) in args):
return _to_feature_type(args[0])

anno = pydantic_to_feature(anno) # type: ignore[arg-type]
if is_list:
anno = list[anno] # type: ignore[valid-type]
return anno


def features_to_tuples(
ds_name: str = "",
output: Union[None, FeatureType, Sequence[str], dict[str, FeatureType]] = None,
Expand Down
38 changes: 37 additions & 1 deletion tests/unit/lib/test_feature_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from enum import Enum
from typing import get_args, get_origin

import pytest
from pydantic import BaseModel

from datachain.lib.dc import DataChain
from datachain.lib.feature_utils import FeatureToTupleError, features_to_tuples
from datachain.lib.feature import Feature
from datachain.lib.feature_utils import (
FeatureToTupleError,
features_to_tuples,
pydantic_to_feature,
)
from datachain.query.schema import Column


Expand Down Expand Up @@ -104,3 +111,32 @@ def test_resolve_column():
def test_resolve_column_attr():
signal = Column.hello.world.again
assert signal.name == "hello__world__again"


def test_to_feature_list_of_lists():
class MyName1(BaseModel):
id: int
name: str

class Mytest2(BaseModel):
loc: str
identity: list[list[MyName1]]

cls = pydantic_to_feature(Mytest2)

assert issubclass(cls, Feature)


def test_to_feature_function():
class MyEnum(str, Enum):
func = "function"

class MyCall(BaseModel):
id: str
type: MyEnum

cls = pydantic_to_feature(MyCall)
assert issubclass(cls, Feature)

type_ = cls.model_fields["type"].annotation
assert type_ is str

0 comments on commit e8e7ce7

Please sign in to comment.