Skip to content

Commit 98429b0

Browse files
authored
fix: support sagemaker batch transform for clip (#6171)
1 parent 12e2a94 commit 98429b0

File tree

7 files changed

+152
-2
lines changed

7 files changed

+152
-2
lines changed

jina/serve/runtimes/worker/http_csp_app.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,14 @@ def construct_model_from_line(
189189
)
190190
else:
191191
parsed_fields[field_name] = parsed_list
192-
# Handle direct assignment for basic types
192+
# General parsing attempt for other types
193193
else:
194-
parsed_fields[field_name] = field_info.type_(field_str)
194+
if field_str:
195+
try:
196+
parsed_fields[field_name] = field_info.type_(field_str)
197+
except (ValueError, TypeError):
198+
# Fallback to parse_obj_as when type is more complex, e., AnyUrl or ImageBytes
199+
parsed_fields[field_name] = parse_obj_as(field_info.type_, field_str)
195200

196201
return model(**parsed_fields)
197202

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# SampleClipExecutor
2+
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
jtype: SampleClipExecutor
2+
py_modules:
3+
- executor.py
4+
metas:
5+
name: SampleClipExecutor
6+
description:
7+
url:
8+
keywords: []
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import Optional
2+
3+
import numpy as np
4+
from docarray import BaseDoc, DocList
5+
from docarray.typing import NdArray
6+
from docarray.typing.bytes import ImageBytes
7+
from docarray.typing.url import AnyUrl
8+
from jina import Executor, requests
9+
from pydantic import Field
10+
11+
12+
class TextAndImageDoc(BaseDoc):
13+
text: Optional[str] = None
14+
url: Optional[AnyUrl] = None
15+
bytes: Optional[ImageBytes] = None
16+
17+
18+
class EmbeddingResponseModel(TextAndImageDoc):
19+
embeddings: NdArray = Field(description="The embedding of the texts", default=[])
20+
21+
class Config(BaseDoc.Config):
22+
allow_population_by_field_name = True
23+
arbitrary_types_allowed = True
24+
json_encoders = {NdArray: lambda v: v.tolist()}
25+
26+
27+
class SampleClipExecutor(Executor):
28+
@requests(on="/encode")
29+
def foo(
30+
self, docs: DocList[TextAndImageDoc], **kwargs
31+
) -> DocList[EmbeddingResponseModel]:
32+
ret = []
33+
for doc in docs:
34+
ret.append(
35+
EmbeddingResponseModel(
36+
id=doc.id,
37+
text=doc.text,
38+
url=doc.url,
39+
bytes=doc.bytes,
40+
embeddings=np.random.random((1, 64)),
41+
)
42+
)
43+
return DocList[EmbeddingResponseModel](ret)

tests/integration/docarray_v2/csp/SampleClipExecutor/requirements.txt

Whitespace-only changes.
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import csv
2+
import io
3+
import os
4+
5+
import requests
6+
from jina.orchestrate.pods import Pod
7+
from jina.parsers import set_pod_parser
8+
9+
sagemaker_port = 8080
10+
11+
12+
def test_provider_sagemaker_pod_rank():
13+
args, _ = set_pod_parser().parse_known_args(
14+
[
15+
"--uses",
16+
os.path.join(os.path.dirname(__file__), "SampleClipExecutor", "config.yml"),
17+
"--provider",
18+
"sagemaker",
19+
"--provider-endpoint",
20+
"encode",
21+
"serve", # This is added by sagemaker
22+
]
23+
)
24+
with Pod(args):
25+
# Test the `GET /ping` endpoint (added by jina for sagemaker)
26+
resp = requests.get(f"http://localhost:{sagemaker_port}/ping")
27+
assert resp.status_code == 200
28+
assert resp.json() == {}
29+
30+
# Test the `POST /invocations` endpoint for inference
31+
# Note: this endpoint is not implemented in the sample executor
32+
resp = requests.post(
33+
f"http://localhost:{sagemaker_port}/invocations",
34+
json={
35+
"data": [
36+
{"url": "http://google.com"},
37+
]
38+
},
39+
)
40+
assert resp.status_code == 200
41+
resp_json = resp.json()
42+
assert len(resp_json["data"]) == 1
43+
assert len(resp_json["data"][0]["embeddings"][0]) == 64
44+
assert resp_json["data"][0]["url"] == "http://google.com"
45+
46+
47+
def test_provider_sagemaker_pod_batch_transform_valid():
48+
args, _ = set_pod_parser().parse_known_args(
49+
[
50+
"--uses",
51+
os.path.join(os.path.dirname(__file__), "SampleClipExecutor", "config.yml"),
52+
"--provider",
53+
"sagemaker",
54+
"serve", # This is added by sagemaker
55+
]
56+
)
57+
with Pod(args):
58+
# Test `POST /invocations` endpoint for batch-transform with valid input
59+
with open(
60+
os.path.join(os.path.dirname(__file__), "valid_clip_input.csv"), "r"
61+
) as f:
62+
csv_data = f.read()
63+
64+
text = []
65+
for line in csv.reader(
66+
io.StringIO(csv_data),
67+
delimiter=",",
68+
quoting=csv.QUOTE_NONE,
69+
escapechar="\\",
70+
):
71+
text.append(line)
72+
73+
resp = requests.post(
74+
f"http://localhost:{sagemaker_port}/invocations",
75+
headers={
76+
"accept": "application/json",
77+
"content-type": "text/csv",
78+
},
79+
data=csv_data,
80+
)
81+
assert resp.status_code == 200
82+
resp_json = resp.json()
83+
assert len(resp_json["data"]) == 3
84+
assert resp_json["data"][0]["text"] == "the cat is in my house"
85+
assert (
86+
resp_json["data"][1]["url"]
87+
== "https://dummyimage3.com/333/000/fff.jpg&text=embed+this"
88+
)
89+
assert "hWjj1RNtNftP" in resp_json["data"][2]["bytes"]

tests/integration/docarray_v2/csp/valid_clip_input.csv

Lines changed: 3 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)