From 21fd5473bf83140be8fe6c473f579961116cfc25 Mon Sep 17 00:00:00 2001 From: John Demme Date: Tue, 19 Nov 2024 21:05:04 +0000 Subject: [PATCH] [PyCDE][Handshake] Add bindings for Handshake functions --- frontends/PyCDE/src/CMakeLists.txt | 1 + frontends/PyCDE/src/pycde/handshake.py | 146 ++++++++++++++++++++++ frontends/PyCDE/test/test_handshake.py | 50 ++++++++ lib/Bindings/Python/dialects/handshake.py | 17 ++- 4 files changed, 209 insertions(+), 5 deletions(-) create mode 100644 frontends/PyCDE/src/pycde/handshake.py create mode 100644 frontends/PyCDE/test/test_handshake.py diff --git a/frontends/PyCDE/src/CMakeLists.txt b/frontends/PyCDE/src/CMakeLists.txt index cbcd48839214..fad232cb8301 100644 --- a/frontends/PyCDE/src/CMakeLists.txt +++ b/frontends/PyCDE/src/CMakeLists.txt @@ -31,6 +31,7 @@ declare_mlir_python_sources(PyCDESources pycde/common.py pycde/system.py pycde/devicedb.py + pycde/handshake.py pycde/instance.py pycde/seq.py pycde/signals.py diff --git a/frontends/PyCDE/src/pycde/handshake.py b/frontends/PyCDE/src/pycde/handshake.py new file mode 100644 index 000000000000..803bef8b099c --- /dev/null +++ b/frontends/PyCDE/src/pycde/handshake.py @@ -0,0 +1,146 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from __future__ import annotations +from typing import Any, List, Optional, Set, Tuple, Dict +import typing + +from .module import Module, ModuleLikeBuilderBase, PortError +from .signals import BitsSignal, ChannelSignal, ClockSignal, Signal +from .system import System +from .support import (get_user_loc, _obj_to_attribute, obj_to_typed_attribute, + create_const_zero) +from .types import Channel + +from .circt.dialects import handshake as raw_handshake +from .circt import ir + + +class FuncBuilder(ModuleLikeBuilderBase): + """Defines how an ESI `PureModule` gets built.""" + + @property + def circt_mod(self): + sys: System = System.current() + ret = sys._op_cache.get_circt_mod(self) + if ret is None: + return sys._create_circt_mod(self) + return ret + + def create_op(self, sys: System, symbol): + if hasattr(self.modcls, "metadata"): + meta = self.modcls.metadata + self.add_metadata(sys, symbol, meta) + else: + self.add_metadata(sys, symbol, None) + + # If there are associated constants, add them to the manifest. + if len(self.constants) > 0: + constants_dict: Dict[str, ir.Attribute] = {} + for name, constant in self.constants.items(): + constant_attr = obj_to_typed_attribute(constant.value, constant.type) + constants_dict[name] = constant_attr + with ir.InsertionPoint(sys.mod.body): + from .dialects.esi import esi + esi.SymbolConstantsOp(symbolRef=ir.FlatSymbolRefAttr.get(symbol), + constants=ir.DictAttr.get(constants_dict)) + + assert len(self.generators) > 0 + + if hasattr(self, "parameters") and self.parameters is not None: + self.attributes["pycde.parameters"] = self.parameters + # If this Module has a generator, it's a real module. + return raw_handshake.FuncOp.create( + symbol, + [(p.name, p.type._type) for p in self.inputs], + [(p.name, p.type._type) for p in self.outputs], + attributes=self.attributes, + loc=self.loc, + ip=sys._get_ip(), + ) + + def generate(self): + """Fill in (generate) this module. Only supports a single generator + currently.""" + if len(self.generators) != 1: + raise ValueError("Must have exactly one generator.") + g: Generator = list(self.generators.values())[0] + + entry_block = self.circt_mod.add_entry_block() + ports = self.generator_port_proxy(entry_block.arguments, self) + with self.GeneratorCtxt(self, ports, entry_block, g.loc): + outputs = g.gen_func(ports) + if outputs is not None: + raise ValueError("Generators must not return a value") + + ports._check_unconnected_outputs() + raw_handshake.ReturnOp([o.value for o in ports._output_values]) + + def instantiate(self, module_inst, inputs, instance_name: str): + """"Instantiate this Func from ESI channels. Check that the input types + match expectations.""" + + port_input_lookup = {port.name: port for port in self.inputs} + circt_inputs: List[Optional[ir.Value]] = [None] * len(self.inputs) + remaining_inputs = set(port_input_lookup.keys()) + clk = None + rst = None + for name, signal in inputs.items(): + if name == "clk": + if not isinstance(signal, ClockSignal): + raise PortError("'clk' must be a clock signal") + clk = signal.value + continue + elif name == "rst": + if not isinstance(signal, BitsSignal): + raise PortError("'rst' must be a Bits(1)") + rst = signal.value + continue + + if name not in port_input_lookup: + raise PortError(f"Input port {name} not found in module") + port = port_input_lookup[name] + if isinstance(signal, ChannelSignal): + # If the input is a channel signal, the types must match. + if signal.type.inner_type != port.type: + raise ValueError( + f"Wrong type on input signal '{name}'. Got '{signal.type}'," + f" expected '{port.type}'") + assert port.idx is not None + circt_inputs[port.idx] = signal.value + remaining_inputs.remove(name) + elif isinstance(signal, Signal): + raise PortError(f"Input {name} must be a channel signal") + else: + raise PortError(f"Port {name} must be a signal") + if clk is None: + raise PortError("Missing 'clk' signal") + if rst is None: + raise PortError("Missing 'rst' signal") + if len(remaining_inputs) > 0: + raise PortError( + f"Missing input signals for ports: {', '.join(remaining_inputs)}") + + circt_mod = self.circt_mod + assert circt_mod is not None + result_types = [Channel(port.type)._type for port in self.outputs] + inst = raw_handshake.ESIInstanceOp( + result_types, + ir.StringAttr(circt_mod.attributes["sym_name"]).value, + instance_name, + clk=clk, + rst=rst, + opOperands=circt_inputs, + loc=get_user_loc()) + inst.operation.verify() + return inst + + +class Func(Module): + """A pure ESI module has no ports and contains only instances of modules with + only ESI ports and connections between said instances. Use ESI services for + external communication.""" + + BuilderType: type[ModuleLikeBuilderBase] = FuncBuilder + _builder: FuncBuilder diff --git a/frontends/PyCDE/test/test_handshake.py b/frontends/PyCDE/test/test_handshake.py new file mode 100644 index 000000000000..359f3f55e41d --- /dev/null +++ b/frontends/PyCDE/test/test_handshake.py @@ -0,0 +1,50 @@ +# RUN: %PYTHON% %s | FileCheck %s + +from pycde import (Clock, Output, Input, generator, types, Module) +from pycde.handshake import Func +from pycde.testing import unittestmodule +from pycde.types import Bits, Channel + +# CHECK-LABEL: hw.module @Top() +# CHECK: %c7_i12 = hw.constant 7 : i12 +# CHECK: hw.struct_create (%c7_i12) : !hw.struct +# CHECK: %c42_i8 = hw.constant 42 : i8 +# CHECK: %c45_i8 = hw.constant 45 : i8 +# CHECK: hw.array_create %c45_i8, %c42_i8 : i8 +# CHECK: %c5_i8 = hw.constant 5 : i8 +# CHECK: %c7_i12_0 = hw.constant 7 : i12 +# CHECK: hw.struct_create (%c7_i12_0) : !hw.typealias<@pycde::@bar, !hw.struct> +# CHECK: %Taps.taps = hw.instance "Taps" sym @Taps @Taps() -> (taps: !hw.array<3xi8>) +# CHECK: hw.output +# CHECK-LABEL: hw.module @Taps(out taps : !hw.array<3xi8>) +# CHECK: %c-53_i8 = hw.constant -53 : i8 +# CHECK: %c100_i8 = hw.constant 100 : i8 +# CHECK: %c23_i8 = hw.constant 23 : i8 +# CHECK: [[R0:%.+]] = hw.array_create %c23_i8, %c100_i8, %c-53_i8 : i8 +# CHECK: hw.output [[R0]] : !hw.array<3xi8> + + +class TestFunc(Func): + a = Input(Bits(8)) + x = Output(Bits(8)) + + @generator + def build(ports): + ports.x = ports.a & Bits(8)(0xF) + + +BarType = types.struct({"foo": types.i12}, "bar") + + +@unittestmodule() +class Top(Module): + clk = Clock() + rst = Input(Bits(1)) + + a = Input(Channel(Bits(8))) + x = Output(Channel(Bits(8))) + + @generator + def build(ports): + test = TestFunc(clk=ports.clk, rst=ports.rst, a=ports.a) + ports.x = test.x diff --git a/lib/Bindings/Python/dialects/handshake.py b/lib/Bindings/Python/dialects/handshake.py index 78ce7de1bb1f..e60b70eb5141 100644 --- a/lib/Bindings/Python/dialects/handshake.py +++ b/lib/Bindings/Python/dialects/handshake.py @@ -3,15 +3,14 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from __future__ import annotations +from typing import Dict, List, Tuple, Union from . import handshake from ._handshake_ops_gen import * from ._handshake_ops_gen import _Dialect from ..dialects._ods_common import _cext as _ods_cext -from ..ir import ArrayAttr, FunctionType, StringAttr, Type, TypeAttr - -from typing import List, Tuple, Union +from ..ir import ArrayAttr, Attribute, FunctionType, StringAttr, Type, TypeAttr from ._ods_common import ( equally_sized_accessor as _ods_equally_sized_accessor, @@ -32,17 +31,25 @@ class FuncOp(FuncOp): def create(sym_name: Union[StringAttr, str], args: List[Tuple[str, Type]], results: List[Tuple[str, Type]], - private: bool = False) -> FuncOp: + attributes: Dict[str, Attribute] = {}, + loc=None, + ip=None) -> FuncOp: if isinstance(sym_name, str): sym_name = StringAttr.get(sym_name) input_types = [t for _, t in args] res_types = [t for _, t in results] func_type = FunctionType.get(input_types, res_types) func_type_attr = TypeAttr.get(func_type) - funcop = FuncOp(func_type_attr) + funcop = FuncOp(func_type_attr, loc=loc, ip=ip) + for k, v in attributes.items(): + funcop.attributes[k] = v funcop.attributes["sym_name"] = sym_name funcop.attributes["argNames"] = ArrayAttr.get( [StringAttr.get(name) for name, _ in args]) funcop.attributes["resNames"] = ArrayAttr.get( [StringAttr.get(name) for name, _ in results]) return funcop + + def add_entry_block(self): + self.body.blocks.append(*self.function_type.value.inputs) + return self.body.blocks[0]