Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix issue 6140 #6141

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions jina/serve/runtimes/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@
'default',
]

def _create_aux_model_doc_list_to_list(model):

def _create_aux_model_doc_list_to_list(model, cached_models=None):
cached_models = cached_models or {}

Check warning on line 104 in jina/serve/runtimes/helper.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/helper.py#L103-L104

Added lines #L103 - L104 were not covered by tests
fields: Dict[str, Any] = {}
for field_name, field in model.__annotations__.items():
if field_name not in model.__fields__:
Expand All @@ -108,18 +110,25 @@
try:
if issubclass(field, DocList):
t: Any = field.doc_type
t_aux = _create_aux_model_doc_list_to_list(t)
fields[field_name] = (List[t_aux], field_info)
if t.__name__ in cached_models:
fields[field_name] = (List[cached_models[t.__name__]], field_info)

Check warning on line 114 in jina/serve/runtimes/helper.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/helper.py#L113-L114

Added lines #L113 - L114 were not covered by tests
else:
t_aux = _create_aux_model_doc_list_to_list(t, cached_models)
cached_models[t.__name__] = t_aux
fields[field_name] = (List[t_aux], field_info)

Check warning on line 118 in jina/serve/runtimes/helper.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/helper.py#L116-L118

Added lines #L116 - L118 were not covered by tests
else:
fields[field_name] = (field, field_info)
except TypeError:
fields[field_name] = (field, field_info)
return create_model(
new_model = create_model(

Check warning on line 123 in jina/serve/runtimes/helper.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/helper.py#L123

Added line #L123 was not covered by tests
model.__name__,
__base__=model,
__validators__=model.__validators__,
**fields,
)
**fields)
cached_models[model.__name__] = new_model

Check warning on line 128 in jina/serve/runtimes/helper.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/helper.py#L128

Added line #L128 was not covered by tests

return new_model

Check warning on line 130 in jina/serve/runtimes/helper.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/helper.py#L130

Added line #L130 was not covered by tests


def _get_field_from_type(
field_schema,
Expand Down Expand Up @@ -264,6 +273,7 @@
)
return ret


def _create_pydantic_model_from_schema(
schema: Dict[str, any],
model_name: str,
Expand Down
9 changes: 5 additions & 4 deletions jina/serve/runtimes/worker/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@
]

return self.process_single_data(request, None, is_generator=is_generator)

app = get_fastapi_app(
request_models_map=request_models_map, caller=call_handle, **kwargs
)
Expand Down Expand Up @@ -1001,6 +1002,7 @@
endpoints_proto.write_endpoints.extend(list(self._executor.write_endpoints))
schemas = self._executor._get_endpoint_models_dict()
if docarray_v2:
cached_aux_models = {}

Check warning on line 1005 in jina/serve/runtimes/worker/request_handling.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/worker/request_handling.py#L1005

Added line #L1005 was not covered by tests
from docarray.documents.legacy import LegacyDocument

from jina.serve.runtimes.helper import _create_aux_model_doc_list_to_list
Expand All @@ -1011,16 +1013,14 @@
inner_dict['input']['model'] = legacy_doc_schema
else:
inner_dict['input']['model'] = _create_aux_model_doc_list_to_list(
inner_dict['input']['model']
inner_dict['input']['model'], cached_aux_models
).schema()

if inner_dict['output']['model'].schema() == legacy_doc_schema:
inner_dict['output']['model'] = legacy_doc_schema
else:
inner_dict['output']['model'] = _create_aux_model_doc_list_to_list(
inner_dict['output']['model']
inner_dict['output']['model'], cached_aux_models
).schema()

if inner_dict['parameters']['model'] is not None:
inner_dict['parameters']['model'] = inner_dict['parameters'][
'model'
Expand All @@ -1031,6 +1031,7 @@
inner_dict['output']['model'] = inner_dict['output']['model'].schema()
inner_dict['parameters'] = {}
json_format.ParseDict(schemas, endpoints_proto.schemas)
self.logger.debug('return an endpoint discovery request')
return endpoints_proto

def _extract_tracing_context(
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/serve/runtimes/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,28 @@ class SearchResult(BaseDoc):
reconstructed_in_gateway_from_Search_results = QuoteFile_reconstructed_in_gateway_from_Search_results(
texts=textlist)
assert reconstructed_in_gateway_from_Search_results.texts[0].text == 'hey'


@pytest.mark.skipif(not docarray_v2, reason='Test only working with docarray v2')
def test_create_aux_model_with_multiple_doclists_of_same_type():
from docarray import DocList, BaseDoc
from jina.serve.runtimes.helper import _create_aux_model_doc_list_to_list

class MyTextDoc(BaseDoc):
text: str

class QuoteFile(BaseDoc):
texts: DocList[MyTextDoc]

class QuoteFileType(BaseDoc):
"""
QuoteFileType class.
"""
id: str = None # same as name, compatibility reasons for a generic, shared `id` field
name: str = None
total_count: int = None
docs: DocList[QuoteFile] = None
chunks: DocList[QuoteFile] = None

new_model = _create_aux_model_doc_list_to_list(QuoteFileType)
new_model.schema()
Loading