Skip to content

Commit fd4be38

Browse files
committed
Adds GenerateHeatmap transform
Adds a `GenerateHeatmap` transform to generate heatmaps from point data. This transform creates heatmaps from point data, validating that the dtype is a floating-point type. Signed-off-by: sewon.jeon <[email protected]>
1 parent 1b5888b commit fd4be38

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

monai/transforms/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@
293293
AsDiscrete,
294294
DistanceTransformEDT,
295295
FillHoles,
296+
GenerateHeatmap,
296297
Invert,
297298
KeepLargestConnectedComponent,
298299
LabelFilter,
@@ -319,6 +320,9 @@
319320
FillHolesD,
320321
FillHolesd,
321322
FillHolesDict,
323+
GenerateHeatmapd,
324+
GenerateHeatmapD,
325+
GenerateHeatmapDict,
322326
InvertD,
323327
Invertd,
324328
InvertDict,

monai/transforms/post/array.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,9 @@ def __init__(
799799
self.normalize = normalize
800800
self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor)
801801
self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray)
802+
# Validate that dtype is floating-point for meaningful Gaussian values
803+
if self.torch_dtype not in (torch.float16, torch.float32, torch.float64, torch.bfloat16):
804+
raise ValueError(f"dtype must be a floating-point type, got {self.torch_dtype}")
802805
self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape)
803806

804807
def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor:

0 commit comments

Comments
 (0)