diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 1fc0b13f..ac98b6b3 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -1315,97 +1315,4 @@ Below is the documentation on the available arguments. Interval of validation in training --train-ratio 0.8 Ratio of train dataset. The remaining will be used for valid and test split. --valid-ratio 0.1 Ratio of validation set after the train data split. The remaining will be test split - --share-model - -Model initialization using the Torch API ----------------------------------------- - -The scikit-learn API provides parametrization to many common use cases. -The Torch API however allows for more flexibility and customization, for e.g. -sampling, criterions, and data loaders. - -In this minimal example we show how to initialize a CEBRA model using the Torch API. -Here the :py:class:`cebra.data.single_session.DiscreteDataLoader` -gets initilized which also allows the `prior` to be directly parametrized. - -👉 For an example notebook using the Torch API check out the :doc:`demo_notebooks/Demo_Allen`. - - -.. testcode:: - - import numpy as np - import cebra.datasets - from cebra import plot_embedding - import torch - - if torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - - neural_data = cebra.load_data(file="neural_data.npz", key="neural") - - discrete_label = cebra.load_data( - file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"], - ) - - # 1. Define Cebra Dataset - InputData = cebra.data.TensorDataset( - torch.from_numpy(neural_data).type(torch.FloatTensor), - discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor), - ).to(device) - - # 2. Define Cebra Model - neural_model = cebra.models.init( - name="offset10-model", - num_neurons=InputData.input_dimension, - num_units=32, - num_output=2, - ).to(device) - - InputData.configure_for(neural_model) - - # 3. Define Loss Function Criterion and Optimizer - Crit = cebra.models.criterions.LearnableCosineInfoNCE( - temperature=0.001, - min_temperature=0.0001 - ).to(device) - - Opt = torch.optim.Adam( - list(neural_model.parameters()) + list(Crit.parameters()), - lr=0.001, - weight_decay=0, - ) - - # 4. Initialize Cebra Model - solver = cebra.solver.init( - name="single-session", - model=neural_model, - criterion=Crit, - optimizer=Opt, - tqdm_on=True, - ).to(device) - - # 5. Define Data Loader - loader = cebra.data.single_session.DiscreteDataLoader( - dataset=InputData, num_steps=10, batch_size=200, prior="uniform" - ) - - # 6. Fit Model - solver.fit(loader=loader) - - # 7. Transform Embedding - TrainBatches = np.lib.stride_tricks.sliding_window_view( - neural_data, neural_model.get_offset().__len__(), axis=0 - ) - - X_train_emb = solver.transform( - torch.from_numpy(TrainBatches[:]).type(torch.FloatTensor).to(device) - ).to(device) - - # 8. Plot Embedding - plot_embedding( - X_train_emb, - discrete_label[neural_model.get_offset().__len__() - 1 :, 0], - markersize=10, - ) + --share-model \ No newline at end of file