Skip to content

A library independent forward pass inspector for neural nets

License

Notifications You must be signed in to change notification settings

adonath/clouseau

Folders and files

NameName
Last commit message
Last commit date

Latest commit

4af459a · Apr 4, 2025

History

37 Commits
Apr 4, 2025
Mar 5, 2025
Mar 5, 2025
Mar 5, 2025
Mar 5, 2025
Feb 16, 2025
Feb 13, 2025
Feb 13, 2025
Feb 13, 2025
Feb 13, 2025
Mar 5, 2025
Feb 13, 2025
Feb 17, 2025
Mar 5, 2025
Feb 13, 2025

Repository files navigation

Clouseau: the forward pass inspector

Release Build status codecov Commit activity License

Clouseau Banner

A library independent forward pass inspector for neural nets. The tool is designed to be used with PyTorch and Jax (others libraries might come later...). It allows you to register hooks for the forward pass of a model, and write the forward pass activations to a file for later inspection. It is useful for debugging models or transitioning models from one framework to another and checking their equivalence at any stage.

Installation

pip install clouseau

Usage

Jax / Equinox Example

You can use the inspector as a context manager to record the forward pass of a model. The following example shows how to use the inspector with a model from the Equinox library:

import jax
import equinox as eqx
from clouseau import inspector

keys = jax.random.split(jax.random.PRNGKey(918832), 3)

model = eqx.nn.Sequential([
    eqx.nn.Linear(764, 100, keys[0]),
    jax.nn.relu,
    eqx.nn.Linear(100, 50, keys[0]),
    jax.nn.relu,
    eqx.nn.Linear(50, 10, keys[0]),
    jax.nn.sigmoid,
])
x = jax.random.normal(jax.random.PRNGKey(0), (764,))

with inspector.tail(model, path="activations.safetensors") as m:
    m(x)

Then in an interactive session inspect the recorded activations:

from clouseau import inspector

inspector.magnify("activations.safetensors")

Which will open the file and generate a hierachical treescope view of the activations.

PyTorch Example

from torch import nn
from clouseau import inspector

model = nn.Sequential({
    "dense1": nn.Linear(764, 100),
    "act1": nn.ReLU(),
    "dense2": nn.Linear(100, 50),
    "act2": nn.ReLU(),
    "output": nn.Linear(50, 10),
    "outact": nn.Sigmoid(),
})

x = torch.randn((764,))

with inspector.tail(model) as m:
    m(x)

For more advanced usage including filtering layer types, please refer to the documentation.

About

A library independent forward pass inspector for neural nets

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published