Skip to content

Save, load, host, and share AI model checkpoints without slowing down training. Host on Lightning AI or your own cloud with enterprise-grade access controls.

License

Notifications You must be signed in to change notification settings

Lightning-AI/LitModels

Repository files navigation

Save, share and host AI model checkpoints Lightning fast ⚡

Lightning

Save, load, host, and share models without slowing down training. LitModels minimizes training slowdowns from checkpoint saving. Share public links on Lightning AI or your own cloud with enterprise-grade access controls.

✅ Checkpoint without slowing training.  ✅ Granular access controls.           
✅ Load models anywhere.                 ✅ Host on Lightning or your own cloud.

Discord CI testing Cloud integration codecov license

Quick start

Install LitModels via pip:

pip install litmodels

Toy example (see real examples):

import litmodels as lm
import torch

# save a model
model = torch.nn.Module()
lm.save_model(model=model, name="model-name")

# load a model
model = lm.load_model(name="model-name")

Examples

PyTorch

Save model:

import torch
from litmodels import save_model

model = torch.nn.Module()
save_model(model=model, name="your_org/your_team/torch-model")

Load model:

from litmodels import load_model

model_ = load_model(name="your_org/your_team/torch-model")
PyTorch Lightning

Save model:

from lightning import Trainer
from litmodels import upload_model
from litmodels.demos import BoringModel

# Configure Lightning Trainer
trainer = Trainer(max_epochs=2)
# Define the model and train it
trainer.fit(BoringModel())

# Upload the best model to cloud storage
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
# Define the model name - this should be unique to your model
upload_model(model=checkpoint_path, name="<organization>/<teamspace>/<model-name>")

Load model:

from lightning import Trainer
from litmodels import download_model
from litmodels.demos import BoringModel

# Load the model from cloud storage
checkpoint_path = download_model(
    # Define the model name and version - this needs to be unique to your model
    name="<organization>/<teamspace>/<model-name>:<model-version>",
    download_dir="my_models",
)
print(f"model: {checkpoint_path}")

# Train the model with extended training period
trainer = Trainer(max_epochs=4)
trainer.fit(BoringModel(), ckpt_path=checkpoint_path)
TensorFlow / Keras

Save model:

from tensorflow import keras

from litmodels import save_model

# Define the model
model = keras.Sequential(
    [
        keras.layers.Dense(10, input_shape=(784,), name="dense_1"),
        keras.layers.Dense(10, name="dense_2"),
    ]
)

# Compile the model
model.compile(optimizer="adam", loss="categorical_crossentropy")

# Save the model
save_model("lightning-ai/jirka/sample-tf-keras-model", model=model)

Load model:

from litmodels import load_model

model_ = load_model(
    "lightning-ai/jirka/sample-tf-keras-model", download_dir="./my-model"
)
SKLearn

Save model:

from sklearn import datasets, model_selection, svm
from litmodels import save_model

# Load example dataset
iris = datasets.load_iris()
X, y = iris.data, iris.target

# Split dataset into training and test sets
X_train, X_test, y_train, y_test = model_selection.train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Train a simple SVC model
model = svm.SVC()
model.fit(X_train, y_train)

# Upload the saved model using litmodels
save_model(model=model, name="your_org/your_team/sklearn-svm-model")

Use model:

from litmodels import load_model

# Download and load the model file from cloud storage
model = load_model(
    name="your_org/your_team/sklearn-svm-model", download_dir="my_models"
)

# Example: run inference with the loaded model
sample_input = [[5.1, 3.5, 1.4, 0.2]]
prediction = model.predict(sample_input)
print(f"Prediction: {prediction}")

Features

PyTorch Lightning Callback

Enhance your training process with an automatic checkpointing callback that uploads the model at the end of each epoch.

import torch.utils.data as data
import torchvision as tv
from lightning import Trainer
from litmodels.integrations import LightningModelCheckpoint
from litmodels.demos import BoringModel

dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])

trainer = Trainer(
    max_epochs=2,
    callbacks=[
        LightningModelCheckpoint(
            # Define the model name - this should be unique to your model
            model_registry="<organization>/<teamspace>/<model-name>",
        )
    ],
)
trainer.fit(
    BoringModel(),
    data.DataLoader(train, batch_size=256),
    data.DataLoader(val, batch_size=256),
)
Save any Python class as a checkpoint

Mixin classes streamline model management in Python by modularizing reusable functionalities like saving/loading, enabling consistent, conflict-free, and maintainable code across multiple models.

Save model:

from litmodels.integrations.mixins import PickleRegistryMixin


class MyModel(PickleRegistryMixin):
    def __init__(self, param1, param2):
        self.param1 = param1
        self.param2 = param2
        # Your model initialization code
        ...


# Create and push a model instance
model = MyModel(param1=42, param2="hello")
model.upload_model(name="my-org/my-team/my-model")

Load model:

loaded_model = MyModel.download_model(name="my-org/my-team/my-model")
Save custom PyTorch models

Mixin classes centralize serialization logic, eliminating redundant code and ensuring consistent, error-free model persistence across projects. The download_model method bypasses constructor arguments entirely, reconstructing the model directly from the registry with pre-configured architecture and weights, eliminating initialization mismatches.

Save model:

import torch
from litmodels.integrations.mixins import PyTorchRegistryMixin


# Important: PyTorchRegistryMixin must be first in the inheritance order
class MyTorchModel(PyTorchRegistryMixin, torch.nn.Module):
    def __init__(self, input_size, hidden_size=128):
        super().__init__()
        self.linear = torch.nn.Linear(input_size, hidden_size)
        self.activation = torch.nn.ReLU()

    def forward(self, x):
        return self.activation(self.linear(x))


# Create and push the model
model = MyTorchModel(input_size=784)
model.upload_model(name="my-org/my-team/torch-model")

Use the model:

# Pull the model with the same architecture
loaded_model = MyTorchModel.download_model(name="my-org/my-team/torch-model")

Performance

Community

💬 Get help on Discord
📋 License: Apache 2.0

About

Save, load, host, and share AI model checkpoints without slowing down training. Host on Lightning AI or your own cloud with enterprise-grade access controls.

Topics

Resources

License

Stars

Watchers

Forks