Skip to content

Commit 153fe4c

Browse files
authored
New API: PathGraph dataclass (#246)
The Skeleton class is a bit of a hodgepodge of data, generated data, things that should be properties, manually cached properties, and functions that should not be part of the class at all. Thanks to the helpful prodding of @kevinyamauchi, this PR attempts to simplify the concepts and data structures in skan into a simple class that can be created without an image — since the path graph, not the image, is the core of the computational abilities of skan. This is still a work in progress but should already be useful: if you have a graph as a `scipy.sparse.csr_array` (note: `csr_array`, not the deprecated `csr_matrix` used by skan so far) generated through your own means, you can make a "Skeleton" as follows: ```python from skan.csr import PathGraph, Skeleton, summarize g = PathGraph.from_graph( node_coordinates=coordinates, # (n, ndim) NumPy array graph=graph, # scipy.sparse.csr_array ) s = Skeleton.from_path_graph(g) summary = summarize(s, separator='_') ``` Still to do in this PR: - allow making PathGraphs from Skeletons - allow `summarize` to take in PathGraphs directly In the future, we might want to deprecate Skeleton altogether, but I'm happy to do this over a long time. @kevinyamauchi, I'm curious what you think about `paths` being a data attribute, even though it is generated. I think it's expensive enough to compute that we want it to be data and serializable. But it kinda breaks the dataclass paradigm a little bit. @DragaDoncila, you should be able to pull down this branch and use it on your networkx tracking graphs after exporting them to csr arrays (I'm pretty sure that's built-in to nx).
1 parent 0e2da68 commit 153fe4c

File tree

3 files changed

+211
-11
lines changed

3 files changed

+211
-11
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# skan v0.13.0
2+
3+
This is a minor step forward from v0.12.x which adds a new API:
4+
`skan.csr.PathGraph` is a more abstract version of a Skeleton, which only
5+
needs a pixel adjacency matrix to work, rather than a full skeleton image.
6+
The motivations for PathGraph are numerous:
7+
8+
- a simpler dataclass object that doesn't require complex instantiation
9+
logic. See e.g. Glyph's [Stop Writing `__init__`
10+
Methods](https://blog.glyph.im/2025/04/stop-writing-init-methods.html)
11+
- making it easier to compute the pixel adjacency matrix separately, for
12+
example [using dask when the images don't fit in
13+
memory](https://blog.dask.org/2021/05/07/skeleton-analysis), and having
14+
an API for which you can provide this matrix (rather than having to
15+
modify a Skeleton instance in-place).
16+
- allowing more flexible use cases, for example to use skan to measure
17+
tracks, as in
18+
[live-image-tracking-tools/traccuracy#251](https://github.com/live-image-tracking-tools/traccuracy/pull/251).
19+
20+
Due to some urgent need to use this code in the wild, this release doesn't
21+
provide any documentation examples. Indeed, the new class may see some changes
22+
in upcoming releases based on user feedback. See the discussion in
23+
[jni/skan#246](https://github.com/jni/skan/pull/246) for details. Look for
24+
further refinement of this idea in the 0.13.x releases!
25+
26+
## New features
27+
28+
- [#246](https://github.com/jni/skan/pull/246): New API: PathGraph dataclass
29+

src/skan/csr.py

Lines changed: 149 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from __future__ import annotations
22

3+
from dataclasses import dataclass
4+
from functools import cached_property
5+
36
import networkx as nx
47
import numpy as np
58
import pandas as pd
9+
import scipy
610
from scipy import sparse, ndimage as ndi
711
from scipy.sparse import csgraph
812
from scipy.spatial import distance_matrix
@@ -203,8 +207,8 @@ def csr_to_nbgraph(csr, node_props=None):
203207
node_props = np.broadcast_to(1., csr.shape[0])
204208
node_props.flags.writeable = True
205209
return NBGraph(
206-
csr.indptr,
207-
csr.indices,
210+
csr.indptr.astype(np.int32, copy=False),
211+
csr.indices.astype(np.int32, copy=False),
208212
csr.data,
209213
np.array(csr.shape, dtype=np.int32),
210214
node_props.astype(np.float64),
@@ -410,8 +414,9 @@ def _build_skeleton_path_graph(graph):
410414
degrees = np.diff(graph.indptr)
411415
visited_data = np.zeros(graph.data.shape, dtype=bool)
412416
visited = NBGraphBool(
413-
graph.indptr, graph.indices, visited_data, graph.shape,
414-
np.broadcast_to(1.0, graph.shape[0])
417+
graph.indptr.astype(np.int32, copy=False),
418+
graph.indices.astype(np.int32, copy=False), visited_data,
419+
graph.shape, np.broadcast_to(1.0, graph.shape[0])
415420
)
416421
endpoints = (degrees != 2)
417422
endpoint_degrees = degrees[endpoints]
@@ -440,6 +445,122 @@ def _build_skeleton_path_graph(graph):
440445
return paths
441446

442447

448+
@dataclass
449+
class PathGraph:
450+
"""Generalization of a skeleton.
451+
452+
A morphological skeleton is a collection of single pixel wide paths in
453+
a binary (or, more generally, quantitative) image. Skan was created to
454+
make measurements of those paths in image data. It turns out, though,
455+
that the neighborhood topology of those paths (a graph containing long
456+
strings of nodes of degree 2) is more generally useful, and indeed makes
457+
sense without a source image or without representing pixels. (One example:
458+
tracklets in cell tracking data.)
459+
460+
In the text below, we use the following notation:
461+
462+
- N: the number of points in the pixel skeleton,
463+
- ndim: the dimensionality of the skeleton
464+
- P: the number of paths in the skeleton (also the number of links in the
465+
junction graph).
466+
- J: the number of junction nodes
467+
- Sd: the sum of the degrees of all the junction nodes
468+
- [Nt], [Np], Nr, Nc: the dimensions of the source image
469+
"""
470+
adj: scipy.sparse.csr_array # pixel/node-id neighbor adjacency matrix
471+
node_coordinates: np.ndarray | None
472+
node_values: np.ndarray | None
473+
paths: scipy.sparse.csr_array # paths[i, j] = 1 iff coord j is in path i
474+
spacing: float | tuple[float, ...] = 1 # spatial scale between coordinates
475+
476+
@classmethod
477+
def from_graph(cls, *, adj, node_coordinates, node_values=None, spacing=1):
478+
"""Build a PathGraph from an adjacency matrix and node coordinates.
479+
480+
Parameters
481+
----------
482+
adj : scipy.sparse.csr_array
483+
An adjacency matrix where adj[i, j] is nonzero iff there is an
484+
edge between node i and node j.
485+
node_coordinates : np.ndarray, shape (N, ndim)
486+
The coordinates of the nodes. node_coordinates[i] is the
487+
coordinate of node i. The indices of these coordinates must match
488+
the indices of adj.
489+
node_values : np.ndarray, shape (N,)
490+
Values of the nodes. Could be image intensity, height, or some
491+
other quantity of which you want to compute statistics along the
492+
path.
493+
spacing : float or tuple of float, shape (ndim,)
494+
The pixel/voxel spacing along each axis coordinate.
495+
"""
496+
nbgraph = csr_to_nbgraph(adj, node_values)
497+
paths = _build_skeleton_path_graph(nbgraph)
498+
return cls(adj, node_coordinates, node_values, paths, spacing)
499+
500+
@classmethod
501+
def from_image(cls, skeleton_image, *, spacing=1, value_is_height=False):
502+
"""Build a PathGraph from a skeleton image.
503+
504+
This is just a convenience meant to mirror Skeleton.__init__.
505+
"""
506+
graph, coords = skeleton_to_csgraph(
507+
skeleton_image,
508+
spacing=spacing,
509+
value_is_height=value_is_height,
510+
)
511+
values = _extract_values(skeleton_image, coords)
512+
return cls.from_graph(
513+
adj=graph,
514+
node_coordinates=np.transpose(coords),
515+
node_values=values,
516+
spacing=spacing,
517+
)
518+
519+
@cached_property
520+
def nbgraph(self):
521+
return csr_to_nbgraph(self.adj, self.node_values)
522+
523+
@cached_property
524+
def distances(self):
525+
"""The path distances.
526+
527+
Returns
528+
-------
529+
distances : np.ndarray of float, shape (P,)
530+
distances[i] contains the distance of path i.
531+
"""
532+
distances = np.empty(self.n_paths, dtype=float)
533+
_compute_distances(
534+
self.nbgraph, self.paths.indptr, self.paths.indices, distances
535+
)
536+
return distances
537+
538+
@property
539+
def n_paths(self):
540+
return self.paths.shape[0]
541+
542+
@cached_property
543+
def degrees(self):
544+
"""The degree (number of neighbors) of each node/pixel.
545+
546+
Returns
547+
-------
548+
degrees : np.ndarray of int, shape (N,)
549+
"""
550+
return np.diff(self.adj.indptr)
551+
552+
553+
def _extract_values(image, coords):
554+
if image.dtype == np.bool_:
555+
return None
556+
values = image[coords]
557+
output_dtype = (
558+
np.float64 if np.issubdtype(image.dtype, np.integer) else
559+
image.dtype
560+
)
561+
return values.astype(output_dtype, copy=False)
562+
563+
443564
class Skeleton:
444565
"""Object to group together all the properties of a skeleton.
445566
@@ -522,12 +643,7 @@ def __init__(
522643
spacing=spacing,
523644
value_is_height=value_is_height,
524645
)
525-
if np.issubdtype(skeleton_image.dtype, np.floating):
526-
self.pixel_values = skeleton_image[coords]
527-
elif np.issubdtype(skeleton_image.dtype, np.integer):
528-
self.pixel_values = skeleton_image.astype(np.float64)[coords]
529-
else:
530-
self.pixel_values = None
646+
self.pixel_values = _extract_values(skeleton_image, coords)
531647
self.graph = graph
532648
self.nbgraph = csr_to_nbgraph(graph, self.pixel_values)
533649
self.coordinates = np.transpose(coords)
@@ -549,6 +665,26 @@ def __init__(
549665
self.skeleton_image = skeleton_image
550666
self.source_image = source_image
551667

668+
@classmethod
669+
def from_path_graph(cls, pg: PathGraph):
670+
dtype = bool if pg.node_values is None else pg.node_values.dtype
671+
ndim = pg.node_coordinates.shape[1]
672+
dummy_image = np.zeros((4,) * ndim, dtype=dtype)
673+
dummy_image[([1, 2],) * ndim] = 1 # make a single diagonal branch
674+
obj = cls(dummy_image, spacing=pg.spacing, keep_images=False)
675+
obj.pixel_values = pg.node_values
676+
obj.graph = pg.adj
677+
obj.nbgraph = pg.nbgraph
678+
obj.coordinates = pg.node_coordinates
679+
obj.paths = pg.paths
680+
obj.n_paths = pg.n_paths
681+
obj.distances = pg.distances
682+
obj.skeleton_shape = np.max(obj.coordinates, axis=0) + 1
683+
obj.skeleton_dtype = dtype
684+
obj.degrees = pg.degrees
685+
obj.spacing = pg.spacing
686+
return obj
687+
552688
def path(self, index):
553689
"""Return the pixel indices of path number `index`.
554690
@@ -721,7 +857,7 @@ def __array__(self, dtype=None):
721857

722858

723859
def summarize(
724-
skel: Skeleton,
860+
skel: Skeleton | PathGraph,
725861
*,
726862
value_is_height: bool = False,
727863
find_main_branch: bool = False,
@@ -754,6 +890,8 @@ def summarize(
754890
A summary of the branches including branch length, mean branch value,
755891
branch euclidean distance, etc.
756892
"""
893+
if isinstance(skel, PathGraph):
894+
skel = Skeleton.from_path_graph(skel)
757895
if separator is None:
758896
warnings.warn(
759897
"separator in column name will change to _ in version 0.13; "

src/skan/test/test_csr.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,39 @@ def test_skeletonlabel():
384384
assert stats['mean-pixel-value'].max() > 1
385385

386386

387+
@pytest.mark.parametrize(('np_skeleton', 'spacing', 'dtype'),
388+
product(
389+
[
390+
tinycycle, tinyline, skeleton0,
391+
skeleton1, skeleton2, skeleton3d
392+
],
393+
[1.0, 2.0, (5.0, 2.5, 2.5)],
394+
[bool, np.float32, np.float64],
395+
))
396+
def test_pathgraph_skeleton_equiv(np_skeleton, spacing, dtype):
397+
if dtype is not bool:
398+
np_skeleton = (
399+
np_skeleton
400+
* np.random.random(np_skeleton.shape).astype(dtype)
401+
)
402+
if isinstance(spacing, tuple):
403+
spacing = spacing[:np_skeleton.ndim] # truncate spacing to image ndim
404+
s = csr.Skeleton(np_skeleton, spacing=spacing)
405+
p = csr.PathGraph.from_image(np_skeleton, spacing=spacing)
406+
p2 = csr.PathGraph.from_graph(
407+
adj=p.adj,
408+
node_coordinates=p.node_coordinates,
409+
node_values=p.node_values,
410+
spacing=spacing,
411+
)
412+
ss = csr.summarize(s)
413+
sp = csr.summarize(p)
414+
sp2 = csr.summarize(p2)
415+
416+
np.testing.assert_allclose(ss.to_numpy(), sp.to_numpy())
417+
np.testing.assert_allclose(sp.to_numpy(), sp2.to_numpy())
418+
419+
387420
@pytest.mark.parametrize(
388421
('np_skeleton', 'summary', 'nodes', 'edges'),
389422
[

0 commit comments

Comments
 (0)