Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma capping #34282

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
85d549a
softcapping
ArthurZucker Jun 28, 2024
eba5191
soft cap before the mask
ArthurZucker Jun 28, 2024
b9e4a54
style
ArthurZucker Jun 28, 2024
514a839
...
ArthurZucker Jun 28, 2024
7544feb
super nit
ArthurZucker Jun 28, 2024
be1b8c3
update
ArthurZucker Oct 21, 2024
0e0511f
fixes
ArthurZucker Oct 21, 2024
03ccc22
update
ArthurZucker Oct 21, 2024
bdda724
small issue with modular
ArthurZucker Oct 21, 2024
a2b6b12
fix modular imports
ArthurZucker Oct 21, 2024
9365c1b
update
ArthurZucker Oct 21, 2024
2108ee3
fixup
ArthurZucker Oct 21, 2024
520120a
simplify a hell lot
ArthurZucker Oct 21, 2024
314ed1f
simplify cleaning imports
ArthurZucker Oct 22, 2024
8830473
finish fixing
ArthurZucker Oct 22, 2024
e4c19d7
update our design
ArthurZucker Oct 22, 2024
7922210
nits
ArthurZucker Oct 22, 2024
fa1319d
Merge branch 'main' of github.com:huggingface/transformers into gemma…
ArthurZucker Nov 1, 2024
43c68f6
use a deprecation cycle
ArthurZucker Nov 1, 2024
1aec944
updates
ArthurZucker Nov 1, 2024
93b53ef
Fix modular (recursive deps need to always be computed after merges!)
Cyrilvallez Nov 1, 2024
6f3cabb
Merge branch 'gemma-capping' of github.com:huggingface/transformers i…
ArthurZucker Nov 1, 2024
a79c4a9
push
ArthurZucker Nov 1, 2024
4c6d299
fix
ArthurZucker Nov 1, 2024
607c45d
update
ArthurZucker Nov 1, 2024
4598bba
fix modular order
Cyrilvallez Nov 1, 2024
5727270
make fix-copies
ArthurZucker Nov 1, 2024
198b4c4
updates
ArthurZucker Nov 1, 2024
3d35151
update
ArthurZucker Nov 1, 2024
da050cd
?
ArthurZucker Nov 1, 2024
e02078c
don't compile for now
ArthurZucker Nov 1, 2024
5861bbf
?
ArthurZucker Nov 4, 2024
8c47da2
fix some stuff
ArthurZucker Nov 4, 2024
09a88d9
donc!
ArthurZucker Nov 4, 2024
c06b530
fix copies
ArthurZucker Nov 4, 2024
89e6f85
update
ArthurZucker Nov 4, 2024
152e0b7
fixup
ArthurZucker Nov 4, 2024
46d8fa7
Merge branch 'main' of github.com:huggingface/transformers into gemma…
ArthurZucker Nov 4, 2024
006e869
?
ArthurZucker Nov 4, 2024
159c65a
fix two tests
ArthurZucker Nov 4, 2024
56ea5b9
fix?
ArthurZucker Nov 4, 2024
4c3deb9
for now, don't use head info
ArthurZucker Nov 4, 2024
9e3609d
eager when output attentoin and sdpa or flash as it's the simplest be…
ArthurZucker Nov 4, 2024
21edaed
fix-copies
ArthurZucker Nov 4, 2024
b5d9819
revert sdpa check
ArthurZucker Nov 4, 2024
5a3dade
Apply suggestions from code review
ArthurZucker Nov 6, 2024
faf433b
Merge branch 'main' of github.com:huggingface/transformers into gemma…
ArthurZucker Nov 6, 2024
1da75e1
rebase, fix-copies and push
ArthurZucker Nov 6, 2024
aca9120
add a slow integration test
ArthurZucker Nov 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/transformers/models/gemma/configuration_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from ...configuration_utils import PretrainedConfig


Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/gemma2/configuration_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from ...configuration_utils import PretrainedConfig


Expand Down
72 changes: 31 additions & 41 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,42 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...cache_utils import Cache, HybridCache
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from ...utils import (
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_torch_greater_or_equal,
logging,
)


if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward

if is_torch_greater_or_equal("2.5"):
from torch.nn.attention.flex_attention import flex_attention
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch import nn

from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_gemma2 import Gemma2Config
Expand Down Expand Up @@ -402,22 +415,6 @@ def forward(
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"Gemma2Model is using Gemma2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)

bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
Expand All @@ -441,40 +438,33 @@ def forward(
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
def tanh_softcap(score, b, h, q_idx, kv_idx):
soft_cap = self.config.attn_logit_softcapping
return soft_cap * torch.tanh(score / soft_cap)

attn_output = torch.nn.functional.scaled_dot_product_attention(
attn_output = flex_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
block_mask=causal_mask,
score_mod=tanh_softcap,
enable_gqa=True,
scale=self.scaling,
return_lse=output_attentions,
)
if output_attentions:
attn_output, attention_scores = attn_output

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)

attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value
return attn_output, attention_scores, past_key_value


GEMMA2_ATTENTION_CLASSES = {
Expand Down
49 changes: 15 additions & 34 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_torch_greater_or_equal,
logging,
)
from ..gemma.modeling_gemma import (
Expand All @@ -49,6 +50,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward

if is_torch_greater_or_equal("2.5"):
from torch.nn.attention.flex_attention import flex_attention


logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -414,22 +418,6 @@ def forward(
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"Gemma2Model is using Gemma2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)

bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
Expand All @@ -453,40 +441,33 @@ def forward(
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
def tanh_softcap(score, b, h, q_idx, kv_idx):
soft_cap = self.config.attn_logit_softcapping
return soft_cap * torch.tanh(score / soft_cap)

attn_output = torch.nn.functional.scaled_dot_product_attention(
attn_output = flex_attention(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it a bit misleading to use flex attn when we have attn_implementation="sdpa"? My concerns would be

  • People that previously used sdpa (forced or not) will suddenly have different torch requirements
  • Sdpa != Flexattn imo, it's a different API, name, and potentially slightly different behaviour
  • Are the slow tests still passing? We should ensure that it's still behaving the same ish in comparison to eager

Wdyt about making another attn implementation option for flex attn specifically? Not sure if this goes over the goal but control over the specific implementation is always appreciated.

Overall excited to see this, great work!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDPA version of gemma never "worked" TBH!
I'll probably add a new class for flex attention, this was simpler for testing

query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
block_mask=causal_mask,
score_mod=tanh_softcap,
enable_gqa=True,
scale=self.scaling,
return_lse=output_attentions,
)
if output_attentions:
attn_output, attention_scores = attn_output

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)

attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value
return attn_output, attention_scores, past_key_value


class Gemma2DecoderLayer(GemmaDecoderLayer):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch import nn

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import os
from typing import Union

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import math
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@
is_torch_fp16_available_on_device,
is_torch_fx_available,
is_torch_fx_proxy,
is_torch_greater_or_equal,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_musa_available,
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,14 @@ def is_flash_attn_greater_or_equal(library_version: str):
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)


@lru_cache()
def is_torch_greater_or_equal(library_version: str):
if not _is_package_available("torch"):
return False

return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)


def is_torchdistx_available():
return _torchdistx_available

Expand Down
Loading
Loading