diff --git a/python/scalehls/transforms.py b/python/scalehls/transforms.py index 552c6cc0..efd7c7cb 100644 --- a/python/scalehls/transforms.py +++ b/python/scalehls/transforms.py @@ -19,6 +19,83 @@ from typing import Sequence, Optional, Union, Callable, Mapping from functools import wraps + +# ===----------------------------------------------------------------------=== # +# Compilation Passes +# ===----------------------------------------------------------------------=== # + + +def apply_transform_sequence( + module: Module, + sequence: transform.NamedSequenceOp): + pm = PassManager.parse( + "builtin.module(" + "transform-interpreter{entry-point=" + sequence.sym_name.value + "}," + "cse, canonicalize)") + pm.run(module.operation) + + +def apply_linalg_optimization_passes(module: Module): + pm = PassManager() + add_linalg_transform_passes(pm) + pm.run(module.operation) + + +def apply_reduce_full_tensor_to_itensor(module: Module): + pm = PassManager.parse( + "builtin.module(func.func(scalehls-reduce-full-tensor-to-itensor)," + "cse, canonicalize)") + pm.run(module.operation) + + +def apply_materialize_itensor(module: Module, enable_packing: bool = False): + enable_packing_str = "true" if enable_packing else "false" + pm = PassManager.parse( + "builtin.module(func.func(scalehls-materialize-itensor{" + "enable-packing=" + enable_packing_str + "})," + "cse, canonicalize)") + pm.run(module.operation) + + +def apply_scalarize_itensor(module: Module): + pm = PassManager.parse( + "builtin.module(func.func(scalehls-scalarize-itensor)," + "cse, canonicalize)") + pm.run(module.operation) + + +def apply_lower_itensor_to_stream(module: Module): + pm = PassManager.parse( + "builtin.module(func.func(scalehls-lower-itensor-to-stream)," + "cse, canonicalize)") + pm.run(module.operation) + + +def apply_comprehensive_bufferize_passes(module: Module): + pm = PassManager() + add_comprehensive_bufferize_passes(pm) + pm.run(module.operation) + + +def apply_schedule_dataflow(module: Module): + pm = PassManager.parse( + "builtin.module(func.func(scalehls-schedule-dataflow)," + "cse, canonicalize)") + pm.run(module.operation) + + +def apply_convert_dataflow_to_func_passes(module: Module): + pm = PassManager() + add_convert_dataflow_to_func_passes(pm) + pm.run(module.operation) + + +def get_module_cpp_str(module: Module): + buf = io.StringIO() + emit_hlscpp(module, buf) + return buf.getvalue() + + # ===----------------------------------------------------------------------=== # # General Utils # ===----------------------------------------------------------------------=== # @@ -50,14 +127,21 @@ def extract_iterator_type(s: str): return loop_properties +def find_func(module: Module, name: str) -> Optional[func.FuncOp]: + for op in module.body: + if isinstance(op, func.FuncOp): + if op.name.value == name: + return op + return None + + # ===----------------------------------------------------------------------=== # # Transform Utils # ===----------------------------------------------------------------------=== # -def transform_sequence( - name: str = "__transform_main", - result_types: Sequence[Type] = []): +def transform_sequence(name: str = "__transform_main", + result_types: Sequence[Type] = []): """ A decorator to construct a `transform.named_sequence` op containing the ops built by the decorated function. `result_types` must be the same as the @@ -171,10 +255,9 @@ def match_linalg_init(linalg_op_handle: BlockArgument, op_name: str): @match_linalg_with_conditions() -def match_linalg_result( - linalg_op_handle: BlockArgument, - op_name: str, - position: int = 0): +def match_linalg_result(linalg_op_handle: BlockArgument, + op_name: str, + position: int = 0): """ The returned handle may contain multiple users of the matched result, which may need to be transformed with `foreach_transform` decorated function. @@ -212,7 +295,7 @@ def wrapper(target: Value, *args, **kwargs): # ===----------------------------------------------------------------------=== # -# Transform Utils +# Transform Operation Utils # ===----------------------------------------------------------------------=== # @@ -318,252 +401,230 @@ def foreach_merge_consecutive_extract_slice_and_convert_to_itensor_read( def convert_full_tensor_linalg_generic_to_itensor( - target: Value, + linalg_op_handle: Value, parallel_tile_sizes: Sequence[int], reduction_tile_sizes: Sequence[int], permutation: Sequence[int], has_input: bool = True, - combine_split_reduction=False): - tile_op = tile(target, parallel_tile_sizes) - target = tile_op.tiled_linalg_op + split_combine_reduction=False): + """ + This is the main function driving the transformation of a linalg.generic op + with full tensor semantics to a sequence of ops with itensor semantics. + """ + # Tile the linalg op with the given parallel tile sizes. Nested SCF loops + # are generated for parallel tile sizes that are greater than 0. + tile_op = tile(linalg_op_handle, parallel_tile_sizes) + linalg_op_handle = tile_op.tiled_linalg_op - matched_init = match_linalg_init(target, "tensor.extract_slice") + # Convert the linalg `init` operand to a `tensor.init` op. + matched_init = match_linalg_init(linalg_op_handle, "tensor.extract_slice") foreach_convert_extract_slice_to_tensor_init(matched_init) - matched_result = match_linalg_result(target, "tensor.insert_slice") + # Convert each `insert_slice` op to a `itensor.write` op. + matched_result = match_linalg_result( + linalg_op_handle, "tensor.insert_slice") foreach_convert_insert_slice_to_itensor_write(matched_result) + # Tile the linalg op with the given reduction tile sizes if applicable. + # Again, nested SCF loops are generated for reduction tile sizes that are + # greater than 0. If `split_combine_reduction` is set to True, the reduction + # tiling will generate two nested loops: spliting loops and combining loops. + # Please refer to the MLIR documentation of TileReductionUsingForOp for more + # details. if any(size > 0 for size in reduction_tile_sizes): - if (combine_split_reduction): - tile_reduction_op = tile_reduction(target, reduction_tile_sizes) - target = tile_reduction_op.split_linalg_op + if (split_combine_reduction): + tile_reduction_op = tile_reduction( + linalg_op_handle, reduction_tile_sizes) + linalg_op_handle = tile_reduction_op.split_linalg_op convert_fill_to_tensor_init(tile_reduction_op.fill_op) else: - tile_reduction_op = tile(target, reduction_tile_sizes) - target = tile_reduction_op.tiled_linalg_op + tile_reduction_op = tile(linalg_op_handle, reduction_tile_sizes) + linalg_op_handle = tile_reduction_op.tiled_linalg_op + # Convert each `extract_slice` op to a `itensor.read` op. if has_input: - matched_input = match_linalg_input(target, "tensor.extract_slice") + matched_input = match_linalg_input( + linalg_op_handle, "tensor.extract_slice") foreach_merge_consecutive_extract_slice_and_convert_to_itensor_read( matched_input) - interchange_op = interchange(target, permutation) - target = interchange_op.transformed - return target + # Interchange the loops of the linalg op with the given permutation. + interchange_op = interchange(linalg_op_handle, permutation) + linalg_op_handle = interchange_op.transformed + return linalg_op_handle # ===----------------------------------------------------------------------=== # -# Computation Graph Utils +# Design Space Exploration Utils # ===----------------------------------------------------------------------=== # -def is_nontrivial_node(node: Operation): - return not isinstance(node, (arith.ConstantOp, hls.TensorInitOp, tensor.EmptyOp)) - - -def construct_graph(module: Module, ): - def find_func(module: Module, name: str) -> Optional[func.FuncOp]: - for op in module.body: - if isinstance(op, func.FuncOp): - if op.name.value == name: - return op - return None - - g = nx.Graph() - f = find_func(module, "forward") - if f is None: - raise ValueError("forward function not found") - - g.add_node(f, name=f.OPERATION_NAME, id=-1) - for id, op in enumerate(f.entry_block): - g.add_node(op, name=op.OPERATION_NAME, id=id) - op.attributes["id"] = i64_attr(id) - for operand in op.operands: - parent = operand.owner.owner if isinstance( - operand.owner, Block) else operand.owner - if not g.has_node(parent): - raise ValueError("parent node not found") - g.add_edge(parent, op, value=operand) - - return g - - -def print_graph(g: nx.Graph, name: str): - dot = Digraph() - for node, data in g.nodes(data=True): - if is_nontrivial_node(data["name"]): - dot.node(data["name"] + str(data["id"])) - for prev, next, data in g.edges(data=True): - prev_data = g.nodes[prev] - next_data = g.nodes[next] - if is_nontrivial_node(prev) and is_nontrivial_node(next): - dot.edge(prev_data["name"] + str(prev_data["id"]), - next_data["name"] + str(next_data["id"])) - - dot.render(name, format='png', cleanup=True) - - # ===----------------------------------------------------------------------=== # -# Design Space Exploration Utils +# DesignSpaceGraph Class # ===----------------------------------------------------------------------=== # -def get_generic_op_naive_permutation(node: linalg.GenericOp): - loop_properties = extract_loop_properties(node) - numReduction = 0 - interchange_permutation = [] - for index, (_, type) in enumerate(loop_properties): - if type == "parallel": - interchange_permutation.append(index) - elif type == "reduction": - interchange_permutation.insert(numReduction, index) - numReduction += 1 - return interchange_permutation - - -def get_generic_op_naive_tile_sizes( - node: linalg.GenericOp, - default_tile_size: int = 16): - loop_properties = extract_loop_properties(node) - - parallel_tile_sizes = [] - reduction_tile_sizes = [] - for range, type in loop_properties: - tile_size = default_tile_size if range > default_tile_size else 0 - if type == "parallel": - parallel_tile_sizes.append(tile_size) - reduction_tile_sizes.append(0) - elif type == "reduction": - parallel_tile_sizes.append(0) - reduction_tile_sizes.append(tile_size) - return parallel_tile_sizes, reduction_tile_sizes - - -def get_reshape_op_naive_tile_sizes( - node: Union[tensor.ExpandShapeOp, tensor.CollapseShapeOp], - default_tile_size: int = 16): - source_tile_sizes = [] - result_tile_sizes = [] - for source_dim_size in node.src.type.shape: - tile_size = default_tile_size if source_dim_size > default_tile_size else 1 - source_tile_sizes.append(tile_size) - for result_dim_size in node.result.type.shape: - tile_size = default_tile_size if result_dim_size > default_tile_size else 1 - result_tile_sizes.append(tile_size) - - # Support more flexible tile sizes. - if (functools.reduce(operator.mul, source_tile_sizes, 1) != - functools.reduce(operator.mul, result_tile_sizes, 1)): - raise ValueError("Source tile sizes do not match result tile sizes") - return source_tile_sizes, result_tile_sizes +class DesignSpaceGraph(nx.Graph): + def __init__(self, + module: Module, + top_name: str = "forward", + default_tile_size: int = 16, + default_unroll_size: int = 4): + super().__init__() + self.module = module + self.top = find_func(self.module, top_name) + if self.top is None: + raise ValueError("top function `" + top_name + "` not found") + self.default_tile_size = default_tile_size + self.default_unroll_size = default_unroll_size + + self.add_node(self.top, name=self.top.OPERATION_NAME, id=-1) + for id, op in enumerate(self.top.entry_block): + self.add_node(op, name=op.OPERATION_NAME, id=id) + op.attributes["id"] = i64_attr(id) + for operand in op.operands: + parent = operand.owner.owner if isinstance( + operand.owner, Block) else operand.owner + if not self.has_node(parent): + raise ValueError("parent node not found") + self.add_edge(parent, op, value=operand) + + @staticmethod + def is_nontrivial_node(node: Operation): + return not isinstance( + node, (arith.ConstantOp, hls.TensorInitOp, tensor.EmptyOp)) + + @staticmethod + def get_generic_op_naive_permutation(node: linalg.GenericOp): + loop_properties = extract_loop_properties(node) + numReduction = 0 + interchange_permutation = [] + for index, (_, type) in enumerate(loop_properties): + if type == "parallel": + interchange_permutation.append(index) + elif type == "reduction": + interchange_permutation.insert(numReduction, index) + numReduction += 1 + return interchange_permutation + + @staticmethod + def get_generic_op_naive_tile_sizes(node: linalg.GenericOp, + default_tile_size: int = 16): + loop_properties = extract_loop_properties(node) + + parallel_tile_sizes = [] + reduction_tile_sizes = [] + for range, type in loop_properties: + tile_size = default_tile_size if range > default_tile_size else 0 + if type == "parallel": + parallel_tile_sizes.append(tile_size) + reduction_tile_sizes.append(0) + elif type == "reduction": + parallel_tile_sizes.append(0) + reduction_tile_sizes.append(tile_size) + return parallel_tile_sizes, reduction_tile_sizes + + @staticmethod + def get_reshape_op_naive_tile_sizes( + node: Union[tensor.ExpandShapeOp, tensor.CollapseShapeOp], + default_tile_size: int = 16): + source_tile_sizes = [] + result_tile_sizes = [] + for source_dim_size in node.src.type.shape: + tile_size = default_tile_size if source_dim_size > default_tile_size else 1 + source_tile_sizes.append(tile_size) + for result_dim_size in node.result.type.shape: + tile_size = default_tile_size if result_dim_size > default_tile_size else 1 + result_tile_sizes.append(tile_size) + + # Support more flexible tile sizes. + if (functools.reduce(operator.mul, source_tile_sizes, 1) != + functools.reduce(operator.mul, result_tile_sizes, 1)): + raise ValueError( + "Source tile sizes do not match result tile sizes") + return source_tile_sizes, result_tile_sizes + + def naive_exploration(self): + for node, data in self.nodes(data=True): + if isinstance(node, linalg.GenericOp): + data["parallel_tile_sizes"], data["reduction_tile_sizes"] = self.get_generic_op_naive_tile_sizes( + node, default_tile_size=self.default_tile_size) + data["permutation"] = self.get_generic_op_naive_permutation( + node) + + if isinstance(node, (tensor.ExpandShapeOp, tensor.CollapseShapeOp)): + data["source_tile_sizes"], data["result_tile_sizes"] = self.get_reshape_op_naive_tile_sizes( + node, default_tile_size=self.default_tile_size) + + def print_dot(self, file_name: str): + dot = Digraph() + for node, data in self.nodes(data=True): + if self.is_nontrivial_node(node): + dot.node(data["name"] + str(data["id"])) + for prev, next, data in self.edges(data=True): + prev_data = self.nodes[prev] + next_data = self.nodes[next] + if self.is_nontrivial_node(prev) and self.is_nontrivial_node(next): + dot.edge(prev_data["name"] + str(prev_data["id"]), + next_data["name"] + str(next_data["id"])) + dot.render(file_name, format='png', cleanup=True) @transform_sequence() -def construct_design_space_exploration_transform_sequence( - target: BlockArgument, module_graph: nx.Graph): - for node, data in module_graph.nodes(data=True): +def construct_transform_sequence(target: BlockArgument, + graph: DesignSpaceGraph): + """ + This function constructs a transform sequence to transform the target + function based on the given design space graph. + """ + for node, data in graph.nodes(data=True): node_handle = match(target, [data["name"]], { - "id": i64_attr(data["id"])}) + "id": i64_attr(data["id"])}) if isinstance(node, linalg.GenericOp): - permutation = get_generic_op_naive_permutation(node) - parallel_tile_sizes, reduction_tile_sizes = get_generic_op_naive_tile_sizes( - node, default_tile_size=16) - - stream_node_handle = convert_full_tensor_linalg_generic_to_itensor( - node_handle, parallel_tile_sizes, reduction_tile_sizes, permutation, len(node.inputs) > 0) - annotate(stream_node_handle, "id", i64_param(data["id"])) + if "parallel_tile_sizes" not in data: + raise ValueError("parallel_tile_sizes not found") + if "reduction_tile_sizes" not in data: + raise ValueError("reduction_tile_sizes not found") + if "permutation" not in data: + raise ValueError("permutation not found") + + linalg_op_handle = convert_full_tensor_linalg_generic_to_itensor( + node_handle, + data["parallel_tile_sizes"], + data["reduction_tile_sizes"], + data["permutation"], + len(node.inputs) > 0) + annotate(linalg_op_handle, "id", i64_param(data["id"])) if isinstance(node, tensor.ExpandShapeOp): - source_tile_sizes, result_tile_sizes = get_reshape_op_naive_tile_sizes( - node, default_tile_size=16) + if "source_tile_sizes" not in data: + raise ValueError("source_tile_sizes not found") + if "result_tile_sizes" not in data: + raise ValueError("result_tile_sizes not found") convert_op = convert_expand_shape_to_itensor_reassociate( - node_handle, source_tile_sizes, result_tile_sizes) + node_handle, + data["source_tile_sizes"], + data["result_tile_sizes"]) annotate(convert_op.itensor_reassociate, "id", i64_param(data["id"])) if isinstance(node, tensor.CollapseShapeOp): - source_tile_sizes, result_tile_sizes = get_reshape_op_naive_tile_sizes( - node, default_tile_size=16) + if "source_tile_sizes" not in data: + raise ValueError("source_tile_sizes not found") + if "result_tile_sizes" not in data: + raise ValueError("result_tile_sizes not found") convert_op = convert_collapse_shape_to_itensor_reassociate( - node_handle, source_tile_sizes, result_tile_sizes) + node_handle, + data["source_tile_sizes"], + data["result_tile_sizes"]) annotate(convert_op.itensor_reassociate, "id", i64_param(data["id"])) return [] -# ===----------------------------------------------------------------------=== # -# Transform Passes -# ===----------------------------------------------------------------------=== # - - -def apply_linalg_transform_passes(module: Module): - pm = PassManager() - add_linalg_transform_passes(pm) - pm.run(module.operation) - - -def apply_transform_sequence( - module: Module, - sequence: transform.NamedSequenceOp): - pm = PassManager.parse( - "builtin.module(" - "transform-interpreter{entry-point=" + sequence.sym_name.value + "}," - "cse, canonicalize)") - pm.run(module.operation) - -def apply_reduce_full_tensor_to_itensor(module: Module): - pm = PassManager.parse( - "builtin.module(func.func(scalehls-reduce-full-tensor-to-itensor)," - "cse, canonicalize)") - pm.run(module.operation) - - -def apply_materialize_itensor(module: Module, enable_packing: bool = False): - enable_packing_str = "true" if enable_packing else "false" - pm = PassManager.parse( - "builtin.module(func.func(scalehls-materialize-itensor{" - "enable-packing=" + enable_packing_str + "})," - "cse, canonicalize)") - pm.run(module.operation) - - -def apply_scalarize_itensor(module: Module): - pm = PassManager.parse( - "builtin.module(func.func(scalehls-scalarize-itensor)," - "cse, canonicalize)") - pm.run(module.operation) - - -def apply_lower_itensor_to_stream(module: Module): - pm = PassManager.parse( - "builtin.module(func.func(scalehls-lower-itensor-to-stream)," - "cse, canonicalize)") - pm.run(module.operation) - - -def apply_comprehensive_bufferize_passes(module: Module): - pm = PassManager() - add_comprehensive_bufferize_passes(pm) - pm.run(module.operation) - - -def apply_schedule_dataflow(module: Module): - pm = PassManager.parse( - "builtin.module(func.func(scalehls-schedule-dataflow)," - "cse, canonicalize)") - pm.run(module.operation) - - -def apply_convert_dataflow_to_func_passes(module: Module): - pm = PassManager() - add_convert_dataflow_to_func_passes(pm) - pm.run(module.operation) - - -def get_module_cpp_str(module: Module): - buf = io.StringIO() - emit_hlscpp(module, buf) - return buf.getvalue() +def apply_design_space(graph: DesignSpaceGraph): + apply_transform_sequence( + graph.module, construct_transform_sequence(graph.module, graph))