From c31495513db3833d46acf5fa0965b1edf47f06a4 Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Tue, 25 Nov 2025 00:47:25 -0500 Subject: [PATCH 01/10] implement basic pretraining using pyro param store hack --- src/perturbo/models/_model.py | 105 +++++++++++++++++++++++++++++---- src/perturbo/models/_module.py | 36 ++++++----- tests/test_basic.py | 6 ++ 3 files changed, 121 insertions(+), 26 deletions(-) diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index a3ab0d7..f22c390 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -2,27 +2,25 @@ import numpy as np import pandas as pd +import pyro.optim as optim import scipy.sparse as sp import torch from mudata import AnnData, MuData from pandas import DataFrame from pyro import poutine -from pyro.infer import TraceEnum_ELBO, infer_discrete +from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, infer_discrete from scipy.sparse import issparse from scipy.stats import chi2 from scvi._types import AnnOrMuData from scvi.data import AnnDataManager, fields from scvi.dataloaders import AnnDataLoader, DataSplitter, DeviceBackedDataSplitter -from scvi.model.base import ( - BaseModelClass, - PyroJitGuideWarmup, - PyroSampleMixin, - PyroSviTrainMixin, -) +from scvi.model._utils import parse_device_args +from scvi.model.base import ArchesMixin, BaseModelClass, PyroJitGuideWarmup, PyroSampleMixin, PyroSviTrainMixin from scvi.train import PyroTrainingPlan from scvi.utils._docstrings import devices_dsp from sklearn.isotonic import IsotonicRegression from sklearn.linear_model import LinearRegression +from tqdm.auto import trange from ._constants import REGISTRY_KEYS from ._module import PerTurboPyroModule @@ -30,7 +28,7 @@ logger = logging.getLogger(__name__) -class PERTURBO(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass): +class PERTURBO(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass, ArchesMixin): def __init__( self, mdata: AnnOrMuData, @@ -166,10 +164,13 @@ def read_matrix_from_registry(self, registry_key: str) -> torch.Tensor: @classmethod def setup_anndata( cls, - adata: AnnData, + adata: AnnOrMuData, **kwargs, ): - raise NotImplementedError("MuData input required, use setup_mudata.") + if isinstance(adata, AnnData): + raise NotImplementedError("MuData input required, use setup_mudata.") + else: + cls.setup_mudata(adata, **kwargs) @classmethod def setup_mudata( @@ -364,6 +365,90 @@ def setup_mudata( adata_manager.register_fields(mdata, **kwargs) cls.register_manager(adata_manager) + def pretrain( + self, + indices, + max_epochs: int = 100, + accelerator: str = "cpu", + device: int | str = "auto", + batch_size: int = 1024, + # early_stopping: bool = False, + lr: float | None = 0.01, + ): + """ + Pretrain the model on a subset of the data. + + Parameters + ---------- + indices : array-like + Indices of the subset of data to use for pretraining. + max_epochs : int + Number of passes through the dataset. + accelerator : str + Accelerator type ("cpu", "gpu", etc.). + device : int or str + Device identifier. + batch_size : int + Minibatch size to use during training. + early_stopping : bool + Perform early stopping. + lr : float or None + Optimizer learning rate. + """ + _, _, device = parse_device_args(accelerator, device, return_device="torch", validate_single_device=True) + loader = AnnDataLoader( + adata_manager=self.adata_manager, + indices=indices, + batch_size=min(batch_size, len(indices)), + data_and_attributes=self.data_and_attrs, + ) + + if self.module.local_effects and self.module.sparse_tensors: + zero_element_effects = torch.zeros((self.module.n_element_effects), dtype=torch.float32, device=device) + else: + zero_element_effects = torch.zeros( + (self.module.n_elements, self.module.n_genes), dtype=torch.float32, device=device + ) + pretrain_model = poutine.condition( + self.module.model, + data={"element_effects": zero_element_effects} + if self.module.local_effects and self.module.sparse_tensors + else {}, + ) + pretrain_guide = self.module._guide_factory(poutine.block(pretrain_model, hide=["element_effects"])) + + svi = SVI(pretrain_model, pretrain_guide, optim.Adam({"lr": lr}), loss=Trace_ELBO(max_plate_nesting=2)) + losses = [] + + t = trange(max_epochs, desc="SVI epochs", leave=True) + if len(loader) == 1: + batch = next(iter(loader)) + args, kwargs = self.module._get_fn_args_from_batch(batch) + args = tuple(a.to(device) if isinstance(a, torch.Tensor) else a for a in args) + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + kwargs[k] = v.to(device) + + for _ in t: + loss = svi.step(*args, **kwargs) + losses.append(loss) + t.set_postfix(loss=loss) + + else: + for _ in t: + batch_losses = [] + for batch in loader: + args, kwargs = self.module._get_fn_args_from_batch(batch) + args = tuple(a.to(device) if isinstance(a, torch.Tensor) else a for a in args) + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + kwargs[k] = v.to(device) + loss = svi.step(*args, **kwargs) + batch_losses.append(loss) + t.set_postfix(loss=np.mean(batch_losses)) + losses.append(np.mean(batch_losses)) + return losses + @devices_dsp.dedent def train( self, diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index bb08fb4..f569e5c 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -163,22 +163,7 @@ def __init__( # if control_pcs is not None and n_factors is not None: # init_values["cell_loadings"] = control_pcs - self._guide = AutoGuideList(self.model, create_plates=self.create_plates) - - self._guide.append( - AutoNormal( - poutine.block(self.model, hide=self.delta_sites + self.discrete_sites), - init_loc_fn=lambda x: init_to_median(x, num_samples=100), - ), - ) - - if self.delta_sites: - self._guide.append( - AutoDelta( - poutine.block(self.model, expose=self.delta_sites), - init_loc_fn=lambda x: init_to_median(x, num_samples=100), - ) - ) + self._guide = self._guide_factory(self.model) ## register hyperparameters as buffers so they get automatically moved to GPU by scvi-tools @@ -254,6 +239,25 @@ def __init__( assert v.shape == self.get_buffer(k).shape self.register_buffer(k, v) + def _guide_factory(self, model): + guide = AutoGuideList(model, create_plates=self.create_plates) + + guide.append( + AutoNormal( + poutine.block(model, hide=self.delta_sites + self.discrete_sites), + init_loc_fn=lambda x: init_to_median(x, num_samples=100), + ), + ) + + if self.delta_sites: + guide.append( + AutoDelta( + poutine.block(model, expose=self.delta_sites), + init_loc_fn=lambda x: init_to_median(x, num_samples=100), + ) + ) + return guide + @staticmethod def _get_fn_args_from_batch(tensor_dict: dict) -> tuple[tuple[torch.Tensor], dict]: fit_size_factor_covariate = False diff --git a/tests/test_basic.py b/tests/test_basic.py index 70661bf..10a64af 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -69,6 +69,11 @@ def test_model_mdata( assert model.summary_stats.n_vars == len(mdata[rna_key].var) assert model.summary_stats.n_perturbations == len(mdata[perturb_key].var) + model.pretrain( + indices=np.random.choice(len(mdata), size=len(mdata) // 2, replace=False), + max_epochs=2, + lr=0.1, + ) model.train( # accelerator="auto", max_epochs=5, @@ -76,6 +81,7 @@ def test_model_mdata( batch_size=2, # load_sparse_tensor=sparse_tensors, ) + # model.train( # # accelerator="auto", # max_epochs=5, From a3e31994c038316b4bdb01a6b6d98326ed4bf0e3 Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Thu, 4 Dec 2025 04:24:44 +0000 Subject: [PATCH 02/10] add gpu support, update pretrain recipe --- src/perturbo/models/_model.py | 80 +++++++++++++++++++++++++++++----- src/perturbo/models/_module.py | 40 ++++++++++++----- 2 files changed, 99 insertions(+), 21 deletions(-) diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index f22c390..5fb2c75 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -9,13 +9,13 @@ from pandas import DataFrame from pyro import poutine from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, infer_discrete -from scipy.sparse import issparse +from scipy.sparse import coo_matrix, issparse from scipy.stats import chi2 from scvi._types import AnnOrMuData from scvi.data import AnnDataManager, fields from scvi.dataloaders import AnnDataLoader, DataSplitter, DeviceBackedDataSplitter from scvi.model._utils import parse_device_args -from scvi.model.base import ArchesMixin, BaseModelClass, PyroJitGuideWarmup, PyroSampleMixin, PyroSviTrainMixin +from scvi.model.base import BaseModelClass, PyroJitGuideWarmup, PyroSampleMixin, PyroSviTrainMixin from scvi.train import PyroTrainingPlan from scvi.utils._docstrings import devices_dsp from sklearn.isotonic import IsotonicRegression @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) -class PERTURBO(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass, ArchesMixin): +class PERTURBO(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass): def __init__( self, mdata: AnnOrMuData, @@ -94,6 +94,7 @@ def __init__( else: control_guide_idx = grna_counts[:, control_guides].sum(axis=1) > 0 X = X[control_guide_idx, :] + self.control_guides = control_guides n_cells_for_init = X.shape[0] if n_cells_for_init == 0: @@ -367,8 +368,9 @@ def setup_mudata( def pretrain( self, - indices, max_epochs: int = 100, + indices: list[int] | list[bool] | None = None, + num_particles: int = 1, accelerator: str = "cpu", device: int | str = "auto", batch_size: int = 1024, @@ -396,6 +398,7 @@ def pretrain( Optimizer learning rate. """ _, _, device = parse_device_args(accelerator, device, return_device="torch", validate_single_device=True) + loader = AnnDataLoader( adata_manager=self.adata_manager, indices=indices, @@ -405,19 +408,38 @@ def pretrain( if self.module.local_effects and self.module.sparse_tensors: zero_element_effects = torch.zeros((self.module.n_element_effects), dtype=torch.float32, device=device) + zero_guide_efficacy = torch.zeros((self.module.n_guide_effects), dtype=torch.float32, device=device) else: zero_element_effects = torch.zeros( (self.module.n_elements, self.module.n_genes), dtype=torch.float32, device=device ) + zero_guide_efficacy = torch.full( + (self.module.n_perturbations, self.module.n_genes), fill_value=0.5, dtype=torch.float32, device=device + ) + + self.module.to(device) pretrain_model = poutine.condition( self.module.model, - data={"element_effects": zero_element_effects} - if self.module.local_effects and self.module.sparse_tensors - else {}, + data={ + "element_effects": zero_element_effects, + "guide_efficacy": zero_guide_efficacy, + }, + ) + pretrain_guide = self.module._guide_factory( + poutine.block(self.module.model, hide=["element_effects", "guide_efficacy"]), + init_values={ + "log_gene_mean": self.module.log_gene_mean_init.to(device), + "log_gene_dispersion": self.module.log_gene_dispersion_init.to(device), + }, ) - pretrain_guide = self.module._guide_factory(poutine.block(pretrain_model, hide=["element_effects"])) - svi = SVI(pretrain_model, pretrain_guide, optim.Adam({"lr": lr}), loss=Trace_ELBO(max_plate_nesting=2)) + pretrain_guide.to(device) + svi = SVI( + pretrain_model, + pretrain_guide, + optim.Adam({"lr": lr}), + loss=Trace_ELBO(max_plate_nesting=3, num_particles=num_particles), + ) losses = [] t = trange(max_epochs, desc="SVI epochs", leave=True) @@ -557,6 +579,7 @@ def train( devices=device, **trainer_kwargs, ) + return runner() def get_element_names(self) -> list: @@ -574,7 +597,44 @@ def get_element_names(self) -> list: element_ids = self.adata_manager.get_state_registry(REGISTRY_KEYS.PERTURBATION_KEY).column_names return element_ids - def get_element_effects(self) -> pd.DataFrame: + def get_z_values(self) -> pd.DataFrame: + """ + Return a DataFrame summary of the effects for targeted elements on each gene. + + Returns + ------- + pd.DataFrame + DataFrame with columns for effect location, scale, element, gene, z-value, and q-value. + """ + element_ids = self.get_element_names() + gene_ids = self.adata_manager.get_state_registry("X").column_names + + # Check if all element effects are factorized and raise an error if so + if "element_effects" not in self.module.guide.median(): + raise NotImplementedError( + "All element effects are factorized. Use 'get_factorized_element_effects' instead." + ) + else: + for guide in self.module.guide: + if "element_effects" in guide.median(): + loc_values, scale_values = guide._get_loc_and_scale("element_effects") + z_values = loc_values / scale_values + + # loc_values, scale_values = self.module.guide._get_loc_and_scale("element_effects") + + if hasattr(self.module, "element_by_gene_idx"): + # loc/scale_values are the nonzero elements of a sparse matrix of elements by genes + i, j = self.module.element_by_gene_idx.detach().cpu().numpy().astype(int) + z_values_coo = coo_matrix( + (z_values.detach().cpu().numpy(), (i, j)), + shape=(len(element_ids), len(gene_ids)), + ) + z_values_matrix = z_values_coo.to_csr() + else: + z_values_matrix = z_values.detach().cpu().numpy() + return pd.DataFrame(z_values_matrix, index=element_ids, columns=gene_ids) + + def get_element_effects(self, return_long=True) -> pd.DataFrame: """ Return a DataFrame summary of the effects for targeted elements on each gene. diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index f569e5c..0a10b73 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -1,3 +1,4 @@ +import functools import warnings from collections.abc import Mapping from typing import Literal @@ -6,7 +7,7 @@ import pyro.distributions as dist import torch from pyro import poutine -from pyro.infer.autoguide import AutoDelta, AutoGuideList, AutoNormal, init_to_median +from pyro.infer.autoguide import AutoDelta, AutoGuideList, AutoNormal, init_to_median, init_to_value from scvi.module.base import PyroBaseModuleClass from ._constants import REGISTRY_KEYS @@ -156,9 +157,11 @@ def __init__( if log_gene_mean_init is None: log_gene_mean_init = torch.zeros(self.n_genes) + self.log_gene_mean_init = log_gene_mean_init if log_gene_dispersion_init is None: log_gene_dispersion_init = torch.ones(self.n_genes) + self.log_gene_dispersion_init = log_gene_dispersion_init # if control_pcs is not None and n_factors is not None: # init_values["cell_loadings"] = control_pcs @@ -193,11 +196,14 @@ def __init__( self.register_buffer("one", torch.tensor(1.0)) # per-gene hyperparams - self.register_buffer("gene_mean_prior_loc", log_gene_mean_init) - self.register_buffer("gene_disp_prior_loc", log_gene_dispersion_init) + self.register_buffer("gene_mean_prior_loc", torch.tensor(0.0)) + self.register_buffer("gene_disp_prior_loc", torch.tensor(1.0)) - self.register_buffer("gene_mean_prior_scale", torch.tensor(0.2)) - self.register_buffer("gene_disp_prior_scale", torch.tensor(0.2)) + self.register_buffer("gene_mean_prior_scale", torch.tensor(2.0)) + self.register_buffer("gene_disp_prior_scale", torch.tensor(1.0)) + + self.register_buffer("noise_prior_loc", torch.tensor(-2.0)) + self.register_buffer("noise_prior_scale", torch.tensor(0.5)) # batch/covariate hyperparams self.register_buffer("batch_effect_prior_scale", torch.tensor(0.2)) @@ -213,7 +219,7 @@ def __init__( ## element effect size hyperparams # Normal/Laplace/Cauchy prior - effect_prior_scales = {"cauchy": 0.1, "laplace": 0.5, "normal": 1.0} + effect_prior_scales = {"cauchy": 0.2, "laplace": 0.5, "normal": 1.0} model_effect_prior_scale = effect_prior_scales[effect_prior_dist] self.register_buffer("element_effects_prior_scale", torch.tensor(model_effect_prior_scale)) self.register_buffer("guide_effects_prior_scale", torch.tensor(model_effect_prior_scale)) @@ -229,7 +235,7 @@ def __init__( self.register_buffer("pert_loading_prior_scale", torch.tensor(1.0)) # for LogNormalNegativeBinomial likelihood hyperparams - self.register_buffer("noise_prior_rate", torch.tensor(2.0)) + # self.register_buffer("noise_prior_rate", torch.tensor(2.0)) # override with user-provided values from prior_param_dict if prior_param_dict is not None: @@ -239,13 +245,21 @@ def __init__( assert v.shape == self.get_buffer(k).shape self.register_buffer(k, v) - def _guide_factory(self, model): + def _guide_factory(self, model, init_values=None): guide = AutoGuideList(model, create_plates=self.create_plates) + if init_values is None: + init_values = {} + # if init_values is None: + # init_values = { + # "log_gene_mean": self.log_gene_mean_init, + # "log_gene_dispersion": self.log_gene_dispersion_init, + # } + init_loc_fn = functools.partial(init_to_value, values=init_values, fallback=init_to_median(num_samples=100)) guide.append( AutoNormal( poutine.block(model, hide=self.delta_sites + self.discrete_sites), - init_loc_fn=lambda x: init_to_median(x, num_samples=100), + init_loc_fn=init_loc_fn, ), ) @@ -253,7 +267,7 @@ def _guide_factory(self, model): guide.append( AutoDelta( poutine.block(model, expose=self.delta_sites), - init_loc_fn=lambda x: init_to_median(x, num_samples=100), + init_loc_fn=init_loc_fn, ) ) return guide @@ -551,7 +565,10 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: if self.likelihood == "lnnb": # additional noise for LogNormalNegativeBinomial likelihood - multiplicative_noise = pyro.sample("multiplicative_noise", dist.Exponential(self.noise_prior_rate)) + # multiplicative_noise = pyro.sample("multiplicative_noise", dist.LogNormal(self.noise_prior_rate)) + multiplicative_noise = pyro.sample( + "multiplicative_noise", dist.LogNormal(self.noise_prior_loc, self.noise_prior_scale) + ) # multiplicative_noise = 1 / self.noise_prior_rate with batch_plate: @@ -622,6 +639,7 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: "obs", LogNormalNegativeBinomial( logits=nb_log_mean - nb_log_dispersion - multiplicative_noise**2 / 2, + # logits=nb_log_mean - nb_log_dispersion, total_count=nb_log_dispersion.exp(), multiplicative_noise_scale=multiplicative_noise, num_quad_points=self.lnnb_quad_points, From 947fef7dd072c9c2996b3d3326ac50a565f06b26 Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Fri, 5 Dec 2025 03:42:21 +0000 Subject: [PATCH 03/10] add optional pretraining routine to main train fn --- .gitignore | 1 + src/perturbo/models/_model.py | 23 +++- src/perturbo/utils/__init__.py | 1 + src/perturbo/utils/utils.py | 191 +++++++++++++++++++++++++++++++++ 4 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 src/perturbo/utils/__init__.py create mode 100644 src/perturbo/utils/utils.py diff --git a/.gitignore b/.gitignore index c7e3604..6aaec02 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,4 @@ docs/notebooks/*.pt /.hatch/ src/perturbo/simulation/.ipynb_checkpoints/ +lightning_logs/ \ No newline at end of file diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index 5fb2c75..78e48bd 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -442,7 +442,6 @@ def pretrain( ) losses = [] - t = trange(max_epochs, desc="SVI epochs", leave=True) if len(loader) == 1: batch = next(iter(loader)) args, kwargs = self.module._get_fn_args_from_batch(batch) @@ -451,12 +450,14 @@ def pretrain( if isinstance(v, torch.Tensor): kwargs[k] = v.to(device) + t = trange(max_epochs, desc="SVI epochs", leave=True) for _ in t: loss = svi.step(*args, **kwargs) losses.append(loss) t.set_postfix(loss=loss) else: + t = trange(max_epochs, desc="SVI epochs", leave=True) for _ in t: batch_losses = [] for batch in loader: @@ -482,6 +483,9 @@ def train( shuffle_set_split: bool = False, batch_size: int = 1024, early_stopping: bool = False, + pretrain: bool = False, + pretrain_kwargs=None, + pretrain_max_epochs: int | None = None, lr: float | None = 0.005, load_sparse_tensor: bool = "auto", training_plan: PyroTrainingPlan = PyroTrainingPlan, @@ -529,6 +533,23 @@ def train( Any The result of the training runner. """ + if pretrain: + if self.control_guides is None: + raise ValueError("Pretraining requires control_guides to be set during model initialization.") + control_indices = ( + self.adata_manager.get_from_registry(REGISTRY_KEYS.PERTURBATION_KEY)[:, self.control_guides].sum(axis=1) + > 0 + ) + self.pretrain( + max_epochs=max_epochs if pretrain_max_epochs is None else pretrain_max_epochs, + indices=control_indices, + accelerator=accelerator, + device=device, + batch_size=batch_size, + lr=lr, + **(pretrain_kwargs if pretrain_kwargs is not None else {}), + ) + plan_kwargs = plan_kwargs if plan_kwargs is not None else {} if len(self.module.discrete_sites) > 0: plan_kwargs.update({"loss_fn": TraceEnum_ELBO(max_plate_nesting=3)}) diff --git a/src/perturbo/utils/__init__.py b/src/perturbo/utils/__init__.py new file mode 100644 index 0000000..a0b9a36 --- /dev/null +++ b/src/perturbo/utils/__init__.py @@ -0,0 +1 @@ +from .utils import empirical_pvals_from_null, compute_empirical_pvals diff --git a/src/perturbo/utils/utils.py b/src/perturbo/utils/utils.py new file mode 100644 index 0000000..17bb2c6 --- /dev/null +++ b/src/perturbo/utils/utils.py @@ -0,0 +1,191 @@ +import numpy as np +import pandas as pd +from statsmodels.stats.multitest import multipletests + + +def empirical_pvals_from_tnull_fixed0( + null_z, + real_z, + two_sided: bool = True, + winsor: float | None = None, + return_params: bool = False, +): + """ + Fit a Student-t(df, scale) with location fixed at 0 to null z-values, + then compute parametric (empirical) p-values for real z-values. + + Parameters + ---------- + null_z : array-like or pandas.Series + Null z-values (1D). + real_z : array-like or pandas.Series + Observed z-values to evaluate. + two_sided : bool, default True + If True, two-sided p = 2 * sf(|z|); otherwise one-sided upper tail. + winsor : float in (0,0.5), optional + Winsorize null_z at quantiles (winsor, 1-winsor) before fitting. + return_params : bool, default False + If True, return (pvals, params_dict). + + Returns + ------- + pvals : Series or ndarray + Empirical p-values under fitted t(df, loc=0, scale). + params : dict, optional + {'df': df, 'scale': scale, 'n_null': B} + """ + from scipy import stats + + # --- clean nulls + null = pd.Series(null_z, dtype=float).replace([np.inf, -np.inf], np.nan).dropna() + if null.size == 0: + raise ValueError("No valid null statistics.") + B = null.size + + # Optional winsorization for stability + if winsor is not None: + if not (0 < winsor < 0.5): + raise ValueError("winsor must be in (0,0.5)") + q_lo, q_hi = null.quantile([winsor, 1 - winsor]) + null = null.clip(q_lo, q_hi) + + # --- fit Student-t with loc fixed at 0 + df_hat, loc_hat, scale_hat = stats.t.fit(null.values, floc=0.0) + # loc_hat will be 0.0 by construction + + # --- prep real values + if isinstance(real_z, pd.Series): + real = real_z.astype(float).replace([np.inf, -np.inf], np.nan) + real_index = real.index + real = real.values + return_series = True + else: + real = np.asarray(real_z, dtype=float) + real_index = None + return_series = False + + # --- compute p-values + z_std = real / scale_hat + if two_sided: + p = 2.0 * stats.t.sf(np.abs(z_std), df_hat) + else: + p = stats.t.sf(z_std, df_hat) + p = np.clip(p, 0.0, 1.0) + + if return_series: + p = pd.Series(p, index=real_index, name="p_tnull") + + if return_params: + params = {"df": float(df_hat), "scale": float(scale_hat), "n_null": int(B)} + return p, params + return p + + +def empirical_pvals_from_null( + null_z, + real_z, + two_sided: bool = True, + bias_correction: bool = True, +): + """ + Compute empirical p-values from a pooled null of z-like statistics. + + Parameters + ---------- + null_z : array-like or pandas.Series + 1-D collection of null test statistics (e.g., from shuffled data). + real_z : array-like or pandas.Series + 1-D collection of observed test statistics to evaluate. + two_sided : bool, default True + If True, uses |z| (two-sided). If False, uses one-sided (upper tail). + bias_correction : bool, default True + If True, uses (r+1)/(B+1). If False, uses r/B. + + Returns + ------- + pvals : pandas.Series or np.ndarray + Empirical p-values aligned to real_z (Series preserves index). + """ + # Convert & clean + null = pd.Series(null_z).astype(float).replace([np.inf, -np.inf], np.nan).dropna().values + if null.size == 0: + raise ValueError("No valid null statistics provided.") + + if isinstance(real_z, pd.Series): + real = real_z.astype(float).replace([np.inf, -np.inf], np.nan) + real_index = real.index + real = real.values + return_series = True + else: + real = np.asarray(real_z, dtype=float) + real_index = None + return_series = False + + # Transform for sidedness + if two_sided: + null_t = np.abs(null) + real_t = np.abs(real) + else: + null_t = null + real_t = real + + # Sort null once (ascending) + null_sorted = np.sort(null_t) + B = null_sorted.size + + # For each real stat, count null >= real (upper tail) + # searchsorted gives index of first value >= real_t (with 'left'), + # so r = B - idx + idx = np.searchsorted(null_sorted, real_t, side="left") + r = B - idx # exceedance counts + + if bias_correction: + p = (r + 1.0) / (B + 1.0) + else: + # Guard against division by zero if B==0 (already caught above) + p = r / B + + # Preserve index/type if input was a Series + if return_series: + return pd.Series(p, index=real_index, name="empirical_p") + return p + + +def compute_empirical_pvals( + data_real, + data_shuffled, + value_col="z_value", + pval_adj_method=None, + group_col=None, + two_sided=True, + bias_correction=True, +): + if group_col is None: + pvals = empirical_pvals_from_null( + null_z=data_shuffled[value_col].values, + real_z=data_real[value_col].values, + two_sided=two_sided, + bias_correction=bias_correction, + ) + if pval_adj_method is not None: + rej, pval_adj, _, _ = multipletests(pvals, alpha=0.05, method=pval_adj_method) + return pval_adj + else: + return pvals + else: + pvals = data_real.groupby(group_col)[value_col].transform( + lambda x: empirical_pvals_from_null( + null_z=data_shuffled.query(f"{group_col} == @x.name")[value_col], + real_z=x, + two_sided=two_sided, + bias_correction=bias_correction, + ) + ) + if pval_adj_method is not None: + pval_df = data_real[[group_col]].assign(pval=pvals) + pval_adj = pval_df.groupby(group_col)["pval"].transform( + lambda x: multipletests(x, alpha=0.05, method=pval_adj_method)[1] + ) + return pval_adj + else: + return pvals From a13689f0d80a30acd10355d3dd4842fad3fc2c35 Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Tue, 9 Dec 2025 04:49:34 +0000 Subject: [PATCH 04/10] initialize lfc values to basic statistical estimates --- src/perturbo/models/_model.py | 232 ++++++++++++++++++++++++++++----- src/perturbo/models/_module.py | 45 +++++-- 2 files changed, 232 insertions(+), 45 deletions(-) diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index 78e48bd..1e53cbf 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -28,6 +28,88 @@ logger = logging.getLogger(__name__) +def compute_element_lfc_initialization( + X: np.ndarray | sp.spmatrix, + guide_obs: np.ndarray | sp.spmatrix, + guide_by_element: torch.Tensor | np.ndarray, + control_indices: np.ndarray | None = None, + n_cells_for_control: int = 10000, + pseudocount: float = 0.1, +) -> np.ndarray: + """ + Compute log fold-change (LFC) estimates for element effects on genes using matrix operations. + + For each element, computes the mean expression in cells targeted by that element versus + baseline (control or random cells), using a pseudocount to avoid log(0). + + Parameters + ---------- + X : np.ndarray or sp.spmatrix + Count matrix (n_cells x n_genes). + guide_obs : np.ndarray or sp.spmatrix + Binary or count matrix of guides per cell (n_cells x n_guides). + guide_by_element : torch.Tensor or np.ndarray + Binary matrix mapping guides to elements (n_guides x n_elements). + control_indices : np.ndarray or None + Boolean or integer indices of control cells (baseline for LFC). + If None, samples 1000 random cells (or all if fewer available). + n_cells_for_control : int + Number of random cells to sample for control if control_indices is None. + pseudocount : float + Pseudocount added to avoid log(0). Default: 0.1. + + Returns + ------- + np.ndarray + Log fold-change matrix (n_elements x n_genes). Entry [i, j] is the LFC + of element i on gene j. + """ + # Convert inputs to numpy arrays if needed + if isinstance(guide_by_element, torch.Tensor): + guide_by_element = guide_by_element.detach().cpu().numpy() + if isinstance(guide_obs, sp.spmatrix): + guide_obs = guide_obs.toarray() + if isinstance(X, sp.spmatrix): + X = X.toarray() + + n_guides, n_elements = guide_by_element.shape + n_genes = X.shape[1] + + # Determine baseline (control) indices + if control_indices is None: + n_cells = X.shape[0] + n_control = min(n_cells_for_control, n_cells) + control_indices = np.random.choice(n_cells, size=n_control, replace=False) + elif isinstance(control_indices, (list, np.ndarray)): + if len(control_indices) > 0 and control_indices.dtype == bool: + control_indices = np.where(control_indices)[0] + + # Compute mean expression in control cells + X_control = X[control_indices, :] # (n_control x n_genes) + control_mean = X_control.mean(axis=0) # (n_genes,) + + # For each element, identify cells targeted by any guide targeting that element + # guide_obs: (n_cells x n_guides) + # guide_by_element: (n_guides x n_elements) + # cells_by_element: (n_cells x n_elements) = guide_obs @ guide_by_element + cells_by_element = guide_obs @ guide_by_element # (n_cells x n_elements) + cells_by_element = cells_by_element > 0 # Convert to boolean + + # Compute mean expression for each element + element_means = np.zeros((n_elements, n_genes), dtype=np.float32) + for elem_idx in range(n_elements): + elem_cell_mask = cells_by_element[:, elem_idx] + if elem_cell_mask.sum() > 0: + element_means[elem_idx, :] = X[elem_cell_mask, :].mean(axis=0) + else: + element_means[elem_idx, :] = 0.0 + + # Compute log fold-change with pseudocount + lfc = np.log((element_means + pseudocount) / (control_mean[np.newaxis, :] + pseudocount)) + + return lfc + + class PERTURBO(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass): def __init__( self, @@ -118,6 +200,30 @@ def __init__( # model_kwargs["control_pcs"] = v.T.unsqueeze(dim=-2) # print(v) + # Compute element LFC initialization if elements are present and not already in model_kwargs + element_lfc_init = None + if n_elements is not None and guide_by_element is not None and "element_effects_init" not in model_kwargs: + # Use full data for LFC computation (not filtered by control guides) + X_full = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) + grna_counts_full = self.adata_manager.get_from_registry(REGISTRY_KEYS.PERTURBATION_KEY) + + # Determine baseline indices: use control guides if available, else sample 1000 random cells + if control_guides is not None: + if issparse(grna_counts_full): + control_indices = grna_counts_full[:, control_guides].sum(axis=1).A1 > 0 + else: + control_indices = grna_counts_full[:, control_guides].sum(axis=1) > 0 + else: + control_indices = None + + element_lfc_init = compute_element_lfc_initialization( + X=X_full, + guide_obs=grna_counts_full, + guide_by_element=guide_by_element, + control_indices=control_indices, + pseudocount=0.1, + ) + self.module = PerTurboPyroModule( n_cells=self.summary_stats.n_cells, n_batches=self.summary_stats.n_batch, @@ -127,6 +233,7 @@ def __init__( n_elements=n_elements, log_gene_mean_init=torch.tensor(log_means, dtype=torch.float32), log_gene_dispersion_init=torch.tensor(log_disp_smoothed, dtype=torch.float32), + lfc_init=torch.tensor(element_lfc_init, dtype=torch.float32) if element_lfc_init is not None else None, guide_by_element=guide_by_element, gene_by_element=gene_by_element, # n_cats_per_cov=n_cats_per_cov, @@ -425,12 +532,14 @@ def pretrain( "guide_efficacy": zero_guide_efficacy, }, ) + pretrain_init_values = { + "log_gene_mean": self.module.log_gene_mean_init.to(device), + "log_gene_dispersion": self.module.log_gene_dispersion_init.to(device), + } + pretrain_guide = self.module._guide_factory( poutine.block(self.module.model, hide=["element_effects", "guide_efficacy"]), - init_values={ - "log_gene_mean": self.module.log_gene_mean_init.to(device), - "log_gene_dispersion": self.module.log_gene_dispersion_init.to(device), - }, + init_values=pretrain_init_values, ) pretrain_guide.to(device) @@ -450,14 +559,14 @@ def pretrain( if isinstance(v, torch.Tensor): kwargs[k] = v.to(device) - t = trange(max_epochs, desc="SVI epochs", leave=True) + t = trange(max_epochs, desc="Pretrain epochs", leave=True) for _ in t: loss = svi.step(*args, **kwargs) losses.append(loss) t.set_postfix(loss=loss) else: - t = trange(max_epochs, desc="SVI epochs", leave=True) + t = trange(max_epochs, desc="Pretrain epochs", leave=True) for _ in t: batch_losses = [] for batch in loader: @@ -478,15 +587,12 @@ def train( max_epochs: int = 1000, accelerator: str = "cpu", device: int | str = "auto", - train_size: float = 1.0, - validation_size: float | None = None, - shuffle_set_split: bool = False, batch_size: int = 1024, early_stopping: bool = False, + indices: list[int] | None = None, pretrain: bool = False, - pretrain_kwargs=None, - pretrain_max_epochs: int | None = None, - lr: float | None = 0.005, + blocked_sites: list[str] | None = None, + lr: float | None = 0.01, load_sparse_tensor: bool = "auto", training_plan: PyroTrainingPlan = PyroTrainingPlan, plan_kwargs: dict | None = None, @@ -533,24 +639,68 @@ def train( Any The result of the training runner. """ - if pretrain: + _, _, torch_device = parse_device_args(accelerator, device, return_device="torch", validate_single_device=True) + + if not hasattr(self.module, "_guide") or self.module._guide is None: + self.module._guide = self.module._guide_factory( + self.module.model, + init_values={ + "log_gene_mean": self.module.log_gene_mean_init.to(torch_device), + "log_gene_dispersion": self.module.log_gene_dispersion_init.to(torch_device), + "element_effects": self.module.lfc_init.to(torch_device), + }, + ) + + # if self.module.local_effects and self.module.sparse_tensors: + # zero_element_effects = torch.zeros( + # (self.module.n_element_effects), dtype=torch.float32, device=torch_device + # ) + # zero_guide_efficacy = torch.zeros((self.module.n_guide_effects), dtype=torch.float32, device=torch_device) + # else: + # zero_element_effects = torch.zeros( + # (self.module.n_elements, self.module.n_genes), dtype=torch.float32, device=torch_device + # ) + # zero_guide_efficacy = torch.full( + # (self.module.n_perturbations, self.module.n_genes), + # fill_value=0.5, + # dtype=torch.float32, + # device=torch_device, + # ) + + if pretrain and indices is None: if self.control_guides is None: raise ValueError("Pretraining requires control_guides to be set during model initialization.") - control_indices = ( + indices = np.where( self.adata_manager.get_from_registry(REGISTRY_KEYS.PERTURBATION_KEY)[:, self.control_guides].sum(axis=1) > 0 - ) - self.pretrain( - max_epochs=max_epochs if pretrain_max_epochs is None else pretrain_max_epochs, - indices=control_indices, - accelerator=accelerator, - device=device, - batch_size=batch_size, - lr=lr, - **(pretrain_kwargs if pretrain_kwargs is not None else {}), + )[0] + assert isinstance(indices, np.ndarray) and len(indices) > 0, ( + "No cells with control guides found for pretraining." ) - plan_kwargs = plan_kwargs if plan_kwargs is not None else {} + if indices is not None: + control_indices = [indices, np.setdiff1d(np.arange(self.summary_stats.n_cells), indices), None] + else: + control_indices = None + # else: + # control_indices = None + # self.pretrain( + # max_epochs=max_epochs if pretrain_max_epochs is None else pretrain_max_epochs, + # indices=control_indices, + # accelerator=accelerator, + # device=device, + # batch_size=batch_size, + # lr=lr, + # **(pretrain_kwargs if pretrain_kwargs is not None else {}), + # ) + + plan_kwargs = ( + plan_kwargs + if plan_kwargs is not None + else {"n_epochs_kl_warmup": None, "n_steps_kl_warmup": None, "scale_elbo": 1.0} + ) + if blocked_sites is not None: + plan_kwargs["blocked"] = blocked_sites if len(self.module.discrete_sites) > 0: plan_kwargs.update({"loss_fn": TraceEnum_ELBO(max_plate_nesting=3)}) if lr is not None and "optim" not in plan_kwargs.keys(): @@ -561,12 +711,13 @@ def train( data_splitter_kwargs["data_and_attributes"] = self.data_and_attrs if load_sparse_tensor == "auto": load_sparse_tensor = accelerator == "gpu" + if batch_size is None: # use data splitter which moves data to GPU once data_splitter = DeviceBackedDataSplitter( self.adata_manager, - train_size=train_size, - validation_size=validation_size, + train_size=1, + external_indexing=control_indices, accelerator=accelerator, device=device, **data_splitter_kwargs, @@ -574,9 +725,8 @@ def train( else: data_splitter = DataSplitter( self.adata_manager, - train_size=train_size, - validation_size=validation_size, - shuffle_set_split=shuffle_set_split, + train_size=1, + external_indexing=control_indices, batch_size=batch_size, load_sparse_tensor=load_sparse_tensor, **data_splitter_kwargs, @@ -618,7 +768,7 @@ def get_element_names(self) -> list: element_ids = self.adata_manager.get_state_registry(REGISTRY_KEYS.PERTURBATION_KEY).column_names return element_ids - def get_z_values(self) -> pd.DataFrame: + def get_z_values(self, return_loc_scale=False) -> pd.DataFrame: """ Return a DataFrame summary of the effects for targeted elements on each gene. @@ -650,10 +800,30 @@ def get_z_values(self) -> pd.DataFrame: (z_values.detach().cpu().numpy(), (i, j)), shape=(len(element_ids), len(gene_ids)), ) + loc_values_coo = coo_matrix( + (loc_values.detach().cpu().numpy(), (i, j)), + shape=(len(element_ids), len(gene_ids)), + ) + scale_values_coo = coo_matrix( + (scale_values.detach().cpu().numpy(), (i, j)), + shape=(len(element_ids), len(gene_ids)), + ) + z_values_matrix = z_values_coo.to_csr() + loc_values_matrix = loc_values_coo.to_csr() + scale_values_matrix = scale_values_coo.to_csr() else: z_values_matrix = z_values.detach().cpu().numpy() - return pd.DataFrame(z_values_matrix, index=element_ids, columns=gene_ids) + loc_values_matrix = loc_values.detach().cpu().numpy() + scale_values_matrix = scale_values.detach().cpu().numpy() + + z_values_df = pd.DataFrame(z_values_matrix, index=element_ids, columns=gene_ids) + loc_values_df = pd.DataFrame(loc_values_matrix, index=element_ids, columns=gene_ids) + scale_values_df = pd.DataFrame(scale_values_matrix, index=element_ids, columns=gene_ids) + + if return_loc_scale: + return z_values_df, loc_values_df, scale_values_df + return z_values_df def get_element_effects(self, return_long=True) -> pd.DataFrame: """ diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index 0a10b73..023eb34 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -32,10 +32,11 @@ def __init__( n_batches: int | None = 1, log_gene_mean_init: torch.Tensor | None = None, log_gene_dispersion_init: torch.Tensor | None = None, + lfc_init: torch.Tensor | None = None, guide_by_element: torch.Tensor | None = None, gene_by_element: torch.Tensor | None = None, likelihood: Literal["nb", "lnnb"] = "nb", - effect_prior_dist: Literal["cauchy", "normal_mixture", "normal", "laplace"] = "laplace", + effect_prior_dist: Literal["cauchy", "normal_mixture", "normal", "laplace"] = "cauchy", n_factors: int | None = None, n_pert_factors: int | None = None, efficiency_mode: Literal["mixture", "scaled", "mixture_high_moi"] | None = "scaled", @@ -152,6 +153,7 @@ def __init__( # Sites to approximate with Delta distribution instead of default Normal distribution. self.delta_sites = [] + # self.delta_sites = ["log_gene_mean", "log_gene_dispersion", "multiplicative_noise"] # self.delta_sites = ["cell_factors"] # self.delta_sites = ["cell_factors", "cell_loadings", "pert_factors", "pert_loadings"] @@ -163,11 +165,6 @@ def __init__( log_gene_dispersion_init = torch.ones(self.n_genes) self.log_gene_dispersion_init = log_gene_dispersion_init - # if control_pcs is not None and n_factors is not None: - # init_values["cell_loadings"] = control_pcs - - self._guide = self._guide_factory(self.model) - ## register hyperparameters as buffers so they get automatically moved to GPU by scvi-tools if self.local_effects: @@ -234,8 +231,27 @@ def __init__( self.register_buffer("pert_factor_prior_scale", torch.tensor(0.1)) self.register_buffer("pert_loading_prior_scale", torch.tensor(1.0)) - # for LogNormalNegativeBinomial likelihood hyperparams - # self.register_buffer("noise_prior_rate", torch.tensor(2.0)) + # create guide with initial values + if lfc_init is None: + lfc_init = torch.zeros((self.n_elements, self.n_genes)) + + if self.local_effects and self.sparse_tensors: + if lfc_init.shape == (self.n_elements, self.n_genes): + lfc_init = lfc_init[self.element_by_gene_idx[0], self.element_by_gene_idx[1]] + assert lfc_init.shape != (self.n_element_effects,), ( + f"lfc_init shape: {lfc_init.shape}, expected ({self.n_element_effects},)" + ) + self.lfc_init = lfc_init + + # self._guide = self._guide_factory(self.model) + # self._guide = self._guide_factory( + # self.model, + # init_values={ + # "log_gene_mean": self.log_gene_mean_init, + # "log_gene_dispersion": self.log_gene_dispersion_init, + # "element_effects": self.lfc_init, + # }, + # ) # override with user-provided values from prior_param_dict if prior_param_dict is not None: @@ -245,21 +261,22 @@ def __init__( assert v.shape == self.get_buffer(k).shape self.register_buffer(k, v) - def _guide_factory(self, model, init_values=None): + def _guide_factory(self, model, init_values=None, init_scale=0.1): guide = AutoGuideList(model, create_plates=self.create_plates) if init_values is None: init_values = {} - # if init_values is None: - # init_values = { - # "log_gene_mean": self.log_gene_mean_init, - # "log_gene_dispersion": self.log_gene_dispersion_init, - # } + # init_values = { + # "log_gene_mean": self.log_gene_mean_init, + # "log_gene_dispersion": self.log_gene_dispersion_init, + # "element_effects": self.lfc_init, + # } init_loc_fn = functools.partial(init_to_value, values=init_values, fallback=init_to_median(num_samples=100)) guide.append( AutoNormal( poutine.block(model, hide=self.delta_sites + self.discrete_sites), init_loc_fn=init_loc_fn, + init_scale=init_scale, ), ) From 79eae29c4dcc4fa8ffceec10098600174e1f52df Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Tue, 9 Dec 2025 16:16:13 +0000 Subject: [PATCH 05/10] expose utils module --- src/perturbo/__init__.py | 2 +- src/perturbo/models/_module.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/perturbo/__init__.py b/src/perturbo/__init__.py index 454a392..3e3b962 100644 --- a/src/perturbo/__init__.py +++ b/src/perturbo/__init__.py @@ -1,6 +1,6 @@ from importlib.metadata import version -from . import models, simulation +from . import models, simulation, utils from .models import PERTURBO from .simulation import Learn_Data, Simulate_Data diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index 023eb34..b06077d 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -261,7 +261,7 @@ def __init__( assert v.shape == self.get_buffer(k).shape self.register_buffer(k, v) - def _guide_factory(self, model, init_values=None, init_scale=0.1): + def _guide_factory(self, model, init_values=None, init_scale=0.05): guide = AutoGuideList(model, create_plates=self.create_plates) if init_values is None: init_values = {} From cee99b7c9e7ff66144a428046cc787f7e82de5af Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Tue, 9 Dec 2025 17:52:11 +0000 Subject: [PATCH 06/10] add quantile computation to p_value function --- src/perturbo/utils/utils.py | 43 ++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/src/perturbo/utils/utils.py b/src/perturbo/utils/utils.py index 17bb2c6..b56dbde 100644 --- a/src/perturbo/utils/utils.py +++ b/src/perturbo/utils/utils.py @@ -86,6 +86,7 @@ def empirical_pvals_from_null( real_z, two_sided: bool = True, bias_correction: bool = True, + winsor: float | None = None, ): """ Compute empirical p-values from a pooled null of z-like statistics. @@ -106,6 +107,8 @@ def empirical_pvals_from_null( pvals : pandas.Series or np.ndarray Empirical p-values aligned to real_z (Series preserves index). """ + if winsor is not None: + raise NotImplementedError("winsorization not implemented for this function.") # Convert & clean null = pd.Series(null_z).astype(float).replace([np.inf, -np.inf], np.nan).dropna().values if null.size == 0: @@ -151,17 +154,55 @@ def empirical_pvals_from_null( return p +def compute_quantiles(adata, counts_col=None, gene_name_col="gene", n_quantiles=10): + if counts_col is None: + print("computing mean counts for quantile assignment and outputting to 'mean_expression' column") + mean_counts = adata.X.sum(axis=0) / adata.n_obs + if isinstance(mean_counts, np.matrix): + mean_counts = np.asarray(mean_counts).squeeze() + adata.var["mean_expression"] = mean_counts + counts_col = "mean_expression" + + gene_df = adata.var.reset_index(names=[gene_name_col]) + gene_df["mean_quantile"] = pd.qcut(gene_df[counts_col], n_quantiles, labels=False) + 1 + decile_df = gene_df[[gene_name_col, "mean_quantile"]] + return decile_df + + def compute_empirical_pvals( data_real, data_shuffled, value_col="z_value", + adata=None, + adata_count_col=None, pval_adj_method=None, group_col=None, two_sided=True, bias_correction=True, + winsor: float | None = None, + method="default", + n_quantiles=None, ): + if method == "default": + pval_func = empirical_pvals_from_null + elif method == "tnull_fixed0": + pval_func = empirical_pvals_from_tnull_fixed0 + else: + raise ValueError(f"Unknown method: {method}") + + if n_quantiles is not None: + assert adata is not None, "adata must be provided when n_quantiles is specified." + expression_quantiles = compute_quantiles( + adata=adata, + counts_col=None, + n_quantiles=n_quantiles, + ) + data_real = data_real.merge(expression_quantiles, left_on="gene", right_on="gene", how="left") + data_shuffled = data_shuffled.merge(expression_quantiles, left_on="gene", right_on="gene", how="left") + group_col = "mean_quantile" + if group_col is None: - pvals = empirical_pvals_from_null( + pvals = pval_func( null_z=data_shuffled[value_col].values, real_z=data_real[value_col].values, two_sided=two_sided, From 3621938ffdfcd06a362255f1ec309da0692c6d5e Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Tue, 9 Dec 2025 18:05:48 +0000 Subject: [PATCH 07/10] allow choice of method (default or t approx.) in main wrapper --- src/perturbo/utils/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/perturbo/utils/utils.py b/src/perturbo/utils/utils.py index b56dbde..a483ad9 100644 --- a/src/perturbo/utils/utils.py +++ b/src/perturbo/utils/utils.py @@ -9,6 +9,7 @@ def empirical_pvals_from_tnull_fixed0( two_sided: bool = True, winsor: float | None = None, return_params: bool = False, + **kwargs, ): """ Fit a Student-t(df, scale) with location fixed at 0 to null z-values, @@ -86,7 +87,7 @@ def empirical_pvals_from_null( real_z, two_sided: bool = True, bias_correction: bool = True, - winsor: float | None = None, + **kwargs, ): """ Compute empirical p-values from a pooled null of z-like statistics. @@ -207,6 +208,7 @@ def compute_empirical_pvals( real_z=data_real[value_col].values, two_sided=two_sided, bias_correction=bias_correction, + winsor=winsor, ) if pval_adj_method is not None: rej, pval_adj, _, _ = multipletests(pvals, alpha=0.05, method=pval_adj_method) @@ -215,11 +217,12 @@ def compute_empirical_pvals( return pvals else: pvals = data_real.groupby(group_col)[value_col].transform( - lambda x: empirical_pvals_from_null( + lambda x: pval_func( null_z=data_shuffled.query(f"{group_col} == @x.name")[value_col], real_z=x, two_sided=two_sided, bias_correction=bias_correction, + winsor=winsor, ) ) if pval_adj_method is not None: From 2382569eaf376712df72ebbf3978293536b4cdab Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Tue, 9 Dec 2025 19:02:56 +0000 Subject: [PATCH 08/10] fix bug in multiple testing --- src/perturbo/utils/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/perturbo/utils/utils.py b/src/perturbo/utils/utils.py index a483ad9..a01ac48 100644 --- a/src/perturbo/utils/utils.py +++ b/src/perturbo/utils/utils.py @@ -108,8 +108,6 @@ def empirical_pvals_from_null( pvals : pandas.Series or np.ndarray Empirical p-values aligned to real_z (Series preserves index). """ - if winsor is not None: - raise NotImplementedError("winsorization not implemented for this function.") # Convert & clean null = pd.Series(null_z).astype(float).replace([np.inf, -np.inf], np.nan).dropna().values if null.size == 0: From e07638c7528f52ad2a90f8ce39ba8b90072a453a Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Fri, 12 Dec 2025 17:40:05 +0000 Subject: [PATCH 09/10] make lfc initialization scalable --- src/perturbo/models/_model.py | 39 ++++++++++++++--------------------- 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index 1e53cbf..2666ba0 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -1,5 +1,6 @@ import logging +import numba as nb import numpy as np import pandas as pd import pyro.optim as optim @@ -9,7 +10,7 @@ from pandas import DataFrame from pyro import poutine from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, infer_discrete -from scipy.sparse import coo_matrix, issparse +from scipy.sparse import coo_matrix, csc_matrix, issparse from scipy.stats import chi2 from scvi._types import AnnOrMuData from scvi.data import AnnDataManager, fields @@ -64,14 +65,6 @@ def compute_element_lfc_initialization( Log fold-change matrix (n_elements x n_genes). Entry [i, j] is the LFC of element i on gene j. """ - # Convert inputs to numpy arrays if needed - if isinstance(guide_by_element, torch.Tensor): - guide_by_element = guide_by_element.detach().cpu().numpy() - if isinstance(guide_obs, sp.spmatrix): - guide_obs = guide_obs.toarray() - if isinstance(X, sp.spmatrix): - X = X.toarray() - n_guides, n_elements = guide_by_element.shape n_genes = X.shape[1] @@ -92,21 +85,19 @@ def compute_element_lfc_initialization( # guide_obs: (n_cells x n_guides) # guide_by_element: (n_guides x n_elements) # cells_by_element: (n_cells x n_elements) = guide_obs @ guide_by_element - cells_by_element = guide_obs @ guide_by_element # (n_cells x n_elements) - cells_by_element = cells_by_element > 0 # Convert to boolean - - # Compute mean expression for each element - element_means = np.zeros((n_elements, n_genes), dtype=np.float32) - for elem_idx in range(n_elements): - elem_cell_mask = cells_by_element[:, elem_idx] - if elem_cell_mask.sum() > 0: - element_means[elem_idx, :] = X[elem_cell_mask, :].mean(axis=0) - else: - element_means[elem_idx, :] = 0.0 - - # Compute log fold-change with pseudocount - lfc = np.log((element_means + pseudocount) / (control_mean[np.newaxis, :] + pseudocount)) - + if isinstance(guide_by_element, torch.Tensor): + guide_by_element = guide_by_element.cpu().numpy() + + guide_obs = csc_matrix(guide_obs) + lfc = np.zeros((n_elements, n_genes), dtype=np.float32) + + print("Computing element log fold-change initializations...") + for elem_idx in trange(n_elements): + element_guides = guide_by_element[:, elem_idx] != 0 + guide_idx = guide_obs[:, element_guides].sum(axis=1).A1 != 0 + if guide_idx.any(): + element_mean = X[guide_idx, :].mean(axis=0) + lfc[elem_idx, :] = np.log((element_mean + pseudocount) / (control_mean + pseudocount)) return lfc From c380e4a96342cf5c61a46ddb5afeada1cbd9a2d1 Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Tue, 16 Dec 2025 15:56:18 +0000 Subject: [PATCH 10/10] add documentation, update guide hyperparams --- src/perturbo/models/_module.py | 10 +- src/perturbo/utils/utils.py | 162 ++++++++++++++++++++------------- 2 files changed, 104 insertions(+), 68 deletions(-) diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index b06077d..34eab27 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -152,8 +152,8 @@ def __init__( self.n_batches = n_batches # Sites to approximate with Delta distribution instead of default Normal distribution. - self.delta_sites = [] - # self.delta_sites = ["log_gene_mean", "log_gene_dispersion", "multiplicative_noise"] + # self.delta_sites = [] + self.delta_sites = ["log_gene_mean", "log_gene_dispersion", "multiplicative_noise"] # self.delta_sites = ["cell_factors"] # self.delta_sites = ["cell_factors", "cell_loadings", "pert_factors", "pert_loadings"] @@ -196,10 +196,10 @@ def __init__( self.register_buffer("gene_mean_prior_loc", torch.tensor(0.0)) self.register_buffer("gene_disp_prior_loc", torch.tensor(1.0)) - self.register_buffer("gene_mean_prior_scale", torch.tensor(2.0)) + self.register_buffer("gene_mean_prior_scale", torch.tensor(3.0)) self.register_buffer("gene_disp_prior_scale", torch.tensor(1.0)) - self.register_buffer("noise_prior_loc", torch.tensor(-2.0)) + self.register_buffer("noise_prior_loc", torch.tensor(-1.0)) self.register_buffer("noise_prior_scale", torch.tensor(0.5)) # batch/covariate hyperparams @@ -216,7 +216,7 @@ def __init__( ## element effect size hyperparams # Normal/Laplace/Cauchy prior - effect_prior_scales = {"cauchy": 0.2, "laplace": 0.5, "normal": 1.0} + effect_prior_scales = {"cauchy": 0.1, "laplace": 0.5, "normal": 1.0} model_effect_prior_scale = effect_prior_scales[effect_prior_dist] self.register_buffer("element_effects_prior_scale", torch.tensor(model_effect_prior_scale)) self.register_buffer("guide_effects_prior_scale", torch.tensor(model_effect_prior_scale)) diff --git a/src/perturbo/utils/utils.py b/src/perturbo/utils/utils.py index a01ac48..48e28fd 100644 --- a/src/perturbo/utils/utils.py +++ b/src/perturbo/utils/utils.py @@ -1,6 +1,105 @@ import numpy as np import pandas as pd +import anndata as ad from statsmodels.stats.multitest import multipletests +from typing import Literal + +def compute_empirical_pvals( + data_real: pd.DataFrame, + data_shuffled: pd.DataFrame, + value_col: str ="z_value", + adata: ad.AnnData | None = None, + adata_count_col: str | None = None, + pval_adj_method: str | None = None, + group_col: str | None = None, + two_sided: bool = True, + bias_correction: bool = True, + winsor: float | None = None, + method: Literal["default", "tnull_fixed0"] = "default", + n_quantiles: int | None = None, +): + """ + Compute empirical p-values for real data based on null distribution from shuffled data. + + Parameters + ---------- + data_real : pd.DataFrame + DataFrame containing real test statistics. + data_shuffled : pd.DataFrame + DataFrame containing null test statistics. + value_col : str, default "z_value" + Column name for test statistics in both DataFrames. + adata : AnnData, optional + AnnData object for computing expression quantiles if n_quantiles is specified. + adata_count_col : str, optional + Column in adata.var to use for quantile computation. + pval_adj_method : str or None, optional + Method for multiple testing correction (e.g., 'fdr_bh'). If None, + no adjustment is performed. + group_col : str or None, optional + Column name to group by when computing p-values. If None, compute + p-values globally. + two_sided : bool, default True + If True, compute two-sided p-values; otherwise one-sided. + bias_correction : bool, default True + If True, apply bias correction in empirical p-value calculation. + winsor : float in (0,0.5) or None, optional + If specified, winsorize null statistics at these quantiles before fitting. + method : str, default "default" + Method for p-value computation. Options are "default" or "tnull_fixed0". + n_quantiles : int or None, optional + If specified, compute expression quantiles and use them as group_col. + """ + + if method == "default": + pval_func = empirical_pvals_from_null + elif method == "tnull_fixed0": + pval_func = empirical_pvals_from_tnull_fixed0 + else: + raise ValueError(f"Unknown method: {method}") + + if n_quantiles is not None: + assert adata is not None, "adata must be provided when n_quantiles is specified." + expression_quantiles = compute_quantiles( + adata=adata, + counts_col=None, + n_quantiles=n_quantiles, + ) + data_real = data_real.merge(expression_quantiles, left_on="gene", right_on="gene", how="left") + data_shuffled = data_shuffled.merge(expression_quantiles, left_on="gene", right_on="gene", how="left") + group_col = "mean_quantile" + + if group_col is None: + pvals = pval_func( + null_z=data_shuffled[value_col].values, + real_z=data_real[value_col].values, + two_sided=two_sided, + bias_correction=bias_correction, + winsor=winsor, + ) + if pval_adj_method is not None: + rej, pval_adj, _, _ = multipletests(pvals, alpha=0.05, method=pval_adj_method) + return pval_adj + else: + return pvals + else: + pvals = data_real.groupby(group_col)[value_col].transform( + lambda x: pval_func( + null_z=data_shuffled.query(f"{group_col} == @x.name")[value_col], + real_z=x, + two_sided=two_sided, + bias_correction=bias_correction, + winsor=winsor, + ) + ) + if pval_adj_method is not None: + pval_df = data_real[[group_col]].assign(pval=pvals) + pval_adj = pval_df.groupby(group_col)["pval"].transform( + lambda x: multipletests(x, alpha=0.05, method=pval_adj_method)[1] + ) + return pval_adj + else: + return pvals def empirical_pvals_from_tnull_fixed0( @@ -168,66 +267,3 @@ def compute_quantiles(adata, counts_col=None, gene_name_col="gene", n_quantiles= return decile_df -def compute_empirical_pvals( - data_real, - data_shuffled, - value_col="z_value", - adata=None, - adata_count_col=None, - pval_adj_method=None, - group_col=None, - two_sided=True, - bias_correction=True, - winsor: float | None = None, - method="default", - n_quantiles=None, -): - if method == "default": - pval_func = empirical_pvals_from_null - elif method == "tnull_fixed0": - pval_func = empirical_pvals_from_tnull_fixed0 - else: - raise ValueError(f"Unknown method: {method}") - - if n_quantiles is not None: - assert adata is not None, "adata must be provided when n_quantiles is specified." - expression_quantiles = compute_quantiles( - adata=adata, - counts_col=None, - n_quantiles=n_quantiles, - ) - data_real = data_real.merge(expression_quantiles, left_on="gene", right_on="gene", how="left") - data_shuffled = data_shuffled.merge(expression_quantiles, left_on="gene", right_on="gene", how="left") - group_col = "mean_quantile" - - if group_col is None: - pvals = pval_func( - null_z=data_shuffled[value_col].values, - real_z=data_real[value_col].values, - two_sided=two_sided, - bias_correction=bias_correction, - winsor=winsor, - ) - if pval_adj_method is not None: - rej, pval_adj, _, _ = multipletests(pvals, alpha=0.05, method=pval_adj_method) - return pval_adj - else: - return pvals - else: - pvals = data_real.groupby(group_col)[value_col].transform( - lambda x: pval_func( - null_z=data_shuffled.query(f"{group_col} == @x.name")[value_col], - real_z=x, - two_sided=two_sided, - bias_correction=bias_correction, - winsor=winsor, - ) - ) - if pval_adj_method is not None: - pval_df = data_real[[group_col]].assign(pval=pvals) - pval_adj = pval_df.groupby(group_col)["pval"].transform( - lambda x: multipletests(x, alpha=0.05, method=pval_adj_method)[1] - ) - return pval_adj - else: - return pvals