Skip to content

Commit

Permalink
refactor: keep only sagemaker endpoint (#6152)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored Mar 20, 2024
1 parent ade9084 commit b5793f0
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 37 deletions.
11 changes: 11 additions & 0 deletions jina/serve/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,14 @@ def _validate_sagemaker(self):
):
return

remove_keys = set()
for k in self.requests.keys():
if k != '/invocations':
remove_keys.add(k)

if '/invocations' in self.requests:
for k in remove_keys:
self.requests.pop(k)
return

if (
Expand All @@ -632,12 +639,16 @@ def _validate_sagemaker(self):
f'Using "{endpoint_to_use}" as "/invocations" route'
)
self.requests['/invocations'] = self.requests[endpoint_to_use]
for k in remove_keys:
self.requests.pop(k)
return

if len(self.requests) == 1:
route = list(self.requests.keys())[0]
self.logger.warning(f'Using "{route}" as "/invocations" route')
self.requests['/invocations'] = self.requests[route]
for k in remove_keys:
self.requests.pop(k)
return

raise ValueError('Cannot identify the endpoint to use for "/invocations"')
Expand Down
17 changes: 15 additions & 2 deletions jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,21 @@ async def event_generator():
input_doc_model = input_output_map['input']
output_doc_model = input_output_map['output']
is_generator = input_output_map['is_generator']
parameters_model = input_output_map['parameters'] or Optional[Dict]
default_parameters = ... if input_output_map['parameters'] else None
parameters_model = input_output_map['parameters']
parameters_model_needed = parameters_model is not None
if parameters_model_needed:
try:
_ = parameters_model()
parameters_model_needed = False
except:
parameters_model_needed = True
parameters_model = parameters_model if parameters_model_needed else Optional[parameters_model]
default_parameters = (
... if parameters_model_needed else None
)
else:
parameters_model = Optional[Dict]
default_parameters = None

_config = inherit_config(InnerConfig, BaseDoc.__config__)

Expand Down
19 changes: 15 additions & 4 deletions jina/serve/runtimes/worker/http_fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,21 @@ async def streaming_get(request: Request = None, body: input_doc_model = None):
input_doc_model = input_output_map['input']['model']
output_doc_model = input_output_map['output']['model']
is_generator = input_output_map['is_generator']
parameters_model = input_output_map['parameters']['model'] or Optional[Dict]
default_parameters = (
... if input_output_map['parameters']['model'] else None
)
parameters_model = input_output_map['parameters']['model']
parameters_model_needed = parameters_model is not None
if parameters_model_needed:
try:
_ = parameters_model()
parameters_model_needed = False
except:
parameters_model_needed = True
parameters_model = parameters_model if parameters_model_needed else Optional[parameters_model]
default_parameters = (
... if parameters_model_needed else None
)
else:
parameters_model = Optional[Dict]
default_parameters = None

if docarray_v2:
_config = inherit_config(InnerConfig, BaseDoc.__config__)
Expand Down
55 changes: 33 additions & 22 deletions jina/serve/runtimes/worker/http_sagemaker_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@


def get_fastapi_app(
request_models_map: Dict,
caller: Callable,
logger: 'JinaLogger',
cors: bool = False,
**kwargs,
request_models_map: Dict,
caller: Callable,
logger: 'JinaLogger',
cors: bool = False,
**kwargs,
):
"""
Get the app from FastAPI as the REST interface.
Expand Down Expand Up @@ -70,11 +70,11 @@ class InnerConfig(BaseConfig):
logger.warning('CORS is enabled. This service is accessible from any website!')

def add_post_route(
endpoint_path,
input_model,
output_model,
input_doc_list_model=None,
output_doc_list_model=None,
endpoint_path,
input_model,
output_model,
input_doc_list_model=None,
output_doc_list_model=None,
):
import json
from typing import List, Type, Union
Expand Down Expand Up @@ -155,13 +155,13 @@ async def post(request: Request):
)

def construct_model_from_line(
model: Type[BaseModel], line: List[str]
model: Type[BaseModel], line: List[str]
) -> BaseModel:
parsed_fields = {}
model_fields = model.__fields__

for field_str, (field_name, field_info) in zip(
line, model_fields.items()
line, model_fields.items()
):
field_type = field_info.outer_type_

Expand Down Expand Up @@ -204,16 +204,16 @@ def construct_model_from_line(
field_names = [f for f in input_doc_list_model.__fields__]
data = []
for line in csv.reader(
StringIO(csv_body),
delimiter=',',
quoting=csv.QUOTE_NONE,
escapechar='\\',
StringIO(csv_body),
delimiter=',',
quoting=csv.QUOTE_NONE,
escapechar='\\',
):
if len(line) != len(field_names):
raise HTTPException(
status_code=400,
detail=f'Invalid CSV format. Line {line} doesn\'t match '
f'the expected field order {field_names}.',
f'the expected field order {field_names}.',
)
data.append(construct_model_from_line(input_doc_list_model, line))

Expand All @@ -223,17 +223,28 @@ def construct_model_from_line(
raise HTTPException(
status_code=400,
detail=f'Invalid content-type: {content_type}. '
f'Please use either application/json or text/csv.',
f'Please use either application/json or text/csv.',
)

for endpoint, input_output_map in request_models_map.items():
if endpoint != '_jina_dry_run_':
input_doc_model = input_output_map['input']['model']
output_doc_model = input_output_map['output']['model']
parameters_model = input_output_map['parameters']['model'] or Optional[Dict]
default_parameters = (
... if input_output_map['parameters']['model'] else None
)
parameters_model = input_output_map['parameters']['model']
parameters_model_needed = parameters_model is not None
if parameters_model_needed:
try:
_ = parameters_model()
parameters_model_needed = False
except:
parameters_model_needed = True
parameters_model = parameters_model if parameters_model_needed else Optional[parameters_model]
default_parameters = (
... if parameters_model_needed else None
)
else:
parameters_model = Optional[Dict]
default_parameters = None

_config = inherit_config(InnerConfig, BaseDoc.__config__)
endpoint_input_model = pydantic.create_model(
Expand Down
76 changes: 67 additions & 9 deletions tests/integration/docarray_v2/test_parameters_as_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Parameters(BaseModel):
class FooParameterExecutor(Executor):
@requests(on='/hello')
def foo(
self, docs: DocList[TextDoc], parameters: Parameters, **kwargs
self, docs: DocList[TextDoc], parameters: Parameters, **kwargs
) -> DocList[TextDoc]:
for doc in docs:
doc.text += f'Processed by foo with param: {parameters.param} and num: {parameters.num}'
Expand Down Expand Up @@ -68,15 +68,15 @@ def bar(self, doc: TextDoc, parameters: Parameters, **kwargs) -> TextDoc:
resp = global_requests.post(url, json=myobj)
resp_json = resp.json()
assert (
resp_json['data'][0]['text']
== f'Processed by {processed_by} with param: value and num: 5'
resp_json['data'][0]['text']
== f'Processed by {processed_by} with param: value and num: 5'
)
myobj = {'data': [{'text': ''}], 'parameters': {'param': 'value'}}
resp = global_requests.post(url, json=myobj)
resp_json = resp.json()
assert (
resp_json['data'][0]['text']
== f'Processed by {processed_by} with param: value and num: 5'
resp_json['data'][0]['text']
== f'Processed by {processed_by} with param: value and num: 5'
)


Expand All @@ -94,7 +94,7 @@ class Parameters(BaseModel):
class FooInvalidParameterExecutor(Executor):
@requests(on='/hello')
def foo(
self, docs: DocList[TextDoc], parameters: Parameters, **kwargs
self, docs: DocList[TextDoc], parameters: Parameters, **kwargs
) -> DocList[TextDoc]:
for doc in docs:
doc.text += f'Processed by foo with param: {parameters.param} and num: {parameters.num}'
Expand Down Expand Up @@ -131,7 +131,7 @@ class ParametersFirst(BaseModel):
class Exec1Chain(Executor):
@requests(on='/bar')
def bar(
self, docs: DocList[Input1], parameters: ParametersFirst, **kwargs
self, docs: DocList[Input1], parameters: ParametersFirst, **kwargs
) -> DocList[Output1]:
docs_return = DocList[Output1](
[Output1(price=5 * parameters.mult) for _ in range(len(docs))]
Expand Down Expand Up @@ -180,7 +180,7 @@ def bar(self, docs: DocList[Input1], **kwargs) -> DocList[Output1]:
class Exec2Chain(Executor):
@requests(on='/bar')
def bar(
self, docs: DocList[Output1], parameters: ParametersSecond, **kwargs
self, docs: DocList[Output1], parameters: ParametersSecond, **kwargs
) -> DocList[Output2]:
docs_return = DocList[Output2](
[
Expand Down Expand Up @@ -231,7 +231,7 @@ class MyConfigParam(BaseModel):
class MyExecDocWithExample(Executor):
@requests
def foo(
self, docs: DocList[MyDocWithExample], parameters: MyConfigParam, **kwargs
self, docs: DocList[MyDocWithExample], parameters: MyConfigParam, **kwargs
) -> DocList[MyDocWithExample]:
pass

Expand Down Expand Up @@ -261,3 +261,61 @@ def foo(
assert 'Configuration for Executor endpoint' in resp_str
assert 'batch size' in resp_str
assert '256' in resp_str


@pytest.mark.parametrize('ctxt_manager', ['deployment', 'flow'])
def test_parameters_all_default_not_required(ctxt_manager):
class DefaultParameters(BaseModel):
param: str = 'default'
num: int = 5

class DefaultParamExecutor(Executor):
@requests(on='/hello')
def foo(
self, docs: DocList[TextDoc], parameters: DefaultParameters, **kwargs
) -> DocList[TextDoc]:
for doc in docs:
doc.text += f'Processed by foo with param: {parameters.param} and num: {parameters.num}'

@requests(on='/hello_single')
def bar(self, doc: TextDoc, parameters: DefaultParameters, **kwargs) -> TextDoc:
doc.text = f'Processed by bar with param: {parameters.param} and num: {parameters.num}'

if ctxt_manager == 'flow':
ctxt_mgr = Flow(protocol='http').add(uses=DefaultParamExecutor)
else:
ctxt_mgr = Deployment(protocol='http', uses=DefaultParamExecutor)

with ctxt_mgr:
ret = ctxt_mgr.post(
on='/hello',
inputs=DocList[TextDoc]([TextDoc(text='')]),
)
assert len(ret) == 1
assert ret[0].text == 'Processed by foo with param: default and num: 5'

ret = ctxt_mgr.post(
on='/hello_single',
inputs=DocList[TextDoc]([TextDoc(text='')]),
)
assert len(ret) == 1
assert ret[0].text == 'Processed by bar with param: default and num: 5'
import requests as global_requests

for endpoint in {'hello', 'hello_single'}:
processed_by = 'foo' if endpoint == 'hello' else 'bar'
url = f'http://localhost:{ctxt_mgr.port}/{endpoint}'
myobj = {'data': {'text': ''}}
resp = global_requests.post(url, json=myobj)
resp_json = resp.json()
assert (
resp_json['data'][0]['text']
== f'Processed by {processed_by} with param: default and num: 5'
)
myobj = {'data': [{'text': ''}]}
resp = global_requests.post(url, json=myobj)
resp_json = resp.json()
assert (
resp_json['data'][0]['text']
== f'Processed by {processed_by} with param: default and num: 5'
)

0 comments on commit b5793f0

Please sign in to comment.