Skip to content

Commit

Permalink
test: add tests for metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
NarekA committed Dec 7, 2023
1 parent a8098f2 commit 0fb78c9
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 2 deletions.
2 changes: 1 addition & 1 deletion jina/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _ignore_google_warnings():

# do not change this line manually
# this is managed by proto/build-proto.sh and updated on every execution
__proto_version__ = '0.1.27'
__proto_version__ = '0.1.28'

try:
__docarray_version__ = _docarray.__version__
Expand Down
3 changes: 3 additions & 0 deletions jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,13 @@ async def post(body: input_model, response: Response, request: Request):
)
return result



def add_streaming_routes(
endpoint_path,
input_doc_model=None,
):
from fastapi import Request
@app.api_route(
path=f'/{endpoint_path.strip("/")}',
methods=['GET'],
Expand Down
3 changes: 2 additions & 1 deletion jina/serve/runtimes/worker/http_fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def post(body: input_model, response: Response, request: Request):

if body.parameters is not None:
req.parameters = body.parameters
req.metadata = dict(request.headers or {"no_headers": "true"})
req.metadata = dict(request.headers or {})
req.header.exec_endpoint = endpoint_path
data = body.data
if isinstance(data, list):
Expand Down Expand Up @@ -152,6 +152,7 @@ async def streaming_get(request: Request = None, body: input_doc_model = None):
body = Document.from_pydantic_model(body)
req = DataRequest()
req.header.exec_endpoint = endpoint_path
req.metadata = dict(request.headers or {})
if not docarray_v2:
req.data.docs = DocumentArray([body])
else:
Expand Down
150 changes: 150 additions & 0 deletions tests/integration/docarray_v2/test_metadata_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import logging
from typing import Dict, List, Literal, Optional

import pytest
from docarray import BaseDoc, DocList

from jina import Client, Deployment, Executor, requests
from jina.helper import random_port


class PortGetter:
def __init__(self):
self.ports = {
"http": {
True: random_port(),
False: random_port(),
},
"grpc": {
True: random_port(),
False: random_port(),
},
}

def get_port(self, protocol: Literal["http", "grpc"], include_gateway: bool) -> int:
return self.ports[protocol][include_gateway]

@property
def gateway_ports(self) -> List[int]:
return [self.ports["http"][True], self.ports["grpc"][True]]

@property
def no_gateway_ports(self) -> List[int]:
return [self.ports["http"][False], self.ports["grpc"][False]]


@pytest.fixture(scope='module')
def port_getter() -> callable:
getter = PortGetter()
return getter


class DictDoc(BaseDoc):
data: dict


class MetadataExecutor(Executor):
@requests(on="/get-metadata-headers")
def post_endpoint(
self,
docs: DocList[DictDoc],
parameters: Optional[Dict] = None,
metadata: Optional[Dict] = None,
**kwargs,
) -> DocList[DictDoc]:
return DocList[DictDoc]([DictDoc(data=metadata)])

@requests(on='/stream-metadata-headers')
async def stream_task(
self, doc: DictDoc, metadata: Optional[dict] = None, **kwargs
) -> DictDoc:
for k, v in sorted((metadata or {}).items()):
yield DictDoc(data={k: v})

yield DictDoc(data={"DONE": "true"})


@pytest.fixture(scope='module')
def deployment_no_gateway(port_getter: PortGetter) -> Deployment:

with Deployment(
uses=MetadataExecutor,
protocol=["http", "grpc"],
port=port_getter.no_gateway_ports,
include_gateway=False,
) as dep:
yield dep


@pytest.fixture(scope='module')
def deployment_gateway(port_getter: PortGetter) -> Deployment:

with Deployment(
uses=MetadataExecutor,
protocol=["http", "grpc"],
port=port_getter.gateway_ports,
include_gateway=False,
) as dep:
yield dep


@pytest.fixture(scope='module')
def deployments(deployment_gateway, deployment_no_gateway) -> Dict[bool, Deployment]:
return {
True: deployment_gateway,
False: deployment_no_gateway,
}


@pytest.mark.parametrize('include_gateway', [False, True])
def test_headers_in_http_metadata(
include_gateway, port_getter: PortGetter, deployments
):
port = port_getter.get_port("http", include_gateway)
data = {
"data": [{"text": "test"}],
"parameters": {
"parameter1": "value1",
},
}
logging.info(f"Posting to {port}")
client = Client(port=port, protocol="http")
resp = client.post(
on=f'/get-metadata-headers',
inputs=DocList([DictDoc(data=data)]),
headers={
"header1": "value1",
"header2": "value2",
},
return_type=DocList[DictDoc],
)
assert resp[0].data['header1'] == 'value1'


@pytest.mark.asyncio
@pytest.mark.parametrize('include_gateway', [False, True])
async def test_headers_in_http_metadata_streaming(
include_gateway, port_getter: PortGetter, deployments
):
client = Client(
port=port_getter.get_port("http", include_gateway),
protocol="http",
asyncio=True,
)
data = {"data": [{"text": "test"}], "parameters": {"parameter1": "value1"}}
chunks = []

async for doc in client.stream_doc(
on=f'/stream-metadata-headers',
inputs=DictDoc(data=data),
headers={
"header1": "value1",
"header2": "value2",
},
return_type=DictDoc,
):
chunks.append(doc)
assert len(chunks) > 2

assert DictDoc(data={'header1': 'value1'}) in chunks
assert DictDoc(data={'header2': 'value2'}) in chunks

0 comments on commit 0fb78c9

Please sign in to comment.