From 7d75b81889544ac3019eaafef2ff688e33f5ef6f Mon Sep 17 00:00:00 2001 From: David Revay Date: Thu, 9 Jun 2022 13:46:34 +1000 Subject: [PATCH] chore: add unit tests --- Dockerfile | 8 ++-- .../fixtures/fixtures_nuclio.py | 38 +++++++++++++++++++ .../test_flash_model_handler.py | 23 +++++++---- nuclio_lighting_flash/test_main.py | 37 ++++++++++++++++++ 4 files changed, 95 insertions(+), 11 deletions(-) create mode 100644 nuclio_lighting_flash/fixtures/fixtures_nuclio.py create mode 100644 nuclio_lighting_flash/test_main.py diff --git a/Dockerfile b/Dockerfile index 3d8fc04..aae4da2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,6 @@ FROM pytorchlightning/pytorch_lightning:1.6.4-py3.9-torch1.9 LABEL org.opencontainers.image.source=https://github.com/Greenroom-Robotics/nuclio_lighting_flash -COPY ./nuclio_lighting_flash /opt/nuclio - # Remote nvidia lists as they have a borked GPG RUN rm /etc/apt/sources.list.d/nvidia-ml.list /etc/apt/sources.list.d/cuda.list @@ -13,6 +11,10 @@ RUN apt-get install ffmpeg libsm6 libxext6 -y # Install lighting flash and it's deps RUN pip install lightning-flash icevision 'lightning-flash[image]' +COPY ./nuclio_lighting_flash /opt/nuclio + WORKDIR /opt/nuclio -CMD python ./test_flash_model_handler.py \ No newline at end of file +# Run the tests +# Nuclio will overwrite this CMD when it is deployed +CMD pytest . \ No newline at end of file diff --git a/nuclio_lighting_flash/fixtures/fixtures_nuclio.py b/nuclio_lighting_flash/fixtures/fixtures_nuclio.py new file mode 100644 index 0000000..8a9fcdd --- /dev/null +++ b/nuclio_lighting_flash/fixtures/fixtures_nuclio.py @@ -0,0 +1,38 @@ +import base64 +from dataclasses import dataclass + + +class Logger: + def info(self, message: str): + ... + + +class UserData: + model = None + + +@dataclass +class Response: + body: str + headers: dict + content_type: str + status_code: int + + +class Context: + logger = Logger() + user_data = UserData() + Response = Response + + +class Event: + def __init__(self, image_path: str, threshold: float): + """ + Create an event with the image body serialised as base64 + """ + with open(image_path, "rb") as image_file: + image_base64 = base64.b64encode(image_file.read()).decode("ascii") + self.body = { + "image": image_base64, + "threshold": threshold, + } diff --git a/nuclio_lighting_flash/test_flash_model_handler.py b/nuclio_lighting_flash/test_flash_model_handler.py index 7a8a58a..ae0f021 100644 --- a/nuclio_lighting_flash/test_flash_model_handler.py +++ b/nuclio_lighting_flash/test_flash_model_handler.py @@ -4,11 +4,18 @@ from flash_model_handler import FlashModelHandler -model = ObjectDetector( - head="efficientdet", backbone="d0", num_classes=91, image_size=1024 -) -model_handler = FlashModelHandler(model=model, image_size=1024, labels={25: "giraffe"}) -image = Image.open(os.path.join(os.getcwd(), "./fixtures/giraffe.jpg")) - -result = model_handler.infer(image, 0) -print(result) + +def test_flash_model_handler(): + """ + Runs the flash model handler (bypassing nuclio) + Should detect giraffes in the COCO image + """ + model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=91, image_size=1024) + model_handler = FlashModelHandler(model=model, image_size=1024, labels={25: "giraffe"}) + image = Image.open(os.path.join(os.getcwd(), "./fixtures/giraffe.jpg")) + + result = model_handler.infer(image, 0.0) + + assert len(result) == 2 + assert result[0]["label"] == "giraffe" + assert result[1]["label"] == "giraffe" diff --git a/nuclio_lighting_flash/test_main.py b/nuclio_lighting_flash/test_main.py new file mode 100644 index 0000000..dac0e4a --- /dev/null +++ b/nuclio_lighting_flash/test_main.py @@ -0,0 +1,37 @@ +import json +import os + +from main import init_context, handler +from fixtures.fixtures_nuclio import Context, Event + + +def test_main(): + """ + Runs main nuclio function with a fake nuclio context + """ + context = Context() + init_context(context) + + event = Event(image_path=os.path.join(os.getcwd(), "./fixtures/giraffe.jpg"), threshold=0.5) + response = handler(context, event) + assert response.body == json.dumps( + [ + { + "confidence": 0.9171872138977051, + "label": "giraffe", + "points": [596.2103271484375, 276.046875, 962.4700927734375, 759.198974609375], + "type": "rectangle", + }, + { + "confidence": 0.8744097352027893, + "label": "giraffe", + "points": [ + 84.43610382080078, + 734.1155395507812, + 297.03192138671875, + 840.5248413085938, + ], + "type": "rectangle", + }, + ] + )