Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

internal assert failure from sync_information.cpp #4052

Open
crcrpar opened this issue Mar 9, 2025 · 3 comments
Open

internal assert failure from sync_information.cpp #4052

crcrpar opened this issue Mar 9, 2025 · 3 comments
Assignees
Labels

Comments

@crcrpar
Copy link
Collaborator

crcrpar commented Mar 9, 2025

environment

  • Container: pjnl-20250309
  • CUDA devices: H100 80GB HBM3

repro script

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[64], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T1 = fd.define_tensor(shape=[16, 64], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T5 = fd.ops.broadcast_in_dim(T0, shape=[16, 64], broadcast_dims=[1])
    T6 = fd.ops.cast(T1, dtype=DataType.Float)
    T7 = fd.ops.cast(T5, dtype=DataType.Float)
    T8 = fd.ops.add(T6, T7)
    T9 = fd.ops.mul(T8, T8)
    T10 = fd.ops.mul(T9, T8)
    S11 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T12 = fd.ops.mul(S11, T8)
    S13 = fd.define_scalar(0.0447150, dtype=DataType.Double)
    T14 = fd.ops.mul(S13, T10)
    T15 = fd.ops.add(T8, T14)
    S16 = fd.define_scalar(0.797885, dtype=DataType.Double)
    T17 = fd.ops.mul(S16, T15)
    T18 = fd.ops.tanh(T17)
    S19 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T20 = fd.ops.add(S19, T18)
    T21 = fd.ops.mul(T12, T20)
    T22 = fd.ops.abs(T21)
    T23 = fd.ops.max(T22, dims=[0, 1], keepdim=False, dtype=DataType.Null)
    T24 = fd.ops.cast(T23, dtype=DataType.Double)
    T25 = fd.ops.ne(T24, T24)
    S26 = fd.define_scalar(1.00000e-12, dtype=DataType.Double)
    T27 = fd.ops.gt(T24, S26)
    S28 = fd.define_scalar(1.00000e-12, dtype=DataType.Double)
    T29 = fd.ops.where(T27, T24, S28)
    T30 = fd.ops.cast(T29, dtype=DataType.Double)
    T31 = fd.ops.where(T25, T24, T30)
    T32 = fd.ops.cast(T31, dtype=DataType.Double)
    S33 = fd.define_scalar(448.000, dtype=DataType.Double)
    T34 = fd.ops.reciprocal(T32)
    T35 = fd.ops.mul(S33, T34)
    T36 = fd.ops.cast(T35, dtype=DataType.Float)
    T40 = fd.ops.broadcast_in_dim(T36, shape=[16, 64], broadcast_dims=[])
    T41 = fd.ops.mul(T21, T40)
    T42 = fd.ops.ne(T41, T41)
    S43 = fd.define_scalar(-448.000, dtype=DataType.Double)
    T44 = fd.ops.gt(T41, S43)
    S45 = fd.define_scalar(-448.000, dtype=DataType.Double)
    T46 = fd.ops.where(T44, T41, S45)
    T47 = fd.ops.cast(T46, dtype=DataType.Float)
    T48 = fd.ops.where(T42, T41, T47)
    T49 = fd.ops.cast(T48, dtype=DataType.Float)
    T50 = fd.ops.ne(T49, T49)
    S51 = fd.define_scalar(448.000, dtype=DataType.Double)
    T52 = fd.ops.lt(T49, S51)
    S53 = fd.define_scalar(448.000, dtype=DataType.Double)
    T54 = fd.ops.where(T52, T49, S53)
    T55 = fd.ops.cast(T54, dtype=DataType.Float)
    T56 = fd.ops.where(T50, T49, T55)
    T57 = fd.ops.cast(T56, dtype=DataType.Float)
    fd.add_output(T36)
    fd.add_output(T57)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.testing.make_tensor((64,), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((16, 64), dtype=torch.bfloat16, device='cuda:0'),
]
fd.execute(inputs)

error

I got what follows with NVFUSER_DISABLE=parallel_compile python repro.py

Traceback (most recent call last):
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 317, in execute
    out_tensors, out_shardings = self._execute(
                                 ^^^^^^^^^^^^^^
RuntimeError:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp":812, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Inconsistent parallelization found between TV62 (T62_l___bfloat[iS460{1}, iUS462{1}, ithreadIdx.x463{16}_p, iV459{4}]) and TV2(T2_l___bfloat[iS209{1}, iUS211{1}, ithreadIdx.x212{16}_p, iS208{4}] ca_pos( 4 )). Producer is required to be in Global, Shared or Tensor Memory based on parallelization strategy. RAW flags: (threadIdx.x)
Exception raised from SyncMap at /opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp:812 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x103 (0x7fdb2f992cd1 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x62 (0x7fdb2fdde932 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #2: nvfuser::SyncMap::SyncMap(nvfuser::Fusion*) + 0x1909 (0x7fdb2fc638e9 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x483ffc (0x7fdb2fc97ffc in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #4: nvfuser::GpuLower::GpuLower(nvfuser::Fusion*, nvfuser::CompileParams const&) + 0xf59 (0x7fdb2fc9a059 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #5: nvfuser::CompiledKernel::CompiledKernel(nvfuser::Fusion*, nvfuser::CompileParams, c10::Device, nvfuser::SchedulerType, long, long, long, long, std::vector<std::function<void (nvfuser::GpuLower*)>, std::allocator<std::function<void (nvfuser::GpuLower*)> > > const&, std::vector<std::function<void (nvfuser::kir::Kernel*)>, std::allocator<std::function<void (nvfuser::kir::Kernel*)> > > const&) + 0xbc (0x7fdb3013e56c in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #6: nvfuser::KernelExecutor::compile(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::SchedulerType) + 0x29e (0x7fdb3014f9fe in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x944fa0 (0x7fdb30158fa0 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #8: <unknown function> + 0x976eb0 (0x7fdb3018aeb0 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #9: nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder) + 0x63f (0x7fdb3018d61f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #10: nvfuser::FusionExecutorCache::runFusionWithInputs(nvfuser::KernelArgumentHolder, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0x139 (0x7fdb301833e9 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #11: nvfuser::python_frontend::FusionDefinition::execute(nvfuser::KernelArgumentHolder, std::optional<signed char>, bool, bool, bool, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >) const + 0xa20 (0x7fdb30360490 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0x21f4ed (0x7fdb2fa334ed in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #13: <unknown function> + 0x2eedfd (0x7fdb2fb02dfd in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #14: <unknown function> + 0x20cb53 (0x7fdb2fa20b53 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #15: python() [0x58208f]
<omitting python frames>
frame #19: python() [0x608b42]
frame #20: python() [0x6b4e93]
frame #25: <unknown function> + 0x2a1ca (0x7fecc687a1ca in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #26: __libc_start_main + 0x8b (0x7fecc687a28b in /usr/lib/x86_64-linux-gnu/libc.so.6)

Traceback (most recent call last):
  File "/opt/pytorch/nvfuser/repro.py", line 67, in <module>
    fd.execute(inputs)
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 317, in execute
    out_tensors, out_shardings = self._execute(
                                 ^^^^^^^^^^^^^^
RuntimeError:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp":812, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Inconsistent parallelization found between TV62 (T62_l___bfloat[iS460{1}, iUS462{1}, ithreadIdx.x463{16}_p, iV459{4}]) and TV2(T2_l___bfloat[iS209{1}, iUS211{1}, ithreadIdx.x212{16}_p, iS208{4}] ca_pos( 4 )). Producer is required to be in Global, Shared or Tensor Memory based on parallelization strategy. RAW flags: (threadIdx.x)
Exception raised from SyncMap at /opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp:812 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x103 (0x7fdb2f992cd1 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x62 (0x7fdb2fdde932 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #2: nvfuser::SyncMap::SyncMap(nvfuser::Fusion*) + 0x1909 (0x7fdb2fc638e9 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x483ffc (0x7fdb2fc97ffc in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #4: nvfuser::GpuLower::GpuLower(nvfuser::Fusion*, nvfuser::CompileParams const&) + 0xf59 (0x7fdb2fc9a059 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #5: nvfuser::CompiledKernel::CompiledKernel(nvfuser::Fusion*, nvfuser::CompileParams, c10::Device, nvfuser::SchedulerType, long, long, long, long, std::vector<std::function<void (nvfuser::GpuLower*)>, std::allocator<std::function<void (nvfuser::GpuLower*)> > > const&, std::vector<std::function<void (nvfuser::kir::Kernel*)>, std::allocator<std::function<void (nvfuser::kir::Kernel*)> > > const&) + 0xbc (0x7fdb3013e56c in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #6: nvfuser::KernelExecutor::compile(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::SchedulerType) + 0x29e (0x7fdb3014f9fe in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x944fa0 (0x7fdb30158fa0 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #8: <unknown function> + 0x976eb0 (0x7fdb3018aeb0 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #9: nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder) + 0x63f (0x7fdb3018d61f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #10: nvfuser::FusionExecutorCache::runFusionWithInputs(nvfuser::KernelArgumentHolder, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0x139 (0x7fdb301833e9 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #11: nvfuser::python_frontend::FusionDefinition::execute(nvfuser::KernelArgumentHolder, std::optional<signed char>, bool, bool, bool, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >) const + 0xa20 (0x7fdb30360490 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0x21f4ed (0x7fdb2fa334ed in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #13: <unknown function> + 0x2eedfd (0x7fdb2fb02dfd in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #14: <unknown function> + 0x20cb53 (0x7fdb2fa20b53 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #15: python() [0x58208f]
<omitting python frames>
frame #19: python() [0x608b42]
frame #20: python() [0x6b4e93]
frame #25: <unknown function> + 0x2a1ca (0x7fecc687a1ca in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #26: __libc_start_main + 0x8b (0x7fecc687a28b in /usr/lib/x86_64-linux-gnu/libc.so.6)

context

I faced this while running test_torchao_float8_linear of test_tensor_subclass.py with -k nvfuser and bfloat16 and true where torchao is v0.7.0.

@naoyam
Copy link
Collaborator

naoyam commented Mar 13, 2025

This seems like exposing a non-trivial bug in how the normalization schedulers cache inputs. I need to think about it more deeply. In the meantime, this should work around the issue for this repro, if you really need this repro to work now:

diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp
index 65ec6f9e6..c4f2777ca 100644
--- a/csrc/scheduler/utils.cpp
+++ b/csrc/scheduler/utils.cpp
@@ -703,7 +703,8 @@ PersistentBufferInfo persistentBuffers(Fusion* fusion) {
        persistent_buffer_info.projectable_persistent_buffers.end()});
   for (auto expr : all_exprs) {
     if (expr->isA<UnaryOp>() &&
-        expr->as<UnaryOp>()->getUnaryOpType() == UnaryOpType::Exp) {
+        (expr->as<UnaryOp>()->getUnaryOpType() == UnaryOpType::Exp ||
+         expr->as<UnaryOp>()->getUnaryOpType() == UnaryOpType::Tanh)) {
       persistent_buffer_info.projection_with_exp_op = true;
     }

Note that this is really just an ad-hoc WAR for this particular repro.

@crcrpar
Copy link
Collaborator Author

crcrpar commented Mar 13, 2025

The model is defined as follows:

    model = nn.Sequential(
        nn.Linear(in_features, out_features, bias=bias),
        nn.GELU(approximate="tanh"),
        nn.Linear(out_features, out_features, bias=bias),
    ).to(device=device, dtype=torch_dtype)

Would you expect to not see this error with a model of linear -> exp -> linear?
This question is rather to check if I read the patch correctly

@naoyam
Copy link
Collaborator

naoyam commented Mar 13, 2025

The model is defined as follows:

model = nn.Sequential(
    nn.Linear(in_features, out_features, bias=bias),
    nn.GELU(approximate="tanh"),
    nn.Linear(out_features, out_features, bias=bias),
).to(device=device, dtype=torch_dtype)

Would you expect to not see this error with a model of linear -> exp -> linear? This question is rather to check if I read the patch correctly

Not sure. It turned out the issue can be rather pervasive. I created a separate issue to clarify what's happening (#4074).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants