Skip to content

braingeneers/sims-web

Repository files navigation

sims-web

Run SIMS in the browser using h5wasm to read local AnnData (.h5ad) files and ONNX to run the model.

Opens an h5ad in the browser and runs a selected SIMs model and displays predictions.

You can view the default ONNX model via netron

Alt text

Architecture

The front end a single page React web app using Material UI and Vite with no back end - just file storage and an HTTP server is required. The python pieces all relate to converting PyTorch models into ONNX and then editing the ONNX graph to move as much of predictions processing into the graph as possible (i.e. LpNorm and SoftMax of probabilities) as well as to expose internal nodes such as the encoder output for clustering and the attention masks for explainability. Incremental h5 file reading and inference via ONNX are all handled in worker.js, a web worker that attempts to utilize all the cores on a machine running multi-threaded inference, double buffered incremental h5 file reading with support for sparse data expansion and gene inflation inline.

Developing

NOTE: SIMS install from git on a Mac works with Python 3.10, other versions not so much...

Install dependencies for the python model exporting and webapp:

pip install -r requirements.txt
npm install

Export a SIMS checkpoint to an ONNX file and list of genes:

python scripts/sims-to-onnx.py checkpoints/default.ckpt public/models/

Validate the output of SIMS to ONNX using the python runtime:

mkdir -p data/validation
python scripts/validate.py checkpoints/default.ckpt public/models/default.onnx public/sample.h5ad

Check a model for compatibility with various ONNX runtimes:

python -m onnxruntime.tools.check_onnx_model_mobile_usability public/models/default.onnx

Serve the web app and exported models locally with auto-reload courtesy of vite:

npm run dev

Display the compute graph using netron:

netron public/models/default.onnx

Memory Requirements

worker.js uses h5wasm slice() to read data from the cell by gene matrix (i.e. X). As these data on disk are typically stored row major (i.e. all data for a cell is contiguous) we can process the sample incrementally keeping memory requirements to a minimum. Reading cell by cell from a 5.3G h5ad file consumed just under 30M of browser memory. YMMV.

Performance

ONNX supports multithreaded inference. We allocate total cores on the machine - 2 for inference. This leaves 1 thread for the main loop so the UI can remain responsible and 1 thread for ONNX to coordinate via its 'proxy' setting (see worker.js for details).

Processed 1759 cells in 0.18 minutes on a MacBook M3 Pro or around 10k samples per minute.

Leveraging a GPU

ONNX Web Runtime does have support for GPUs, but unfortunately they don't support all operators yet. Specifically TopK, LpNormalization and GatherElements are not supported. See sclblonnx.check(graph) for details.

ONNX vs. PyTorch Concordance

Sometimes the output of an ONNX model will differ from PyTorch due to floating point numerical differences. To determine if this is the case you can export the model as float64 and see if the results match the original PyTorch float32:

torch.onnx.export(
    model.double(),  # convert entire model to double precision
    sample_input.double(),  # provide double precision sample input
    ...
)

SIMS ONNX single precision output differed from PyTorch in around 30% of samples due to small numerical differences in TabNet's Sparsemax implementation. After trying several Grok3 crafted rewrites designed to reduce the single precision numerical differences I found the simplest solution was to run just the Sparsemax sub-graph in double precision via this change to the forward function:

def forward(self, priors, processed_feat):
    x = self.fc(processed_feat)
    x = self.bn(x)
    x = torch.mul(x, priors)
    # Force onnx to compute sparsemax in 64bit to avoid numerical differences
    if torch.onnx.is_in_onnx_export():
        return self.selector(x.to(torch.float64)).to(x.dtype)
    else:
        return self.selector(x)

After this modification the outputs are concordant to 5+ decimal places. Several files in the scripts folder illustrate how to explore the ONNX graph to hunt for places that diverge, YMMV.

References

Open Neural Network Exchange (ONNX)

ONNX Runtime Web (WASM Backend)

ONNX Runtime Web Platform Functionality Details

ONNX Runtime Javascript Examples

Alternative Web ONNX Runtime in Rust

ONNX Simplifier

Netron ONNX Graph Display Website

Graphical ONNX Editor

Classify images in a web application with ONNX Runtime Web

h5wasm

anndata/h5ad file structure and on disk format

SIMS Streamlit App and Source

TabNet Model for attentive tabular learning

Semi supervised pre training with TabNet

Self Supervised TabNet

Classification of Alzheimer's disease using robust TabNet neural networks on genetic data

Designing interpretable deep learning applications for functional genomics: a quantitative analysis

Assessing GPT-4 for cell type annotation in single-cell RNA-seq analysis

10x Space Ranger HDF5 Feature-Barcode Matrix Format

About

Attempt to run SIMS inference in the browser

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published