diff --git a/.gitignore b/.gitignore index c7e3604..6aaec02 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,4 @@ docs/notebooks/*.pt /.hatch/ src/perturbo/simulation/.ipynb_checkpoints/ +lightning_logs/ \ No newline at end of file diff --git a/src/perturbo/__init__.py b/src/perturbo/__init__.py index 454a392..3e3b962 100644 --- a/src/perturbo/__init__.py +++ b/src/perturbo/__init__.py @@ -1,6 +1,6 @@ from importlib.metadata import version -from . import models, simulation +from . import models, simulation, utils from .models import PERTURBO from .simulation import Learn_Data, Simulate_Data diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index a3ab0d7..2666ba0 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -1,28 +1,27 @@ import logging +import numba as nb import numpy as np import pandas as pd +import pyro.optim as optim import scipy.sparse as sp import torch from mudata import AnnData, MuData from pandas import DataFrame from pyro import poutine -from pyro.infer import TraceEnum_ELBO, infer_discrete -from scipy.sparse import issparse +from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, infer_discrete +from scipy.sparse import coo_matrix, csc_matrix, issparse from scipy.stats import chi2 from scvi._types import AnnOrMuData from scvi.data import AnnDataManager, fields from scvi.dataloaders import AnnDataLoader, DataSplitter, DeviceBackedDataSplitter -from scvi.model.base import ( - BaseModelClass, - PyroJitGuideWarmup, - PyroSampleMixin, - PyroSviTrainMixin, -) +from scvi.model._utils import parse_device_args +from scvi.model.base import BaseModelClass, PyroJitGuideWarmup, PyroSampleMixin, PyroSviTrainMixin from scvi.train import PyroTrainingPlan from scvi.utils._docstrings import devices_dsp from sklearn.isotonic import IsotonicRegression from sklearn.linear_model import LinearRegression +from tqdm.auto import trange from ._constants import REGISTRY_KEYS from ._module import PerTurboPyroModule @@ -30,6 +29,78 @@ logger = logging.getLogger(__name__) +def compute_element_lfc_initialization( + X: np.ndarray | sp.spmatrix, + guide_obs: np.ndarray | sp.spmatrix, + guide_by_element: torch.Tensor | np.ndarray, + control_indices: np.ndarray | None = None, + n_cells_for_control: int = 10000, + pseudocount: float = 0.1, +) -> np.ndarray: + """ + Compute log fold-change (LFC) estimates for element effects on genes using matrix operations. + + For each element, computes the mean expression in cells targeted by that element versus + baseline (control or random cells), using a pseudocount to avoid log(0). + + Parameters + ---------- + X : np.ndarray or sp.spmatrix + Count matrix (n_cells x n_genes). + guide_obs : np.ndarray or sp.spmatrix + Binary or count matrix of guides per cell (n_cells x n_guides). + guide_by_element : torch.Tensor or np.ndarray + Binary matrix mapping guides to elements (n_guides x n_elements). + control_indices : np.ndarray or None + Boolean or integer indices of control cells (baseline for LFC). + If None, samples 1000 random cells (or all if fewer available). + n_cells_for_control : int + Number of random cells to sample for control if control_indices is None. + pseudocount : float + Pseudocount added to avoid log(0). Default: 0.1. + + Returns + ------- + np.ndarray + Log fold-change matrix (n_elements x n_genes). Entry [i, j] is the LFC + of element i on gene j. + """ + n_guides, n_elements = guide_by_element.shape + n_genes = X.shape[1] + + # Determine baseline (control) indices + if control_indices is None: + n_cells = X.shape[0] + n_control = min(n_cells_for_control, n_cells) + control_indices = np.random.choice(n_cells, size=n_control, replace=False) + elif isinstance(control_indices, (list, np.ndarray)): + if len(control_indices) > 0 and control_indices.dtype == bool: + control_indices = np.where(control_indices)[0] + + # Compute mean expression in control cells + X_control = X[control_indices, :] # (n_control x n_genes) + control_mean = X_control.mean(axis=0) # (n_genes,) + + # For each element, identify cells targeted by any guide targeting that element + # guide_obs: (n_cells x n_guides) + # guide_by_element: (n_guides x n_elements) + # cells_by_element: (n_cells x n_elements) = guide_obs @ guide_by_element + if isinstance(guide_by_element, torch.Tensor): + guide_by_element = guide_by_element.cpu().numpy() + + guide_obs = csc_matrix(guide_obs) + lfc = np.zeros((n_elements, n_genes), dtype=np.float32) + + print("Computing element log fold-change initializations...") + for elem_idx in trange(n_elements): + element_guides = guide_by_element[:, elem_idx] != 0 + guide_idx = guide_obs[:, element_guides].sum(axis=1).A1 != 0 + if guide_idx.any(): + element_mean = X[guide_idx, :].mean(axis=0) + lfc[elem_idx, :] = np.log((element_mean + pseudocount) / (control_mean + pseudocount)) + return lfc + + class PERTURBO(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass): def __init__( self, @@ -96,6 +167,7 @@ def __init__( else: control_guide_idx = grna_counts[:, control_guides].sum(axis=1) > 0 X = X[control_guide_idx, :] + self.control_guides = control_guides n_cells_for_init = X.shape[0] if n_cells_for_init == 0: @@ -119,6 +191,30 @@ def __init__( # model_kwargs["control_pcs"] = v.T.unsqueeze(dim=-2) # print(v) + # Compute element LFC initialization if elements are present and not already in model_kwargs + element_lfc_init = None + if n_elements is not None and guide_by_element is not None and "element_effects_init" not in model_kwargs: + # Use full data for LFC computation (not filtered by control guides) + X_full = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) + grna_counts_full = self.adata_manager.get_from_registry(REGISTRY_KEYS.PERTURBATION_KEY) + + # Determine baseline indices: use control guides if available, else sample 1000 random cells + if control_guides is not None: + if issparse(grna_counts_full): + control_indices = grna_counts_full[:, control_guides].sum(axis=1).A1 > 0 + else: + control_indices = grna_counts_full[:, control_guides].sum(axis=1) > 0 + else: + control_indices = None + + element_lfc_init = compute_element_lfc_initialization( + X=X_full, + guide_obs=grna_counts_full, + guide_by_element=guide_by_element, + control_indices=control_indices, + pseudocount=0.1, + ) + self.module = PerTurboPyroModule( n_cells=self.summary_stats.n_cells, n_batches=self.summary_stats.n_batch, @@ -128,6 +224,7 @@ def __init__( n_elements=n_elements, log_gene_mean_init=torch.tensor(log_means, dtype=torch.float32), log_gene_dispersion_init=torch.tensor(log_disp_smoothed, dtype=torch.float32), + lfc_init=torch.tensor(element_lfc_init, dtype=torch.float32) if element_lfc_init is not None else None, guide_by_element=guide_by_element, gene_by_element=gene_by_element, # n_cats_per_cov=n_cats_per_cov, @@ -166,10 +263,13 @@ def read_matrix_from_registry(self, registry_key: str) -> torch.Tensor: @classmethod def setup_anndata( cls, - adata: AnnData, + adata: AnnOrMuData, **kwargs, ): - raise NotImplementedError("MuData input required, use setup_mudata.") + if isinstance(adata, AnnData): + raise NotImplementedError("MuData input required, use setup_mudata.") + else: + cls.setup_mudata(adata, **kwargs) @classmethod def setup_mudata( @@ -364,18 +464,126 @@ def setup_mudata( adata_manager.register_fields(mdata, **kwargs) cls.register_manager(adata_manager) + def pretrain( + self, + max_epochs: int = 100, + indices: list[int] | list[bool] | None = None, + num_particles: int = 1, + accelerator: str = "cpu", + device: int | str = "auto", + batch_size: int = 1024, + # early_stopping: bool = False, + lr: float | None = 0.01, + ): + """ + Pretrain the model on a subset of the data. + + Parameters + ---------- + indices : array-like + Indices of the subset of data to use for pretraining. + max_epochs : int + Number of passes through the dataset. + accelerator : str + Accelerator type ("cpu", "gpu", etc.). + device : int or str + Device identifier. + batch_size : int + Minibatch size to use during training. + early_stopping : bool + Perform early stopping. + lr : float or None + Optimizer learning rate. + """ + _, _, device = parse_device_args(accelerator, device, return_device="torch", validate_single_device=True) + + loader = AnnDataLoader( + adata_manager=self.adata_manager, + indices=indices, + batch_size=min(batch_size, len(indices)), + data_and_attributes=self.data_and_attrs, + ) + + if self.module.local_effects and self.module.sparse_tensors: + zero_element_effects = torch.zeros((self.module.n_element_effects), dtype=torch.float32, device=device) + zero_guide_efficacy = torch.zeros((self.module.n_guide_effects), dtype=torch.float32, device=device) + else: + zero_element_effects = torch.zeros( + (self.module.n_elements, self.module.n_genes), dtype=torch.float32, device=device + ) + zero_guide_efficacy = torch.full( + (self.module.n_perturbations, self.module.n_genes), fill_value=0.5, dtype=torch.float32, device=device + ) + + self.module.to(device) + pretrain_model = poutine.condition( + self.module.model, + data={ + "element_effects": zero_element_effects, + "guide_efficacy": zero_guide_efficacy, + }, + ) + pretrain_init_values = { + "log_gene_mean": self.module.log_gene_mean_init.to(device), + "log_gene_dispersion": self.module.log_gene_dispersion_init.to(device), + } + + pretrain_guide = self.module._guide_factory( + poutine.block(self.module.model, hide=["element_effects", "guide_efficacy"]), + init_values=pretrain_init_values, + ) + + pretrain_guide.to(device) + svi = SVI( + pretrain_model, + pretrain_guide, + optim.Adam({"lr": lr}), + loss=Trace_ELBO(max_plate_nesting=3, num_particles=num_particles), + ) + losses = [] + + if len(loader) == 1: + batch = next(iter(loader)) + args, kwargs = self.module._get_fn_args_from_batch(batch) + args = tuple(a.to(device) if isinstance(a, torch.Tensor) else a for a in args) + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + kwargs[k] = v.to(device) + + t = trange(max_epochs, desc="Pretrain epochs", leave=True) + for _ in t: + loss = svi.step(*args, **kwargs) + losses.append(loss) + t.set_postfix(loss=loss) + + else: + t = trange(max_epochs, desc="Pretrain epochs", leave=True) + for _ in t: + batch_losses = [] + for batch in loader: + args, kwargs = self.module._get_fn_args_from_batch(batch) + args = tuple(a.to(device) if isinstance(a, torch.Tensor) else a for a in args) + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + kwargs[k] = v.to(device) + loss = svi.step(*args, **kwargs) + batch_losses.append(loss) + t.set_postfix(loss=np.mean(batch_losses)) + losses.append(np.mean(batch_losses)) + return losses + @devices_dsp.dedent def train( self, max_epochs: int = 1000, accelerator: str = "cpu", device: int | str = "auto", - train_size: float = 1.0, - validation_size: float | None = None, - shuffle_set_split: bool = False, batch_size: int = 1024, early_stopping: bool = False, - lr: float | None = 0.005, + indices: list[int] | None = None, + pretrain: bool = False, + blocked_sites: list[str] | None = None, + lr: float | None = 0.01, load_sparse_tensor: bool = "auto", training_plan: PyroTrainingPlan = PyroTrainingPlan, plan_kwargs: dict | None = None, @@ -422,7 +630,68 @@ def train( Any The result of the training runner. """ - plan_kwargs = plan_kwargs if plan_kwargs is not None else {} + _, _, torch_device = parse_device_args(accelerator, device, return_device="torch", validate_single_device=True) + + if not hasattr(self.module, "_guide") or self.module._guide is None: + self.module._guide = self.module._guide_factory( + self.module.model, + init_values={ + "log_gene_mean": self.module.log_gene_mean_init.to(torch_device), + "log_gene_dispersion": self.module.log_gene_dispersion_init.to(torch_device), + "element_effects": self.module.lfc_init.to(torch_device), + }, + ) + + # if self.module.local_effects and self.module.sparse_tensors: + # zero_element_effects = torch.zeros( + # (self.module.n_element_effects), dtype=torch.float32, device=torch_device + # ) + # zero_guide_efficacy = torch.zeros((self.module.n_guide_effects), dtype=torch.float32, device=torch_device) + # else: + # zero_element_effects = torch.zeros( + # (self.module.n_elements, self.module.n_genes), dtype=torch.float32, device=torch_device + # ) + # zero_guide_efficacy = torch.full( + # (self.module.n_perturbations, self.module.n_genes), + # fill_value=0.5, + # dtype=torch.float32, + # device=torch_device, + # ) + + if pretrain and indices is None: + if self.control_guides is None: + raise ValueError("Pretraining requires control_guides to be set during model initialization.") + indices = np.where( + self.adata_manager.get_from_registry(REGISTRY_KEYS.PERTURBATION_KEY)[:, self.control_guides].sum(axis=1) + > 0 + )[0] + assert isinstance(indices, np.ndarray) and len(indices) > 0, ( + "No cells with control guides found for pretraining." + ) + + if indices is not None: + control_indices = [indices, np.setdiff1d(np.arange(self.summary_stats.n_cells), indices), None] + else: + control_indices = None + # else: + # control_indices = None + # self.pretrain( + # max_epochs=max_epochs if pretrain_max_epochs is None else pretrain_max_epochs, + # indices=control_indices, + # accelerator=accelerator, + # device=device, + # batch_size=batch_size, + # lr=lr, + # **(pretrain_kwargs if pretrain_kwargs is not None else {}), + # ) + + plan_kwargs = ( + plan_kwargs + if plan_kwargs is not None + else {"n_epochs_kl_warmup": None, "n_steps_kl_warmup": None, "scale_elbo": 1.0} + ) + if blocked_sites is not None: + plan_kwargs["blocked"] = blocked_sites if len(self.module.discrete_sites) > 0: plan_kwargs.update({"loss_fn": TraceEnum_ELBO(max_plate_nesting=3)}) if lr is not None and "optim" not in plan_kwargs.keys(): @@ -433,12 +702,13 @@ def train( data_splitter_kwargs["data_and_attributes"] = self.data_and_attrs if load_sparse_tensor == "auto": load_sparse_tensor = accelerator == "gpu" + if batch_size is None: # use data splitter which moves data to GPU once data_splitter = DeviceBackedDataSplitter( self.adata_manager, - train_size=train_size, - validation_size=validation_size, + train_size=1, + external_indexing=control_indices, accelerator=accelerator, device=device, **data_splitter_kwargs, @@ -446,9 +716,8 @@ def train( else: data_splitter = DataSplitter( self.adata_manager, - train_size=train_size, - validation_size=validation_size, - shuffle_set_split=shuffle_set_split, + train_size=1, + external_indexing=control_indices, batch_size=batch_size, load_sparse_tensor=load_sparse_tensor, **data_splitter_kwargs, @@ -472,6 +741,7 @@ def train( devices=device, **trainer_kwargs, ) + return runner() def get_element_names(self) -> list: @@ -489,7 +759,64 @@ def get_element_names(self) -> list: element_ids = self.adata_manager.get_state_registry(REGISTRY_KEYS.PERTURBATION_KEY).column_names return element_ids - def get_element_effects(self) -> pd.DataFrame: + def get_z_values(self, return_loc_scale=False) -> pd.DataFrame: + """ + Return a DataFrame summary of the effects for targeted elements on each gene. + + Returns + ------- + pd.DataFrame + DataFrame with columns for effect location, scale, element, gene, z-value, and q-value. + """ + element_ids = self.get_element_names() + gene_ids = self.adata_manager.get_state_registry("X").column_names + + # Check if all element effects are factorized and raise an error if so + if "element_effects" not in self.module.guide.median(): + raise NotImplementedError( + "All element effects are factorized. Use 'get_factorized_element_effects' instead." + ) + else: + for guide in self.module.guide: + if "element_effects" in guide.median(): + loc_values, scale_values = guide._get_loc_and_scale("element_effects") + z_values = loc_values / scale_values + + # loc_values, scale_values = self.module.guide._get_loc_and_scale("element_effects") + + if hasattr(self.module, "element_by_gene_idx"): + # loc/scale_values are the nonzero elements of a sparse matrix of elements by genes + i, j = self.module.element_by_gene_idx.detach().cpu().numpy().astype(int) + z_values_coo = coo_matrix( + (z_values.detach().cpu().numpy(), (i, j)), + shape=(len(element_ids), len(gene_ids)), + ) + loc_values_coo = coo_matrix( + (loc_values.detach().cpu().numpy(), (i, j)), + shape=(len(element_ids), len(gene_ids)), + ) + scale_values_coo = coo_matrix( + (scale_values.detach().cpu().numpy(), (i, j)), + shape=(len(element_ids), len(gene_ids)), + ) + + z_values_matrix = z_values_coo.to_csr() + loc_values_matrix = loc_values_coo.to_csr() + scale_values_matrix = scale_values_coo.to_csr() + else: + z_values_matrix = z_values.detach().cpu().numpy() + loc_values_matrix = loc_values.detach().cpu().numpy() + scale_values_matrix = scale_values.detach().cpu().numpy() + + z_values_df = pd.DataFrame(z_values_matrix, index=element_ids, columns=gene_ids) + loc_values_df = pd.DataFrame(loc_values_matrix, index=element_ids, columns=gene_ids) + scale_values_df = pd.DataFrame(scale_values_matrix, index=element_ids, columns=gene_ids) + + if return_loc_scale: + return z_values_df, loc_values_df, scale_values_df + return z_values_df + + def get_element_effects(self, return_long=True) -> pd.DataFrame: """ Return a DataFrame summary of the effects for targeted elements on each gene. diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index bb08fb4..34eab27 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -1,3 +1,4 @@ +import functools import warnings from collections.abc import Mapping from typing import Literal @@ -6,7 +7,7 @@ import pyro.distributions as dist import torch from pyro import poutine -from pyro.infer.autoguide import AutoDelta, AutoGuideList, AutoNormal, init_to_median +from pyro.infer.autoguide import AutoDelta, AutoGuideList, AutoNormal, init_to_median, init_to_value from scvi.module.base import PyroBaseModuleClass from ._constants import REGISTRY_KEYS @@ -31,10 +32,11 @@ def __init__( n_batches: int | None = 1, log_gene_mean_init: torch.Tensor | None = None, log_gene_dispersion_init: torch.Tensor | None = None, + lfc_init: torch.Tensor | None = None, guide_by_element: torch.Tensor | None = None, gene_by_element: torch.Tensor | None = None, likelihood: Literal["nb", "lnnb"] = "nb", - effect_prior_dist: Literal["cauchy", "normal_mixture", "normal", "laplace"] = "laplace", + effect_prior_dist: Literal["cauchy", "normal_mixture", "normal", "laplace"] = "cauchy", n_factors: int | None = None, n_pert_factors: int | None = None, efficiency_mode: Literal["mixture", "scaled", "mixture_high_moi"] | None = "scaled", @@ -150,35 +152,18 @@ def __init__( self.n_batches = n_batches # Sites to approximate with Delta distribution instead of default Normal distribution. - self.delta_sites = [] + # self.delta_sites = [] + self.delta_sites = ["log_gene_mean", "log_gene_dispersion", "multiplicative_noise"] # self.delta_sites = ["cell_factors"] # self.delta_sites = ["cell_factors", "cell_loadings", "pert_factors", "pert_loadings"] if log_gene_mean_init is None: log_gene_mean_init = torch.zeros(self.n_genes) + self.log_gene_mean_init = log_gene_mean_init if log_gene_dispersion_init is None: log_gene_dispersion_init = torch.ones(self.n_genes) - - # if control_pcs is not None and n_factors is not None: - # init_values["cell_loadings"] = control_pcs - - self._guide = AutoGuideList(self.model, create_plates=self.create_plates) - - self._guide.append( - AutoNormal( - poutine.block(self.model, hide=self.delta_sites + self.discrete_sites), - init_loc_fn=lambda x: init_to_median(x, num_samples=100), - ), - ) - - if self.delta_sites: - self._guide.append( - AutoDelta( - poutine.block(self.model, expose=self.delta_sites), - init_loc_fn=lambda x: init_to_median(x, num_samples=100), - ) - ) + self.log_gene_dispersion_init = log_gene_dispersion_init ## register hyperparameters as buffers so they get automatically moved to GPU by scvi-tools @@ -208,11 +193,14 @@ def __init__( self.register_buffer("one", torch.tensor(1.0)) # per-gene hyperparams - self.register_buffer("gene_mean_prior_loc", log_gene_mean_init) - self.register_buffer("gene_disp_prior_loc", log_gene_dispersion_init) + self.register_buffer("gene_mean_prior_loc", torch.tensor(0.0)) + self.register_buffer("gene_disp_prior_loc", torch.tensor(1.0)) + + self.register_buffer("gene_mean_prior_scale", torch.tensor(3.0)) + self.register_buffer("gene_disp_prior_scale", torch.tensor(1.0)) - self.register_buffer("gene_mean_prior_scale", torch.tensor(0.2)) - self.register_buffer("gene_disp_prior_scale", torch.tensor(0.2)) + self.register_buffer("noise_prior_loc", torch.tensor(-1.0)) + self.register_buffer("noise_prior_scale", torch.tensor(0.5)) # batch/covariate hyperparams self.register_buffer("batch_effect_prior_scale", torch.tensor(0.2)) @@ -243,8 +231,27 @@ def __init__( self.register_buffer("pert_factor_prior_scale", torch.tensor(0.1)) self.register_buffer("pert_loading_prior_scale", torch.tensor(1.0)) - # for LogNormalNegativeBinomial likelihood hyperparams - self.register_buffer("noise_prior_rate", torch.tensor(2.0)) + # create guide with initial values + if lfc_init is None: + lfc_init = torch.zeros((self.n_elements, self.n_genes)) + + if self.local_effects and self.sparse_tensors: + if lfc_init.shape == (self.n_elements, self.n_genes): + lfc_init = lfc_init[self.element_by_gene_idx[0], self.element_by_gene_idx[1]] + assert lfc_init.shape != (self.n_element_effects,), ( + f"lfc_init shape: {lfc_init.shape}, expected ({self.n_element_effects},)" + ) + self.lfc_init = lfc_init + + # self._guide = self._guide_factory(self.model) + # self._guide = self._guide_factory( + # self.model, + # init_values={ + # "log_gene_mean": self.log_gene_mean_init, + # "log_gene_dispersion": self.log_gene_dispersion_init, + # "element_effects": self.lfc_init, + # }, + # ) # override with user-provided values from prior_param_dict if prior_param_dict is not None: @@ -254,6 +261,34 @@ def __init__( assert v.shape == self.get_buffer(k).shape self.register_buffer(k, v) + def _guide_factory(self, model, init_values=None, init_scale=0.05): + guide = AutoGuideList(model, create_plates=self.create_plates) + if init_values is None: + init_values = {} + # init_values = { + # "log_gene_mean": self.log_gene_mean_init, + # "log_gene_dispersion": self.log_gene_dispersion_init, + # "element_effects": self.lfc_init, + # } + init_loc_fn = functools.partial(init_to_value, values=init_values, fallback=init_to_median(num_samples=100)) + + guide.append( + AutoNormal( + poutine.block(model, hide=self.delta_sites + self.discrete_sites), + init_loc_fn=init_loc_fn, + init_scale=init_scale, + ), + ) + + if self.delta_sites: + guide.append( + AutoDelta( + poutine.block(model, expose=self.delta_sites), + init_loc_fn=init_loc_fn, + ) + ) + return guide + @staticmethod def _get_fn_args_from_batch(tensor_dict: dict) -> tuple[tuple[torch.Tensor], dict]: fit_size_factor_covariate = False @@ -547,7 +582,10 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: if self.likelihood == "lnnb": # additional noise for LogNormalNegativeBinomial likelihood - multiplicative_noise = pyro.sample("multiplicative_noise", dist.Exponential(self.noise_prior_rate)) + # multiplicative_noise = pyro.sample("multiplicative_noise", dist.LogNormal(self.noise_prior_rate)) + multiplicative_noise = pyro.sample( + "multiplicative_noise", dist.LogNormal(self.noise_prior_loc, self.noise_prior_scale) + ) # multiplicative_noise = 1 / self.noise_prior_rate with batch_plate: @@ -618,6 +656,7 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: "obs", LogNormalNegativeBinomial( logits=nb_log_mean - nb_log_dispersion - multiplicative_noise**2 / 2, + # logits=nb_log_mean - nb_log_dispersion, total_count=nb_log_dispersion.exp(), multiplicative_noise_scale=multiplicative_noise, num_quad_points=self.lnnb_quad_points, diff --git a/src/perturbo/utils/__init__.py b/src/perturbo/utils/__init__.py new file mode 100644 index 0000000..a0b9a36 --- /dev/null +++ b/src/perturbo/utils/__init__.py @@ -0,0 +1 @@ +from .utils import empirical_pvals_from_null, compute_empirical_pvals diff --git a/src/perturbo/utils/utils.py b/src/perturbo/utils/utils.py new file mode 100644 index 0000000..48e28fd --- /dev/null +++ b/src/perturbo/utils/utils.py @@ -0,0 +1,269 @@ +import numpy as np +import pandas as pd +import anndata as ad +from statsmodels.stats.multitest import multipletests +from typing import Literal + +def compute_empirical_pvals( + data_real: pd.DataFrame, + data_shuffled: pd.DataFrame, + value_col: str ="z_value", + adata: ad.AnnData | None = None, + adata_count_col: str | None = None, + pval_adj_method: str | None = None, + group_col: str | None = None, + two_sided: bool = True, + bias_correction: bool = True, + winsor: float | None = None, + method: Literal["default", "tnull_fixed0"] = "default", + n_quantiles: int | None = None, +): + """ + Compute empirical p-values for real data based on null distribution from shuffled data. + + Parameters + ---------- + data_real : pd.DataFrame + DataFrame containing real test statistics. + data_shuffled : pd.DataFrame + DataFrame containing null test statistics. + value_col : str, default "z_value" + Column name for test statistics in both DataFrames. + adata : AnnData, optional + AnnData object for computing expression quantiles if n_quantiles is specified. + adata_count_col : str, optional + Column in adata.var to use for quantile computation. + pval_adj_method : str or None, optional + Method for multiple testing correction (e.g., 'fdr_bh'). If None, + no adjustment is performed. + group_col : str or None, optional + Column name to group by when computing p-values. If None, compute + p-values globally. + two_sided : bool, default True + If True, compute two-sided p-values; otherwise one-sided. + bias_correction : bool, default True + If True, apply bias correction in empirical p-value calculation. + winsor : float in (0,0.5) or None, optional + If specified, winsorize null statistics at these quantiles before fitting. + method : str, default "default" + Method for p-value computation. Options are "default" or "tnull_fixed0". + n_quantiles : int or None, optional + If specified, compute expression quantiles and use them as group_col. + """ + + if method == "default": + pval_func = empirical_pvals_from_null + elif method == "tnull_fixed0": + pval_func = empirical_pvals_from_tnull_fixed0 + else: + raise ValueError(f"Unknown method: {method}") + + if n_quantiles is not None: + assert adata is not None, "adata must be provided when n_quantiles is specified." + expression_quantiles = compute_quantiles( + adata=adata, + counts_col=None, + n_quantiles=n_quantiles, + ) + data_real = data_real.merge(expression_quantiles, left_on="gene", right_on="gene", how="left") + data_shuffled = data_shuffled.merge(expression_quantiles, left_on="gene", right_on="gene", how="left") + group_col = "mean_quantile" + + if group_col is None: + pvals = pval_func( + null_z=data_shuffled[value_col].values, + real_z=data_real[value_col].values, + two_sided=two_sided, + bias_correction=bias_correction, + winsor=winsor, + ) + if pval_adj_method is not None: + rej, pval_adj, _, _ = multipletests(pvals, alpha=0.05, method=pval_adj_method) + return pval_adj + else: + return pvals + else: + pvals = data_real.groupby(group_col)[value_col].transform( + lambda x: pval_func( + null_z=data_shuffled.query(f"{group_col} == @x.name")[value_col], + real_z=x, + two_sided=two_sided, + bias_correction=bias_correction, + winsor=winsor, + ) + ) + if pval_adj_method is not None: + pval_df = data_real[[group_col]].assign(pval=pvals) + pval_adj = pval_df.groupby(group_col)["pval"].transform( + lambda x: multipletests(x, alpha=0.05, method=pval_adj_method)[1] + ) + return pval_adj + else: + return pvals + + +def empirical_pvals_from_tnull_fixed0( + null_z, + real_z, + two_sided: bool = True, + winsor: float | None = None, + return_params: bool = False, + **kwargs, +): + """ + Fit a Student-t(df, scale) with location fixed at 0 to null z-values, + then compute parametric (empirical) p-values for real z-values. + + Parameters + ---------- + null_z : array-like or pandas.Series + Null z-values (1D). + real_z : array-like or pandas.Series + Observed z-values to evaluate. + two_sided : bool, default True + If True, two-sided p = 2 * sf(|z|); otherwise one-sided upper tail. + winsor : float in (0,0.5), optional + Winsorize null_z at quantiles (winsor, 1-winsor) before fitting. + return_params : bool, default False + If True, return (pvals, params_dict). + + Returns + ------- + pvals : Series or ndarray + Empirical p-values under fitted t(df, loc=0, scale). + params : dict, optional + {'df': df, 'scale': scale, 'n_null': B} + """ + from scipy import stats + + # --- clean nulls + null = pd.Series(null_z, dtype=float).replace([np.inf, -np.inf], np.nan).dropna() + if null.size == 0: + raise ValueError("No valid null statistics.") + B = null.size + + # Optional winsorization for stability + if winsor is not None: + if not (0 < winsor < 0.5): + raise ValueError("winsor must be in (0,0.5)") + q_lo, q_hi = null.quantile([winsor, 1 - winsor]) + null = null.clip(q_lo, q_hi) + + # --- fit Student-t with loc fixed at 0 + df_hat, loc_hat, scale_hat = stats.t.fit(null.values, floc=0.0) + # loc_hat will be 0.0 by construction + + # --- prep real values + if isinstance(real_z, pd.Series): + real = real_z.astype(float).replace([np.inf, -np.inf], np.nan) + real_index = real.index + real = real.values + return_series = True + else: + real = np.asarray(real_z, dtype=float) + real_index = None + return_series = False + + # --- compute p-values + z_std = real / scale_hat + if two_sided: + p = 2.0 * stats.t.sf(np.abs(z_std), df_hat) + else: + p = stats.t.sf(z_std, df_hat) + p = np.clip(p, 0.0, 1.0) + + if return_series: + p = pd.Series(p, index=real_index, name="p_tnull") + + if return_params: + params = {"df": float(df_hat), "scale": float(scale_hat), "n_null": int(B)} + return p, params + return p + + +def empirical_pvals_from_null( + null_z, + real_z, + two_sided: bool = True, + bias_correction: bool = True, + **kwargs, +): + """ + Compute empirical p-values from a pooled null of z-like statistics. + + Parameters + ---------- + null_z : array-like or pandas.Series + 1-D collection of null test statistics (e.g., from shuffled data). + real_z : array-like or pandas.Series + 1-D collection of observed test statistics to evaluate. + two_sided : bool, default True + If True, uses |z| (two-sided). If False, uses one-sided (upper tail). + bias_correction : bool, default True + If True, uses (r+1)/(B+1). If False, uses r/B. + + Returns + ------- + pvals : pandas.Series or np.ndarray + Empirical p-values aligned to real_z (Series preserves index). + """ + # Convert & clean + null = pd.Series(null_z).astype(float).replace([np.inf, -np.inf], np.nan).dropna().values + if null.size == 0: + raise ValueError("No valid null statistics provided.") + + if isinstance(real_z, pd.Series): + real = real_z.astype(float).replace([np.inf, -np.inf], np.nan) + real_index = real.index + real = real.values + return_series = True + else: + real = np.asarray(real_z, dtype=float) + real_index = None + return_series = False + + # Transform for sidedness + if two_sided: + null_t = np.abs(null) + real_t = np.abs(real) + else: + null_t = null + real_t = real + + # Sort null once (ascending) + null_sorted = np.sort(null_t) + B = null_sorted.size + + # For each real stat, count null >= real (upper tail) + # searchsorted gives index of first value >= real_t (with 'left'), + # so r = B - idx + idx = np.searchsorted(null_sorted, real_t, side="left") + r = B - idx # exceedance counts + + if bias_correction: + p = (r + 1.0) / (B + 1.0) + else: + # Guard against division by zero if B==0 (already caught above) + p = r / B + + # Preserve index/type if input was a Series + if return_series: + return pd.Series(p, index=real_index, name="empirical_p") + return p + + +def compute_quantiles(adata, counts_col=None, gene_name_col="gene", n_quantiles=10): + if counts_col is None: + print("computing mean counts for quantile assignment and outputting to 'mean_expression' column") + mean_counts = adata.X.sum(axis=0) / adata.n_obs + if isinstance(mean_counts, np.matrix): + mean_counts = np.asarray(mean_counts).squeeze() + adata.var["mean_expression"] = mean_counts + counts_col = "mean_expression" + + gene_df = adata.var.reset_index(names=[gene_name_col]) + gene_df["mean_quantile"] = pd.qcut(gene_df[counts_col], n_quantiles, labels=False) + 1 + decile_df = gene_df[[gene_name_col, "mean_quantile"]] + return decile_df + + diff --git a/tests/test_basic.py b/tests/test_basic.py index 70661bf..10a64af 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -69,6 +69,11 @@ def test_model_mdata( assert model.summary_stats.n_vars == len(mdata[rna_key].var) assert model.summary_stats.n_perturbations == len(mdata[perturb_key].var) + model.pretrain( + indices=np.random.choice(len(mdata), size=len(mdata) // 2, replace=False), + max_epochs=2, + lr=0.1, + ) model.train( # accelerator="auto", max_epochs=5, @@ -76,6 +81,7 @@ def test_model_mdata( batch_size=2, # load_sparse_tensor=sparse_tensors, ) + # model.train( # # accelerator="auto", # max_epochs=5,