Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Segmentation Fault in PartitionTransformParams Pass with Relax IR #17460

Open
Thrsu opened this issue Oct 11, 2024 · 0 comments
Open

[Bug] Segmentation Fault in PartitionTransformParams Pass with Relax IR #17460

Thrsu opened this issue Oct 11, 2024 · 0 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@Thrsu
Copy link
Contributor

Thrsu commented Oct 11, 2024

I encountered a segmentation fault when applying the PartitionTransformParams pass to a Relax IR module that performs tensor concatenation and transposition operations. The segmentation fault occurs during the execution of the PartitionTransformParams pass, following a call to LiftTransformParams.

Actual behavior

[12:44:56] /software/tvm/src/runtime/logging.cc:390: TVM_LOG_DEBUG enables VLOG statements in 'ir/transform.cc' up to level 1
[12:44:56] /software/tvm/src/runtime/logging.cc:390: TVM_LOG_DEBUG enables VLOG statements in 'relay/ir/transform.cc' up to level 1
[12:44:57] /software/tvm/src/ir/transform.cc:479: Running pass PartitionTransformParams
[12:44:57] /software/tvm/src/ir/transform.cc:418: PartitionTransformParams: Executing module pass with opt level: 1
Segmentation fault (core dumped)

Steps to reproduce

import tvm
from tvm import relax

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def concatenate(rxplaceholder: T.Buffer((T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32"), T_concat: T.Buffer((T.int64(2), T.int64(4), T.int64(64), T.int64(64)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(64), T.int64(64)):
            with T.block("T_concat"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3], rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3])
                T.writes(T_concat[v_ax0, v_ax1, v_ax2, v_ax3])
                T_concat[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(T.int64(1) <= v_ax0, rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3], rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3])

    @T.prim_func(private=True)
    def transpose(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(64), T.int64(64)), "float32"), T_transpose: T.Buffer((T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64), T.int64(64), T.int64(4)):
            with T.block("T_transpose"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(rxplaceholder[v_ax0, v_ax3, v_ax1, v_ax2])
                T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
                T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax3, v_ax1, v_ax2]

    @R.function
    def fused_concatenate_transpose(inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
        R.func_attr({"Primitive": 1})
        cls = Module
        with R.dataflow():
            lv = R.call_tir(cls.concatenate, (inp_0, inp_0), out_sinfo=R.Tensor((2, 4, 64, 64), dtype="float32"))
            gv = R.call_tir(cls.transpose, (lv,), out_sinfo=R.Tensor((2, 64, 64, 4), dtype="float32"))
            R.output(gv)
        return gv

    @R.function
    def main(inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
        R.func_attr({"num_input": 3})
        cls = Module
        with R.dataflow():
            lv: R.Tensor((2, 64, 64, 4), dtype="float32") = cls.fused_concatenate_transpose(inp_0)
            R.output(lv)
        return lv

mod = Module
mod = relax.transform.LiftTransformParams()(mod)

Could you help me confirm what is causing this bug and provide guidance on how to resolve it?
CC @Lunderberg @ezyang

@Thrsu Thrsu added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Oct 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

1 participant