-
Notifications
You must be signed in to change notification settings - Fork 355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implements Vera
#763
Open
julian-fong
wants to merge
68
commits into
adapter-hub:main
Choose a base branch
from
julian-fong:implement_vera
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Implements Vera
#763
Changes from all commits
Commits
Show all changes
68 commits
Select commit
Hold shift + click to select a range
93bee1a
Make generation tests generic
TimoImhof f51cfdb
Merge remote-tracking branch 'origin/main' into dev/test-refactoring
TimoImhof 7e65e82
Draft Refactoring AdapterTestBase
TimoImhof 793cbe5
Merge branch 'adapter-hub:main' into dev/test-refactoring
TimoImhof 65c3fb7
Replace import class names
TimoImhof afdcfdd
Merge branch 'dev/test-refactoring' of https://github.com/TimoImhof/a…
TimoImhof ee6166c
Base refactoring:
TimoImhof 630b722
remove redundant imports
TimoImhof 0d3577f
Add pytest markers and respective pytest commands
TimoImhof 1300856
Add draft of README
TimoImhof 78387db
Refactoring:
TimoImhof 83d3b32
Fix make quality
TimoImhof 5e8e1b8
Add gpt2 tests
TimoImhof 53eb0b9
Fix config union and head tests
TimoImhof 1dbd412
Fix paths and imports
TimoImhof cf4f6a7
remove accidently added prompt tuning from gpt2 and make style
TimoImhof b390d61
Revert PromptTuning changes
TimoImhof 2193aee
Revert "Revert PromptTuning changes"
TimoImhof f555484
Re-add missing adapter model tests
TimoImhof 8dccda2
Refactoring:
TimoImhof c665948
Introduce generic test creator function
TimoImhof fb425b6
Re-add beit adapter method tests
TimoImhof 225439c
Refactor & Re-add bertgeneration and bert
TimoImhof 09f9cdc
Re-add clip tests
TimoImhof 7934350
Re-add:
TimoImhof 5f55935
Add more models
TimoImhof 147c8af
Re-add whisper
TimoImhof 57c5131
initial commit
julian-fong 259a268
improved docstring and fixed formatting issues
julian-fong b66571c
fixed formatting
julian-fong acee994
updates
julian-fong 18182af
Updates
julian-fong f28e508
Updates
julian-fong f38b0e3
removed typo
julian-fong 385cd35
fix black
julian-fong 46af3fd
updates
julian-fong 9f3a202
fixed typo
julian-fong b2979ce
Changes:
TimoImhof ffd21a9
Add debug statements and only execute failing test
TimoImhof 0dba87c
Add verbose information
TimoImhof c333467
check package versions
TimoImhof aac4038
More debugging statements
TimoImhof 0f4c9b6
Merge branch 'adapter-hub:main' into dev/test-refactoring
TimoImhof 12379e3
Merge branch 'main' into implement_vera
julian-fong 0c0f7e6
updates
julian-fong 99cfb68
Merge branch 'implement_vera' of github.com:julian-fong/adapters into…
julian-fong 9ac515c
Merge branch 'adapter-hub:main' into dev/test-refactoring
TimoImhof 4af10df
Fix failing test:
TimoImhof 1229fc5
added review updates
julian-fong 20ddb5c
apply fix from #770
julian-fong dbd4965
Update README
TimoImhof 25fe0a9
updated docstring
julian-fong 7f79832
updated docstring
julian-fong d1a4a09
Merge branch 'main' of https://github.com/TimoImhof/adapters into dev…
TimoImhof c516464
Fix hf version and clip tests
TimoImhof 470169f
Merge branch 'adapter-hub:main' into implement_vera
julian-fong bb019b0
Merge branch 'adapter-hub:main' into implement_vera
julian-fong 87c0998
Merge branch 'adapter-hub:main' into dev/test-refactoring
TimoImhof 2c80a5c
Polish:
TimoImhof be69f0a
Merge branch 'main' into dev/test-refactoring
TimoImhof f1b1136
Merge branch 'main' into dev/test-refactoring
TimoImhof ebdf0a7
Merge remote-tracking branch 'github-desktop-TimoImhof/dev/test-refac…
julian-fong cd95c06
configure vera tests to #740
julian-fong a0e578a
fix quality:
julian-fong 3e7f2a5
update model_mixin.py to use forwardcontext in merge_adapter and rese…
julian-fong f62a44d
update black
julian-fong ac914c4
updated black
julian-fong ef574bf
updates
julian-fong File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
|
||
from ..composition import Average, BatchSplit, Parallel, Stack | ||
from ..configuration import LoRAConfig, ModelAdaptersConfig | ||
from ..context import ForwardContext | ||
from .adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase | ||
from .utils import dequantize_bnb_weight | ||
|
||
|
@@ -37,6 +38,7 @@ def __init__( | |
lora_B_shape, | ||
config: LoRAConfig, | ||
gating_heads: int = 1, | ||
name: str = None, | ||
): | ||
super().__init__() | ||
assert config.composition_mode == "add", "LoRA module only supports composition_mode='add'." | ||
|
@@ -45,6 +47,7 @@ def __init__( | |
self.composition_mode = config.composition_mode | ||
self.attn_matrices = config.attn_matrices | ||
self.use_gating = config.use_gating | ||
self.name = name | ||
# Optional dropout | ||
if config.dropout > 0.0: | ||
self.lora_dropout = nn.Dropout(p=config.dropout) | ||
|
@@ -69,6 +72,9 @@ def __init__( | |
elif config.init_weights == "ia3": | ||
nn.init.ones_(self.lora_A) | ||
nn.init.ones_(self.lora_B) | ||
elif config.init_weights == "vera": | ||
nn.init.kaiming_uniform_(self.lora_A) | ||
nn.init.kaiming_uniform_(self.lora_B) | ||
else: | ||
raise ValueError("Unknown init_weights type: {}".format(config.init_weights)) | ||
|
||
|
@@ -112,6 +118,7 @@ def __init__( | |
lora_B_shape, | ||
config: LoRAConfig, | ||
gating_heads: int = 1, | ||
name: str = None, | ||
): | ||
super().__init__() | ||
assert config.composition_mode == "scale", "IA3 module only supports composition_mode='scale'." | ||
|
@@ -122,6 +129,7 @@ def __init__( | |
self.composition_mode = config.composition_mode | ||
self.attn_matrices = config.attn_matrices | ||
self.use_gating = config.use_gating | ||
self.name = name | ||
# Optional dropout | ||
if config.dropout > 0.0: | ||
raise ValueError("IA3 module does not support dropout.") | ||
|
@@ -133,7 +141,7 @@ def __init__( | |
# For compatibility with LoRA, allow all init_weights types here. | ||
# Usually should be "ia3". | ||
if config.init_weights == "lora": | ||
logger.warning("(IA)^3 module initialized with LoRA zeo init. Ignore if this is intended.") | ||
logger.warning("(IA)^3 module initialized with LoRA zero init. Ignore if this is intended.") | ||
nn.init.zeros_(self.lora_B) | ||
elif config.init_weights == "bert": | ||
nn.init.normal_(self.lora_B, std=0.02) | ||
|
@@ -177,6 +185,116 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens | |
return hidden_states, gate | ||
|
||
|
||
class Vera(nn.Module): | ||
def __init__( | ||
self, | ||
lora_A_shape, | ||
lora_B_shape, | ||
config: LoRAConfig, | ||
gating_heads: int = 1, | ||
name: str = None, | ||
): | ||
super().__init__() | ||
self.d = config.vera_d | ||
self.b = config.vera_b | ||
self.r = config.r | ||
self.alpha = config.alpha | ||
self.use_gating = config.use_gating | ||
self.name = name | ||
|
||
# check to make sure that the `composition_mode` is set to `add` | ||
self.composition_mode = config.composition_mode | ||
if self.composition_mode != "add": | ||
raise ValueError("Vera module only supports composition_mode='add'.") | ||
|
||
# Optional dropout | ||
if config.dropout > 0.0: | ||
self.lora_dropout = nn.Dropout(p=config.dropout) | ||
|
||
self.lora_A_shape = lora_A_shape | ||
self.lora_B_shape = lora_B_shape | ||
self.d_shape = self.lora_A_shape[0] | ||
self.b_shape = self.lora_B_shape[0] | ||
|
||
# Actual trainable parameters | ||
self.vera_D = nn.Parameter(torch.diag(torch.ones(self.d_shape) * self.d)) | ||
self.vera_B = nn.Parameter(torch.diag(torch.ones(self.b_shape) * self.b)) | ||
self.scaling = self.alpha / self.r | ||
|
||
if self.use_gating: | ||
self.gate = nn.Linear(lora_A_shape[-1], gating_heads) | ||
nn.init.normal_(self.gate.weight, std=0.02) | ||
|
||
@property | ||
def delta_w(self) -> torch.Tensor: | ||
parameters = ForwardContext.get_context().shared_parameters[self.name] | ||
lora_A = parameters["lora_A"] | ||
lora_B = parameters["lora_B"] | ||
return self.vera_B @ lora_B @ self.vera_D @ lora_A | ||
|
||
def com(self, weights: torch.Tensor, added: torch.Tensor, scaling=None) -> torch.Tensor: | ||
"""Performs the composition operation between existing and injected weights.""" | ||
if scaling is None: | ||
scaling = self.scaling | ||
return weights + added * scaling | ||
|
||
def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor: | ||
"""Inverts the composition operation between existing and injected weights.""" | ||
return weights - added * self.scaling | ||
|
||
def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor): | ||
parameters = ForwardContext.get_context().shared_parameters[self.name] | ||
lora_A = parameters["lora_A"] | ||
lora_B = parameters["lora_B"] | ||
|
||
if hidden_states is None: | ||
hidden_states = layer_input | ||
|
||
if getattr(self, "lora_dropout"): | ||
hidden_states = self.lora_dropout(hidden_states) | ||
|
||
hidden_states = hidden_states @ torch.t(self.vera_B @ lora_B @ self.vera_D @ lora_A) | ||
|
||
if self.use_gating: | ||
gate = torch.sigmoid(self.gate(layer_input)) | ||
gate = torch.mean(gate, dim=1).unsqueeze(-1) | ||
hidden_states = hidden_states * gate | ||
else: | ||
gate = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as this is likely merged after #770, the same fix from there should be applied here |
||
hidden_states = hidden_states * self.scaling | ||
|
||
return hidden_states, gate | ||
|
||
|
||
def init_shared_vera_parameters(model_config, adapter_config, device): | ||
hidden_size = model_config.hidden_size | ||
r = adapter_config["r"] | ||
|
||
parameters = nn.ParameterDict() | ||
|
||
# initialize frozen, random tensors A, B | ||
parameters["lora_A"] = torch.zeros(r, hidden_size).to(device) | ||
parameters["lora_B"] = torch.zeros(hidden_size, r).to(device) | ||
|
||
if adapter_config["init_weights"] == "lora": | ||
# initialize A the same way as the default for nn.Linear and B to zero | ||
nn.init.kaiming_uniform_(parameters["lora_A"], a=math.sqrt(5)) | ||
nn.init.zeros_(parameters["lora_B"]) | ||
elif adapter_config["init_weights"] == "bert": | ||
nn.init.normal_(parameters["lora_A"], std=0.02) | ||
nn.init.normal_(parameters["lora_B"], std=0.02) | ||
elif adapter_config["init_weights"] == "ia3": | ||
nn.init.ones_(parameters["lora_A"]) | ||
nn.init.ones_(parameters["lora_B"]) | ||
elif adapter_config["init_weights"] == "vera": | ||
nn.init.kaiming_uniform_(parameters["lora_A"]) | ||
nn.init.kaiming_uniform_(parameters["lora_B"]) | ||
else: | ||
raise ValueError("Unknown init_weights type: {}".format(adapter_config["init_weights"])) | ||
|
||
return parameters | ||
|
||
|
||
class LoRALayer(AdapterLayerBase): | ||
adapter_modules_name = "loras" | ||
|
||
|
@@ -202,6 +320,7 @@ def _get_lora_shapes(self, config: LoRAConfig): | |
|
||
def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: | ||
self.layer_idx = layer_idx | ||
|
||
lora_config = self.adapters_config.match( | ||
adapter_name, | ||
config_type=LoRAConfig, | ||
|
@@ -210,7 +329,10 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: | |
) | ||
if lora_config is not None and self._check_lora_location(lora_config): | ||
if lora_config.composition_mode == "add": | ||
lora_cls = LoRA | ||
if isinstance(lora_config.vera_d, float) or isinstance(lora_config.vera_b, float): | ||
lora_cls = Vera | ||
else: | ||
lora_cls = LoRA | ||
elif lora_config.composition_mode == "scale": | ||
lora_cls = IA3 | ||
else: | ||
|
@@ -219,7 +341,9 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: | |
*self._get_lora_shapes(lora_config), | ||
lora_config, | ||
gating_heads=self.get_n_heads(lora_config), | ||
name=adapter_name, | ||
) | ||
|
||
lora.train(self.training) | ||
lora = lora.to(self.weight.device) | ||
self.loras[adapter_name] = lora | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should also add an assert for composition mode "add" here (same as in LoRA init), just to make sure