From 847a227fe967cd6c022bb6d6af2003eb51515e03 Mon Sep 17 00:00:00 2001 From: John Demme Date: Tue, 19 Nov 2024 19:22:50 +0000 Subject: [PATCH 1/6] [Handshake][Python] Add Python bindings for Handshake --- .../Bindings/Python/dialects/handshake.py | 16 +++++++ lib/Bindings/Python/CMakeLists.txt | 8 ++++ lib/Bindings/Python/dialects/HandshakeOps.td | 14 ++++++ lib/Bindings/Python/dialects/handshake.py | 48 +++++++++++++++++++ 4 files changed, 86 insertions(+) create mode 100644 integration_test/Bindings/Python/dialects/handshake.py create mode 100644 lib/Bindings/Python/dialects/HandshakeOps.td create mode 100644 lib/Bindings/Python/dialects/handshake.py diff --git a/integration_test/Bindings/Python/dialects/handshake.py b/integration_test/Bindings/Python/dialects/handshake.py new file mode 100644 index 000000000000..4655628b55ef --- /dev/null +++ b/integration_test/Bindings/Python/dialects/handshake.py @@ -0,0 +1,16 @@ +# REQUIRES: bindings_python +# RUN: %PYTHON% %s | FileCheck %s + +import circt + +from circt.dialects import hw, handshake +from circt.ir import Context, Location, Module, InsertionPoint, IntegerAttr, IntegerType + +with Context() as ctx, Location.unknown(): + circt.register_dialects(ctx) + m = Module.create() + with InsertionPoint(m.body): + op = handshake.FuncOp.create("foo", [("a", IntegerType.get_signless(8))], + [("x", IntegerType.get_signless(1))]) + # CHECK: handshake.func @foo(i8, ...) -> i1 attributes {argNames = ["a"], resNames = ["x"]} + print(m) diff --git a/lib/Bindings/Python/CMakeLists.txt b/lib/Bindings/Python/CMakeLists.txt index 8f2d07d85cc0..46518d6ba0df 100644 --- a/lib/Bindings/Python/CMakeLists.txt +++ b/lib/Bindings/Python/CMakeLists.txt @@ -92,6 +92,14 @@ declare_mlir_dialect_python_bindings( dialects/esi.py DIALECT_NAME esi) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT CIRCTBindingsPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}" + TD_FILE dialects/HandshakeOps.td + SOURCES + dialects/handshake.py + DIALECT_NAME handshake) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT CIRCTBindingsPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}" diff --git a/lib/Bindings/Python/dialects/HandshakeOps.td b/lib/Bindings/Python/dialects/HandshakeOps.td new file mode 100644 index 000000000000..66bc69e4bda3 --- /dev/null +++ b/lib/Bindings/Python/dialects/HandshakeOps.td @@ -0,0 +1,14 @@ +//===- HandshakeOps.td - python op bindings gen ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef BINDINGS_PYTHON_HANDSHAKE_OPS +#define BINDINGS_PYTHON_HANDSHAKE_OPS + +include "circt/Dialect/Handshake/Handshake.td" + +#endif diff --git a/lib/Bindings/Python/dialects/handshake.py b/lib/Bindings/Python/dialects/handshake.py new file mode 100644 index 000000000000..78ce7de1bb1f --- /dev/null +++ b/lib/Bindings/Python/dialects/handshake.py @@ -0,0 +1,48 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from __future__ import annotations + +from . import handshake +from ._handshake_ops_gen import * +from ._handshake_ops_gen import _Dialect + +from ..dialects._ods_common import _cext as _ods_cext +from ..ir import ArrayAttr, FunctionType, StringAttr, Type, TypeAttr + +from typing import List, Tuple, Union + +from ._ods_common import ( + equally_sized_accessor as _ods_equally_sized_accessor, + get_default_loc_context as _ods_get_default_loc_context, + get_op_result_or_op_results as _get_op_result_or_op_results, + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + segmented_accessor as _ods_segmented_accessor, +) + +_ods_ir = _ods_cext.ir + + +@_ods_cext.register_operation(_Dialect, replace=True) +class FuncOp(FuncOp): + + @staticmethod + def create(sym_name: Union[StringAttr, str], + args: List[Tuple[str, Type]], + results: List[Tuple[str, Type]], + private: bool = False) -> FuncOp: + if isinstance(sym_name, str): + sym_name = StringAttr.get(sym_name) + input_types = [t for _, t in args] + res_types = [t for _, t in results] + func_type = FunctionType.get(input_types, res_types) + func_type_attr = TypeAttr.get(func_type) + funcop = FuncOp(func_type_attr) + funcop.attributes["sym_name"] = sym_name + funcop.attributes["argNames"] = ArrayAttr.get( + [StringAttr.get(name) for name, _ in args]) + funcop.attributes["resNames"] = ArrayAttr.get( + [StringAttr.get(name) for name, _ in results]) + return funcop From 8720c4a1aaf92d8152c89a429d0a5f96126dc822 Mon Sep 17 00:00:00 2001 From: John Demme Date: Tue, 19 Nov 2024 21:05:04 +0000 Subject: [PATCH 2/6] [PyCDE][Handshake] Add bindings for Handshake functions --- frontends/PyCDE/src/CMakeLists.txt | 1 + frontends/PyCDE/src/pycde/handshake.py | 146 ++++++++++++++++++++++ frontends/PyCDE/test/test_handshake.py | 42 +++++++ lib/Bindings/Python/dialects/handshake.py | 17 ++- 4 files changed, 201 insertions(+), 5 deletions(-) create mode 100644 frontends/PyCDE/src/pycde/handshake.py create mode 100644 frontends/PyCDE/test/test_handshake.py diff --git a/frontends/PyCDE/src/CMakeLists.txt b/frontends/PyCDE/src/CMakeLists.txt index cbcd48839214..fad232cb8301 100644 --- a/frontends/PyCDE/src/CMakeLists.txt +++ b/frontends/PyCDE/src/CMakeLists.txt @@ -31,6 +31,7 @@ declare_mlir_python_sources(PyCDESources pycde/common.py pycde/system.py pycde/devicedb.py + pycde/handshake.py pycde/instance.py pycde/seq.py pycde/signals.py diff --git a/frontends/PyCDE/src/pycde/handshake.py b/frontends/PyCDE/src/pycde/handshake.py new file mode 100644 index 000000000000..803bef8b099c --- /dev/null +++ b/frontends/PyCDE/src/pycde/handshake.py @@ -0,0 +1,146 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from __future__ import annotations +from typing import Any, List, Optional, Set, Tuple, Dict +import typing + +from .module import Module, ModuleLikeBuilderBase, PortError +from .signals import BitsSignal, ChannelSignal, ClockSignal, Signal +from .system import System +from .support import (get_user_loc, _obj_to_attribute, obj_to_typed_attribute, + create_const_zero) +from .types import Channel + +from .circt.dialects import handshake as raw_handshake +from .circt import ir + + +class FuncBuilder(ModuleLikeBuilderBase): + """Defines how an ESI `PureModule` gets built.""" + + @property + def circt_mod(self): + sys: System = System.current() + ret = sys._op_cache.get_circt_mod(self) + if ret is None: + return sys._create_circt_mod(self) + return ret + + def create_op(self, sys: System, symbol): + if hasattr(self.modcls, "metadata"): + meta = self.modcls.metadata + self.add_metadata(sys, symbol, meta) + else: + self.add_metadata(sys, symbol, None) + + # If there are associated constants, add them to the manifest. + if len(self.constants) > 0: + constants_dict: Dict[str, ir.Attribute] = {} + for name, constant in self.constants.items(): + constant_attr = obj_to_typed_attribute(constant.value, constant.type) + constants_dict[name] = constant_attr + with ir.InsertionPoint(sys.mod.body): + from .dialects.esi import esi + esi.SymbolConstantsOp(symbolRef=ir.FlatSymbolRefAttr.get(symbol), + constants=ir.DictAttr.get(constants_dict)) + + assert len(self.generators) > 0 + + if hasattr(self, "parameters") and self.parameters is not None: + self.attributes["pycde.parameters"] = self.parameters + # If this Module has a generator, it's a real module. + return raw_handshake.FuncOp.create( + symbol, + [(p.name, p.type._type) for p in self.inputs], + [(p.name, p.type._type) for p in self.outputs], + attributes=self.attributes, + loc=self.loc, + ip=sys._get_ip(), + ) + + def generate(self): + """Fill in (generate) this module. Only supports a single generator + currently.""" + if len(self.generators) != 1: + raise ValueError("Must have exactly one generator.") + g: Generator = list(self.generators.values())[0] + + entry_block = self.circt_mod.add_entry_block() + ports = self.generator_port_proxy(entry_block.arguments, self) + with self.GeneratorCtxt(self, ports, entry_block, g.loc): + outputs = g.gen_func(ports) + if outputs is not None: + raise ValueError("Generators must not return a value") + + ports._check_unconnected_outputs() + raw_handshake.ReturnOp([o.value for o in ports._output_values]) + + def instantiate(self, module_inst, inputs, instance_name: str): + """"Instantiate this Func from ESI channels. Check that the input types + match expectations.""" + + port_input_lookup = {port.name: port for port in self.inputs} + circt_inputs: List[Optional[ir.Value]] = [None] * len(self.inputs) + remaining_inputs = set(port_input_lookup.keys()) + clk = None + rst = None + for name, signal in inputs.items(): + if name == "clk": + if not isinstance(signal, ClockSignal): + raise PortError("'clk' must be a clock signal") + clk = signal.value + continue + elif name == "rst": + if not isinstance(signal, BitsSignal): + raise PortError("'rst' must be a Bits(1)") + rst = signal.value + continue + + if name not in port_input_lookup: + raise PortError(f"Input port {name} not found in module") + port = port_input_lookup[name] + if isinstance(signal, ChannelSignal): + # If the input is a channel signal, the types must match. + if signal.type.inner_type != port.type: + raise ValueError( + f"Wrong type on input signal '{name}'. Got '{signal.type}'," + f" expected '{port.type}'") + assert port.idx is not None + circt_inputs[port.idx] = signal.value + remaining_inputs.remove(name) + elif isinstance(signal, Signal): + raise PortError(f"Input {name} must be a channel signal") + else: + raise PortError(f"Port {name} must be a signal") + if clk is None: + raise PortError("Missing 'clk' signal") + if rst is None: + raise PortError("Missing 'rst' signal") + if len(remaining_inputs) > 0: + raise PortError( + f"Missing input signals for ports: {', '.join(remaining_inputs)}") + + circt_mod = self.circt_mod + assert circt_mod is not None + result_types = [Channel(port.type)._type for port in self.outputs] + inst = raw_handshake.ESIInstanceOp( + result_types, + ir.StringAttr(circt_mod.attributes["sym_name"]).value, + instance_name, + clk=clk, + rst=rst, + opOperands=circt_inputs, + loc=get_user_loc()) + inst.operation.verify() + return inst + + +class Func(Module): + """A pure ESI module has no ports and contains only instances of modules with + only ESI ports and connections between said instances. Use ESI services for + external communication.""" + + BuilderType: type[ModuleLikeBuilderBase] = FuncBuilder + _builder: FuncBuilder diff --git a/frontends/PyCDE/test/test_handshake.py b/frontends/PyCDE/test/test_handshake.py new file mode 100644 index 000000000000..3827496ae391 --- /dev/null +++ b/frontends/PyCDE/test/test_handshake.py @@ -0,0 +1,42 @@ +# RUN: %PYTHON% %s | FileCheck %s + +from pycde import (Clock, Output, Input, generator, types, Module) +from pycde.handshake import Func +from pycde.testing import unittestmodule +from pycde.types import Bits, Channel + +# CHECK: hw.module @Top(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel, out x : !esi.channel) attributes {output_file = #hw.output_file<"Top.sv", includeReplicatedOps>} { +# CHECK: %0 = handshake.esi_instance @TestFunc "TestFunc" clk %clk rst %rst(%a) : (!esi.channel) -> !esi.channel +# CHECK: hw.output %0 : !esi.channel +# CHECK: } +# CHECK: handshake.func @TestFunc(%arg0: i8, ...) -> i8 attributes {argNames = ["a"], output_file = #hw.output_file<"TestFunc.sv", includeReplicatedOps>, resNames = ["x"]} { +# CHECK: %c15_i8 = hw.constant 15 : i8 +# CHECK: %0 = comb.and bin %arg0, %c15_i8 : i8 +# CHECK: return %0 : i8 +# CHECK: } + + +class TestFunc(Func): + a = Input(Bits(8)) + x = Output(Bits(8)) + + @generator + def build(ports): + ports.x = ports.a & Bits(8)(0xF) + + +BarType = types.struct({"foo": types.i12}, "bar") + + +@unittestmodule() +class Top(Module): + clk = Clock() + rst = Input(Bits(1)) + + a = Input(Channel(Bits(8))) + x = Output(Channel(Bits(8))) + + @generator + def build(ports): + test = TestFunc(clk=ports.clk, rst=ports.rst, a=ports.a) + ports.x = test.x diff --git a/lib/Bindings/Python/dialects/handshake.py b/lib/Bindings/Python/dialects/handshake.py index 78ce7de1bb1f..e60b70eb5141 100644 --- a/lib/Bindings/Python/dialects/handshake.py +++ b/lib/Bindings/Python/dialects/handshake.py @@ -3,15 +3,14 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from __future__ import annotations +from typing import Dict, List, Tuple, Union from . import handshake from ._handshake_ops_gen import * from ._handshake_ops_gen import _Dialect from ..dialects._ods_common import _cext as _ods_cext -from ..ir import ArrayAttr, FunctionType, StringAttr, Type, TypeAttr - -from typing import List, Tuple, Union +from ..ir import ArrayAttr, Attribute, FunctionType, StringAttr, Type, TypeAttr from ._ods_common import ( equally_sized_accessor as _ods_equally_sized_accessor, @@ -32,17 +31,25 @@ class FuncOp(FuncOp): def create(sym_name: Union[StringAttr, str], args: List[Tuple[str, Type]], results: List[Tuple[str, Type]], - private: bool = False) -> FuncOp: + attributes: Dict[str, Attribute] = {}, + loc=None, + ip=None) -> FuncOp: if isinstance(sym_name, str): sym_name = StringAttr.get(sym_name) input_types = [t for _, t in args] res_types = [t for _, t in results] func_type = FunctionType.get(input_types, res_types) func_type_attr = TypeAttr.get(func_type) - funcop = FuncOp(func_type_attr) + funcop = FuncOp(func_type_attr, loc=loc, ip=ip) + for k, v in attributes.items(): + funcop.attributes[k] = v funcop.attributes["sym_name"] = sym_name funcop.attributes["argNames"] = ArrayAttr.get( [StringAttr.get(name) for name, _ in args]) funcop.attributes["resNames"] = ArrayAttr.get( [StringAttr.get(name) for name, _ in results]) return funcop + + def add_entry_block(self): + self.body.blocks.append(*self.function_type.value.inputs) + return self.body.blocks[0] From 4677e590bba2aa953414ae341936ac80fc0b11cb Mon Sep 17 00:00:00 2001 From: John Demme Date: Thu, 21 Nov 2024 21:12:05 +0000 Subject: [PATCH 3/6] Adding passes --- frontends/PyCDE/src/pycde/system.py | 11 +++++++++++ frontends/PyCDE/test/test_handshake.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/frontends/PyCDE/src/pycde/system.py b/frontends/PyCDE/src/pycde/system.py index c893c4e7fc7f..5ba2be1fbf34 100644 --- a/frontends/PyCDE/src/pycde/system.py +++ b/frontends/PyCDE/src/pycde/system.py @@ -256,12 +256,23 @@ def get_instance(self, "builtin.module(verify-esi-connections)", lambda sys: sys.generate(), "builtin.module(verify-esi-connections)", + # After all of the pycde code has been executed, we have all the types # defined so we can go through and output the typedefs delcarations. lambda sys: TypeAlias.declare_aliases(sys.mod), + + # Then run all the passes to lower dialects which produce `hw.module`s. + "builtin.module(lower-handshake-to-dc)", + "builtin.module(dc-materialize-forks-sinks)", + "builtin.module(canonicalize)", + "builtin.module(lower-dc-to-hw)", + + # Run ESI manifest passes. "builtin.module(esi-appid-hier{{top={tops} }}, esi-build-manifest{{top={tops} }})", "builtin.module(msft-lower-constructs, msft-lower-instances)", "builtin.module(esi-clean-metadata)", + + # Instaniate hlmems, which could produce new esi connections. "builtin.module(hw.module(lower-seq-hlmem))", "builtin.module(lower-esi-to-physical)", # TODO: support more than just cosim. diff --git a/frontends/PyCDE/test/test_handshake.py b/frontends/PyCDE/test/test_handshake.py index 3827496ae391..b7e18f6d22d1 100644 --- a/frontends/PyCDE/test/test_handshake.py +++ b/frontends/PyCDE/test/test_handshake.py @@ -28,7 +28,7 @@ def build(ports): BarType = types.struct({"foo": types.i12}, "bar") -@unittestmodule() +@unittestmodule(print=True, run_passes=True, print_after_passes=True) class Top(Module): clk = Clock() rst = Input(Bits(1)) From 304ea71ca9d612668cc97a8d357a8e5d4dce9374 Mon Sep 17 00:00:00 2001 From: John Demme Date: Thu, 21 Nov 2024 22:44:17 +0000 Subject: [PATCH 4/6] checkpoint --- frontends/PyCDE/src/pycde/system.py | 1 + frontends/PyCDE/test/test_handshake.py | 5 +++- .../PyCDE/test/test_verilog_readablility.py | 2 +- include/circt-c/Dialect/DC.h | 25 +++++++++++++++++++ lib/Bindings/Python/CIRCTModule.cpp | 2 ++ lib/Bindings/Python/CMakeLists.txt | 1 + lib/CAPI/Dialect/CMakeLists.txt | 9 +++++++ lib/CAPI/Dialect/DC.cpp | 18 +++++++++++++ .../DC/Transforms/DCMaterialization.cpp | 3 ++- test/Dialect/DC/materialize-forks-sinks.mlir | 8 ++++++ 10 files changed, 71 insertions(+), 3 deletions(-) create mode 100644 include/circt-c/Dialect/DC.h create mode 100644 lib/CAPI/Dialect/DC.cpp diff --git a/frontends/PyCDE/src/pycde/system.py b/frontends/PyCDE/src/pycde/system.py index 5ba2be1fbf34..f4504ce34de2 100644 --- a/frontends/PyCDE/src/pycde/system.py +++ b/frontends/PyCDE/src/pycde/system.py @@ -314,6 +314,7 @@ def run_passes(self, debug=False): tcl_file=tcl_file, platform=self.platform).strip() if aplog is not None: + print(f"Running phase #{idx}: {phase}") aplog.write(f"// passes ran: {passes}\n") aplog.flush() pm = passmanager.PassManager.parse(passes) diff --git a/frontends/PyCDE/test/test_handshake.py b/frontends/PyCDE/test/test_handshake.py index b7e18f6d22d1..effe5faf3689 100644 --- a/frontends/PyCDE/test/test_handshake.py +++ b/frontends/PyCDE/test/test_handshake.py @@ -28,7 +28,10 @@ def build(ports): BarType = types.struct({"foo": types.i12}, "bar") -@unittestmodule(print=True, run_passes=True, print_after_passes=True) +@unittestmodule(print=True, + run_passes=True, + print_after_passes=True, + debug=True) class Top(Module): clk = Clock() rst = Input(Bits(1)) diff --git a/frontends/PyCDE/test/test_verilog_readablility.py b/frontends/PyCDE/test/test_verilog_readablility.py index cd48a312df0b..2b31c29ac8bf 100644 --- a/frontends/PyCDE/test/test_verilog_readablility.py +++ b/frontends/PyCDE/test/test_verilog_readablility.py @@ -24,7 +24,7 @@ def build(self): sys = System([WireNames], output_directory=sys.argv[1]) sys.generate() -sys.run_passes() +sys.run_passes(debug=True) sys.print() # CHECK-LABEL: hw.module @WireNames(in %clk : i1, in %data_in : !hw.array<3xi32>, in %sel : i2, out a : i32, out b : i32) # CHECK: %foo__reg1 = sv.reg sym @foo__reg1 : !hw.inout diff --git a/include/circt-c/Dialect/DC.h b/include/circt-c/Dialect/DC.h new file mode 100644 index 000000000000..3608cfb8c75b --- /dev/null +++ b/include/circt-c/Dialect/DC.h @@ -0,0 +1,25 @@ +//===- DC.h - C interface for the DC dialect ----------------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_C_DIALECT_DC_H +#define CIRCT_C_DIALECT_DC_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(DC, dc); +MLIR_CAPI_EXPORTED void registerDCPasses(void); + +#ifdef __cplusplus +} +#endif + +#endif // CIRCT_C_DIALECT_DC_H diff --git a/lib/Bindings/Python/CIRCTModule.cpp b/lib/Bindings/Python/CIRCTModule.cpp index 3a2aa7cb3576..db308e759e85 100644 --- a/lib/Bindings/Python/CIRCTModule.cpp +++ b/lib/Bindings/Python/CIRCTModule.cpp @@ -11,6 +11,7 @@ #include "circt-c/Conversion.h" #include "circt-c/Dialect/Arc.h" #include "circt-c/Dialect/Comb.h" +#include "circt-c/Dialect/DC.h" #include "circt-c/Dialect/Debug.h" #include "circt-c/Dialect/ESI.h" #include "circt-c/Dialect/Emit.h" @@ -40,6 +41,7 @@ namespace py = pybind11; static void registerPasses() { registerArcPasses(); registerCombPasses(); + registerDCPasses(); registerSeqPasses(); registerSVPasses(); registerFSMPasses(); diff --git a/lib/Bindings/Python/CMakeLists.txt b/lib/Bindings/Python/CMakeLists.txt index ca830c492140..4319f91fc407 100644 --- a/lib/Bindings/Python/CMakeLists.txt +++ b/lib/Bindings/Python/CMakeLists.txt @@ -25,6 +25,7 @@ declare_mlir_python_extension(CIRCTBindingsPythonExtension EMBED_CAPI_LINK_LIBS CIRCTCAPIArc CIRCTCAPIComb + CIRCTCAPIDC CIRCTCAPIDebug CIRCTCAPIEmit CIRCTCAPIESI diff --git a/lib/CAPI/Dialect/CMakeLists.txt b/lib/CAPI/Dialect/CMakeLists.txt index 365e74c7c63d..e2b37e8b2f38 100644 --- a/lib/CAPI/Dialect/CMakeLists.txt +++ b/lib/CAPI/Dialect/CMakeLists.txt @@ -3,6 +3,7 @@ set(LLVM_OPTIONAL_SOURCES Arc.cpp CHIRRTL.cpp Comb.cpp + DC.cpp Debug.cpp Emit.cpp ESI.cpp @@ -39,6 +40,14 @@ add_circt_public_c_api_library(CIRCTCAPIComb CIRCTCombTransforms ) +add_circt_public_c_api_library(CIRCTCAPIDC + DC.cpp + + LINK_LIBS PUBLIC + MLIRCAPIIR + CIRCTDC +) + add_circt_public_c_api_library(CIRCTCAPIDebug Debug.cpp diff --git a/lib/CAPI/Dialect/DC.cpp b/lib/CAPI/Dialect/DC.cpp new file mode 100644 index 000000000000..40dd6e04d4a0 --- /dev/null +++ b/lib/CAPI/Dialect/DC.cpp @@ -0,0 +1,18 @@ +//===- DC.cpp - C interface for the DC dialect ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt-c/Dialect/DC.h" +#include "circt/Conversion/Passes.h" +#include "circt/Dialect/DC/DCDialect.h" +#include "circt/Dialect/DC/DCPasses.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/CAPI/Support.h" + +void registerDCPasses() { circt::dc::registerPasses(); } +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(DC, dc, circt::dc::DCDialect) diff --git a/lib/Dialect/DC/Transforms/DCMaterialization.cpp b/lib/Dialect/DC/Transforms/DCMaterialization.cpp index 19c935682242..9ccef06796d6 100644 --- a/lib/Dialect/DC/Transforms/DCMaterialization.cpp +++ b/lib/Dialect/DC/Transforms/DCMaterialization.cpp @@ -107,7 +107,8 @@ static LogicalResult addForkOps(Block &block, OpBuilder &rewriter) { for (auto barg : block.getArguments()) if (!barg.use_empty() && !barg.hasOneUse()) - insertFork(barg, rewriter); + if (isDCTyped(barg)) + insertFork(barg, rewriter); return success(); } diff --git a/test/Dialect/DC/materialize-forks-sinks.mlir b/test/Dialect/DC/materialize-forks-sinks.mlir index 9dfa3e4df040..35b07cc6ed99 100644 --- a/test/Dialect/DC/materialize-forks-sinks.mlir +++ b/test/Dialect/DC/materialize-forks-sinks.mlir @@ -55,3 +55,11 @@ func.func @testUnusedArg(%t: !dc.token, %v : !dc.value) -> () { func.func @testForkOfValue(%v : !dc.value) -> (!dc.value, !dc.value) { return %v, %v : !dc.value, !dc.value } + +// CHECK-LABEL: hw.module @shouldNotChange(in %clk : !seq.clock, in %d : i32) { +// CHECK-NEXT: seq.compreg %d, %clk : i32 +// CHECK-NEXT: seq.compreg %d, %clk : i32 +hw.module @shouldNotChange(in %clk: !seq.clock, in %d: i32) { + seq.compreg %d, %clk : i32 + seq.compreg %d, %clk : i32 +} From b62bc9b733cfb8e50575e012b950c7a9d0483778 Mon Sep 17 00:00:00 2001 From: John Demme Date: Thu, 21 Nov 2024 23:34:09 +0000 Subject: [PATCH 5/6] working! --- frontends/PyCDE/test/test_handshake.py | 14 ++++++++++---- frontends/PyCDE/test/test_polynomial.py | 2 +- lib/Conversion/DCToHW/DCToHW.cpp | 6 ++---- lib/Conversion/HandshakeToDC/HandshakeToDC.cpp | 5 +++-- test/Conversion/DCToHW/basic.mlir | 6 ++++++ 5 files changed, 22 insertions(+), 11 deletions(-) diff --git a/frontends/PyCDE/test/test_handshake.py b/frontends/PyCDE/test/test_handshake.py index effe5faf3689..463085a65092 100644 --- a/frontends/PyCDE/test/test_handshake.py +++ b/frontends/PyCDE/test/test_handshake.py @@ -15,6 +15,15 @@ # CHECK: return %0 : i8 # CHECK: } +# CHECK: hw.module @Top(in %clk : i1, in %rst : i1, in %a : i8, in %a_valid : i1, in %x_ready : i1, out a_ready : i1, out x : i8, out x_valid : i1) +# CHECK: %TestFunc.in0_ready, %TestFunc.out0, %TestFunc.out0_valid = hw.instance "TestFunc" @TestFunc(in0: %a: i8, in0_valid: %a_valid: i1, clk: %clk: i1, rst: %rst: i1, out0_ready: %x_ready: i1) -> (in0_ready: i1, out0: i8, out0_valid: i1) +# CHECK: hw.output %TestFunc.in0_ready, %TestFunc.out0, %TestFunc.out0_valid : i1, i8, i1 +# CHECK: hw.module @TestFunc(in %in0 : i8, in %in0_valid : i1, in %clk : i1, in %rst : i1, in %out0_ready : i1, out in0_ready : i1, out out0 : i8, out out0_valid : i1) +# CHECK: %c15_i8 = hw.constant 15 : i8 +# CHECK: [[R0:%.+]] = comb.and %out0_ready, %in0_valid : i1 +# CHECK: [[R1:%.+]] = comb.and bin %in0, %c15_i8 : i8 +# CHECK: hw.output [[R0]], [[R1]], %in0_valid : i1, i8, i1 + class TestFunc(Func): a = Input(Bits(8)) @@ -28,10 +37,7 @@ def build(ports): BarType = types.struct({"foo": types.i12}, "bar") -@unittestmodule(print=True, - run_passes=True, - print_after_passes=True, - debug=True) +@unittestmodule(print=True, run_passes=True, print_after_passes=True) class Top(Module): clk = Clock() rst = Input(Bits(1)) diff --git a/frontends/PyCDE/test/test_polynomial.py b/frontends/PyCDE/test/test_polynomial.py index a8563ca5a4bb..0b2fd9e74090 100755 --- a/frontends/PyCDE/test/test_polynomial.py +++ b/frontends/PyCDE/test/test_polynomial.py @@ -122,7 +122,7 @@ def construct(self): print("Generating rest...") poly.generate() -poly.run_passes() +poly.run_passes(debug=True) print("=== Final IR...") poly.print() diff --git a/lib/Conversion/DCToHW/DCToHW.cpp b/lib/Conversion/DCToHW/DCToHW.cpp index b87fc15fb1b2..51e279417796 100644 --- a/lib/Conversion/DCToHW/DCToHW.cpp +++ b/lib/Conversion/DCToHW/DCToHW.cpp @@ -837,10 +837,8 @@ static bool isDCType(Type type) { return isa(type); } /// Returns true if the given `op` is considered as legal - i.e. it does not /// contain any dc-typed values. static bool isLegalOp(Operation *op) { - if (auto funcOp = dyn_cast(op)) { - return llvm::none_of(funcOp.getPortTypes(), isDCType) && - llvm::none_of(funcOp.getBodyBlock()->getArgumentTypes(), isDCType); - } + if (auto funcOp = dyn_cast(op)) + return llvm::none_of(funcOp.getPortTypes(), isDCType); bool operandsOK = llvm::none_of(op->getOperandTypes(), isDCType); bool resultsOK = llvm::none_of(op->getResultTypes(), isDCType); diff --git a/lib/Conversion/HandshakeToDC/HandshakeToDC.cpp b/lib/Conversion/HandshakeToDC/HandshakeToDC.cpp index 5f12fd419bec..a7404d1814cd 100644 --- a/lib/Conversion/HandshakeToDC/HandshakeToDC.cpp +++ b/lib/Conversion/HandshakeToDC/HandshakeToDC.cpp @@ -763,7 +763,7 @@ class HandshakeToDCPass void runOnOperation() override { mlir::ModuleOp mod = getOperation(); auto targetModifier = [](mlir::ConversionTarget &target) { - target.addLegalDialect(); + // target.addLegalDialect(); }; auto patternBuilder = [&](TypeConverter &typeConverter, @@ -807,7 +807,8 @@ LogicalResult circt::handshaketodc::runHandshakeToDC( ConversionTarget target(*ctx); target.addIllegalDialect(); target.addLegalDialect(); - target.addLegalOp(); + target.addLegalOp(); // And any user-specified target adjustments if (configureTarget) diff --git a/test/Conversion/DCToHW/basic.mlir b/test/Conversion/DCToHW/basic.mlir index 52a436dbcb85..002c7e600d8c 100644 --- a/test/Conversion/DCToHW/basic.mlir +++ b/test/Conversion/DCToHW/basic.mlir @@ -196,3 +196,9 @@ hw.module @merge(in %first : !dc.token, in %second : !dc.token, out token : !dc. %selected = dc.merge %first, %second hw.output %selected : !dc.value } + +// CHECK: hw.module.extern @ext(in %a : i32, out b : i32) +hw.module.extern @ext(in %a : i32, out b : i32) + +// CHECK: hw.module.extern @extDC(in %a : !esi.channel, out b : i32) +hw.module.extern @extDC(in %a : !dc.value, out b : i32) From 96401090b35ebedb2a11b3b8d951607154bf6b72 Mon Sep 17 00:00:00 2001 From: John Demme Date: Fri, 22 Nov 2024 22:10:00 +0000 Subject: [PATCH 6/6] integration test! --- frontends/PyCDE/integration_test/esi_test.py | 26 ++++++++++++++++++- .../test_software/esi_test.py | 17 ++++++++++++ .../HandshakeToDC/HandshakeToDC.cpp | 6 +---- 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/frontends/PyCDE/integration_test/esi_test.py b/frontends/PyCDE/integration_test/esi_test.py index d0a5f9543f66..840912e0a384 100644 --- a/frontends/PyCDE/integration_test/esi_test.py +++ b/frontends/PyCDE/integration_test/esi_test.py @@ -7,11 +7,12 @@ import pycde from pycde import (AppID, Clock, Module, Reset, modparams, generator) from pycde.bsp import cosim -from pycde.common import Constant +from pycde.common import Constant, Input, Output from pycde.constructs import ControlReg, Reg, Wire from pycde.esi import ChannelService, FuncService, MMIO, MMIOReadWriteCmdType from pycde.types import (Bits, Channel, UInt) from pycde.behavioral import If, Else, EndIf +from pycde.handshake import Func import sys @@ -107,6 +108,28 @@ def construct(ports): ChannelService.to_host(AppID("const_producer"), ch) +class JoinFunc(Func): + a = Input(UInt(32)) + b = Input(UInt(32)) + x = Output(UInt(32)) + + @generator + def construct(ports): + ports.x = (ports.a + ports.b).as_uint(32) + + +class Join(Module): + clk = Clock() + rst = Reset() + + @generator + def construct(ports): + a = ChannelService.from_host(AppID("join_a"), UInt(32)) + b = ChannelService.from_host(AppID("join_b"), UInt(32)) + f = JoinFunc(clk=ports.clk, rst=ports.rst, a=a, b=b) + ChannelService.to_host(AppID("join_x"), f.x) + + class Top(Module): clk = Clock() rst = Reset() @@ -118,6 +141,7 @@ def construct(ports): MMIOClient(i)() MMIOReadWriteClient(clk=ports.clk, rst=ports.rst) ConstProducer(clk=ports.clk, rst=ports.rst) + Join(clk=ports.clk, rst=ports.rst) if __name__ == "__main__": diff --git a/frontends/PyCDE/integration_test/test_software/esi_test.py b/frontends/PyCDE/integration_test/test_software/esi_test.py index 1c7e26c53a66..8dc55f663719 100644 --- a/frontends/PyCDE/integration_test/test_software/esi_test.py +++ b/frontends/PyCDE/integration_test/test_software/esi_test.py @@ -143,3 +143,20 @@ def read_offset_check(i: int, add_amt: int): producer.disconnect() print(f"data: {data}") assert data == 42 + +################################################################################ +# Handshake Join +################################################################################ + +a = d.ports[esi.AppID("join_a")].write_port("data") +a.connect() +b = d.ports[esi.AppID("join_b")].write_port("data") +b.connect() +x = d.ports[esi.AppID("join_x")].read_port("data") +x.connect() + +a.write(15) +b.write(24) +xdata = x.read() +print(f"join: {xdata}") +assert xdata == 15 + 24 diff --git a/lib/Conversion/HandshakeToDC/HandshakeToDC.cpp b/lib/Conversion/HandshakeToDC/HandshakeToDC.cpp index a7404d1814cd..25173a9777d6 100644 --- a/lib/Conversion/HandshakeToDC/HandshakeToDC.cpp +++ b/lib/Conversion/HandshakeToDC/HandshakeToDC.cpp @@ -762,10 +762,6 @@ class HandshakeToDCPass public: void runOnOperation() override { mlir::ModuleOp mod = getOperation(); - auto targetModifier = [](mlir::ConversionTarget &target) { - // target.addLegalDialect(); - }; - auto patternBuilder = [&](TypeConverter &typeConverter, handshaketodc::ConvertedOps &convertedOps, RewritePatternSet &patterns) { @@ -774,7 +770,7 @@ class HandshakeToDCPass patterns.add(typeConverter, mod.getContext()); }; - LogicalResult res = runHandshakeToDC(mod, patternBuilder, targetModifier); + LogicalResult res = runHandshakeToDC(mod, patternBuilder, nullptr); if (failed(res)) signalPassFailure(); }