Skip to content

Commit

Permalink
specify GPU allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsmechtel committed Dec 16, 2024
1 parent 63e5c58 commit 17ed877
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions bioimageio_colab/register_sam_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def compute_image_embedding(
raise e


@ray.remote
@ray.remote(num_gpus=1)
def compute_image_embedding_ray(kwargs: dict) -> np.ndarray:
from bioimageio_colab.register_sam_service import compute_image_embedding

Expand Down Expand Up @@ -145,7 +145,7 @@ def compute_mask(
raise e


@ray.remote
@ray.remote(num_gpus=1)
def compute_mask_ray(kwargs: dict) -> np.ndarray:
from bioimageio_colab.register_sam_service import compute_mask

Expand All @@ -158,7 +158,7 @@ def test_model(cache_dir: str, model_name: str, context: dict = None) -> dict:
"""
user_id = context["user"].get("id") if context else "anonymous"
logger.info(f"User '{user_id}' - Test run for model '{model_name}'...")

image = np.random.rand(1024, 1024)
embedding = compute_image_embedding(
cache_dir=cache_dir,
Expand All @@ -171,7 +171,7 @@ def test_model(cache_dir: str, model_name: str, context: dict = None) -> dict:
return {"status": "ok"}


@ray.remote
@ray.remote(num_gpus=1)
def test_model_ray(kwargs: dict) -> dict:
from bioimageio_colab.register_sam_service import test_model

Expand Down Expand Up @@ -214,7 +214,11 @@ async def register_service(args: dict) -> None:
}

# Connect to Ray
ray.init(runtime_env=runtime_env, address=args.ray_address)
ray.init(
address=args.ray_address,
num_gpus=1,
runtime_env=runtime_env,
)

def compute_embedding_function(**kwargs: dict):
kwargs["cache_dir"] = cache_dir
Expand Down

0 comments on commit 17ed877

Please sign in to comment.