Skip to content

Commit a1bffaf

Browse files
committed
reformat
1 parent 2dd8dce commit a1bffaf

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

bioimageio_colab/register_sam_service.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ def parse_requirements(file_path) -> list:
3636
lines = file.readlines()
3737
# Filter and clean package names (skip comments and empty lines)
3838
skip_lines = ("#", "-r ", "ray")
39-
packages = [line.strip() for line in lines if line.strip() and not line.startswith(skip_lines)]
39+
packages = [
40+
line.strip()
41+
for line in lines
42+
if line.strip() and not line.startswith(skip_lines)
43+
]
4044
return packages
4145

4246

@@ -59,7 +63,9 @@ def compute_image_embedding(
5963
"""
6064
try:
6165
user_id = context["user"].get("id") if context else "anonymous"
62-
logger.info(f"User '{user_id}' - Computing embedding (model: '{model_name}')...")
66+
logger.info(
67+
f"User '{user_id}' - Computing embedding (model: '{model_name}')..."
68+
)
6369

6470
sam_predictor = load_model_from_ckpt(
6571
model_name=model_name,
@@ -76,7 +82,8 @@ def compute_image_embedding(
7682
except Exception as e:
7783
logger.error(f"User '{user_id}' - Error computing embedding: {e}")
7884
raise e
79-
85+
86+
8087
@ray.remote
8188
def compute_image_embedding_ray(kwargs: dict) -> np.ndarray:
8289
from bioimageio_colab.register_sam_service import compute_image_embedding
@@ -112,15 +119,17 @@ def compute_mask(
112119

113120
# Set the embedding
114121
sam_predictor.original_size = image_size
115-
sam_predictor.input_size = tuple([sam_predictor.model.image_encoder.img_size] * 2)
122+
sam_predictor.input_size = tuple(
123+
[sam_predictor.model.image_encoder.img_size] * 2
124+
)
116125
sam_predictor.features = torch.as_tensor(embedding, device=sam_predictor.device)
117126
sam_predictor.is_image_set = True
118127

119128
# Segment the image
120129
masks = segment_image(
121130
sam_predictor=sam_predictor,
122131
point_coords=point_coords,
123-
point_labels=point_labels
132+
point_labels=point_labels,
124133
)
125134

126135
if format == "mask":
@@ -134,7 +143,8 @@ def compute_mask(
134143
except Exception as e:
135144
logger.error(f"User '{user_id}' - Error segmenting image: {e}")
136145
raise e
137-
146+
147+
138148
@ray.remote
139149
def compute_mask_ray(kwargs: dict) -> np.ndarray:
140150
from bioimageio_colab.register_sam_service import compute_mask
@@ -192,7 +202,7 @@ async def register_service(args: dict) -> None:
192202
sam_requirements = parse_requirements("../requirements-sam.txt")
193203
runtime_env = {
194204
"pip": base_requirements + sam_requirements,
195-
"py_modules": ["../bioimageio_colab"]
205+
"py_modules": ["../bioimageio_colab"],
196206
}
197207

198208
# Connect to Ray
@@ -209,8 +219,11 @@ def compute_mask_function(**kwargs: dict):
209219
def test_model_function(**kwargs: dict):
210220
kwargs["cache_dir"] = cache_dir
211221
return ray.get(test_model_ray.remote(kwargs))
222+
212223
else:
213-
compute_embedding_function = partial(compute_image_embedding, cache_dir=cache_dir)
224+
compute_embedding_function = partial(
225+
compute_image_embedding, cache_dir=cache_dir
226+
)
214227
compute_mask_function = partial(compute_mask, cache_dir=cache_dir)
215228
test_model_function = partial(test_model, cache_dir=cache_dir)
216229

0 commit comments

Comments
 (0)