-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
I encountered a segmentation fault when applying the PartitionTransformParams pass to a Relax IR module that performs tensor concatenation and transposition operations. The segmentation fault occurs during the execution of the PartitionTransformParams pass, following a call to LiftTransformParams.
Actual behavior
[12:44:56] /software/tvm/src/runtime/logging.cc:390: TVM_LOG_DEBUG enables VLOG statements in 'ir/transform.cc' up to level 1
[12:44:56] /software/tvm/src/runtime/logging.cc:390: TVM_LOG_DEBUG enables VLOG statements in 'relay/ir/transform.cc' up to level 1
[12:44:57] /software/tvm/src/ir/transform.cc:479: Running pass PartitionTransformParams
[12:44:57] /software/tvm/src/ir/transform.cc:418: PartitionTransformParams: Executing module pass with opt level: 1
Segmentation fault (core dumped)
Steps to reproduce
import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def concatenate(rxplaceholder: T.Buffer((T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32"), T_concat: T.Buffer((T.int64(2), T.int64(4), T.int64(64), T.int64(64)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(64), T.int64(64)):
with T.block("T_concat"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3], rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_concat[v_ax0, v_ax1, v_ax2, v_ax3])
T_concat[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(T.int64(1) <= v_ax0, rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3], rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3])
@T.prim_func(private=True)
def transpose(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(64), T.int64(64)), "float32"), T_transpose: T.Buffer((T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64), T.int64(64), T.int64(4)):
with T.block("T_transpose"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(rxplaceholder[v_ax0, v_ax3, v_ax1, v_ax2])
T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax3, v_ax1, v_ax2]
@R.function
def fused_concatenate_transpose(inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
R.func_attr({"Primitive": 1})
cls = Module
with R.dataflow():
lv = R.call_tir(cls.concatenate, (inp_0, inp_0), out_sinfo=R.Tensor((2, 4, 64, 64), dtype="float32"))
gv = R.call_tir(cls.transpose, (lv,), out_sinfo=R.Tensor((2, 64, 64, 4), dtype="float32"))
R.output(gv)
return gv
@R.function
def main(inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
R.func_attr({"num_input": 3})
cls = Module
with R.dataflow():
lv: R.Tensor((2, 64, 64, 4), dtype="float32") = cls.fused_concatenate_transpose(inp_0)
R.output(lv)
return lv
mod = Module
mod = relax.transform.LiftTransformParams()(mod)Could you help me confirm what is causing this bug and provide guidance on how to resolve it?
CC @Lunderberg @ezyang
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug