Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP not work well with PEFT and some error in mixed precision when use peft in fine-tune #837

Open
1 of 2 tasks
blurmemo opened this issue Jan 11, 2025 · 2 comments
Open
1 of 2 tasks

Comments

@blurmemo
Copy link

blurmemo commented Jan 11, 2025

System Info

GPUs: A100 80G x1 and A100 40G x2
Pytorch version: latest
CUDA version: 11.8

Information

  • The official example scripts
  • My own modified scripts

🐛 Describe the bug

Bug 1

Description: mixed precision error when use peft in fine-tune
GPUs: number 5 is A100 80G and number 3 and 4 are both A100 40G
Command: CUDA_VISIBLE_DEVICES=5,3,4 torchrun --standalone --nnodes=1 --nproc-per-node=3 finetuning.py
Solved: Yes
Reason: When load model at first, it will set model dtype is fp16 or bf16. However if use peft(lora) to fine-tune, it will load in fp32 dtype. So we will see mixed precision error because of model is fp16/bf16 and lora modules is fp32

Bug 2

Description: FSDP not work well with PEFT, CUDA out of memory
GPUs: number 5 is A100 80G and number 3 and 4 are both A100 40G
Command: CUDA_VISIBLE_DEVICES=5,3,4 torchrun --standalone --nnodes=1 --nproc-per-node=3 finetuning.py
Solved: No
Question: I write my own finetuning.py by imitating official example script and do not modify codes about FSDP core codes. When I run command, CUDA out of memory is always. I watch GPUs memory and find three GPUs will load all model weights which about 41G equally. I am very confused, when I use FULL_SHARD ShardingStrategy, it does not split model weights and load them to GPUs rather load all model weights in every GPUs. Is my relevant config set incorrectly or is it just how it works.

Error logs

Bug 2

W0111 04:50:12.577000 2572097 site-packages/torch/distributed/run.py:793] 
W0111 04:50:12.577000 2572097 site-packages/torch/distributed/run.py:793] *****************************************
W0111 04:50:12.577000 2572097 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0111 04:50:12.577000 2572097 site-packages/torch/distributed/run.py:793] *****************************************
/PATH/lib/python3.10/site-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/PATH/lib/python3.10/site-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/PATH/lib/python3.10/site-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:16<00:00,  3.25s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:16<00:00,  3.30s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:16<00:00,  3.35s/it]
trainable params: 5,898,240 || all params: 10,676,119,075 || trainable%: 0.0552
--> Model /PATH/models/llama/llama_vision_11B_instruct/hf

--> /PATH/models/llama/llama_vision_11B_instruct/hf has 10670.220835 Million params

After freezing the model:
--> /PATH/models/llama/llama_vision_11B_instruct/hf has 9806.653456 Million trainable params

--> Model state after freezing:
    vision_model: Frozen
    language_model: Unfrozen
    multi_modal_projector: Unfrozen

trainable params: 5,898,240 || all params: 10,676,119,075 || trainable%: 0.0552
trainable params: 5,898,240 || all params: 10,676,119,075 || trainable%: 0.0552
bFloat16 enabled for mixed precision - using bfSixteen policy
[rank1]: Traceback (most recent call last):
[rank1]:   File "/PATH/tools/finetuning.py", line 438, in <module>
[rank1]:     fire.Fire(main)
[rank1]:   File "/PATH/lib/python3.10/site-packages/fire/core.py", line 135, in Fire
[rank1]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank1]:   File "/PATH/lib/python3.10/site-packages/fire/core.py", line 468, in _Fire
[rank1]:     component, remaining_args = _CallAndUpdateTrace(
[rank1]:   File "/PATH/lib/python3.10/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
[rank1]:     component = fn(*varargs, **kwargs)
[rank1]:   File "/data0/home/ening/NICA/cogmllm/src/cogmllm/tools/finetuning.py", line 282, in main
[rank1]:     model = FSDP(
[rank1]:   File "/PATH/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 509, in __init__
[rank1]:     _init_param_handle_from_module(
[rank1]:   File "/PATH/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 636, in _init_param_handle_from_module
[rank1]:     _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank1]:   File "/PATH/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 648, in _init_param_handle_from_params
[rank1]:     handle = FlatParamHandle(
[rank1]:   File "/PATH/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 584, in __init__
[rank1]:     self._init_flat_param_and_metadata(
[rank1]:   File "/PATH/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 739, in _init_flat_param_and_metadata
[rank1]:     self.flat_param: FlatParameter = self.flatten_tensors_into_flat_param(
[rank1]:   File "/PATH/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 852, in flatten_tensors_into_flat_param
[rank1]:     flat_param_data = self.flatten_tensors(tensors, aligned_numel)
[rank1]:   File "/PATH/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 844, in flatten_tensors
[rank1]:     return torch.cat(flat_tensors, dim=0)
[rank1]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 19.88 GiB. GPU 1 has a total capacity of 39.38 GiB of which 18.80 GiB is free. Including non-PyTorch memory, this process has 20.57 GiB memory in use. Of the allocated memory 19.89 GiB is allocated by PyTorch, and 208.63 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
W0111 04:51:07.712000 2572097 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2572179 closing signal SIGTERM
W0111 04:51:07.713000 2572097 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2572181 closing signal SIGTERM
E0111 04:51:10.196000 2572097 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 1 (pid: 2572180) of binary: /PATH/bin/python
Traceback (most recent call last):
  File "/PATH/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/PATH/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
  File "/PATH/lib/python3.10/site-packages/torch/distributed/run.py", line 919, in main
    run(args)
  File "/PATH/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
    elastic_launch(
  File "/PATH/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/PATH/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
finetuning.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-01-11_04:51:07
  host      : host
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 2572180)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Expected behavior

Bug 2

I write my own finetuning.py by imitating official example script and do not modify codes about FSDP core codes. When I run command, CUDA out of memory is always. I watch GPUs memory and find three GPUs will load all model weights which about 41G equally. I am very confused, when I use FULL_SHARD ShardingStrategy, it does not split model weights and load them to GPUs rather load all model weights in every GPUs. Is my relevant config set incorrectly or is it just how it works.

Codes:

def setup_wandb(train_config, fsdp_config, **kwargs):
    try:
        import wandb
    except ImportError:
        raise ImportError(
            "You are trying to use wandb which is not currently installed. "
            "Please install it using pip install wandb"
        )
    from cogmllm.configs import wandb_config as WANDB_CONFIG
    wandb_config = WANDB_CONFIG()
    update_config(wandb_config, **kwargs)
    init_dict = dataclasses.asdict(wandb_config)
    run = wandb.init(**init_dict)
    run.config.update(train_config)
    run.config.update(fsdp_config, allow_val_change=True)
    return run


def main(**kwargs):
    # init train and fsdp config with default
    train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
    # update config through cmd-args
    update_config((train_config, fsdp_config), **kwargs)
    torch.manual_seed(train_config.seed)
    random.seed(train_config.seed)
    np.random.seed(train_config.seed)

    # init when using fsdp in multiple node and gpu
    if train_config.enable_fsdp:
        setup()
        # torchrun specific
        local_rank = int(os.environ["LOCAL_RANK"])
        # rank:0 == gpu:0 in all gpus among all machines
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])

    # init distribute train
    if torch.distributed.is_initialized():
        if torch.cuda.is_available():
            torch.cuda.set_device(local_rank)
        clear_gpu_cache(local_rank)
        setup_environ_flags(rank)

    wandb_run = None

    if train_config.use_wandb:
        if not train_config.enable_fsdp or rank == 0:
            wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)

    # setting quantization configs
    bnb_config = None
    if train_config.quantization:
        if type(train_config.quantization) == type(True):
            warn(
                "Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.",
                FutureWarning)
            train_config.quantization = "8bit"

        if train_config.quantization == "8bit" and train_config.enable_fsdp:
            raise ValueError("8bit quantization is not supported with FSDP, please use 4bit quantization")

        quant_config = QUANTIZATION_CONFIG()
        update_config(quant_config, **kwargs)
        bnb_config = quant_config.create_bnb_config(train_config.quantization)

    use_cache = False if train_config.enable_fsdp else None

    # arg = the dir path including huggingface's config.json
    config = AutoConfig.from_pretrained(train_config.model_name)
    # multi-mode model
    if config.model_type == "mllama":
        is_vision = True
        # load model checkpoint
        if train_config.use_fast_model:
            model = FastMllamaForConditionalGeneration.from_pretrained
        else:
            model = MllamaForConditionalGeneration.from_pretrained
        model = model(
            train_config.model_name,
            quantization_config=bnb_config,
            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
            device_map=(
                train_config.device_map  # default is auto
                if train_config.quantization and not train_config.enable_fsdp
                else None
            ),
            # torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
        )
        # load image and text processor
        processor = AutoProcessor.from_pretrained(
            train_config.model_name
            if train_config.tokenizer_name is None
            else train_config.tokenizer_name
        )
        processor.tokenizer.padding_side = 'right'
        # enable gradient checkpointing
        # throw some middle activations when forward, and re-compute those in backward
        model.supports_gradient_checkpointing = True
        model.language_model.supports_gradient_checkpointing = True
    # language model
    elif config.model_type == "llama":
        is_vision = False
        model = LlamaForCausalLM.from_pretrained(
            train_config.model_name,
            quantization_config=bnb_config,
            use_cache=use_cache,  # train mode does not use, only infer use
            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
            device_map=(
                train_config.device_map
                if train_config.quantization and not train_config.enable_fsdp
                else None
            ),
            # torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
        )
    else:
        raise ValueError(f"Model type {config.model_type} is not supported. Please use llama or mllama model.")

    # Load the tokenizer and add special tokens
    tokenizer = AutoTokenizer.from_pretrained(
        train_config.model_name
        if train_config.tokenizer_name is None
        else train_config.tokenizer_name
    )
    if not tokenizer.pad_token_id:    # pad_token_id=128004, eos_token_id=128009
        tokenizer.pad_token_id = tokenizer.eos_token_id

    # If there is a mismatch between tokenizer vocab size and model token_embedding matrix,
    # throw a warning and then expand the embedding matrix
    # tokenizer vocab size=128257, model embedding matrix=128264
    if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
        print(
            "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
        )
        model.resize_token_embeddings(len(tokenizer))

    # single gpu rank is 0
    print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)

    # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
    if train_config.enable_fsdp and fsdp_config.pure_bf16 and not train_config.quantization:
        model.to(torch.bfloat16)

    if train_config.freeze_vision:
        freeze_LLM_vision(model)
        # print model size and frozen layers after freezing layers
        print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)

    # use lora !!!
    # set peft default is lora
    if train_config.use_peft:
        # Load the pre-trained peft model checkpoint and setup its configuration
        if train_config.from_peft_checkpoint:  # load from peft local checkpoint
            model = PeftModel.from_pretrained(
                model, train_config.from_peft_checkpoint, is_trainable=True
            )
            peft_config = model.peft_config
        # Generate the peft config and start fine-tuning from original model
        else:
            peft_config = generate_peft_config(train_config, kwargs)
            model = get_peft_model(model, peft_config)
        if wandb_run:
            wandb_run.config.update(peft_config)
        model.print_trainable_parameters()

    if train_config.use_fp16:
        model.to(torch.float16)
    else:
        model.to(torch.bfloat16)

    # single node may do not need hsdp
    hsdp_device_mesh_plan = None
    if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
        hsdp_device_mesh_plan = hsdp_device_mesh(
            replica_group_size=fsdp_config.replica_group_size,
            sharding_group_size=fsdp_config.sharding_group_size
        )
        print("HSDP device mesh is ready")

    # setting up FSDP if enable_fsdp is enabled
    if train_config.enable_fsdp:
        check_fsdp_config(fsdp_config)

        if not train_config.use_peft and train_config.freeze_layers:
            freeze_transformer_layers(model, train_config.num_freeze_layers)
            # print model size and frozen layers after freezing layers
            print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)

        if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
            freeze_LLM_only(model)
            # print model size and frozen layers after freezing layers
            print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)

        # wrapping_policy decide which modules are needed to wrap in fsdp
        mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
        # when use peft, create the FSDP wrapper for
        # MllamaSelfAttentionDecoderLayer,
        # MllamaCrossAttentionDecoderLayer,
        # MllamaVisionEncoderLayer in vision models
        if is_vision:
            my_auto_wrapping_policy = fsdp_auto_wrap_policy(
                model,
                [
                    MllamaSelfAttentionDecoderLayer,
                    MllamaCrossAttentionDecoderLayer,
                    # MllamaVisionEncoderLayer,
                ],
            )
        else:
            # Create the FSDP wrapper for LlamaDecoderLayer in text models
            my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
        device_id = 0
        if torch.cuda.is_available():
            device_id = torch.cuda.current_device()
        if train_config.freeze_LLM_only or train_config.freeze_vision:
            use_orig_params = True
        else:
            use_orig_params = False
        model = FSDP(
            model,
            auto_wrap_policy=(
                my_auto_wrapping_policy if train_config.use_peft else wrapping_policy
            ),
            cpu_offload=(
                CPUOffload(offload_params=True)
                if fsdp_config.fsdp_cpu_offload
                else None
            ),
            mixed_precision=(
                mixed_precision_policy if not fsdp_config.pure_bf16 else None
            ),
            sharding_strategy=fsdp_config.sharding_strategy,
            device_mesh=hsdp_device_mesh_plan,
            device_id=device_id,
            limit_all_gathers=True,
            sync_module_states=train_config.low_cpu_fsdp,
            param_init_fn=(
                (
                    lambda module: module.to_empty(
                        device=torch.device("cuda"), recurse=False
                    )
                )
                if train_config.low_cpu_fsdp and rank != 0
                else None
            ),
            use_orig_params=use_orig_params,
        )
        if fsdp_config.fsdp_activation_checkpointing:
            model.enable_input_require_grads()
            model.gradient_checkpointing_enable()
            apply_fsdp_checkpointing(model)
    # single gpu for
    elif not train_config.quantization and not train_config.enable_fsdp:
        if torch.cuda.is_available():
            model.to(train_config.device_map)

    dataset_config = generate_dataset_config(train_config, kwargs)

    if is_vision:
        # MllamaProcessor
        dataset_processer = processor
    else:
        dataset_processer = tokenizer

    # Load and preprocess the dataset for training and validation
    dataset_train = get_preprocessed_dataset(
        dataset_processer,
        dataset_config,
        split="train",
    )
    if not train_config.enable_fsdp or rank == 0:
        print(f"--> Training Set Length = {len(dataset_train)}")

    dataset_val = get_preprocessed_dataset(
        dataset_processer,
        dataset_config,
        split="val",
    )
    if not train_config.enable_fsdp or rank == 0:
        print(f"--> Validation Set Length = {len(dataset_val)}")

    # vision + text dataset batching_strategy should be padding
    if train_config.batching_strategy == "packing":
        if is_vision:
            raise ValueError("Packing is not supported for vision datasets")
        else:
            dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)

    # set train and val dataloader
    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
    print("length of dataset_train", len(dataset_train))


    custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
    if custom_data_collator:
         print("custom_data_collator is used")
         train_dl_kwargs["collate_fn"] = custom_data_collator


    # Create DataLoaders for the training and validation dataset
    train_dataloader = torch.utils.data.DataLoader(
        dataset_train,
        num_workers=train_config.num_workers_dataloader,
        pin_memory=True,
        **train_dl_kwargs,
    )
    print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")

    eval_dataloader = None
    if train_config.run_validation:
        if train_config.batching_strategy == "packing":
            if is_vision:
                raise ValueError("Packing is not supported for vision datasets")
            else:
                dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)

        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")

        if custom_data_collator:
            val_dl_kwargs["collate_fn"] = custom_data_collator

        eval_dataloader = torch.utils.data.DataLoader(
            dataset_val,
            num_workers=train_config.num_workers_dataloader,
            pin_memory=True,
            **val_dl_kwargs,
        )
        print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
        if len(eval_dataloader) == 0:
            raise ValueError(
                "The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
        else:
            print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")

    # Initialize the optimizer and learning rate scheduler
    if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
        optimizer = AnyPrecisionAdamW(
            model.parameters(),
            lr=train_config.lr,
            momentum_dtype=torch.bfloat16,
            variance_dtype=torch.bfloat16,
            use_kahan_summation=False,
            weight_decay=train_config.weight_decay,
        )
    else:
        optimizer = optim.AdamW(
            model.parameters(),
            lr=train_config.lr,
            weight_decay=train_config.weight_decay,
        )
    scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)

    results = train(
        model,
        train_dataloader,
        eval_dataloader,
        tokenizer,
        optimizer,
        scheduler,
        train_config.gradient_accumulation_steps,
        train_config,
        fsdp_config if train_config.enable_fsdp else None,
        local_rank if train_config.enable_fsdp else None,
        rank if train_config.enable_fsdp else None,
        wandb_run,
    )
    if not train_config.enable_fsdp or rank == 0:
        [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
        if train_config.use_wandb:
            for k, v in results.items():
                wandb_run.summary[k] = v


if __name__ == '__main__':
    fire.Fire(main)
@blurmemo
Copy link
Author

@HamidShojanazeri

@HamidShojanazeri
Copy link
Contributor

cc: @mreso

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants