Skip to content

[Bug] Segmentation Fault in PartitionTransformParams Pass with Relax IR #17460

Open
@Thrsu

Description

@Thrsu

I encountered a segmentation fault when applying the PartitionTransformParams pass to a Relax IR module that performs tensor concatenation and transposition operations. The segmentation fault occurs during the execution of the PartitionTransformParams pass, following a call to LiftTransformParams.

Actual behavior

[12:44:56] /software/tvm/src/runtime/logging.cc:390: TVM_LOG_DEBUG enables VLOG statements in 'ir/transform.cc' up to level 1
[12:44:56] /software/tvm/src/runtime/logging.cc:390: TVM_LOG_DEBUG enables VLOG statements in 'relay/ir/transform.cc' up to level 1
[12:44:57] /software/tvm/src/ir/transform.cc:479: Running pass PartitionTransformParams
[12:44:57] /software/tvm/src/ir/transform.cc:418: PartitionTransformParams: Executing module pass with opt level: 1
Segmentation fault (core dumped)

Steps to reproduce

import tvm
from tvm import relax

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def concatenate(rxplaceholder: T.Buffer((T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32"), T_concat: T.Buffer((T.int64(2), T.int64(4), T.int64(64), T.int64(64)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(64), T.int64(64)):
            with T.block("T_concat"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3], rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3])
                T.writes(T_concat[v_ax0, v_ax1, v_ax2, v_ax3])
                T_concat[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(T.int64(1) <= v_ax0, rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3], rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3])

    @T.prim_func(private=True)
    def transpose(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(64), T.int64(64)), "float32"), T_transpose: T.Buffer((T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64), T.int64(64), T.int64(4)):
            with T.block("T_transpose"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(rxplaceholder[v_ax0, v_ax3, v_ax1, v_ax2])
                T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
                T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax3, v_ax1, v_ax2]

    @R.function
    def fused_concatenate_transpose(inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
        R.func_attr({"Primitive": 1})
        cls = Module
        with R.dataflow():
            lv = R.call_tir(cls.concatenate, (inp_0, inp_0), out_sinfo=R.Tensor((2, 4, 64, 64), dtype="float32"))
            gv = R.call_tir(cls.transpose, (lv,), out_sinfo=R.Tensor((2, 64, 64, 4), dtype="float32"))
            R.output(gv)
        return gv

    @R.function
    def main(inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
        R.func_attr({"num_input": 3})
        cls = Module
        with R.dataflow():
            lv: R.Tensor((2, 64, 64, 4), dtype="float32") = cls.fused_concatenate_transpose(inp_0)
            R.output(lv)
        return lv

mod = Module
mod = relax.transform.LiftTransformParams()(mod)

Could you help me confirm what is causing this bug and provide guidance on how to resolve it?
CC @Lunderberg @ezyang

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions