diff --git a/jina/serve/executors/__init__.py b/jina/serve/executors/__init__.py index b33099f61e90e..fa8e08064f616 100644 --- a/jina/serve/executors/__init__.py +++ b/jina/serve/executors/__init__.py @@ -619,7 +619,14 @@ def _validate_sagemaker(self): ): return + remove_keys = set() + for k in self.requests.keys(): + if k != '/invocations': + remove_keys.add(k) + if '/invocations' in self.requests: + for k in remove_keys: + self.requests.pop(k) return if ( @@ -632,12 +639,16 @@ def _validate_sagemaker(self): f'Using "{endpoint_to_use}" as "/invocations" route' ) self.requests['/invocations'] = self.requests[endpoint_to_use] + for k in remove_keys: + self.requests.pop(k) return if len(self.requests) == 1: route = list(self.requests.keys())[0] self.logger.warning(f'Using "{route}" as "/invocations" route') self.requests['/invocations'] = self.requests[route] + for k in remove_keys: + self.requests.pop(k) return raise ValueError('Cannot identify the endpoint to use for "/invocations"') diff --git a/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py b/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py index c4153ec3480fc..76e3c429da7b9 100644 --- a/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py +++ b/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py @@ -272,8 +272,21 @@ async def event_generator(): input_doc_model = input_output_map['input'] output_doc_model = input_output_map['output'] is_generator = input_output_map['is_generator'] - parameters_model = input_output_map['parameters'] or Optional[Dict] - default_parameters = ... if input_output_map['parameters'] else None + parameters_model = input_output_map['parameters'] + parameters_model_needed = parameters_model is not None + if parameters_model_needed: + try: + _ = parameters_model() + parameters_model_needed = False + except: + parameters_model_needed = True + parameters_model = parameters_model if parameters_model_needed else Optional[parameters_model] + default_parameters = ( + ... if parameters_model_needed else None + ) + else: + parameters_model = Optional[Dict] + default_parameters = None _config = inherit_config(InnerConfig, BaseDoc.__config__) diff --git a/jina/serve/runtimes/worker/http_fastapi_app.py b/jina/serve/runtimes/worker/http_fastapi_app.py index 47006dd4be329..fd4aeabf8c79c 100644 --- a/jina/serve/runtimes/worker/http_fastapi_app.py +++ b/jina/serve/runtimes/worker/http_fastapi_app.py @@ -162,10 +162,21 @@ async def streaming_get(request: Request = None, body: input_doc_model = None): input_doc_model = input_output_map['input']['model'] output_doc_model = input_output_map['output']['model'] is_generator = input_output_map['is_generator'] - parameters_model = input_output_map['parameters']['model'] or Optional[Dict] - default_parameters = ( - ... if input_output_map['parameters']['model'] else None - ) + parameters_model = input_output_map['parameters']['model'] + parameters_model_needed = parameters_model is not None + if parameters_model_needed: + try: + _ = parameters_model() + parameters_model_needed = False + except: + parameters_model_needed = True + parameters_model = parameters_model if parameters_model_needed else Optional[parameters_model] + default_parameters = ( + ... if parameters_model_needed else None + ) + else: + parameters_model = Optional[Dict] + default_parameters = None if docarray_v2: _config = inherit_config(InnerConfig, BaseDoc.__config__) diff --git a/jina/serve/runtimes/worker/http_sagemaker_app.py b/jina/serve/runtimes/worker/http_sagemaker_app.py index e44082afc57f2..62ef822b19ad4 100644 --- a/jina/serve/runtimes/worker/http_sagemaker_app.py +++ b/jina/serve/runtimes/worker/http_sagemaker_app.py @@ -12,11 +12,11 @@ def get_fastapi_app( - request_models_map: Dict, - caller: Callable, - logger: 'JinaLogger', - cors: bool = False, - **kwargs, + request_models_map: Dict, + caller: Callable, + logger: 'JinaLogger', + cors: bool = False, + **kwargs, ): """ Get the app from FastAPI as the REST interface. @@ -70,11 +70,11 @@ class InnerConfig(BaseConfig): logger.warning('CORS is enabled. This service is accessible from any website!') def add_post_route( - endpoint_path, - input_model, - output_model, - input_doc_list_model=None, - output_doc_list_model=None, + endpoint_path, + input_model, + output_model, + input_doc_list_model=None, + output_doc_list_model=None, ): import json from typing import List, Type, Union @@ -155,13 +155,13 @@ async def post(request: Request): ) def construct_model_from_line( - model: Type[BaseModel], line: List[str] + model: Type[BaseModel], line: List[str] ) -> BaseModel: parsed_fields = {} model_fields = model.__fields__ for field_str, (field_name, field_info) in zip( - line, model_fields.items() + line, model_fields.items() ): field_type = field_info.outer_type_ @@ -204,16 +204,16 @@ def construct_model_from_line( field_names = [f for f in input_doc_list_model.__fields__] data = [] for line in csv.reader( - StringIO(csv_body), - delimiter=',', - quoting=csv.QUOTE_NONE, - escapechar='\\', + StringIO(csv_body), + delimiter=',', + quoting=csv.QUOTE_NONE, + escapechar='\\', ): if len(line) != len(field_names): raise HTTPException( status_code=400, detail=f'Invalid CSV format. Line {line} doesn\'t match ' - f'the expected field order {field_names}.', + f'the expected field order {field_names}.', ) data.append(construct_model_from_line(input_doc_list_model, line)) @@ -223,17 +223,28 @@ def construct_model_from_line( raise HTTPException( status_code=400, detail=f'Invalid content-type: {content_type}. ' - f'Please use either application/json or text/csv.', + f'Please use either application/json or text/csv.', ) for endpoint, input_output_map in request_models_map.items(): if endpoint != '_jina_dry_run_': input_doc_model = input_output_map['input']['model'] output_doc_model = input_output_map['output']['model'] - parameters_model = input_output_map['parameters']['model'] or Optional[Dict] - default_parameters = ( - ... if input_output_map['parameters']['model'] else None - ) + parameters_model = input_output_map['parameters']['model'] + parameters_model_needed = parameters_model is not None + if parameters_model_needed: + try: + _ = parameters_model() + parameters_model_needed = False + except: + parameters_model_needed = True + parameters_model = parameters_model if parameters_model_needed else Optional[parameters_model] + default_parameters = ( + ... if parameters_model_needed else None + ) + else: + parameters_model = Optional[Dict] + default_parameters = None _config = inherit_config(InnerConfig, BaseDoc.__config__) endpoint_input_model = pydantic.create_model( diff --git a/tests/integration/docarray_v2/test_parameters_as_pydantic.py b/tests/integration/docarray_v2/test_parameters_as_pydantic.py index 94045b5a64e83..c0bae50061a19 100644 --- a/tests/integration/docarray_v2/test_parameters_as_pydantic.py +++ b/tests/integration/docarray_v2/test_parameters_as_pydantic.py @@ -23,7 +23,7 @@ class Parameters(BaseModel): class FooParameterExecutor(Executor): @requests(on='/hello') def foo( - self, docs: DocList[TextDoc], parameters: Parameters, **kwargs + self, docs: DocList[TextDoc], parameters: Parameters, **kwargs ) -> DocList[TextDoc]: for doc in docs: doc.text += f'Processed by foo with param: {parameters.param} and num: {parameters.num}' @@ -68,15 +68,15 @@ def bar(self, doc: TextDoc, parameters: Parameters, **kwargs) -> TextDoc: resp = global_requests.post(url, json=myobj) resp_json = resp.json() assert ( - resp_json['data'][0]['text'] - == f'Processed by {processed_by} with param: value and num: 5' + resp_json['data'][0]['text'] + == f'Processed by {processed_by} with param: value and num: 5' ) myobj = {'data': [{'text': ''}], 'parameters': {'param': 'value'}} resp = global_requests.post(url, json=myobj) resp_json = resp.json() assert ( - resp_json['data'][0]['text'] - == f'Processed by {processed_by} with param: value and num: 5' + resp_json['data'][0]['text'] + == f'Processed by {processed_by} with param: value and num: 5' ) @@ -94,7 +94,7 @@ class Parameters(BaseModel): class FooInvalidParameterExecutor(Executor): @requests(on='/hello') def foo( - self, docs: DocList[TextDoc], parameters: Parameters, **kwargs + self, docs: DocList[TextDoc], parameters: Parameters, **kwargs ) -> DocList[TextDoc]: for doc in docs: doc.text += f'Processed by foo with param: {parameters.param} and num: {parameters.num}' @@ -131,7 +131,7 @@ class ParametersFirst(BaseModel): class Exec1Chain(Executor): @requests(on='/bar') def bar( - self, docs: DocList[Input1], parameters: ParametersFirst, **kwargs + self, docs: DocList[Input1], parameters: ParametersFirst, **kwargs ) -> DocList[Output1]: docs_return = DocList[Output1]( [Output1(price=5 * parameters.mult) for _ in range(len(docs))] @@ -180,7 +180,7 @@ def bar(self, docs: DocList[Input1], **kwargs) -> DocList[Output1]: class Exec2Chain(Executor): @requests(on='/bar') def bar( - self, docs: DocList[Output1], parameters: ParametersSecond, **kwargs + self, docs: DocList[Output1], parameters: ParametersSecond, **kwargs ) -> DocList[Output2]: docs_return = DocList[Output2]( [ @@ -231,7 +231,7 @@ class MyConfigParam(BaseModel): class MyExecDocWithExample(Executor): @requests def foo( - self, docs: DocList[MyDocWithExample], parameters: MyConfigParam, **kwargs + self, docs: DocList[MyDocWithExample], parameters: MyConfigParam, **kwargs ) -> DocList[MyDocWithExample]: pass @@ -261,3 +261,61 @@ def foo( assert 'Configuration for Executor endpoint' in resp_str assert 'batch size' in resp_str assert '256' in resp_str + + +@pytest.mark.parametrize('ctxt_manager', ['deployment', 'flow']) +def test_parameters_all_default_not_required(ctxt_manager): + class DefaultParameters(BaseModel): + param: str = 'default' + num: int = 5 + + class DefaultParamExecutor(Executor): + @requests(on='/hello') + def foo( + self, docs: DocList[TextDoc], parameters: DefaultParameters, **kwargs + ) -> DocList[TextDoc]: + for doc in docs: + doc.text += f'Processed by foo with param: {parameters.param} and num: {parameters.num}' + + @requests(on='/hello_single') + def bar(self, doc: TextDoc, parameters: DefaultParameters, **kwargs) -> TextDoc: + doc.text = f'Processed by bar with param: {parameters.param} and num: {parameters.num}' + + if ctxt_manager == 'flow': + ctxt_mgr = Flow(protocol='http').add(uses=DefaultParamExecutor) + else: + ctxt_mgr = Deployment(protocol='http', uses=DefaultParamExecutor) + + with ctxt_mgr: + ret = ctxt_mgr.post( + on='/hello', + inputs=DocList[TextDoc]([TextDoc(text='')]), + ) + assert len(ret) == 1 + assert ret[0].text == 'Processed by foo with param: default and num: 5' + + ret = ctxt_mgr.post( + on='/hello_single', + inputs=DocList[TextDoc]([TextDoc(text='')]), + ) + assert len(ret) == 1 + assert ret[0].text == 'Processed by bar with param: default and num: 5' + import requests as global_requests + + for endpoint in {'hello', 'hello_single'}: + processed_by = 'foo' if endpoint == 'hello' else 'bar' + url = f'http://localhost:{ctxt_mgr.port}/{endpoint}' + myobj = {'data': {'text': ''}} + resp = global_requests.post(url, json=myobj) + resp_json = resp.json() + assert ( + resp_json['data'][0]['text'] + == f'Processed by {processed_by} with param: default and num: 5' + ) + myobj = {'data': [{'text': ''}]} + resp = global_requests.post(url, json=myobj) + resp_json = resp.json() + assert ( + resp_json['data'][0]['text'] + == f'Processed by {processed_by} with param: default and num: 5' + )