Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unwrapping for generation with FSDP #3425

Open
2 of 4 tasks
VityaVitalich opened this issue Mar 6, 2025 · 1 comment
Open
2 of 4 tasks

Unwrapping for generation with FSDP #3425

VityaVitalich opened this issue Mar 6, 2025 · 1 comment

Comments

@VityaVitalich
Copy link

VityaVitalich commented Mar 6, 2025

System Info

accelerate==1.4.0

compute_environment: LOCAL_MACHINE
debug: true
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
  fsdp_activation_checkpointing: false
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: SHARD_GRAD_OP
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

I attach the code to reproduce the error, I am facing. One may use a simple following script. What is done there - is the model is wrapped with accelerate and then unwrapped, however the generation is impossible since the parameters are sharded. The solution with fsdp.summon_full_params() does not work for me, since i need to switch between generation and training multiple times, that breaks the model is doing it like that

import argparse
import yaml
import torch
import torch.nn as nn

from utils.generation_utils import load_config
from accelerate.utils import extract_model_from_parallel

from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP


def main():
   # parser = argparse.ArgumentParser()
   # parser.add_argument("--config_path", type=str, required=True, help="Path to YAML config file")
   # args = parser.parse_args()

    # ✅ 1. Load config
    #config = load_config(args.config_path)
    config = {
        'model_path': 'Qwen/Qwen2.5-Math-1.5B-Instruct',
        'cache_dir': './'
    }

    # ✅ 2. Initialize Accelerator (with FSDP/DeepSpeed)
    accelerator = Accelerator()
   # print(f"\n🚀 [INFO] Using Accelerator with {accelerator.state.fsdp_plugin}\n")

    # ✅ 3. Load Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        config["model_path"], use_fast=True, cache_dir=config["cache_dir"]
    )

    # ✅ 4. Load Model
    print(f"🔄 Loading model from: {config['model_path']}...")
    model = AutoModelForCausalLM.from_pretrained(
        config["model_path"],
        torch_dtype=torch.bfloat16,  # Match mixed precision settings
        cache_dir=config["cache_dir"]
    )


    # ✅ 5. Prepare model for distributed training
    model = accelerator.prepare_model(model)
    print("✅ Model wrapped with Accelerator")
    
    print("\n🔍 Model Parameter Shapes:")
    print("   🔹 embed_tokens:",  model.model.embed_tokens, model.model.embed_tokens.weight.shape)
    print("   🔹 lm_head:", model.lm_head, model.lm_head.weight.shape)

    # ✅ 6. Unwrap model to check if FSDP/DeepSpeed is correctly handling parameters
    unwrapped_model = accelerator.unwrap_model(model)
    print(unwrapped_model.__class__.__name__)

    # ✅ 8. Verify Embeddings & LM Head
    print("\n🔍 Model Parameter Shapes:")
    print("   🔹 embed_tokens:", unwrapped_model.model.embed_tokens.weight.shape)
    print("   🔹 lm_head:", unwrapped_model.lm_head.weight.shape)


    # ✅ 9. Ensure `lm_head` is correctly restored (not flattened)
    if len(unwrapped_model.lm_head.weight.shape) != 2:
        print("❌ ERROR: `lm_head` is incorrectly shaped (should be 2D)")
    else:
        print("✅ `lm_head` has correct shape")

    # ✅ 10. Ensure `embed_tokens` is not sharded (should be `[vocab_size, hidden_dim]`)
    if len(unwrapped_model.model.embed_tokens.weight.shape) != 2:
        print("❌ ERROR: `embed_tokens` is incorrectly shaped (should be 2D)")
    else:
        print("✅ `embed_tokens` has correct shape")

    # ✅ 11. Run a Test Generation
    prompt = "hello"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(accelerator.device)

    print("\n🚀 Running Test Generation...")
    with torch.no_grad():
       # with torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params(model):
        print("\n🔍 Model Parameter Shapes:")
        print("   🔹 embed_tokens:", model.model.embed_tokens.weight.shape)
        print("   🔹 lm_head:", model.lm_head.weight.shape)
        output_ids = model.generate(inputs=input_ids, max_new_tokens=3)
    
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    print(f"📝 Generated Output: {output_text}")

    print("\n🎉 Model unwrapping test completed successfully!")

if __name__ == "__main__":
    main()

I obtain the following outputs

🔍 Model Parameter Shapes:                                                                                                                                                                                  
   🔹 embed_tokens: torch.Size([116686080])                                                                                                                                                                 
   🔹 lm_head: torch.Size([116686080])                                                                                                                                                                      
❌ ERROR: `lm_head` is incorrectly shaped (should be 2D)                                                                                                                                                    
❌ ERROR: `embed_tokens` is incorrectly shaped (should be 2D)                                                                                                                                               
                                                                                                                                                                                                            
🚀 Running Test Generation...                                                                                                                                                                               
                                                                                                                                                                                                            
🔍 Model Parameter Shapes:                                                                                                                                                                                  
   🔹 embed_tokens: torch.Size([116686080])                                                                                                                                                                 
   🔹 lm_head: torch.Size([116686080])                                                                                                                                                                      
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to 
obtain reliable results.                                                                                                                                                                                    
Qwen2ForCausalLM                                                                                                                                                                                            
                                                                                                                                                                                                            
🔍 Model Parameter Shapes:                                                                                                                                                                                  
   🔹 embed_tokens: torch.Size([116687616])                                                                                                                                                                 
   🔹 lm_head: torch.Size([116687616])                                                                                                                                                                      
❌ ERROR: `lm_head` is incorrectly shaped (should be 2D)                                                                                                                                                    
❌ ERROR: `embed_tokens` is incorrectly shaped (should be 2D)                                                                                                                                               
                                                                                                                                                                                                            
🚀 Running Test Generation...                                                                                                                                                                               
                                                                                                                                                                                                            
🔍 Model Parameter Shapes:                                                                                                                                                                                  
   🔹 embed_tokens: torch.Size([116687616])                                                                                                                                                                 
   🔹 lm_head: torch.Size([116687616])   

And the following error

[rank1]: Traceback (most recent call last):                                                                                                                                                                 
[rank1]:   File "/home/data/v.moskvoretskii/QAC/test_score.py", line 92, in <module>                                                                                                                        
[rank1]:     main()                                                                                                                                                                                         
[rank1]:   File "/home/data/v.moskvoretskii/QAC/test_score.py", line 84, in main                                                                                                                            
[rank1]:     output_ids = model.generate(inputs=input_ids, max_new_tokens=3)                                                                                                                                
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context             
[rank1]:     return func(*args, **kwargs)                                                                                                                                                          [34/1833]
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2223, in generate                                                                                          
[rank1]:     result = self._sample(                                                                                                                                                                         
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 3211, in _sample                                                                                           
[rank1]:     outputs = self(**model_inputs, return_dict=True)                                                                                                                                               
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl                                                                                      
[rank1]:     return self._call_impl(*args, **kwargs)                                                                                                                                                        
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1536, in _call_impl                                                                                              
[rank1]:     return forward_call(*args, **kwargs)                                                                                                                                                           
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 820, in forward                                                                                              
[rank1]:     return model_forward(*args, **kwargs)                                                                                                                                                          
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 808, in __call__                                                                                             
[rank1]:     return convert_to_fp32(self.model_forward(*args, **kwargs))                                                                                                                                    
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast                                                                                         
[rank1]:     return func(*args, **kwargs)                                                                                                                                                                   
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/transformers/utils/deprecation.py", line 172, in wrapped_func                                                                                      
[rank1]:     return func(*args, **kwargs)                                                                                                                                                                   
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 856, in forward                                                                                 
[rank1]:     outputs = self.model(                                                                                                                                                                          
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl                                                                                      
[rank1]:     return self._call_impl(*args, **kwargs)                                                                                                                                                        
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1536, in _call_impl                                                                                              
[rank1]:     return forward_call(*args, **kwargs)                                                                                                                                                           
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 535, in forward                                                                                 
[rank1]:     inputs_embeds = self.embed_tokens(input_ids)                                                                                                                                                   
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)                                
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1536, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)                                                     
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py", line 163, in forward
[rank1]:     return F.embedding(                                                                                                                                                                            
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py", line 2264, in embedding
[rank1]:     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)                                                                                                                 
[rank1]: RuntimeError: 'weight' must be 2-D

Expected behavior

I would expect model to be able to generate after unwrapping

@S1ro1
Copy link
Member

S1ro1 commented Mar 8, 2025

FSDP params are sharded outside of forward/backward, so you can't do that even with unwrap which returns the underlying model. The only option (as of my knowledge) is summon_full_params. Maybe share how that breaks for you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants