Skip to content

Commit

Permalink
[python] Check for list/tuple arguments in registrars (#3518)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnkerl authored Jan 7, 2025
1 parent b3f9df0 commit 439c530
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def _acquire_experiment_mappings(
def from_anndata_appends_on_experiment(
cls,
experiment_uri: str | None,
adatas: Sequence[ad.AnnData],
adatas: Sequence[ad.AnnData] | ad.AnnData,
*,
measurement_name: str,
obs_field_name: str,
Expand All @@ -404,6 +404,13 @@ def from_anndata_appends_on_experiment(
is ``None`` then you will be computing registrations only for the input
``AnnData`` objects. If ``experiment_uri`` is not ``None`` then it is
an error if the experiment is not accessible."""
# typeguard doesn't help at runtime. Check this crucial user-facing API.
if isinstance(adatas, ad.AnnData):
adatas = [adatas]
elif not isinstance(adatas, (list, tuple)):
raise ValueError(
f"adatas must be list or tuple of AnnData, or a single AnnData; got {type(adatas)}"
)

registration_data = cls._acquire_experiment_mappings(
experiment_uri,
Expand Down Expand Up @@ -455,7 +462,7 @@ def from_h5ad_append_on_experiment(
def from_h5ad_appends_on_experiment(
cls,
experiment_uri: str | None,
h5ad_file_names: Sequence[str],
h5ad_file_names: Sequence[str] | str,
*,
measurement_name: str,
obs_field_name: str,
Expand All @@ -465,6 +472,13 @@ def from_h5ad_appends_on_experiment(
) -> Self:
"""Extends registration data from the baseline, already-written SOMA
experiment to include multiple H5AD input files."""
# typeguard doesn't help at runtime. Check this crucial user-facing API.
if isinstance(h5ad_file_names, str):
h5ad_file_names = [h5ad_file_names]
elif not isinstance(h5ad_file_names, (list, tuple)):
raise ValueError(
f"h5ad_file_names must be list or tuple of string, or a single string; got {type(h5ad_file_names)}"
)

registration_data = cls._acquire_experiment_mappings(
experiment_uri,
Expand Down
4 changes: 2 additions & 2 deletions apis/python/src/tiledbsoma/io/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def __init__(
# entrypoints for append-mode soma_joinid registration.
def register_h5ads(
experiment_uri: str | None,
h5ad_file_names: Sequence[str],
h5ad_file_names: Sequence[str] | str,
*,
measurement_name: str,
obs_field_name: str,
Expand All @@ -212,7 +212,7 @@ def register_h5ads(

def register_anndatas(
experiment_uri: str | None,
adatas: Sequence[ad.AnnData],
adatas: Sequence[ad.AnnData] | ad.AnnData,
*,
measurement_name: str,
obs_field_name: str,
Expand Down
62 changes: 62 additions & 0 deletions apis/python/tests/test_registration_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,3 +1342,65 @@ def test_multimodal_names(tmp_path, conftest_pbmc3k_adata):
assert exp.obs.count == len(adata_protein.obs)
assert exp.ms["RNA"].var.count == len(adata_rna.var)
assert exp.ms["protein"].var.count == len(adata_protein.var)


def test_registration_lists_and_tuples(tmp_path):
obs_field_name = "cell_id"
var_field_name = "gene_id"

exp_uri = create_soma_canned(1, obs_field_name, var_field_name)
adata = create_anndata_canned(2, obs_field_name, var_field_name)
h5ad_file_name = create_h5ad_canned(2, obs_field_name, var_field_name)

rd1 = tiledbsoma.io.register_anndatas(
experiment_uri=exp_uri,
adatas=[adata],
measurement_name="measname",
obs_field_name=obs_field_name,
var_field_name=var_field_name,
)

rd2 = tiledbsoma.io.register_anndatas(
experiment_uri=exp_uri,
adatas=(adata,),
measurement_name="measname",
obs_field_name=obs_field_name,
var_field_name=var_field_name,
)

rd3 = tiledbsoma.io.register_anndatas(
experiment_uri=exp_uri,
adatas=adata,
measurement_name="measname",
obs_field_name=obs_field_name,
var_field_name=var_field_name,
)
assert rd1 == rd2
assert rd2 == rd3

rd4 = tiledbsoma.io.register_h5ads(
experiment_uri=exp_uri,
h5ad_file_names=[h5ad_file_name],
measurement_name="measname",
obs_field_name=obs_field_name,
var_field_name=var_field_name,
)

rd5 = tiledbsoma.io.register_h5ads(
experiment_uri=exp_uri,
h5ad_file_names=(h5ad_file_name,),
measurement_name="measname",
obs_field_name=obs_field_name,
var_field_name=var_field_name,
)

rd6 = tiledbsoma.io.register_h5ads(
experiment_uri=exp_uri,
h5ad_file_names=h5ad_file_name,
measurement_name="measname",
obs_field_name=obs_field_name,
var_field_name=var_field_name,
)

assert rd4 == rd5
assert rd5 == rd6

0 comments on commit 439c530

Please sign in to comment.