diff --git a/jina/serve/runtimes/helper.py b/jina/serve/runtimes/helper.py index 70bb75a485c1b..0319f7c8a18d8 100644 --- a/jina/serve/runtimes/helper.py +++ b/jina/serve/runtimes/helper.py @@ -99,7 +99,9 @@ def _parse_specific_params(parameters: Dict, executor_name: str): 'default', ] - def _create_aux_model_doc_list_to_list(model): + + def _create_aux_model_doc_list_to_list(model, cached_models=None): + cached_models = cached_models or {} fields: Dict[str, Any] = {} for field_name, field in model.__annotations__.items(): if field_name not in model.__fields__: @@ -108,18 +110,25 @@ def _create_aux_model_doc_list_to_list(model): try: if issubclass(field, DocList): t: Any = field.doc_type - t_aux = _create_aux_model_doc_list_to_list(t) - fields[field_name] = (List[t_aux], field_info) + if t.__name__ in cached_models: + fields[field_name] = (List[cached_models[t.__name__]], field_info) + else: + t_aux = _create_aux_model_doc_list_to_list(t, cached_models) + cached_models[t.__name__] = t_aux + fields[field_name] = (List[t_aux], field_info) else: fields[field_name] = (field, field_info) except TypeError: fields[field_name] = (field, field_info) - return create_model( + new_model = create_model( model.__name__, __base__=model, __validators__=model.__validators__, - **fields, - ) + **fields) + cached_models[model.__name__] = new_model + + return new_model + def _get_field_from_type( field_schema, @@ -264,6 +273,7 @@ def _get_field_from_type( ) return ret + def _create_pydantic_model_from_schema( schema: Dict[str, any], model_name: str, diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index 0849aaebb388d..6756eb0356ed8 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -178,6 +178,7 @@ def call_handle(request): ] return self.process_single_data(request, None, is_generator=is_generator) + app = get_fastapi_app( request_models_map=request_models_map, caller=call_handle, **kwargs ) @@ -1001,6 +1002,7 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto: endpoints_proto.write_endpoints.extend(list(self._executor.write_endpoints)) schemas = self._executor._get_endpoint_models_dict() if docarray_v2: + cached_aux_models = {} from docarray.documents.legacy import LegacyDocument from jina.serve.runtimes.helper import _create_aux_model_doc_list_to_list @@ -1011,16 +1013,14 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto: inner_dict['input']['model'] = legacy_doc_schema else: inner_dict['input']['model'] = _create_aux_model_doc_list_to_list( - inner_dict['input']['model'] + inner_dict['input']['model'], cached_aux_models ).schema() - if inner_dict['output']['model'].schema() == legacy_doc_schema: inner_dict['output']['model'] = legacy_doc_schema else: inner_dict['output']['model'] = _create_aux_model_doc_list_to_list( - inner_dict['output']['model'] + inner_dict['output']['model'], cached_aux_models ).schema() - if inner_dict['parameters']['model'] is not None: inner_dict['parameters']['model'] = inner_dict['parameters'][ 'model' @@ -1031,6 +1031,7 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto: inner_dict['output']['model'] = inner_dict['output']['model'].schema() inner_dict['parameters'] = {} json_format.ParseDict(schemas, endpoints_proto.schemas) + self.logger.debug('return an endpoint discovery request') return endpoints_proto def _extract_tracing_context( diff --git a/tests/unit/serve/runtimes/test_helper.py b/tests/unit/serve/runtimes/test_helper.py index 43abba91f2157..b16ac39dd18bf 100644 --- a/tests/unit/serve/runtimes/test_helper.py +++ b/tests/unit/serve/runtimes/test_helper.py @@ -382,3 +382,28 @@ class SearchResult(BaseDoc): reconstructed_in_gateway_from_Search_results = QuoteFile_reconstructed_in_gateway_from_Search_results( texts=textlist) assert reconstructed_in_gateway_from_Search_results.texts[0].text == 'hey' + + +@pytest.mark.skipif(not docarray_v2, reason='Test only working with docarray v2') +def test_create_aux_model_with_multiple_doclists_of_same_type(): + from docarray import DocList, BaseDoc + from jina.serve.runtimes.helper import _create_aux_model_doc_list_to_list + + class MyTextDoc(BaseDoc): + text: str + + class QuoteFile(BaseDoc): + texts: DocList[MyTextDoc] + + class QuoteFileType(BaseDoc): + """ + QuoteFileType class. + """ + id: str = None # same as name, compatibility reasons for a generic, shared `id` field + name: str = None + total_count: int = None + docs: DocList[QuoteFile] = None + chunks: DocList[QuoteFile] = None + + new_model = _create_aux_model_doc_list_to_list(QuoteFileType) + new_model.schema()