Skip to content

Commit

Permalink
make adjust_unit endpoint simpler
Browse files Browse the repository at this point in the history
  • Loading branch information
Noza23 committed Jan 16, 2024
1 parent b5b804e commit afa0bcf
Showing 1 changed file with 29 additions and 36 deletions.
65 changes: 29 additions & 36 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from contextlib import asynccontextmanager
from functools import lru_cache
from typing import Annotated
import json

from fastapi import (
FastAPI,
Expand Down Expand Up @@ -28,9 +29,9 @@
State,
ValidationResponse,
InferenceResponse,
ObjectNames,
)
from .utils import get_fp
from pydantic import ValidationError


settings = Settings(_env_file=".env", _env_file_encoding="utf-8")
Expand Down Expand Up @@ -125,7 +126,7 @@ async def run_validation(
myos = Myotubes.model_validate_json(
await redis.get(keys.result_key(img_hash))
)
if not myos.myo_objects:
if not myos:
raise HTTPException(
status_code=404,
detail="myotubes not found for the given hash.",
Expand Down Expand Up @@ -172,32 +173,32 @@ async def run_validation(
return ValidationResponse(hash_str=img_hash, image_path=path)


# Validation Socket: 0 = Invalid, 1 = Valid, 2 = Skip, -1 = Undo
# Before starting socket front gets state on start it gets the first contour
@app.websocket("/validation/{hash}/")
@app.websocket("/validation/{hash_str}/")
async def validation_ws(
websocket: WebSocket,
hash: str,
hash_str: str,
background_tasks: BackgroundTasks,
keys: Annotated[REDIS_KEYS, Depends(REDIS_KEYS)],
redis: Annotated[aioredis.Redis, Depends(setup_redis)],
):
"""Websocket for validation mode."""
await websocket.accept()
mo = MyoObjects.model_validate_json(await redis.get(keys.result_key(hash)))
state = State.model_validate_json(await redis.get(keys.state_key(hash)))
mo = MyoObjects.model_validate_json(
await redis.get(keys.result_key(hash_str))
)
state = State.model_validate_json(
await redis.get(keys.state_key(hash_str))
)

if state.done:
await websocket.send_text("done")
websocket.close(reason="Validation done.")
i = state.get_next()
# Starting Contour send on connection openning
await websocket.send_json(
{"roi_coords": mo[i].roi_coords, "contour_id": i}
)
while True:
if len(mo) == i:
# When front gets "done" it should tell the user that validation is done
# and give the option to download the results until closing the websocket
state.done = True
websocket.send_text("done")
# Wating for response from front
Expand All @@ -214,7 +215,7 @@ async def validation_ws(
# Skip contour: move to end and recache result
_ = mo.move_object_to_end(i)
await set_cache(
{keys.result_key(hash): mo.model_dump_json()}, redis
{keys.result_key(hash_str): mo.model_dump_json()}, redis
)
elif data == -1:
# Undo contour
Expand All @@ -227,7 +228,9 @@ async def validation_ws(
)
# Update state in cache
background_tasks.add_task(
set_cache, {keys.state_key(hash): state.model_dump_json()}, redis
set_cache,
{keys.state_key(hash_str): state.model_dump_json()},
redis,
)
# Send next contour
step = data != 2 if data != -1 else -1
Expand Down Expand Up @@ -311,33 +314,23 @@ async def redis_status(redis: Annotated[aioredis.Redis, Depends(setup_redis)]):
return {"status": False}


@app.get("/adjust_unit/{obj_name}/")
@app.get("/adjust_unit/{hash_str}/{mu}")
async def adjust_unit(
obj_name: ObjectNames,
hash: str,
hash_str: str,
mu: float,
keys: Annotated[REDIS_KEYS, Depends(REDIS_KEYS)],
redis: Annotated[aioredis.Redis, Depends(setup_redis)],
):
"""Adjust unit of the metrics"""
if obj_name == ObjectNames.MYOTUBES:
myos = await redis.get(keys.result_key(hash))
if not myos:
raise HTTPException(
status_code=404,
detail="myotubes not found for the given hash.",
)
myos = Myotubes.model_validate_json(myos)
myos.adjust_measure_unit(mu)
await redis.set(keys.result_key(hash), myos.model_dump_json())
elif obj_name == ObjectNames.NUCLEIS:
nucls = await redis.get(keys.result_key(hash))
if not nucls:
raise HTTPException(
status_code=404,
detail="nucleis not found for the given hash.",
)
nucls = Nucleis.model_validate_json(nucls)
nucls.adjust_measure_unit(mu)
await redis.set(keys.result_key(hash), nucls.model_dump_json())
objs = json.loads(await redis.get(keys.result_key(hash_str)))
if not objs["myo_objects"]:
raise HTTPException(
status_code=404, detail="myo_objects not found for the given hash."
)
try:
objs = Myotubes.model_validate_json(objs)
except ValidationError:
objs = Nucleis.model_validate_json(objs)
objs.adjust_measure_unit(mu)
await redis.set(keys.result_key(hash_str), objs.model_dump_json())
return Response(status_code=200)

0 comments on commit afa0bcf

Please sign in to comment.