Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Entzynger tree is working #991

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lineage/Analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def Results(tHMMobj: tHMM, LL: float) -> dict[str, Any]:
results_dict["total_number_of_lineages"] = len(tHMMobj.X)
results_dict["LL"] = LL
results_dict["total_number_of_cells"] = sum(
[len(lineage.output_lineage) for lineage in tHMMobj.X]
[len(lineage) for lineage in tHMMobj.X]
)

true_states_by_lineage = [
Expand Down
49 changes: 28 additions & 21 deletions lineage/BaumWelch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .LineageTree import get_Emission_Likelihoods
from .states.StateDistributionGamma import atonce_estimator
from .HMM.M_step import get_all_zetas, sum_nonleaf_gammas
from .HMM.E_step import get_beta_and_NF, get_MSD, get_gamma
from .HMM.E_step import get_beta, get_MSD, get_gamma, get_leaf_Normalizing_Factors


def do_E_step(tHMMobj: tHMM) -> Tuple[list, list, list, list]:
Expand All @@ -26,17 +26,18 @@ def do_E_step(tHMMobj: tHMM) -> Tuple[list, list, list, list]:
EL = get_Emission_Likelihoods(tHMMobj.X, tHMMobj.estimate.E)

for ii, lO in enumerate(tHMMobj.X):
MSD.append(
get_MSD(lO.cell_to_daughters, tHMMobj.estimate.pi, tHMMobj.estimate.T)
)

NF_one, beta = get_beta_and_NF(
lO.leaves_idx, lO.cell_to_daughters, tHMMobj.estimate.T, MSD[ii], EL[ii]
MSD.append(get_MSD(len(lO), tHMMobj.estimate.pi, tHMMobj.estimate.T))
NF.append(get_leaf_Normalizing_Factors(MSD[ii], EL[ii]))
betas.append(
get_beta(
tHMMobj.estimate.T,
MSD[ii],
EL[ii],
NF[ii],
)
)
NF.append(NF_one)
betas.append(beta)
gammas.append(
get_gamma(lO.cell_to_daughters, tHMMobj.estimate.T, MSD[ii], betas[ii])
get_gamma(tHMMobj.estimate.T, MSD[ii], betas[ii])
)

return MSD, NF, betas, gammas
Expand Down Expand Up @@ -169,14 +170,12 @@ def do_M_T_step(
for num, lO in enumerate(tt.X):
# local T estimate
numer_e += get_all_zetas(
lO.leaves_idx,
lO.cell_to_daughters,
betas[i][num],
MSD[i][num],
gammas[i][num],
tt.estimate.T,
)
denom_e += sum_nonleaf_gammas(lO.leaves_idx, gammas[i][num])
denom_e += sum_nonleaf_gammas(gammas[i][num])

T_estimate = numer_e / denom_e[:, np.newaxis]
T_estimate /= T_estimate.sum(axis=1)[:, np.newaxis]
Expand All @@ -197,8 +196,17 @@ def do_M_E_step(tHMMobj: tHMM, gammas: list[np.ndarray]):
:type tHMMobj: object
:param gammas: gamma values. The conditional probability of states, given the observation of the whole tree
"""
all_cells = [cell.obs for lineage in tHMMobj.X for cell in lineage.output_lineage]
cell_arr = np.array(all_cells)
all_cells: list[np.ndarray] = []

for lineage in tHMMobj.X:
for cell in lineage.output_lineage:
if cell is None:
all_cells.append(-1 * np.ones(all_cells[0].size))
else:
all_cells.append(np.array(cell.obs))

all_cells = np.array(all_cells) # type: ignore

all_gammas = np.vstack(gammas)
for state_j in range(tHMMobj.num_states):
tHMMobj.estimate.E[state_j].estimator(cell_arr, all_gammas[:, state_j])
Expand All @@ -214,10 +222,9 @@ def do_M_E_step_atonce(all_tHMMobj: list[tHMM], all_gammas: list[list[np.ndarray
for gm in all_gammas:
gms.append(np.vstack(gm))

all_cells = np.array(
[cell.obs for lineage in all_tHMMobj[0].X for cell in lineage.output_lineage]
)
if len(all_cells[1, :]) == 6:
all_cells = all_tHMMobj[0].X[0].get_observations()

if all_cells.shape[1] == 6:
phase = True
else:
phase = False
Expand All @@ -226,8 +233,8 @@ def do_M_E_step_atonce(all_tHMMobj: list[tHMM], all_gammas: list[list[np.ndarray
G2cells = []
cells = []
for tHMMobj in all_tHMMobj:
all_cells = np.array(
[cell.obs for lineage in tHMMobj.X for cell in lineage.output_lineage]
all_cells = np.vstack(
[lineage.get_observations() for lineage in tHMMobj.X]
)
if phase:
G1cells.append(all_cells[:, np.array([0, 2, 4])])
Expand Down
17 changes: 0 additions & 17 deletions lineage/CellVar.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,6 @@ def __init__(self, parent: Optional["CellVar"], state: Optional[int] = None):
self.right = None
self.obs = np.empty(0, dtype=float)

def divide(self, T: np.ndarray, rng=None):
"""
Member function that performs division of a cell.
Equivalent to adding another timestep in a Markov process.
:param T: The array containing the likelihood of a cell switching states.
"""
rng = np.random.default_rng(rng)
# Checking that the inputs are of the right shape
assert T.shape[0] == T.shape[1]

# roll a loaded die according to the row in the transtion matrix
left_state, right_state = rng.choice(T.shape[0], size=2, p=T[self.state, :])
self.left = CellVar(state=left_state, parent=self)
self.right = CellVar(state=right_state, parent=self)

return self.left, self.right

def isLeafBecauseTerminal(self) -> bool:
"""
Returns true when a cell is a leaf with no children.
Expand Down
117 changes: 68 additions & 49 deletions lineage/HMM/E_step.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,51 @@
import numpy as np
import numpy.typing as npt
from numba import jit


def get_leaf_Normalizing_Factors(
MSD: npt.NDArray[np.float64],
EL: npt.NDArray[np.float64],
) -> npt.NDArray[np.float64]:
"""
Normalizing factor (NF) matrix and base case at the leaves.

Each element in this N by 1 matrix is the normalizing
factor for each beta value calculation for each node.
This normalizing factor is essentially the marginal
observation distribution for a node.

This function gets the normalizing factor for
the upward recursion only for the leaves.
We first calculate the joint probability
using the definition of conditional probability:

:math:`P(x_n = x | z_n = k) * P(z_n = k) = P(x_n = x , z_n = k)`,
where n are the leaf nodes.

We can then sum this joint probability over k,
which are the possible states z_n can be,
and through the law of total probability,
obtain the marginal observation distribution
:math:`P(x_n = x) = sum_k ( P(x_n = x , z_n = k) ) = P(x_n = x)`.

:param EL: The emissions likelihood
:param MSD: The marginal state distribution P(z_n = k)
:return: normalizing factor. The marginal observation distribution P(x_n = x)
"""
NF_array = np.zeros(MSD.shape[0], dtype=float) # instantiating N by 1 array
first_leaf = int(np.floor(MSD.shape[0] / 2))

# P(x_n = x , z_n = k) = P(x_n = x | z_n = k) * P(z_n = k)
# this product is the joint probability
# P(x_n = x) = sum_k ( P(x_n = x , z_n = k) )
# the sum of the joint probabilities is the marginal probability
NF_array[first_leaf:] = np.sum(MSD[first_leaf:, :] * EL[first_leaf:, :], axis=1)
assert np.all(np.isfinite(NF_array))
return NF_array


def get_MSD(
cell_to_daughters: np.ndarray,
pi: npt.NDArray[np.float64],
T: npt.NDArray[np.float64],
n_cells: int, pi: npt.NDArray[np.float64], T: npt.NDArray[np.float64]
) -> npt.NDArray[np.float64]:
r"""Marginal State Distribution (MSD) matrix by upward recursion.
This is the probability that a hidden state variable :math:`z_n` is of
Expand All @@ -29,48 +68,26 @@ def get_MSD(
:param T: State transitions matrix
:return: The marginal state distribution
"""
m = np.zeros((cell_to_daughters.shape[0], pi.size))
m = np.zeros((n_cells, pi.size))
m[0, :] = pi

# recursion based on parent cell
for pIDX, cIDX in enumerate(cell_to_daughters):
if cIDX[0] != -1:
m[cIDX[0], :] = m[pIDX, :] @ T
if cIDX[1] != -1:
m[cIDX[1], :] = m[pIDX, :] @ T
for cIDX in range(1, n_cells):
pIDX = int(np.floor(cIDX / 2))
m[cIDX, :] = m[pIDX, :] @ T

# Assert all ~= 1.0
assert np.linalg.norm(np.sum(m, axis=1) - 1.0) < 1e-9
return m


@jit
def get_beta_and_NF(
leaves_idx, cell_to_daughters, T: np.ndarray, MSD: np.ndarray, EL: np.ndarray
):
r"""
Normalizing factor (NF) matrix and base case at the leaves.

Each element in this N by 1 matrix is the normalizing
factor for each beta value calculation for each node.
This normalizing factor is essentially the marginal
observation distribution for a node.

This function gets the normalizing factor for
the upward recursion only for the leaves.
We first calculate the joint probability
using the definition of conditional probability:

:math:`P(x_n = x | z_n = k) * P(z_n = k) = P(x_n = x , z_n = k)`,
where n are the leaf nodes.

We can then sum this joint probability over k,
which are the possible states z_n can be,
and through the law of total probability,
obtain the marginal observation distribution
:math:`P(x_n = x) = sum_k ( P(x_n = x , z_n = k) ) = P(x_n = x)`.

Beta matrix and base case at the leaves.
def get_beta(
T: npt.NDArray[np.float64],
MSD: npt.NDArray[np.float64],
EL: npt.NDArray[np.float64],
NF: npt.NDArray[np.float64],
) -> npt.NDArray[np.float64]:
r"""Beta matrix and base case at the leaves.

Each element in this N by K matrix is the beta value
for each cell and at each state. In particular, this
Expand Down Expand Up @@ -120,20 +137,24 @@ def get_beta_and_NF(

### beta calculation
beta = np.zeros_like(MSD)
first_leaf = int(np.floor(MSD.shape[0] / 2))

# Emission Likelihood, Marginal State Distribution, Normalizing Factor (same regardless of state)
# P(x_n = x | z_n = k), P(z_n = k), P(x_n = x)
beta[leaves_idx, :] = ELMSD[leaves_idx, :] / NF[leaves_idx, np.newaxis]
ZZ = EL[first_leaf:, :] * MSD[first_leaf:, :] / NF[first_leaf:, np.newaxis]
beta[first_leaf:, :] = ZZ

# Assert all ~= 1.0
assert np.abs(np.sum(beta[-1]) - 1.0) < 1e-9

cIDXs = np.arange(MSD.shape[0])
cIDXs = np.delete(cIDXs, leaves_idx)
cIDXs = np.flip(cIDXs)
MSD_array = np.maximum(
MSD, np.finfo(MSD.dtype).eps
) # MSD of the respective lineage
ELMSD = EL * MSD

for pii in range(first_leaf - 1, -1, -1):
ch_ii = np.array([pii * 2 + 1, pii * 2 + 2])

for pii in cIDXs:
ch_ii = cell_to_daughters[pii, :]
ratt = (beta[ch_ii, :] / MSD_array[ch_ii, :]) @ T.T
fac1 = ratt[0, :] * ratt[1, :] * ELMSD[pii, :]

Expand All @@ -144,7 +165,6 @@ def get_beta_and_NF(


def get_gamma(
cell_to_daughters: npt.NDArray[np.uintp],
T: npt.NDArray[np.float64],
MSD: npt.NDArray[np.float64],
beta: npt.NDArray[np.float64],
Expand All @@ -167,12 +187,11 @@ def get_gamma(
coeffs = np.maximum(coeffs, epss)
beta_parents = T @ coeffs.T

# Getting lineage by generation, but it is sorted this way
for pidx, cis in enumerate(cell_to_daughters):
for ci in cis:
if ci == -1:
continue
first_leaf = int(np.floor(MSD.shape[0] / 2))

# Getting lineage by generation, but it is sorted this way
for pidx in range(first_leaf):
for ci in [pidx * 2 + 1, pidx * 2 + 2]:
A = gamma[pidx, :].T / beta_parents[:, ci]

gamma[ci, :] = coeffs[ci, :] * (A @ T)
Expand Down
17 changes: 5 additions & 12 deletions lineage/HMM/M_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import numpy.typing as npt


def sum_nonleaf_gammas(
leaves_idx, gammas: npt.NDArray[np.float64]
) -> npt.NDArray[np.float64]:
def sum_nonleaf_gammas(gammas: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
"""
Sum of the gammas of the cells that are able to divide, that is,
sum the of the gammas of all the nonleaf cells. It is used in estimating the transition probability matrix.
Expand All @@ -16,16 +14,13 @@ def sum_nonleaf_gammas(
:param gamma_arr: the gamma values for each lineage
:return: the sum of gamma values for each state for non-leaf cells.
"""
# Remove leaves
gs = np.delete(gammas, leaves_idx, axis=0)
first_leaf = int(np.floor(gammas.shape[0] / 2))

# sum the gammas for cells that are transitioning (all but gen 0)
return np.sum(gs[1:, :], axis=0)
return np.sum(gammas[1:first_leaf, :], axis=0)


def get_all_zetas(
leaves_idx: npt.NDArray[np.uintp],
cell_to_daughters: npt.NDArray[np.uintp],
beta_array: npt.NDArray[np.float64],
MSD_array: npt.NDArray[np.float64],
gammas: npt.NDArray[np.float64],
Expand All @@ -45,10 +40,8 @@ def get_all_zetas(
betaMSD = beta_array / np.clip(MSD_array, np.finfo(float).eps, np.inf)
TbetaMSD = np.clip(betaMSD @ T.T, np.finfo(float).eps, np.inf)

cIDXs = np.arange(gammas.shape[0])
cIDXs = np.delete(cIDXs, leaves_idx)

dIDXs = cell_to_daughters[cIDXs, :]
cIDXs = np.arange(int(np.floor(gammas.shape[0] / 2)) - 1)
dIDXs = np.vstack((cIDXs * 2 + 1, cIDXs * 2 + 2)).T

# Getting lineage by generation, but it is sorted this way
js = gammas[cIDXs, np.newaxis, :] / TbetaMSD[dIDXs, :]
Expand Down
Loading
Loading