Skip to content

Commit

Permalink
add sdpa to OPT (#33298)
Browse files Browse the repository at this point in the history
* add sdpa to OPT

* chore: remove redundant whitespace in OPTDecoder class

* fixup

* bug fix

* add sdpa and attention generate test

* fixup

* Refactor OPTAttention forward method for improved readability and maintainability

* undo refactor for _shape and key,val states

* add OPT to doc, fixup didn't find it for some reason

* change order

* change default attn_implemntation in testing to eager

* [run-slow] opt

* change test_eager_matches_sdpa_generate to the one llama

* Update default attention implementation in testing common

* [run-slow] opt

* remove uneeded print

* [run-slow] opt

* refactor model testers to have attn_implementation="eager"

* [run-slow] opt

* convert test_eager_matches_sdpa_generate to opt-350M

* bug fix when creating mask for opt

* [run-slow] opt

* if layer head mask default to eager

* if head mask is not none fall to eager

* [run-slow] opt

* Update src/transformers/models/opt/modeling_opt.py

Co-authored-by: amyeroberts <[email protected]>

* Clean up Unpack imports (#33631)

clean up Unpack imports

* Fix DPT /Dinov2 sdpa regression on main (#33660)

* fallback to eager if output attentions.

* fix copies

* handle dependency errors in check_imports (#33622)

* handle dependency errors in check_imports

* change log level to warning

* add back self.max_position_embeddings = config.max_position_embeddings (#33550)

* add back self.max_position_embeddings = config.max_position_embeddings

* fix-copies

* Fix Llava conversion for LlavaQwen2ForCausalLM with Clip vision tower (#33613)

fix llavaqwen2 model conversion

* Uniformize kwargs for Udop processor and update docs (#33628)

* Add optional kwargs and uniformize udop

* cleanup Unpack

* nit Udop

* Generation: deprecate `PreTrainedModel` inheriting from `GenerationMixin`  (#33203)

* Enable BNB multi-backend support (#31098)

* enable cpu bnb path

* fix style

* fix code style

* fix 4 bit path

* Update src/transformers/utils/import_utils.py

Co-authored-by: Aarni Koskela <[email protected]>

* add multi backend refactor tests

* fix style

* tweak 4bit quantizer + fix corresponding tests

* tweak 8bit quantizer + *try* fixing corresponding tests

* fix dequant bnb 8bit

* account for Intel CPU in variability of expected outputs

* enable cpu and xpu device map

* further tweaks to account for Intel CPU

* fix autocast to work with both cpu + cuda

* fix comments

* fix comments

* switch to testing_utils.torch_device

* allow for xpu in multi-gpu tests

* fix tests 4bit for CPU NF4

* fix bug with is_torch_xpu_available needing to be called as func

* avoid issue where test reports attr err due to other failure

* fix formatting

* fix typo from resolving of merge conflict

* polish based on last PR review

Co-authored-by: Marc Sun <[email protected]>

* fix CI

* Update src/transformers/integrations/integration_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/integrations/integration_utils.py

Co-authored-by: Arthur <[email protected]>

* fix error log

* fix error msg

* add \n in error log

* make quality

* rm bnb cuda restriction in doc

* cpu model don't need dispatch

* fix doc

* fix style

* check cuda avaliable in testing

* fix tests

* Update docs/source/en/model_doc/chameleon.md

Co-authored-by: Marc Sun <[email protected]>

* Update docs/source/en/model_doc/llava_next.md

Co-authored-by: Aarni Koskela <[email protected]>

* Update tests/quantization/bnb/test_4bit.py

Co-authored-by: Aarni Koskela <[email protected]>

* Update tests/quantization/bnb/test_4bit.py

Co-authored-by: Aarni Koskela <[email protected]>

* fix doc

* fix check multibackends

* fix import sort

* remove check torch in bnb

* docs: update bitsandbytes references with multi-backend info

* docs: fix small mistakes in bnb paragraph

* run formatting

* reveret bnb check

* move bnb multi-backend check to import_utils

* Update src/transformers/utils/import_utils.py

Co-authored-by: Aarni Koskela <[email protected]>

* fix bnb check

* minor fix for bnb

* check lib first

* fix code style

* Revert "run formatting"

This reverts commit ac108c6.

* fix format

* give warning when bnb version is low and no cuda found]

* fix device assignment check to be multi-device capable

* address akx feedback on get_avlbl_dev fn

* revert partially, as we don't want the function that public, as docs would be too much (enforced)

---------

Co-authored-by: Aarni Koskela <[email protected]>
Co-authored-by: Titus von Koeller <[email protected]>
Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: Arthur <[email protected]>

* Fix error string after refactoring into get_chat_template (#33652)

* Fix error string after refactoring into get_chat_template

* Take suggestion from CR

Co-authored-by: Matt <[email protected]>

---------

Co-authored-by: Matt <[email protected]>

* uniformize git processor (#33668)

* uniformize git processor

* update doctring

* Modular `transformers`: modularity and inheritance for new model additions (#33248)

* update exampel

* update

* push the converted diff files for testing and ci

* correct one example

* fix class attributes and docstring

* nits

* oups

* fixed config!

* update

* nitd

* class attributes are not matched against the other, this is missing

* fixed overwriting self.xxx now onto the attributes I think

* partial fix, now order with docstring

* fix docstring order?

* more fixes

* update

* fix missing docstrings!

* examples don't all work yet

* fixup

* nit

* updated

* hick

* update

* delete

* update

* update

* update

* fix

* all default

* no local import

* fix more diff

* some fix related to "safe imports"

* push fixed

* add helper!

* style

* add a check

* all by default

* add the

* update

* FINALLY!

* nit

* fix config dependencies

* man that is it

* fix fix

* update diffs

* fix the last issue

* re-default to all

* alll the fixes

* nice

* fix properties vs setter

* fixup

* updates

* update dependencies

* make sure to install what needs to be installed

* fixup

* quick fix for now

* fix!

* fixup

* update

* update

* updates

* whitespaces

* nit

* fix

* simplify everything, and make it file agnostic (should work for image processors)

* style

* finish fixing all import issues

* fixup

* empty modeling should not be written!

* Add logic to find who depends on what

* update

* cleanup

* update

* update gemma to support positions

* some small nits

* this is the correct docstring for gemma2

* fix merging of docstrings

* update

* fixup

* update

* take doc into account

* styling

* update

* fix hidden activation

* more fixes

* final fixes!

* fixup

* fixup instruct  blip video

* update

* fix bugs

* align gemma2 with the rest as well

* updats

* revert

* update

* more reversiom

* grind

* more

* arf

* update

* order will matter

* finish del stuff

* update

* rename to modular

* fixup

* nits

* update makefile

* fixup

* update order of the checks!

* fix

* fix docstring that has a call inside

* fiix conversion check

* style

* add some initial documentation

* update

* update doc

* some fixup

* updates

* yups

* Mostly todo gimme a minut

* update

* fixup

* revert some stuff

* Review docs for the modular transformers (#33472)

Docs

* good update

* fixup

* mmm current updates lead to this code

* okay, this fixes it

* cool

* fixes

* update

* nit

* updates

* nits

* fix doc

* update

* revert bad changes

* update

* updates

* proper update

* update

* update?

* up

* update

* cool

* nits

* nits

* bon bon

* fix

* ?

* minimise changes

* update

* update

* update

* updates?

* fixed gemma2

* kind of a hack

* nits

* update

* remove `diffs` in favor of `modular`

* fix make fix copies

---------

Co-authored-by: Lysandre Debut <[email protected]>

* Fix CIs post merging modular transformers (#33681)

update

* Fixed docstring for cohere model regarding unavailability of prune_he… (#33253)

* Fixed docstring for cohere model regarding unavailability of prune_head() methods

The docstring mentions that cohere model supports prune_heads() methods. I have fixed the docstring by explicitly mentioning that it doesn't support that functionality.

* Update src/transformers/models/cohere/modeling_cohere.py

---------

Co-authored-by: Lysandre Debut <[email protected]>

* Generation tests: update imagegpt input name, remove unused functions (#33663)

* Improve Error Messaging for Flash Attention 2 on CPU (#33655)

Update flash-attn error message on CPU

Rebased to latest branch

* Gemma2: fix config initialization (`cache_implementation`) (#33684)

* Fix ByteLevel alphabet missing when Sequence pretokenizer is used (#33556)

* Fix ByteLevel alphabet missing when Sequence pretokenizer is used

* Fixed formatting with `ruff`.

* Uniformize kwargs for image-text-to-text processors (#32544)

* uniformize FUYU processor kwargs

* Uniformize instructblip processor kwargs

* Fix processor kwargs and tests Fuyu, InstructBlip, Kosmos2

* Uniformize llava_next processor

* Fix save_load test for processor with chat_template only as extra init args

* Fix import Unpack

* Fix Fuyu Processor import

* Fix FuyuProcessor import

* Fix FuyuProcessor

* Add defaults for specific kwargs kosmos2

* Fix Udop to return BatchFeature instead of BatchEncoding and uniformize kwargs

* Add tests processor Udop

* remove Copied from in processing Udop as change of input orders caused by BatchEncoding -> BatchFeature

* Fix overwrite tests kwargs processors

* Add warnings and BC for changes in processor inputs order, change docs, add BC for text_pair as arg for Udop

* Fix processing test fuyu

* remove unnecessary pad_token check in instructblip ProcessorTest

* Fix BC tests and cleanup

* FIx imports fuyu

* Uniformize Pix2Struct

* Fix wrong name for FuyuProcessorKwargs

* Fix slow tests reversed inputs align fuyu llava-next, change udop warning

* Fix wrong logging import udop

* Add check images text input order

* Fix copies

* change text pair handling when positional arg

* rebase on main, fix imports in test_processing_common

* remove optional args and udop uniformization from this PR

* fix failing tests

* remove unnecessary test, fix processing utils and test processing common

* cleanup Unpack

* cleanup

* fix conflict grounding dino

* 🚨🚨 Setting default behavior of assisted decoding (#33657)

* tests: fix pytorch tensor placement errors (#33485)

This commit fixes the following errors:
* Fix "expected all tensors to be on the same device" error
* Fix "can't convert device type tensor to numpy"

According to pytorch documentation torch.Tensor.numpy(force=False)
performs conversion only if tensor is on CPU (plus few other restrictions)
which is not the case. For our case we need force=True since we just
need a data and don't care about tensors coherency.

Fixes: #33517
See: https://pytorch.org/docs/2.4/generated/torch.Tensor.numpy.html

Signed-off-by: Dmitry Rogozhkin <[email protected]>

* bump tokenizers, fix added tokens fast (#32535)

* update based on tokenizers release

* update

* nits

* update

* revert re addition

* don't break that yet

* fmt

* revert unwanted

* update tokenizers version

* update dep table

* update

* update in conversion script as well

* some fix

* revert

* fully revert

* fix training

* remove set trace

* fixup

* update

* update

* [Pixtral] Improve docs, rename model (#33491)

* Improve docs, rename model

* Fix style

* Update repo id

* fix code quality after merge

* HFQuantizer implementation for compressed-tensors library (#31704)

* Add compressed-tensors HFQuantizer implementation

* flag serializable as False

* run

* revive lines deleted by ruff

* fixes to load+save from sparseml, edit config to quantization_config, and load back

* address satrat comment

* compressed_tensors to compressed-tensors and revert back is_serializable

* rename quant_method from sparseml to compressed-tensors

* tests

* edit tests

* clean up tests

* make style

* cleanup

* cleanup

* add test skip for when compressed tensors is not installed

* remove pydantic import + style

* delay torch import in test

* initial docs

* update main init for compressed tensors config

* make fix-copies

* docstring

* remove fill_docstring

* Apply suggestions from code review

Co-authored-by: Marc Sun <[email protected]>

* review comments

* review comments

* comments - suppress warnings on state dict load, tests, fixes

* bug-fix - remove unnecessary call to apply quant lifecycle

* run_compressed compatability

* revert changes not needed for compression

* no longer need unexpected keys fn

* unexpected keys not needed either

* Apply suggestions from code review

Co-authored-by: Marc Sun <[email protected]>

* add to_diff_dict

* update docs and expand testing

* Update _toctree.yml with compressed-tensors

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Arthur <[email protected]>

* update doc

* add note about saving a loaded model

---------

Co-authored-by: George Ohashi <[email protected]>
Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: Sara Adkins <[email protected]>
Co-authored-by: Sara Adkins <[email protected]>
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
Co-authored-by: Dipika <[email protected]>

* update model card for opt

* add batch size to inference table

* [slow-run] opt

* [run-slow] opt

---------

Signed-off-by: Dmitry Rogozhkin <[email protected]>
Co-authored-by: Avishai Elmakies <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: Pablo Montalvo <[email protected]>
Co-authored-by: chengchengpei <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: Yoni Gozlan <[email protected]>
Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: jiqing-feng <[email protected]>
Co-authored-by: Aarni Koskela <[email protected]>
Co-authored-by: Titus von Koeller <[email protected]>
Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Tibor Reiss <[email protected]>
Co-authored-by: Matt <[email protected]>
Co-authored-by: Lysandre Debut <[email protected]>
Co-authored-by: Muhammad Naufil <[email protected]>
Co-authored-by: sizhky <[email protected]>
Co-authored-by: Umar Butler <[email protected]>
Co-authored-by: Jonathan Mamou <[email protected]>
Co-authored-by: Dmitry Rogozhkin <[email protected]>
Co-authored-by: NielsRogge <[email protected]>
Co-authored-by: Arthur Zucker <[email protected]>
Co-authored-by: Benjamin Fineran <[email protected]>
Co-authored-by: George Ohashi <[email protected]>
Co-authored-by: Sara Adkins <[email protected]>
Co-authored-by: Sara Adkins <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
Co-authored-by: Dipika <[email protected]>
  • Loading branch information
1 parent 69b5ccb commit a265600
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 26 deletions.
67 changes: 67 additions & 0 deletions docs/source/en/model_doc/opt.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,73 @@ Below is an expected speedup diagram that compares pure inference time between t
</div>


### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```python
from transformers import OPTForCausalLM
model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, attn_implementation="sdpa")
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (L40S-45GB, PyTorch 2.4.0, OS Debian GNU/Linux 11) using `float16` with
[facebook/opt-350m](https://huggingface.co/facebook/opt-350m), we saw the
following speedups during training and inference.

### Training

| batch_size | seq_len | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) |
|--------------:|-----------:|:------------------------------|-----------------------------:|:---------------|:-----------------------|----------------------:|:------------------|
| 1 | 128 | 0.047 | 0.037 | 26.360 | 1474.611 | 1474.32 | 0.019 |
| 1 | 256 | 0.046 | 0.037 | 24.335 | 1498.541 | 1499.49 | -0.063 |
| 1 | 512 | 0.046 | 0.037 | 24.959 | 1973.544 | 1551.35 | 27.215 |
| 1 | 1024 | 0.062 | 0.038 | 65.135 | 4867.113 | 1698.35 | 186.578 |
| 1 | 2048 | 0.230 | 0.039 | 483.933 | 15662.224 | 2715.75 | 476.718 |
| 2 | 128 | 0.045 | 0.037 | 20.455 | 1498.164 | 1499.49 | -0.089 |
| 2 | 256 | 0.046 | 0.037 | 24.027 | 1569.367 | 1551.35 | 1.161 |
| 2 | 512 | 0.045 | 0.037 | 20.965 | 3257.074 | 1698.35 | 91.778 |
| 2 | 1024 | 0.122 | 0.038 | 225.958 | 9054.405 | 2715.75 | 233.403 |
| 2 | 2048 | 0.464 | 0.067 | 593.646 | 30572.058 | 4750.55 | 543.548 |
| 4 | 128 | 0.045 | 0.037 | 21.918 | 1549.448 | 1551.35 | -0.123 |
| 4 | 256 | 0.044 | 0.038 | 18.084 | 2451.768 | 1698.35 | 44.361 |
| 4 | 512 | 0.069 | 0.037 | 84.421 | 5833.180 | 2715.75 | 114.791 |
| 4 | 1024 | 0.262 | 0.062 | 319.475 | 17427.842 | 4750.55 | 266.860 |
| 4 | 2048 | OOM | 0.062 | Eager OOM | OOM | 4750.55 | Eager OOM |
| 8 | 128 | 0.044 | 0.037 | 18.436 | 2049.115 | 1697.78 | 20.694 |
| 8 | 256 | 0.048 | 0.036 | 32.887 | 4222.567 | 2715.75 | 55.484 |
| 8 | 512 | 0.153 | 0.06 | 154.862 | 10985.391 | 4750.55 | 131.245 |
| 8 | 1024 | 0.526 | 0.122 | 330.697 | 34175.763 | 8821.18 | 287.428 |
| 8 | 2048 | OOM | 0.122 | Eager OOM | OOM | 8821.18 | Eager OOM |

### Inference

| batch_size | seq_len | Per token latency eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) |
|--------------:|-----------:|--------------------------------:|-------------------------------:|---------------:|------------------:|---------------:|-----------------:|
| 1 | 128 | 11.634 | 8.647 | 34.546 | 717.676 | 717.674 | 0 |
| 1 | 256 | 11.593 | 8.86 | 30.851 | 742.852 | 742.845 | 0.001 |
| 1 | 512 | 11.515 | 8.816 | 30.614 | 798.232 | 799.593 | -0.17 |
| 1 | 1024 | 11.556 | 8.915 | 29.628 | 917.265 | 895.538 | 2.426 |
| 2 | 128 | 12.724 | 11.002 | 15.659 | 762.434 | 762.431 | 0 |
| 2 | 256 | 12.704 | 11.063 | 14.83 | 816.809 | 816.733 | 0.009 |
| 2 | 512 | 12.757 | 10.947 | 16.535 | 917.383 | 918.339 | -0.104 |
| 2 | 1024 | 13.018 | 11.018 | 18.147 | 1162.65 | 1114.81 | 4.291 |
| 4 | 128 | 12.739 | 10.959 | 16.243 | 856.335 | 856.483 | -0.017 |
| 4 | 256 | 12.718 | 10.837 | 17.355 | 957.298 | 957.674 | -0.039 |
| 4 | 512 | 12.813 | 10.822 | 18.393 | 1158.44 | 1158.45 | -0.001 |
| 4 | 1024 | 13.416 | 11.06 | 21.301 | 1653.42 | 1557.19 | 6.18 |
| 8 | 128 | 12.763 | 10.891 | 17.193 | 1036.13 | 1036.51 | -0.036 |
| 8 | 256 | 12.89 | 11.104 | 16.085 | 1236.98 | 1236.87 | 0.01 |
| 8 | 512 | 13.327 | 10.939 | 21.836 | 1642.29 | 1641.78 | 0.031 |
| 8 | 1024 | 15.181 | 11.175 | 35.848 | 2634.98 | 2443.35 | 7.843 |

## OPTConfig

Expand Down
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel)
* [OPT](https://huggingface.co/docs/transformers/en/model_doc/opt)
* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
Expand Down
177 changes: 152 additions & 25 deletions src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand Down Expand Up @@ -121,7 +124,7 @@ def __init__(
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
Expand Down Expand Up @@ -368,9 +371,108 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value


class OPTSdpaAttention(OPTAttention):
"""
OPT sdpa attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
The only required change would be on the forward pass where it needs to correctly call the public API of sdpa
attention and deal with padding tokens in case the input contains any of them.
"""

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions or layer_head_mask is not None:
logger.warning_once(
"OPTModel is using SDPA attention, which currently does not support output_attentions=True."
'failing back to eager attention. remove warning using attn_implementation="eager".'
)

return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
key_value_states=key_value_states,
) # TODO after merge add position_ids=position_ids
is_cross_attention = key_value_states is not None

bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states) * self.scaling
query_states = self._shape(query_states, -1, bsz)

# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)

# shape now is (bsz, num_heads, seq_len, head_dim), all are continuous

causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
# this model uses the scaling factor in the query projection for some reason, but not in Q@K^T
# so we need to scale to remove scaling in SDPA to have similar results with eager.
# Maybe needs a change in the model to remove scaling in query projection
scale=1.0,
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.out_proj(attn_output)

return attn_output, None, past_key_value


OPT_ATTENTION_CLASSES = {
"eager": OPTAttention,
"flash_attention_2": OptFlashAttention2,
"sdpa": OPTSdpaAttention,
}


Expand Down Expand Up @@ -499,6 +601,7 @@ class OPTPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["OPTDecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True

def _init_weights(self, module):
std = self.config.init_std
Expand Down Expand Up @@ -620,6 +723,7 @@ def __init__(self, config: OPTConfig):

self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"

self.gradient_checkpointing = False
# Initialize weights and apply final processing
Expand All @@ -631,6 +735,49 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.embed_tokens = value

def _update_causal_mask(
self,
inputs_embeds: torch.Tensor,
input_shape: Tuple[int, int],
past_key_values_length: int,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
):
"""
Updates the causal mask for the decoder.
"""
batch_size, seq_length = input_shape
mask_seq_length = past_key_values_length + seq_length
if self._use_flash_attention_2:
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if attention_mask is None
else attention_mask
)

return causal_attention_mask, attention_mask

if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
if self._use_sdpa and not output_attentions and head_mask is None:
causal_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
else:
causal_attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)

return causal_attention_mask, attention_mask

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -718,32 +865,12 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

batch_size, seq_length = input_shape
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length

causal_attention_mask, attention_mask = self._update_causal_mask(
inputs_embeds, input_shape, past_key_values_length, attention_mask, head_mask, output_attentions
)
# embed positions
if self._use_flash_attention_2:
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if attention_mask is None
else attention_mask
)
else:
# 4d mask is passed through the layers
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
causal_attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)

if position_ids is None:
position_ids = torch.cumsum(attention_mask, dim=1)
Expand Down
Loading

0 comments on commit a265600

Please sign in to comment.