Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError from normalization_utils.cpp, could not resolve persistent buffer #4020

Open
riccardofelluga opened this issue Mar 6, 2025 · 0 comments

Comments

@riccardofelluga
Copy link
Contributor

RuntimeError:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/scheduler/normalization_utils.cpp":1544, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Could not resolve persistent buffer: T20_g_float[bS85{1}, iS90{32}rf, iS91{10}rf, iS87{64}, iS88{64}]
Exception raised from getResolutionPointsOf at /opt/pytorch/nvfuser/csrc/scheduler/normalization_utils.cpp:1544 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x103 (0x7f1894f3d769 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x62 (0x7f189538b232 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #2: <unknown function> + 0x9f0213 (0x7f18957af213 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x9f02a1 (0x7f18957af2a1 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0xa7d46d (0x7f189583c46d in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0xa08aef (0x7f18957c7aef in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0xa14a84 (0x7f18957d3a84 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x63d676 (0x7f18953fc676 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #8: <unknown function> + 0x63da9e (0x7f18953fca9e in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #9: <unknown function> + 0x63dbf5 (0x7f18953fcbf5 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #10: <unknown function> + 0x64d8e9 (0x7f189540c8e9 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #11: nvfuser::SegmentCandidateFinder::SegmentCandidateFinder(std::unique_ptr<nvfuser::Fusion, std::default_delete<nvfuser::Fusion> >, nvfuser::KernelArgumentHolder const&, nvfuser::SegmentCandidateFinderOptions, bool) + 0x3a3 (0x7f189540ce03 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0x64dfc5 (0x7f189540cfc5 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #13: <unknown function> + 0x64e3ff (0x7f189540d3ff in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #14: <unknown function> + 0x977082 (0x7f1895736082 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #15: <unknown function> + 0x96db08 (0x7f189572cb08 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #16: nvfuser::FusionExecutorCache::runFusionWithInputs(nvfuser::KernelArgumentHolder, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0xbd (0x7f189572d55d in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #17: nvfuser::python_frontend::FusionDefinition::execute(nvfuser::KernelArgumentHolder, std::optional<signed char>, bool, bool, bool, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >) const + 0xe4f (0x7f189590a69f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)

Repro:

# CUDA devices:
#  0: NVIDIA RTX 6000 Ada Generation
# torch version: 2.7.0.dev20250303+cu128
# cuda version: 12.8
# nvfuser version: 0.2.26+git4137de7
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, 320, 64, 64], contiguity=[None, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    T1 = fd.define_tensor(shape=[320], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T2 = fd.define_tensor(shape=[1, 32], contiguity=[None, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T3 = fd.define_tensor(shape=[1, 320, 64, 64], contiguity=[None, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    T4 = fd.define_tensor(shape=[1, 32, 1, 1, 1], contiguity=[None, True, None, None, None], dtype=DataType.Float, is_cpu=False, stride_order=[4, 3, 2, 1, 0])
    T5 = fd.define_tensor(shape=[1, 320, 64, 64], contiguity=[None, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    T6 = fd.define_tensor(shape=[320, 320, 1, 1], contiguity=[True, True, None, None], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    T7 = fd.define_tensor(shape=[320], contiguity=[True], dtype=DataType.Half, is_cpu=False, stride_order=[0])
    S8 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T18 = fd.ops.pad(T0, [0, 0, 0, 0, 0, 0, 0, 0], S8)
    T24 = fd.ops.reshape(T1, new_shape=[1, 320, 1, 1])
    T25 = fd.ops.cast(T18, dtype=DataType.Float)
    T31 = fd.ops.broadcast_in_dim(T24, shape=[1, 320, 64, 64], broadcast_dims=[0, 1, 2, 3])
    T38 = fd.ops.broadcast_in_dim(T2, shape=[1, 32, 1, 1, 1], broadcast_dims=[0, 1])
    T45 = fd.ops.reshape(T3, new_shape=[1, 32, 10, 64, 64])
    T46 = fd.ops.mul(T31, T25)
    T53 = fd.ops.broadcast_in_dim(T38, shape=[1, 32, 10, 64, 64], broadcast_dims=[0, 1, 2, 3, 4])
    T54 = fd.ops.cast(T45, dtype=DataType.Float)
    T61 = fd.ops.reshape(T46, new_shape=[1, 32, 10, 64, 64])
    T62 = fd.ops.sub(T54, T53)
    T63 = fd.ops.mul(T62, T61)
    T64 = fd.ops.sum(T63, dims=[0, 2, 3, 4], keepdim=False, dtype=DataType.Null)
    T71 = fd.ops.broadcast_in_dim(T64, shape=[1, 32, 1, 1, 1], broadcast_dims=[1])
    T78 = fd.ops.broadcast_in_dim(T4, shape=[1, 32, 10, 64, 64], broadcast_dims=[0, 1, 2, 3, 4])
    S79 = fd.define_scalar(3.00000, dtype=DataType.Double)
    T80 = fd.ops.pow(T4, S79)
    S81 = fd.define_scalar(-0.500000, dtype=DataType.Double)
    T82 = fd.ops.mul(S81, T71)
    T83 = fd.ops.mul(T78, T61)
    T84 = fd.ops.mul(T82, T80)
    T85 = fd.ops.neg(T83)
    T86 = fd.ops.sum(T84, dims=[0, 2, 3, 4], keepdim=False, dtype=DataType.Null)
    T87 = fd.ops.sum(T85, dims=[0, 2, 3, 4], keepdim=False, dtype=DataType.Null)
    T91 = fd.ops.broadcast_in_dim(T86, shape=[1, 32], broadcast_dims=[1])
    T98 = fd.ops.broadcast_in_dim(T87, shape=[1, 32, 1, 1, 1], broadcast_dims=[1])
    T105 = fd.ops.broadcast_in_dim(T2, shape=[1, 32, 1, 1, 1], broadcast_dims=[0, 1])
    T112 = fd.ops.broadcast_in_dim(T91, shape=[1, 32, 1, 1, 1], broadcast_dims=[0, 1])
    T113 = fd.ops.sum(T98, dims=[0, 2, 3, 4], keepdim=False, dtype=DataType.Null)
    T120 = fd.ops.broadcast_in_dim(T105, shape=[1, 32, 10, 64, 64], broadcast_dims=[0, 1, 2, 3, 4])
    T127 = fd.ops.broadcast_in_dim(T112, shape=[1, 32, 10, 64, 64], broadcast_dims=[0, 1, 2, 3, 4])
    T131 = fd.ops.broadcast_in_dim(T113, shape=[1, 32], broadcast_dims=[1])
    T132 = fd.ops.sub(T54, T120)
    S133 = fd.define_scalar(2.00000, dtype=DataType.Double)
    T134 = fd.ops.mul(S133, T127)
    T141 = fd.ops.broadcast_in_dim(T131, shape=[1, 32, 1, 1, 1], broadcast_dims=[0, 1])
    T142 = fd.ops.mul(T134, T132)
    T149 = fd.ops.broadcast_in_dim(T141, shape=[1, 32, 10, 64, 64], broadcast_dims=[0, 1, 2, 3, 4])
    S150 = fd.define_scalar(40960.0, dtype=DataType.Double)
    S151 = fd.ops.reciprocal(S150)
    T152 = fd.ops.mul(T142, S151)
    S153 = fd.define_scalar(2.44141e-05, dtype=DataType.Double)
    T154 = fd.ops.mul(S153, T149)
    T155 = fd.ops.add(T154, T152)
    T156 = fd.ops.mul(T62, T78)
    T157 = fd.ops.add(T83, T155)
    T163 = fd.ops.reshape(T156, new_shape=[1, 320, 64, 64])
    T164 = fd.ops.cast(T157, dtype=DataType.Half)
    T165 = fd.ops.mul(T163, T25)
    T171 = fd.ops.reshape(T164, new_shape=[1, 320, 64, 64])
    T172 = fd.ops.sum(T25, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null)
    T173 = fd.ops.sum(T165, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null)
    T174 = fd.ops.cast(T171, dtype=DataType.Float)
    T175 = fd.ops.cast(T5, dtype=DataType.Float)
    T176 = fd.ops.permute(T6, dims=[1, 0, 2, 3])
    T182 = fd.ops.broadcast_in_dim(T172, shape=[1, 320, 1, 1], broadcast_dims=[1])
    T188 = fd.ops.broadcast_in_dim(T173, shape=[1, 320, 1, 1], broadcast_dims=[1])
    T189 = fd.ops.add(T175, T174)
    T190 = fd.ops.cast(T7, dtype=DataType.Float)
    T191 = fd.ops.cast(T176, dtype=DataType.Float)
    T194 = fd.ops.reshape(T182, new_shape=[320])
    T197 = fd.ops.reshape(T188, new_shape=[320])
    T198 = fd.ops.cast(T189, dtype=DataType.Half)
    fd.add_output(T190)
    fd.add_output(T191)
    fd.add_output(T194)
    fd.add_output(T197)
    fd.add_output(T198)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.testing.make_tensor((1, 320, 64, 64), dtype=torch.float16, device='cuda:0'),
    torch.testing.make_tensor((320,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((1, 32), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((1, 320, 64, 64), dtype=torch.float16, device='cuda:0'),
    torch.testing.make_tensor((1, 32, 1, 1, 1), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((1, 320, 64, 64), dtype=torch.float16, device='cuda:0'),
    torch.testing.make_tensor((320, 320, 1, 1), dtype=torch.float16, device='cuda:0'),
    torch.testing.make_tensor((320,), dtype=torch.float16, device='cuda:0'),
]
fd.execute(inputs)

Looks similar to #1123 but it comes from a different file :(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant