Skip to content

Commit

Permalink
Fix LLamaTokenizer init error | Add merge_and_unload() func for LoraM…
Browse files Browse the repository at this point in the history
…odel | Improve forward func of Linear/ MergedLinear class
  • Loading branch information
hoanganhpham1006 committed Sep 23, 2024
1 parent 6a0c18d commit 6f51d95
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 58 deletions.
1 change: 1 addition & 0 deletions examples/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
model.save("./llama_weights")

# If you want to load the model just do BaseModel.load("./llama_weights")
# If you want to merge the lora weights with the base model, you can do model.merge_lora()
7 changes: 3 additions & 4 deletions src/xturing/engines/llama_utils/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,16 @@ def __init__(
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
super().__init__(
bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs
)
self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.decode_with_prefix_space = decode_with_prefix_space
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
self._no_prefix_space_tokens = None

super().__init__(
bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs
)
""" Initialisation"""

@property
Expand Down
137 changes: 83 additions & 54 deletions src/xturing/engines/lora_engine/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,48 @@ def from_pretrained(cls, model, saved_dir):
set_peft_model_state_dict(model, adapters_weights)
model.eval()
return model

def merge_and_unload(self):
"""
Merge the LoRA weights with the base model weights and unload the LoRA modules.
This method should be called after training to use the model without the LoRA modules.
"""
if not getattr(self.model, 'is_loaded_in_8bit', False):
# Collect the modules to merge
modules_to_merge = []
for name, module in self.model.named_modules():
if isinstance(module, LoraLayer) and hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
modules_to_merge.append((name, module))

# Merge the collected modules
for name, module in modules_to_merge:
# Merge weights
delta_weight = (module.lora_B.weight @ module.lora_A.weight) * module.scaling
if getattr(module, 'fan_in_fan_out', False):
delta_weight = delta_weight.T
module.weight.data += delta_weight

# Delete the LoRA parameters
del module.lora_A
del module.lora_B

# Set r to zero to indicate that LoRA is unloaded
module.r = 0
module.merged = True

# Remove references to LoRA attributes
if hasattr(module, 'lora_dropout'):
del module.lora_dropout
if hasattr(module, 'scaling'):
del module.scaling

# Disable LoRA modules
self.disable_adapter_layers()

# Reset the model to evaluation mode
self.model.eval()

return self.model


# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
Expand Down Expand Up @@ -588,8 +630,9 @@ def reset_parameters(self):

def train(self, mode: bool = True):
nn.Linear.train(self, mode)
self.lora_A.train(mode)
self.lora_B.train(mode)
if self.r > 0:
self.lora_A.train(mode)
self.lora_B.train(mode)
if not mode and self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
Expand All @@ -601,7 +644,7 @@ def train(self, mode: bool = True):
)
self.merged = True
elif self.merge_weights and self.merged:
# Make sure that the weights are not merged
# Ensure the weights are not merged
if self.r > 0:
self.weight.data -= (
transpose(
Expand All @@ -611,39 +654,29 @@ def train(self, mode: bool = True):
)
self.merged = False


def eval(self):
nn.Linear.eval(self)
self.lora_A.eval()
self.lora_B.eval()
if self.r > 0:
self.lora_A.eval()
self.lora_B.eval()

def forward(self, x: torch.Tensor):
if self.disable_adapters:
if self.r > 0 and self.merged:
self.weight.data -= (
transpose(
self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out
)
* self.scaling
transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
)
self.merged = False

return F.linear(
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
)
elif self.r > 0 and not self.merged:
result = F.linear(
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
)
if self.r > 0:
loraoutput = (
self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
)
result = result + loraoutput
return result
else:
return F.linear(
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
)
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if self.r > 0 and not self.merged:
lora_output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
result += lora_output

return result


class MergedLinear(nn.Linear, LoraLayer):
Expand Down Expand Up @@ -755,35 +788,31 @@ def eval(self):
def forward(self, x: torch.Tensor):
if self.disable_adapters:
if self.r > 0 and self.merged and any(self.enable_lora):
delta_w = (
F.conv1d(
self.lora_A.weight.data.unsqueeze(0),
self.lora_B.weight.data,
groups=sum(self.enable_lora),
)
.squeeze(0)
.transpose(-2, -1)
)
self.weight.data -= transpose(
self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out
)
self.merged = False
return F.linear(
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
)
elif self.merged:
return F.linear(
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
self._unmerge_weights()
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

if self.merged:
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if self.r > 0:
after_A = self.lora_A(self.lora_dropout(x))
after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
result += self.zero_pad(after_B) * self.scaling
return result

def _unmerge_weights(self):
delta_w = (
F.conv1d(
self.lora_A.weight.data.unsqueeze(0),
self.lora_B.weight.data,
groups=sum(self.enable_lora),
)
else:
result = F.linear(
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
)
if self.r > 0:
after_A = self.lora_A(self.lora_dropout(x))
after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
result += self.zero_pad(after_B) * self.scaling
return result
.squeeze(0)
.transpose(-2, -1)
)
self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out)
self.merged = False


if is_bnb_available():
Expand Down

0 comments on commit 6f51d95

Please sign in to comment.