Skip to content

Commit

Permalink
fix: issue in gcp app
Browse files Browse the repository at this point in the history
  • Loading branch information
zac-li committed Jan 16, 2024
1 parent a3dfc9c commit 8a48521
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 9 deletions.
8 changes: 4 additions & 4 deletions jina/serve/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def __init__(
self._add_dynamic_batching(dynamic_batching)
self._add_runtime_args(runtime_args)
self.logger = JinaLogger(self.__class__.__name__, **vars(self.runtime_args))
self._validate_sagemaker()
self._validate_csp()
self._init_instrumentation(runtime_args)
self._init_monitoring()
self._init_workspace = workspace
Expand Down Expand Up @@ -599,14 +599,14 @@ def _add_requests(self, _requests: Optional[Dict]):
f'expect {typename(self)}.{func} to be a function, but receiving {typename(_func)}'
)

def _validate_sagemaker(self):
# sagemaker expects the POST /invocations endpoint to be defined.
def _validate_csp(self):
# csp (sagemaker/azure/gcp) expects the POST /invocations endpoint to be defined.
# if it is not defined, we check if there is only one endpoint defined,
# and if so, we use it as the POST /invocations endpoint, or raise an error
if (
not hasattr(self, 'runtime_args')
or not hasattr(self.runtime_args, 'provider')
or self.runtime_args.provider != ProviderType.SAGEMAKER.value
or self.runtime_args.provider not in (ProviderType.SAGEMAKER.value, ProviderType.GCP.value)
):
return

Expand Down
17 changes: 17 additions & 0 deletions jina/serve/runtimes/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,23 @@ def _get_server(self):
cors=getattr(self.args, 'cors', None),
is_cancel=self.is_cancel,
)
elif (
hasattr(self.args, 'provider')
and self.args.provider == ProviderType.GCP
):
from jina.serve.runtimes.servers.http import GCPHTTPServer

return GCPHTTPServer(
name=self.args.name,
runtime_args=self.args,
req_handler_cls=self.req_handler_cls,
proxy=getattr(self.args, 'proxy', None),
uvicorn_kwargs=getattr(self.args, 'uvicorn_kwargs', None),
ssl_keyfile=getattr(self.args, 'ssl_keyfile', None),
ssl_certfile=getattr(self.args, 'ssl_certfile', None),
cors=getattr(self.args, 'cors', None),
is_cancel=self.is_cancel,
)
elif not hasattr(self.args, 'protocol') or (
len(self.args.protocol) == 1 and self.args.protocol[0] == ProtocolType.GRPC
):
Expand Down
5 changes: 2 additions & 3 deletions jina/serve/runtimes/worker/http_gcp_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_fastapi_app(
from jina.serve.runtimes.gateway.models import _to_camel_case

if not docarray_v2:
logger.warning('Only docarray v2 is supported with Sagemaker. ')
logger.warning('Only docarray v2 is supported with GCP. ')
return

class Header(BaseModel):
Expand Down Expand Up @@ -129,7 +129,6 @@ async def process(body) -> output_model:
raise HTTPException(status_code=499, detail=status.description)
else:
return {"predictions": resp.docs}
return output_model(predictions=resp.docs)

@app.api_route(**app_kwargs)
async def post(request: Request):
Expand Down Expand Up @@ -175,7 +174,7 @@ async def post(request: Request):

from jina.serve.runtimes.gateway.health_model import JinaHealthModel

# `/ping` route is required by AWS Sagemaker
# `/ping` route is required by GCP
@app.get(
path='/ping',
summary='Get the health of Jina Executor service',
Expand Down
2 changes: 1 addition & 1 deletion jina/serve/runtimes/worker/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def _init_monitoring(
if metrics_registry:
with ImportExtensions(
required=True,
help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina',
help_text='You need to install the `prometheus_client` to use the monitoring functionality of jina',
):
from prometheus_client import Counter, Summary

Expand Down
22 changes: 21 additions & 1 deletion tests/integration/docarray_v2/gcp/test_gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,25 @@ def test_provider_gcp_pod_inference():
assert resp.status_code == 200
resp_json = resp.json()
assert len(resp_json['predictions']) == 2
print(resp_json)


def test_provider_gcp_deployment_inference():
with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')):
dep_port = random_port()
with Deployment(uses='config.yml', provider='gcp', port=dep_port):
# Test the `GET /ping` endpoint (added by jina for gcp)
resp = requests.get(f'http://localhost:{dep_port}/ping')
assert resp.status_code == 200
assert resp.json() == {}

# Test the `POST /invocations` endpoint
# Note: this endpoint is not implemented in the sample executor
resp = requests.post(
f'http://localhost:{dep_port}/invocations',
json={
'instances': ["hello world", "good apple"]
},
)
assert resp.status_code == 200
resp_json = resp.json()
assert len(resp_json['predictions']) == 2

0 comments on commit 8a48521

Please sign in to comment.