Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default arguments in DebertaConfig disable relative attention, contrary to the docs and deberta-base #35335

Open
4 tasks
bauwenst opened this issue Dec 19, 2024 · 6 comments
Labels

Comments

@bauwenst
Copy link

System Info

transformers 4.47.0

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The documentation for DebertaConfig says that

Instantiating a configuration with the defaults will yield a similar configuration to that of the DeBERTa microsoft/deberta-base architecture.

Yet, the most important part of DeBERTa, namely the relative attention, is disabled by default in the model and in the config:

self.relative_attention = getattr(config, "relative_attention", False)

relative_attention (`bool`, *optional*, defaults to `False`):
Whether use relative position encoding.
max_relative_positions (`int`, *optional*, defaults to 1):
The range of relative positions `[-max_position_embeddings, max_position_embeddings]`. Use the same value
as `max_position_embeddings`.

Even when users request a given amount of max_relative_positions, relative attention stays disabled as long as that option is set to False.

if self.relative_attention:
self.max_relative_positions = getattr(config, "max_relative_positions", -1)
if self.max_relative_positions < 1:
self.max_relative_positions = config.max_position_embeddings
self.pos_dropout = nn.Dropout(config.hidden_dropout_prob)
if "c2p" in self.pos_att_type:
self.pos_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
if "p2c" in self.pos_att_type:
self.pos_q_proj = nn.Linear(config.hidden_size, self.all_head_size)

And indeed:

from transformers import DebertaConfig

config = DebertaConfig()
print(config.relative_attention)

This prints False, and when you instantiate a new DeBERTa model, e.g. like

from transformers import DebertaConfig, DebertaForMaskedLM

print(DebertaForMaskedLM._from_config(DebertaConfig()))
print(DebertaForMaskedLM._from_config(DebertaConfig(max_relative_positions=512)))

...there are no relative positional embeddings in the model, only absolute positional embeddings. This model will also not do any disentangled attention.

Expected behavior

Conform to the documentation by setting relative_attention=True in the DebertaConfig by default.

I would also add a warning when relative attention is False, so that users know very clearly that despite using a DeBERTa model, they are not getting the core feature offered by DeBERTa, namely the relative attention.

@bauwenst bauwenst added the bug label Dec 19, 2024
@Rocketknight1
Copy link
Member

Hi @bauwenst, I would prefer not to update the default here, because it might cause an unexpected change in behaviour for users. Instead, I think updating the documentation to explain that relative attention is not enabled by default in the config class is probably the right course of action.

@bauwenst
Copy link
Author

I understand your point @Rocketknight1, but consider this:

  • Configs are only instantiated from scratch for new training runs, rather than recurrently in the life cycle of a model (unlike e.g. the default arguments in AutoModel.from_pretrained). The people whose implementations would be affected, would be those people who are now getting ready to train DeBERTa AND don't want relative attention AND don't explicitly disable it by using DebertaConfig(relative_attention=False) AND will update to the newer version of transformers before running their code. That is an exceedingly tiny number of cases.
  • Just because technical debt exists, does not justify that it keeps existing for "backwards compatibility" reasons, which again, isn't even really needed here.

I know at least two academic researchers who have had to throw out all their DeBERTa experiments, and I accidentally avoided this because I initialised DebertaConfig.from_pretrained() rather than DebertaConfig(), but could have easily had to trash weeks of GPU hours. As you put it, the current HuggingFace defaults themselves "cause an unexpected change in behaviour for users". Most users expect DeBERTa to have disentangled attention, because that's what the paper is about and what the name implies. The current defaults are the unexpected change in behaviour. Keeping the current defaults means you are assuming that less harm is done by tricking users who want disentangled attention than tricking the tiny number of users above. I don't see how this trade-off makes sense.

@Rocketknight1
Copy link
Member

Hi @bauwenst, you're right that this is a key feature of DeBERTa, but I think we prefer to keep the default as-is both for backward compatibility reasons, but also because this matches the code in the original DeBERTa repo.

However, you were totally right to point out that setting this to False does not in fact yield "a similar configuration to that of the DeBERTa microsoft/deberta-base architecture", as the docs suggest. If you want to open a PR to correct the docs, we'd be happy to accept it!

@bauwenst
Copy link
Author

bauwenst commented Dec 20, 2024

both for backward compatibility reasons

To reiterate my previous comment: Keeping the current defaults means you are assuming that less harm is done by tricking users who want disentangled attention than tricking the tiny number of users above. I don't see how this trade-off makes sense.

but also because this matches the code in the original DeBERTa repo

I fail to see how this is a case of backwards compatibility.

  1. People who used the original repo are definitely not the same people who use transformers. They don't live in the same ecosystem.
  2. In transformers there is a dedicated config class for DeBERTa, namely the DebertaConfig. You may notice that the original repo hesitates about whether the relative_attention field even exists (with getattr(config, "relative_attention", False)) rather than just using config.relative_attention. Have you considered why this is? Your point seems to be that this is because they intended the default DeBERTa to have relative_attention be False. This is not right. The reason is that they did not have a dedicated config class for DeBERTa. They wrote one config class whose fields where those of BERT (which is why relative_attention only appears in the docstring of the file you linked to, but is not a field that appears in the constructor). Extra fields for DeBERTa were added directly in the config JSONs. If you go look at the configs they provide, you'll see that for all configs that aren't the BERT-base config, all instances of relative_attention are True.

In other words: by saying you keep it False for backwards compatibility with the DeBERTa repo, you are saying that you are keeping it False so that DeBERTa's default config is a BERT config. This makes no sense, because DebertaConfig is not intended to replace BertConfig. What that docstring you linked to is saying is not that the default value in a config for DeBERTa should be False. It is saying that because DeBERTa is a modified BERT and BERT configs don't contain the relative_attention field, the modified BERT constructor will assume that you want to construct BERT if you use an old config that has no relative_attention field.

In even other words: the default value of False, which is now apparently sacred technical debt in transformers, was itself grandfathered in from technical debt by Microsoft. Zero use for keeping it this way.

@Rocketknight1
Copy link
Member

Regarding backward compatibility, we're just keeping it False because people who have are running code that instantiates the class will find their model suddenly changing when they upgrade transformers, which we would prefer not to do without a strong reason. Config classes are not intended to have perfectly optimal default settings - we assume that users advanced enough to define their own configs and train models from scratch can figure out how to read a docstring and set a few kwargs.

As a result, I really do think the problem is just that one misleading line in the documentation, as I mentioned! You're welcome to submit a PR to fix it, but I don't want to keep going back and forth over the default value

@bauwenst
Copy link
Author

bauwenst commented Dec 20, 2024

perfectly optimal default settings

The current default DeBERTa model in transformers is a BERT model, not a DeBERTa model.

That's not a matter of perfectionism. That's just incorrect. You are offering a BERT model and calling it DeBERTa.

It makes very little sense to change documentation to reflect that you have a bug rather than just fixing the bug. Here's a better idea: let's figure out a system where an additional field is added to the DebertaConfig class such that

  1. New calls to DebertaConfig result in relative_attention=True unless the user supplies relative_attention=False.
  2. Old configs loaded with DebertaConfig.from_pretrained result in relative_attention=False unless the config contains relative_attention=True.
  3. New configs loaded with from_pretrained result in relative_attention=True unless the config contains relative_attention=False.

This would give you your supposed "backward compatibility" while actually instantiating DeBERTa from the default DebertaConfig.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants