Skip to content

Commit

Permalink
validation websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
Noza23 committed Jan 15, 2024
1 parent bf5f396 commit 8f05dad
Showing 1 changed file with 80 additions and 11 deletions.
91 changes: 80 additions & 11 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from contextlib import asynccontextmanager
from functools import lru_cache
from typing import Annotated
import json

from fastapi import (
FastAPI,
Expand All @@ -10,13 +9,17 @@
UploadFile,
File,
HTTPException,
WebSocket,
WebSocketException,
status,
)
from fastapi.responses import Response
from fastapi.middleware.cors import CORSMiddleware
from redis import asyncio as aioredis # type: ignore

from myo_sam.inference.pipeline import Pipeline
from myo_sam.inference.models.base import Nucleis, Myotubes
from myo_sam.inference.models.base import Nucleis, Myotubes, MyoObjects
from myo_sam.inference.utils import hash_bytes

from .models import (
Settings,
Expand Down Expand Up @@ -120,17 +123,23 @@ async def run_validation(

if await is_cached(keys.result_key(img_hash), redis):
# Case when image is cached
myos = json.loads(await redis.get(keys.result_key(img_hash)))
myos = Myotubes.model_validate_json(
await redis.get(keys.result_key(img_hash))
)
state = State.model_validate_json(
await redis.get(keys.state_key(img_hash))
)
path = await redis.get(keys.image_path_key(img_hash))
img_drawn = pipeline.draw_contours_on_myotube_image(
myos.filter_by_ids(state.valid), thickness=2
)
img_drawn_hash = hash_bytes(img_drawn)
path = await redis.get(keys.image_path_key(img_drawn_hash))
if not path:
# path might be cleaned by regular image cleaning
path = get_fp(settings.images_dir)
_ = pipeline.save_myotube_image(path)
pipeline.save_myotube_image(path, img_drawn)
background_tasks.add_task(
set_cache, {keys.image_path_key(img_hash): path}, redis
set_cache, {keys.image_path_key(img_drawn_hash): path}, redis
)
else:
# Case when image is not cached
Expand All @@ -151,11 +160,71 @@ async def run_validation(
redis,
)

return ValidationResponse(
roi_coords=[myo.roi_coords for myo in myos],
state=state,
hash_str=img_hash,
return ValidationResponse(state=state, hash_str=img_hash, image_path=path)


# Validation Socket: 0 = Invalid, 1 = Valid, 2 = Skip, -1 = Undo
# Before starting socket front gets state on start it gets the first contour
@app.websocket("/validation/{hash}/")
async def validation_ws(
websocket: WebSocket,
hash: str,
background_tasks: BackgroundTasks,
keys: Annotated[REDIS_KEYS, Depends(REDIS_KEYS)],
redis: Annotated[aioredis.Redis, Depends(setup_redis)],
):
"""Websocket for validation mode."""
await websocket.accept()
mo = MyoObjects.model_validate_json(await redis.get(keys.result_key(hash)))
state = State.model_validate_json(await redis.get(keys.state_key(hash)))
if state.done:
await websocket.send_text("done")
websocket.close(reason="Validation done.")
i = state.get_next()
# Starting Contour send on connection openning
await websocket.send_json(
{"roi_coords": mo[i].roi_coords, "contour_id": i}
)
while True:
if len(mo) == i:
# When front gets "done" it should tell the user that validation is done
# and give the option to download the results until closing the websocket
state.done = True
websocket.send_text("done")
# Wating for response from front
data = int(await websocket.receive_text())
# Invalid = 0, Valid = 1, Skip = 2, Undo = -1
assert data in (0, 1, 2, -1)
if data == 0:
# Invalid contour
state.invalid.add(i)
elif data == 1:
# Valid contour
state.valid.add(i)
elif data == 2:
# Skip contour: move to end and recache result
_ = mo.move_object_to_end(i)
await set_cache(
{keys.result_key(hash): mo.model_dump_json()}, redis
)
elif data == -1:
# Undo contour
state.valid.discard(i)
state.invalid.discard(i)
else:
raise WebSocketException(
status_code=status.WS_1008_POLICY_VIOLATION,
detail="Invalid data received.",
)
# Update state in cache
background_tasks.add_task(
set_cache, {keys.state_key(hash): state.model_dump_json()}, redis
)
# Send next contour
step = data != 2 if data != -1 else -1
await websocket.send_json(
{"roi_coords": mo[i + step].roi_coords, "contour_id": i + step}
)


@app.post("/inference/", response_model=InferenceResponse)
Expand Down Expand Up @@ -223,7 +292,7 @@ async def run_inference(


@app.get("/redis_status/")
async def status(redis: Annotated[aioredis.Redis, Depends(setup_redis)]):
async def redis_status(redis: Annotated[aioredis.Redis, Depends(setup_redis)]):
"""check status of the redis connection"""
try:
status = await redis.ping()
Expand Down

0 comments on commit 8f05dad

Please sign in to comment.