Skip to content

Commit

Permalink
fix: use body for streaming instead of params (#6098)
Browse files Browse the repository at this point in the history
  • Loading branch information
NarekA authored Oct 27, 2023
1 parent 2cee961 commit ab2cc19
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 115 deletions.
10 changes: 6 additions & 4 deletions docs/concepts/serving/executor/add-endpoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,12 @@ Streaming endpoints receive one Document as input and yields one Document at a t

Streaming endpoints are only supported for HTTP and gRPC protocols and for Deployment and Flow with one single Executor.

For HTTP deployment streaming executors generate both a GET and POST endpoint.
The GET endpoint support documents with string, integer, or float fields only,
whereas, POST requests support all docarrays.
The Jina client uses the POST endpoints.
For HTTP deployment streaming executors generate a GET endpoint.
The GET endpoint support passing documet fields in
the request body or as URL query parameters,
however, query parameters only support string, integer, or float fields,
whereas, the request body support all serializable docarrays.
The Jina client uses the request body.
```

A streaming endpoint has the following signature:
Expand Down
5 changes: 3 additions & 2 deletions jina/clients/base/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,14 @@ async def send_streaming_message(self, doc: 'Document', on: str):
:param on: Request endpoint
:yields: responses
"""
req_dict = doc.to_dict() if hasattr(doc, "to_dict") else doc.dict()
request_kwargs = {
'url': self.url,
'headers': {'Accept': 'text/event-stream'},
'json': doc.dict() if docarray_v2 else doc.to_dict(),
'json': req_dict,
}

async with self.session.post(**request_kwargs) as response:
async with self.session.get(**request_kwargs) as response:
async for chunk in response.content.iter_any():
events = chunk.split(b'event: ')[1:]
for event in events:
Expand Down
30 changes: 4 additions & 26 deletions jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,35 +252,13 @@ def add_streaming_routes(
methods=['GET'],
summary=f'Streaming Endpoint {endpoint_path}',
)
async def streaming_get(request: Request):
query_params = dict(request.query_params)
async def streaming_get(request: Request, body: input_doc_model = None):
body = body or dict(request.query_params)
body = input_doc_model.parse_obj(body)

async def event_generator():
async for doc, error in streamer.stream_doc(
doc=input_doc_model(**query_params), exec_endpoint=endpoint_path
):
if error:
raise HTTPException(status_code=499, detail=str(error))
yield {
'event': 'update',
'data': doc.dict()
}
yield {
'event': 'end'
}

return EventSourceResponse(event_generator())

@app.api_route(
path=f'/{endpoint_path.strip("/")}',
methods=['POST'],
summary=f'Streaming Endpoint {endpoint_path}',
)
async def streaming_post(body: dict):

async def event_generator():
async for doc, error in streamer.stream_doc(
doc=input_doc_model.parse_obj(body), exec_endpoint=endpoint_path
doc=body, exec_endpoint=endpoint_path
):
if error:
raise HTTPException(status_code=499, detail=str(error))
Expand Down
13 changes: 12 additions & 1 deletion jina/serve/runtimes/gateway/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,19 @@ async def _load_balance(self, request):

try:
async with aiohttp.ClientSession() as session:

if request.method == 'GET':
async with session.get(target_url) as response:
request_kwargs = {}
try:
payload = await request.json()
if payload:
request_kwargs['json'] = payload
except Exception:
self.logger.debug('No JSON payload found in request')

async with session.get(
url=target_url, **request_kwargs
) as response:
# Create a StreamResponse with the same headers and status as the target response
stream_response = web.StreamResponse(
status=response.status,
Expand Down
68 changes: 32 additions & 36 deletions jina/serve/runtimes/worker/http_fastapi_app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union

from jina import DocumentArray, Document
from jina import Document, DocumentArray
from jina._docarray import docarray_v2
from jina.importer import ImportExtensions
from jina.serve.networking.sse import EventSourceResponse
Expand All @@ -11,15 +11,15 @@
from jina.logging.logger import JinaLogger

if docarray_v2:
from docarray import DocList, BaseDoc
from docarray import BaseDoc, DocList


def get_fastapi_app(
request_models_map: Dict,
caller: Callable,
logger: 'JinaLogger',
cors: bool = False,
**kwargs,
request_models_map: Dict,
caller: Callable,
logger: 'JinaLogger',
cors: bool = False,
**kwargs,
):
"""
Get the app from FastAPI as the REST interface.
Expand All @@ -35,15 +35,18 @@ def get_fastapi_app(
from fastapi import FastAPI, Response, HTTPException
import pydantic
from fastapi.middleware.cors import CORSMiddleware
import os

from pydantic import BaseModel, Field
from pydantic.config import BaseConfig, inherit_config

from jina.proto import jina_pb2
from jina.serve.runtimes.gateway.models import _to_camel_case
import os

class Header(BaseModel):
request_id: Optional[str] = Field(description='Request ID', example=os.urandom(16).hex())
request_id: Optional[str] = Field(
description='Request ID', example=os.urandom(16).hex()
)

class Config(BaseConfig):
alias_generator = _to_camel_case
Expand All @@ -66,11 +69,11 @@ class InnerConfig(BaseConfig):
logger.warning('CORS is enabled. This service is accessible from any website!')

def add_post_route(
endpoint_path,
input_model,
output_model,
input_doc_list_model=None,
output_doc_list_model=None,
endpoint_path,
input_model,
output_model,
input_doc_list_model=None,
output_doc_list_model=None,
):
app_kwargs = dict(
path=f'/{endpoint_path.strip("/")}',
Expand Down Expand Up @@ -123,8 +126,8 @@ async def post(body: input_model, response: Response):
return ret

def add_streaming_routes(
endpoint_path,
input_doc_model=None,
endpoint_path,
input_doc_model=None,
):
from fastapi import Request

Expand All @@ -133,26 +136,17 @@ def add_streaming_routes(
methods=['GET'],
summary=f'Streaming Endpoint {endpoint_path}',
)
async def streaming_get(request: Request):
query_params = dict(request.query_params)
req = DataRequest()
req.header.exec_endpoint = endpoint_path
if not docarray_v2:
req.data.docs = DocumentArray([Document.from_dict(query_params)])
else:
req.document_array_cls = DocList[input_doc_model]
req.data.docs = DocList[input_doc_model](
[input_doc_model(**query_params)]
async def streaming_get(request: Request = None, body: input_doc_model = None):
if not body:
query_params = dict(request.query_params)
body = (
input_doc_model.parse_obj(query_params)
if docarray_v2
else Document.from_dict(query_params)
)
event_generator = _gen_dict_documents(await caller(req))
return EventSourceResponse(event_generator)

@app.api_route(
path=f'/{endpoint_path.strip("/")}',
methods=['POST'],
summary=f'Streaming Endpoint {endpoint_path}',
)
async def streaming_post(body: input_doc_model, request: Request):
else:
if not docarray_v2:
body = Document.from_pydantic_model(body)
req = DataRequest()
req.header.exec_endpoint = endpoint_path
if not docarray_v2:
Expand All @@ -169,7 +163,9 @@ async def streaming_post(body: input_doc_model, request: Request):
output_doc_model = input_output_map['output']['model']
is_generator = input_output_map['is_generator']
parameters_model = input_output_map['parameters']['model'] or Optional[Dict]
default_parameters = ... if input_output_map['parameters']['model'] else None
default_parameters = (
... if input_output_map['parameters']['model'] else None
)

if docarray_v2:
_config = inherit_config(InnerConfig, BaseDoc.__config__)
Expand Down
132 changes: 96 additions & 36 deletions tests/integration/docarray_v2/test_issues.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import List, Optional, Dict
from typing import Dict, List, Optional

import aiohttp
import pytest
from docarray import BaseDoc, DocList
from pydantic import Field

from jina import Executor, Flow, requests, Deployment, Client
from jina import Client, Deployment, Executor, Flow, requests
from jina.clients.base.helper import HTTPClientlet


class Nested2Doc(BaseDoc):
Expand Down Expand Up @@ -78,6 +80,7 @@ def test_issue_6019_with_nested_list():
assert res[0].text == 'hello world'
assert res[0].nested[0].nested.value == 'test'


def test_issue_6084():
class EnvInfo(BaseDoc):
history: str = ''
Expand All @@ -86,7 +89,6 @@ class A(BaseDoc):
b: EnvInfo

class MyIssue6084Exec(Executor):

@requests
def foo(self, docs: DocList[A], **kwargs) -> DocList[A]:
pass
Expand All @@ -96,48 +98,106 @@ def foo(self, docs: DocList[A], **kwargs) -> DocList[A]:
pass


class NestedFieldSchema(BaseDoc):
name: str = "test_name"
dict_field: Dict = Field(default_factory=dict)


class InputWithComplexFields(BaseDoc):
text: str = "test"
nested_field: NestedFieldSchema = Field(default_factory=NestedFieldSchema)
dict_field: Dict = Field(default_factory=dict)
bool_field: bool = False


class SimpleInput(BaseDoc):
text: str = "test"


class MyExecutor(Executor):
@requests(on="/stream")
async def stream(
self,
doc: InputWithComplexFields,
parameters: Optional[Dict] = None,
**kwargs,
) -> InputWithComplexFields:
for i in range(4):
yield InputWithComplexFields(text=f"hello world {doc.text} {i}")

@requests(on="/stream-simple")
async def stream_simple(
self,
doc: SimpleInput,
parameters: Optional[Dict] = None,
**kwargs,
) -> SimpleInput:
for i in range(4):
yield SimpleInput(text=f"hello world {doc.text} {i}")


@pytest.fixture(scope="module")
def streaming_deployment():
protocol = "http"
with Deployment(uses=MyExecutor, protocol=protocol) as dep:
yield dep


@pytest.mark.asyncio
async def test_issue_6090():
async def test_issue_6090(streaming_deployment):
"""Tests if streaming works with pydantic models with complex fields which are not
str, int, or float.
"""

class NestedFieldSchema(BaseDoc):
name: str = "test_name"
dict_field: Dict = Field(default_factory=dict)

class InputWithComplexFields(BaseDoc):
text: str = "test"
nested_field: NestedFieldSchema = Field(default_factory=NestedFieldSchema)
dict_field: Dict = Field(default_factory=dict)
bool_field: bool = False

class MyExecutor(Executor):
@requests(on="/stream")
async def stream(
self, doc: InputWithComplexFields, parameters: Optional[Dict] = None, **kwargs
) -> InputWithComplexFields:
for i in range(4):
yield InputWithComplexFields(text=f"hello world {doc.text} {i}")

docs = []
protocol = "http"
with Deployment(uses=MyExecutor, protocol=protocol) as dep:
client = Client(port=dep.port, protocol=protocol, asyncio=True)
example_doc = InputWithComplexFields(text="my input text")
async for doc in client.stream_doc(
on="/stream",
inputs=example_doc,
input_type=InputWithComplexFields,
return_type=InputWithComplexFields,
):
docs.append(doc)
client = Client(port=streaming_deployment.port, protocol=protocol, asyncio=True)
example_doc = InputWithComplexFields(text="my input text")
async for doc in client.stream_doc(
on="/stream",
inputs=example_doc,
input_type=InputWithComplexFields,
return_type=InputWithComplexFields,
):
docs.append(doc)

assert [d.text for d in docs] == [
"hello world my input text 0",
"hello world my input text 1",
"hello world my input text 2",
"hello world my input text 3",
'hello world my input text 0',
'hello world my input text 1',
'hello world my input text 2',
'hello world my input text 3',
]
assert docs[0].nested_field.name == "test_name"


@pytest.mark.asyncio
async def test_issue_6090_get_params(streaming_deployment):
"""Tests if streaming works with pydantic models with complex fields which are not
str, int, or float.
"""

docs = []
url = (
f"htto://localhost:{streaming_deployment.port}/stream-simple?text=my_input_text"
)
async with aiohttp.ClientSession() as session:

async with session.get(url) as resp:
async for chunk in resp.content.iter_any():
print(chunk)
events = chunk.split(b'event: ')[1:]
for event in events:
if event.startswith(b'update'):
parsed = event[HTTPClientlet.UPDATE_EVENT_PREFIX:].decode()
parsed = SimpleInput.parse_raw(parsed)
print(parsed)
docs.append(parsed)
elif event.startswith(b'end'):
pass

assert [d.text for d in docs] == [
'hello world my_input_text 0',
'hello world my_input_text 1',
'hello world my_input_text 2',
'hello world my_input_text 3',
]
Loading

0 comments on commit ab2cc19

Please sign in to comment.