Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add z drift metrics to imaging.py #159

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions element_calcium_imaging/imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from . import imaging_report, scan
from .scan import (
get_calcium_imaging_files,
get_zstack_files,
get_imaging_root_data_dir,
get_processed_root_data_dir,
)
Expand Down Expand Up @@ -75,6 +76,165 @@ def activate(
# -------------- Table declarations --------------


@schema
class ZDriftParamSet(dj.Manual):
definition = """
paramset_idx: int
---
paramset_desc: varchar(1280) # Parameter-set description
param_set_hash: uuid # A universally unique identifier for the parameter set
unique index (param_set_hash)
params: longblob # Parameter Set, a dictionary of all z-drift parameters.
"""

@classmethod
def insert_new_params(
cls,
paramset_idx: int,
paramset_desc: str,
params: dict,
):
"""Insert a parameter set into ProcessingParamSet table.
This function automates the parameter set hashing and avoids insertion of an
existing parameter set.
Attributes:
processing_method (str): Processing method/package used for processing of
calcium imaging.
paramset_idx (int): Unique parameter set ID.
paramset_desc (str): Parameter set description.
params (dict): Parameter Set, all applicable parameters to the
z-axis correlation analysis.
"""
param_dict = {
"paramset_idx": paramset_idx,
"paramset_desc": paramset_desc,
"params": params,
"param_set_hash": dict_to_uuid(params),
}
q_param = cls & {"param_set_hash": param_dict["param_set_hash"]}

if q_param: # If the specified param-set already exists
p_name = q_param.fetch1("paramset_idx")
if p_name == paramset_idx: # If the existed set has the same name: job done
return
else: # If not same name: human error, trying to add the same paramset with different name
raise dj.DataJointError(
"The specified param-set already exists - name: {}".format(p_name)
)
else:
cls.insert1(param_dict)


@schema
class ZDriftMetrics(dj.Computed):
"""Generate z-axis motion report.
Attributes:
scan.Scan (foreign key): Primary key from scan.Scan.
ZDriftParamSet (foreign key): Primary key from ZDriftParamSet.
z_drift (longblob): Amount of drift in microns per frame in Z direction.
"""

definition = """
-> scan.Scan
-> ZDriftParamSet
---
z_drift: longblob # Amount of drift in microns per frame in Z direction.
"""

def make(self, key):
def _make_taper(size, width):
m = np.ones(size - width + 1)
k = np.hanning(width)
return np.convolve(m, k, mode="full") / k.sum()

image_files = (scan.ScanInfo.ScanFile & key).fetch("file_path")
image_files = [
find_full_path(get_imaging_root_data_dir()[0], image_file)
for image_file in image_files
]
zstack_files = get_zstack_files(key)

acq_software = (scan.Scan & key).fetch1("acq_software")

if acq_software == "NIS":
import nd2

ca_imaging_movie = nd2.imread(image_files[0])
zstack = nd2.imread(zstack_files[0])

else:
raise NotImplementedError(
f"Z-drift metrics functionality for {acq_software} acquisition software is not supported yet. Please contact '[email protected]' to request the feature."
)

drift_params = (ZDriftParamSet & key).fetch1("params")
required_params = ["pad_length", "slice_interval", "channel"]

if not all(parameter in drift_params.keys() for parameter in required_params):
raise Exception(
"Z-drift parameters must include a keys for 'pad_length', 'slice_interval', and 'channel'."
)

ca_imaging_movie = ca_imaging_movie[:, drift_params["channel"], :, :]
zstack = zstack[:, drift_params["channel"], :, :]
# center zstack
zstack = zstack - zstack.mean(axis=(1, 2), keepdims=True)

# taper zstack
ytaper = _make_taper(zstack.shape[1], drift_params["pad_length"])
zstack = zstack * ytaper[None, :, None]

xtaper = _make_taper(zstack.shape[2], drift_params["pad_length"])
zstack = zstack * xtaper[None, None, :]

# normalize zstack
zstack = zstack - zstack.mean(axis=(1, 2), keepdims=True)
zstack /= np.sqrt((zstack**2).sum(axis=(1, 2), keepdims=True))

# pad zstack
zstack = np.pad(
zstack,
(
(0, 0),
(drift_params["pad_length"], drift_params["pad_length"]),
(drift_params["pad_length"], drift_params["pad_length"]),
),
)

# normalize movie
ca_imaging_movie = ca_imaging_movie - ca_imaging_movie.mean(
axis=(1, 2), keepdims=True
)
ca_imaging_movie /= np.sqrt(
(ca_imaging_movie**2).sum(axis=(1, 2), keepdims=True)
)

# Vectorized implementation
middle = (zstack.shape[0] - 1) // 2
_, ny, nx = ca_imaging_movie.shape
offsets = list(
(dy, dx)
for dx in range(2 * drift_params["pad_length"] + 1)
for dy in range(2 * drift_params["pad_length"] + 1)
)
c = list(
np.einsum(
"dij, tij -> td",
zstack[:, dy : dy + ny, dx : dx + nx],
ca_imaging_movie,
)
for dy, dx in offsets
)

drift = ((np.argmax(np.stack(c).max(axis=0), axis=1)) - middle) * drift_params[
"slice_interval"
]

self.insert1(
dict(**key, z_drift=drift),
)


@schema
class ProcessingMethod(dj.Lookup):
"""Package used for processing of calcium imaging data (e.g. Suite2p, CaImAn, etc.).
Expand Down
11 changes: 11 additions & 0 deletions element_calcium_imaging/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ def get_calcium_imaging_files(scan_key: dict, acq_software: str) -> list:
return _linking_module.get_calcium_imaging_files(scan_key, acq_software)


def get_zstack_files(scan_key: dict) -> list:
"""Retrieve the list of zstack files associated with a given Scan.
Args:
scan_key: Primary key of a Scan entry.
Returns:
A list of zstack files' full file-paths.
"""

return _linking_module.get_zstack_files(scan_key)


# ----------------------------- Table declarations ----------------------


Expand Down