Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] [Relax] Build fails when applying dlight.gpu.GeneralReduction to R.nn.group_norm with dynamic shapes and R.reshape #17531

Open
Yumin-gd opened this issue Nov 14, 2024 · 0 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@Yumin-gd
Copy link

Yumin-gd commented Nov 14, 2024

Actual behavior

When building the TVMScript below using dlight.gpu.GeneralReduction(), the build fails with the following error:
InternalError: Check failed: (!divisor.is_const(0)) is false: Find divide by zero

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def reshape_norm(
        inp_0: R.Tensor((1, 512, "w", "h"), dtype="float16"), 
        inp_1: R.Tensor((512,), dtype="float16"), 
        inp_2: R.Tensor((512,), dtype="float16")
        )-> R.Tensor((1, 512, "w * h"), dtype="float16"):
        w = T.int64()
        h = T.int64()
        with R.dataflow():
            lv = R.reshape(inp_0, R.shape([1, 512, w * h]))
            lv1 = R.nn.group_norm(data = lv, gamma = inp_1, beta = inp_2, num_groups=32, channel_axis=1, axes=[2], epsilon=9.9999999999999995e-07, center=True, scale=True)
            R.output(lv1)
        return lv1
  • However, if I modify the input tensor inp_0 and the output tensor shape to (1, 512, "n") and remove the R.reshape operation, the build completes successfully without errors.
  • It works well with other dlight schedules. If I remove dl.gpu.GeneralReduction(), the build also completes with other dlight schedules.

Environment

  • TVM Version: v0.18.0
  • Commit Hash: 30b7b1c

Steps to reproduce

import tvm
from tvm import relax
import tvm.dlight as dl

@tvm.transform.module_pass(opt_level=0)
def dynshape_build_pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:
    seq = tvm.transform.Sequential(
        [
            relax.backend.DispatchSampling(),
            relax.backend.DispatchSortScan(),
            relax.transform.LegalizeOps(),
            dl.ApplyDefaultSchedule(
                dl.gpu.Matmul(),
                dl.gpu.GEMV(),
                dl.gpu.Reduction(),
                dl.gpu.GeneralReduction(),
                dl.gpu.Fallback(),
            ),
            relax.transform.RewriteDataflowReshape(),
            relax.transform.ToNonDataflow(),
            relax.transform.RemovePurityChecking(),
            relax.transform.CallTIRRewrite(),
            relax.transform.StaticPlanBlockMemory(),
            relax.transform.RewriteCUDAGraph(),
            relax.transform.LowerAllocTensor(),
            relax.transform.KillAfterLastUse(),
            relax.transform.LowerRuntimeBuiltin(),
            relax.transform.ComputePrimValue(),
            relax.transform.VMShapeLower(),
            relax.transform.AttachGlobalSymbol(),
        ],
    )
    mod = seq(mod)
    return mod

# `Module` as TVMScript in 'Actual behavior'
mod = Module
mod = relax.get_pipeline()(mod)
target = tvm.target.Target("cuda")
ex = relax.build(mod, target=target, pipeline=dynshape_build_pipeline)

cc @junrushao

@Yumin-gd Yumin-gd added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Nov 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

1 participant