Run SIMS in the browser using h5wasm to read local AnnData (.h5ad) files and ONNX to run the model.
Opens an h5ad in the browser and runs a selected SIMs model and displays predictions.
You can view the default ONNX model via netron
The front end a single page React web app using Material UI and Vite with no back end - just file storage and an HTTP server is required. The python pieces all relate to converting PyTorch models into ONNX and then editing the ONNX graph to move as much of predictions processing into the graph as possible (i.e. LpNorm and SoftMax of probabilities) as well as to expose internal nodes such as the encoder output for clustering and the attention masks for explainability. Incremental h5 file reading and inference via ONNX are all handled in worker.js, a web worker that attempts to utilize all the cores on a machine running multi-threaded inference, double buffered incremental h5 file reading with support for sparse data expansion and gene inflation inline.
NOTE: SIMS install from git on a Mac works with Python 3.10, other versions not so much...
Install dependencies for the python model exporting and webapp:
pip install -r requirements.txt
npm install
Export a SIMS checkpoint to an ONNX file and list of genes:
python scripts/sims-to-onnx.py checkpoints/default.ckpt public/models/
Validate the output of SIMS to ONNX using the python runtime:
mkdir -p data/validation
python scripts/validate.py checkpoints/default.ckpt public/models/default.onnx public/sample.h5ad
Check a model for compatibility with various ONNX runtimes:
python -m onnxruntime.tools.check_onnx_model_mobile_usability public/models/default.onnx
Serve the web app and exported models locally with auto-reload courtesy of vite:
npm run dev
Display the compute graph using netron:
netron public/models/default.onnx
worker.js uses h5wasm slice() to read data from the cell by gene matrix (i.e. X). As these data on disk are typically stored row major (i.e. all data for a cell is contiguous) we can process the sample incrementally keeping memory requirements to a minimum. Reading cell by cell from a 5.3G h5ad file consumed just under 30M of browser memory. YMMV.
ONNX supports multithreaded inference. We allocate total cores on the machine - 2 for inference. This leaves 1 thread for the main loop so the UI can remain responsible and 1 thread for ONNX to coordinate via its 'proxy' setting (see worker.js for details).
Processed 1759 cells in 0.18 minutes on a MacBook M3 Pro or around 10k samples per minute.
ONNX Web Runtime does have support for GPUs, but unfortunately they don't support all operators yet. Specifically TopK, LpNormalization and GatherElements are not supported. See sclblonnx.check(graph) for details.
Sometimes the output of an ONNX model will differ from PyTorch due to floating point numerical differences. To determine if this is the case you can export the model as float64 and see if the results match the original PyTorch float32:
torch.onnx.export(
model.double(), # convert entire model to double precision
sample_input.double(), # provide double precision sample input
...
)
SIMS ONNX single precision output differed from PyTorch in around 30% of samples due to small numerical differences in TabNet's Sparsemax implementation. After trying several Grok3 crafted rewrites designed to reduce the single precision numerical differences I found the simplest solution was to run just the Sparsemax sub-graph in double precision via this change to the forward function:
def forward(self, priors, processed_feat):
x = self.fc(processed_feat)
x = self.bn(x)
x = torch.mul(x, priors)
# Force onnx to compute sparsemax in 64bit to avoid numerical differences
if torch.onnx.is_in_onnx_export():
return self.selector(x.to(torch.float64)).to(x.dtype)
else:
return self.selector(x)
After this modification the outputs are concordant to 5+ decimal places. Several files in the scripts folder illustrate how to explore the ONNX graph to hunt for places that diverge, YMMV.
Open Neural Network Exchange (ONNX)
ONNX Runtime Web (WASM Backend)
ONNX Runtime Web Platform Functionality Details
ONNX Runtime Javascript Examples
Alternative Web ONNX Runtime in Rust
Netron ONNX Graph Display Website
Classify images in a web application with ONNX Runtime Web
anndata/h5ad file structure and on disk format
TabNet Model for attentive tabular learning
Semi supervised pre training with TabNet
Classification of Alzheimer's disease using robust TabNet neural networks on genetic data
Designing interpretable deep learning applications for functional genomics: a quantitative analysis
Assessing GPT-4 for cell type annotation in single-cell RNA-seq analysis