Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

layer_norm backward problem #40

Open
FindHao opened this issue Nov 6, 2024 · 4 comments
Open

layer_norm backward problem #40

FindHao opened this issue Nov 6, 2024 · 4 comments

Comments

@FindHao
Copy link
Member

FindHao commented Nov 6, 2024

The bwd and fwd_bwd tests for layer_norm failed.

Error string is RuntimeError: This backward function was compiled with non-empty donated buffers which requires create_graph=False and retain_graph=False. Please keep backward(create_graph=False, retain_graph=False) across all backward() function calls, or set torch._functorch.config.donated_buffer=False to disable donated buffer.

Test Plan:

% python run.py --op layer_norm --precision fp32 --metrics latency,accuracy,speedup,gpu_peak_mem,mem_footprint --mode fwd_bwd

  3%|████████▏                                                                                                                                                                                                                                            | 1/30 [00:01<00:55,  1.91s/it]
Caught exception, terminating early with partial results
Traceback (most recent call last):
  File "/home/yhao/ptd/tritonbench/tritonbench/utils/triton_op.py", line 716, in run
    y_vals: Dict[str, BenchmarkOperatorMetrics] = functools.reduce(
                                                  ^^^^^^^^^^^^^^^^^
  File "/home/yhao/ptd/tritonbench/tritonbench/utils/triton_op.py", line 704, in _reduce_benchmarks
    acc[bm_name] = self._do_bench(
                   ^^^^^^^^^^^^^^^
  File "/home/yhao/ptd/tritonbench/tritonbench/utils/triton_op.py", line 923, in _do_bench
    metrics.latency = triton.testing.do_bench(
                      ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/triton/testing.py", line 106, in do_bench
    fn()
  File "/home/yhao/ptd/tritonbench/tritonbench/utils/triton_op.py", line 627, in <lambda>
    fwd_bwd_fn = lambda: (fwd_fn(), bwd_fn())
                                    ^^^^^^^^
  File "/home/yhao/ptd/tritonbench/tritonbench/operators/layer_norm/operator.py", line 50, in <lambda>
    return lambda: y.backward(dy, retain_graph=True)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/_tensor.py", line 624, in backward
    torch.autograd.backward(
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1705, in backward
    return impl_fn()
           ^^^^^^^^^
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1695, in impl_fn
    out = CompiledFunction._backward_impl(ctx, all_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2022, in _backward_impl
    torch._check(
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/__init__.py", line 1615, in _check
    _check_with(RuntimeError, cond, message)
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/__init__.py", line 1597, in _check_with
    raise error_type(message_evaluated)
RuntimeError: This backward function was compiled with non-empty donated buffers which requires create_graph=False and retain_graph=False. Please keep backward(create_graph=False, retain_graph=False) across all backward() function calls, or set torch._functorch.config.donated_buffer=False to disable donated buffer.
@FindHao
Copy link
Member Author

FindHao commented Nov 6, 2024

Not sure if set torch._functorch.config.donated_buffer=False is the correct way to solve it.

@xuzhao9
Copy link
Contributor

xuzhao9 commented Nov 14, 2024

This error only happens when --num_input input numbers is greater than 1

@xuzhao9
Copy link
Contributor

xuzhao9 commented Dec 2, 2024

Can we have a minimum reproduction of this error and report it to PT team?

I can reproduce with the following script:

M = 4096
import torch
import torch.nn.functional as F

def pt2_layernorm(*args):
    @torch.compile
    def inner(*args):
        return F.layer_norm(*args)
    return lambda: inner(*args)

for N in [512, 1024]:
    x_shape = (M, N)
    w_shape = (x_shape[-1],)
    x = -2.3 + 0.5 * torch.randn(x_shape, dtype=torch.float32, device="cuda")
    eps = 1e-5
    x.requires_grad_()
    weight = torch.rand(w_shape, dtype=torch.float32, device="cuda", requires_grad=True)
    bias = torch.rand(w_shape, dtype=torch.float32, device="cuda", requires_grad=True)
    yf = pt2_layernorm(x, w_shape, weight, bias, eps)
    y = yf()
    dy = 0.1 * torch.randn_like(y)
    for _i in range(2):
        y.backward(dy, retain_graph=True)

@FindHao
Copy link
Member Author

FindHao commented Dec 6, 2024

I found that this issue is more extensive than I initially thought. Other operators, fused_linear_cross_entropy, geglu, and swiglu, are also experiencing this issue. Considering that disabling donated buffers doesn't accurately represent real-world use cases, I think we may need a better approach to benchmarking backward. @xuzhao9

@FindHao FindHao reopened this Dec 6, 2024
facebook-github-bot pushed a commit that referenced this issue Dec 9, 2024
Summary:
It is still a temporary fix for backward benchmarking. Related discussion #40

Pull Request resolved: #104

Reviewed By: xuzhao9

Differential Revision: D66911331

Pulled By: FindHao

fbshipit-source-id: 6b3e5188fb6c929d6fe34aaf3a141bafa92c33f3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants