4
4
import anndata as ad
5
5
import scanpy as sc
6
6
from scipy .sparse import csr_matrix
7
- import numpy as np
8
-
9
-
10
7
11
8
parser = argparse .ArgumentParser (description = "Merge datasets" )
12
9
parser .add_argument ("--input" , help = "Input file" , type = str , nargs = "+" )
13
- parser .add_argument ("--output_integration " , help = "Output file containing only cells which do not require transfer learning " , type = str )
14
- parser .add_argument ("--output_intersection" , help = "Output file containing all cells but gene intersection" , type = str )
15
- parser .add_argument ("--output_transfer " , help = "Output file containing all cells which require transfer learning " , type = str )
16
- parser .add_argument ("--output_counts " , help = "Output file, outer join of cells and genes " , type = str )
10
+ parser .add_argument ("--base " , help = "Base dataset to use as reference " , type = str , required = False )
11
+ parser .add_argument ("--output_intersection" , help = "Output file containing all cells but gene intersection" , type = str , required = True )
12
+ parser .add_argument ("--output_union " , help = "Output file, outer join of cells and genes " , type = str , required = True )
13
+ parser .add_argument ("--output_transfer " , help = "Output file, cells to project onto base " , type = str , required = False )
17
14
parser .add_argument ("--min_cells" , help = 'Minimum number of cells to keep a gene' , type = int , required = False , default = 50 )
18
15
parser .add_argument ("--custom_genes" , help = "Additional genes to include" , type = str , nargs = "*" )
19
16
20
17
args = parser .parse_args ()
21
18
22
19
datasets = [ad .read_h5ad (f ) for f in args .input ]
23
20
24
- adata = ad .concat (datasets )
25
- adata_outer = ad .concat (datasets , join = 'outer' )
21
+ if args .base :
22
+ if not args .output_transfer :
23
+ raise ValueError ("Transfer file required when using base dataset" )
24
+
25
+ adata_base = ad .read_h5ad (args .base )
26
+ datasets = [adata_base ] + datasets
27
+
28
+ adata_intersection = ad .concat (datasets )
29
+ adata_union = ad .concat (datasets , join = 'outer' )
26
30
27
- additional_genes = [gene for gene in args .custom_genes if gene not in adata .var_names and gene in adata_outer .var_names ]
31
+ additional_genes = [gene for gene in args .custom_genes if gene not in adata_intersection .var_names and gene in adata_union .var_names ]
28
32
29
33
# Add custom genes from outer join to the intersection
30
34
if additional_genes :
31
- adata_additional = adata_outer [ adata .obs_names , additional_genes ]
32
- adata_concatenated = ad .concat ([adata , adata_additional ], join = "outer" , axis = 1 )
33
- adata_concatenated .obs , adata_concatenated .obsm = adata .obs , adata .obsm
34
- adata = adata_concatenated
35
+ adata_additional = adata_union [ adata_intersection .obs_names , additional_genes ]
36
+ adata_concatenated = ad .concat ([adata_intersection , adata_additional ], join = "outer" , axis = 1 )
37
+ adata_concatenated .obs , adata_concatenated .obsm = adata_intersection .obs , adata_intersection .obsm
38
+ adata_intersection = adata_concatenated
35
39
36
40
# Convert to CSR matrix
37
- adata .X = csr_matrix (adata .X )
38
- adata_outer .X = csr_matrix (adata_outer .X )
39
-
40
- # Filter genes with no counts in core atlas
41
- gene_mask , _ = sc .pp .filter_genes (adata [~ adata .obs ["transfer" ]], min_cells = 1 , inplace = False )
42
- adata = adata [:, gene_mask ]
41
+ adata_intersection .X = csr_matrix (adata_intersection .X )
42
+ adata_union .X = csr_matrix (adata_union .X )
43
43
44
44
# Filter cells with no counts
45
- cell_mask , _ = sc .pp .filter_cells (adata , min_genes = 1 , inplace = False )
46
- adata = adata [cell_mask , :]
47
- adata_outer = adata_outer [cell_mask , :]
45
+ cell_mask , _ = sc .pp .filter_cells (adata_intersection , min_genes = 1 , inplace = False )
46
+ adata_intersection = adata_intersection [cell_mask , :]
47
+ adata_union = adata_union [cell_mask , :]
48
48
49
49
# Filter genes with too few occurrences in outer join
50
- sc .pp .filter_genes (adata_outer , min_cells = args .min_cells )
50
+ sc .pp .filter_genes (adata_union , min_cells = args .min_cells )
51
51
52
- adata .obs ["batch" ] = adata .obs ["dataset" ].astype (str ) + "_" + adata .obs ["batch" ].astype (str )
53
- adata .obs ["patient" ] = adata .obs ["dataset" ].astype (str ) + "_" + adata .obs ["patient" ].astype (str )
52
+ adata_intersection .obs ["batch" ] = adata_intersection .obs ["dataset" ].astype (str ) + "_" + adata_intersection .obs ["batch" ].astype (str )
53
+ adata_intersection .obs ["patient" ] = adata_intersection .obs ["dataset" ].astype (str ) + "_" + adata_intersection .obs ["patient" ].astype (str )
54
54
55
55
def to_Florent_case (s : str ):
56
56
corrected = s .lower ().strip ()
@@ -77,25 +77,25 @@ def to_Florent_case(s: str):
77
77
78
78
return corrected [0 ].upper () + corrected [1 :]
79
79
80
- for column in adata .obs .columns :
81
- if column == "transfer" :
82
- continue
83
- if not adata .obs [column ].dtype .name == "category" and not adata .obs [column ].dtype .name == "object" :
80
+ for column in adata_intersection .obs .columns :
81
+ if not adata_intersection .obs [column ].dtype .name == "category" and not adata_intersection .obs [column ].dtype .name == "object" :
84
82
continue
85
83
# Convert first to string and then to category
86
- adata .obs [column ] = adata .obs [column ].astype (str ).fillna ("Unknown" ).apply (to_Florent_case ).astype ("category" )
84
+ adata_intersection .obs [column ] = adata_intersection .obs [column ].astype (str ).fillna ("Unknown" ).apply (to_Florent_case ).astype ("category" )
85
+
86
+ adata_union .obs = adata_intersection .obs
87
+
88
+ adata_intersection .layers ["counts" ] = adata_intersection .X
89
+ adata_union .layers ["counts" ] = adata_union .X
87
90
88
- adata_outer .obs = adata .obs
91
+ if args .base :
92
+ adata_transfer = adata_intersection [~ adata_intersection .obs .index .isin (adata_base .obs .index )]
89
93
90
- adata . layers [ "counts" ] = adata . X
91
- adata_outer . layers [ "counts " ] = adata_outer . X
94
+ known_celltypes = adata_base . obs [ "cell_type" ]. unique ()
95
+ adata_transfer . obs [ "cell_type " ] = adata_transfer . obs [ "cell_type" ]. map ( lambda x : x if x in known_celltypes else "Unknown" )
92
96
93
- if any (adata .obs ["transfer" ]):
94
- adata_transfer = adata [adata .obs ["transfer" ]]
95
97
adata_transfer .write_h5ad (args .output_transfer )
96
98
97
- adata_notransfer = adata [ ~ adata . obs [ "transfer" ]]
98
- adata_notransfer .write_h5ad (args .output_integration )
99
+ adata_intersection . write_h5ad ( args . output_intersection )
100
+ adata_union .write_h5ad (args .output_union )
99
101
100
- adata .write_h5ad (args .output_intersection )
101
- adata_outer .write_h5ad (args .output_counts )
0 commit comments