diff --git a/bioimageio_colab/register_sam_service.py b/bioimageio_colab/register_sam_service.py index 4eb6b35..412277a 100644 --- a/bioimageio_colab/register_sam_service.py +++ b/bioimageio_colab/register_sam_service.py @@ -84,7 +84,7 @@ def compute_image_embedding( raise e -@ray.remote +@ray.remote(num_gpus=1) def compute_image_embedding_ray(kwargs: dict) -> np.ndarray: from bioimageio_colab.register_sam_service import compute_image_embedding @@ -145,7 +145,7 @@ def compute_mask( raise e -@ray.remote +@ray.remote(num_gpus=1) def compute_mask_ray(kwargs: dict) -> np.ndarray: from bioimageio_colab.register_sam_service import compute_mask @@ -158,7 +158,7 @@ def test_model(cache_dir: str, model_name: str, context: dict = None) -> dict: """ user_id = context["user"].get("id") if context else "anonymous" logger.info(f"User '{user_id}' - Test run for model '{model_name}'...") - + image = np.random.rand(1024, 1024) embedding = compute_image_embedding( cache_dir=cache_dir, @@ -171,7 +171,7 @@ def test_model(cache_dir: str, model_name: str, context: dict = None) -> dict: return {"status": "ok"} -@ray.remote +@ray.remote(num_gpus=1) def test_model_ray(kwargs: dict) -> dict: from bioimageio_colab.register_sam_service import test_model @@ -214,7 +214,11 @@ async def register_service(args: dict) -> None: } # Connect to Ray - ray.init(runtime_env=runtime_env, address=args.ray_address) + ray.init( + address=args.ray_address, + num_gpus=1, + runtime_env=runtime_env, + ) def compute_embedding_function(**kwargs: dict): kwargs["cache_dir"] = cache_dir