Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/improvedgpu: Fixing a bug so that now the user can specify GPU training #23

Merged
merged 3 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions docs/customising_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ This page will show you how to customise the training and evaluation of your fus

We will cover the following topics:

* Using GPU
* Early stopping
* Valildation metrics
* Batch size
Expand All @@ -13,6 +14,38 @@ We will cover the following topics:
* Number of workers in PyTorch DataLoader
* Train/test and cross-validation splitting yourself

Using GPU
------------

If you want to use a GPU to train your model, you can pass the ``training_modifications`` argument to the :func:`~.fusilli.data.prepare_fusion_data` and :func:`~.fusilli.train.train_and_save_models` functions. By default, the model will train on the CPU.

For example, to train on a single GPU, you can do the following:

.. code-block:: python

from fusilli.data import prepare_fusion_data
from fusilli.train import train_and_save_models

datamodule = prepare_fusion_data(
prediction_task="binary",
fusion_model=example_model,
data_paths=data_paths,
output_paths=output_path,
)

trained_model_list = train_and_save_models(
data_module=datamodule,
fusion_model=example_model,
training_modifications={"accelerator": "gpu", "devices": 1},
)

.. warning::

This is currently not implemented for subspace-based models as of May 2024.
When this is implemented, the documentation will be updated.



Early stopping
--------------

Expand Down
Loading
Loading