Skip to content

Commit

Permalink
Add tomogram mask option for template matching (#185)
Browse files Browse the repository at this point in the history
* Add tomogram mask option in the volume control for template matching

* implement mask check during volume splitting

* use tomogram mask during extraction as well

* update entry point documentation

* fix typo

* reword entry point documentation

* Name the tests that need to be added

* add tests for the volume splitting and the extraction logic

* fix added tests

* Update tests/test_tmjob.py

Co-authored-by: Marten Chaillet <[email protected]>

* test validity of the tomogram mask already in the tmjob init

* it is a tomogram mask, not a template mask...

* Update src/pytom_tm/tmjob.py

Co-authored-by: Marten Chaillet <[email protected]>

---------

Co-authored-by: Marten Chaillet <[email protected]>
  • Loading branch information
sroet and McHaillet authored Jun 14, 2024
1 parent fcdb82c commit 10e6655
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 9 deletions.
13 changes: 12 additions & 1 deletion src/pytom_tm/entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ def extract_candidates(argv=None):
help="Here you can provide a mask for the extraction with dimensions equal to "
"the tomogram. All values in the mask that are smaller or equal to 0 will be "
"removed, all values larger than 0 are considered regions of interest. It can "
"be used to extract annotations only within a specific cellular region.",
"be used to extract annotations only within a specific cellular region."
"If the job was run with a tomogram mask, this file will be used instead of the job mask",
)
parser.add_argument(
"-n",
Expand Down Expand Up @@ -635,6 +636,15 @@ def match_template(argv=None):
help="Start and end indices of the search along the z-axis, "
"e.g. --search-x 30 230 ",
)
volume_group.add_argument(
"--tomogram-mask",
type=pathlib.Path,
required=False,
action=CheckFileExists,
help="Here you can provide a mask for matching with dimensions equal to "
"the tomogram. If a subvolume only has values <= 0 for this mask it will be skipped.",
)

filter_group = parser.add_argument_group("Filter control")
filter_group.add_argument(
"-a",
Expand Down Expand Up @@ -856,6 +866,7 @@ def match_template(argv=None):
search_x=args.search_x,
search_y=args.search_y,
search_z=args.search_z,
tomogram_mask=args.tomogram_mask,
voxel_size=args.voxel_size_angstrom,
low_pass=args.low_pass,
high_pass=args.high_pass,
Expand Down
19 changes: 13 additions & 6 deletions src/pytom_tm/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def extract_particles(
tune the number of false positives to be included for automated error function cut-off estimation:
should be a float > 0
tomogram_mask_path: Optional[pathlib.Path]
path to a tomographic binary mask for extraction
path to a tomographic binary mask for extraction, will override job.tomogram_mask
tophat_filter: bool
attempt to only select sharp peaks with the tophat filter
create_plot: bool, default True
Expand Down Expand Up @@ -202,12 +202,19 @@ def extract_particles(
)

# apply tomogram mask if provided
tomogram_mask = None
if tomogram_mask_path is not None:
tomogram_mask = read_mrc(tomogram_mask_path)[
job.search_origin[0] : job.search_origin[0] + job.search_size[0],
job.search_origin[1] : job.search_origin[1] + job.search_size[1],
job.search_origin[2] : job.search_origin[2] + job.search_size[2],
] # mask should be larger than zero in regions of interest!
tomogram_mask = read_mrc(tomogram_mask_path)
elif job.tomogram_mask is not None:
tomogram_mask = read_mrc(job.tomogram_mask)

if tomogram_mask is not None:
slices = [
slice(origin, origin + size)
for origin, size in zip(job.search_origin, job.search_size)
]
tomogram_mask = tomogram_mask[*slices]
# mask should be larger than zero in regions of interest!
score_volume[tomogram_mask <= 0] = 0

# mask edges of score volume
Expand Down
29 changes: 29 additions & 0 deletions src/pytom_tm/tmjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def load_json_to_tmjob(
search_x=data["search_x"],
search_y=data["search_y"],
search_z=data["search_z"],
# Use 'get' for backwards compatibility
tomogram_mask=data.get("tomogram_mask", None),
voxel_size=data["voxel_size"],
low_pass=data["low_pass"],
# Use 'get' for backwards compatibility
Expand Down Expand Up @@ -213,6 +215,7 @@ def __init__(
search_x: Optional[list[int, int]] = None,
search_y: Optional[list[int, int]] = None,
search_z: Optional[list[int, int]] = None,
tomogram_mask: Optional[pathlib.Path] = None,
voxel_size: Optional[float] = None,
low_pass: Optional[float] = None,
high_pass: Optional[float] = None,
Expand Down Expand Up @@ -256,6 +259,8 @@ def __init__(
restrict tomogram search region along the y-axis
search_z: Optional[list[int, int]], default None
restrict tomogram search region along the z-axis
tomogram_mask: Optional[pathlib.Path], default None
when volume splitting tomograms, only subjobs where any(mask > 0) will be generated
voxel_size: Optional[float], default None
voxel size of tomogram and template (in A) if not provided will be read from template/tomogram MRCs
low_pass: Optional[float], default None
Expand Down Expand Up @@ -363,6 +368,13 @@ def __init__(
]

logging.debug(f"origin, size = {self.search_origin}, {self.search_size}")
self.tomogram_mask = tomogram_mask
if tomogram_mask is not None:
temp = read_mrc(tomogram_mask)
if np.all(temp <= 0):
raise ValueError(
f"No values larger than 0 found in the tomogram mask: {tomogram_mask}"
)

self.whole_start = None
# For the main job these are always [0,0,0] and self.search_size, for sub_jobs these will differ from
Expand Down Expand Up @@ -535,6 +547,8 @@ def split_volume_search(self, split: tuple[int, int, int]) -> list[TMJob, ...]:
Finally, new_job.sub_start and new_job.sub_step, extract the score and angle map without the template
overhang from the subvolume.
If self.tomogram_mask is set, we will skip subjobs where all(mask <= 0).
Parameters
----------
split: tuple[int, int, int]
Expand All @@ -551,6 +565,11 @@ def split_volume_search(self, split: tuple[int, int, int]) -> list[TMJob, ...]:
)

search_size = self.search_size
if self.tomogram_mask is not None:
# This should have some positve values after the check in the __init__
tomogram_mask = read_mrc(self.tomogram_mask)
else:
tomogram_mask = None
# shape of template for overhang
overhang = self.template_shape
# use overhang//2 (+1 for odd sizes)
Expand All @@ -573,6 +592,16 @@ def split_volume_search(self, split: tuple[int, int, int]) -> list[TMJob, ...]:
whole_start = tuple(dim_data[1][0] for dim_data in data_3D)
sub_start = tuple(dim_data[1][0] - dim_data[0][0] for dim_data in data_3D)
sub_step = tuple(dim_data[1][1] - dim_data[1][0] for dim_data in data_3D)

# check if this contains any of the unique data points are where tomo_mask>=0
if tomogram_mask is not None:
slices = [
slice(origin, origin + step)
for origin, step in zip(whole_start, sub_step)
]
if np.all(tomogram_mask[*slices] <= 0):
# No non-masked unique data-points, skipping
continue
new_job = self.copy()
new_job.leader = self.job_key
new_job.job_key = self.job_key + str(i)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ def test_predict_tophat_mask(self):
self.assertEqual(
tophat_mask.dtype, bool, msg="predicted tophat mask should be boolean"
)

# part of the extraction test in test_tmjob.py
# def test_extract_job_with_tomogram_mask(self):
# pass
45 changes: 43 additions & 2 deletions tests/test_tmjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ANGULAR_SEARCH = "38.53"
TEST_DATA_DIR = pathlib.Path(__file__).parent.joinpath("test_data")
TEST_TOMOGRAM = TEST_DATA_DIR.joinpath("tomogram.mrc")
TEST_BROKEN_TOMOGRAM_MASK = TEST_DATA_DIR.joinpath("broken_tomogram_mask.mrc")
TEST_EXTRACTION_MASK_OUTSIDE = TEST_DATA_DIR.joinpath("extraction_mask_outside.mrc")
TEST_EXTRACTION_MASK_INSIDE = TEST_DATA_DIR.joinpath("extraction_mask_inside.mrc")
TEST_TEMPLATE = TEST_DATA_DIR.joinpath("template.mrc")
Expand Down Expand Up @@ -119,9 +120,14 @@ def setUpClass(cls) -> None:
)
job.write_to_json(TEST_JOB_JSON_WHITENING)

# write broken tomogram mask
broken_tomogram_mask = np.zeros(TOMO_SHAPE, dtype=np.float32)
write_mrc(TEST_BROKEN_TOMOGRAM_MASK, broken_tomogram_mask, 1.0)

@classmethod
def tearDownClass(cls) -> None:
TEST_MASK.unlink()
TEST_BROKEN_TOMOGRAM_MASK.unlink()
TEST_EXTRACTION_MASK_OUTSIDE.unlink()
TEST_EXTRACTION_MASK_INSIDE.unlink()
TEST_TEMPLATE.unlink()
Expand Down Expand Up @@ -232,6 +238,20 @@ def test_tm_job_errors(self):
voxel_size=1.0,
)

# Test broken template mask
with self.assertRaisesRegex(ValueError, str(TEST_BROKEN_TOMOGRAM_MASK)):
TMJob(
"0",
10,
TEST_TOMOGRAM,
TEST_TEMPLATE,
TEST_MASK,
TEST_DATA_DIR,
angle_increment=ANGULAR_SEARCH,
voxel_size=1.0,
tomogram_mask=TEST_BROKEN_TOMOGRAM_MASK,
)

def test_tm_job_copy(self):
copy = self.job.copy()
self.assertIsNot(
Expand Down Expand Up @@ -481,6 +501,12 @@ def test_tm_job_split_volume(self):
"almost identical.",
)

def test_splitting_with_tomogram_mask(self):
job = self.job.copy()
job.tomogram_mask = TEST_EXTRACTION_MASK_INSIDE
job.split_volume_search((10, 10, 10))
self.assertLess(len(job.sub_jobs), 10 * 10 * 10)

def test_splitting_with_offsets(self):
# check if subjobs have correct offsets for the main job, the last sub job will have the largest errors
job = TMJob(
Expand Down Expand Up @@ -602,10 +628,25 @@ def test_extraction(self):
msg="Length of returned list should be 0 after applying mask where the "
"object is not in the region of interest.",
)
# test if the extraction mask can be grabbed from the job instead
job = self.job.copy()
job.tomogram_mask = TEST_EXTRACTION_MASK_OUTSIDE
df, scores = extract_particles(
job,
5,
100,
create_plot=False,
)
self.assertEqual(
len(scores),
0,
msg="Length of returned list should be 0 after applying mask where the "
"object is not in the region of interest.",
)

# test mask that covers the particle
# test mask that covers the particle and should override the one now attached to the job
df, scores = extract_particles(
self.job,
job,
5,
100,
tomogram_mask_path=TEST_EXTRACTION_MASK_INSIDE,
Expand Down

0 comments on commit 10e6655

Please sign in to comment.