Skip to content

A highly efficient library for large scale distributed training

License

Notifications You must be signed in to change notification settings

ibm-granite/dolomite-engine

Repository files navigation

Dolomite Engine

Introduction

This repository contains code used for pretraining and finetuning IBM's Granite models. It also includes the following key innovations on model architectures, finetuning methods, systems optimizations:

  1. Saving Memory Using Padding-Free Transformer Layers during Finetuning
    Mayank Mishra
    image
  2. Reducing Transformer Key-Value Cache Size with Cross-Layer Attention
    William Brandon, Mayank Mishra, Aniruddha Nrusimha, Rameswar Panda, Jonathan Ragan Kelly
    image image
  3. Dense Training, Sparse Inference: Rethinking Training of Mixture-of-Experts Language Models
    Bowen Pan, Yikang Shen, Haokun Liu, Mayank Mishra, Gaoyuan Zhang, Aude Oliva, Colin Raffel, Rameswar Panda
    image image image
  4. NEFTune: Noisy Embeddings Improve Instruction Finetuning
    Neel Jain, Ping-yeh Chiang, Yuxin Wen, John Kirchenbauer, Hong-Min Chu, Gowthami Somepalli, Brian R. Bartoldson, Bhavya Kailkhura, Avi Schwarzschild, Aniruddha Saha, Micah Goldblum, Jonas Geiping, Tom Goldstein
    image
  5. Parallelizing Linear Transformers with the Delta Rule over Sequence Length
    Songlin Yang, Bailin Wang, Yu Zhang, Yikang Shen, Yoon Kim
    image image image
  6. Scattered Mixture-of-Experts Implementation
    Shawn Tan, Yikang Shen, Rameswar Panda, Aaron Courville
    image image image

Getting Started

Run make install to install the requirements for this repository. You might need to install flash-attn.

Distributed finetuning

This repository is meant for finetuning large language models (of any scale) using multiple backends. The following backends are currently supported:

  1. FSDP
  2. DeepSpeed

The repository currently only supports generative models but can be easily extended to non-generative models if needed. 2 main class of models from HuggingFace are supported:

  1. decoder models (AutoModelForCausalLM) like Granite, Llama, BLOOM etc
  2. encoder-decoder models (AutoModelForSeq2SeqLM) like T5, BART etc

Please note that this repository doesn't support Tensor Parallel or Pipeline Parallel (yet 😉).

HuggingFace compatible custom models

This repository works with all HuggingFace models (text-to-text only for the moment) out-of-the-box. The checkpoints have to be in safetensors format, if not you can check tools/pt_to_safetensors.py. If your model_type is gpt_megatron just change it to gpt_dolomite.

Tip

You might be able to enjoy additional memory and computation savings when finetuning your models using the padding free transformers optimization. This optimization is currently only supported for decoder models and requires converting your model (say LLama-3 for example) to a custom class implemented in this repo. This is completely optional and not required for finetuning. The conversion can be achieved as follows:

from dolomite_engine.hf_models import import_from_huggingface

import_from_huggingface(
    pretrained_model_name_or_path="ibm-granite/granite-3b-code-base",
    save_path="dolomite_compatible_model"
)

Once done training, you can convert the model back to the HF class as:

from dolomite_engine.hf_models import export_to_huggingface

export_to_huggingface(
    pretrained_model_name_or_path="trained_checkpoint",
    save_path="hf_compatible_model",
    model_type="llama",
)

If you are interested in using this optimization outside this repo for some reason, you can do as follows:

import torch
from dolomite_engine.hf_models import GPTDolomiteForCausalLM


# we need unpadded lists here for avoiding any useless computations on pad tokens
# this is a bit different from the standard transformer which takes in tensors and an attention mask
# if you turn off padding free transformers, you can use the tensor inputs with this class too
input_ids = [[1, 2, 3, 4, 5, 0], [6, 7, 8, 0]]
labels = [[-100, -100, -100, 4, 5, 0], [-100, -100, 8, 0]]

# this will throw a warning saying that the model is of gpt_bigcode class
# ignore the warning
model = GPTDolomiteForCausalLM.from_pretrained(
    <model_path>,
    attn_implementation="flash_attention_2"
    use_padding_free_transformer=True,
).cuda()

loss = model(
    input_ids=input_ids,
    labels=labels,
).loss

Note that padding free transformers doesn't support generation and thus for running generation on the model, you will need to load the model without padding-free transformers.

Usage

The typical training workflow looks like:

  1. Pretraining or Finetuning: This is the actual training process
# for finetuning
sh scripts/finetune.sh configs/sst2/training.yml
# for pretraining
sh scripts/pretrain.sh configs/pretraining-examples/pretrain-1.yml
  1. Inference: Run inference on the trained models or the un-trained model
sh scripts/generate.sh configs/sst2/inference.yml
  1. Unshard the checkpoint: This is used to unshard the model to a safetensors checkpoint since dolomite-engine saves a sharded model during training
sh scripts/unshard.sh configs/sst2/unshard.yml

Running basic inference

For a simple HuggingFace inference example, refer to tools/inference.py. For an example running tensor parallel inference, refer to tools/tensor_parallel_inference.py.

Using custom datasets

The data directory should obey the following structure:

📦data
 ┣ 📂train
 ┃ ┣ 📜filename1.jsonl
 ┃ ┣ 📜filename2.jsonl
 ┃ ┗ 📜filename3.jsonl
 ┗ 📂val
 ┃ ┣ 📜filename1.jsonl
 ┃ ┣ 📜filename2.jsonl
 ┃ ┣ 📜filename3.jsonl
 ┣ 📂test
 ┃ ┣ 📜filename1.jsonl
 ┃ ┣ 📜filename2.jsonl
 ┃ ┣ 📜filename3.jsonl

Filenames can be anything as long as there are no whitespaces in them. Each line in each file should be a json (jsonlines file format) with the entries looking like:

{"input": "The movie sucks", "output": "negative"}
{"input": "The movie was awesome", "output": "positive"}

Note for the test set, only input field is needed in the json instances in each line. output field is not needed.

All the files in each directory are concatenated to form the respective split.

If you need reformatting of the examples, you can use input_format and output_format arguments. For example input_format = 'Classify the sentiment of the sentence:\n__input__\nSentiment:' and output_format = ' __output__' reformats the input and output examples to:

INPUT:
Classify the sentiment of the sentence:
The movie sucks
Sentiment:

OUTPUT:
 negative

If you don't need any reformatting, leave the arguments input_format and output_format to their default values __input__ and __output__ respectively.

Please note that the user is expected to provide this at both training and inference time.

Try not to have trailing spaces in input_format, if you need a space between input and output, the space should be part of the output_format as in the above example.

Tip

Alternatively, you can also add your own dataset class in the repository if you don't want to use the jsonlines format or need custom logic to load your own dataset.

Currently, the repo has following implemented dataclasses:

AlpacaDataset
DebugDataset
DollyDataset
HuggingFaceDataset
SlimOrcaDataset
SST2Dataset

Using Megatron Dataset outside of this repository

This repository implements the dataloader from Megatron-LM for efficient pretraining. If for some reason you need to use that dataloader outside this repository, take a look at this example.

Supported optimizers

We support all of the following optimizers. The default optimizer is TorchAdamW. Note that using the DeepSpeed or Apex optimizers will require installing the respective pip package.

# https://nvidia.github.io/apex/optimizers.html
from apex.optimizers import FusedAdam as ApexFusedAdam
from apex.optimizers import FusedLAMB as ApexFusedLAMB
from apex.optimizers import FusedNovoGrad as ApexFusedNovoGrad
from apex.optimizers import FusedSGD as ApexFusedSGD

# https://deepspeed.readthedocs.io/en/latest/optimizers.html
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.ops.adam import FusedAdam as DeepSpeedFusedAdam
from deepspeed.ops.lamb import FusedLamb as DeepSpeedFusedLAMB
from deepspeed.runtime.fp16.onebit import OnebitAdam as DeepSpeedOnebitAdam
from deepspeed.runtime.fp16.onebit import OnebitLamb as DeepSpeedOnebitLAMB
from deepspeed.runtime.fp16.onebit import ZeroOneAdam as DeepSpeedZeroOneAdam

# https://pytorch.org/docs/stable/optim.html
from torch.optim.adadelta import Adadelta as TorchAdadelta
from torch.optim.adagrad import Adagrad as TorchAdagrad
from torch.optim.adam import Adam as TorchAdam
from torch.optim.adamax import Adamax as TorchAdamax
from torch.optim.adamw import AdamW as TorchAdamW
from torch.optim.asgd import ASGD as TorchASGD
from torch.optim.lbfgs import LBFGS as TorchLBFGS
from torch.optim.nadam import NAdam as TorchNAdam
from torch.optim.radam import RAdam as TorchRAdam
from torch.optim.rmsprop import RMSprop as TorchRMSprop
from torch.optim.rprop import Rprop as TorchRprop
from torch.optim.sgd import SGD as TorchSGD

Citation

If you find this repository useful, please consider citing it in your research:

@software{Mishra_Dolomite_Engine_A_2024,
    author = {Mishra, Mayank},
    month = jun,
    title = {{Dolomite Engine: A Hyper-Optimized Library for Pretraining and Finetuning}},
    url = {https://github.com/ibm-granite/dolomite-engine},
    year = {2024}
}

About

A highly efficient library for large scale distributed training

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages