1+ import re
2+ from typing import Dict , Set
3+
4+ import torch
15import torch .nn as nn
2- from peft import PeftModel
6+ from peft import PeftModel , PeftType
7+
8+
9+ def extract_lora_layers (model : PeftModel , names : Set [str ], adapter_name : str = "default" ):
10+ config = model .peft_config [adapter_name ]
11+ if config .peft_type != PeftType .LORA :
12+ raise ValueError (f"Adapter { adapter_name } is not a LORA adapter." )
13+ # to_return = lora_state_dict(model, bias=model.peft_config.bias)
14+ # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
15+ # to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
16+ bias = config .bias
17+ if bias == "none" :
18+ to_return = {k for k in names if "lora_" in k }
19+ elif bias == "all" :
20+ to_return = {k for k in names if "lora_" in k or "bias" in k }
21+ elif bias == "lora_only" :
22+ to_return = set ()
23+ for k in names :
24+ if "lora_" in k :
25+ to_return .add (k )
26+ bias_name = k .split ("lora_" )[0 ] + "bias"
27+ if bias_name in names :
28+ to_return .add (bias_name )
29+ else :
30+ raise NotImplementedError
31+ to_return = {k for k in to_return if (("lora_" in k and adapter_name in k ) or ("bias" in k ))}
32+ if config .use_dora :
33+ # Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
34+ # ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since
35+ # we want the state_dict format not to change, we remove the "weight" part.
36+ new_dora_suffix = f"lora_magnitude_vector.{ adapter_name } .weight"
37+
38+ def renamed_dora_weights (k ):
39+ if k .endswith (new_dora_suffix ):
40+ k = k [:- 7 ] # remove ".weight"
41+ return k
42+
43+ to_return = {renamed_dora_weights (k ) for k in to_return }
44+
45+ to_return = {re .sub (f"lora_\S\.{ adapter_name } \.(weight|bias)" , "base_layer" , k ) for k in to_return }
46+ return to_return
47+
48+
49+ class PeftUnwrapMixin :
50+ def __init__ (self , peft_model : PeftModel ):
51+ self .base_model = peft_model .get_base_model ()
52+ # peft does not affect buffers
53+ self .lora_layers = extract_lora_layers (peft_model , set (n for n , p in self .base_model .named_parameters ()))
54+ potential_lora_weights = set ()
55+ for n in self .lora_layers :
56+ potential_lora_weights .add (f"{ n } .weight" )
57+ potential_lora_weights .add (f"{ n } .bias" )
58+ self .lora_param_to_origin_param = {n : n .replace ("base_layer." , "" ) for n in potential_lora_weights }
59+ self .origin_param_to_lora_param = {v : k for k , v in self .lora_param_to_origin_param .items ()}
60+
61+ def named_parameters (self ):
62+ for n , p in self .base_model .named_parameters ():
63+ if n in self .lora_param_to_origin_param :
64+ n = self .lora_param_to_origin_param [n ]
65+ yield n , p
66+
67+ def named_buffers (self ):
68+ return self .base_model .named_buffers ()
69+
70+ @property
71+ def _modules (self ):
72+ return self .base_model ._modules
73+
74+ @property
75+ def _non_persistent_buffers_set (self ):
76+ return self .base_model ._non_persistent_buffers_set
77+
78+ def patch_state_dict (self , state_dict : Dict [str , torch .Tensor ]):
79+ new_state_dict = {}
80+ for k , v in state_dict .items ():
81+ if k in self .origin_param_to_lora_param :
82+ k = self .origin_param_to_lora_param [k ]
83+ new_state_dict [k ] = v
84+ return new_state_dict
85+
86+ def state_dict (self ):
87+ state_dict = {}
88+ for k , v in self .base_model .state_dict ().items ():
89+ if k in self .lora_param_to_origin_param :
90+ k = self .lora_param_to_origin_param [k ]
91+ state_dict [k ] = v
92+ return state_dict
93+
94+ def load_state_dict (self , state_dict , strict : bool = True , assign : bool = False ):
95+ state_dict = self .patch_state_dict (state_dict )
96+ self .base_model .load_state_dict (state_dict , strict = strict , assign = assign )
97+
98+ def __hash__ (self ):
99+ return hash (self .base_model )
3100
4101
5102class ModelWrapper (nn .Module ):
@@ -23,7 +120,7 @@ def unwrap(self, unwrap_peft: bool = True):
23120 else :
24121 model = self .module
25122 if unwrap_peft and isinstance (model , PeftModel ):
26- model = model . get_base_model ( )
123+ model = PeftUnwrapMixin ( model )
27124 return model
28125
29126 def forward (self , * args , ** kwargs ):
0 commit comments