Skip to content

Commit

Permalink
test: test CSP parameters (#6176)
Browse files Browse the repository at this point in the history
Co-authored-by: Jina Dev Bot <[email protected]>
  • Loading branch information
JoanFM and jina-bot authored Jul 9, 2024
1 parent 58e0397 commit e3ea29f
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 17 deletions.
31 changes: 16 additions & 15 deletions jina/serve/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,25 +634,26 @@ def _validate_sagemaker(self):
and self.runtime_args.provider_endpoint
):
endpoint_to_use = ('/' + self.runtime_args.provider_endpoint).lower()
if endpoint_to_use in list(self.requests.keys()):
self.logger.warning(
f'Using "{endpoint_to_use}" as "/invocations" route'
)
self.requests['/invocations'] = self.requests[endpoint_to_use]
for k in remove_keys:
self.requests.pop(k)
return

if len(self.requests) == 1:
route = list(self.requests.keys())[0]
self.logger.warning(f'Using "{route}" as "/invocations" route')
self.requests['/invocations'] = self.requests[route]
elif len(self.requests) == 1:
endpoint_to_use = list(self.requests.keys())[0]
else:
raise ValueError('Cannot identify the endpoint to use for "/invocations"')

if endpoint_to_use in list(self.requests.keys()):
self.logger.warning(f'Using "{endpoint_to_use}" as "/invocations" route')
self.requests['/invocations'] = self.requests[endpoint_to_use]
if (
getattr(self, 'dynamic_batching', {}).get(endpoint_to_use, None)
is not None
):
self.dynamic_batching['/invocations'] = self.dynamic_batching[
endpoint_to_use
]
self.dynamic_batching.pop(endpoint_to_use)
for k in remove_keys:
self.requests.pop(k)
return

raise ValueError('Cannot identify the endpoint to use for "/invocations"')

def _add_dynamic_batching(self, _dynamic_batching: Optional[Dict]):
if _dynamic_batching:
self.dynamic_batching = getattr(self, 'dynamic_batching', {})
Expand Down
9 changes: 9 additions & 0 deletions jina/serve/runtimes/worker/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,16 @@ def _init_batchqueue_dict(self):
# Endpoints allow specific configurations while functions allow configs to be applied to all endpoints of the function
dbatch_endpoints = []
dbatch_functions = []
request_models_map = self._executor._get_endpoint_models_dict()

for key, dbatch_config in self._executor.dynamic_batching.items():
if request_models_map.get(key, {}).get('parameters', {}).get('model', None) is not None:
error_msg = f'Executor Dynamic Batching cannot be used for endpoint {key} because it depends on parameters.'
self.logger.error(
error_msg
)
raise Exception(error_msg)

if key.startswith('/'):
dbatch_endpoints.append((key, dbatch_config))
else:
Expand Down
20 changes: 19 additions & 1 deletion tests/integration/docarray_v2/csp/SampleExecutor/executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from docarray import BaseDoc, DocList
from docarray.typing import NdArray
from pydantic import Field
from pydantic import Field, BaseModel

from jina import Executor, requests

Expand All @@ -19,6 +19,11 @@ class Config(BaseDoc.Config):
json_encoders = {NdArray: lambda v: v.tolist()}


class Parameters(BaseModel):
emb_dim: int



class SampleExecutor(Executor):
@requests(on="/encode")
def foo(self, docs: DocList[TextDoc], **kwargs) -> DocList[EmbeddingResponseModel]:
Expand All @@ -32,3 +37,16 @@ def foo(self, docs: DocList[TextDoc], **kwargs) -> DocList[EmbeddingResponseMode
)
)
return DocList[EmbeddingResponseModel](ret)

@requests(on="/encode_parameter")
def bar(self, docs: DocList[TextDoc], parameters: Parameters, **kwargs) -> DocList[EmbeddingResponseModel]:
ret = []
for doc in docs:
ret.append(
EmbeddingResponseModel(
id=doc.id,
text=doc.text,
embeddings=np.random.random((1, parameters.emb_dim)),
)
)
return DocList[EmbeddingResponseModel](ret)
65 changes: 64 additions & 1 deletion tests/integration/docarray_v2/csp/test_sagemaker_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def test_provider_sagemaker_pod_inference():
os.path.join(os.path.dirname(__file__), "SampleExecutor", "config.yml"),
'--provider',
'sagemaker',
"--provider-endpoint",
"encode",
'serve', # This is added by sagemaker
]
)
Expand All @@ -60,6 +62,43 @@ def test_provider_sagemaker_pod_inference():
assert len(resp_json['data'][0]['embeddings'][0]) == 64


def test_provider_sagemaker_pod_inference_parameters():
args, _ = set_pod_parser().parse_known_args(
[
'--uses',
os.path.join(os.path.dirname(__file__), "SampleExecutor", "config.yml"),
'--provider',
'sagemaker',
"--provider-endpoint",
"encode_parameter",
'serve', # This is added by sagemaker
]
)
with Pod(args):
# Test the `GET /ping` endpoint (added by jina for sagemaker)
resp = requests.get(f'http://localhost:{sagemaker_port}/ping')
assert resp.status_code == 200
assert resp.json() == {}
for emb_dim in {32, 64, 128}:

# Test the `POST /invocations` endpoint for inference
# Note: this endpoint is not implemented in the sample executor
resp = requests.post(
f'http://localhost:{sagemaker_port}/invocations',
json={
'data': [
{'text': 'hello world'},
],
'parameters': {'emb_dim': emb_dim}
},
)
assert resp.status_code == 200
resp_json = resp.json()
assert len(resp_json['data']) == 1
assert len(resp_json['data'][0]['embeddings'][0]) == emb_dim



@pytest.mark.parametrize(
"filename",
[
Expand All @@ -74,6 +113,8 @@ def test_provider_sagemaker_pod_batch_transform_valid(filename):
os.path.join(os.path.dirname(__file__), "SampleExecutor", "config.yml"),
'--provider',
'sagemaker',
"--provider-endpoint",
"encode",
'serve', # This is added by sagemaker
]
)
Expand Down Expand Up @@ -114,6 +155,8 @@ def test_provider_sagemaker_pod_batch_transform_invalid():
os.path.join(os.path.dirname(__file__), "SampleExecutor", "config.yml"),
'--provider',
'sagemaker',
"--provider-endpoint",
"encode",
'serve', # This is added by sagemaker
]
)
Expand Down Expand Up @@ -145,6 +188,7 @@ def test_provider_sagemaker_deployment_inference():
with Deployment(
uses=os.path.join(os.path.dirname(__file__), "SampleExecutor", "config.yml"),
provider='sagemaker',
provider_endpoint='encode',
port=dep_port,
):
# Test the `GET /ping` endpoint (added by jina for sagemaker)
Expand All @@ -171,7 +215,7 @@ def test_provider_sagemaker_deployment_inference():
def test_provider_sagemaker_deployment_inference_docker(replica_docker_image_built):
dep_port = random_port()
with Deployment(
uses='docker://sampler-executor', provider='sagemaker', port=dep_port
uses='docker://sampler-executor', provider='sagemaker', provider_endpoint='encode', port=dep_port
):
# Test the `GET /ping` endpoint (added by jina for sagemaker)
rsp = requests.get(f'http://localhost:{dep_port}/ping')
Expand Down Expand Up @@ -200,6 +244,7 @@ def test_provider_sagemaker_deployment_batch():
with Deployment(
uses=os.path.join(os.path.dirname(__file__), "SampleExecutor", "config.yml"),
provider='sagemaker',
provider_endpoint='encode',
port=dep_port,
):
# Test the `POST /invocations` endpoint for batch-transform
Expand Down Expand Up @@ -230,6 +275,24 @@ def test_provider_sagemaker_deployment_wrong_port():
os.path.dirname(__file__), "SampleExecutor", "config.yml"
),
provider='sagemaker',
provider_endpoint='encode',
port=8080,
):
pass


def test_provider_sagemaker_deployment_wrong_dynamic_batching():
# Sagemaker executor would start on 8080.
# If we use the same port for deployment, it should raise an error.
from jina.excepts import RuntimeFailToStart

with pytest.raises(RuntimeFailToStart) as exc:
with Deployment(
uses=os.path.join(
os.path.dirname(__file__), "SampleExecutor", "config.yml"
),
provider='sagemaker',
provider_endpoint='encode_parameter',
uses_dynamic_batching={'/encode_parameter': {'preferred_batch_size': 20, 'timeout': 50}},
):
pass

0 comments on commit e3ea29f

Please sign in to comment.