Skip to content

Commit d663903

Browse files
committed
RF: ANTs h5 loading
1 parent 8d44792 commit d663903

File tree

2 files changed

+32
-24
lines changed

2 files changed

+32
-24
lines changed

nibabies/interfaces/resampling.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from scipy.sparse import hstack as sparse_hstack
2525
from sdcflows.transform import grid_bspline_weights
2626
from sdcflows.utils.tools import ensure_positive_cosines
27+
from transforms3d.affines import compose as compose_affine
2728

2829
R = TypeVar('R')
2930

@@ -62,16 +63,6 @@ def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.Trans
6263
return chain
6364

6465

65-
FIXED_PARAMS = np.array([
66-
193.0, 229.0, 193.0, # Size
67-
96.0, 132.0, -78.0, # Origin
68-
1.0, 1.0, 1.0, # Spacing
69-
-1.0, 0.0, 0.0, # Directions
70-
0.0, -1.0, 0.0,
71-
0.0, 0.0, 1.0,
72-
]) # fmt:skip
73-
74-
7566
def load_ants_h5(filename: Path) -> nt.base.TransformBase:
7667
"""Load ANTs H5 files as a nitransforms TransformChain"""
7768
# Borrowed from https://github.com/feilong/process
@@ -80,7 +71,8 @@ def load_ants_h5(filename: Path) -> nt.base.TransformBase:
8071
# Changes:
8172
# * Tolerate a missing displacement field
8273
# * Return the original affine without a round-trip
83-
# * Always return a nitransforms TransformChain
74+
# * Always return a nitransforms TransformBase
75+
# * Construct warp affine from fixed parameters
8476
#
8577
# This should be upstreamed into nitransforms
8678
h = h5py.File(filename)
@@ -104,22 +96,37 @@ def load_ants_h5(filename: Path) -> nt.base.TransformBase:
10496
msg += f'[{i}]: {h["TransformGroup"][i]["TransformType"][:][0]}\n'
10597
raise ValueError(msg)
10698

107-
fixed_params = transform2['TransformFixedParameters'][:]
108-
shape = tuple(fixed_params[:3].astype(int))
99+
# Warp field fixed parameters as defined in
100+
# https://itk.org/Doxygen/html/classitk_1_1DisplacementFieldTransform.html
101+
shape = transform2['TransformFixedParameters'][:3]
102+
origin = transform2['TransformFixedParameters'][3:6]
103+
spacing = transform2['TransformFixedParameters'][6:9]
104+
direction = transform2['TransformFixedParameters'][9:].reshape((3, 3))
105+
106+
# We are not yet confident that we handle non-unit spacing
107+
# or direction cosine ordering correctly.
108+
# If we confirm or fix, we can remove these checks.
109+
if not np.allclose(spacing, 1):
110+
raise ValueError(f'Unexpected spacing: {spacing}')
111+
if not np.allclose(direction, direction.T):
112+
raise ValueError(f'Asymmetric direction matrix: {direction}')
113+
114+
# ITK uses LPS affines
115+
lps_affine = compose_affine(T=origin, R=direction, Z=spacing)
116+
ras_affine = np.diag([-1, -1, 1, 1]) @ lps_affine
117+
109118
# ITK stores warps in Fortran-order, where the vector components change fastest
110-
# Nitransforms expects 3 volumes, not a volume of three-vectors, so transpose
111-
warp = np.reshape(
119+
# Vectors are in mm LPS
120+
itk_warp = np.reshape(
112121
transform2['TransformParameters'],
113-
(3, *shape),
122+
(3, *shape.astype(int)),
114123
order='F',
115-
).transpose(1, 2, 3, 0)
116-
117-
warp_affine = np.eye(4)
118-
warp_affine[:3, :3] = fixed_params[9:].reshape((3, 3))
119-
warp_affine[:3, 3] = fixed_params[3:6]
120-
lps_to_ras = np.eye(4) * np.array([-1, -1, 1, 1])
121-
warp_affine = lps_to_ras @ warp_affine
122-
transforms.insert(0, nt.DenseFieldTransform(nb.Nifti1Image(warp, warp_affine)))
124+
)
125+
126+
# Nitransforms warps are in RAS, with the vector components changing slowest
127+
nt_warp = itk_warp.transpose(1, 2, 3, 0) * np.array([-1, -1, 1])
128+
129+
transforms.insert(0, nt.DenseFieldTransform(nb.Nifti1Image(nt_warp, ras_affine)))
123130
return nt.TransformChain(transforms)
124131

125132

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dependencies = [
3535
"smriprep @ git+https://github.com/nipreps/smriprep.git@enh/nibabies-fit-apply",
3636
"tedana >= 23.0.2",
3737
"templateflow >= 24.2.0",
38+
"transforms3d",
3839
"toml",
3940
]
4041
dynamic = ["version"]

0 commit comments

Comments
 (0)