From de7a7ed16aa1eca21494feb6290bf1682194e981 Mon Sep 17 00:00:00 2001 From: Nils Mechtel Date: Sat, 14 Dec 2024 19:54:17 +0100 Subject: [PATCH 1/8] prepare service for ray computation --- .gitignore | 1 + bioimageio_colab/__main__.py | 45 ++- bioimageio_colab/bioengine_app.py | 89 ----- bioimageio_colab/create_workspace.py | 8 +- bioimageio_colab/register_sam_service.py | 358 +++++-------------- bioimageio_colab/sam.py | 104 ++++++ chatbot_extension/data_provider_extension.py | 48 --- docs/data_providing_service.py | 11 +- requirements-sam.txt | 6 +- requirements.txt | 4 +- 10 files changed, 253 insertions(+), 421 deletions(-) delete mode 100644 bioimageio_colab/bioengine_app.py create mode 100644 bioimageio_colab/sam.py delete mode 100644 chatbot_extension/data_provider_extension.py diff --git a/.gitignore b/.gitignore index 102de1a..c545fb4 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ data/ .DS_Store *.onnx *.pt +*.pth *tif *zip visualize_annotation.py diff --git a/bioimageio_colab/__main__.py b/bioimageio_colab/__main__.py index f24b7c2..5d8d4e6 100644 --- a/bioimageio_colab/__main__.py +++ b/bioimageio_colab/__main__.py @@ -1,5 +1,42 @@ -def main(): - print("Welcome to the bioimageio-colab") +import argparse +import asyncio -if __name__ == '__main__': - main() +from bioimageio_colab.register_sam_service import register_service + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Register SAM annotation service on BioImageIO Colab workspace." + ) + parser.add_argument( + "--server_url", + default="https://hypha.aicell.io", + help="URL of the Hypha server", + ) + parser.add_argument( + "--workspace_name", default="bioimageio-colab", help="Name of the workspace" + ) + parser.add_argument( + "--service_id", + default="microsam", + help="Service ID for registering the service", + ) + parser.add_argument( + "--token", + default=None, + help="Workspace token for connecting to the Hypha server", + ) + parser.add_argument( + "--cache_dir", + default="./models", + help="Directory for caching the models", + ) + parser.add_argument( + "--ray_address", + default=None, + help="Address of the Ray cluster for running SAM", + ) + args = parser.parse_args() + + loop = asyncio.get_event_loop() + loop.create_task(register_service(args=args)) + loop.run_forever() diff --git a/bioimageio_colab/bioengine_app.py b/bioimageio_colab/bioengine_app.py deleted file mode 100644 index 7159084..0000000 --- a/bioimageio_colab/bioengine_app.py +++ /dev/null @@ -1,89 +0,0 @@ -import asyncio -import os - -import requests -from hypha_rpc import connect_to_server -from tifffile import imread - -BASE_URL = "https://raw.githubusercontent.com/bioimage-io/bioimageio-colab/" -BRANCH = "main" - -# Download the register_sam_service.py file -script_url = os.path.join(BASE_URL, BRANCH, "bioimageio_colab/register_sam_service.py") -script = requests.get(script_url).text -# Remove everything after the 'async def register_service' function -script = script.split("async def register_service")[0] - -# Imports -imports = "\n".join([line for line in script.split("\n") if "import" in line]) - - -# Functions -functions = "def" + "def".join([f for f in script.split("def")[1:]]) - -functions = "\n".join([" " + line for line in functions.split("\n")]) - -# Define the execute function -run_segmentation_script = imports + """ -def execute(image, point_coordinates, point_labels): - MODELS = { - "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", - "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b.pt", - "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt", - } - STORAGE = {} - CONTEXT = {"user": {"id": "dummy"}} - - logger = getLogger(__name__) - logger.setLevel("INFO") - -""" + functions + """ - compute_embedding("vit_b", image, CONTEXT) - features = segment(point_coordinates, point_labels, CONTEXT) - return features -""" - -print(run_segmentation_script) - -# Define the pip requirements -base_requirements_file = os.path.join(BASE_URL, BRANCH, "requirements.txt") -base_requirements = requests.get(base_requirements_file).text - -sam_requirements_file = os.path.join(BASE_URL, BRANCH, "requirements-sam.txt") -sam_requirements = requests.get(sam_requirements_file).text - -pip_requirements = [ - requirement - for requirement in (base_requirements + sam_requirements).split("\n") - if requirement and not requirement.startswith(("#", "-r")) -] + ["python-dotenv"] -print(pip_requirements) - - -async def main(name, script, pip_requirements): - # Connect to the Hypha server - server_url = "https://hypha.aicell.io" - workspace_id = "bioengine-apps" - service_id = "ray-function-registry" - - server = await connect_to_server({"server_url": server_url}) - - # Retrieve the Ray Function Registry service - svc = await server.get_service(f"{workspace_id}/{service_id}") - - # Register the ResNet function - function_id = await svc.register_function(name=name, script=script, pip_requirements=pip_requirements) - print(f"Registered function with id: {function_id}") - - # Example image - image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "example_image.tif") - image = imread(image_path) # np.ndarray - point_coordinates = [[10, 10]] # Union[list, np.ndarray] - point_labels = [0] # Union[list, np.ndarray] - - # Run the ResNet function - result = await svc.run_function(function_id=function_id, args=[image, point_coordinates, point_labels]) - print("Segmentation result:", result) - -if __name__ == "__main__": - asyncio.run(main("microSAM", script, pip_requirements)) \ No newline at end of file diff --git a/bioimageio_colab/create_workspace.py b/bioimageio_colab/create_workspace.py index 5f82f8b..6ebdeff 100644 --- a/bioimageio_colab/create_workspace.py +++ b/bioimageio_colab/create_workspace.py @@ -1,7 +1,12 @@ +""" +Hypha now supports creating workspaces from the Hypha Dashboard. This script is no longer needed. +""" + import argparse -from hypha_rpc import connect_to_server, login import asyncio +from hypha_rpc import connect_to_server, login + async def create_workspace_token(args): # Get a user login token @@ -64,6 +69,7 @@ async def create_workspace_token(args): for service in services: print(f"- {service['name']}") + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Create the BioImageIO Colab workspace." diff --git a/bioimageio_colab/register_sam_service.py b/bioimageio_colab/register_sam_service.py index 74356be..7cec084 100644 --- a/bioimageio_colab/register_sam_service.py +++ b/bioimageio_colab/register_sam_service.py @@ -1,28 +1,21 @@ -import argparse -import io +import logging import os from functools import partial -import logging -from typing import Union +from typing import Literal import numpy as np -import requests import torch -from cachetools import TTLCache from dotenv import find_dotenv, load_dotenv from hypha_rpc import connect_to_server from kaibu_utils import mask_to_features -from segment_anything import SamPredictor, sam_model_registry +from bioimageio_colab.sam import compute_embedding, load_model_from_ckpt, segment_image + +# Load environment variables ENV_FILE = find_dotenv() if ENV_FILE: load_dotenv(ENV_FILE) -MODELS = { - "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", - "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b.pt", - "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt", -} logger = logging.getLogger(__name__) logger.setLevel("INFO") @@ -30,205 +23,91 @@ logger.propagate = False # Create a new console handler console_handler = logging.StreamHandler() -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') +formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" +) console_handler.setFormatter(formatter) logger.addHandler(console_handler) -def _load_model( - model_cache: TTLCache, model_name: str, user_id: str -) -> torch.nn.Module: - if model_name not in MODELS: - raise ValueError( - f"Model {model_name} not found. Available models: {list(MODELS.keys())}" - ) - - # Check cache first - sam = model_cache.get(model_name, None) - if sam: - logger.info( - f"User {user_id} - Loading model '{model_name}' from cache (device={sam.device})..." - ) - else: - # Download model if not in cache - model_url = MODELS[model_name] - logger.info( - f"User {user_id} - Loading model '{model_name}' from {model_url}..." - ) - response = requests.get(model_url) - if response.status_code != 200: - raise RuntimeError(f"Failed to download model from {model_url}") - buffer = io.BytesIO(response.content) - - # Load model state - device = "cuda" if torch.cuda.is_available() else "cpu" - ckpt = torch.load(buffer, map_location=device) - model_type = model_name[:5] - sam = sam_model_registry[model_type]() - sam.load_state_dict(ckpt) - logger.info( - f"User {user_id} - Caching model '{model_name}' (device={device})..." - ) - - # Cache the model / renew the TTL - model_cache[model_name] = sam - - # Create a SAM predictor - sam_predictor = SamPredictor(sam) - return sam_predictor +def hello(context: dict = None) -> str: + return "Welcome to the Interactive Segmentation service!" -def _to_image(input_: np.ndarray) -> np.ndarray: - # we require the input to be uint8 - if input_.dtype != np.dtype("uint8"): - # first normalize the input to [0, 1] - input_ = input_.astype("float32") - input_.min() - input_ = input_ / input_.max() - # then bring to [0, 255] and cast to uint8 - input_ = (input_ * 255).astype("uint8") - if input_.ndim == 2: - image = np.concatenate([input_[..., None]] * 3, axis=-1) - elif input_.ndim == 3 and input_.shape[-1] == 3: - image = input_ - else: - raise ValueError( - f"Invalid input image of shape {input_.shape}. Expect either 2D grayscale or 3D RGB image." - ) - return image +def ping(context: dict = None) -> str: + return {"status": "ok"} -def _calculate_embedding( - embedding_cache: TTLCache, - sam_predictor: SamPredictor, +def compute_image_embedding( + cache_dir: str, + ray_address: str, model_name: str, image: np.ndarray, - user_id: str, + context: dict = None, ) -> np.ndarray: - # Calculate the embedding if not cached - predictor_dict = embedding_cache.get(user_id, {}) - if predictor_dict and predictor_dict.get("model_name") == model_name: - logger.info( - f"User {user_id} - Loading image embedding from cache (model: '{model_name}')..." - ) - for key, value in predictor_dict.items(): - if key != "model_name": - setattr(sam_predictor, key, value) - else: - logger.info( - f"User {user_id} - Computing image embedding (model: '{model_name}')..." - ) - sam_predictor.set_image(_to_image(image)) - logger.info( - f"User {user_id} - Caching image embedding (model: '{model_name}')..." - ) - predictor_dict = { - "model_name": model_name, - "original_size": sam_predictor.original_size, - "input_size": sam_predictor.input_size, - "features": sam_predictor.features, # embedding - "is_image_set": sam_predictor.is_image_set, - } - # Cache the embedding / renew the TTL - embedding_cache[user_id] = predictor_dict + """ + Compute the embeddings of an image using the specified model. + """ + user_id = context["user"].get("id") + logger.info(f"User '{user_id}' - Computing embedding (model: '{model_name}')...") - return sam_predictor + sam_predictor = load_model_from_ckpt(model_name, cache_dir) + sam_predictor = compute_embedding(sam_predictor, image) + logger.info(f"User '{user_id}' - Embedding computed successfully.") -def _segment_image( - sam_predictor: SamPredictor, - model_name: str, - point_coordinates: Union[list, np.ndarray], - point_labels: Union[list, np.ndarray], - user_id: str, -): - if isinstance(point_coordinates, list): - point_coordinates = np.array(point_coordinates, dtype=np.float32) - if isinstance(point_labels, list): - point_labels = np.array(point_labels, dtype=np.float32) - logger.debug( - f"User {user_id} - point coordinates: {point_coordinates}, {point_labels}" - ) - logger.info(f"User {user_id} - Segmenting image (model: '{model_name}')...") - mask, scores, logits = sam_predictor.predict( - point_coords=point_coordinates[:, ::-1], # SAM has reversed XY conventions - point_labels=point_labels, - multimask_output=False, - ) - logger.debug(f"User {user_id} - predicted mask of shape {mask.shape}") - features = mask_to_features(mask[0]) - return features + return sam_predictor.features.detach().cpu().numpy() -def segment( - model_cache: TTLCache, - embedding_cache: TTLCache, +def compute_mask( + cache_dir: str, + ray_address: str, model_name: str, - image: np.ndarray, - point_coordinates: Union[list, np.ndarray], - point_labels: Union[list, np.ndarray], + embedding: np.ndarray, + image_size: tuple, + point_coords: np.ndarray, + point_labels: np.ndarray, + format: Literal["mask", "kaibu"], context: dict = None, -) -> list: +) -> np.ndarray: + """ + Segment the image using the specified model and the provided point coordinates and labels. + """ user_id = context["user"].get("id") - if not user_id: - logger.info("User ID not found in context.") - return False + logger.info(f"User '{user_id}' - Segmenting image (model: '{model_name}')...") + + if not format in ["mask", "kaibu"]: + raise ValueError("Invalid format. Please choose either 'mask' or 'kaibu'.") # Load the model - sam_predictor = _load_model(model_cache, model_name, user_id) + sam_predictor = load_model_from_ckpt(model_name, cache_dir) - # Calculate the embedding - sam_predictor = _calculate_embedding( - embedding_cache, sam_predictor, model_name, image, user_id - ) + # Set the embedding + sam_predictor.original_size = image_size + sam_predictor.input_size = tuple([sam_predictor.model.image_encoder.img_size] * 2) + sam_predictor.features = torch.as_tensor(embedding, device=sam_predictor.device) + sam_predictor.is_image_set = True # Segment the image - features = _segment_image( - sam_predictor, model_name, point_coordinates, point_labels, user_id - ) + masks = segment_image(sam_predictor, point_coords, point_labels) - return features + if format == "mask": + features = masks + elif format == "kaibu": + features = [mask_to_features(mask) for mask in masks] -def compute_embedding(model_cache: TTLCache, model_name, image, context=None): - user_id = context["user"].get("id") - sam_predictor = _load_model(model_cache, model_name, user_id) - logger.info( - f"User {user_id} - Computing image embedding (model: '{model_name}')..." - ) - sam_predictor.set_image(_to_image(image)) - return { - "model_name": model_name, - "original_size": sam_predictor.original_size, - "input_size": sam_predictor.input_size, - "features": sam_predictor.get_image_embedding().cpu().numpy(), - } - -def clear_cache(embedding_cache: TTLCache, context: dict = None) -> bool: - user_id = context["user"].get("id") - if user_id not in embedding_cache: - return False - else: - logger.info(f"User {user_id} - Resetting embedding cache...") - del embedding_cache[user_id] - return True - - -def hello(context: dict = None) -> str: - return "Welcome to the Interactive Segmentation service!" + logger.info(f"User '{user_id}' - Image segmented successfully.") -def ping(context: dict = None) -> str: - return "pong" + return features -def check_readiness(service_list): - # Check if the service is ready - assert len(service_list) > 0, "Service is not ready" - return {"status": "ok"} -async def check_liveness(colab_client, workspace_name, client_id, service_id): - # Check if the service is alive - sid = f"{workspace_name}/{client_id}:{service_id}" - service = await colab_client.get_service(sid) - alive = await service.ping() == "pong" - assert alive, "Service is not alive" +def test_model(cache_dir: str, ray_address: str, model_name: str, context: dict = None): + """ + Test the segmentation service. + """ + image = np.random.rand(1024, 1024, 3) + embedding = compute_image_embedding(cache_dir, model_name, image, context) + assert embedding return {"status": "ok"} @@ -252,24 +131,6 @@ async def register_service(args: dict) -> None: client_id = colab_client.config["client_id"] client_base_url = f"{args.server_url}/{args.workspace_name}/services/{client_id}" - # Register a probe for the service - service_list = [] - await colab_client.register_probes({ - "readiness": partial(check_readiness, service_list), - "liveness": partial(check_liveness, colab_client, args.workspace_name, client_id, args.service_id) - }) - logger.info(f"Probes registered in workspace: {args.workspace_name}") - logger.info( - f"Test the readiness probe here: {client_base_url}:probes/readiness" - ) - logger.info( - f"Test the liveness probe here: {client_base_url}:probes/liveness" - ) - - # Initialize caches - model_cache = TTLCache(maxsize=len(MODELS), ttl=args.model_timeout) - embedding_cache = TTLCache(maxsize=args.max_num_clients, ttl=args.embedding_timeout) - # Register a new service service_info = await colab_client.register_service( { @@ -283,81 +144,38 @@ async def register_service(args: dict) -> None: # Exposed functions: "hello": hello, "ping": ping, - # **Run segmentation** - # Params: - # - model name - # - image to compute the embeddings on - # - point coordinates (XY format) - # - point labels - # Returns: - # - a list of XY coordinates of the segmented polygon in the format (1, N, 2) - "segment": partial(segment, model_cache, embedding_cache), - # **Compute the embedding of an image** - # Params: - # - model name - # - image to compute the embeddings on - # Returns: - # - a dictionary containing the computed embedding, original size, and input size - "compute_embedding": partial(compute_embedding, model_cache), - # **Clear the embedding cache** - # Returns: - # - True if the embedding was removed successfully - # - False if the user was not found in the cache - "clear_cache": partial(clear_cache, embedding_cache), + "compute_embedding": partial( + compute_image_embedding, args.cache_dir, ray_address=args.ray_address + ), + "compute_mask": partial( + compute_mask, args.cache_dir, ray_address=args.ray_address + ), + "test_model": partial( + test_model, args.cache_dir, ray_address=args.ray_address + ), } ) sid = service_info["id"] logger.info(f"Service registered with ID: {sid}") - logger.info( - f"Test the service here: {client_base_url}:{args.service_id}/hello" - ) - service_list.append(sid) + logger.info(f"Test the service here: {client_base_url}:{args.service_id}/hello") if __name__ == "__main__": - import asyncio - - parser = argparse.ArgumentParser( - description="Register SAM annotation service on BioImageIO Colab workspace." - ) - parser.add_argument( - "--server_url", - default="https://hypha.aicell.io", - help="URL of the Hypha server", - ) - parser.add_argument( - "--workspace_name", default="bioimageio-colab", help="Name of the workspace" + model_name = "vit_b" + cache_dir = "./models" + embedding = compute_image_embedding( + cache_dir=cache_dir, + model_name=model_name, + image=np.random.rand(1024, 1024, 3), + context={"user": {"id": "test"}}, + ) + mask = compute_mask( + cache_dir=cache_dir, + model_name="vit_b", + embedding=embedding, + image_size=(1024, 1024), + point_coords=np.array([[10, 10]]), + point_labels=np.array([1]), + format="kaibu", + context={"user": {"id": "test"}}, ) - parser.add_argument( - "--service_id", - default="microsam", - help="Service ID for registering the service", - ) - parser.add_argument( - "--token", - default=None, - help="Workspace token for connecting to the Hypha server", - ) - parser.add_argument( - "--model_timeout", - type=int, - default=9600, # 3 hours - help="Model cache timeout in seconds", - ) - parser.add_argument( - "--embedding_timeout", - type=int, - default=600, # 10 minutes - help="Embedding cache timeout in seconds", - ) - parser.add_argument( - "--max_num_clients", - type=int, - default=50, - help="Maximum number of clients to cache embeddings for", - ) - args = parser.parse_args() - - loop = asyncio.get_event_loop() - loop.create_task(register_service(args=args)) - loop.run_forever() diff --git a/bioimageio_colab/sam.py b/bioimageio_colab/sam.py new file mode 100644 index 0000000..6af290b --- /dev/null +++ b/bioimageio_colab/sam.py @@ -0,0 +1,104 @@ +import os +from typing import Optional + +import numpy as np +import requests +import torch +from segment_anything import SamPredictor, sam_model_registry + +MODELS = { + "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b.pt", + "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt", +} + + +def load_model_from_ckpt(model_name: str, cache_dir: str) -> torch.nn.Module: + model_url = MODELS.get(model_name, None) + + if model_url is None: + raise ValueError( + f"Model {model_name} not found. Available models: {list(MODELS.keys())}" + ) + + # Download model if not in cache + basename = model_url.split("/")[-1] + model_path = os.path.join(cache_dir, basename) + + if not os.path.exists(model_path): + os.makedirs(cache_dir, exist_ok=True) + response = requests.get(model_url) + if response.status_code != 200: + raise RuntimeError(f"Failed to download model from {model_url}") + with open(model_path, "wb") as f: + f.write(response.content) + + # Load model state + device = "cuda" if torch.cuda.is_available() else "cpu" + ckpt = torch.load(model_path, map_location=device) + model_architecture = model_name[:5] + sam = sam_model_registry[model_architecture]() + sam.load_state_dict(ckpt) + + # Create a SAM predictor + sam_predictor = SamPredictor(sam) + sam_predictor.model_architecture = model_architecture + + return sam_predictor + + +def _to_image_format(array: np.ndarray) -> np.ndarray: + # we require the input to be uint8 + if array.dtype != np.dtype("uint8"): + # first normalize the input to [0, 1] + array = array.astype("float32") - array.min() + array = array / array.max() + # then bring to [0, 255] and cast to uint8 + array = (array * 255).astype("uint8") + if array.ndim == 2: + image = np.concatenate([array[..., None]] * 3, axis=-1) + elif array.ndim == 3 and array.shape[-1] == 3: + image = array + else: + raise ValueError( + f"Invalid input image of shape {array.shape}. Expected either 2-channel grayscale or 3-channel RGB image." + ) + return image + + +def compute_embedding( + sam_predictor: SamPredictor, + array: np.ndarray, +) -> np.ndarray: + # Run image encoder to compute the embedding + image = _to_image_format(array) + sam_predictor.set_image(image) + + return sam_predictor + + +def segment_image( + sam_predictor: SamPredictor, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + # box: Optional[np.ndarray] = None, + # mask_input: Optional[np.ndarray] = None, +): + if isinstance(point_coords, list): + point_coords = np.array(point_coords, dtype=np.float32) + if isinstance(point_labels, list): + point_labels = np.array(point_labels, dtype=np.float32) + masks, scores, logits = sam_predictor.predict( + point_coords=point_coords[:, ::-1], # SAM has reversed XY conventions + point_labels=point_labels, + box=None, # Not supported yet + mask_input=None, # Not supported yet + multimask_output=False, + ) + return masks + + +if __name__ == "__main__": + sam_predictor = load_model_from_ckpt("vit_b", "./models") + sam_predictor = compute_embedding(sam_predictor, np.random.rand(256, 256)) + masks = segment_image(sam_predictor, [[10, 10]], [1]) diff --git a/chatbot_extension/data_provider_extension.py b/chatbot_extension/data_provider_extension.py deleted file mode 100644 index bc438df..0000000 --- a/chatbot_extension/data_provider_extension.py +++ /dev/null @@ -1,48 +0,0 @@ -import asyncio -from pydantic import BaseModel, Field -from hypha_rpc import login, connect_to_server -from bioimageio_colab.data_provider import start_server - -class RegisterService(BaseModel): - """Register collaborative annotation service to start a collaborative annotation session. Returns the URL for the annotation server.""" - path2data: str = Field(..., description="Path to data folder from which the images are loaded; example: /mnt/data") - outpath: str = Field(..., description="Path to output folder to which the annotations are saved; example: /mnt/annotations") - - -async def register_service(kwargs): - annotator_url = await start_server(**kwargs) - return f"Annotation server is running at: {annotator_url}" - - -def get_schema(): - return { - "move_stage": RegisterService.schema() - } - -async def main(): - # Define an chatbot extension - microscope_control_extension = { - "_rintf": True, - "id": "annotation-colab-provider", - "config": {"visibility": "public"}, - "type": "bioimageio-chatbot-extension", - "name": "Annotation Colab Data Provider", - "description": "This extension starts a collaborative annotation session. It provides data for remote annotation and then saves the annotations.", - "get_schema": get_schema, - "tools": { - "move_stage": register_service, - } - } - - # Connect to the chat server - server_url = "https://chat.bioimage.io" - token = await login({"server_url": server_url}) - server = await connect_to_server({"server_url": server_url, "token": token}) - # Register the extension service - svc = await server.register_service(microscope_control_extension) - print(f"Extension service registered with id: {svc.id}, you can visit the service at: https://bioimage.io/chat?server={server_url}&extension={svc.id}&assistant=Bridget") - -if __name__ == "__main__": - loop = asyncio.get_event_loop() - loop.create_task(main()) - loop.run_forever() diff --git a/docs/data_providing_service.py b/docs/data_providing_service.py index 89c00ce..e302794 100644 --- a/docs/data_providing_service.py +++ b/docs/data_providing_service.py @@ -29,8 +29,10 @@ def get_random_image(image_folder: str, supported_file_types: Tuple[str]): return (image, file_name.split(".")[0]) -def save_annotation(annotations_folder: str, image_name: str, features, image_shape): - mask = features_to_mask(features, image_shape) +def save_annotation( + annotations_folder: str, image_name: str, features: list, image_shape: tuple +): + mask = features_to_mask(features, image_shape[:2]) n_image_masks = len( [f for f in os.listdir(annotations_folder) if f.startswith(image_name)] ) @@ -56,6 +58,7 @@ def upload_image_to_s3(): """ raise NotImplementedError + async def register_service( server_url: str, token: str, @@ -82,7 +85,7 @@ async def register_service( { "name": name, "description": description, - "id": "data-provider-" + str(int(time.time()*100)), + "id": "data-provider-" + str(int(time.time() * 100)), "type": "annotation-data-provider", "config": { "visibility": "public", # TODO: make protected @@ -99,4 +102,4 @@ async def register_service( "save_annotation": partial(save_annotation, annotations_path), } ) - print(f"Service registered with ID: {svc['id']}") \ No newline at end of file + print(f"Service registered with ID: {svc['id']}") diff --git a/requirements-sam.txt b/requirements-sam.txt index 6af2dcb..23e0c12 100644 --- a/requirements-sam.txt +++ b/requirements-sam.txt @@ -1,5 +1,5 @@ -r "requirements.txt" -torch==2.3.1 -torchvision==0.18.1 +ray==2.33.0 segment_anything==1.0 -cachetools==5.5.0 \ No newline at end of file +torch==2.3.1 +torchvision==0.18.1 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ea0a1f6..2bfe6aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ hypha-rpc==0.20.43 -numpy==1.26.4 -requests==2.31.0 kaibu-utils==0.1.14 +numpy==1.26.4 python-dotenv==1.0.1 +requests==2.31.0 From d0c080f45da1f4f263c756028f52db083706454d Mon Sep 17 00:00:00 2001 From: Nils Mechtel Date: Sat, 14 Dec 2024 20:37:10 +0100 Subject: [PATCH 2/8] update tests --- bioimageio_colab/register_sam_service.py | 50 ++++++++++---- bioimageio_colab/sam.py | 2 +- test/test_image_provider.py | 39 +++++++++++ test/test_model_service.py | 86 +++++++++++++++++------- test/test_sam.py | 20 ++++++ 5 files changed, 159 insertions(+), 38 deletions(-) create mode 100644 test/test_image_provider.py create mode 100644 test/test_sam.py diff --git a/bioimageio_colab/register_sam_service.py b/bioimageio_colab/register_sam_service.py index 7cec084..40d6a20 100644 --- a/bioimageio_colab/register_sam_service.py +++ b/bioimageio_colab/register_sam_service.py @@ -48,11 +48,17 @@ def compute_image_embedding( """ Compute the embeddings of an image using the specified model. """ - user_id = context["user"].get("id") + user_id = context["user"].get("id") if context else "anonymous" logger.info(f"User '{user_id}' - Computing embedding (model: '{model_name}')...") - sam_predictor = load_model_from_ckpt(model_name, cache_dir) - sam_predictor = compute_embedding(sam_predictor, image) + sam_predictor = load_model_from_ckpt( + model_name=model_name, + cache_dir=cache_dir, + ) + sam_predictor = compute_embedding( + sam_predictor=sam_predictor, + array=image, + ) logger.info(f"User '{user_id}' - Embedding computed successfully.") @@ -67,20 +73,23 @@ def compute_mask( image_size: tuple, point_coords: np.ndarray, point_labels: np.ndarray, - format: Literal["mask", "kaibu"], + format: Literal["mask", "kaibu"] = "mask", context: dict = None, ) -> np.ndarray: """ Segment the image using the specified model and the provided point coordinates and labels. """ - user_id = context["user"].get("id") + user_id = context["user"].get("id") if context else "anonymous" logger.info(f"User '{user_id}' - Segmenting image (model: '{model_name}')...") if not format in ["mask", "kaibu"]: raise ValueError("Invalid format. Please choose either 'mask' or 'kaibu'.") # Load the model - sam_predictor = load_model_from_ckpt(model_name, cache_dir) + sam_predictor = load_model_from_ckpt( + model_name=model_name, + cache_dir=cache_dir, + ) # Set the embedding sam_predictor.original_size = image_size @@ -89,7 +98,11 @@ def compute_mask( sam_predictor.is_image_set = True # Segment the image - masks = segment_image(sam_predictor, point_coords, point_labels) + masks = segment_image( + sam_predictor=sam_predictor, + point_coords=point_coords, + point_labels=point_labels, + ) if format == "mask": features = masks @@ -105,9 +118,16 @@ def test_model(cache_dir: str, ray_address: str, model_name: str, context: dict """ Test the segmentation service. """ - image = np.random.rand(1024, 1024, 3) - embedding = compute_image_embedding(cache_dir, model_name, image, context) - assert embedding + image = np.random.rand(1024, 1024) + embedding = compute_image_embedding( + cache_dir=cache_dir, + ray_address=ray_address, + model_name=model_name, + image=image, + context=context, + ) + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (1, 256, 64, 64) return {"status": "ok"} @@ -145,13 +165,15 @@ async def register_service(args: dict) -> None: "hello": hello, "ping": ping, "compute_embedding": partial( - compute_image_embedding, args.cache_dir, ray_address=args.ray_address + compute_image_embedding, + cache_dir=args.cache_dir, + ray_address=args.ray_address, ), "compute_mask": partial( - compute_mask, args.cache_dir, ray_address=args.ray_address + compute_mask, cache_dir=args.cache_dir, ray_address=args.ray_address ), "test_model": partial( - test_model, args.cache_dir, ray_address=args.ray_address + test_model, cache_dir=args.cache_dir, ray_address=args.ray_address ), } ) @@ -166,7 +188,7 @@ async def register_service(args: dict) -> None: embedding = compute_image_embedding( cache_dir=cache_dir, model_name=model_name, - image=np.random.rand(1024, 1024, 3), + image=np.random.rand(1024, 1024), context={"user": {"id": "test"}}, ) mask = compute_mask( diff --git a/bioimageio_colab/sam.py b/bioimageio_colab/sam.py index 6af290b..12708a0 100644 --- a/bioimageio_colab/sam.py +++ b/bioimageio_colab/sam.py @@ -100,5 +100,5 @@ def segment_image( if __name__ == "__main__": sam_predictor = load_model_from_ckpt("vit_b", "./models") - sam_predictor = compute_embedding(sam_predictor, np.random.rand(256, 256)) + sam_predictor = compute_embedding(sam_predictor, np.random.rand(1024, 1024)) masks = segment_image(sam_predictor, [[10, 10]], [1]) diff --git a/test/test_image_provider.py b/test/test_image_provider.py new file mode 100644 index 0000000..084ca3a --- /dev/null +++ b/test/test_image_provider.py @@ -0,0 +1,39 @@ +import os +import sys + +import numpy as np + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from kaibu_utils import mask_to_features + +from docs.data_providing_service import get_random_image, save_annotation + + +def test_load_image(): + supported_file_types = ("tif", "tiff") + image, image_name = get_random_image( + image_folder="./bioimageio_colab/", + supported_file_types=supported_file_types, + ) + assert image is not None + assert isinstance(image, np.ndarray) + assert image.shape == (512, 512, 3) + assert image_name is not None + assert isinstance(image_name, str) + assert image_name == "example_image" + + +# def test_save_annotation(): +# mask = np.zeros((512, 512)) +# mask[10:20, 10:20] = 1 +# features = mask_to_features(mask) + +# save_annotation( +# annotations_folder="./", +# image_name="test_image", +# features=features, # square coordinates +# image_shape=(512, 512) +# ) +# assert os.path.exists("test_image_mask_1.tif") +# os.remove("test_image_mask_1.tif") +# assert True diff --git a/test/test_model_service.py b/test/test_model_service.py index 88843c5..f870836 100644 --- a/test/test_model_service.py +++ b/test/test_model_service.py @@ -1,7 +1,7 @@ -from hypha_rpc.sync import connect_to_server import numpy as np import requests - +from hypha_rpc.sync import connect_to_server +from bioimageio_colab.register_sam_service import compute_image_embedding, compute_mask SERVER_URL = "https://hypha.aicell.io" WORKSPACE_NAME = "bioimageio-colab" @@ -9,35 +9,75 @@ MODEL_NAME = "vit_b" -def test_service_available(): +def test_service_functions(): + cache_dir = "./models" + embedding = compute_image_embedding( + cache_dir=cache_dir, + ray_address=None, + model_name=MODEL_NAME, + image=np.random.rand(1024, 1024), + context={"user": {"id": "test"}}, + ) + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (1, 256, 64, 64) + + polygon_features = compute_mask( + cache_dir=cache_dir, + ray_address=None, + model_name=MODEL_NAME, + embedding=embedding, + image_size=(1024, 1024), + point_coords=np.array([[10, 10]]), + point_labels=np.array([1]), + format="kaibu", + context={"user": {"id": "test"}}, + ) + assert isinstance(polygon_features, list) + assert len(polygon_features) == 1 # Only one point given + + +def test_service_is_running_http_api(): service_url = f"{SERVER_URL}/{WORKSPACE_NAME}/services/{SERVICE_ID}/ping" response = requests.get(service_url) assert response.status_code == 200 - assert response.json() == "pong" + assert response.json() == {"status": "ok"} + -def test_get_service(): +def test_service_python_api(): client = connect_to_server({"server_url": SERVER_URL, "method_timeout": 5}) assert client sid = f"{WORKSPACE_NAME}/{SERVICE_ID}" segment_svc = client.get_service(sid, {"mode": "random"}) assert segment_svc.config.workspace == WORKSPACE_NAME - assert segment_svc.get("segment") - assert segment_svc.get("clear_cache") - - # Test segmentation - image = np.random.rand(256, 256) - features = segment_svc.segment(model_name=MODEL_NAME, image=image, point_coordinates=[[128, 128]], point_labels=[1]) - assert features - - # Test embedding caching - features = segment_svc.segment(model_name=MODEL_NAME, image=image, point_coordinates=[[20, 50]], point_labels=[1]) - features = segment_svc.segment(model_name=MODEL_NAME, image=image, point_coordinates=[[180, 10]], point_labels=[1]) - - # Test embedding computation for running SAM client-side - result = segment_svc.compute_embedding(model_name=MODEL_NAME, image=image) - assert result - embedding = result["features"] - assert embedding.shape == (1, 256, 64, 64) + assert segment_svc.get("hello") + assert segment_svc.get("ping") + assert segment_svc.get("compute_embedding") + assert segment_svc.get("compute_mask") + assert segment_svc.get("test_model") + + # TODO: Uncomment again when hypha-rpc is fixed + # Test embedding computation + # image = np.random.rand(1024, 1024) + # embedding = segment_svc.compute_embedding( + # model_name=MODEL_NAME, + # image=image, + # ) + # assert isinstance(embedding, np.nparray) + # assert embedding.shape == (1, 256, 64, 64) + + # Test mask computation + # polygon_features = segment_svc.compute_mask( + # model_name=MODEL_NAME, + # embedding=embedding, + # image_size=image.shape[:2], + # point_coordinates=[[10, 10]], + # point_labels=[1], + # format="kaibu", + # ) + # assert isinstance(polygon_features, list) + # assert len(polygon_features) == 1 # Only one point given - assert segment_svc.clear_cache() + # Test service test run + result = segment_svc.test_model(model_name="vit_b") + assert result == {"status": "ok"} diff --git a/test/test_sam.py b/test/test_sam.py new file mode 100644 index 0000000..8339620 --- /dev/null +++ b/test/test_sam.py @@ -0,0 +1,20 @@ +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from bioimageio_colab.sam import compute_embedding, load_model_from_ckpt, segment_image +from docs.data_providing_service import get_random_image + + +def test_sam(): + sam_predictor = load_model_from_ckpt(model_name="vit_b", cache_dir="./models/") + assert os.path.exists("./models/sam_vit_b_01ec64.pth") + image, _ = get_random_image( + image_folder="./bioimageio_colab/", + supported_file_types=("tif"), + ) + sam_predictor = compute_embedding(sam_predictor, image) + assert sam_predictor.is_image_set is True + + masks = segment_image(sam_predictor, point_coords=[[80, 80]], point_labels=[1]) + assert all([mask.shape == image.shape[:2] for mask in masks]) From 9eee2af9aad22a746d1ad7bb5a0dbf3a97050375 Mon Sep 17 00:00:00 2001 From: Nils Mechtel Date: Sat, 14 Dec 2024 20:56:26 +0100 Subject: [PATCH 3/8] update docker image --- Dockerfile | 11 ++++------- bioimageio_colab/__main__.py | 2 +- bioimageio_colab/sam.py | 2 +- scripts/start_service.sh | 3 --- test/test_model_service.py | 2 +- test/test_sam.py | 4 ++-- 6 files changed, 9 insertions(+), 15 deletions(-) delete mode 100644 scripts/start_service.sh diff --git a/Dockerfile b/Dockerfile index 19e70f7..ca0dc7a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -31,19 +31,16 @@ COPY ./requirements-sam.txt /app/requirements-sam.txt RUN pip install -r /app/requirements-sam.txt # Copy the python script to the docker environment -COPY ./bioimageio_colab/register_sam_service.py /app/register_sam_service.py +COPY ./bioimageio_colab /app/bioimageio_colab -# Copy the start service script -COPY ./scripts/start_service.sh /app/start_service.sh +# Create cache directory for models +RUN mkdir -p /app/.model_cache # Change ownership of the application directory to the non-root user RUN chown -R bioimageio_colab:bioimageio_colab /app/ -# Make the start script executable -RUN chmod +x /app/start_service.sh - # Switch to the non-root user USER bioimageio_colab # Use the start script as the entrypoint and forward arguments -ENTRYPOINT ["python", "register_sam_service.py"] +ENTRYPOINT ["python", "-m", "bioimageio_colab"] diff --git a/bioimageio_colab/__main__.py b/bioimageio_colab/__main__.py index 5d8d4e6..c0bd7e1 100644 --- a/bioimageio_colab/__main__.py +++ b/bioimageio_colab/__main__.py @@ -27,7 +27,7 @@ ) parser.add_argument( "--cache_dir", - default="./models", + default="./.model_cache", help="Directory for caching the models", ) parser.add_argument( diff --git a/bioimageio_colab/sam.py b/bioimageio_colab/sam.py index 12708a0..03644ea 100644 --- a/bioimageio_colab/sam.py +++ b/bioimageio_colab/sam.py @@ -99,6 +99,6 @@ def segment_image( if __name__ == "__main__": - sam_predictor = load_model_from_ckpt("vit_b", "./models") + sam_predictor = load_model_from_ckpt("vit_b", "./model_cache") sam_predictor = compute_embedding(sam_predictor, np.random.rand(1024, 1024)) masks = segment_image(sam_predictor, [[10, 10]], [1]) diff --git a/scripts/start_service.sh b/scripts/start_service.sh deleted file mode 100644 index 8e45ddd..0000000 --- a/scripts/start_service.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash -# Pass all arguments to the Python script -python register_sam_service.py "$@" diff --git a/test/test_model_service.py b/test/test_model_service.py index f870836..8e4d463 100644 --- a/test/test_model_service.py +++ b/test/test_model_service.py @@ -10,7 +10,7 @@ def test_service_functions(): - cache_dir = "./models" + cache_dir = "./model_cache" embedding = compute_image_embedding( cache_dir=cache_dir, ray_address=None, diff --git a/test/test_sam.py b/test/test_sam.py index 8339620..e8048c8 100644 --- a/test/test_sam.py +++ b/test/test_sam.py @@ -7,8 +7,8 @@ def test_sam(): - sam_predictor = load_model_from_ckpt(model_name="vit_b", cache_dir="./models/") - assert os.path.exists("./models/sam_vit_b_01ec64.pth") + sam_predictor = load_model_from_ckpt(model_name="vit_b", cache_dir="./model_cache/") + assert os.path.exists("./model_cache/sam_vit_b_01ec64.pth") image, _ = get_random_image( image_folder="./bioimageio_colab/", supported_file_types=("tif"), From 4f4b9edc60fcf3069640c838b5a2800605e79e46 Mon Sep 17 00:00:00 2001 From: Nils Mechtel Date: Sat, 14 Dec 2024 20:56:36 +0100 Subject: [PATCH 4/8] log errors --- bioimageio_colab/register_sam_service.py | 97 +++++++++++++----------- 1 file changed, 52 insertions(+), 45 deletions(-) diff --git a/bioimageio_colab/register_sam_service.py b/bioimageio_colab/register_sam_service.py index 40d6a20..f639a37 100644 --- a/bioimageio_colab/register_sam_service.py +++ b/bioimageio_colab/register_sam_service.py @@ -48,21 +48,25 @@ def compute_image_embedding( """ Compute the embeddings of an image using the specified model. """ - user_id = context["user"].get("id") if context else "anonymous" - logger.info(f"User '{user_id}' - Computing embedding (model: '{model_name}')...") + try: + user_id = context["user"].get("id") if context else "anonymous" + logger.info(f"User '{user_id}' - Computing embedding (model: '{model_name}')...") - sam_predictor = load_model_from_ckpt( - model_name=model_name, - cache_dir=cache_dir, - ) - sam_predictor = compute_embedding( - sam_predictor=sam_predictor, - array=image, - ) + sam_predictor = load_model_from_ckpt( + model_name=model_name, + cache_dir=cache_dir, + ) + sam_predictor = compute_embedding( + sam_predictor=sam_predictor, + array=image, + ) - logger.info(f"User '{user_id}' - Embedding computed successfully.") + logger.info(f"User '{user_id}' - Embedding computed successfully.") - return sam_predictor.features.detach().cpu().numpy() + return sam_predictor.features.detach().cpu().numpy() + except Exception as e: + logger.error(f"User '{user_id}' - Error computing embedding: {e}") + raise e def compute_mask( @@ -79,39 +83,43 @@ def compute_mask( """ Segment the image using the specified model and the provided point coordinates and labels. """ - user_id = context["user"].get("id") if context else "anonymous" - logger.info(f"User '{user_id}' - Segmenting image (model: '{model_name}')...") + try: + user_id = context["user"].get("id") if context else "anonymous" + logger.info(f"User '{user_id}' - Segmenting image (model: '{model_name}')...") - if not format in ["mask", "kaibu"]: - raise ValueError("Invalid format. Please choose either 'mask' or 'kaibu'.") + if not format in ["mask", "kaibu"]: + raise ValueError("Invalid format. Please choose either 'mask' or 'kaibu'.") - # Load the model - sam_predictor = load_model_from_ckpt( - model_name=model_name, - cache_dir=cache_dir, - ) + # Load the model + sam_predictor = load_model_from_ckpt( + model_name=model_name, + cache_dir=cache_dir, + ) - # Set the embedding - sam_predictor.original_size = image_size - sam_predictor.input_size = tuple([sam_predictor.model.image_encoder.img_size] * 2) - sam_predictor.features = torch.as_tensor(embedding, device=sam_predictor.device) - sam_predictor.is_image_set = True - - # Segment the image - masks = segment_image( - sam_predictor=sam_predictor, - point_coords=point_coords, - point_labels=point_labels, - ) + # Set the embedding + sam_predictor.original_size = image_size + sam_predictor.input_size = tuple([sam_predictor.model.image_encoder.img_size] * 2) + sam_predictor.features = torch.as_tensor(embedding, device=sam_predictor.device) + sam_predictor.is_image_set = True + + # Segment the image + masks = segment_image( + sam_predictor=sam_predictor, + point_coords=point_coords, + point_labels=point_labels + ) - if format == "mask": - features = masks - elif format == "kaibu": - features = [mask_to_features(mask) for mask in masks] + if format == "mask": + features = masks + elif format == "kaibu": + features = [mask_to_features(mask) for mask in masks] - logger.info(f"User '{user_id}' - Image segmented successfully.") + logger.info(f"User '{user_id}' - Image segmented successfully.") - return features + return features + except Exception as e: + logger.error(f"User '{user_id}' - Error segmenting image: {e}") + raise e def test_model(cache_dir: str, ray_address: str, model_name: str, context: dict = None): @@ -152,6 +160,7 @@ async def register_service(args: dict) -> None: client_base_url = f"{args.server_url}/{args.workspace_name}/services/{client_id}" # Register a new service + cache_dir = os.path.abspath(args.cache_dir) service_info = await colab_client.register_service( { "name": "Interactive Segmentation", @@ -165,15 +174,13 @@ async def register_service(args: dict) -> None: "hello": hello, "ping": ping, "compute_embedding": partial( - compute_image_embedding, - cache_dir=args.cache_dir, - ray_address=args.ray_address, + compute_image_embedding, cache_dir=cache_dir, ray_address=args.ray_address ), "compute_mask": partial( - compute_mask, cache_dir=args.cache_dir, ray_address=args.ray_address + compute_mask, cache_dir=cache_dir, ray_address=args.ray_address ), "test_model": partial( - test_model, cache_dir=args.cache_dir, ray_address=args.ray_address + test_model, cache_dir=cache_dir, ray_address=args.ray_address ), } ) @@ -184,7 +191,7 @@ async def register_service(args: dict) -> None: if __name__ == "__main__": model_name = "vit_b" - cache_dir = "./models" + cache_dir = "./model_cache" embedding = compute_image_embedding( cache_dir=cache_dir, model_name=model_name, From 90948fe5312b4213f32ad45cbfd072e6a5632b75 Mon Sep 17 00:00:00 2001 From: Nils Mechtel Date: Sat, 14 Dec 2024 20:57:01 +0100 Subject: [PATCH 5/8] bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8add3c0..eaf6f46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"] [project] name = "bioimageio-colab" -version = "0.1.5" +version = "0.1.6" readme = "README.md" description = "Collaborative image annotation and model training with human in the loop." dependencies = [ From 2bd5ceded3cb1f3bb62b06cb1c8a4583de7cfd3e Mon Sep 17 00:00:00 2001 From: Nils Mechtel Date: Sat, 14 Dec 2024 22:04:39 +0100 Subject: [PATCH 6/8] include remote ray execution --- bioimageio_colab/register_sam_service.py | 73 +++++++++++++++++++----- test/test_model_service.py | 2 - 2 files changed, 60 insertions(+), 15 deletions(-) diff --git a/bioimageio_colab/register_sam_service.py b/bioimageio_colab/register_sam_service.py index f639a37..8c929ad 100644 --- a/bioimageio_colab/register_sam_service.py +++ b/bioimageio_colab/register_sam_service.py @@ -4,6 +4,7 @@ from typing import Literal import numpy as np +import ray import torch from dotenv import find_dotenv, load_dotenv from hypha_rpc import connect_to_server @@ -30,6 +31,15 @@ logger.addHandler(console_handler) +def parse_requirements(file_path): + with open(file_path, "r") as file: + lines = file.readlines() + # Filter and clean package names (skip comments and empty lines) + skip_lines = ("#", "-r ", "ray") + packages = [line.strip() for line in lines if line.strip() and not line.startswith(skip_lines)] + return packages + + def hello(context: dict = None) -> str: return "Welcome to the Interactive Segmentation service!" @@ -40,7 +50,6 @@ def ping(context: dict = None) -> str: def compute_image_embedding( cache_dir: str, - ray_address: str, model_name: str, image: np.ndarray, context: dict = None, @@ -67,11 +76,16 @@ def compute_image_embedding( except Exception as e: logger.error(f"User '{user_id}' - Error computing embedding: {e}") raise e + +@ray.remote +def compute_image_embedding_ray(kwargs: dict): + from bioimageio_colab.register_sam_service import compute_image_embedding + + return compute_image_embedding(**kwargs) def compute_mask( cache_dir: str, - ray_address: str, model_name: str, embedding: np.ndarray, image_size: tuple, @@ -120,16 +134,21 @@ def compute_mask( except Exception as e: logger.error(f"User '{user_id}' - Error segmenting image: {e}") raise e + +@ray.remote +def compute_mask_ray(kwargs: dict): + from bioimageio_colab.register_sam_service import compute_mask + + return compute_mask(**kwargs) -def test_model(cache_dir: str, ray_address: str, model_name: str, context: dict = None): +def test_model(cache_dir: str, model_name: str, context: dict = None): """ Test the segmentation service. """ image = np.random.rand(1024, 1024) embedding = compute_image_embedding( cache_dir=cache_dir, - ray_address=ray_address, model_name=model_name, image=image, context=context, @@ -139,6 +158,13 @@ def test_model(cache_dir: str, ray_address: str, model_name: str, context: dict return {"status": "ok"} +@ray.remote +def test_model_ray(kwargs: dict): + from bioimageio_colab.register_sam_service import test_model + + return test_model(**kwargs) + + async def register_service(args: dict) -> None: """ Register the SAM annotation service on the BioImageIO Colab workspace. @@ -159,6 +185,33 @@ async def register_service(args: dict) -> None: client_id = colab_client.config["client_id"] client_base_url = f"{args.server_url}/{args.workspace_name}/services/{client_id}" + if args.ray_address: + # Create runtime environment + base_requirements = parse_requirements("../requirements.txt") + sam_requirements = parse_requirements("../requirements-sam.txt") + + runtime_env = { + "pip": { + "pip": base_requirements + sam_requirements, + }, + "py_modules": ["../bioimageio_colab"] + } + + ray.init(runtime_env=runtime_env, address=args.ray_address) + + def compute_embedding_function(kwargs: dict): + return ray.get(compute_image_embedding_ray.remote(kwargs)) + + def compute_mask_function(kwargs: dict): + return ray.get(compute_mask_ray.remote(kwargs)) + + def test_model_function(kwargs: dict): + return ray.get(test_model_ray.remote(kwargs)) + else: + compute_embedding_function = compute_image_embedding + compute_mask_function = compute_mask + test_model_function = test_model + # Register a new service cache_dir = os.path.abspath(args.cache_dir) service_info = await colab_client.register_service( @@ -173,15 +226,9 @@ async def register_service(args: dict) -> None: # Exposed functions: "hello": hello, "ping": ping, - "compute_embedding": partial( - compute_image_embedding, cache_dir=cache_dir, ray_address=args.ray_address - ), - "compute_mask": partial( - compute_mask, cache_dir=cache_dir, ray_address=args.ray_address - ), - "test_model": partial( - test_model, cache_dir=cache_dir, ray_address=args.ray_address - ), + "compute_embedding": partial(compute_embedding_function, cache_dir=args.cache_dir), + "compute_mask": partial(compute_mask_function, cache_dir=args.cache_dir), + "test_model": partial(test_model_function, cache_dir=args.cache_dir), } ) sid = service_info["id"] diff --git a/test/test_model_service.py b/test/test_model_service.py index 8e4d463..bd548e7 100644 --- a/test/test_model_service.py +++ b/test/test_model_service.py @@ -13,7 +13,6 @@ def test_service_functions(): cache_dir = "./model_cache" embedding = compute_image_embedding( cache_dir=cache_dir, - ray_address=None, model_name=MODEL_NAME, image=np.random.rand(1024, 1024), context={"user": {"id": "test"}}, @@ -23,7 +22,6 @@ def test_service_functions(): polygon_features = compute_mask( cache_dir=cache_dir, - ray_address=None, model_name=MODEL_NAME, embedding=embedding, image_size=(1024, 1024), From 2dd8dcea0229111dbb01735d4c3774cf699fc4ea Mon Sep 17 00:00:00 2001 From: Nils Mechtel Date: Sat, 14 Dec 2024 22:50:28 +0100 Subject: [PATCH 7/8] fix parsed kwargs --- bioimageio_colab/register_sam_service.py | 43 ++++++++++++------------ 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/bioimageio_colab/register_sam_service.py b/bioimageio_colab/register_sam_service.py index 8c929ad..a3614cd 100644 --- a/bioimageio_colab/register_sam_service.py +++ b/bioimageio_colab/register_sam_service.py @@ -31,7 +31,7 @@ logger.addHandler(console_handler) -def parse_requirements(file_path): +def parse_requirements(file_path) -> list: with open(file_path, "r") as file: lines = file.readlines() # Filter and clean package names (skip comments and empty lines) @@ -78,7 +78,7 @@ def compute_image_embedding( raise e @ray.remote -def compute_image_embedding_ray(kwargs: dict): +def compute_image_embedding_ray(kwargs: dict) -> np.ndarray: from bioimageio_colab.register_sam_service import compute_image_embedding return compute_image_embedding(**kwargs) @@ -136,13 +136,13 @@ def compute_mask( raise e @ray.remote -def compute_mask_ray(kwargs: dict): +def compute_mask_ray(kwargs: dict) -> np.ndarray: from bioimageio_colab.register_sam_service import compute_mask return compute_mask(**kwargs) -def test_model(cache_dir: str, model_name: str, context: dict = None): +def test_model(cache_dir: str, model_name: str, context: dict = None) -> dict: """ Test the segmentation service. """ @@ -159,7 +159,7 @@ def test_model(cache_dir: str, model_name: str, context: dict = None): @ray.remote -def test_model_ray(kwargs: dict): +def test_model_ray(kwargs: dict) -> dict: from bioimageio_colab.register_sam_service import test_model return test_model(**kwargs) @@ -184,36 +184,37 @@ async def register_service(args: dict) -> None: ) client_id = colab_client.config["client_id"] client_base_url = f"{args.server_url}/{args.workspace_name}/services/{client_id}" + cache_dir = os.path.abspath(args.cache_dir) if args.ray_address: # Create runtime environment base_requirements = parse_requirements("../requirements.txt") sam_requirements = parse_requirements("../requirements-sam.txt") - runtime_env = { - "pip": { - "pip": base_requirements + sam_requirements, - }, + "pip": base_requirements + sam_requirements, "py_modules": ["../bioimageio_colab"] } + # Connect to Ray ray.init(runtime_env=runtime_env, address=args.ray_address) - def compute_embedding_function(kwargs: dict): + def compute_embedding_function(**kwargs: dict): + kwargs["cache_dir"] = cache_dir return ray.get(compute_image_embedding_ray.remote(kwargs)) - - def compute_mask_function(kwargs: dict): + + def compute_mask_function(**kwargs: dict): + kwargs["cache_dir"] = cache_dir return ray.get(compute_mask_ray.remote(kwargs)) - - def test_model_function(kwargs: dict): + + def test_model_function(**kwargs: dict): + kwargs["cache_dir"] = cache_dir return ray.get(test_model_ray.remote(kwargs)) else: - compute_embedding_function = compute_image_embedding - compute_mask_function = compute_mask - test_model_function = test_model + compute_embedding_function = partial(compute_image_embedding, cache_dir=cache_dir) + compute_mask_function = partial(compute_mask, cache_dir=cache_dir) + test_model_function = partial(test_model, cache_dir=cache_dir) # Register a new service - cache_dir = os.path.abspath(args.cache_dir) service_info = await colab_client.register_service( { "name": "Interactive Segmentation", @@ -226,9 +227,9 @@ def test_model_function(kwargs: dict): # Exposed functions: "hello": hello, "ping": ping, - "compute_embedding": partial(compute_embedding_function, cache_dir=args.cache_dir), - "compute_mask": partial(compute_mask_function, cache_dir=args.cache_dir), - "test_model": partial(test_model_function, cache_dir=args.cache_dir), + "compute_embedding": compute_embedding_function, + "compute_mask": compute_mask_function, + "test_model": test_model_function, } ) sid = service_info["id"] From a1bffafe40ce56a9b8ee3d9f2a51a73daa16dfcf Mon Sep 17 00:00:00 2001 From: Nils Mechtel Date: Sat, 14 Dec 2024 22:54:57 +0100 Subject: [PATCH 8/8] reformat --- bioimageio_colab/register_sam_service.py | 29 +++++++++++++++++------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/bioimageio_colab/register_sam_service.py b/bioimageio_colab/register_sam_service.py index a3614cd..8b12c8d 100644 --- a/bioimageio_colab/register_sam_service.py +++ b/bioimageio_colab/register_sam_service.py @@ -36,7 +36,11 @@ def parse_requirements(file_path) -> list: lines = file.readlines() # Filter and clean package names (skip comments and empty lines) skip_lines = ("#", "-r ", "ray") - packages = [line.strip() for line in lines if line.strip() and not line.startswith(skip_lines)] + packages = [ + line.strip() + for line in lines + if line.strip() and not line.startswith(skip_lines) + ] return packages @@ -59,7 +63,9 @@ def compute_image_embedding( """ try: user_id = context["user"].get("id") if context else "anonymous" - logger.info(f"User '{user_id}' - Computing embedding (model: '{model_name}')...") + logger.info( + f"User '{user_id}' - Computing embedding (model: '{model_name}')..." + ) sam_predictor = load_model_from_ckpt( model_name=model_name, @@ -76,7 +82,8 @@ def compute_image_embedding( except Exception as e: logger.error(f"User '{user_id}' - Error computing embedding: {e}") raise e - + + @ray.remote def compute_image_embedding_ray(kwargs: dict) -> np.ndarray: from bioimageio_colab.register_sam_service import compute_image_embedding @@ -112,7 +119,9 @@ def compute_mask( # Set the embedding sam_predictor.original_size = image_size - sam_predictor.input_size = tuple([sam_predictor.model.image_encoder.img_size] * 2) + sam_predictor.input_size = tuple( + [sam_predictor.model.image_encoder.img_size] * 2 + ) sam_predictor.features = torch.as_tensor(embedding, device=sam_predictor.device) sam_predictor.is_image_set = True @@ -120,7 +129,7 @@ def compute_mask( masks = segment_image( sam_predictor=sam_predictor, point_coords=point_coords, - point_labels=point_labels + point_labels=point_labels, ) if format == "mask": @@ -134,7 +143,8 @@ def compute_mask( except Exception as e: logger.error(f"User '{user_id}' - Error segmenting image: {e}") raise e - + + @ray.remote def compute_mask_ray(kwargs: dict) -> np.ndarray: from bioimageio_colab.register_sam_service import compute_mask @@ -192,7 +202,7 @@ async def register_service(args: dict) -> None: sam_requirements = parse_requirements("../requirements-sam.txt") runtime_env = { "pip": base_requirements + sam_requirements, - "py_modules": ["../bioimageio_colab"] + "py_modules": ["../bioimageio_colab"], } # Connect to Ray @@ -209,8 +219,11 @@ def compute_mask_function(**kwargs: dict): def test_model_function(**kwargs: dict): kwargs["cache_dir"] = cache_dir return ray.get(test_model_ray.remote(kwargs)) + else: - compute_embedding_function = partial(compute_image_embedding, cache_dir=cache_dir) + compute_embedding_function = partial( + compute_image_embedding, cache_dir=cache_dir + ) compute_mask_function = partial(compute_mask, cache_dir=cache_dir) test_model_function = partial(test_model, cache_dir=cache_dir)