Skip to content

Commit

Permalink
Merge branch 'main' into fix-distributed-test
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein authored Jul 12, 2024
2 parents 1fe1612 + ee3d16c commit 72ec99a
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 68 deletions.
97 changes: 58 additions & 39 deletions README.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
|PyPI| |Python Version| |Codecov| |Tests| |License|
|PyPI| |Python Version| |Codecov| |Tests|

.. |PyPI| image:: https://img.shields.io/pypi/v/datachain.svg
:target: https://pypi.org/project/datachain/
Expand Down Expand Up @@ -39,21 +39,24 @@ For example, let us consider a dataset from Karlsruhe Institute of Technology de
# this example requires a free Mistral API key, get yours at https://console.mistral.ai
# add the key to your shell environment: $ export MISTRAL_API_KEY= your key
# pip install mistralai
# this example requires a free Mistral API key, get yours at https://console.mistral.ai
# add the key to your shell environment: $ export MISTRAL_API_KEY= your key
import os
import pandas as pd
from datachain.lib.feature import Feature
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
from datachain.lib.dc import Column, DataChain
source = "gs://datachain-demo/chatbot-KiT/"
from datachain.lib.dc import DataChain, Column
PROMPT = "Was this bot dialog successful? Describe the 'result' as 'Yes' or 'No' in a short JSON"
model = "mistral-large-latest"
api_key = os.environ["MISTRAL_API_KEY"]
chain = (
DataChain.from_storage(source)
DataChain.from_storage("gs://datachain-demo/chatbot-KiT/")
.limit(5)
.settings(cache=True, parallel=5)
.map(
Expand All @@ -64,17 +67,21 @@ For example, let us consider a dataset from Karlsruhe Institute of Technology de
messages=[
ChatMessage(role="user", content=f"{PROMPT}: {file.get_value()}")
],
).choices[0].message.content,
)
.choices[0]
.message.content,
)
.save()
)
try:
print(chain.select("mistral_response").results())
print(chain.select("mistral_response").results())
except Exception as e:
print(f"do you have the right Mistral API key? {e}")
print(f"do you have the right Mistral API key? {e}")
.. code:: shell
->
[('{"result": "Yes"}',), ('{"result": "No"}',), ... , ('{"result": "Yes"}',)]
Now we have parallel-processed an LLM API-based query over cloud data and persisted the results.
Expand All @@ -90,8 +97,10 @@ Datachain internally represents datasets as tables, so analytical queries on the
success_rate = failed_dialogs.count() / chain.count()
print(f"Chatbot dialog success rate: {100*success_rate:.2f}%")
->
"40.00%" (results may vary)
.. code:: shell
"40.00%"
Note that DataChain represents file samples as pointers into their respective storage locations. This means a newly created dataset version does not duplicate files in storage, and storage remains the single source of truth for the original samples

Expand All @@ -104,29 +113,37 @@ For example, instead of collecting just a text response from Mistral API, we mig
.. code:: py
import os
from datachain.lib.feature import Feature
from datachain.lib.dc import Column, DataChain
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
source = "gs://datachain-demo/chatbot-KiT/"
PROMPT = "Was this dialog successful? Describe the 'result' as 'Yes' or 'No' in a short JSON"
from datachain.lib.dc import DataChain
from datachain.lib.feature import Feature
PROMPT = (
"Was this dialog successful? Describe the 'result' as 'Yes' or 'No' in a short JSON"
)
model = "mistral-large-latest"
api_key = os.environ["MISTRAL_API_KEY"]
## define the data model ###
class Usage(Feature):
prompt_tokens: int = 0
completion_tokens: int = 0
class MyChatMessage(Feature):
role: str = ""
content: str = ""
class CompletionResponseChoice(Feature):
message: MyChatMessage = MyChatMessage()
class MistralModel(Feature):
id: str = ""
choices: list[CompletionResponseChoice]
Expand All @@ -135,7 +152,7 @@ For example, instead of collecting just a text response from Mistral API, we mig
## Populate model instances ###
chain = (
DataChain.from_storage(source)
DataChain.from_storage("gs://datachain-demo/chatbot-KiT/")
.limit(5)
.settings(cache=True, parallel=5)
.map(
Expand All @@ -147,7 +164,8 @@ For example, instead of collecting just a text response from Mistral API, we mig
messages=[
ChatMessage(role="user", content=f"{PROMPT}: {file.get_value()}")
],
).dict()
)
.dict()
),
output=MistralModel,
)
Expand All @@ -158,19 +176,18 @@ After the chain execution, we can collect the objects:

.. code:: py
responses = chain.collect_one("mistral_response")
for object in responses:
print(type(object))
->
<class '__main__.MistralModel'>
<class '__main__.MistralModel'>
<class '__main__.MistralModel'>
<class '__main__.MistralModel'>
<class '__main__.MistralModel'>
for obj in responses:
assert isinstance(obj, MistralModel)
print(obj.dict())
.. code:: shell
{'choices': [{'message': {'role': 'assistant', 'content': '{"result": "Yes"}'}}], 'usage': {'prompt_tokens': 610, 'completion_tokens': 6}}
{'choices': [{'message': {'role': 'assistant', 'content': '{"result": "No"}'}}], 'usage': {'prompt_tokens': 3983, 'completion_tokens': 6}}
{'choices': [{'message': {'role': 'assistant', 'content': '{"result": "Yes"}'}}], 'usage': {'prompt_tokens': 706, 'completion_tokens': 6}}
{'choices': [{'message': {'role': 'assistant', 'content': '{"result": "No"}'}}], 'usage': {'prompt_tokens': 1250, 'completion_tokens': 6}}
{'choices': [{'message': {'role': 'assistant', 'content': '{"result": "Yes"}'}}], 'usage': {'prompt_tokens': 1217, 'completion_tokens': 6}}
print(responses[0].usage.prompt_tokens)
->
610
Dataset persistence
--------------------
Expand Down Expand Up @@ -215,9 +232,7 @@ Here is an example of reading a CSV file where schema is heuristically derived f
.. code:: py
from datachain.lib.dc import DataChain
uri="gs://datachain-demo/chatbot-csv/"
csv_dataset = DataChain.from_csv(uri)
csv_dataset = DataChain.from_csv("gs://datachain-demo/chatbot-csv/")
print(csv_dataset.to_pandas())
Expand Down Expand Up @@ -264,17 +279,21 @@ To deal with this layout, we can take the following steps:
from datachain.lib.dc import DataChain
image_uri="gs://datachain-demo/coco2017/images/val/"
coco_json="gs://datachain-demo/coco2017/annotations_captions"
images = DataChain.from_storage(image_uri)
meta = DataChain.from_json(coco_json, jmespath = "images")
images = DataChain.from_storage("gs://datachain-demo/coco2017/images/val/")
meta = DataChain.from_json("gs://datachain-demo/coco2017/annotations_captions", jmespath = "images")
images_with_meta = images.merge(meta, on="file.name", right_on="images.file_name")
print(images_with_meta.limit(1).results())
.. code:: shell
Processed: 5000 rows [00:00, 15481.66 rows/s]
Processed: 1 rows [00:00, 1291.75 rows/s]
Processed: 1 rows [00:00, 4.70 rows/s]
Generated: 5000 rows [00:00, 27128.67 rows/s]
[(1, 2336066478558845549, '', 0, 'coco2017/images/val', '000000000139.jpg', 'CNvXoemj8IYDEAE=', '1719096046021595', 1, datetime.datetime(2024, 6, 22, 22, 40, 46, 70000, tzinfo=datetime.timezone.utc), 161811, '', '', None, 'gs://datachain-demo', 'gs://datachain-demo', 'coco2017/images/val', '000000000139.jpg', 161811, '1719096046021595', 'CNvXoemj8IYDEAE=', 1, datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc), None, '', 4146, 6967063844996569113, 2, '000000000139.jpg', 'http://images.cocodataset.org/val2017/000000000139.jpg', 426, 640, '2013-11-21 01:34:01', 'http://farm9.staticflickr.com/8035/8024364858_9c41dc1666_z.jpg', 139)]
Passing data to training
------------------------
Expand Down
34 changes: 34 additions & 0 deletions src/datachain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from datachain.lib.dc import C, DataChain
from datachain.lib.feature import Feature
from datachain.lib.feature_utils import pydantic_to_feature
from datachain.lib.file import File, FileError, FileFeature, IndexedFile, TarVFile
from datachain.lib.image import ImageFile, convert_images
from datachain.lib.text import convert_text
from datachain.lib.udf import Aggregator, Generator, Mapper
from datachain.lib.utils import AbstractUDF, DataChainError
from datachain.query.dataset import UDF as BaseUDF # noqa: N811
from datachain.query.schema import Column
from datachain.query.session import Session

__all__ = [
"AbstractUDF",
"Aggregator",
"BaseUDF",
"C",
"Column",
"DataChain",
"DataChainError",
"Feature",
"File",
"FileError",
"FileFeature",
"Generator",
"ImageFile",
"IndexedFile",
"Mapper",
"Session",
"TarVFile",
"convert_images",
"convert_text",
"pydantic_to_feature",
]
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def monkeypatch_session() -> Generator[MonkeyPatch, None, None]:
mpatch.undo()


@pytest.fixture(autouse=True)
def clean_session() -> None:
"""
Make sure we clean leftover session before each test case
"""
Session.cleanup_for_tests()


@pytest.fixture(scope="session", autouse=True)
def clean_environment(
monkeypatch_session: MonkeyPatch,
Expand Down
28 changes: 15 additions & 13 deletions tests/func/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import open_clip
import pytest
from torch import Size, Tensor
Expand All @@ -8,10 +10,10 @@
from datachain.lib.pytorch import PytorchDataset


@pytest.fixture
def fake_dataset(tmp_path, catalog):
@pytest.fixture(scope="module")
def fake_dataset(tmpdir_factory):
# Create fake images in labeled dirs
data_path = tmp_path / "data" / ""
data_path = Path(tmpdir_factory.mktemp("data"))
for i, (img, label) in enumerate(FakeData()):
label = str(label)
(data_path / label).mkdir(parents=True, exist_ok=True)
Expand All @@ -37,11 +39,11 @@ def test_pytorch_dataset(fake_dataset):
transform=transform,
tokenizer=tokenizer,
)
for img, text, label in pt_dataset:
assert isinstance(img, Tensor)
assert isinstance(text, Tensor)
assert isinstance(label, int)
assert img.size() == Size([3, 64, 64])
img, text, label = next(iter(pt_dataset))
assert isinstance(img, Tensor)
assert isinstance(text, Tensor)
assert isinstance(label, int)
assert img.size() == Size([3, 64, 64])


def test_pytorch_dataset_sample(fake_dataset):
Expand All @@ -62,8 +64,8 @@ def test_to_pytorch(fake_dataset):
tokenizer = open_clip.get_tokenizer("ViT-B-32")
pt_dataset = fake_dataset.to_pytorch(transform=transform, tokenizer=tokenizer)
assert isinstance(pt_dataset, IterableDataset)
for img, text, label in pt_dataset:
assert isinstance(img, Tensor)
assert isinstance(text, Tensor)
assert isinstance(label, int)
assert img.size() == Size([3, 64, 64])
img, text, label = next(iter(pt_dataset))
assert isinstance(img, Tensor)
assert isinstance(text, Tensor)
assert isinstance(label, int)
assert img.size() == Size([3, 64, 64])
21 changes: 21 additions & 0 deletions tests/unit/lib/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
import torch
from torch import float32
from torchvision.transforms import v2


@pytest.fixture(scope="session")
def fake_clip_model():
class Model:
def encode_image(self, tensor):
return torch.randn(len(tensor), 512)

def encode_text(self, tensor):
return torch.randn(len(tensor), 512)

def tokenizer(tensor, context_length=77):
return torch.randn(len(tensor), context_length)

model = Model()
preprocess = v2.ToDtype(float32, scale=True)
return model, preprocess, tokenizer
12 changes: 4 additions & 8 deletions tests/unit/lib/test_clip.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import open_clip
import pytest
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
Expand All @@ -7,10 +6,6 @@

IMAGES = [Image.new(mode="RGB", size=(64, 64)), Image.new(mode="RGB", size=(32, 32))]
TEXTS = ["text1", "text2"]
MODEL, _, PREPROCESS = open_clip.create_model_and_transforms(
"ViT-B-32", pretrained="laion2b_s34b_b79k"
)
TOKENIZER = open_clip.get_tokenizer("ViT-B-32")


@pytest.mark.parametrize(
Expand All @@ -20,15 +15,16 @@
@pytest.mark.parametrize("text", [None, "text", TEXTS])
@pytest.mark.parametrize("prob", [True, False])
@pytest.mark.parametrize("image_to_text", [True, False])
def test_similarity_scores(images, text, prob, image_to_text):
def test_similarity_scores(fake_clip_model, images, text, prob, image_to_text):
model, preprocess, tokenizer = fake_clip_model
if not (images or text):
with pytest.raises(ValueError):
scores = similarity_scores(
images, text, MODEL, PREPROCESS, TOKENIZER, prob, image_to_text
images, text, model, preprocess, tokenizer, prob, image_to_text
)
else:
scores = similarity_scores(
images, text, MODEL, PREPROCESS, TOKENIZER, prob, image_to_text
images, text, model, preprocess, tokenizer, prob, image_to_text
)
assert isinstance(scores, list)
if not images:
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/lib/test_datachain_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def teardown(self):
self.value = MyMapper.TEARDOWN_VALUE


def test_udf(catalog):
def test_udf():
vals = ["a", "b", "c", "d", "e", "f"]
chain = DataChain.from_features(key=vals)

Expand All @@ -36,7 +36,7 @@ def test_udf(catalog):


@pytest.mark.skip(reason="Skip until tests module will be importer for unit-tests")
def test_udf_parallel(catalog):
def test_udf_parallel():
vals = ["a", "b", "c", "d", "e", "f"]
chain = DataChain.from_features(key=vals)

Expand All @@ -45,7 +45,7 @@ def test_udf_parallel(catalog):
assert res == [MyMapper.BOOTSTRAP_VALUE] * len(vals)


def test_no_bootstrap_for_callable(catalog):
def test_no_bootstrap_for_callable():
class MyMapper:
def __init__(self):
self._had_bootstrap = False
Expand Down
Loading

0 comments on commit 72ec99a

Please sign in to comment.