Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Lifted parameters bindings may also be function output #17232

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/relax/transform/lift_transform_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,15 +488,24 @@ class LocalLiftableBindingCollector : public BaseLiftableBindingCollector {

} else {
info_.required_at_runtime.insert(binding->var);
for (const auto& upstream_var : FreeVars(bound_value)) {
info_.required_at_runtime.insert(upstream_var);
}
for (const auto& tir_var : FreeSymbolicVars(bound_value)) {
info_.required_at_runtime.insert(tir_var);
}

// Visit the bound value for expressions that must be computable
// at runtime, to populate the `required_at_runtime` set of
// variables. Populating it from `FreeVars(bound_value)` would
// not be sufficient, because it would omit variables that are
// used in a function's output.
VisitExpr(bound_value);
}
}

void VisitExpr_(const VarNode* op) override {
auto var = GetRef<Var>(op);
info_.required_at_runtime.insert(var);
}

LocalCollectInfo info_;
};

Expand Down
192 changes: 192 additions & 0 deletions tests/python/relax/test_transform_lift_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1821,5 +1821,197 @@ def main_transform_params(params: R.Tuple([R.Tensor([16], "int32")])):
tvm.ir.assert_structural_equal(after, Expected)


def test_lift_binding_that_produces_function_output():
"""It is possible that the function doesn't depend on any runtime values"""

@tvm.script.ir_module
class Before:
@R.function
def main(
A: R.Tensor([16], "int32"),
B: R.Tensor([16, 16], "int32"),
):
R.func_attr({"num_input": 1})
with R.dataflow():
B_t = R.permute_dims(B)
R.output(B_t)
return B_t

@tvm.script.ir_module
class Expected:
@R.function
def main(
A: R.Tensor([16], "int32"),
B_t: R.Tensor([16, 16], "int32"),
):
R.func_attr({"num_input": 1})
return B_t

@R.function
def main_transform_params(params: R.Tuple([R.Tensor([16, 16], "int32")])):
R.func_attr({"num_input": 0})
with R.dataflow():
B = params[0]
B_t = R.permute_dims(B)
output = (B_t,)
R.output(output)
return output

mod = Before
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, Expected)


def test_lift_binding_that_produces_part_of_function_output():
"""The function's output may include a compile-time value."""

@tvm.script.ir_module
class Before:
@R.function
def main(
A: R.Tensor([16], "int32"),
B: R.Tensor([16, 16], "int32"),
):
R.func_attr({"num_input": 1})
with R.dataflow():
B_t = R.permute_dims(B)
C = R.matmul(A, B_t)
R.output(C, B_t)
return (C, B_t)

@tvm.script.ir_module
class Expected:
@R.function
def main(
A: R.Tensor([16], "int32"),
B_t: R.Tensor([16, 16], "int32"),
):
R.func_attr({"num_input": 1})
with R.dataflow():
C = R.matmul(A, B_t)
R.output(C)

gv = (C, B_t)
return gv

@R.function
def main_transform_params(params: R.Tuple([R.Tensor([16, 16], "int32")])):
R.func_attr({"num_input": 0})
with R.dataflow():
B = params[0]
B_t = R.permute_dims(B)
output = (B_t,)
R.output(output)
return output

mod = Before
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, Expected)


def test_lift_shared_binding_that_produces_part_of_function_output():
"""The function's output may include a compile-time value."""

@tvm.script.ir_module
class Before:
@R.function
def main(
A: R.Tensor([16], "int32"),
B: R.Tensor([16, 16], "int32"),
):
R.func_attr({"num_input": 1})
with R.dataflow():
B_t = R.permute_dims(B)
C = R.matmul(A, B_t)
R.output(C, B_t)
return (C, B_t)

@tvm.script.ir_module
class Expected:
@R.function
def main(
A: R.Tensor([16], "int32"),
B_t: R.Tensor([16, 16], "int32"),
):
R.func_attr({"num_input": 1})
with R.dataflow():
C = R.matmul(A, B_t)
R.output(C)

gv = (C, B_t)
return gv

@R.function
def transform_params(params: R.Tuple([R.Tensor([16, 16], "int32")])):
R.func_attr({"num_input": 0})
with R.dataflow():
B = params[0]
B_t = R.permute_dims(B)
output = (B_t,)
R.output(output)
return output

mod = Before
after = relax.transform.LiftTransformParams(shared_transform=True)(mod)
tvm.ir.assert_structural_equal(after, Expected)


def test_lift_all_bindings_from_dataflow_block():
"""A variable that has no inputs should not be lifted

For example, `R.zeros`, or the result of allocation function
calls.
"""

@tvm.script.ir_module
class Before:
@R.function(pure=False)
def main(
A: R.Tensor([16], "int32"),
B: R.Tensor([16, 16], "int32"),
):
R.func_attr({"num_input": 1})
with R.dataflow():
B_t = R.permute_dims(B)
R.output(B_t)

_ = R.print(format="impure func")

with R.dataflow():
C = R.matmul(A, B_t)
R.output(C)

return C

@tvm.script.ir_module
class Expected:
@R.function(pure=False)
def main(
A: R.Tensor([16], "int32"),
B_t: R.Tensor([16, 16], "int32"),
):
R.func_attr({"num_input": 1})
_ = R.print(format="impure func")
with R.dataflow():
C = R.matmul(A, B_t)
R.output(C)

return C

@R.function
def main_transform_params(params: R.Tuple([R.Tensor([16, 16], "int32")])):
R.func_attr({"num_input": 0})
with R.dataflow():
B = params[0]
B_t = R.permute_dims(B)
output = (B_t,)
R.output(output)
return output

mod = Before
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, Expected)


if __name__ == "__main__":
tvm.testing.main()
Loading