From 86579eb88df9d5fab0a9d71034759170bfc65bb2 Mon Sep 17 00:00:00 2001 From: clee2000 <44682903+clee2000@users.noreply.github.com> Date: Mon, 22 Apr 2024 11:40:21 -0700 Subject: [PATCH] Add deprecation warning for `torch.backends.cuda.sdp_kernel` (#43) * only deprecation warning * typo * move to correct section in readme --- README.md | 7 +++++++ .../deprecated_symbols/checker/sdp_kernel.py | 12 ++++++++++++ .../deprecated_symbols/checker/sdp_kernel.txt | 4 ++++ torchfix/deprecated_symbols.yaml | 5 +++++ 4 files changed, 28 insertions(+) create mode 100644 tests/fixtures/deprecated_symbols/checker/sdp_kernel.py create mode 100644 tests/fixtures/deprecated_symbols/checker/sdp_kernel.txt diff --git a/README.md b/README.md index 1b825e1..9fbb7bd 100644 --- a/README.md +++ b/README.md @@ -103,5 +103,12 @@ Migration guide: `torch.nn.utils.parametrize.cached` before invoking the module in question. +#### torch.backends.cuda.sdp_kernel + +This function is deprecated. Use the `torch.nn.attention.sdpa_kernel` context manager instead. + +Migration guide: +Each boolean input parameter (defaulting to true unless specified) of `sdp_kernel` corresponds to a `SDPBackened`. If the input parameter is true, the corresponding backend should be added to the input list of `sdpa_kernel`. + ## License TorchFix is BSD License licensed, as found in the LICENSE file. diff --git a/tests/fixtures/deprecated_symbols/checker/sdp_kernel.py b/tests/fixtures/deprecated_symbols/checker/sdp_kernel.py new file mode 100644 index 0000000..06d14a8 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/sdp_kernel.py @@ -0,0 +1,12 @@ +import torch +from torch.backends import cuda +from torch.backends.cuda import sdp_kernel + +with torch.backends.cuda.sdp_kernel() as context: + pass + +with cuda.sdp_kernel() as context: + pass + +with sdp_kernel() as context: + pass diff --git a/tests/fixtures/deprecated_symbols/checker/sdp_kernel.txt b/tests/fixtures/deprecated_symbols/checker/sdp_kernel.txt new file mode 100644 index 0000000..d18f1ee --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/sdp_kernel.txt @@ -0,0 +1,4 @@ +3:1 TOR103 Import of deprecated function torch.backends.cuda.sdp_kernel: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel +5:6 TOR101 Use of deprecated function torch.backends.cuda.sdp_kernel: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel +8:6 TOR101 Use of deprecated function torch.backends.cuda.sdp_kernel: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel +11:6 TOR101 Use of deprecated function torch.backends.cuda.sdp_kernel: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel diff --git a/torchfix/deprecated_symbols.yaml b/torchfix/deprecated_symbols.yaml index b2bf4c5..9cce56d 100644 --- a/torchfix/deprecated_symbols.yaml +++ b/torchfix/deprecated_symbols.yaml @@ -65,6 +65,11 @@ remove_pr: reference: https://github.com/pytorch-labs/torchfix#torchnnutilsweight_norm +- name: torch.backends.cuda.sdp_kernel + deprecate_pr: https://github.com/pytorch/pytorch/pull/114689 + remove_pr: + reference: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel + # functorch - name: functorch.vmap deprecate_pr: TBA