Skip to content

Commit

Permalink
chore: add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MrBlenny committed Jun 9, 2022
1 parent 26d8ac4 commit 7d75b81
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 11 deletions.
8 changes: 5 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
FROM pytorchlightning/pytorch_lightning:1.6.4-py3.9-torch1.9
LABEL org.opencontainers.image.source=https://github.com/Greenroom-Robotics/nuclio_lighting_flash

COPY ./nuclio_lighting_flash /opt/nuclio

# Remote nvidia lists as they have a borked GPG
RUN rm /etc/apt/sources.list.d/nvidia-ml.list /etc/apt/sources.list.d/cuda.list

Expand All @@ -13,6 +11,10 @@ RUN apt-get install ffmpeg libsm6 libxext6 -y
# Install lighting flash and it's deps
RUN pip install lightning-flash icevision 'lightning-flash[image]'

COPY ./nuclio_lighting_flash /opt/nuclio

WORKDIR /opt/nuclio

CMD python ./test_flash_model_handler.py
# Run the tests
# Nuclio will overwrite this CMD when it is deployed
CMD pytest .
38 changes: 38 additions & 0 deletions nuclio_lighting_flash/fixtures/fixtures_nuclio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import base64
from dataclasses import dataclass


class Logger:
def info(self, message: str):
...


class UserData:
model = None


@dataclass
class Response:
body: str
headers: dict
content_type: str
status_code: int


class Context:
logger = Logger()
user_data = UserData()
Response = Response


class Event:
def __init__(self, image_path: str, threshold: float):
"""
Create an event with the image body serialised as base64
"""
with open(image_path, "rb") as image_file:
image_base64 = base64.b64encode(image_file.read()).decode("ascii")
self.body = {
"image": image_base64,
"threshold": threshold,
}
23 changes: 15 additions & 8 deletions nuclio_lighting_flash/test_flash_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@

from flash_model_handler import FlashModelHandler

model = ObjectDetector(
head="efficientdet", backbone="d0", num_classes=91, image_size=1024
)
model_handler = FlashModelHandler(model=model, image_size=1024, labels={25: "giraffe"})
image = Image.open(os.path.join(os.getcwd(), "./fixtures/giraffe.jpg"))

result = model_handler.infer(image, 0)
print(result)

def test_flash_model_handler():
"""
Runs the flash model handler (bypassing nuclio)
Should detect giraffes in the COCO image
"""
model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=91, image_size=1024)
model_handler = FlashModelHandler(model=model, image_size=1024, labels={25: "giraffe"})
image = Image.open(os.path.join(os.getcwd(), "./fixtures/giraffe.jpg"))

result = model_handler.infer(image, 0.0)

assert len(result) == 2
assert result[0]["label"] == "giraffe"
assert result[1]["label"] == "giraffe"
37 changes: 37 additions & 0 deletions nuclio_lighting_flash/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import json
import os

from main import init_context, handler
from fixtures.fixtures_nuclio import Context, Event


def test_main():
"""
Runs main nuclio function with a fake nuclio context
"""
context = Context()
init_context(context)

event = Event(image_path=os.path.join(os.getcwd(), "./fixtures/giraffe.jpg"), threshold=0.5)
response = handler(context, event)
assert response.body == json.dumps(
[
{
"confidence": 0.9171872138977051,
"label": "giraffe",
"points": [596.2103271484375, 276.046875, 962.4700927734375, 759.198974609375],
"type": "rectangle",
},
{
"confidence": 0.8744097352027893,
"label": "giraffe",
"points": [
84.43610382080078,
734.1155395507812,
297.03192138671875,
840.5248413085938,
],
"type": "rectangle",
},
]
)

0 comments on commit 7d75b81

Please sign in to comment.