Skip to content

Commit

Permalink
Merge pull request #147 from jhlegarreta/AllowTractogramDownsamplingR…
Browse files Browse the repository at this point in the history
…eproducibility

ENH: Make tractogram downsampling results reproducible by default
  • Loading branch information
ljod authored Sep 21, 2023
2 parents cebf9f0 + 96f6719 commit 3f74faf
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
11 changes: 9 additions & 2 deletions bin/wm_preprocess_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def main():
parser.add_argument(
'-retaindata', action='store_true', dest="flag_retaindata",
help='If given, all point and cell data stored along the tractography will be retained.')
parser.add_argument(
'--nonidentical', action='store_true',
help='Obtain nonidentical results across runs for downsampling.')

args = parser.parse_args()

Expand Down Expand Up @@ -88,7 +91,11 @@ def main():
else:
print("Remove all data stored along the tractography and only keep fiber streamlines.")
retaindata = args.flag_retaindata


random_seed = 1234
if args.nonidentical:
random_seed = None

print("==========================")

# =======================================================================
Expand Down Expand Up @@ -154,7 +161,7 @@ def pipeline(inputPolyDatas, sidx, args):
print(id_msg + msg)

# , preserve_point_data=True needs editing of preprocess function to use mask function
wm3 = wma.filter.downsample(wm2, args.numberOfFibers, preserve_point_data=retaindata, preserve_cell_data=retaindata, verbose=False)
wm3 = wma.filter.downsample(wm2, args.numberOfFibers, preserve_point_data=retaindata, preserve_cell_data=retaindata, verbose=False, random_seed=random_seed)
print("Number of fibers retained: ", wm3.GetNumberOfLines(), "/", num_lines)

if wm3 is None:
Expand Down
2 changes: 1 addition & 1 deletion whitematteranalysis/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def preprocess(inpd, min_length_mm,
else:
return outpd

def downsample(inpd, output_number_of_lines, return_indices=False, preserve_point_data=False, preserve_cell_data=True, initial_indices=None, verbose=True, random_seed=None):
def downsample(inpd, output_number_of_lines, return_indices=False, preserve_point_data=False, preserve_cell_data=True, initial_indices=None, verbose=True, random_seed=1234):
""" Random (down)sampling of fibers without replacement. """

if initial_indices is None:
Expand Down

0 comments on commit 3f74faf

Please sign in to comment.