Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsmechtel committed Dec 14, 2024
1 parent 2dd8dce commit a1bffaf
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions bioimageio_colab/register_sam_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def parse_requirements(file_path) -> list:
lines = file.readlines()
# Filter and clean package names (skip comments and empty lines)
skip_lines = ("#", "-r ", "ray")
packages = [line.strip() for line in lines if line.strip() and not line.startswith(skip_lines)]
packages = [
line.strip()
for line in lines
if line.strip() and not line.startswith(skip_lines)
]
return packages


Expand All @@ -59,7 +63,9 @@ def compute_image_embedding(
"""
try:
user_id = context["user"].get("id") if context else "anonymous"
logger.info(f"User '{user_id}' - Computing embedding (model: '{model_name}')...")
logger.info(
f"User '{user_id}' - Computing embedding (model: '{model_name}')..."
)

sam_predictor = load_model_from_ckpt(
model_name=model_name,
Expand All @@ -76,7 +82,8 @@ def compute_image_embedding(
except Exception as e:
logger.error(f"User '{user_id}' - Error computing embedding: {e}")
raise e



@ray.remote
def compute_image_embedding_ray(kwargs: dict) -> np.ndarray:
from bioimageio_colab.register_sam_service import compute_image_embedding
Expand Down Expand Up @@ -112,15 +119,17 @@ def compute_mask(

# Set the embedding
sam_predictor.original_size = image_size
sam_predictor.input_size = tuple([sam_predictor.model.image_encoder.img_size] * 2)
sam_predictor.input_size = tuple(
[sam_predictor.model.image_encoder.img_size] * 2
)
sam_predictor.features = torch.as_tensor(embedding, device=sam_predictor.device)
sam_predictor.is_image_set = True

# Segment the image
masks = segment_image(
sam_predictor=sam_predictor,
point_coords=point_coords,
point_labels=point_labels
point_labels=point_labels,
)

if format == "mask":
Expand All @@ -134,7 +143,8 @@ def compute_mask(
except Exception as e:
logger.error(f"User '{user_id}' - Error segmenting image: {e}")
raise e



@ray.remote
def compute_mask_ray(kwargs: dict) -> np.ndarray:
from bioimageio_colab.register_sam_service import compute_mask
Expand Down Expand Up @@ -192,7 +202,7 @@ async def register_service(args: dict) -> None:
sam_requirements = parse_requirements("../requirements-sam.txt")
runtime_env = {
"pip": base_requirements + sam_requirements,
"py_modules": ["../bioimageio_colab"]
"py_modules": ["../bioimageio_colab"],
}

# Connect to Ray
Expand All @@ -209,8 +219,11 @@ def compute_mask_function(**kwargs: dict):
def test_model_function(**kwargs: dict):
kwargs["cache_dir"] = cache_dir
return ray.get(test_model_ray.remote(kwargs))

else:
compute_embedding_function = partial(compute_image_embedding, cache_dir=cache_dir)
compute_embedding_function = partial(
compute_image_embedding, cache_dir=cache_dir
)
compute_mask_function = partial(compute_mask, cache_dir=cache_dir)
test_model_function = partial(test_model, cache_dir=cache_dir)

Expand Down

0 comments on commit a1bffaf

Please sign in to comment.