Skip to content

Commit

Permalink
Visium hd rasterize bins labels (#811)
Browse files Browse the repository at this point in the history
* rasterize bins labels

* rasterize bins labels

* fix mypy

* minor fixes in docstrings, mypy, exceptions, todos

* fix tests/pre-commit (merge with main)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test

* fix docs

* add tests

* cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* moved _get_uint_dtype() to models._utils; fix docs

---------

Co-authored-by: LucaMarconato <[email protected]>
Co-authored-by: Luca Marconato <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Jan 2, 2025
1 parent 7bce868 commit cf0445b
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 67 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning][].
### Major

- Added attributes at the SpatialData object level (`.attrs`)
- `rasterize_bins()` can now produce a labels element #811 @ArneDefauw

## [0.2.6] - 2024-11-26

Expand Down
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Operations on `SpatialData` objects.
transform
rasterize
rasterize_bins
rasterize_bins_link_table_to_labels
to_circles
to_polygons
aggregate
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"concatenate",
"rasterize",
"rasterize_bins",
"rasterize_bins_link_table_to_labels",
"to_circles",
"to_polygons",
"transform",
Expand Down Expand Up @@ -61,7 +62,7 @@
from spatialdata._core.operations.aggregate import aggregate
from spatialdata._core.operations.map import map_raster, relabel_sequential
from spatialdata._core.operations.rasterize import rasterize
from spatialdata._core.operations.rasterize_bins import rasterize_bins
from spatialdata._core.operations.rasterize_bins import rasterize_bins, rasterize_bins_link_table_to_labels
from spatialdata._core.operations.transform import transform
from spatialdata._core.operations.vectorize import to_circles, to_polygons
from spatialdata._core.query._utils import get_bounding_box_corners
Expand Down
181 changes: 140 additions & 41 deletions src/spatialdata/_core/operations/rasterize_bins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dask.array as da
import numpy as np
import pandas as pd
from anndata import AnnData
from dask.dataframe import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame
from numpy.random import default_rng
Expand All @@ -14,12 +15,16 @@
from xarray import DataArray

from spatialdata._core.query.relational_query import get_values
from spatialdata._logging import logger
from spatialdata._types import ArrayLike
from spatialdata.models import Image2DModel, get_table_keys
from spatialdata.models import Image2DModel, Labels2DModel, get_table_keys
from spatialdata.models._utils import _get_uint_dtype
from spatialdata.transformations import Affine, Sequence, get_transformation

RNG = default_rng(0)

__all__ = ["rasterize_bins", "rasterize_bins_link_table_to_labels"]


if TYPE_CHECKING:
from spatialdata import SpatialData
Expand All @@ -32,6 +37,7 @@ def rasterize_bins(
col_key: str,
row_key: str,
value_key: str | list[str] | None = None,
return_region_as_labels: bool = False,
) -> DataArray:
"""
Rasterizes grid-like binned shapes/points annotated by a table (e.g. Visium HD data).
Expand All @@ -51,6 +57,14 @@ def rasterize_bins(
value_key
The key(s) (obs columns/var names) in the table that will be used to rasterize the bins.
If `None`, all the var names will be used, and the returned object will be lazily constructed.
Ignored if `return_region_as_labels` is `True`.
return_regions_as_labels
If `False` this function returns a `xarray.DataArray` of shape `(c, y, x)` with dimension
of `c` equal to the number of key(s) specified in `value_key`, or the number of var names
in `table_name` if `value_key` is `None`. If `True`, will return labels of shape `(y, x)`,
where each bin of the `bins` element will be represented as a pixel. The table by default will not be set to
annotate the new rasterized labels; this can be achieved using the helper function
`spatialdata.rasterize_bins_link_table_to_labels()`.
Returns
-------
Expand All @@ -73,24 +87,93 @@ def rasterize_bins(
"""
element = sdata[bins]
table = sdata.tables[table_name]
if not isinstance(element, GeoDataFrame | DaskDataFrame):
raise ValueError("The bins should be a GeoDataFrame or a DaskDataFrame.")
if not isinstance(element, GeoDataFrame | DaskDataFrame | DataArray):
raise ValueError("The bins should be a GeoDataFrame, a DaskDataFrame or a DataArray.")
if isinstance(element, DataArray):
if "c" in element.dims:
raise ValueError(
"If bins is a DataArray, it should hold labels; found a image element instead, with"
f" 'c': {element.dims}."
)
if not np.issubdtype(element.dtype, np.integer):
raise ValueError(f"If bins is a DataArray, it should hold integers. Found dtype {element.dtype}.")

_, region_key, instance_key = get_table_keys(table)
if not table.obs[region_key].dtype == "category":
raise ValueError(f"Please convert `table.obs['{region_key}']` to a category series to improve performances")
unique_regions = table.obs[region_key].cat.categories
if len(unique_regions) > 1 or unique_regions[0] != bins:
if len(unique_regions) > 1:
raise ValueError(
f"Found multiple regions annotated by the table: {', '.join(list(unique_regions))}, "
"currently only tables annotating a single region are supported. Please open a feature request if you are "
"interested in the general case."
)
if unique_regions[0] != bins:
raise ValueError("The table should be associated with the specified bins.")

if isinstance(element, DataArray) and return_region_as_labels:
raise ValueError(
"The table should be associated with the specified bins. "
f"Found multiple regions annotated by the table: {', '.join(list(unique_regions))}."
f"bins is already a labels layer that annotates the table '{table_name}'. "
"Consider setting 'return_region_as_labels' to 'False' to create a lazy spatial image."
)

min_row, min_col = table.obs[row_key].min(), table.obs[col_key].min()
n_rows, n_cols = table.obs[row_key].max() - min_row + 1, table.obs[col_key].max() - min_col + 1
y = (table.obs[row_key] - min_row).values
x = (table.obs[col_key] - min_col).values

if isinstance(element, DataArray):
transformations = get_transformation(element, get_all=True)
assert isinstance(transformations, dict)
else:
# get the transformation
if table.n_obs < 6:
raise ValueError("At least 6 bins are needed to estimate the transformation.")

random_indices = RNG.choice(table.n_obs, min(20, table.n_obs), replace=True)
location_ids = table.obs[instance_key].iloc[random_indices].values
sub_df, sub_table = element.loc[location_ids], table[random_indices]

src = np.stack([sub_table.obs[col_key] - min_col, sub_table.obs[row_key] - min_row], axis=1)
if isinstance(sub_df, GeoDataFrame):
if isinstance(sub_df.iloc[0].geometry, Point):
sub_x = sub_df.geometry.x.values
sub_y = sub_df.geometry.y.values
else:
assert isinstance(sub_df.iloc[0].geometry, Polygon | MultiPolygon)
sub_x = sub_df.centroid.x
sub_y = sub_df.centroid.y
else:
assert isinstance(sub_df, DaskDataFrame)
sub_x = sub_df.x.compute().values
sub_y = sub_df.y.compute().values
dst = np.stack([sub_x, sub_y], axis=1)

to_bins = Sequence(
[
Affine(
estimate_transform(ttype="affine", src=src, dst=dst).params,
input_axes=("x", "y"),
output_axes=("x", "y"),
)
]
)
bins_transformations = get_transformation(element, get_all=True)

assert isinstance(bins_transformations, dict)

transformations = {cs: to_bins.compose_with(t) for cs, t in bins_transformations.items()}

if return_region_as_labels:
new_instance_key = _get_relabeled_column_name(instance_key)
table.obs[new_instance_key] = _relabel_labels(table=table, instance_key=instance_key)
dtype = table.obs[new_instance_key].dtype
labels_element = np.zeros((n_rows, n_cols), dtype=dtype)
# make labels layer that can visualy represent the cells
labels_element[y, x] = table.obs[new_instance_key].values.T

return Labels2DModel.parse(data=labels_element, dims=("y", "x"), transformations=transformations)

keys = ([value_key] if isinstance(value_key, str) else value_key) if value_key is not None else table.var_names

if (value_key is None or any(key in table.var_names for key in keys)) and not isinstance(
Expand All @@ -115,7 +198,6 @@ def rasterize_bins(
shape = (n_rows, n_cols)

def channel_rasterization(block_id: tuple[int, int, int] | None) -> ArrayLike:

image: ArrayLike = np.zeros((1, *shape), dtype=dtype)

if block_id is None:
Expand Down Expand Up @@ -148,42 +230,59 @@ def channel_rasterization(block_id: tuple[int, int, int] | None) -> ArrayLike:
else:
image[i, y, x] = table.X[:, key_index]

# get the transformation
if table.n_obs < 6:
raise ValueError("At least 6 bins are needed to estimate the transformation.")
return Image2DModel.parse(
data=image,
dims=("c", "y", "x"),
transformations=transformations,
c_coords=keys,
)

random_indices = RNG.choice(table.n_obs, min(20, table.n_obs), replace=True)
location_ids = table.obs[instance_key].iloc[random_indices].values
sub_df, sub_table = element.loc[location_ids], table[random_indices]

src = np.stack([sub_table.obs[col_key] - min_col, sub_table.obs[row_key] - min_row], axis=1)
if isinstance(sub_df, GeoDataFrame):
if isinstance(sub_df.iloc[0].geometry, Point):
sub_x = sub_df.geometry.x.values
sub_y = sub_df.geometry.y.values
else:
assert isinstance(sub_df.iloc[0].geometry, Polygon | MultiPolygon)
sub_x = sub_df.centroid.x
sub_y = sub_df.centroid.y
else:
assert isinstance(sub_df, DaskDataFrame)
sub_x = sub_df.x.compute().values
sub_y = sub_df.y.compute().values
dst = np.stack([sub_x, sub_y], axis=1)

to_bins = Sequence(
[
Affine(
estimate_transform(ttype="affine", src=src, dst=dst).params,
input_axes=("x", "y"),
output_axes=("x", "y"),
)
]
)
bins_transformations = get_transformation(element, get_all=True)
def _get_relabeled_column_name(column_name: str) -> str:
return f"relabeled_{column_name}"

assert isinstance(bins_transformations, dict)

transformations = {cs: to_bins.compose_with(t) for cs, t in bins_transformations.items()}
def _relabel_labels(table: AnnData, instance_key: str) -> pd.Series:
labels_values_count = len(table.obs[instance_key].unique())

return Image2DModel.parse(image, transformations=transformations, c_coords=keys, dims=("c", "y", "x"))
is_not_numeric = not np.issubdtype(table.obs[instance_key].dtype, np.number)
zero_in_instance_key = 0 in table.obs[instance_key].values
has_gaps = not is_not_numeric and labels_values_count != table.obs[instance_key].max() + int(zero_in_instance_key)

relabeling_is_needed = is_not_numeric or zero_in_instance_key or has_gaps
if relabeling_is_needed:
logger.info(
f"The instance_key column in 'table.obs' ('table.obs[{instance_key}]') will be relabeled to ensure"
" a numeric data type, with a continuous range and without including the value 0 (which is reserved "
"for the background). The new labels will be stored in a new column named "
f"{_get_relabeled_column_name(instance_key)!r}."
)

relabeled_instance_key_column = table.obs[instance_key].astype("category").cat.codes + int(zero_in_instance_key)
# uses only allowed dtypes that passes our model validations, in particuar no uint8
dtype = _get_uint_dtype(value=relabeled_instance_key_column.max())
return relabeled_instance_key_column.astype(dtype)


def rasterize_bins_link_table_to_labels(sdata: SpatialData, table_name: str, rasterized_labels_name: str) -> None:
"""
Change the annotation target of the table to the rasterized labels.
This function should be called after having rasterized the bins (calling `rasterize_bins()` with
`return_regions_as_labels=True`) and after having added the rasterized labels to the spatial data object.
Parameters
----------
sdata
The spatial data object containing the rasterized labels.
table_name
The name of the table to be annotated.
rasterized_labels_name
The name of the rasterized labels in the spatial data object.
"""
_, region_key, instance_key = get_table_keys(sdata[table_name])
sdata[table_name].obs[region_key] = rasterized_labels_name
relabled_instance_key = _get_relabeled_column_name(instance_key)
sdata.set_table_annotates_spatialelement(
table_name=table_name, region=rasterized_labels_name, region_key=region_key, instance_key=relabled_instance_key
)
20 changes: 19 additions & 1 deletion src/spatialdata/models/_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import warnings
from functools import singledispatch
from typing import TYPE_CHECKING, Any, TypeAlias
Expand Down Expand Up @@ -367,7 +369,7 @@ def force_2d(gdf: GeoDataFrame) -> None:
gdf.geometry = new_shapes


def get_raster_model_from_data_dims(dims: tuple[str, ...]) -> type["RasterSchema"]:
def get_raster_model_from_data_dims(dims: tuple[str, ...]) -> type[RasterSchema]:
"""
Get the raster model from the dimensions of the data.
Expand Down Expand Up @@ -435,3 +437,19 @@ def set_channel_names(element: DataArray | DataTree, channel_names: str | list[s
raise TypeError("Element model does not support setting channel names, no `c` dimension found.")

return element


def _get_uint_dtype(value: int) -> str:
max_uint64 = np.iinfo(np.uint64).max
max_uint32 = np.iinfo(np.uint32).max
max_uint16 = np.iinfo(np.uint16).max

if max_uint16 >= value:
dtype = "uint16"
elif max_uint32 >= value:
dtype = "uint32"
elif max_uint64 >= value:
dtype = "uint64"
else:
raise ValueError(f"Maximum cell number is {value}. Values higher than {max_uint64} are not supported.")
return dtype
8 changes: 5 additions & 3 deletions src/spatialdata/transformations/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from functools import singledispatch
from typing import TYPE_CHECKING, Any, Optional, Union

Expand All @@ -8,7 +9,6 @@
from geopandas import GeoDataFrame
from xarray import DataArray, Dataset, DataTree

from spatialdata._logging import logger
from spatialdata._types import ArrayLike

if TYPE_CHECKING:
Expand Down Expand Up @@ -253,10 +253,12 @@ def scale_radii(radii: ArrayLike, affine: Affine, axes: tuple[str, ...]) -> Arra
modules = np.absolute(eigenvalues)
if not np.allclose(modules, modules[0]):
scale_factor = np.mean(modules)
logger.warning(
warnings.warn(
"The vector part of the transformation matrix is not isotropic, the radius will be scaled by the average "
f"of the modules of eigenvalues of the affine transformation matrix.\nmatrix={matrix}\n"
f"eigenvalues={eigenvalues}\nscale_factor={scale_factor}"
f"eigenvalues={eigenvalues}\nscale_factor={scale_factor}",
UserWarning,
stacklevel=2,
)
else:
scale_factor = modules[0]
Expand Down
Loading

0 comments on commit cf0445b

Please sign in to comment.