Skip to content

Commit

Permalink
undo changes in usage.rst
Browse files Browse the repository at this point in the history
  • Loading branch information
timonmerk committed Oct 29, 2023
1 parent c82695e commit e8f73fe
Showing 1 changed file with 1 addition and 94 deletions.
95 changes: 1 addition & 94 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1315,97 +1315,4 @@ Below is the documentation on the available arguments.
Interval of validation in training
--train-ratio 0.8 Ratio of train dataset. The remaining will be used for valid and test split.
--valid-ratio 0.1 Ratio of validation set after the train data split. The remaining will be test split
--share-model
Model initialization using the Torch API
----------------------------------------

The scikit-learn API provides parametrization to many common use cases.
The Torch API however allows for more flexibility and customization, for e.g.
sampling, criterions, and data loaders.

In this minimal example we show how to initialize a CEBRA model using the Torch API.
Here the :py:class:`cebra.data.single_session.DiscreteDataLoader`
gets initilized which also allows the `prior` to be directly parametrized.

👉 For an example notebook using the Torch API check out the :doc:`demo_notebooks/Demo_Allen`.


.. testcode::

import numpy as np
import cebra.datasets
from cebra import plot_embedding
import torch

if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"

neural_data = cebra.load_data(file="neural_data.npz", key="neural")

discrete_label = cebra.load_data(
file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"],
)

# 1. Define Cebra Dataset
InputData = cebra.data.TensorDataset(
torch.from_numpy(neural_data).type(torch.FloatTensor),
discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor),
).to(device)

# 2. Define Cebra Model
neural_model = cebra.models.init(
name="offset10-model",
num_neurons=InputData.input_dimension,
num_units=32,
num_output=2,
).to(device)

InputData.configure_for(neural_model)

# 3. Define Loss Function Criterion and Optimizer
Crit = cebra.models.criterions.LearnableCosineInfoNCE(
temperature=0.001,
min_temperature=0.0001
).to(device)

Opt = torch.optim.Adam(
list(neural_model.parameters()) + list(Crit.parameters()),
lr=0.001,
weight_decay=0,
)

# 4. Initialize Cebra Model
solver = cebra.solver.init(
name="single-session",
model=neural_model,
criterion=Crit,
optimizer=Opt,
tqdm_on=True,
).to(device)

# 5. Define Data Loader
loader = cebra.data.single_session.DiscreteDataLoader(
dataset=InputData, num_steps=10, batch_size=200, prior="uniform"
)

# 6. Fit Model
solver.fit(loader=loader)

# 7. Transform Embedding
TrainBatches = np.lib.stride_tricks.sliding_window_view(
neural_data, neural_model.get_offset().__len__(), axis=0
)

X_train_emb = solver.transform(
torch.from_numpy(TrainBatches[:]).type(torch.FloatTensor).to(device)
).to(device)

# 8. Plot Embedding
plot_embedding(
X_train_emb,
discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
markersize=10,
)
--share-model

0 comments on commit e8f73fe

Please sign in to comment.