Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SD3.5-Large DreamBooth Training - Over 80GB VRAM Usage #10412

Open
deman311 opened this issue Dec 30, 2024 · 1 comment
Open

SD3.5-Large DreamBooth Training - Over 80GB VRAM Usage #10412

deman311 opened this issue Dec 30, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@deman311
Copy link

Describe the bug

⚠️ We are running out of memory on step 0

❕It does work without '--train_text_encoder'. It seems that there might be a memory leak or issue with training the text encoder with the current script / model.
❓Does it make sense that the model uses over 80GB of VRAM?
❓Do you have any recommendations on decreasing VRAM usage
Other than:
. 8bit Adam
. Mixed precision 16fp
. xformers (that doesn't work with SD3.5)

💡Idea:
After successfully training with the Kohya-ss scripts: Relevant Repo,
I have deducted that the issue might be with the Dreambooth scripts here not using 8bitAdam properly; either ignoring or a bug might be in the implementation itself. This is due to the fact that the only single parameter that had a massive effect on VRAM and caused a massive surge is not using Adam8Bit optimizer, otherwise the seemingly same parameters in Kohya-ss.

Kohya-ss Parameters for reference 📝

# Models

pretrained_model_name_or_path = "/kohya_ss/models/sd3.5_large.safetensors"

# Captioning

cache_latents = true
caption_dropout_every_n_epochs = 0
caption_dropout_rate = 0
caption_extension = ".txt"
clip_skip = 1
keep_tokens = 0

# Text Encoder Training

use_t5xxl_cache_only = true
t5xxl_dtype = "fp16"
train_text_encoder = true

# Learning Rates 

learning_rate = 5e-6
learning_rate_te1 = 1e-5 
learning_rate_te2 = 1e-5
loss_type = "l2"
lr_scheduler = "cosine"
lr_scheduler_args = []
lr_scheduler_num_cycles = 1
lr_scheduler_power = 0.5
lr_warmup_steps = 0
optimizer_type = "AdamW8bit"

# Batch Sizes

text_encoder_batch_size = 1
train_batch_size = 1
epoch = 1
persistent_data_loader_workers = 0
max_data_loader_n_workers = 0

# Buckets, Noise & SNR

max_bucket_reso = 2048
min_bucket_reso = 256
bucket_no_upscale = true
bucket_reso_steps = 64
huber_c = 0.1
huber_schedule = "snr"
min_snr_gamma = 5
prior_loss_weight = 1
max_timestep = 1000
multires_noise_discount = 0.3
multires_noise_iterations = 0
noise_offset = 0
noise_offset_type = "Original"
adaptive_noise_scale = 0

# SD3 Logits

mode_scale = 1.29
weighting_scheme = "logit_normal"
logit_mean = 0
logit_std = 1

# VRAM Optimization

resolution = "512,512"
max_token_length = 75
max_train_steps = 800
mem_eff_attn = true
mixed_precision = "fp16"
full_fp16 = true
gradient_accumulation_steps = 1
gradient_checkpointing = true
xformers = true
dynamo_backend = "no"

# Sampling

sample_every_n_epochs = 50
sample_sampler = "euler"

# Model Saving

save_every_n_steps = 200
save_model_as = "diffusers"
save_precision = "fp16"

# General

output_name = "last"
log_with = "tensorboard"

Reproduction

We are running the following command in Jupyter Notebook:

!accelerate launch train_dreambooth_sd3.py
--pretrained_model_name_or_path="stabilityai/stable-diffusion-3.5-large"
--output_dir="sd_outputs"
--instance_data_dir="ogo"
--instance_prompt="the face of ogo person"
--resolution=512
--train_batch_size=1
--gradient_accumulation_steps=2
--gradient_checkpointing
--checkpointing_steps=200
--learning_rate=2e-6
--text_encoder_lr=1e-6
--train_text_encoder
--lr_scheduler="constant"
--lr_warmup_steps=0
--max_train_steps=800
--seed="0"
--use_8bit_adam
--mixed_precision="fp16"

Logs

2024-12-02 12:36:35.615846: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1733142995.629356 226993 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1733142995.633681 226993 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
12/02/2024 12:36:39 - INFO - main - Distributed environment: DistributedType.NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: no

You set add_prefix_space. The tokenizer needs to be converted from the slow tokenizers
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type t5 to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
{'base_shift', 'max_image_seq_len', 'max_shift', 'base_image_seq_len', 'invert_sigmas', 'use_dynamic_shifting'} was not found in config. Values will be initialized to default values.
Downloading shards: 100%|███████████████████████| 2/2 [00:00<00:00, 3450.68it/s]
Loading checkpoint shards: 100%|██████████████████| 2/2 [00:03<00:00, 1.73s/it]
Fetching 2 files: 100%|█████████████████████████| 2/2 [00:00<00:00, 7476.48it/s]
{'dual_attention_layers'} was not found in config. Values will be initialized to default values.
12/02/2024 12:37:04 - INFO - main - ***** Running training *****
12/02/2024 12:37:04 - INFO - main - Num examples = 1
12/02/2024 12:37:04 - INFO - main - Num batches each epoch = 1
12/02/2024 12:37:04 - INFO - main - Num Epochs = 800
12/02/2024 12:37:04 - INFO - main - Instantaneous batch size per device = 1
12/02/2024 12:37:04 - INFO - main - Total train batch size (w. parallel, distributed & accumulation) = 2
12/02/2024 12:37:04 - INFO - main - Gradient Accumulation steps = 2
12/02/2024 12:37:04 - INFO - main - Total optimization steps = 800
Steps: 0%| | 0/800 [00:00<?, ?it/s]Traceback (most recent call last):
File "/home/azureuser/Picturethis/Dima/train_dreambooth_sd3.py", line 1811, in
main(args)
File "/home/azureuser/Picturethis/Dima/train_dreambooth_sd3.py", line 1666, in main
optimizer.step()
File "/home/azureuser/mambaforge/envs/picturevenv/lib/python3.11/site-packages/accelerate/optimizer.py", line 171, in step
self.optimizer.step(closure)
File "/home/azureuser/mambaforge/envs/picturevenv/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper
return func.get(opt, opt.class)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/mambaforge/envs/picturevenv/lib/python3.11/site-packages/torch/optim/optimizer.py", line 487, in wrapper
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/mambaforge/envs/picturevenv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/mambaforge/envs/picturevenv/lib/python3.11/site-packages/bitsandbytes/optim/optimizer.py", line 288, in step
self.init_state(group, p, gindex, pindex)
File "/home/azureuser/mambaforge/envs/picturevenv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/mambaforge/envs/picturevenv/lib/python3.11/site-packages/bitsandbytes/optim/optimizer.py", line 474, in init_state
state["state2"] = self.get_state_buffer(p, dtype=torch.uint8)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/mambaforge/envs/picturevenv/lib/python3.11/site-packages/bitsandbytes/optim/optimizer.py", line 328, in get_state_buffer
return torch.zeros_like(p, dtype=dtype, device=p.device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 79.15 GiB of which 10.62 MiB is free. Process 68964 has 530.00 MiB memory in use. Including non-PyTorch memory, this process has 78.45 GiB memory in use. Of the allocated memory 75.60 GiB is allocated by PyTorch, and 2.35 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Steps: 0%| | 0/800 [00:02<?, ?it/s]
Traceback (most recent call last):
File "/home/azureuser/mambaforge/envs/picturevenv/bin/accelerate", line 8, in
sys.exit(main())
^^^^^^
File "/home/azureuser/mambaforge/envs/picturevenv/lib/python3.11/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
args.func(args)
File "/home/azureuser/mambaforge/envs/picturevenv/lib/python3.11/site-packages/accelerate/commands/launch.py", line 1168, in launch_command
simple_launcher(args)
File "/home/azureuser/mambaforge/envs/picturevenv/lib/python3.11/site-packages/accelerate/commands/launch.py", line 763, in simple_launcher
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/home/azureuser/mambaforge/envs/picturevenv/bin/python3.11', 'train_dreambooth_sd3.py', '--pretrained_model_name_or_path=stabilityai/stable-diffusion-3.5-large', '--output_dir=sd_outputs', '--instance_data_dir=ogo', '--instance_prompt=the face of ogo person', '--resolution=512', '--train_batch_size=1', '--gradient_accumulation_steps=2', '--gradient_checkpointing', '--checkpointing_steps=200', '--learning_rate=2e-6', '--text_encoder_lr=1e-6', '--lr_scheduler=constant', '--lr_warmup_steps=0', '--max_train_steps=800', '--seed=0', '--use_8bit_adam']' returned non-zero exit status 1.

System Info

System 🖥️

A100 Azure Remote Server.
Running the code from Jupyter Notebook.

Libraries 📚

torch==2.5.1+cu124
torchvision==0.20.0+cu124
xformers==0.0.28.post2

bitsandbytes==0.44.0
tensorboard==2.15.2
tensorflow==2.15.0.post1
onnxruntime-gpu==1.19.2

accelerate==0.33.0
aiofiles==23.2.1
altair==4.2.2
dadaptation==3.2
diffusers[torch]==0.25.0
easygui==0.98.3
einops==0.7.0
fairscale==0.4.13
ftfy==6.1.1
gradio==5.4.0
huggingface-hub==0.25.2
imagesize==1.4.1
invisible-watermark==0.2.0
lion-pytorch==0.0.6
lycoris_lora==3.1.0
omegaconf==2.3.0
onnx==1.16.1
prodigyopt==1.0
protobuf==3.20.3
open-clip-torch==2.20.0
opencv-python==4.10.0.84
prodigyopt==1.0
pytorch-lightning==1.9.0
rich>=13.7.1
safetensors==0.4.4
schedulefree==1.2.7
scipy==1.11.4
# for T5XXL tokenizer (SD3/FLUX)
sentencepiece==0.2.0
timm==0.6.12
tk==0.1.0
toml==0.10.2
transformers==4.44.2
voluptuous==0.13.1
wandb==0.18.0

Who can help?

No response

@deman311 deman311 added the bug Something isn't working label Dec 30, 2024
@hlky
Copy link
Collaborator

hlky commented Dec 30, 2024

kohya-ss has a separate option to enable training t5xxl whereas --train_text_encoder in train_dreambooth_sd3.py enables training for all text encoders, this may account for the difference in usage if other parameters are the same. We could consider a similar option to enable t5xxl training separately from CLIP.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants