Skip to content

Commit

Permalink
fix: add and use post endpoint for streaming (#6093)
Browse files Browse the repository at this point in the history
  • Loading branch information
NarekA authored Oct 25, 2023
1 parent ee95f6e commit f2085a9
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 22 deletions.
5 changes: 5 additions & 0 deletions docs/concepts/serving/executor/add-endpoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,11 @@ Streaming endpoints receive one Document as input and yields one Document at a t
:class: note

Streaming endpoints are only supported for HTTP and gRPC protocols and for Deployment and Flow with one single Executor.

For HTTP deployment streaming executors generate both a GET and POST endpoint.
The GET endpoint support documents with string, integer, or float fields only,
whereas, POST requests support all docarrays.
The Jina client uses the POST endpoints.
```

A streaming endpoint has the following signature:
Expand Down
10 changes: 2 additions & 8 deletions jina/clients/base/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,13 @@ async def send_streaming_message(self, doc: 'Document', on: str):
:param on: Request endpoint
:yields: responses
"""
if docarray_v2:
req_dict = doc.dict()
else:
req_dict = doc.to_dict()

request_kwargs = {
'url': self.url,
'headers': {'Accept': 'text/event-stream'},
'json': doc.dict() if docarray_v2 else doc.to_dict(),
}
req_dict = {key: value for key, value in req_dict.items() if value is not None}
request_kwargs['params'] = req_dict

async with self.session.get(**request_kwargs) as response:
async with self.session.post(**request_kwargs) as response:
async for chunk in response.content.iter_any():
events = chunk.split(b'event: ')[1:]
for event in events:
Expand Down
27 changes: 25 additions & 2 deletions jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ async def post(body: input_model, response: Response):
)
return result

def add_streaming_get_route(
def add_streaming_routes(
endpoint_path,
input_doc_model=None,
):
Expand All @@ -258,6 +258,29 @@ async def streaming_get(request: Request):
async def event_generator():
async for doc, error in streamer.stream_doc(
doc=input_doc_model(**query_params), exec_endpoint=endpoint_path
):
if error:
raise HTTPException(status_code=499, detail=str(error))
yield {
'event': 'update',
'data': doc.dict()
}
yield {
'event': 'end'
}

return EventSourceResponse(event_generator())

@app.api_route(
path=f'/{endpoint_path.strip("/")}',
methods=['POST'],
summary=f'Streaming Endpoint {endpoint_path}',
)
async def streaming_post(body: dict):

async def event_generator():
async for doc, error in streamer.stream_doc(
doc=input_doc_model.parse_obj(body), exec_endpoint=endpoint_path
):
if error:
raise HTTPException(status_code=499, detail=str(error))
Expand Down Expand Up @@ -293,7 +316,7 @@ async def event_generator():
)

if is_generator:
add_streaming_get_route(
add_streaming_routes(
endpoint,
input_doc_model=input_doc_model,
)
Expand Down
26 changes: 21 additions & 5 deletions jina/serve/runtimes/worker/http_fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ async def post(body: input_model, response: Response):
ret = output_model(data=docs_response, parameters=resp.parameters)
return ret

def add_streaming_get_route(
def add_streaming_routes(
endpoint_path,
input_doc_model=None,
):
Expand All @@ -138,12 +138,28 @@ async def streaming_get(request: Request):
req = DataRequest()
req.header.exec_endpoint = endpoint_path
if not docarray_v2:
from docarray import Document

req.data.docs = DocumentArray([Document.from_dict(query_params)])
else:
req.document_array_cls = DocList[input_doc_model]
req.data.docs = DocList[input_doc_model]([input_doc_model(**query_params)])
req.data.docs = DocList[input_doc_model](
[input_doc_model(**query_params)]
)
event_generator = _gen_dict_documents(await caller(req))
return EventSourceResponse(event_generator)

@app.api_route(
path=f'/{endpoint_path.strip("/")}',
methods=['POST'],
summary=f'Streaming Endpoint {endpoint_path}',
)
async def streaming_post(body: input_doc_model, request: Request):
req = DataRequest()
req.header.exec_endpoint = endpoint_path
if not docarray_v2:
req.data.docs = DocumentArray([body])
else:
req.document_array_cls = DocList[input_doc_model]
req.data.docs = DocList[input_doc_model]([body])
event_generator = _gen_dict_documents(await caller(req))
return EventSourceResponse(event_generator)

Expand Down Expand Up @@ -176,7 +192,7 @@ async def streaming_get(request: Request):
)

if is_generator:
add_streaming_get_route(
add_streaming_routes(
endpoint,
input_doc_model=input_doc_model,
)
Expand Down
53 changes: 51 additions & 2 deletions tests/integration/docarray_v2/test_issues.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import List, Optional
from typing import List, Optional, Dict

import pytest
from docarray import BaseDoc, DocList
from pydantic import Field

from jina import Executor, Flow, requests
from jina import Executor, Flow, requests, Deployment, Client


class Nested2Doc(BaseDoc):
Expand Down Expand Up @@ -92,3 +94,50 @@ def foo(self, docs: DocList[A], **kwargs) -> DocList[A]:
f = Flow().add(uses=MyIssue6084Exec).add(uses=MyIssue6084Exec)
with f:
pass


@pytest.mark.asyncio
async def test_issue_6090():
"""Tests if streaming works with pydantic models with complex fields which are not
str, int, or float.
"""

class NestedFieldSchema(BaseDoc):
name: str = "test_name"
dict_field: Dict = Field(default_factory=dict)

class InputWithComplexFields(BaseDoc):
text: str = "test"
nested_field: NestedFieldSchema = Field(default_factory=NestedFieldSchema)
dict_field: Dict = Field(default_factory=dict)
bool_field: bool = False

class MyExecutor(Executor):
@requests(on="/stream")
async def stream(
self, doc: InputWithComplexFields, parameters: Optional[Dict] = None, **kwargs
) -> InputWithComplexFields:
for i in range(4):
yield InputWithComplexFields(text=f"hello world {doc.text} {i}")

docs = []
protocol = "http"
with Deployment(uses=MyExecutor, protocol=protocol) as dep:
client = Client(port=dep.port, protocol=protocol, asyncio=True)
example_doc = InputWithComplexFields(text="my input text")
async for doc in client.stream_doc(
on="/stream",
inputs=example_doc,
input_type=InputWithComplexFields,
return_type=InputWithComplexFields,
):
docs.append(doc)

assert [d.text for d in docs] == [
"hello world my input text 0",
"hello world my input text 1",
"hello world my input text 2",
"hello world my input text 3",
]
assert docs[0].nested_field.name == "test_name"

12 changes: 7 additions & 5 deletions tests/integration/docarray_v2/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,19 @@ async def test_streaming_delay(protocol, include_gateway):
):
client = Client(port=port, protocol=protocol, asyncio=True)
i = 0
start_time = time.time()
async for doc in client.stream_doc(
stream = client.stream_doc(
on='/hello',
inputs=MyDocument(text='hello world', number=i),
return_type=MyDocument,
):
)
start_time = None
async for doc in stream:
start_time = start_time or time.time()
assert doc.text == f'hello world {i}'
i += 1

delay = time.time() - start_time
# 0.5 seconds between each request + 0.5 seconds tolerance interval
assert time.time() - start_time < (0.5 * i) + 0.5
assert delay < (0.5 * i), f'Expected delay to be less than {0.5 * i}, got {delay} on iteration {i}'


@pytest.mark.asyncio
Expand Down

0 comments on commit f2085a9

Please sign in to comment.