Skip to content

Commit

Permalink
Support linalg "unrolling" in dse
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchenye committed Mar 20, 2024
1 parent c7b2ba8 commit ad377d1
Showing 1 changed file with 52 additions and 39 deletions.
91 changes: 52 additions & 39 deletions python/scalehls/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,11 @@ def annotate(target: Value, annotation: str, param=None):
return transform.AnnotateOp(target, annotation, param=param)


def tile(linalg_op_handle: Value, sizes: Sequence[int]):
return linalg_transform.TileUsingForOp(linalg_op_handle, sizes=sizes)
def tile(linalg_op_handle: Value,
sizes: Sequence[int],
interchange: Union[Sequence[int], None] = None):
return linalg_transform.TileUsingForOp(
linalg_op_handle, sizes=sizes, interchange=interchange)


def tile_reduction(linalg_op_handle: Value, sizes: Sequence[int]):
Expand Down Expand Up @@ -400,10 +403,11 @@ def foreach_merge_consecutive_extract_slice_and_convert_to_itensor_read(
convert_extract_slice_to_itensor_read(merge_op.result)


def convert_full_tensor_linalg_generic_to_itensor(
def convert_full_tensor_linalg_op_to_itensor(
linalg_op_handle: Value,
parallel_tile_sizes: Sequence[int],
reduction_tile_sizes: Sequence[int],
unroll_sizes: Sequence[int],
permutation: Sequence[int],
has_input: bool = True,
split_combine_reduction=False):
Expand Down Expand Up @@ -448,35 +452,24 @@ def convert_full_tensor_linalg_generic_to_itensor(
foreach_merge_consecutive_extract_slice_and_convert_to_itensor_read(
matched_input)

# Interchange the loops of the linalg op with the given permutation.
interchange_op = interchange(linalg_op_handle, permutation)
linalg_op_handle = interchange_op.transformed
# Interchange and "unroll" the linalg op with the given unroll sizes.
unroll_op = tile(linalg_op_handle, unroll_sizes, permutation)
linalg_op_handle = unroll_op.tiled_linalg_op
return linalg_op_handle


# ===----------------------------------------------------------------------=== #
# Design Space Exploration Utils
# ===----------------------------------------------------------------------=== #


# ===----------------------------------------------------------------------=== #
# DesignSpaceGraph Class
# ===----------------------------------------------------------------------=== #


class DesignSpaceGraph(nx.Graph):
def __init__(self,
module: Module,
top_name: str = "forward",
default_tile_size: int = 16,
default_unroll_size: int = 4):
def __init__(self, module: Module, top_name: str = "forward"):
super().__init__()
self.module = module
self.top = find_func(self.module, top_name)
if self.top is None:
raise ValueError("top function `" + top_name + "` not found")
self.default_tile_size = default_tile_size
self.default_unroll_size = default_unroll_size

self.add_node(self.top, name=self.top.OPERATION_NAME, id=-1)
for id, op in enumerate(self.top.entry_block):
Expand All @@ -495,21 +488,8 @@ def is_nontrivial_node(node: Operation):
node, (arith.ConstantOp, hls.TensorInitOp, tensor.EmptyOp))

@staticmethod
def get_generic_op_naive_permutation(node: linalg.GenericOp):
loop_properties = extract_loop_properties(node)
numReduction = 0
interchange_permutation = []
for index, (_, type) in enumerate(loop_properties):
if type == "parallel":
interchange_permutation.append(index)
elif type == "reduction":
interchange_permutation.insert(numReduction, index)
numReduction += 1
return interchange_permutation

@staticmethod
def get_generic_op_naive_tile_sizes(node: linalg.GenericOp,
default_tile_size: int = 16):
def get_linalg_op_naive_tile_sizes(node: linalg.GenericOp,
default_tile_size: int = 16):
loop_properties = extract_loop_properties(node)

parallel_tile_sizes = []
Expand All @@ -524,6 +504,34 @@ def get_generic_op_naive_tile_sizes(node: linalg.GenericOp,
reduction_tile_sizes.append(tile_size)
return parallel_tile_sizes, reduction_tile_sizes

@staticmethod
def get_linalg_op_naive_unroll_sizes(node: linalg.GenericOp,
default_unroll_size: int = 2):
loop_properties = extract_loop_properties(node)

unroll_sizes = []
for range, type in loop_properties:
unroll_size = default_unroll_size if range > default_unroll_size else 0
if type == "parallel":
unroll_sizes.append(unroll_size)
elif type == "reduction":
unroll_sizes.append(1)
return unroll_sizes

@staticmethod
def get_linalg_op_naive_permutation(node: linalg.GenericOp):
loop_properties = extract_loop_properties(node)

num_reduction = 0
interchange_permutation = []
for index, (_, type) in enumerate(loop_properties):
if type == "parallel":
interchange_permutation.append(index)
elif type == "reduction":
interchange_permutation.insert(num_reduction, index)
num_reduction += 1
return interchange_permutation

@staticmethod
def get_reshape_op_naive_tile_sizes(
node: Union[tensor.ExpandShapeOp, tensor.CollapseShapeOp],
Expand All @@ -544,17 +552,19 @@ def get_reshape_op_naive_tile_sizes(
"Source tile sizes do not match result tile sizes")
return source_tile_sizes, result_tile_sizes

def naive_exploration(self):
def naive_exploration(self, default_tile_size: int = 16, default_unroll_size: int = 2):
for node, data in self.nodes(data=True):
if isinstance(node, linalg.GenericOp):
data["parallel_tile_sizes"], data["reduction_tile_sizes"] = self.get_generic_op_naive_tile_sizes(
node, default_tile_size=self.default_tile_size)
data["permutation"] = self.get_generic_op_naive_permutation(
data["parallel_tile_sizes"], data["reduction_tile_sizes"] = self.get_linalg_op_naive_tile_sizes(
node, default_tile_size=default_tile_size)
data["unroll_sizes"] = self.get_linalg_op_naive_unroll_sizes(
node, default_unroll_size=default_unroll_size)
data["permutation"] = self.get_linalg_op_naive_permutation(
node)

if isinstance(node, (tensor.ExpandShapeOp, tensor.CollapseShapeOp)):
data["source_tile_sizes"], data["result_tile_sizes"] = self.get_reshape_op_naive_tile_sizes(
node, default_tile_size=self.default_tile_size)
node, default_tile_size=default_unroll_size)

def print_dot(self, file_name: str):
dot = Digraph()
Expand Down Expand Up @@ -586,13 +596,16 @@ def construct_transform_sequence(target: BlockArgument,
raise ValueError("parallel_tile_sizes not found")
if "reduction_tile_sizes" not in data:
raise ValueError("reduction_tile_sizes not found")
if "unroll_sizes" not in data:
raise ValueError("unroll_sizes not found")
if "permutation" not in data:
raise ValueError("permutation not found")

linalg_op_handle = convert_full_tensor_linalg_generic_to_itensor(
linalg_op_handle = convert_full_tensor_linalg_op_to_itensor(
node_handle,
data["parallel_tile_sizes"],
data["reduction_tile_sizes"],
data["unroll_sizes"],
data["permutation"],
len(node.inputs) > 0)
annotate(linalg_op_handle, "id", i64_param(data["id"]))
Expand Down

0 comments on commit ad377d1

Please sign in to comment.