Skip to content

ENH: remove sft._data usage part 1 - tractogram coloring scripts + more #1105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
aa132ac
ENH: remote sft._data usage
AntoineTheb Dec 12, 2024
e0808c4
Merge branch 'master' into atheb/sft_data
AntoineTheb Dec 12, 2024
c8b0cf6
ENH: more _data cleanup
AntoineTheb Dec 12, 2024
4fbfdc9
Merge remote-tracking branch 'upstream/master' into atheb/sft_data
AntoineTheb Dec 17, 2024
9d14d3f
ENH: better dps/dpp handling
AntoineTheb Dec 26, 2024
1132b4a
Merge remote-tracking branch 'upstream/master' into atheb/sft_data
AntoineTheb Dec 26, 2024
4bf9d51
FIX: pep8
AntoineTheb Dec 26, 2024
7806740
Merge remote-tracking branch 'upstream/master' into atheb/sft_data
AntoineTheb Jan 8, 2025
9b961de
ENH: more tests to appease codecov gods + docstring
AntoineTheb Jan 9, 2025
894bb34
Merge remote-tracking branch 'upstream/master' into atheb/sft_data
AntoineTheb Jan 9, 2025
0a01e6f
FIX: pep8 spaces comments
AntoineTheb Jan 9, 2025
e232d65
FIX: comments
AntoineTheb Jan 9, 2025
d39b2b1
Merge branch 'master' into atheb/sft_data
AntoineTheb Jan 13, 2025
1715a25
Merge branch 'master' into atheb/sft_data
AntoineTheb Jan 17, 2025
26edbd2
Merge branch 'master' into atheb/sft_data
AntoineTheb Jan 23, 2025
bc0a45b
Merge remote-tracking branch 'upstream/master' into atheb/sft_data
AntoineTheb Feb 22, 2025
07dd5da
FIX: comments + cut
AntoineTheb Feb 22, 2025
37e5717
FIX: tests + per-point
AntoineTheb Feb 25, 2025
8a7270c
ENH: simplify cut invalid
AntoineTheb Feb 25, 2025
96459aa
Merge remote-tracking branch 'upstream/master' into atheb/sft_data
AntoineTheb Apr 10, 2025
4dadee5
Merge remote-tracking branch 'upstream/master' into atheb/sft_data
AntoineTheb Apr 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 80 additions & 63 deletions scilpy/tractograms/dps_and_dpp_management.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,101 @@
# -*- coding: utf-8 -*-
import numpy as np

from scilpy.viz.color import clip_and_normalize_data_for_cmap
from nibabel.streamlines import ArraySequence


def add_data_as_color_dpp(sft, cmap, data, clip_outliers=False, min_range=None,
max_range=None, min_cmap=None, max_cmap=None,
log=False, LUT=None):
def get_data_as_arraysequence(data, ref_sft, per_point=False):
""" Get data in the same shape as a reference StatefulTractogram's
streamlines, so it can be used to set data_per_point or
data_per_streamline. The data may represent one value per streamline or one
value per point. The function will return an ArraySequence with the same
shape as the streamlines.

Parameters
----------
data: np.ndarray
The data to convert to ArraySequence.
ref_sft: StatefulTractogram
The reference StatefulTractogram containing the streamlines.
per_point: bool, optional
Return one value per point if True, one value per streamline otherwise.

Returns
-------
data_as_arraysequence: ArraySequence
The data as an ArraySequence.
"""
Normalizes data between 0 and 1 for an easier management with colormaps.
The real lower bound and upperbound are returned.
# Check if data has the right shape, either one value per streamline or one
# value per point.
if data.shape[0] == ref_sft._get_streamline_count():
# Two consective if statements to handle both 1D and 2D arrays
# and turn them into lists of lists of lists.
# Check if the data is a vector or a scalar.
if len(data.shape) == 1:
data = data[:, None]
# ArraySequence expects a list of lists of lists, so we need to add
# an extra dimension.
if len(data.shape) == 2:
data = data[:, None, :]

# Repeat the data for each point in the streamline.
if per_point:
data = [
[data[i]]*len(s) for i, s in enumerate(ref_sft.streamlines)]

data_as_arraysequence = ArraySequence(data)

elif data.shape[0] == ref_sft._get_point_count():
# Split the data into a list of arrays, one per streamline.
# np.split takes the indices at which to split the array, so use
# np.cumsum to get the indices of the end of each streamline.
data_split = np.split(
data, np.cumsum(ref_sft.streamlines._lengths)[:-1])
# Create an ArraySequence from the list of arrays.
data_as_arraysequence = ArraySequence(data_split)
else:
raise ValueError("Data has the wrong shape. Expecting either one value"
" per streamline ({}) or one per point ({}) but got "
"{}."
.format(len(ref_sft), len(ref_sft.streamlines._data),
data.shape[0]))
return data_as_arraysequence


Data can be clipped to (min_range, max_range) before normalization.
Alternatively, data can be kept as is, but the colormap be fixed to
(min_cmap, max_cmap).
def add_data_as_color_dpp(sft, color):
"""
Ensures the color data is in the right shape and adds it to the
data_per_point of the StatefulTractogram. The color data must have one
color per point. The function will return the StatefulTractogram with the
color data added.

Parameters
----------
sft: StatefulTractogram
The tractogram
cmap: plt colormap
The colormap. Ex, see scilpy.viz.utils.get_colormap().
data: np.ndarray or list[list] or list[np.ndarray]
The data to convert to color. Expecting one value per point to add as
dpp. If instead data has one value per streamline, setting the same
color to all points of the streamline (as dpp).
Either a vector numpy array (all streamlines concatenated), or a list
of arrays per streamline.
clip_outliers: bool
See description of the following parameters in
clip_and_normalize_data_for_cmap.
min_range: float
Data values below min_range will be clipped.
max_range: float
Data values above max_range will be clipped.
min_cmap: float
Minimum value of the colormap. Most useful when min_range and max_range
are not set; to fix the colormap range without modifying the data.
max_cmap: float
Maximum value of the colormap. Idem.
log: bool
If True, apply a logarithmic scale to the data.
LUT: np.ndarray
If set, replaces the data values by the Look-Up Table values. In order,
the first value of the LUT is set everywhere where data==1, etc.
color: ArraySequence
The color data.

Returns
-------
sft: StatefulTractogram
The tractogram, with dpp 'color' added.
lbound: float
The lower bound of the associated colormap.
ubound: float
The upper bound of the associated colormap.
"""
# If data is a list of lists, merge.
if isinstance(data[0], list) or isinstance(data[0], np.ndarray):
data = np.hstack(data)

values, lbound, ubound = clip_and_normalize_data_for_cmap(
data, clip_outliers, min_range, max_range,
min_cmap, max_cmap, log, LUT)

# Important: values are in float after clip_and_normalize.
color = np.asarray(cmap(values)[:, 0:3]) * 255
if len(color) == len(sft):
tmp = [np.tile([color[i][0], color[i][1], color[i][2]],
(len(sft.streamlines[i]), 1))
for i in range(len(sft.streamlines))]
sft.data_per_point['color'] = tmp
elif len(color) == len(sft.streamlines._data):
sft.data_per_point['color'] = sft.streamlines
sft.data_per_point['color']._data = color
else:
raise ValueError("Error in the code... Colors do not have the right "
"shape. Expecting either one color per streamline "
"({}) or one per point ({}) but got {}."
.format(len(sft), len(sft.streamlines._data),
len(color)))
return sft, lbound, ubound
if color.total_nb_rows != sft._get_point_count():
raise ValueError("Colors do not have the right shape. Expecting one "
"color per point ({}) but got {}.".format(
sft._get_point_count(),
color.total_nb_rows))

# Check if the color data is in the right shape (tuple of 3 values).
if (type(color.common_shape) is not tuple or
len(color.common_shape) < 1 or
color.common_shape[-1] != 3):
raise ValueError("Colors do not have the right shape. Expecting RGB "
"colors but got shape {}.".format(color.common_shape))

sft.data_per_point['color'] = color
return sft


def convert_dps_to_dpp(sft, keys, overwrite=False):
Expand Down
46 changes: 22 additions & 24 deletions scilpy/tractograms/streamline_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,35 +257,33 @@ def cut_invalid_streamlines(sft, epsilon=0.001):
cutting_counter = 0
for ind in range(len(sft.streamlines)):
if ind in indices_to_cut:
# This streamline was detected as invalid
pos = 0
cur_seg = [0, 0]
best_seg = [0, 0]
while pos < len(sft.streamlines[ind]):
point = sft.streamlines[ind][pos]
cur_seg[1] = pos + 1
if (point < epsilon).any() or \
(point >= sft.dimensions - epsilon).any():
cur_seg = [pos+1, pos+1]
elif cur_seg[1] - cur_seg[0] > best_seg[1] - best_seg[0]:
# We found a longer good segment.
best_seg = cur_seg.copy()

# Ready to check next point.
pos += 1

# Appending the longest segment to the list of streamlines
if not best_seg == [0, 0]:
new_streamlines.append(
sft.streamlines[ind][best_seg[0]:best_seg[1]])
cutting_counter += 1
in_vol = np.logical_and(
sft.streamlines[ind] >= epsilon,
sft.streamlines[ind] < sft.dimensions - epsilon).all(axis=1)

# Get segments in the streamline that are within the volume using
# ndi.label
blobs, _ = ndi.label(in_vol)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frheault @arnaudbore I have greatly simplified and sped up the invalid-cutting process. Let me know how you feel about it.


# Get the largest blob
largest_blob = np.argmax(np.bincount(blobs.ravel())[1:]) + 1

# Get the indices of the points in the largest blob
ind_in_vol = np.where(blobs == largest_blob)[0]
# If there are points in the volume
if len(ind_in_vol):
# Get the streamline segment that is within the volume
new_streamline = sft.streamlines[ind][ind_in_vol]
new_streamlines.append(new_streamline)

for key in sft.data_per_streamline.keys():
new_data_per_streamline[key].append(
sft.data_per_streamline[key][ind])
for key in sft.data_per_point.keys():
new_data_per_point[key].append(
sft.data_per_point[key][ind][
best_seg[0]:best_seg[1]])
ind_in_vol])
cutting_counter += 1
else:
logging.warning('Streamline entirely out of the volume.')
else:
Expand All @@ -296,6 +294,7 @@ def cut_invalid_streamlines(sft, epsilon=0.001):
sft.data_per_streamline[key][ind])
for key in sft.data_per_point.keys():
new_data_per_point[key].append(sft.data_per_point[key][ind])

new_sft = StatefulTractogram.from_sft(
new_streamlines, sft, data_per_streamline=new_data_per_streamline,
data_per_point=new_data_per_point)
Expand All @@ -306,7 +305,6 @@ def cut_invalid_streamlines(sft, epsilon=0.001):

new_sft.to_space(space)
new_sft.to_origin(origin)

return new_sft, cutting_counter


Expand Down
127 changes: 97 additions & 30 deletions scilpy/tractograms/tests/test_dps_and_dpp_management.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# -*- coding: utf-8 -*-
import nibabel as nib
import numpy as np
import pytest

from dipy.io.stateful_tractogram import StatefulTractogram, Space, Origin

from scilpy.image.volume_space_management import DataVolume
from scilpy.tests.utils import nan_array_equal
from scilpy.tractograms.dps_and_dpp_management import (
get_data_as_arraysequence,
add_data_as_color_dpp, convert_dps_to_dpp, project_map_to_streamlines,
project_dpp_to_map, perform_operation_on_dpp, perform_operation_dpp_to_dps,
perform_correlation_on_endpoints)
Expand All @@ -27,45 +30,109 @@ def _get_small_sft():
return fake_sft


def test_add_data_as_color_dpp():
lut = get_lookup_table('viridis')
def test_get_data_as_arraysequence_dpp():
fake_sft = _get_small_sft()

some_data = np.asarray([2, 20, 200, 0.1, 0.3, 22, 5])

# Test 1: One value per point.
array_seq = get_data_as_arraysequence(some_data, fake_sft)

assert fake_sft._get_point_count() == array_seq.total_nb_rows

# Important. cmap(1) != cmap(1.0)
lowest_color = np.asarray(lut(0.0)[0:3]) * 255
highest_color = np.asarray(lut(1.0)[0:3]) * 255

def test_get_data_as_arraysequence_dps():
fake_sft = _get_small_sft()

# Not testing the clipping options. Will be tested through viz.utils tests
some_data = np.asarray([2, 20])

# Test 1: One value per point.
# Lowest cmap color should be first point of second streamline.
some_data = [[2, 20, 200], [0.1, 0.3, 22, 5]]
colored_sft, lbound, ubound = add_data_as_color_dpp(
fake_sft, lut, some_data)
# Test: One value per streamline.
array_seq = get_data_as_arraysequence(some_data, fake_sft)
assert fake_sft._get_streamline_count() == array_seq.total_nb_rows


def test_get_data_as_arraysequence_dps_2D():
fake_sft = _get_small_sft()

some_data = np.asarray([[2], [20]])

# Test: One value per streamline.
array_seq = get_data_as_arraysequence(some_data, fake_sft)
assert fake_sft._get_streamline_count() == array_seq.total_nb_rows


def test_get_data_as_arraysequence_error():
fake_sft = _get_small_sft()

some_data = np.asarray([2, 20, 200, 0.1])

# Test: Too many values per streamline, not enough per point.
with pytest.raises(ValueError):
_ = get_data_as_arraysequence(some_data, fake_sft)


def test_add_data_as_dpp_1_per_point():

fake_sft = _get_small_sft()
cmap = get_lookup_table('jet')

# Test: One value per point.
values = np.asarray([2, 20, 200, 0.1, 0.3, 22, 5])
color = (np.asarray(cmap(values)[:, 0:3]) * 255).astype(np.uint8)

array_seq = get_data_as_arraysequence(color, fake_sft)
colored_sft = add_data_as_color_dpp(
fake_sft, array_seq)
assert len(colored_sft.data_per_streamline.keys()) == 0
assert list(colored_sft.data_per_point.keys()) == ['color']
assert lbound == 0.1
assert ubound == 200
assert np.array_equal(colored_sft.data_per_point['color'][1][0, :],
lowest_color)
assert np.array_equal(colored_sft.data_per_point['color'][0][2, :],
highest_color)

# Test 2: One value per streamline
# Lowest cmap color should be every point in first streamline
some_data = np.asarray([4, 5])
colored_sft, lbound, ubound = add_data_as_color_dpp(
fake_sft, lut, some_data)


def test_add_data_as_dpp_1_per_streamline():

fake_sft = _get_small_sft()
cmap = get_lookup_table('jet')

# Test: One value per streamline
values = np.asarray([4, 5])
color = (np.asarray(cmap(values)[:, 0:3]) * 255).astype(np.uint8)

array_seq = get_data_as_arraysequence(color, fake_sft, per_point=True)

colored_sft = add_data_as_color_dpp(
fake_sft, array_seq)

assert len(colored_sft.data_per_streamline.keys()) == 0
assert list(colored_sft.data_per_point.keys()) == ['color']
assert lbound == 4
assert ubound == 5
# Lowest cmap color should be first point of second streamline.
# Same value for all points.
colors_first_line = colored_sft.data_per_point['color'][0]
assert np.array_equal(colors_first_line[0, :], lowest_color)
assert np.all(colors_first_line[1:, :] == colors_first_line[0, :])


def test_add_data_as_color_error_common_shape():

fake_sft = _get_small_sft()

# Test: One value per streamline
# Should fail because the values aren't RGB values
values = np.asarray([4, 5])
array_seq = get_data_as_arraysequence(values, fake_sft)

with pytest.raises(ValueError):
_ = add_data_as_color_dpp(
fake_sft, array_seq)


def test_add_data_as_color_error_number():

fake_sft = _get_small_sft()
cmap = get_lookup_table('jet')

# Test: One value per streamline
# Should fail because the values aren't RGB values
values = np.asarray([2, 20, 200, 0.1, 0.3, 22, 5])
array_seq = get_data_as_arraysequence(values, fake_sft)
color = (np.asarray(cmap(values)[:, 0:3]) * 255).astype(np.uint8)
color = color[:-2] # Remove last streamline colors
with pytest.raises(ValueError):
_ = add_data_as_color_dpp(
fake_sft, array_seq)


def test_convert_dps_to_dpp():
Expand Down
Loading
Loading