Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] AttributeError: 'NoneType' object has no attribute 'swap_folder' #4998

Open
maywind23 opened this issue Jan 23, 2024 · 2 comments · May be fixed by #6963
Open

[BUG] AttributeError: 'NoneType' object has no attribute 'swap_folder' #4998

maywind23 opened this issue Jan 23, 2024 · 2 comments · May be fixed by #6963
Assignees
Labels
bug Something isn't working training

Comments

@maywind23
Copy link

Describe the bug
I tried to train mixtral-7*8B with ZeRO 3, and got an error when saving a checkpoint, just as follows:

Traceback (most recent call last):
  File "/home/maywind/Google Drive/finllm/training_parallel/mixtral_mp_fp16_train_lora.py", line 192, in <module>
    main()
  File "/home/maywind/Google Drive/finllm/training_parallel/mixtral_mp_fp16_train_lora.py", line 185, in main
    trainer.train()
  File "/home/maywind/anaconda3/envs/finllm/lib/python3.10/site-packages/transformers/trainer.py", line 1537, in train
    return inner_training_loop(
  File "/home/maywind/anaconda3/envs/finllm/lib/python3.10/site-packages/transformers/trainer.py", line 1914, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
  File "/home/maywind/anaconda3/envs/finllm/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _maybe_log_save_evaluate
    self._save_checkpoint(model, trial, metrics=metrics)
  File "/home/maywind/anaconda3/envs/finllm/lib/python3.10/site-packages/transformers/trainer.py", line 2359, in _save_checkpoint
    self._save_optimizer_and_scheduler(staging_output_dir)
  File "/home/maywind/anaconda3/envs/finllm/lib/python3.10/site-packages/transformers/trainer.py", line 2453, in _save_optimizer_and_scheduler
    self.model_wrapped.save_checkpoint(output_dir)
  File "/home/maywind/anaconda3/envs/finllm/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 3117, in save_checkpoint
    offload_dir = self.optimizer.optimizer_swapper.swap_folder
AttributeError: 'NoneType' object has no attribute 'swap_folder'

To Reproduce

  1. main Code:
    deepspeed_config = "./ds_zero3_offload_nvme_mixtral.json"
    deepspeed.init_distributed(dist_backend="nccl", timeout=datetime.timedelta(seconds=5400))

    training_args = TrainingArguments(
        output_dir='./finetuned_mixtral_8x7B_bf16',
        logging_steps=20,
        # max_steps=10000,
        num_train_epochs=1,
        per_device_train_batch_size=3*4,
        gradient_accumulation_steps=8,
        learning_rate=1e-4,
        weight_decay=0.01,
        warmup_steps=500,
        save_steps=20,
        # fp16=True,
        bf16=True,
        deepspeed=deepspeed_config,
        torch_compile=True,
        load_best_model_at_end=True,
        evaluation_strategy="steps",
        remove_unused_columns=False,
    )

    model = MixtralForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
    )
    deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
    ...
  1. ds_config
{
    "bf16": {
        "enabled": true
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "weight_decay": "auto"
        }
    },
    "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto",
            "total_num_steps": "auto"
        }
    },
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true,
            "ratio": 0.3,
            "buffer_count": 4,
            "fast_init": false
        },
        "offload_param": {
            "device": "nvme",
            "nvme_path": "/media/maywind/Data",
            "pin_memory": true,
            "buffer_count": 30,
            "buffer_size": 3e8,
            "max_in_cpu": 2e10
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "sub_group_size": 1e9,
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e8,
        "stage3_gather_16bit_weights_on_model_save": "auto"
    },
    "gradient_clipping": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "train_batch_size": "auto",
    "gradient_accumulation_steps": "auto",
    "wall_clock_breakdown": false
}
  1. run the script:
export CUDA_VISIBLE_DEVICES=0,1,2
export NCCL_P2P_DISABLE=1
export NCCL_IB_DISABLE=1
export NCCL_DEBUG=INFO
accelerate launch mixtral_mp_fp16_train_lora.py

Expected behavior
Successfully save the checkpoint.

System info:

  • OS: [Ubuntu 23.04]
  • RTX 4090*3
  • Deepspeed 0.13.0; transformers 4.36.2; accelerate 0.26.1
  • Single node

I have updated the lastest version of deepspeed-0.13.0 that supports saving checkpoint under ZeRO 3, but unfortunately with a failure for Mixtral-7*8B model. It seems that the error AttributeError: 'NoneType' object has no attribute 'swap_folder' was caused by self.optimizer.optimizer_swapper?
Could you please resolve this issue? Or any help or comments are mostly welcome!

@maywind23 maywind23 added bug Something isn't working training labels Jan 23, 2024
@maywind23
Copy link
Author

@tjruwase @mrwyattii @loadams
I tested a fine-tuning of llama2-70B in the same environment(hardware and software), and it interrupted with the same error AttributeError: 'NoneType' object has no attribute 'swap_folder'. Please check the error above in detail.
I would greatly appreciate any guidance or assistance you can provide and am willing to offer more information if needed. Thank you very much for your time and effort.

@V1ki
Copy link

V1ki commented Sep 24, 2024

meet same error , this because offload_param set to nvme , but offload_optimizer is cpu . that make the error flow . just see the code in deepspeed/runtime/zero/stage3.py:

    def _configure_offloading(self, offload_optimizer_config, offload_param_config):
        ###################### offload optimizer setup ##################################
        if offload_optimizer_config is not None and offload_optimizer_config.device != OffloadDeviceEnum.none:
            self.offload_optimizer = True
            self.offload_optimizer_pin_memory = offload_optimizer_config.pin_memory
            self.swap_optimizer = offload_optimizer_config.device == OffloadDeviceEnum.nvme
            self.offload_optimizer_fast_init = offload_optimizer_config.fast_init

        ###################### offload param setup ##################################
        if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none:
            self.offload_param = True
            self.offload_param_pin_memory = offload_param_config.pin_memory
            self.params_in_nvme_and_cpu = offload_param_config.device == OffloadDeviceEnum.nvme
            self.max_params_in_cpu = offload_param_config.max_in_cpu
            print_rank_0(
                f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}",
                force=False)

offload_optimizer 's device is cpu , that makes swap_optimizer false , then the most import thing , optimizer_swapper can not be init !!!! :

        # Optimizer tensor swapping
        if self.swap_optimizer:
            self._configure_tensor_swapping(offload_optimizer_config, aio_config)
... 
    def _configure_tensor_swapping(self, offload_optimizer_config, aio_config):
        nvme_swap_folder = os.path.join(offload_optimizer_config.nvme_path, 'zero_stage_3')
        os.makedirs(nvme_swap_folder, exist_ok=True)
        if dist.get_rank() == 0:
            logger.info(f'Tensor Swapping: Adding optimizer tensors')

        swapper_type = PipelinedOptimizerSwapper if offload_optimizer_config.pipeline else PartitionedOptimizerSwapper

        self.optimizer_swapper = swapper_type(swap_config=offload_optimizer_config,
                                              aio_config=aio_config,
                                              base_folder=nvme_swap_folder,
                                              optimizer=self.optimizer,
                                              largest_numel=max(self.fp16_partitioned_groups_flat_numel),
                                              device=self.device,
                                              dtype=torch.float32,
                                              timers=self.timers)

then the error happened , so i try to change the zero_has_nvme_offload to :


    def zero_has_nvme_offload(self):
        if not hasattr(self.optimizer, "optimizer_swapper"):
            return False
        return self.optimizer.swap_optimizer or self.optimizer.params_in_nvme_and_cpu

which the origin code is :

    def zero_has_nvme_offload(self):
        if not hasattr(self.optimizer, "swap_optimizer"):
            return False
        return self.optimizer.swap_optimizer or self.optimizer.params_in_nvme_and_cpu

this way may fix this error .

the other way is set offload_optimizer device to nvme .

@tjruwase tjruwase self-assigned this Jan 20, 2025
@tjruwase tjruwase linked a pull request Jan 20, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants