diff --git a/nuclio_lighting_flash/main.py b/nuclio_lighting_flash/main.py index 8dd6cf7..bb728f4 100644 --- a/nuclio_lighting_flash/main.py +++ b/nuclio_lighting_flash/main.py @@ -3,12 +3,24 @@ import base64 from PIL import Image import io +import numpy as np import yaml from flash.image import ObjectDetector from flash_model_handler import FlashModelHandler +class NpEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + return json.JSONEncoder.default(self, obj) + + def init_context(context): context.logger.info("Init context... 0%") @@ -20,6 +32,10 @@ def init_context(context): labels_spec = annotations["spec"] labels = {item["id"]: item["name"] for item in json.loads(labels_spec)} + print(f"Model head: {annotations['head']}") + print(f"Model backbone: {annotations['backbone']}") + print(f"Num classes: {len(labels)}") + # Read the DL model # Either "checkpoint_path" or "head" and "backbone" should be specified model = ( @@ -48,7 +64,7 @@ def handler(context, event): results = context.user_data.model.infer(image, threshold) return context.Response( - body=json.dumps(results), + body=json.dumps(results, cls=NpEncoder), headers={}, content_type="application/json", status_code=200,