diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index b659572f..14ed3e50 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -21,6 +21,13 @@ declare_mlir_python_sources(ScaleHLSPythonSources.Core _mlir_libs/_scalehls.pyi ) +declare_mlir_python_sources(ScaleHLSPythonSources.OpDSL + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/scalehls" + ADD_TO_PARENT ScaleHLSPythonSources + SOURCES_GLOB + opdsl/*.py +) + ################################################################################ # Declare dialect-specific bindings ################################################################################ diff --git a/python/scalehls/opdsl/__init__.py b/python/scalehls/opdsl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/scalehls/opdsl/dump_oplib.py b/python/scalehls/opdsl/dump_oplib.py new file mode 100644 index 00000000..2f651319 --- /dev/null +++ b/python/scalehls/opdsl/dump_oplib.py @@ -0,0 +1,90 @@ +#!/usr/bin/which python +# Command line tool to load an oplib module and dump all of the operations +# it contains in some format. +"""Loads one or more modules containing op definitions and dumps them. + +The dump format can be: + +* `--dump_format=yaml` (default) +* `--dump_format=repr` + +Positional arguments are interpreted as module names (optionally, relative to +this module). Loose module files can be specified via `--file `. + +Sample usage: + # Dump the YAML op definitions for the core named ops (as in the dialect + # source tree). + python -m mlir.dialects.linalg.opdsl.dump_oplib .ops.core_named_ops + +Note: YAML output is emitted in "document list" format with each operation +as its own "document". Practically, this means that each operation (or group +of composite ops) is emitted with a "---" preceding it, which can be useful +for testing. +""" + +import argparse +import importlib + +from .lang import * +from .lang.config import * +from .lang.yaml_helper import * + + +def create_arg_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Dump an oplib in various formats") + p.add_argument( + "modules", metavar="M", type=str, nargs="*", help="Op module to dump" + ) + p.add_argument( + "--file", metavar="F", type=str, nargs="*", help="Python op file to dump" + ) + p.add_argument( + "--format", + type=str, + dest="format", + default="yaml", + choices=("yaml", "repr"), + help="Format in which to dump", + ) + return p + + +def load_module_from_file(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + m = importlib.util.module_from_spec(spec) + spec.loader.exec_module(m) + return m + + +def main(args): + # Load all configs. + configs = [] + modules = [] + for module_name in args.modules: + modules.append( + importlib.import_module(module_name, package="mlir.dialects.linalg.opdsl") + ) + for i, file_path in enumerate(args.file or []): + modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path)) + for m in modules: + for attr_name, value in m.__dict__.items(): + # TODO: This class layering is awkward. + if isinstance(value, DefinedOpCallable): + try: + linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def) + except Exception as e: + raise ValueError( + f"Could not create LinalgOpConfig from {value.op_def}" + ) from e + configs.extend(linalg_config) + + # Print. + if args.format == "yaml": + print(yaml_dump_all(configs)) + elif args.format == "repr": + for config in configs: + print(repr(config)) + + +if __name__ == "__main__": + main(create_arg_parser().parse_args()) diff --git a/python/scalehls/opdsl/lang/__init__.py b/python/scalehls/opdsl/lang/__init__.py new file mode 100644 index 00000000..cf85c885 --- /dev/null +++ b/python/scalehls/opdsl/lang/__init__.py @@ -0,0 +1 @@ +from .dsl import * diff --git a/python/scalehls/opdsl/lang/affine.py b/python/scalehls/opdsl/lang/affine.py new file mode 100644 index 00000000..d2f5632e --- /dev/null +++ b/python/scalehls/opdsl/lang/affine.py @@ -0,0 +1,306 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""DSL for constructing affine expressions and maps. + +These python wrappers allow construction of affine expressions in a more +pythonic fashion that is later instantiated as an IR AffineExpr. Separating the +AST from construction of the map allows for manipulations of symbols and dims +beyond the scope of one expression. + +Affine expression construction: + >>> with _ir.Context(): + ... s = AffineBuildState() + ... (S.K + S.M).build(s) + ... (S.K * S.M).build(s) + ... (S.K // S.M).build(s) + ... (S.K / S.M).build(s) + ... (S.K % 4).build(s) + ... (D.i + D.j * 4).build(s) + ... s + AffineExpr(s0 + s1) + AffineExpr(s0 * s1) + AffineExpr(s0 floordiv s1) + AffineExpr(s0 ceildiv s1) + AffineExpr(s0 mod 4) + AffineExpr(d0 + d1 * 4) + AffineBuildState< + symbols={'K': 0, 'M': 1} + dims={'i': 0, 'j': 1}> + +In the DSL, dimensions and symbols are name-uniqued instances of DimDef and +SymbolDef. There are shortcut "expando" instances that will create a +corresponding DimDef/SymbolDef upon accessing an attribute: + +Referencing a named dimension: + + >>> D.i + Dim(i) + >>> D.a is D.b + False + >>> D.a is D.a + True + +Referencing a named symbol: + + >>> S.foobar + Symbol(foobar) + >>> S.a is S.b + False + >>> S.a is S.a + True +""" + +from typing import Callable, Dict, Optional, Tuple, Union + +from ... import ir as _ir + +__all__ = [ + "AffineBuildState", + "AffineExprDef", + "D", + "DimDef", + "S", + "SymbolDef", +] + + +class AffineBuildState: + """Internal state for the AffineExprDef._create impls. + + Note that a "local" AffineBuildState can be created relative to a "global" + AffineBuildState. In that case, any affine expressions built will inherit + symbol and dim bindings from the global state and will update both as new + ones are discovered. This allows for building expressions across contexts + which share a common symbol and dim space. + """ + + def __init__( + self, + *, + global_state: "AffineBuildState" = None, + allow_new_symbols: bool = True, + allow_new_dims: bool = True, + ): + if not global_state: + self.all_symbols = dict() # type: Dict[str, int] + self.all_dims = dict() # type: Dict[str, int] + else: + # Alias the global dict. + self.all_symbols = global_state.all_symbols + self.all_dims = global_state.all_dims + + # Map of symbols and dims in the current build. + self.local_symbols = dict() # type: Dict[str, int] + self.local_dims = dict() # type: Dict[str, int] + self.allow_new_symbols = allow_new_symbols + self.allow_new_dims = allow_new_dims + + def get_dim(self, dimname: str) -> int: + """Gets the dim position given a name.""" + pos = self.all_dims.get(dimname) + if pos is None: + if not self.allow_new_dims: + raise ValueError( + f"New dimensions not allowed in the current affine expression: " + f"Requested '{dimname}', Availble: {self.all_dims}" + ) + pos = len(self.all_dims) + self.all_dims[dimname] = pos + self.local_dims[dimname] = pos + return pos + + def get_symbol(self, symname: str) -> int: + """Geta a symbol position given a name.""" + pos = self.all_symbols.get(symname) + if pos is None: + if not self.allow_new_symbols: + raise ValueError( + f"New symbols not allowed in the current affine expression: " + f"Requested '{symname}', Availble: {self.all_symbols}" + ) + pos = len(self.all_symbols) + self.all_symbols[symname] = pos + self.local_symbols[symname] = pos + return pos + + @property + def local_dim_count(self) -> int: + return len(self.local_dims) + + @property + def local_symbol_count(self) -> int: + return len(self.local_symbols) + + @property + def dim_count(self) -> int: + return len(self.all_dims) + + @property + def symbol_count(self) -> int: + return len(self.all_symbols) + + def __repr__(self): + lines = [f"AffineBuildState<"] + lines.append(f" symbols={self.local_symbols}") + lines.append(f" dims={self.local_dims}>") + return "\n".join(lines) + + +class AffineExprDef: + """Base class for an affine expression being defined.""" + + def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr: + """Builds the corresponding _ir.AffineExpr from the definitions.""" + state = AffineBuildState() if state is None else state + expr = self._create(state) + return expr + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + raise NotImplementedError() + + @staticmethod + def coerce_from(py_value): + if isinstance(py_value, int): + return AffineConstantExpr(py_value) + assert isinstance(py_value, AffineExprDef) + return py_value + + def visit_affine_exprs(self, callback): + """Visits all AffineExprDefs including self.""" + callback(self) + + def __add__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs) + + def __mul__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs) + + def __mod__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs) + + def __floordiv__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs) + + def __truediv__(lhs, rhs): + # TODO: Not really a ceil div - taking liberties for the DSL. + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs) + + +class AffineConstantExpr(AffineExprDef): + """An affine constant being defined.""" + + def __init__(self, value: int): + assert isinstance(value, int) + self.value = value + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + return _ir.AffineConstantExpr.get(self.value) + + def __repr__(self): + return f"Const({self.value})" + + +class AffineBinaryExprDef(AffineExprDef): + """An affine binary expression being defined.""" + + def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef): + self.ir_ctor = ir_ctor + self.lhs = lhs + self.rhs = rhs + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state)) + + def visit_affine_exprs(self, callback): + """Visits all AffineExprDefs including self.""" + super().visit_affine_exprs(callback) + self.lhs.visit_affine_exprs(callback) + self.rhs.visit_affine_exprs(callback) + + def __repr__(self): + return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})" + + +class DimDef(AffineExprDef): + """Represents a named dimension.""" + + ALL_DIMS = dict() # type: Dict[str, "DimDef"] + + def __new__(cls, dimname: str): + existing = cls.ALL_DIMS.get(dimname) + if existing is not None: + return existing + new = super().__new__(cls) + new.dimname = dimname + cls.ALL_DIMS[dimname] = new + return new + + def __repr__(self): + return f"Dim({self.dimname})" + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + pos = state.get_dim(self.dimname) + return _ir.AffineDimExpr.get(position=pos) + + @classmethod + def create_expando(cls): + """Create an expando class that creates unique symbols based on attr access.""" + + class ExpandoDims: + def __getattr__(self, n): + return cls(n) + + return ExpandoDims() + + +class SymbolDef(AffineExprDef): + """Represents a named symbol. + + >>> s1 = SymbolDef("s1") + >>> s1 + Symbol(s1) + >>> s2 = SymbolDef("s2") + >>> s1 is s2 + False + >>> s1 is SymbolDef("s1") + True + """ + + ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"] + + def __new__(cls, symname: str): + existing = cls.ALL_SYMBOLS.get(symname) + if existing is not None: + return existing + new = super().__new__(cls) + new.symname = symname + cls.ALL_SYMBOLS[symname] = new + return new + + def __repr__(self): + return f"Symbol({self.symname})" + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + pos = state.get_symbol(self.symname) + return _ir.AffineSymbolExpr.get(position=pos) + + @classmethod + def create_expando(cls): + """Create an expando class that creates unique symbols based on attr access.""" + + class ExpandoSymbols: + def __getattr__(self, n): + return cls(n) + + return ExpandoSymbols() + + +# Global accessor for on-demand dims and symbols. +D = DimDef.create_expando() +S = SymbolDef.create_expando() diff --git a/python/scalehls/opdsl/lang/comprehension.py b/python/scalehls/opdsl/lang/comprehension.py new file mode 100644 index 00000000..d39be2d6 --- /dev/null +++ b/python/scalehls/opdsl/lang/comprehension.py @@ -0,0 +1,844 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Model classes representing a tensor comprehension. + +These classes model the language more at an AST level as evaluated. Reasoning +about it typically involves processing this form into config objects that +represent actual op definitions (i.e. YAML). +""" + +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple +from enum import Enum + +from ... import ir as _ir +from .affine import * +from .scalar_expr import * +from .types import * +from .yaml_helper import * + +############################################################################### +# Tensor expression nodes. +############################################################################### + + +class TensorExpression: + """An expression that can appear on the RHS of a comprehension.""" + + def to_scalar_expression(self) -> ScalarExpression: + raise NotImplementedError() + + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + """Visits all tensor expression reachable by the expression.""" + callback(self) + + def collect_dim_uses(self, uses: Set["DimDef"]): + """Collects all DimDefs reachable through this expression.""" + + def visit_dim_def(dim_def: AffineExprDef): + if isinstance(dim_def, DimDef): + uses.add(dim_def) + + def visit_affine_exprs(expr: "TensorExpression"): + if isinstance(expr, TensorUse): + for ind in expr.indices: + ind.visit_affine_exprs(visit_dim_def) + if isinstance(expr, TensorReduceFn): + for ind in expr.reduce_fn.reduce_dims: + ind.visit_affine_exprs(visit_dim_def) + + self.visit_tensor_exprs(visit_affine_exprs) + + def collect_tensor_uses(self, uses: Set["TensorUse"]): + """Collects all TensorUses reachable through this expression.""" + + def visit_tensor_use(expr: "TensorExpression"): + if isinstance(expr, TensorUse): + uses.add(expr) + + self.visit_tensor_exprs(visit_tensor_use) + + def collect_indices(self, indices: Set["index"]): + """Collects all index accesses reachable through this expression.""" + + def visit_index(expr: "TensorExpression"): + if isinstance(expr, index): + indices.add(expr) + + self.visit_tensor_exprs(visit_index) + + def collect_scalar_uses(self, uses: Set["ScalarDef"]): + """Collects all ScalarDefs reachable through this expression.""" + + def visit_scalar_def(expr: "TensorExpression"): + if isinstance(expr, ScalarDef): + uses.add(expr) + + self.visit_tensor_exprs(visit_scalar_def) + + def __add__(self, rhs: "TensorExpression") -> "TensorExpression": + return BinaryFn.add(self, rhs) + + def __mul__(self, rhs) -> "TensorExpression": + return BinaryFn.mul(self, rhs) + + def __sub__(self, rhs) -> "TensorExpression": + return BinaryFn.sub(self, rhs) + + def __truediv__(self, rhs) -> "TensorExpression": + return BinaryFn.div(self, rhs) + + def __hash__(self): + return hash(id(self)) + + +class TensorUse(TensorExpression): + """A used tensor represented by its (tensor_name, indices). + + Note that forming a comprehension via direct assignment is performed through + __setitem__ on the TensorDef level. However, performing a reduction with + compound ops (+=, *=, etc) is done by doing a: + TensorDef.__getitem__ + TensorUse.__iadd__ + TensorDef.__setitem__ + """ + + def __init__(self, operand_def: "OperandDef", indices: Sequence[AffineExprDef]): + self.operand_def = operand_def + self.indices = tuple(indices) + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarArg(self.tensor_name).expr() + + @property + def tensor_name(self) -> str: + name = self.operand_def.name + assert name is not None, "TensorDef not registered with an op" + return name + + def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: + # Computes the reduction dims for implicit reductions. Assumes that the rhs + # is the expression being reduced and self is being reduced into. Any + # indices referenced on the rhs and not in self are considered reduction + # dims and will be ordered as encountered on the rhs. + rhs_dims = set() + lhs_dims = set() + rhs.collect_dim_uses(rhs_dims) + self.collect_dim_uses(lhs_dims) + return rhs_dims - lhs_dims + + def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": + return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs) + + def __repr__(self): + return ( + f"{self.operand_def.name}" f"[{', '.join([repr(i) for i in self.indices])}]" + ) + + +class TensorFn(TensorExpression): + """Application of a tensor function.""" + + def __init__( + self, + kind: "FunctionKind", + name: Optional[str], + operand_def: Optional["OperandDef"], + type_var: Optional[TypeVar], + args: Sequence[TensorExpression], + ): + if bool(name) + bool(operand_def) != 1: + raise ValueError("One of 'name', 'operand_def' must be specified") + self.name = name + self.kind = kind + self.operand_def = operand_def + self.type_var = type_var + self.args = args + + def to_scalar_expression(self) -> ScalarExpression: + if self.operand_def: + assert self.operand_def.name, "TensorFn not registered with an op" + attr_name = self.operand_def.name if self.operand_def else None + args = [arg.to_scalar_expression() for arg in self.args] + return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr() + + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + super().visit_tensor_exprs(callback) + for arg in self.args: + arg.visit_tensor_exprs(callback) + + def __repr__(self): + name = self.operand_def.name if self.operand_def else self.name + return ( + f"{self.kind.name}.{name}(type_var={self.type_var}, " + f"args={', '.join(repr(a) for a in self.args)})" + ) + + +class TensorReduceFn(TensorExpression): + """Application of a reduction function. + + This captures the lhs (initial value) separately from the rhs. + """ + + def __init__(self, reduce_use: "ReduceFnUse", args: Sequence[TensorExpression]): + self.reduce_use = reduce_use + self.lhs = None # type: Optional[TensorUse] + self.args = args + + def to_scalar_expression(self) -> ScalarExpression: + if self.lhs is None: + raise ValueError( + f"Cannot scalarize a TensorReduceFn that has not been " + f"bound to its lhs: {self}" + ) + full_args = [self.lhs.to_scalar_expression()] + [ + arg.to_scalar_expression() for arg in self.args + ] + fn_name = None + attr_name = None + if self.reduce_use.binary_fn: + fn_name = self.reduce_use.binary_fn.fn_name + if self.reduce_use.binary_attr: + attr_name = self.reduce_use.binary_attr.operand_def.name + return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None, full_args).expr() + + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + for arg in self.args: + arg.visit_tensor_exprs(callback) + + def __repr__(self): + return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})" + + +class const(TensorExpression): + """Returns the given constant floating point or integer value.""" + + def __init__(self, value: Any): + with _ir.Context(): + if isinstance(value, float): + self.value = str(_ir.FloatAttr.get_f64(float(value))) + elif isinstance(value, int): + self.value = str( + _ir.IntegerAttr.get( + _ir.IntegerType.get_signless(64), int(value)) + ) + else: + raise ValueError( + f"const requires int or float but got {type(value)}") + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarConst(self.value).expr() + + def __repr__(self): + return f"const({self.value})" + + +class index(TensorExpression): + """Returns the iteration index for a given dimension name. + + Resolves the given dimension name to obtain its position in the iteration + domain of the operation. + """ + + def __init__(self, dim: DimDef): + self.dim_def = dim + self.dim = -1 + + def resolve_dimension_name(self, affine_state: AffineBuildState): + self.dim = affine_state.get_dim(self.dim_def.dimname) + + def to_scalar_expression(self) -> ScalarExpression: + assert self.dim != -1, "Dimension name not resolved" + return ScalarIndex(self.dim).expr() + + def __repr__(self): + return f"index({repr(self.dim)})" + + +############################################################################### +# Function types and function definitions. +############################################################################### + + +class FunctionKind(Enum): + UNARY = 0 + BINARY = 1 + TYPE = 2 + + +class UnaryFnType: + """Unary function. + + A unary function takes one tensor expression and returns the + function evaluation result. + """ + + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__(self, arg: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg]) + + def __repr__(self): + return f"{self.fn_name}" + + +class UnaryFn: + """Unary function namespace.""" + + exp = UnaryFnType("exp") + log = UnaryFnType("log") + abs = UnaryFnType("abs") + ceil = UnaryFnType("ceil") + floor = UnaryFnType("floor") + negf = UnaryFnType("negf") + + +class BinaryFnType: + """Binary function. + + A binary function takes two tensor expressions and returns the + function evaluation result. + """ + + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1]) + + def __repr__(self): + return f"{self.fn_name}" + + +class BinaryFn: + """Binary function namespace. + + As the integer types are signless, signedness is implement by different + functions that treat integers as signed or unsigned values. + + Examples: + - max -> `arith.MaxSIOp` + - max_unsinged -> `arith.MaxUIOp` + """ + + add = BinaryFnType("add") + sub = BinaryFnType("sub") + mul = BinaryFnType("mul") + div = BinaryFnType("div") + div_unsigned = BinaryFnType("div_unsigned") + max_signed = BinaryFnType("max_signed") + min_signed = BinaryFnType("min_signed") + max_unsigned = BinaryFnType("max_unsigned") + min_unsigned = BinaryFnType("min_unsigned") + + +class TypeFnType: + """Type conversion function. + + A type conversion function takes a target type and a tensor expression and + returns the casted tensor expression. + """ + + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg]) + + def __repr__(self): + return f"{self.fn_name}" + + +class TypeFn: + """Type conversion function namespace. + + As the integer types are signless, signedness is implement by different cast + functions that treat integers as signed (`cast_signed`) or unsigned + (`cast_unsigned`) values. + + Examples: + - cast_signed(I32 -> I64) -> `arith.ExtSIOp` + - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` + """ + + cast_signed = TypeFnType("cast_signed") + cast_unsigned = TypeFnType("cast_unsigned") + + +class ReduceFnUse: + """Reduction function use. + + A reduction use specifies the reduction function and dimensions. + """ + + def __init__( + self, + binary_fn: Optional[BinaryFnType], + binary_attr: Optional["BinaryFnAttrDef"], + *reduce_dims: DimDef, + ): + if bool(binary_fn) + bool(binary_attr) != 1: + raise ValueError( + "One of 'binary_fn', 'binary_attr' must be specified") + self.binary_fn = binary_fn + self.binary_attr = binary_attr + self.reduce_dims = reduce_dims + + def __call__(self, *args: TensorExpression) -> "TensorReduceFn": + return TensorReduceFn(self, args) + + def __repr__(self): + fn = self.binary_fn if self.binary_fn else self.binary_attr + return f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})" + + +class ReduceFnType: + """Reduction function. + + A binary function that reduces its RHS into its LHS. + """ + + def __init__(self, binary_fn: BinaryFnType): + if not isinstance(binary_fn, BinaryFnType): + raise ValueError( + f"Reduce expected a BinaryFnType but got {binary_fn}") + self.binary_fn = binary_fn + + def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: + return ReduceFnUse(self.binary_fn, None, *reduce_dims) + + def __repr__(self): + return f"reduce_{repr(self.binary_fn)}" + + +class ReduceFn: + add = ReduceFnType(BinaryFn.add) + mul = ReduceFnType(BinaryFn.mul) + max_signed = ReduceFnType(BinaryFn.max_signed) + min_signed = ReduceFnType(BinaryFn.min_signed) + max_unsigned = ReduceFnType(BinaryFn.max_unsigned) + min_unsigned = ReduceFnType(BinaryFn.min_unsigned) + + +############################################################################### +# Operand definitions. +############################################################################### + + +class OperandKind(Enum): + INPUT_TENSOR = 0 + SCALAR = 1 + OUTPUT_TENSOR = 2 + INDEX_ATTR = 3 + UNARY_FN_ATTR = 4 + BINARY_FN_ATTR = 5 + TYPE_FN_ATTR = 6 + + +class OperandDef: + """Definition of an operand passed to an operation. + + Keep the meta information of Tensor, Scalar, and Attribute operands and + provide the shared registration functionality. + """ + + def __init__( + self, + kind: OperandKind, + type_var: Optional[TypeVar] = None, + size_exprs: Optional[Sequence[AffineExprDef]] = None, + index_dims: Optional[Sequence[DimDef]] = None, + default_indices: Optional[Sequence[int]] = None, + default_fn: Optional[str] = None, + ): + if type_var and not isinstance(type_var, TypeVar): + raise ValueError( + f"OperandDef requires a TypeVar but got {repr(type_var)}") + self.owner = None # type: Optional["LinalgOpDef"] + self.type_var = type_var + self.size_exprs = size_exprs + self.index_dims = index_dims + self.default_indices = default_indices + self.default_fn = default_fn + self.kind = kind + self.name = None # type: Optional[str] + self.registered_index = -1 # type: int + + def attach(self, index: int, name: str, owner: "LinalgOpDef"): + if self.owner: + raise ValueError( + f"OperandDef already registered with an op: {self}") + self.registered_index = index + self.name = name + self.owner = owner + + def is_input(self) -> bool: + return self.kind == OperandKind.SCALAR or self.kind == OperandKind.INPUT_TENSOR + + def is_tensor(self) -> bool: + return ( + self.kind == OperandKind.INPUT_TENSOR + or self.kind == OperandKind.OUTPUT_TENSOR + ) + + def is_attribute(self) -> bool: + return ( + self.kind == OperandKind.INDEX_ATTR + or self.kind == OperandKind.UNARY_FN_ATTR + or self.kind == OperandKind.BINARY_FN_ATTR + or self.kind == OperandKind.TYPE_FN_ATTR + ) + + def __hash__(self): + return hash(id(self)) + + def __repr__(self): + return ( + f"{self.name}:OperandDef(kind={self.kind.name}, " + f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, " + f"index_dims={self.index_dims}, " + f"default_indices={self.default_indices}, " + f"default_fn={self.default_fn})" + ) + + +class TensorDef: + """Tensor operand definition. + + Tensor operands are indexed using the associated indexing_map when forwarded + to the body of the structured op. A unique name identifies the tensor operands + and an index determines their position in the operation's parameter list. A + tensor definition takes type, a shape, and an optional flag to mark output + tensors. Additionally, a tuple of index dimensions may be used to map the + tensor to the loop dimensions of the operation. This mapping is needed to + compute the indexing map of shape-only tensors that have no uses. + """ + + def __init__( + self, + type_var: TypeVar, + *shape: AffineExprDef, + index_dims: Optional[Sequence[DimDef]] = None, + output: bool = False, + ): + if index_dims and len(shape) != len(index_dims): + raise ValueError( + f"Expected the shape rank {len(shape)} to match the " + f"number of index_dims {len(index_dims)}" + ) + if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims): + raise ValueError( + f"TensorDef requires index dims of type DimDef but " f"got {index_dims}" + ) + kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR + self.operand_def = OperandDef( + kind, type_var=type_var, size_exprs=shape, index_dims=index_dims + ) + + def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse: + assert self.operand_def.owner, "TensorDef is not registered with an op" + state = AffineBuildState( + global_state=self.operand_def.owner._affine_state, allow_new_symbols=False + ) + if not isinstance(dims, tuple): + dims = (dims,) # Handle single subscript case. + # Special case: (None) is a 0d-scalar use. + if dims == (None,): + dims = () + + exprs = [] + for expr_def in dims: + if not isinstance(expr_def, AffineExprDef): + raise KeyError( + "A TensorDef can only be subscripted by a tuple of affine dims" + ) + exprs.append(expr_def) + return TensorUse(self.operand_def, exprs) + + def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression): + """Creates a new 1:1 comprehension by binding this tensor to an expression. + + Note that due to the way assignment works in Python, we have to capture + direct assignment as a setitem on the TensorDef. + """ + if not isinstance(value, TensorExpression): + raise ValueError( + f"Only TensorExpressions can be assigned to TensorDefs. " + f"Got: {repr(value)}" + ) + use = self[dims] + comp = Comprehension((use, value)) + self.operand_def.owner.comprehensions.append(comp) + + +class ScalarDef(TensorExpression): + """Scalar operand definition. + + Scalar operands are forwarded to the body of the structured op as they are. + A unique name identifies the scalars and an index determines their position in + the operation's parameter list. + """ + + def __init__(self, type_var: TypeVar): + self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var) + + @property + def scalar_name(self) -> str: + name = self.operand_def.name + assert name is not None, "ScalarDef not registered with an op" + return name + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarArg(self.scalar_name).expr() + + +class IndexAttrDef: + """Index attribute definition. + + Index attributes provide a way to define and set symbols that can be used in + indexing expressions. Every attribute specifies a tuple of symbols that at + compile-time are replaced by integer values as well as their default values. + """ + + def __init__(self, *sizes: SymbolDef, default: Sequence[int]): + if any(not isinstance(size, SymbolDef) for size in sizes): + raise ValueError( + f"IndexAttrDef requires sizes of type SymbolDef " f"but got {sizes}" + ) + if any(not isinstance(default_val, int) for default_val in default): + raise ValueError( + f"IndexAttrDef requires default values of type int " + f"but got {default}" + ) + if len(sizes) != len(default): + raise ValueError( + f"IndexAttrDef expects {len(sizes)} default values " + f"but got {len(default)}" + ) + self.operand_def = OperandDef( + OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default + ) + + +class UnaryFnAttrDef: + """Unary function attribute definition. + + Unary function attributes provide a way to make the arithmetic computation + parametrizable. Every attribute specifies a default unary function + that may be overwritten at operation instantiation time. + """ + + def __init__(self, default: "UnaryFnType"): + if not isinstance(default, UnaryFnType): + raise ValueError( + f"UnaryFnAttrDef requires default of type UnaryFnType " + f"but got {default}" + ) + self.operand_def = OperandDef( + OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name + ) + + def __call__(self, arg: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg]) + + +class BinaryFnAttrDef: + """Binary function attribute definition. + + Binary function attributes provide a way to make the arithmetic computation + parametrizable. Every attribute specifies a default binary function + that may be overwritten at operation instantiation time. + """ + + def __init__(self, default: "BinaryFnType"): + if not isinstance(default, BinaryFnType): + raise ValueError( + f"BinaryFnAttrDef requires default of type BinaryFnType " + f"but got {default}" + ) + self.operand_def = OperandDef( + OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name + ) + + def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.BINARY, None, self.operand_def, None, [arg0, arg1]) + + def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: + return ReduceFnUse(None, self, *reduce_dims) + + +class TypeFnAttrDef: + """Type conversion function attribute definition. + + Type conversion function attributes provide a way to make type conversions + parameterizable. Every attribute specifies a default type conversion function + that may be overwritten at operation instantiation time. + """ + + def __init__(self, default: "TypeFnType"): + if not isinstance(default, TypeFnType): + raise ValueError( + f"TypeFnAttrDef requires default of type TypeFnType " + f"but got {default}" + ) + self.operand_def = OperandDef( + OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name + ) + + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg]) + + +############################################################################### +# Operation definition. +############################################################################### + + +class Comprehension: + """Represents a single comprehension.""" + + def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]): + self.definitions = list() # List[TensorUse] + self.values = list() # List[TensorExpression] + + # Find the lhs to reduction rhs. + for assign, value in bindings: + if isinstance(value, TensorReduceFn): + if value.lhs: + raise ValueError( + f"Reduction expression already assigns: {value}") + value.lhs = assign + self.definitions.append(assign) + self.values.append(value) + + @property + def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]: + """Gets the reduction dims for the comprehension or None.""" + result = set() + for use in self.values: + if isinstance(use, TensorReduceFn): + result.add(use.reduce_use.reduce_dims) + else: + result.add(tuple()) + return result + + def __repr__(self): + if len(self.definitions) > 1: + defs_repr = f"({', '.join(repr(d) for d in self.definitions)})" + values_repr = f"({', '.join(repr(v) for v in self.values)})" + else: + defs_repr = f"{repr(self.definitions[0])}" + values_repr = f"{repr(self.values[0])}" + + return f"{defs_repr} = {values_repr}" + + +class OpInterfaceDef: + """An interface that an op implements.""" + + def __init__(self, cpp_name: str): + self.cpp_name = cpp_name + + +ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") +ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface") +FillOpInterface = OpInterfaceDef("LinalgFillOpInterface") + + +class OpDefinitionDef: + """A method that an op implements.""" + + def __init__(self, def_name: str): + self.def_name = def_name + + +Canonicalizer = OpDefinitionDef("hasCanonicalizer") + + +class OpMetadataDef(YAMLObject): + """Metadata about the op (generally not behavior impacting).""" + + yaml_tag = "!LinalgOpMetadata" + + def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]): + self.name = name + self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name + self.doc = doc + self.implements = [] # type: List[OpInterfaceDef] + self.defines = [] # type: List[OpDefinitionsDef] + + def to_yaml_custom_dict(self): + d = dict( + name=self.name, + cpp_class_name=self.cpp_class_name, + doc=self.doc, + ) + if self.implements: + d["implements"] = [intr.cpp_name for intr in self.implements] + if self.defines: + d["defines"] = [defi.def_name for defi in self.defines] + return d + + +class LinalgOpDef: + """Definition of a linalg op.""" + + def __init__( + self, name: str, cpp_class_name: Optional[str] = None, doc: Optional[str] = None + ): + self.metadata = OpMetadataDef( + name=name, cpp_class_name=cpp_class_name, doc=doc) + self.registered_operands = dict() # type: Dict[str, OperandDef] + self.domain = list() # type: List[DimDef] + self.comprehensions = list() # type: List[Comprehension] + self._affine_state = AffineBuildState() + + def add_operand(self, name: str, operand: OperandDef): + """Registers an operand.""" + if name in self.registered_operands: + raise ValueError( + f"The operand {name} is already registered " + f"to {self.registered_operands['name']}" + ) + structured_op_methods = [ + "inputs", + "outputs", + "result_tensors", + "region", + "iterator_types", + "indexing_maps", + "getRegionBuilder", + "getLibraryCallName", + ] + if operand.is_attribute() and name in structured_op_methods: + raise ValueError( + f"The attribute name {name} conflicts with a structured " + f"op method name" + ) + # Ensure output tensors are registered after input tensors and scalars and + # attributes are registered after all other operand types. + if operand.is_input() and any( + not op_def.is_input() for op_def in self.registered_operands.values() + ): + raise ValueError( + f"Input {name} registered after an output or attribute") + if operand.kind == OperandKind.OUTPUT_TENSOR and any( + op_def.is_attribute() for op_def in self.registered_operands.values() + ): + raise ValueError(f"Output {name} registered after an attribute") + operand.attach(len(self.registered_operands), name, self) + self.registered_operands[name] = operand + + def __repr__(self): + lines = [ + f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"] + for name, operand in self.registered_operands.items(): + lines.append(f" {operand}") + if self.comprehensions: + lines[-1] += " {" + for comprehension in self.comprehensions: + lines.append(f" {comprehension}") + lines.append("}") + return "\n".join(lines) diff --git a/python/scalehls/opdsl/lang/config.py b/python/scalehls/opdsl/lang/config.py new file mode 100644 index 00000000..4c3ae59d --- /dev/null +++ b/python/scalehls/opdsl/lang/config.py @@ -0,0 +1,492 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Represents configured ops as emitted for code generation. + +Classes in this module generally are directly serializable to YAML for use +by the code generator. + +TODO: These should just be dumb containers or serialization code but they +currently encode too many details of how the language is interpreted. Move this +to helpers on the comprehension objects themselves. +""" + +from typing import Dict, Optional + +from ... import ir as _ir +from .comprehension import * +from .yaml_helper import * + +__all__ = ["LinalgStructuredOpConfig", "LinalgOpConfig", "OperandDefConfig"] + + +def _serialize_affine_map(affine_map: _ir.AffineMap) -> str: + with affine_map.context: + # Affine map printing/parsing is via an AffineMap attr. + attr = _ir.AffineMapAttr.get(affine_map) + return str(attr) + + +class TensorUseConfig: + """Wrapper around a TensorUse with additional context-bound state.""" + + def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap): + self.tensor_use = tensor_use + self.indexing_map = indexing_map + + def __repr__(self): + return f"Use({self.tensor_use}, indexing_map={self.indexing_map})" + + +class OperandDefConfig(YAMLObject): + """Wrapper containing an operand definition with additional state.""" + + yaml_tag = "!LinalgOperandDefConfig" + + def __init__( + self, + operand_def: OperandDef, + shape_map: Optional[_ir.AffineMap] = None, + index_attr_map: Optional[_ir.AffineMap] = None, + ): + self.operand_def = operand_def + self.shape_map = shape_map # type: Optional[_ir.AffineMap] + self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap] + self.indexing_map = None # type: Optional[_ir.AffineMap] + + @property + def name(self) -> str: + return self.operand_def.name + + @property + def kind(self) -> OperandKind: + return self.operand_def.kind + + @property + def type_var(self) -> TypeVar: + return self.operand_def.type_var + + def to_yaml_custom_dict(self): + self_dict = dict( + name=self.name, kind=self.operand_def.kind.name.lower()) + if self.type_var: + self_dict["type_var"] = self.type_var.name + if self.shape_map: + self_dict["shape_map"] = _serialize_affine_map(self.shape_map) + if self.index_attr_map: + self_dict["index_attr_map"] = _serialize_affine_map( + self.index_attr_map) + if self.operand_def.default_indices: + self_dict["default_indices"] = self.operand_def.default_indices + if self.operand_def.default_fn: + self_dict["default_fn"] = self.operand_def.default_fn + return self_dict + + def __repr__(self): + return ( + f"OperandDefConfig({self.operand_def}, " + f"shape_map={self.shape_map}, " + f"index_attr_map={self.index_attr_map}, " + f"indexing_map={self.indexing_map})" + ) + + +class LinalgIndexingMapsConfig(YAMLObject): + """Abstracts the style of indexing maps that the op exports. + + Presently only static (tied to the op name) indexing maps are supported. In + the future, it is expected that we will have additional variants: + - Dynamic based on attributes + - Dynamic based on operands + Each is expected to require a different variant of specification. + """ + + yaml_tag = "!LinalgIndexingMapsConfig" + + def __init__(self, static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None): + self.static_indexing_maps = static_indexing_maps + + def to_yaml_custom_dict(self): + if self.static_indexing_maps is not None: + return dict( + static_indexing_maps=[ + _serialize_affine_map(m) for m in self.static_indexing_maps + ] + ) + raise ValueError( + f"LinalgIndexingMapsConfig must have one type of indexing map" f"(got none)" + ) + + +class LinalgStructuredOpConfig(YAMLObject): + """Configuration for metadata sufficient to construct a linalg named op.""" + + yaml_tag = "!LinalgStructuredOpConfig" + + def __init__( + self, + comprehension: Comprehension, + domain: Sequence[DimDef], + registered_operands: Sequence[OperandDef], + context: Optional[_ir.Context] = None, + ): + self.context = context if context is not None else _ir.Context() + self.affine_state = AffineBuildState() + self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]] + self.operands = dict() # type: Dict[OperandDef, OperandDefConfig] + self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] + + # Compute the ordered set of writes and collect the tensor, capture, dims, + # and index uses. + collected_tensor_uses = set() + collected_scalar_uses = set() + collected_dim_uses = set() + collected_indices = set() + for write_use, read_use in zip(comprehension.definitions, comprehension.values): + self.writes.append((write_use, read_use)) + + for write_use, read_use in self.writes: + collected_tensor_uses.add(write_use) + read_use.collect_tensor_uses(collected_tensor_uses) + read_use.collect_scalar_uses(collected_scalar_uses) + read_use.collect_dim_uses(collected_dim_uses) + write_use.collect_dim_uses(collected_dim_uses) + read_use.collect_indices(collected_indices) + + # Set domain to the sorted list of uses if no domain annotation is given. + if not domain: + domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname) + + # Verify the domain dimensions match the used dimensions. + if len(domain) != len(collected_dim_uses) or any( + dim not in collected_dim_uses for dim in domain + ): + raise ValueError( + f"Expected the annotated domain dimensions {domain} to " + f"match the set of dimension used by the tensor " + f"comprehension {collected_dim_uses}" + ) + + # Instantiate the dimensions in the given order. + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False + ) + for dim in domain: + dim.build(state=local_state) + + # Collect all attribute definitions. + collected_attr_defs = list() + for operand in registered_operands: + if operand.is_attribute(): + collected_attr_defs.append(operand) + + # Collect all tensors with manual indexing annotation. + collected_index_defs = list() + for operand in registered_operands: + if operand.index_dims: + if any(dim not in collected_dim_uses for dim in operand.index_dims): + raise ValueError( + f"Expected all index dims {operand.index_dims} of " + f"operand {operand.name} to have uses." + ) + collected_index_defs.append(operand) + + # Collect the operand definitions of all tensor/scalar uses, attributes, and + # shape-only tensors. + all_operand_defs = list() + for use in collected_tensor_uses: + all_operand_defs.append(use.operand_def) + for use in collected_scalar_uses: + all_operand_defs.append(use.operand_def) + for definition in collected_attr_defs: + all_operand_defs.append(definition) + for definition in collected_index_defs: + all_operand_defs.append(definition) + + # Add all operands in registration order to ensure the symbols are + # registered in the order they appear. + all_operand_defs = sorted( + all_operand_defs, key=lambda operand_def: operand_def.registered_index + ) + for operand_def in all_operand_defs: + self.add_operand(operand_def) + + # Add all shape-only tensor index_dim annotations and all tensor uses. + for definition in collected_index_defs: + self.add_indexed_operand(definition) + for use in collected_tensor_uses: + self.add_tensor_use(use) + + # Normalize all shape and indexing maps now that full count of dims and + # symbols are known. + for cuse in self.uses.values(): + cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map) + for definition in collected_index_defs: + self.operands[definition].indexing_map = self._normalize_affine_map( + self.operands[definition].indexing_map + ) + for operand_config in self.operands.values(): + if operand_config.shape_map: + operand_config.shape_map = self._normalize_affine_map( + operand_config.shape_map, with_dims=False + ) + if operand_config.index_attr_map: + operand_config.index_attr_map = self._normalize_affine_map( + operand_config.index_attr_map, with_dims=False + ) + + # Now for each write use, propagate the indexing maps from the use to the + # tensor, ensuring that there are not conflicts. + for write_use, _ in self.writes: + write_tensor_config = self.operands[write_use.operand_def] + if write_tensor_config.indexing_map: + raise ValueError( + f"Unexpected multi-write to a single tensor: {write_tensor_config}" + ) + write_tensor_config.indexing_map = self.uses[write_use].indexing_map + + # For each read use, propagate the indexing maps from the use to the + # tensor, ensuring that there are not conflicts. + for _, read_expr in self.writes: + read_uses = set() # type: Set[TensorUse] + read_expr.collect_tensor_uses(read_uses) + for read_use in read_uses: + read_operand_config = self.operands[read_use.operand_def] + if ( + read_operand_config.indexing_map + and read_operand_config.indexing_map + != self.uses[read_use].indexing_map + ): + raise ValueError( + f"Unexpected multi-read of a tensor with different accesses:" + f"{read_operand_config} vs {read_use}" + ) + read_operand_config.indexing_map = self.uses[read_use].indexing_map + + # Set the indexing map of all scalar uses to the empty map. + for operand_config in self.operands.values(): + if operand_config.operand_def.kind == OperandKind.SCALAR: + operand_config.indexing_map = self._get_scalar_map() + + # Check all registered tensor and scalar operands have an indexing map. + for operand in registered_operands: + if operand.is_attribute(): + continue + if not (operand in self.operands and self.operands[operand].indexing_map): + raise ValueError( + f"Failed to compute an indexing map for operand " f"{operand.name}" + ) + + # Collect reduction dims and ensure all the same. + all_reduction_dims = set(comprehension.all_reduction_dims) + if len(all_reduction_dims) != 1: + raise ValueError( + f"All writes within a generic must have the same reduction " + f"dims. Got: {all_reduction_dims}" + ) + self.reduction_dims = next(iter(all_reduction_dims)) + + # Check the index dimension exists and resolve. + for index in collected_indices: + if index.dim_def.dimname not in self.affine_state.all_dims: + raise ValueError( + f"The dimension {index.dim_def.dimname} is not part of the " + f"iteration domain {self.affine_state.all_dims}" + ) + index.resolve_dimension_name(self.affine_state) + + # Generate the scalar assignments (used to build a body). + self.assignments = [ + ScalarAssign(write_use.tensor_name, + read_expr.to_scalar_expression()) + for write_use, read_expr in self.writes + ] + + @property + def ordered_operands(self) -> Sequence[OperandDefConfig]: + return sorted( + self.operands.values(), + key=lambda operand: operand.operand_def.registered_index, + ) + + @property + def ordered_dims(self) -> Sequence[Tuple[str, int]]: + """Gets the ordered list of dim bindings (symbolic name, position). + + TODO: The original parser relies on parse ordering to arrive at the + iterator types, but that ordering is not defined on the Python side, so + this may be ambiguous. + """ + return list(self.affine_state.all_dims.items()) + + @property + def indexing_maps(self) -> Sequence[_ir.AffineMap]: + return [o.indexing_map for o in self.ordered_operands if o.indexing_map] + + @property + def iterator_types(self) -> Sequence[str]: + def get_type(symbolic_name, position): + for reduction_dim_expr in self.reduction_dims: + if reduction_dim_expr.dimname == symbolic_name: + return "reduction" + return "parallel" + + return [get_type(*dim) for dim in self.ordered_dims] + + def add_operand(self, operand_def: OperandDef): + if operand_def in self.operands: + return + if not (operand_def.is_tensor() or operand_def.kind == OperandKind.INDEX_ATTR): + self.operands[operand_def] = OperandDefConfig(operand_def) + return + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_dims=False + ) + exprs = [] + for expr in operand_def.size_exprs: + exprs.append(expr.build(state=local_state)) + assert local_state.local_dim_count == 0 + affine_map = _ir.AffineMap.get( + dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs + ) + if operand_def.kind == OperandKind.INDEX_ATTR: + self.operands[operand_def] = OperandDefConfig( + operand_def, index_attr_map=affine_map + ) + else: + self.operands[operand_def] = OperandDefConfig( + operand_def, shape_map=affine_map + ) + + def add_indexed_operand(self, operand_def: OperandDef): + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False + ) + exprs = [] + for expr in operand_def.index_dims: + exprs.append(expr.build(state=local_state)) + self.operands[operand_def].indexing_map = _ir.AffineMap.get( + dim_count=local_state.dim_count, + symbol_count=local_state.symbol_count, + exprs=exprs, + ) + + def add_tensor_use(self, tensor_use: TensorUse): + if tensor_use in self.uses: + return + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False + ) + exprs = [] + for expr in tensor_use.indices: + exprs.append(expr.build(state=local_state)) + indexing_map = _ir.AffineMap.get( + dim_count=local_state.dim_count, + symbol_count=local_state.symbol_count, + exprs=exprs, + ) + + use_config = TensorUseConfig(tensor_use, indexing_map) + self.uses[tensor_use] = use_config + + def _get_scalar_map(self) -> _ir.AffineMap: + """Create an empty affine map used to index a scalar.""" + with self.context: + return _ir.AffineMap.get( + dim_count=self.affine_state.dim_count, + symbol_count=self.affine_state.symbol_count, + exprs=list(), + ) + + def _normalize_affine_map( + self, affine_map: _ir.AffineMap, with_dims: bool = True + ) -> _ir.AffineMap: + """Normalizes an indexing map to have the max known symbols and dims.""" + with self.context: + return _ir.AffineMap.get( + dim_count=self.affine_state.dim_count if with_dims else 0, + symbol_count=self.affine_state.symbol_count, + exprs=list(affine_map.results), + ) + + def to_yaml_custom_dict(self): + self_dict = dict(args=self.ordered_operands) + # TODO: Refactor the hierarchy internally when supporting more + # than static (preserving this serialized form). + self_dict["indexing_maps"] = LinalgIndexingMapsConfig( + static_indexing_maps=self.indexing_maps + ) + self_dict["iterator_types"] = self.iterator_types + self_dict["assignments"] = self.assignments + return self_dict + + def __repr__(self): + lines = [ + f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"] + lines.append("operands=[") + for def_config in self.ordered_operands: + lines.append(f" {repr(def_config)}") + lines.append("], indexing_maps=[") + for m in self.indexing_maps: + lines.append(f" {repr(m)}") + lines.append(f"], iterator_types=[") + for t in self.iterator_types: + lines.append(f" {t}") + lines.append("])") + return "\n".join(lines) + + +class LinalgOpConfig(YAMLObject): + """Container for any supported linalg op type. + + This includes the concrete type by name for ease of parsing by systems + that ignore tags. + """ + + yaml_tag = "!LinalgOpConfig" + + def __init__( + self, + metadata: OpMetadataDef, + *, + structured_op: Optional[LinalgStructuredOpConfig] = None, + ): + self.metadata = metadata + self.structured_op = structured_op + + def to_yaml_custom_dict(self): + self_dict = dict( + metadata=self.metadata, + ) + if self.structured_op: + self_dict["structured_op"] = self.structured_op + return self_dict + + @staticmethod + def from_linalg_op_def( + op_def: LinalgOpDef, context: Optional[_ir.Context] = None + ) -> Sequence["LinalgOpConfig"]: + """Expands a LinalgOpDef into corresponding Linalg configured ops.""" + # TODO: Many LinalgOpDef patterns need to expand to multiple generics. + assert len(op_def.comprehensions) == 1, "Only one comprehension supported" + return [ + LinalgOpConfig( + op_def.metadata, + structured_op=LinalgStructuredOpConfig( + op_def.comprehensions[0], + op_def.domain, + op_def.registered_operands.values(), + context, + ), + ), + ] + + def __repr__(self): + return ( + f"LinalgOpConfig(metadata={self.metadata},\n" + f"structured_op={self.structured_op})" + ) diff --git a/python/scalehls/opdsl/lang/dsl.py b/python/scalehls/opdsl/lang/dsl.py new file mode 100644 index 00000000..c23dc637 --- /dev/null +++ b/python/scalehls/opdsl/lang/dsl.py @@ -0,0 +1,203 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Dict, List, Sequence, Union + +from contextlib import contextmanager +import functools +import inspect +import threading + +from ... import ir +from ...dialects._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, +) +from .comprehension import * +from .config import * +from .emitter import * + +_CONTEXT = threading.local() + +StructuredOpOuts = Union[ + ir.Operation, + ir.OpView, + ir.OpResultList, + Sequence[Union[ir.Value, ir.Operation, ir.OpView]], +] + + +@contextmanager +def bind_op_def(op_def: LinalgOpDef): + if hasattr(_CONTEXT, "current_op_def"): + raise ValueError("Cannot recursively define an operation") + _CONTEXT.current_op_def = op_def + try: + yield op_def + finally: + del _CONTEXT.current_op_def + + +def current_op_def() -> LinalgOpDef: + try: + return _CONTEXT.current_op_def + except AttributeError: + raise ValueError( + "Attempt to access the current op definition being defined " + "but none is set. Did you mean to call this in an op definition?" + ) + + +def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList: + if isinstance(outs, (ir.Operation, ir.OpView)): + return _get_op_results_or_values(outs) + elif isinstance(outs, ir.OpResultList): + return outs + + return [_get_op_result_or_value(o) for o in outs] + + +class DefinedOpCallable: + """Callable that wraps any defined op function.""" + + def __init__(self, op_name: str, op_def: LinalgOpDef): + self.op_name = op_name + self.op_def = op_def + + def __call__( + self, + *ins: Union[ir.Operation, ir.OpView, ir.Value], + outs: StructuredOpOuts, + **kwargs, + ): + """Emits the corresponding op definition as IR. + + Most arguments are passed through to the underlying emitter. The following + keyword argument is interpreted here: + emit_generic: Emits a generic form as appropriate (default True). If + False, a named form is emitted (which must have been built in to the + compiler). + """ + emit_generic = kwargs.pop("emit_generic", False) + if not isinstance(emit_generic, bool): + raise ValueError( + f"The named argument 'emit_generic' needs to be " + f" of type bool but got {type(emit_generic)}" + ) + + op_configs = LinalgOpConfig.from_linalg_op_def( + self.op_def, context=ir.Context.current + ) + + if len(op_configs) != 1: + # TODO: Support composite ops. + raise NotImplementedError( + f"Emission of composite linalg ops not supported: {op_configs}" + ) + + ctx = ir.Context.current + linalgDialect = ctx.get_dialect_descriptor("linalg") + fully_qualified_name = "linalg." + self.op_name + emit_generic = emit_generic or not ctx.is_registered_operation( + fully_qualified_name + ) + + op_config = op_configs[0] + out_values = _prepare_structured_op_outs(outs) + in_values = [_get_op_result_or_value(i) for i in ins] + if op_config.structured_op: + if emit_generic: + return emit_generic_structured_op( + op_config.structured_op, *in_values, outs=out_values, **kwargs + ) + else: + return emit_named_structured_op( + op_config.structured_op, + self.op_name, + self.op_def.metadata.cpp_class_name, + *in_values, + outs=out_values, + **kwargs, + ) + + raise NotImplementedError( + f"Emission of linalg op type not supported: {op_config}" + ) + + +def linalg_structured_op( + dsl_func=None, *, op_name=None, op_class_name=None +) -> DefinedOpCallable: + if dsl_func is None: + # Curry the keyword args in for delayed application. + return functools.partial( + linalg_structured_op, op_name=op_name, op_class_name=op_class_name + ) + # Determine default names by introspecting the function. + if op_name is None: + op_name = dsl_func.__name__ + if op_class_name is None: + # Camel case it. + op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op" + + op_def = LinalgOpDef( + name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc( + dsl_func) + ) + + # Extract arguments and TensorDefs from the signature. + dsl_func_args = list() + sig = inspect.signature(dsl_func) + for param_name, param in sig.parameters.items(): + param_default = param.default + if isinstance( + param_default, + ( + TensorDef, + ScalarDef, + IndexAttrDef, + UnaryFnAttrDef, + BinaryFnAttrDef, + TypeFnAttrDef, + ), + ): + op_def.add_operand(param_name, param_default.operand_def) + else: + raise ValueError( + f"@linalg_structured_op function parameters must be defaulted as " + f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): " + f"Found {param_name}: {param_default}" + ) + dsl_func_args.append(param_default) + + # Invoke the DSL func to finish populating the op definition. + with bind_op_def(op_def): + dsl_func(*dsl_func_args) + + # TODO: The returned callable should be an IR emitter but that is not + # upstreamed yet. + return DefinedOpCallable(op_name, op_def) + + +def domain(*dimensions: DimDef): + if any(not isinstance(d, DimDef) for d in dimensions): + raise ValueError( + f"Expected dimensions of type DimDef but got {dimensions}") + current_op_def().domain.extend(dimensions) + + +def implements(*interfaces: OpInterfaceDef): + if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces): + raise ValueError( + f"Expected interfaces of type OpInterfaceDef but got {interfaces}" + ) + current_op_def().metadata.implements.extend(interfaces) + + +def defines(*definitions: OpDefinitionDef): + if any(not isinstance(defi, OpDefinitionDef) for defi in definitions): + raise ValueError( + f"Expected definitions of type OpDefinitionDef but got {definitions}" + ) + current_op_def().metadata.defines.extend(definitions) diff --git a/python/scalehls/opdsl/lang/emitter.py b/python/scalehls/opdsl/lang/emitter.py new file mode 100644 index 00000000..390cdff0 --- /dev/null +++ b/python/scalehls/opdsl/lang/emitter.py @@ -0,0 +1,661 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Callable, Dict, List, Sequence, Tuple, Union + +from ...ir import * + +from ...dialects import func +from ...dialects import linalg +from ...dialects import math +from ...dialects import arith +from ...dialects import complex +from ...dialects._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, +) + +from .scalar_expr import * +from .config import * +from .comprehension import * +import numpy as np + +__all__ = [ + "emit_generic_structured_op", + "emit_named_structured_op", + "ValueList", +] + +# Type aliases. +ValueList = Union[Sequence[Value], OpResultList] + + +def isa(cls: Type, ty: Type): + try: + cls(ty) + return True + except ValueError: + return False + + +def prepare_common_structured_op( + op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: ValueList, + **attrs: Union[Sequence[int], TypeFnType], +): + all_arg_defs = op_config.ordered_operands + in_arg_defs = [ + d + for d in all_arg_defs + if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR] + ] + out_arg_defs = [d for d in all_arg_defs if d.kind == + OperandKind.OUTPUT_TENSOR] + index_attr_arg_defs = [ + d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR] + fn_attr_arg_defs = [ + d + for d in all_arg_defs + if d.kind + in [ + OperandKind.UNARY_FN_ATTR, + OperandKind.BINARY_FN_ATTR, + OperandKind.TYPE_FN_ATTR, + ] + ] + + # Verify outs is a sequence or a list of results. + if not isinstance(outs, (Sequence, OpResultList)): + raise ValueError( + f"Expected named argument outs to have type Sequence or " + f"OpResultLis but got {type(outs)}" + ) + + # Arity validation. + if len(ins) != len(in_arg_defs): + raise ValueError( + f"Expected {len(in_arg_defs)} inputs but got " f"{len(ins)} for {op_config}" + ) + if outs and len(outs) != len(out_arg_defs): + raise ValueError( + f"Expected {len(out_arg_defs)} outputs but got " + f"{len(outs)} for {op_config}" + ) + + # Compute a replacement list for all index attribute symbols. + expressions = [] # type: Sequence[AffineExpr] + replacements = [] # type: Sequence[AffineExpr] + for index_attr in index_attr_arg_defs: + index_attr_vals = index_attr.operand_def.default_indices + if index_attr.name in attrs: + index_attr_vals = attrs.get(index_attr.name) + assert index_attr_vals, "Index attribute has no value" + if not all(isinstance(value, int) for value in index_attr_vals): + raise ValueError( + f"Attribute {index_attr.name} needs to be of type " + f"Sequence[int] but got {type(index_attr_vals)}" + ) + results = index_attr.index_attr_map.results # type: AffineExprList + if len(index_attr_vals) != len(results): + raise ValueError( + f"Attribute {index_attr.name} has length {len(results)} " + f"but got {len(index_attr_vals)} values" + ) + for expr, value in zip(results, index_attr_vals): + expressions.append(expr) + replacements.append(AffineConstantExpr.get(value)) + + # Replace all index attribute symbols by their value. + # TODO: Add support for shape symbols. + indexing_maps = [] # type: Sequence[AffineMap] + for curr in op_config.indexing_maps: + for expression, replacement in zip(expressions, replacements): + curr = curr.replace(expression, replacement, + curr.n_dims, curr.n_symbols) + indexing_maps.append(curr) + + # TODO: Linalg verification does not currently allow symbols. + # Compress them for now and verify none are left. + indexing_maps = AffineMap.compress_unused_symbols( + indexing_maps, Context.current) + if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps): + raise ValueError( + f"Expected indexing_maps to use no symbols after " + f"replacement and compression but got {indexing_maps}" + ) + + outs, out_types = _infer_structured_outs( + op_config, in_arg_defs, ins, out_arg_defs, outs + ) + + result_types = [t for t in out_types if isa(RankedTensorType, t)] + + # Initialize the type dictionary with the predefined types. + type_mapping = dict() # type: Dict[str, Type] + type_mapping["F32"] = F32Type.get() + type_mapping["F64"] = F64Type.get() + type_mapping["I32"] = IntegerType.get_signless(32) + type_mapping["I64"] = IntegerType.get_signless(64) + + # Extract type vars for input/output based types. + block_arg_types = list() # type: List[Type] + for arg_def, arg_element_type in zip( + in_arg_defs + out_arg_defs, _get_types_from_values(*ins, *outs) + ): + _add_type_mapping(arg_def, arg_element_type, + type_mapping, block_arg_types) + + # Emit the generic op. + # TODO: Support emission of pure memref form. + indexing_maps_attr = ArrayAttr.get( + [AffineMapAttr.get(am) for am in indexing_maps]) + iterator_types_attr = ArrayAttr.get( + [ + Attribute.parse(f"#linalg.iterator_type<{s}>") + for s in op_config.iterator_types + ] + ) + + # Compute the index attributes used when emitting a named structured op. + index_attrs = {} # type: Dict[str, DenseElementAttr] + for index_attr in index_attr_arg_defs: + index_attr_vals = attrs.get(index_attr.name) + # Only forward attributes set to a non-default value. + if index_attr_vals: + array = np.array(index_attr_vals, dtype=np.int64) + index_attrs[index_attr.name] = DenseElementsAttr.get(array) + + # Compute the function attribute mapping. + fn_attr_mapping = {} + for fn_attr in fn_attr_arg_defs: + attr_val = fn_attr.operand_def.default_fn + attr_kind = fn_attr.kind + if fn_attr.name in attrs: + fn = attrs.get(fn_attr.name) + if attr_kind == OperandKind.UNARY_FN_ATTR: + if not isinstance(fn, UnaryFnType): + raise ValueError( + f"Attribute {fn_attr.name} needs to be of type " + f"UnaryFnType but got {type(attr_val)}" + ) + elif attr_kind == OperandKind.BINARY_FN_ATTR: + if not isinstance(fn, BinaryFnType): + raise ValueError( + f"Attribute {fn_attr.name} needs to be of type " + f"BinaryFnType but got {type(attr_val)}" + ) + else: + if not isinstance(fn, TypeFnType): + raise ValueError( + f"Attribute {fn_attr.name} needs to be of type " + f"TypeFnType but got {type(attr_val)}" + ) + attr_val = fn.fn_name + assert attr_val, "Function attribute has no value" + fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind) + + return ( + all_arg_defs, + in_arg_defs, + out_arg_defs, + outs, + result_types, + type_mapping, + indexing_maps_attr, + iterator_types_attr, + index_attrs, + fn_attr_mapping, + block_arg_types, + ) + + +def emit_generic_structured_op( + op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: ValueList, + **attrs: Sequence[int], +): + ( + all_arg_defs, + in_arg_defs, + out_arg_defs, + outs, + result_types, + type_mapping, + indexing_maps_attr, + iterator_types_attr, + index_attrs, + fn_attr_mapping, + block_arg_types, + ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs) + + # An operation that accesses only scalars and scalar/rank zero tensors is + # rank polymorhpic. We implement rank polymorphism by generating different + # indexing maps and iterators that match the rank of the first output tensor. + # An operation is rank polymorphic if the iteration domain has rank zero. + if not iterator_types_attr: + rank = ShapedType(outs[0].type).rank + iterator_types_attr = ArrayAttr.get( + [Attribute.parse("#linalg.iterator_type")] * rank + ) + scalar_map = AffineMap.get(rank, 0, []) + tensor_map = AffineMap.get_identity(rank) + indexing_maps = [] + for arg_def in all_arg_defs: + if arg_def.operand_def.kind == OperandKind.SCALAR: + indexing_maps.append(scalar_map) + if arg_def.operand_def.is_tensor(): + idx = arg_def.operand_def.registered_index + if idx < len(ins) and ShapedType(ins[idx].type).rank == 0: + indexing_maps.append(scalar_map) + else: + indexing_maps.append(tensor_map) + indexing_maps_attr = ArrayAttr.get( + [AffineMapAttr.get(am) for am in indexing_maps] + ) + + generic_op = linalg.GenericOp( + result_tensors=result_types, + inputs=ins, + outputs=outs, + indexing_maps=indexing_maps_attr, + iterator_types=iterator_types_attr, + doc=None, # TODO: Make optional. + library_call=None, + ) # TODO: Make optional. + + # Construct the body. + block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs) + block = generic_op.regions[0].blocks.append(*block_arg_types) + block_arg_mapping = dict(zip(block_arg_names, block.arguments)) + with InsertionPoint(block): + body_builder = _BodyBuilder( + type_mapping, block_arg_mapping, fn_attr_mapping) + for assignment in op_config.assignments: + body_builder.assign(assignment) + body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs)) + + if len(result_types) == 1: + return generic_op.result + else: + return generic_op.results + + +def emit_named_structured_op( + op_config: LinalgStructuredOpConfig, + op_name: str, + op_class_name: str, + *ins: Value, + outs: ValueList, + **attrs: Sequence[int], +): + ( + all_arg_defs, + in_arg_defs, + out_arg_defs, + outs, + result_types, + type_mapping, + indexing_maps_attr, + iterator_types_attr, + index_attrs, + fn_attr_mapping, + block_arg_types, + ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs) + + # If we get here, there must exist a builtin class `op_class_name`. + ctx = Context.current + fully_qualified_name = "linalg." + op_name + if ( + not ctx.is_registered_operation(fully_qualified_name) + or not op_class_name in linalg.__dict__.keys() + ): + raise NotImplementedError( + f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}" + ) + + # Set the index attributes used to compute the indexing maps. + named_op = getattr(linalg, op_class_name)(ins, outs, result_types) + for name, value in index_attrs.items(): + named_op.operation.attributes[name] = value + + # Compute the function attributes by combining operand kind and function name. + for name, (fn_name, kind) in fn_attr_mapping.items(): + assert kind.name.lower().endswith("_attr") + enum_name = kind.name.lower()[:-5] + named_op.operation.attributes[name] = Attribute.parse( + f"#linalg.{enum_name}<{fn_name}>" + ) + + linalg.fill_builtin_region(named_op.operation) + + if len(result_types) == 1: + return named_op.result + else: + return named_op.results + + +class _BodyBuilder: + """Constructs a structured op body by evaluating assignments.""" + + def __init__( + self, + type_mapping: Dict[str, Type], + block_arg_mapping: Dict[str, Value], + fn_attr_mapping: Dict[str, str], + ): + self.type_mapping = type_mapping + self.block_arg_mapping = block_arg_mapping + self.fn_attr_mapping = fn_attr_mapping + self.yield_mapping = dict() # type: Dict[str, Value] + + def assign(self, assignment: ScalarAssign): + if assignment.arg in self.yield_mapping: + raise ValueError( + f"Multiple assignments to the same argument are forbidden: " + f"{assignment}" + ) + self.yield_mapping[assignment.arg] = self.expression(assignment.value) + + def expression(self, expr: ScalarExpression) -> Value: + if expr.scalar_arg: + try: + return self.block_arg_mapping[expr.scalar_arg.arg] + except KeyError: + raise ValueError( + f"Argument {expr.scalar_arg.arg} is not bound for " + f"this structured op." + ) + elif expr.scalar_const: + value_attr = Attribute.parse(expr.scalar_const.value) + return arith.ConstantOp(value_attr.type, value_attr).result + elif expr.scalar_index: + dim_attr = IntegerAttr.get( + IntegerType.get_signless(64), expr.scalar_index.dim + ) + return linalg.IndexOp(dim_attr).result + elif expr.scalar_fn: + kind = expr.scalar_fn.kind.name.lower() + fn_name = expr.scalar_fn.fn_name + if expr.scalar_fn.attr_name: + fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name] + fn = self._get_function(f"_{kind}_{fn_name}") + operand_values = [ + self.expression(operand) for operand in expr.scalar_fn.operands + ] + if expr.scalar_fn.kind == FunctionKind.TYPE: + operand_values = [ + expr.scalar_fn.type_var.name] + operand_values + return fn(*operand_values) + raise NotImplementedError( + f"Unimplemented scalar body expression: {expr}") + + def yield_outputs(self, *output_names: str): + output_values = [] + for n in output_names: + try: + output_values.append(self.yield_mapping[n]) + except KeyError: + raise ValueError( + f"Body assignments do not assign all outputs: " f"missing '{n}'" + ) + linalg.YieldOp(output_values) + + def _get_function(self, fn_name: str) -> Callable: + try: + fn = getattr(self, f"{fn_name}") + except AttributeError: + raise ValueError(f"Function '{fn_name}' is not a known function") + return fn + + def _cast( + self, type_var_name: str, operand: Value, is_unsigned_cast: bool = False + ) -> Value: + try: + to_type = self.type_mapping[type_var_name] + except KeyError: + raise ValueError( + f"Unbound type variable '{type_var_name}' (" + f"expected one of {self.type_mapping.keys()}" + ) + if operand.type == to_type: + return operand + if _is_integer_type(to_type): + return self._cast_to_integer(to_type, operand, is_unsigned_cast) + elif _is_floating_point_type(to_type): + return self._cast_to_floating_point(to_type, operand, is_unsigned_cast) + + def _cast_to_integer( + self, to_type: Type, operand: Value, is_unsigned_cast: bool + ) -> Value: + to_width = IntegerType(to_type).width + operand_type = operand.type + if _is_floating_point_type(operand_type): + if is_unsigned_cast: + return arith.FPToUIOp(to_type, operand).result + return arith.FPToSIOp(to_type, operand).result + if _is_index_type(operand_type): + return arith.IndexCastOp(to_type, operand).result + # Assume integer. + from_width = IntegerType(operand_type).width + if to_width > from_width: + if is_unsigned_cast: + return arith.ExtUIOp(to_type, operand).result + return arith.ExtSIOp(to_type, operand).result + elif to_width < from_width: + return arith.TruncIOp(to_type, operand).result + raise ValueError( + f"Unable to cast body expression from {operand_type} to " f"{to_type}" + ) + + def _cast_to_floating_point( + self, to_type: Type, operand: Value, is_unsigned_cast: bool + ) -> Value: + operand_type = operand.type + if _is_integer_type(operand_type): + if is_unsigned_cast: + return arith.UIToFPOp(to_type, operand).result + return arith.SIToFPOp(to_type, operand).result + # Assume FloatType. + to_width = _get_floating_point_width(to_type) + from_width = _get_floating_point_width(operand_type) + if to_width > from_width: + return arith.ExtFOp(to_type, operand).result + elif to_width < from_width: + return arith.TruncFOp(to_type, operand).result + raise ValueError( + f"Unable to cast body expression from {operand_type} to " f"{to_type}" + ) + + def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value: + return self._cast(type_var_name, operand, False) + + def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: + return self._cast(type_var_name, operand, True) + + def _unary_exp(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.ExpOp(x).result + raise NotImplementedError("Unsupported 'exp' operand: {x}") + + def _unary_log(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.LogOp(x).result + raise NotImplementedError("Unsupported 'log' operand: {x}") + + def _unary_abs(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.AbsFOp(x).result + raise NotImplementedError("Unsupported 'abs' operand: {x}") + + def _unary_ceil(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.CeilOp(x).result + raise NotImplementedError("Unsupported 'ceil' operand: {x}") + + def _unary_floor(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.FloorOp(x).result + raise NotImplementedError("Unsupported 'floor' operand: {x}") + + def _unary_negf(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return arith.NegFOp(x).result + if _is_complex_type(x.type): + return complex.NegOp(x).result + raise NotImplementedError("Unsupported 'negf' operand: {x}") + + def _binary_add(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.AddFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.AddIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.AddOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}") + + def _binary_sub(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.SubFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.SubIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.SubOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}") + + def _binary_mul(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MulFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MulIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.MulOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}") + + def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MaxFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MaxSIOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}") + + def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MaxFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MaxUIOp(lhs, rhs).result + raise NotImplementedError( + "Unsupported 'max_unsigned' operands: {lhs}, {rhs}") + + def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MinFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MinSIOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}") + + def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MinFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MinUIOp(lhs, rhs).result + raise NotImplementedError( + "Unsupported 'min_unsigned' operands: {lhs}, {rhs}") + + +def _infer_structured_outs( + op_config: LinalgStructuredOpConfig, + in_arg_defs: Sequence[OperandDefConfig], + ins: Sequence[Value], + out_arg_defs: Sequence[OperandDefConfig], + outs: Union[Sequence[Value], OpResultList], +) -> Tuple[ValueList, List[Type]]: + """Infers implicit outs and output types. + + Respects existing contents of outs if not empty. + + Returns: + normalized outs, output types + """ + # If outs were explicitly provided, we accept them verbatim. + if outs: + return outs, [out.type for out in outs] + + raise NotImplementedError( + f"Output tensor inference not yet supported for " "structured ops" + ) + + +def _get_types_from_values(*values: Value) -> Sequence[Type]: + types = [] + for v in values: + types.append(v.type) + return types + + +def _get_operand_def_names(*operand_configs: OperandDefConfig) -> Sequence[str]: + return [odc.operand_def.name for odc in operand_configs] + + +def _add_type_mapping( + operand_config: OperandDefConfig, + operand_type: Type, + type_mapping: Dict[str, Type], + block_arg_types: Sequence[Type], +): + element_or_self_type = operand_type + # Get the element type for tensor operands and the type itself for scalars. + if operand_config.shape_map: + try: + element_or_self_type = ShapedType(operand_type).element_type + except Exception as e: + raise ValueError( + f"Expected ShapedType but got {operand_type}") from e + name = operand_config.type_var.name + if name in type_mapping: + if type_mapping[name] != element_or_self_type: + raise ValueError( + f"Cannot overwrite type mapping {name} = " + f"{type_mapping[name]} by type {element_or_self_type}" + ) + type_mapping[name] = element_or_self_type + block_arg_types.append(element_or_self_type) + + +def _is_complex_type(t: Type) -> bool: + return ComplexType.isinstance(t) + + +def _is_floating_point_type(t: Type) -> bool: + # TODO: Create a FloatType in the Python API and implement the switch + # there. + return ( + F64Type.isinstance(t) + or F32Type.isinstance(t) + or F16Type.isinstance(t) + or BF16Type.isinstance(t) + ) + + +def _is_integer_type(t: Type) -> bool: + return IntegerType.isinstance(t) + + +def _is_index_type(t: Type) -> bool: + return IndexType.isinstance(t) + + +def _get_floating_point_width(t: Type) -> int: + # TODO: Create a FloatType in the Python API and implement the switch + # there. + if F64Type.isinstance(t): + return 64 + if F32Type.isinstance(t): + return 32 + if F16Type.isinstance(t): + return 16 + if BF16Type.isinstance(t): + return 16 + raise NotImplementedError(f"Unhandled floating point type switch {t}") diff --git a/python/scalehls/opdsl/lang/scalar_expr.py b/python/scalehls/opdsl/lang/scalar_expr.py new file mode 100644 index 00000000..86853994 --- /dev/null +++ b/python/scalehls/opdsl/lang/scalar_expr.py @@ -0,0 +1,166 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Models DAGs of scalar math expressions. + +Used for generating region bodies at the "math" level where they are still type +polymorphic. This is modeled to be polymorphic by attribute name for interop +with serialization schemes that are just plain-old-dicts. + +These classes are typically not user accessed and are created as a by-product +of interpreting a comprehension DSL and model the operations to perform in the +op body. The class hierarchy is laid out to map well to a form of YAML that +can be easily consumed from the C++ side, not necessarily for ergonomics. +""" + +from typing import Optional, Sequence + +from .comprehension import * +from .types import * +from .yaml_helper import * + +__all__ = [ + "ScalarAssign", + "ScalarFn", + "ScalarArg", + "ScalarConst", + "ScalarIndex", + "ScalarExpression", +] + + +class ScalarFn: + """A type of ScalarExpression that applies a function.""" + + def __init__( + self, + kind: "FunctionKind", + fn_name: Optional[str], + attr_name: Optional[str], + type_var: Optional["TypeVar"], + operands: Sequence["ScalarExpression"], + ): + if bool(fn_name) + bool(attr_name) != 1: + raise ValueError("One of 'fn_name', 'attr_name' must be specified") + self.kind = kind + self.fn_name = fn_name + self.attr_name = attr_name + self.type_var = type_var + self.operands = operands + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_fn=self) + + def __repr__(self): + name = self.fn_name if self.fn_name else self.attr_name + return ( + f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, " + f"operands=[{', '.join(self.operands)}])" + ) + + +class ScalarArg: + """A type of ScalarExpression that references a named argument.""" + + def __init__(self, arg: str): + self.arg = arg + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_arg=self) + + def __repr__(self): + return f"(ScalarArg({self.arg})" + + +class ScalarConst: + """A type of ScalarExpression representing a constant.""" + + def __init__(self, value: str): + self.value = value + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_const=self) + + def __repr__(self): + return f"(ScalarConst({self.value})" + + +class ScalarIndex: + """A type of ScalarExpression accessing an iteration index.""" + + def __init__(self, dim: int): + self.dim = dim + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_index=self) + + def __repr__(self): + return f"(ScalarIndex({self.dim})" + + +class ScalarExpression(YAMLObject): + """An expression on scalar values. + + Can be one of: + - ScalarFn + - ScalarArg + - ScalarConst + - ScalarIndex + """ + + yaml_tag = "!ScalarExpression" + + def __init__( + self, + scalar_fn: Optional[ScalarFn] = None, + scalar_arg: Optional[ScalarArg] = None, + scalar_const: Optional[ScalarConst] = None, + scalar_index: Optional[ScalarIndex] = None, + ): + if ( + bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) + bool(scalar_index) + ) != 1: + raise ValueError( + "One of 'scalar_fn', 'scalar_arg', 'scalar_const', or " + "'scalar_index' must be specified" + ) + self.scalar_fn = scalar_fn + self.scalar_arg = scalar_arg + self.scalar_const = scalar_const + self.scalar_index = scalar_index + + def to_yaml_custom_dict(self): + if self.scalar_fn: + scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower()) + if self.scalar_fn.fn_name: + scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name + if self.scalar_fn.attr_name: + scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name + if self.scalar_fn.type_var: + scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name + scalar_fn_dict["operands"] = list(self.scalar_fn.operands) + return dict(scalar_fn=scalar_fn_dict) + elif self.scalar_arg: + return dict(scalar_arg=self.scalar_arg.arg) + elif self.scalar_const: + return dict(scalar_const=self.scalar_const.value) + elif self.scalar_index: + return dict(scalar_index=self.scalar_index.dim) + else: + raise ValueError(f"Unexpected ScalarExpression type: {self}") + + +class ScalarAssign(YAMLObject): + """An assignment to a named argument (LHS of a comprehension).""" + + yaml_tag = "!ScalarAssign" + + def __init__(self, arg: str, value: ScalarExpression): + self.arg = arg + self.value = value + + def to_yaml_custom_dict(self): + return dict(arg=self.arg, value=self.value) + + def __repr__(self): + return f"ScalarAssign({self.arg}, {self.value})" diff --git a/python/scalehls/opdsl/lang/types.py b/python/scalehls/opdsl/lang/types.py new file mode 100644 index 00000000..4f36029b --- /dev/null +++ b/python/scalehls/opdsl/lang/types.py @@ -0,0 +1,79 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Facility for symbolically referencing type variables. + +Type variables are instances of the TypeVar class, which is uniqued by name. +An "expando" accessor `TV` is provided that generates a named TypeVar for +any attribute access: + + >>> TV.T + TypeVar(T) + >>> TV.T is TV.U + False + >>> TV.T is TV.T + True +""" + +from enum import Enum +from typing import Dict + +__all__ = [ + "TypeVar", + "TV", + # Predefined types. + "I32", + "I64", + "F32", + "F64", + # TypeVar aliases. + "T", + "U", + "V", +] + + +class TypeVar: + """A replaceable type variable. + + Type variables are uniqued by name. + """ + + ALL_TYPEVARS = dict() # type: Dict[str, "TypeVar"] + + def __new__(cls, name: str): + existing = cls.ALL_TYPEVARS.get(name) + if existing is not None: + return existing + new = super().__new__(cls) + new.name = name + cls.ALL_TYPEVARS[name] = new + return new + + def __repr__(self): + return f"TypeVar({self.name})" + + @classmethod + def create_expando(cls): + """Create an expando class that creates unique type vars on attr access.""" + + class ExpandoTypeVars: + def __getattr__(self, n): + return cls(n) + + return ExpandoTypeVars() + + +# Expando access via TV.foo +TV = TypeVar.create_expando() + +# Predefined types. +I32 = TV.I32 +I64 = TV.I64 +F32 = TV.F32 +F64 = TV.F64 + +# Some common type name aliases. +T = TV.T +U = TV.U +V = TV.V diff --git a/python/scalehls/opdsl/lang/yaml_helper.py b/python/scalehls/opdsl/lang/yaml_helper.py new file mode 100644 index 00000000..1672656b --- /dev/null +++ b/python/scalehls/opdsl/lang/yaml_helper.py @@ -0,0 +1,53 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""YAML serialization is routed through here to centralize common logic.""" + +import sys + +try: + import yaml +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"This tool requires PyYAML but it was not installed. " + f"Recommend: {sys.executable} -m pip install PyYAML" + ) from e + +__all__ = [ + "yaml_dump", + "yaml_dump_all", + "YAMLObject", +] + + +class YAMLObject(yaml.YAMLObject): + @classmethod + def to_yaml(cls, dumper, self): + """Default to a custom dictionary mapping.""" + return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict()) + + def to_yaml_custom_dict(self): + raise NotImplementedError() + + def as_linalg_yaml(self): + return yaml_dump(self) + + +def multiline_str_representer(dumper, data): + if len(data.splitlines()) > 1: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + else: + return dumper.represent_scalar("tag:yaml.org,2002:str", data) + + +yaml.add_representer(str, multiline_str_representer) + + +def yaml_dump(data, sort_keys=False, **kwargs): + return yaml.dump(data, sort_keys=sort_keys, **kwargs) + + +def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs): + return yaml.dump_all( + data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs + ) diff --git a/python/scalehls/opdsl/ops/__init__.py b/python/scalehls/opdsl/ops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/scalehls/opdsl/ops/core_named_ops.py b/python/scalehls/opdsl/ops/core_named_ops.py new file mode 100644 index 00000000..08818b21 --- /dev/null +++ b/python/scalehls/opdsl/ops/core_named_ops.py @@ -0,0 +1,1575 @@ +from ..lang import * + +T1 = TV.T1 +T2 = TV.T2 + +Batch = S.Batch + + +@linalg_structured_op +def copy( + I=TensorDef(T1), + O=TensorDef(U, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Copies the tensor elementwise. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + defines(Canonicalizer) + O[None] = cast(U, I[None]) + + +@linalg_structured_op +def elemwise_unary( + I=TensorDef(T1), + O=TensorDef(U, output=True), + fun=UnaryFnAttrDef(default=UnaryFn.exp), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Applies the unary function fun elementwise. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + O[None] = fun(cast(U, I[None])) + + +@linalg_structured_op +def exp( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies exp(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.exp(I[None]) + + +@linalg_structured_op +def log( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies log(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.log(I[None]) + + +@linalg_structured_op +def abs( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies abs(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.abs(I[None]) + + +@linalg_structured_op +def ceil( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies ceil(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.ceil(I[None]) + + +@linalg_structured_op +def floor( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies floor(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.floor(I[None]) + + +@linalg_structured_op +def negf( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies negf(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.negf(I[None]) + + +@linalg_structured_op +def elemwise_binary( + lhs=TensorDef(T1), + rhs=TensorDef(T2), + O=TensorDef(U, output=True), + fun=BinaryFnAttrDef(default=BinaryFn.add), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Applies the binary function fun elementwise. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None])) + + +@linalg_structured_op +def add( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Adds two tensors elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.add` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.add(lhs[None], rhs[None]) + + +@linalg_structured_op +def sub( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Subtracts two tensors elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.sub` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.sub(lhs[None], rhs[None]) + + +@linalg_structured_op +def mul( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Multiplies two tensors elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.mul` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.mul(lhs[None], rhs[None]) + + +@linalg_structured_op +def div( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Divides the first tensor by the second tensor, elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.div` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.div(lhs[None], rhs[None]) + + +@linalg_structured_op +def div_unsigned( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Divides the first tensor by the second tensor, elementwise. For integer + types, performs an unsigned division. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.div` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = lhs[None] / rhs[None] + + +@linalg_structured_op +def max( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Takes the max (signed) between two inputs, elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.div` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.max_signed(lhs[None], rhs[None]) + + +@linalg_structured_op +def matmul( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Performs a matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + + +@linalg_structured_op +def matmul_unsigned( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True), +): + """Performs an unsigned matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned( + U, B[D.k, D.n] + ) + + +@linalg_structured_op +def quantized_matmul( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + AZp=ScalarDef(I32), + BZp=ScalarDef(I32), + C=TensorDef(U, S.M, S.N, output=True), +): + """Performs a matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. The quantized variant + includes zero-point adjustments for the left and right operands of the + matmul. + """ + domain(D.m, D.n, D.k) + C[D.m, D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * ( + TypeFn.cast_signed(U, B[D.k, D.n]) - TypeFn.cast_signed(U, BZp) + ) + + +@linalg_structured_op +def matmul_transpose_a(A=TensorDef(T1, S.K, S.N), + B=TensorDef(T2, S.K, S.M), + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): + """Performs a matrix multiplication of two 2D inputs with lhs operand + transposed. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n]) + + +@linalg_structured_op +def matmul_transpose_b(A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.N, S.K), + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): + """Performs a matrix multiplication of two 2D inputs with rhs operand + transposed. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k]) + + +@linalg_structured_op +def mmt4d( + lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), + rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), + accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True), +): + """Performs a matrix-matrix-transpose multiplication of two 4D inputs. + + Differences from linalg.matmul: + * The right hand side is transposed, whence the 't' in 'mmt'. + * The input and output tensors have a 4D shape instead of a 2D shape. They + are interpreted as 2D matrices with one level of 2D tile subdivision, + whence the 2+2=4 dimensions. The inner tile dimensions are identified with + '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads + as: MxK tiles, each of shape M0xK0. + """ + domain(D.m, D.n, D.k, D.m0, D.n0, D.k0) + implements(ContractionOpInterface) + accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed( + TV.AccumType, lhs[D.m, D.k, D.m0, D.k0] + ) * TypeFn.cast_signed(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) + + +@linalg_structured_op +def batch_matmul( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, Batch, S.M, S.N, output=True), +): + """Performs a batched matrix multiplication of two 3D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.k, D.n] + ) + + +@linalg_structured_op +def batch_matmul_transpose_a(A=TensorDef(T1, Batch, S.K, S.M), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, Batch, S.M, S.N, output=True)): + """Performs a batched matrix multiplication of two 3D inputs where lhs operand + has its non-batch dimensions transposed. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) \ + * TypeFn.cast_signed(U, B[D.b, D.k, D.n]) + + +@linalg_structured_op +def batch_matmul_transpose_b(A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.N, S.K), + C=TensorDef(U, Batch, S.M, S.N, output=True)): + """Performs a batched matrix multiplication of two 3D inputs where rhs operand + has its non-batch dimensions transposed. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.b, D.m, + D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.n, D.k]) + + +@linalg_structured_op +def quantized_batch_matmul( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + AZp=ScalarDef(I32), + BZp=ScalarDef(I32), + C=TensorDef(U, Batch, S.M, S.N, output=True), +): + """Performs a batched matrix multiplication of two 3D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. The quantized variant + includes zero-point adjustments for the left and right operands of the + matmul. + """ + domain(D.b, D.m, D.n, D.k) + C[D.b, D.m, D.n] += ( + TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - TypeFn.cast_signed(U, AZp) + ) * (TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) + + +@linalg_structured_op +def batch_reduce_matmul( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True), +): + """Performs a batch-reduce matrix multiplication of two 3D inputs. + The partial multiplication results are reduced into a 2D output. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += TypeFn.cast_signed( + U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n]) + ) + + +@linalg_structured_op +def matvec( + A=TensorDef(T1, S.M, S.N), y=TensorDef(T2, S.N), x=TensorDef(U, S.M, output=True) +): + """Performs a matrix-vector multiplication. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n) + implements(ContractionOpInterface) + x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n]) + + +@linalg_structured_op +def vecmat( + y=TensorDef(T1, S.M), A=TensorDef(T2, S.M, S.N), x=TensorDef(U, S.N, output=True) +): + """Performs a vector-matrix multiplication. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.n, D.m) + implements(ContractionOpInterface) + x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n]) + + +@linalg_structured_op +def batch_matvec( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K), + C=TensorDef(U, Batch, S.M, output=True), +): + """Performs a batched matrix-vector multiplication. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.k) + implements(ContractionOpInterface) + C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.k] + ) + + +@linalg_structured_op +def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)): + """Performs a dot product of two vectors to a scalar result. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ContractionOpInterface) + C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m]) + + +@linalg_structured_op +def conv_1d( + I=TensorDef(T1, S.OW + S.KW), + K=TensorDef(T2, S.KW), + O=TensorDef(U, S.OW, output=True), +): + """Performs 1-D convolution with no channels. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.ow, D.kw) + O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kw]) + + +@linalg_structured_op +def conv_2d( + I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW), + K=TensorDef(T2, S.KH, S.KW), + O=TensorDef(U, S.OH, S.OW, output=True), +): + """Performs 2-D convolution with no channels. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.oh, D.ow, D.kh, D.kw) + O[D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.oh + D.kh, D.ow + D.kw] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw]) + + +@linalg_structured_op +def conv_3d( + I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW), + K=TensorDef(T2, S.KD, S.KH, S.KW), + O=TensorDef(U, S.OD, S.OH, S.OW, output=True), +): + """Performs 3-D convolution with no channels. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) + O[D.od, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw] + ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw]) + + +@linalg_structured_op +def conv_1d_nwc_wcf( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.f, D.kw, D.c) + O[D.n, D.ow, D.f] += TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c] + ) * TypeFn.cast_signed(U, K[D.kw, D.c, D.f]) + + +@linalg_structured_op +def conv_1d_ncw_fcw( + I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.F, S.C, S.KW), + O=TensorDef(U, S.N, S.F, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs 1-D convolution. + + Layout: + * Input: NCW. + * Kernel: FCW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.f, D.ow, D.c, D.kw) + O[D.n, D.f, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kw]) + + +@linalg_structured_op +def conv_2d_nhwc_hwcf( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution. + + Layout: + * Input: NHWC. + * Kernel: HWCF. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) + + +@linalg_structured_op +def conv_2d_nhwc_fhwc( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.F, S.KH, S.KW, S.C), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution. + + Layout: + * Input: NHWC. + * Kernel: FHWC. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) + + +@linalg_structured_op +def conv_2d_nhwc_hwcf_q( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, S.C, S.F), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution with zero point offsets. + + Layout: + * Input: NHWC. + * Kernel: HWCF. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp)) + + +@linalg_structured_op +def conv_2d_nchw_fchw( + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.F, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution. + + Layout: + * Input: NCHW. + * Kernel: FCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) + + +@linalg_structured_op +def conv_2d_ngchw_fgchw( + I=TensorDef( + T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW + ), + K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D grouped convolution. + + Layout: + * Input: NGCHW. + * Kernel: FGCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) + + +@linalg_structured_op +def conv_3d_ndhwc_dhwcf( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) + + +@linalg_structured_op +def conv_3d_ndhwc_dhwcf_q( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3-D convolution with zero point offsets. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + O[D.n, D.od, D.oh, D.ow, D.f] += ( + TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) + - TypeFn.cast_signed(U, IZp) + ) * ( + TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) + - TypeFn.cast_signed(U, KZp) + ) + + +@linalg_structured_op +def conv_3d_ncdhw_fcdhw( + I=TensorDef( + T1, + S.N, + S.C, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + ), + K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW), + O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed( + U, + I[ + D.n, + D.c, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + ], + ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw]) + + +@linalg_structured_op +def depthwise_conv_1d_nwc_wc( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KW, S.IC), + O=TensorDef(U, S.N, S.OW, S.IC, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs depth-wise 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.ic, D.kw) + O[D.n, D.ow, D.ic] += TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic] + ) * TypeFn.cast_signed(U, K[D.kw, D.ic]) + + +@linalg_structured_op +def depthwise_conv_1d_ncw_cw( + I=TensorDef(T1, S.N, S.IC, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.IC, S.KW), + O=TensorDef(U, S.N, S.IC, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs depth-wise 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.ic, D.kw) + O[D.n, D.ic, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.ic, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.ic, D.kw]) + + +@linalg_structured_op +def depthwise_conv_1d_nwc_wcm( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KW, S.IC, S.CM), + O=TensorDef(U, S.N, S.OW, S.IC, S.CM, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs depth-wise 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.ic, D.cm, D.kw) + O[D.n, D.ow, D.ic, D.cm] += TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic] + ) * TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm]) + + +@linalg_structured_op +def depthwise_conv_2d_nhwc_hwc( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) + + +@linalg_structured_op +def depthwise_conv_2d_nchw_chw( + I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.IC, S.KH, S.KW), + O=TensorDef(U, S.N, S.IC, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) + O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw]) + + +@linalg_structured_op +def depthwise_conv_2d_nhwc_hwc_q( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast_signed(U, KZp)) + + +@linalg_structured_op +def depthwise_conv_2d_nhwc_hwcm( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) + + +@linalg_structured_op +def depthwise_conv_2d_nhwc_hwcm_q( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic, D.cm] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast_signed(U, KZp)) + + +@linalg_structured_op +def depthwise_conv_3d_ndhwc_dhwc( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.IC, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs depth-wise 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) + O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.ic, + ], + ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic]) + + +@linalg_structured_op +def depthwise_conv_3d_ncdhw_cdhw( + I=TensorDef( + T1, + S.N, + S.IC, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + ), + K=TensorDef(T2, S.IC, S.KD, S.KH, S.KW), + O=TensorDef(U, S.N, S.IC, S.OD, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs depth-wise 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) + O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed( + U, + I[ + D.n, + D.ic, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + ], + ) * TypeFn.cast_signed(U, K[D.ic, D.kd, D.kh, D.kw]) + + +@linalg_structured_op +def depthwise_conv_3d_ndhwc_dhwcm( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.IC, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.CM, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs depth-wise 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic) + O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.ic, + ], + ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic, D.cm]) + + +@linalg_structured_op +def pooling_nhwc_sum( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs sum pooling. + + Layout: + * Input: NHWC. + * Kernel: HW. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + + +@linalg_structured_op +def pooling_nchw_sum( + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs sum pooling. + + Layout: + * Input: NCHW. + * Kernel: HW. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) + O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) + + +@linalg_structured_op +def pooling_nhwc_max( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw]( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + ) + + +@linalg_structured_op +def pooling_nhwc_max_unsigned( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs unsigned max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw]( + TypeFn.cast_unsigned( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + ) + + +@linalg_structured_op +def pooling_nchw_max( + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) + O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw]( + TypeFn.cast_signed( + U, + I[ + D.n, + D.c, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + ], + ) + ) + + +@linalg_structured_op +def pooling_nhwc_min( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw]( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + ) + + +@linalg_structured_op +def pooling_nhwc_min_unsigned( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs unsigned min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw]( + TypeFn.cast_unsigned( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + ) + + +@linalg_structured_op +def pooling_nwc_sum( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs sum pooling. + + Layout: + * Input: NWC. + * Kernel: W. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + + +@linalg_structured_op +def pooling_ncw_sum( + I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.C, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs sum pooling. + + Layout: + * Input: NCW. + * Kernel: W. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.ow, D.kw) + O[D.n, D.c, D.ow] += TypeFn.cast_signed(U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) + + +@linalg_structured_op +def pooling_nwc_max( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.max_signed[[D.kw]]( + TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + ) + + +@linalg_structured_op +def pooling_nwc_max_unsigned( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs unsigned max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.max_unsigned[[D.kw]]( + TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + ) + + +@linalg_structured_op +def pooling_ncw_max( + I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.C, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.ow, D.kw) + O[D.n, D.c, D.ow] = ReduceFn.max_signed[[D.kw]]( + TypeFn.cast_signed( + U, + I[ + D.n, + D.c, + D.ow * S.SW + D.kw * S.DW, + ], + ) + ) + + +@linalg_structured_op +def pooling_nwc_min( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.min_signed[[D.kw]]( + TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + ) + + +@linalg_structured_op +def pooling_nwc_min_unsigned( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs unsigned min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.min_unsigned[[D.kw]]( + TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + ) + + +@linalg_structured_op +def pooling_ndhwc_sum( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3D sum pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) + O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) + + +@linalg_structured_op +def pooling_ndhwc_max( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3D max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw]( + TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) + ) + + +@linalg_structured_op +def pooling_ndhwc_min( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3D min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw]( + TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) + ) + + +@linalg_structured_op +def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)): + """Fills the output tensor with the given value. + + Works for arbitrary ranked output tensors since the operation performs scalar + accesses only and is thus rank polymorphic. Numeric casting is performed on + the value operand, promoting it to the same data type as the output. + """ + implements(FillOpInterface) + defines(Canonicalizer) + O[None] = TypeFn.cast_signed(U, value) + + +@linalg_structured_op +def fill_rng_2d( + min=ScalarDef(F64), + max=ScalarDef(F64), + seed=ScalarDef(I32), + O=TensorDef(T, S.M, S.N, output=True), +): + """Fills the output tensor with pseudo random numbers. + + The operation generations pseudo random numbers using a linear congruential + generator. It provides no guarantees regarding the distribution of the + generated random numbers. Instead of generating the random numbers + sequentially, it instantiates one random number generator per data element + and runs them in parallel. The seed operand and the indices of the data + element seed the random number generation. The min and max operands limit + the range of the generated random numbers. + """ + domain(D.m, D.n) + multiplier = TypeFn.cast_signed(I32, const(1103515245)) + increment = TypeFn.cast_signed(I32, const(12345)) + rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment + rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment + inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10)) + offset = TypeFn.cast_signed(F64, const(2147483647)) + scaling = (max - min) * inv_range + O[D.m, D.n] = TypeFn.cast_signed( + T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min + ) diff --git a/test/python/import.py b/test/python/import.py index ec3b135b..f1d4a9e5 100644 --- a/test/python/import.py +++ b/test/python/import.py @@ -5,3 +5,4 @@ from scalehls.dialects import linalg from scalehls.dialects import hls from scalehls import uip +from scalehls.opdsl.lang import *