Skip to content

Commit

Permalink
Support returning results for foreach_transform API
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchenye committed Mar 19, 2024
1 parent 20ad02b commit 8370521
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions python/scalehls/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,12 @@ def match_linalg_result(
return match_result_op.results[0]


def foreach_transform():
def foreach_transform(result_types: Sequence[Type] = []):
"""
A decorator to build a `transform.foreach` op containing the ops built by
the decorated function. `transform.foreach` op is used to transform multiple
ops contained by a single handle.
ops contained by a single handle. `result_types` must be the same as the
return types of decorated function.
The decorated function must have a `BlockArgument` as its first argument,
which is the handle of the target op to be transformed. Any other arguments
Expand All @@ -200,12 +201,12 @@ def foreach_transform():
def decorator(body_builder: Callable[..., None]):
@wraps(body_builder)
def wrapper(target: Value, *args, **kwargs):
foreach = transform.ForeachOp([], target)
foreach = transform.ForeachOp(result_types, target)
foreach_block = Block.create_at_start(foreach.body, [target.type])
with InsertionPoint.at_block_begin(foreach_block):
body_builder(
results = body_builder(
foreach_block.arguments[0], *args, **kwargs) # type: ignore
transform.YieldOp()
transform.YieldOp(results)
return wrapper
return decorator

Expand Down

0 comments on commit 8370521

Please sign in to comment.