Convolutional neural network architecture for supervised registration of barcoded spatial transcriptomics (ST) data.
Extending the work of GridNet, GridNext provides the ability to learn over image data, count data, or both together. As with GridNet, the basic structure is a two-layer convolutional neural network:
- f -- spot classifier applied independently to each measurement
- g -- convolutional correction network applied on the output of f incorporating spatial information at a bandwidth dictated by network depth/kernel width.
The GridNext library additionally provides functionality for interfacing directly with data from 10x Genomics' Visium platform (through the outputs of Spaceranger and Loupe for count and annotation data, respectively), as well as guidelines for working with custom data in popular formats such as AnnData.
GridNext can be installed by pip:
pip install git+https://github.com/adaly/gridnext/
Or the source code can manually be downloaded and compiled:
git clone https://github.com/adaly/gridnext.git
cd gridnext
pip install .
The GridNext library provides extensions of the PyTorch Dataset
class for working with ST data:
CountDataset
,PatchDataset
, andMultiModalDataset
for spot-wise (pre) training of the f networkCountGridDataset
,PatchGridDataset
andMultiModalGridDataset
for training/fine-tuning the g network
The create_visium_dataset
function is used to instatiate either spot- or grid-level datasets using either or both of the count or image data modalities from the outputs of Spaceranger (count and image data) and Loupe (spot annotations):
import os, glob
from gridnext.visium_datasets import create_visium_dataset
data_dir = '../data/BA44_testdata'
# Top-level output directories from Spaceranger for each Visium array
spaceranger_dirs = sorted(glob.glob(os.path.join(data_dir, 'spaceranger', '*')))
# Associated full-resolution image files used as inputs to Spaceranger
fullres_image_files = sorted(glob.glob(os.path.join(data_dir, 'fullres_images', '*.jpg')))
# Spot annotation files produced using the Loupe browser
annot_files = sorted(glob.glob(os.path.join(data_dir, 'annotations', '*.csv')))
For count data:
# Spot-wise data (for training f)
spot_dat = create_visium_dataset(spaceranger_dirs, use_count=True, use_image=False, annot_files=annot_files, spatial=False)
len(spot_dat) # n_spots
x_spot, y_spot = spot_dat[0] # x_spot = (n_genes,), y_spot = (,)
# Grid-wise data (for training g)
grid_dat = create_visium_dataset(spaceranger_dirs, use_count=True, use_image=False, annot_files=annot_files, spatial=True)
len(grid_dat) # n_arrays
x_grid, y_grid = grid_dat[0] # x_grid = (n_genes, n_rows_vis, n_cols_vis), y_grid = (n_rows_vis, n_cols_vis)
The first time this is run for a given dataset, it will create a unified list of n_genes seen across all Visium arrays, and save a unified TSV-formatted (n_genes x n_spots) count file with the suffix .unified.tsv.gz
(can be changed with the count_suffix
optional argument) in the top-level directory of each Spaceranger output. On subsequent runs, the function will look for these unified count files and use them in constructing count tensors.
Optional arguments:
minimum_detection_rate
-- after generating unified gene list, drop genes expressed in fewer than this fraction of spotsselect_genes
-- list of gene names (ENSEMBL if using default 10x reference transcriptome) to subset for analysis
Alternatively, if you are working with only count data, it may be faster to store your Visium data in an AnnData object using create_visium_anndata
:
from gridnext.visium_datasets import create_visium_anndata
adata = create_visium_anndata(spaceranger_dirs, annot_files=annot_files, destfile=PATH_TO_SAVE_ANNDATA)
and then follow the instructions for loading AnnData objects.
For image data:
patch_size_px = 128 # or patch_size_um = 100 to specify physical width
# Spot-wise data (for training f)
spot_dat = create_visium_dataset(spaceranger_dirs, use_image=True, use_count=False, annot_files=annot_files, fullres_image_files=fullres_image_files, patch_size_px=patch_size_px, spatial=False)
len(spot_dat) # n_spots
x_spot, y_spot = spot_dat[0] # x_spot = (3, patch_size, patch_size)
# Grid-wise data (for training g)
grid_dat = create_visium_dataset(spaceranger_dirs, use_image=True, use_count=False, annot_files=annot_files, fullres_image_files=fullres_image_files, patch_size_px=patch_size_px, spatial=True)
# ...or set patch_size_px=None and patch_size_um=patch_size_um to extract patches of fixed physical width
len(grid_dat) # n_arrays
x_grid, y_grid = grid_dat[0] # x_grid = (n_rows_vis, n_cols_vis, 3, patch_size, patch_size)
The first time this runs for a given dataset, it will create a sub-directory in each spaceranger output directory with the suffix *_patches[patch_size_px]
or _patches[patch_size_um]
(depending on whether patch_size_px
or patch_size_um
is passed) containing image patches extracted from each spot location in the array (named as [array_name][array_col][array_row].jpg). On subsequent runs with the same patch size, the function will look for these patches and use them in constructing image tensors.
Optional arguments:
img_transforms
-- atorchvision.transforms
object (or a composition thereof) to be applied to any image patch prior to accession through the Dataset class. This can be used to accomodate patches of different pixel width resulting from passingpatch_size_um
with images of multiple resolutions.
GridNext provides functionality to load spatially resolved count data into memory from AnnData. The AnnData object must be structured as such:
- For spatially-resolved data,
adata.obs
must have columns namedx
andy
for spatial coordinate data. - Principal components of count data may be stored in `adata.layers['X_pca']
There are two methods of loading such data:
- For smaller datasets (that can be loaded into memory at once), the following functions yield
TensorDataset
objects with fast accession times:
import scanpy as sc
from gridnext.count_datasets import anndata_to_tensordataset, anndata_arrays_to_tensordataset
adata = sc.read_h5ad(...) # load AnnData object
obs_label = 'AARs' # column in adata.obs containing spot annotations
obs_arr = 'vis_arr' # column in adata.obs.containing array names (for multi-array ST data)
h_st, w_st = 78, 64 # height and width of ST array (i.e., number of rows and columns)
vis_coords = True # whether x and y coordinates are in Visium pseudo-hex (True) or cartesian coordinates (False)
use_pcs = False # whether to use principal components (adata.layers['X_pca']) instead of (raw) count data (adata.X)
spot_dat = anndata_to_tensordata(adata, obs_label=obs_label, use_pcs=use_pcs)
grid_dat = anndata_arrays_to_tensordata(adata, obs_label=obs_label, use_pcs=use_pcs, obs_arr=obs_arr, h_st=h_st, w_st=w_st, vis_coords=vis_coords)
- For larger datasets that can't fit into memory, we provide subclasses of PyTorch
Dataset
objects with lazy loading (slower accession times):
from gridnext.count_datasets import AnnDataset, AnnGridDataset
spot_dat = AnnDataset(adata, obs_label=obs_label, use_pcs=use_pcs)
grid_dat = AnnGridDataset(adata, obs_label=obs_label, use_pcs=use_pcs, obs_arr=obs_arr, h_st=h_st, w_st=w_st, vis_coords=vis_coords)
The aforementioned dataset classes (CountDataset
/CountGridDataset
, PatchDataset
/PatchGridDataset
, MultiModalDataset
/MultiModalGridDataset
) can be instantiated directly by providing count, image, and annotation data in the aforementioned file formats:
- Count data: one (genes x spots) matrix per array, stored in tab-delimited format (other delimiters can be used with
cfile_delim
keyword argument). First column should store gene names, which should be standardized across all arrays, and first column should store spot coordinates in[x]_[y]
format. - Image data: one directory per array containing JPEG-formatted image files (other file formats can be used with
img_ext
keyword argument) extracted from each spatial measurement location. Image patch file names should end with[x]_[y].[img_ext]
to store spatial information. - Annotation data: one (categories x spots) one-hot encoded (exactly one "1" per row) binary annotation matrix per array, stored in CSV format (other delimiters can be used with
afile_delim
keyword argument). First column should store category names, which should be standardized across all arrays, and first column should store spot coordinates in[x]_[y]
format. If Visium data are being passed (Visium=True
), one can alternately pass paired lists of Loupe annotation files and Spaceranger position files in lieu of this custom format.
from gridnext.count_datasets import CountDataset, CountGridDataset
from gridnext.image_datasets import PatchDataset, PatchGridDataset
from gridnext.multimodal_datasets import MultiModalDataset, MultiModalGridDataset
GridNext requires two model instantiations:
- Spot classifier f, which accepts either a transcriptome (1D), image patch (3D), or tuple containing both, and outputs a
n_classes
-length logit tensor- For image data, we provide the
gridnext.densenet.DenseNet
class for instantiating a DenseNet classifier (see image tutorial Section 1 for example). - For count data, a custom
torch.nn.Sequential
network with appropriate input and output dimensions should be used (see count tutorial Section 1.2 for example).
- For image data, we provide the
- Grid classifier g, which accepts either a transcriptomic array tensor (3D), image array tensor (5D), or tuple containig both, and outputs an
(n_classes, H_ST, W_ST)
-shaped logit tensor- For Visium data (either count or image), we provide the
gridnext.gridnet_models.GridNetHexOddr
class. In either instance (see image tutorial Section 2 or count tutorial Section 2.2), instantiation requires:- (pre-trained) patch classifier
- shape of spot data
- shape of spatial grid
- number of classes in final prediction layer
- By default, g takes as input the final output layer of f (spot_shape -> f -> n_classes -> g -> n_classes). To instead learn over a penulatimate feature layer of f, create a truncated network (e.g., a
DenseNet
model instantiated with theclassify=False
option) and instantiateGridNetHexOddr
with thef_dim=SHAPE_OF_F_OUTPUT
option. - For data too large to fit into RAM at once (e.g., tensors of image data),
GridNetHexOddr
provides theatonce_patch_limit
instantiation argument, which splits arrays into mini-batches during training (see example in image tutorial Section 2).
- For Visium data (either count or image), we provide the
GridNext provides two functions for model training:
gridnext.training.train_spotwise
for training f networksgridnext.training.train_gridwise
for training g networks
Both functions require the following arguments:
model
: either f or gdataloaders
: dictionary mapping keys "train" and "val" to separatetorch.utils.data.DataLoader
objects for each data foldcriterion
: loss function fromtorch.nn
optimizer
: optimizer fromtorch.optim
Both functions additionally accept the following optional arguments:
num_epochs
: number of training epochs (defaults to 10)outfile
: destination in which to save trained model parameters (updated each iteration)
See examples in either tutorial
By default, corrector models g will output predictions at every location on the spatial grid, regardless of whether tissue is present or not. For the purposes of evaluating performance & exporting predictions, we provide two utility functions for generating predictions at only foreground locations (e.g., covered by tissue):
gridnext.utils.all_fgd_predictions
-- generates predictions at all foreground locations in a flattened list. Useful for receiver-operator curve and precision-recall curve (see Section 3.2 of count and image tutorials)gridnext.utils.to_loupe_annots
-- exports foreground predictions in Loupe format. See Section 3.3 of count and image tutorials.