Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for data parallel QLoRA training via DeepSpeed Zero stages 0, 1 and 2. #3728

Open
wants to merge 49 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
af00bba
WIP: deepspeed stage 2
tgaddair Sep 25, 2023
70bde7d
Place on current device
tgaddair Sep 25, 2023
4d89ddb
Fixed device placement
tgaddair Sep 25, 2023
7d6e708
Merge branch 'master' into ds-stage2
arnavgarg1 Sep 27, 2023
fcd7d3b
Fix issue with distributed eval metric_fn placement
arnavgarg1 Sep 27, 2023
4b43aef
Add workaround for checkpoint saving based on stage
arnavgarg1 Sep 27, 2023
17a71f0
Gate logging statements with is_coordinator() so they only show up fo…
arnavgarg1 Sep 27, 2023
7d65619
Gate more logging with coordinator barrier
arnavgarg1 Sep 29, 2023
a30ba99
Latest push
arnavgarg1 Sep 29, 2023
c8c273b
Working e2e, but not everything is correct
arnavgarg1 Oct 3, 2023
1aa8efe
Clarification comment
arnavgarg1 Oct 3, 2023
8c525e0
Clean up
arnavgarg1 Oct 6, 2023
b7d68ae
Docstring
arnavgarg1 Oct 6, 2023
78da3d9
resolve merge conflictts
arnavgarg1 Oct 6, 2023
bd3341f
More cleanup
arnavgarg1 Oct 6, 2023
ada045b
Set default optimization stage to 3
arnavgarg1 Oct 6, 2023
c997d33
Merge branch 'master' into ds-stage2
arnavgarg1 Oct 6, 2023
a3546f5
Merge branch 'master' into ds-stage2
arnavgarg1 Oct 10, 2023
5284420
Compatibility, but most of this is not needed with some re-architecting
arnavgarg1 Oct 10, 2023
d0c174e
Add filelock around model from_pretrained call
arnavgarg1 Oct 10, 2023
e68a02e
Add dynamic device_map setting based on backend and zero stage
arnavgarg1 Oct 10, 2023
fec3ef4
Comments
arnavgarg1 Oct 11, 2023
65855ec
Merge branch 'master' into ds-stage2
arnavgarg1 Oct 13, 2023
48f353d
Conditional model loading in LLM base class
arnavgarg1 Oct 13, 2023
9ba3347
Add TODO to fix issue
arnavgarg1 Oct 13, 2023
b4fdc0e
Revert to what was working
arnavgarg1 Oct 13, 2023
eb33afd
Add utility functions to simplify
arnavgarg1 Oct 13, 2023
b3c2496
Working e2s DS stage 2
arnavgarg1 Oct 13, 2023
b14bbeb
Simplify
arnavgarg1 Oct 13, 2023
f519672
Minor modification for ds stage 3 compatibility
arnavgarg1 Oct 13, 2023
b34ef05
Merge branch 'master' into ds-stage2
arnavgarg1 Oct 13, 2023
799fcd6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 13, 2023
ecf2113
Working with DS stage 3 as wellgit add ludwig/
arnavgarg1 Oct 16, 2023
6f70747
Cleanup
arnavgarg1 Oct 16, 2023
2eea6a2
Cleanup
arnavgarg1 Oct 16, 2023
d9d49ce
More cleanup
arnavgarg1 Oct 16, 2023
cd73bd1
Refactor
arnavgarg1 Oct 16, 2023
483d752
Add Data parallel QloRA example training script
arnavgarg1 Oct 16, 2023
f0099b4
Add basic unit tests
arnavgarg1 Oct 16, 2023
a379712
Log artifact dir
arnavgarg1 Oct 16, 2023
855d2c1
Fix example script
arnavgarg1 Oct 16, 2023
3a11709
Comments
arnavgarg1 Oct 16, 2023
e8cdf27
Resolve merge conflicts
arnavgarg1 Oct 17, 2023
2cf7f4b
Resolve comments
arnavgarg1 Oct 17, 2023
832eb93
Comment with relevant doc snippets
arnavgarg1 Oct 17, 2023
fbba8cb
Add config validation check and more tests for quantization and backe…
arnavgarg1 Oct 18, 2023
b5b3dd7
Address comments
arnavgarg1 Oct 19, 2023
2e06e5f
Merge branch 'master' into ds-stage2
arnavgarg1 Oct 19, 2023
46fae60
Address nit
arnavgarg1 Oct 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions examples/llm_qlora_data_parallel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Data-Parallel QLoRA Fine-Tuning

If you have a single-node multi-GPU setup with a large dataset that you would like to train using QLoRA, you can use DeepSpeed Stage 0, 1, or 2.

## DeepSpeed Background

As a refresher, here is what each DeepSpeed Zero stage corresponds to:

- **Stage 0**: Disabled, i.e., no partitioning of optimizer state, gradients or model parameters. You can still perform optimizer and parameter offloading, as well training using bf16 or fp16 etc.
- **Stage 1**: The optimizer states (e.g., for Adam optimizer, 32-bit weights, and the first, and second moment estimates) are partitioned across the processes, so that each process updates only its partition.
- **Stage 2**: The reduced 32-bit gradients for updating the model weights are also partitioned such that each process retains only the gradients corresponding to its portion of the optimizer states.
- **Stage 3**: The 16-bit model parameters are partitioned across the processes. ZeRO-3 will automatically collect and partition them during the forward and backward passes.

_NOTE: Data Parallel QLoRA based training only works with DeepSpeed stages \<= 2. This is because DeepSpeed isn't
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved
compatible with partitioning/sharding of quantized weights as of DeepSpeed 0.10.3 when weights are a mixture of dtypes_. See:

- https://github.com/microsoft/DeepSpeed/issues/4295
- https://github.com/microsoft/DeepSpeed/issues/3620

In particular, this comment here summarizes it well:

> Some code for ZeRO3 assumes that all parameters in a model has the same dtype. This model has uint8 and float32 parameters and it throws the error.

## Example Config

The example `train.py` uses DeepSpeed Stage 2 with the Ray backend as follows to fine-tune a model for natural language to code generation task via instruction fine-tuning.

```yaml
backend:
type: ray
trainer:
use_gpu: true
strategy:
type: deepspeed
zero_optimization:
stage: 2
```

In most cases, stage 2 lets you train large models in distributed fashion across multiple GPUs. However, if you want to use Stage 0 or 1, you can just replace `stage: 2` to the desired zero optimization stage.

## DeepSpeed Zero Stage Benefits

### Benefits of DeepSpeed Stage 0

- **Ease of Use**: Stage 0 is relatively easy to set up and use, making it a good starting point for users looking for memory-efficient training without the complexity of more advanced optimization techniques.
- **Gradient Accumulation**: Stage 0 enables gradient accumulation, which is beneficial for simulating larger batch sizes even on hardware with memory constraints. This can lead to more stable model training and potentially faster convergence.
- **Mixed Precision Training**: It supports mixed-precision training, which utilizes lower-precision data types (e.g., float16) to reduce memory usage while maintaining training stability.

### Benefits of DeepSpeed Stage 1

- **Optimizer State Partitioning**: Stage 1 is primarily focused on partitioning the optimizer state, allowing you to train very large models that wouldn't fit within a single GPU's memory.
- **Memory Efficiency**: It efficiently manages memory by dividing the optimizer state into segments distributed across multiple GPUs. This makes training larger models feasible.
- **Single-GPU Training**: Stage 1 is especially valuable when you need to train large models on a single GPU, making it an essential step before scaling up to more advanced stages for distributed training.
- **Limited Configuration Complexity**: It introduces memory efficiency while maintaining a relatively simple configuration setup compared to the more advanced stages like Stage 2 and Stage 3.

### Benefits of DeepSpeed Stage 2

- **Training Extremely Large Models**: ZeRO Stage 2 partitions both the gradients and the optimizer state to reduce memory requirements significantly. By contrast, Stage 0 and Stage 1 do not have the same level of memory optimization to handle models of such magnitude.
- **Advanced Distributed Training**: ZeRO Stage 2 is designed to handle distributed training at an unprecedented scale. It optimizes communication, gradient aggregation, and synchronization between GPUs and nodes, making it ideal for training large models efficiently in a distributed environment. This advanced distributed training capability is not present in Stage 0 and is more sophisticated than that of Stage 1, which helps in achieving faster training times and handling larger workloads.
144 changes: 144 additions & 0 deletions examples/llm_qlora_data_parallel/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import logging
import os

import numpy as np
import pandas as pd
import yaml

from ludwig.api import LudwigModel
from ludwig.datasets import code_alpaca

np.random.seed(123)


# Llama-2-7b-hf requires HUGGING_FACE_HUB_TOKEN to be set as an environment variable
# You can get a token at https://huggingface.co/settings/tokens
if "HUGGING_FACE_HUB_TOKEN" not in os.environ:
raise ValueError(
"Please set your Hugging Face Hub token as an environment variable using `export "
"HUGGING_FACE_HUB_TOKEN=your_token`. You can get a token at https://huggingface.co/settings/tokens"
)

fine_tuning_config = yaml.safe_load(
"""
model_type: llm
base_model: meta-llama/Llama-2-7b-hf

input_features:
- name: instruction
type: text

output_features:
- name: output
type: text

prompt:
template: >-
Below is an instruction that describes a task, paired with an input
that provides further context. Write a response that appropriately
completes the request.

### Instruction: {instruction}

### Input: {input}

### Response:

generation:
temperature: 0.1
max_new_tokens: 256

adapter:
type: lora

quantization:
bits: 4

preprocessing:
split:
type: random
probabilities:
- 0.9
- 0.05
- 0.05
global_max_sequence_length: 512
sample_size: 1000

backend:
type: ray
trainer:
use_gpu: true
strategy:
type: deepspeed
zero_optimization:
stage: 2

trainer:
type: finetune
epochs: 3
batch_size: 1
eval_batch_size: 1
enable_gradient_checkpointing: true
gradient_accumulation_steps: 4
learning_rate: 0.0001
learning_rate_scheduler:
decay: cosine
warmup_fraction: 0.03
"""
)

df = code_alpaca.load(split=False)
model = LudwigModel(config=fine_tuning_config, logging_level=logging.INFO)

(
train_stats, # dictionary containing training statistics
preprocessed_data, # tuple Ludwig Dataset objects of pre-processed training data
output_directory, # location of training results stored on disk
) = model.train(
dataset=df,
experiment_name="code_alpaca_instruct",
model_name="llama2_7b",
)

# List contents of output directory
print("Contents of output directory:", output_directory)
for item in os.listdir(output_directory):
print("\t", item)

# Run Inference
print("Predict")
prediction_df = pd.DataFrame(
[
{
"instruction": "Create an array of length 5 which contains all even numbers between 1 and 10.",
"input": "",
},
{
"instruction": "Create an array of length 15 containing numbers divisible by 3 up to 45.",
"input": "",
},
{
"instruction": "Create a nested loop to print every combination of numbers between 0-9",
"input": "",
},
{
"instruction": "Generate a function that computes the sum of the numbers in a given list",
"input": "",
},
{
"instruction": "Create a class to store student names, ages and grades.",
"input": "",
},
{
"instruction": "Print out the values in the following dictionary.",
"input": "my_dict = {\n 'name': 'John Doe',\n 'age': 32,\n 'city': 'New York'\n}",
},
]
)
preds, _ = model.predict(dataset=prediction_df)
preds = preds.compute()
for input_with_prediction in zip(prediction_df["instruction"], prediction_df["input"], preds["output_response"]):
print(f"Instruction: {input_with_prediction[0]}")
print(f"Input: {input_with_prediction[1]}")
print(f"Generated Output: {input_with_prediction[2][0]}")
print("\n\n")
Loading
Loading