Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyCDE][Handshake] Add bindings for Handshake functions #7849

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion frontends/PyCDE/integration_test/esi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import pycde
from pycde import (AppID, Clock, Module, Reset, modparams, generator)
from pycde.bsp import cosim
from pycde.common import Constant
from pycde.common import Constant, Input, Output
from pycde.constructs import ControlReg, Reg, Wire
from pycde.esi import ChannelService, FuncService, MMIO, MMIOReadWriteCmdType
from pycde.types import (Bits, Channel, UInt)
from pycde.behavioral import If, Else, EndIf
from pycde.handshake import Func

import sys

Expand Down Expand Up @@ -107,6 +108,28 @@ def construct(ports):
ChannelService.to_host(AppID("const_producer"), ch)


class JoinFunc(Func):
a = Input(UInt(32))
b = Input(UInt(32))
x = Output(UInt(32))

@generator
def construct(ports):
ports.x = (ports.a + ports.b).as_uint(32)


class Join(Module):
clk = Clock()
rst = Reset()

@generator
def construct(ports):
a = ChannelService.from_host(AppID("join_a"), UInt(32))
b = ChannelService.from_host(AppID("join_b"), UInt(32))
f = JoinFunc(clk=ports.clk, rst=ports.rst, a=a, b=b)
ChannelService.to_host(AppID("join_x"), f.x)


class Top(Module):
clk = Clock()
rst = Reset()
Expand All @@ -118,6 +141,7 @@ def construct(ports):
MMIOClient(i)()
MMIOReadWriteClient(clk=ports.clk, rst=ports.rst)
ConstProducer(clk=ports.clk, rst=ports.rst)
Join(clk=ports.clk, rst=ports.rst)


if __name__ == "__main__":
Expand Down
17 changes: 17 additions & 0 deletions frontends/PyCDE/integration_test/test_software/esi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,20 @@ def read_offset_check(i: int, add_amt: int):
producer.disconnect()
print(f"data: {data}")
assert data == 42

################################################################################
# Handshake Join
################################################################################

a = d.ports[esi.AppID("join_a")].write_port("data")
a.connect()
b = d.ports[esi.AppID("join_b")].write_port("data")
b.connect()
x = d.ports[esi.AppID("join_x")].read_port("data")
x.connect()

a.write(15)
b.write(24)
xdata = x.read()
print(f"join: {xdata}")
assert xdata == 15 + 24
1 change: 1 addition & 0 deletions frontends/PyCDE/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ declare_mlir_python_sources(PyCDESources
pycde/common.py
pycde/system.py
pycde/devicedb.py
pycde/handshake.py
pycde/instance.py
pycde/seq.py
pycde/signals.py
Expand Down
146 changes: 146 additions & 0 deletions frontends/PyCDE/src/pycde/handshake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations
from typing import Any, List, Optional, Set, Tuple, Dict
import typing

from .module import Module, ModuleLikeBuilderBase, PortError
from .signals import BitsSignal, ChannelSignal, ClockSignal, Signal
from .system import System
from .support import (get_user_loc, _obj_to_attribute, obj_to_typed_attribute,
create_const_zero)
from .types import Channel

from .circt.dialects import handshake as raw_handshake
from .circt import ir


class FuncBuilder(ModuleLikeBuilderBase):
"""Defines how an ESI `PureModule` gets built."""

@property
def circt_mod(self):
sys: System = System.current()
ret = sys._op_cache.get_circt_mod(self)
if ret is None:
return sys._create_circt_mod(self)
return ret

def create_op(self, sys: System, symbol):
if hasattr(self.modcls, "metadata"):
meta = self.modcls.metadata
self.add_metadata(sys, symbol, meta)
else:
self.add_metadata(sys, symbol, None)

# If there are associated constants, add them to the manifest.
if len(self.constants) > 0:
constants_dict: Dict[str, ir.Attribute] = {}
for name, constant in self.constants.items():
constant_attr = obj_to_typed_attribute(constant.value, constant.type)
constants_dict[name] = constant_attr
with ir.InsertionPoint(sys.mod.body):
from .dialects.esi import esi
esi.SymbolConstantsOp(symbolRef=ir.FlatSymbolRefAttr.get(symbol),
constants=ir.DictAttr.get(constants_dict))

assert len(self.generators) > 0

if hasattr(self, "parameters") and self.parameters is not None:
self.attributes["pycde.parameters"] = self.parameters
# If this Module has a generator, it's a real module.
return raw_handshake.FuncOp.create(
symbol,
[(p.name, p.type._type) for p in self.inputs],
[(p.name, p.type._type) for p in self.outputs],
attributes=self.attributes,
loc=self.loc,
ip=sys._get_ip(),
)

def generate(self):
"""Fill in (generate) this module. Only supports a single generator
currently."""
if len(self.generators) != 1:
raise ValueError("Must have exactly one generator.")
g: Generator = list(self.generators.values())[0]

entry_block = self.circt_mod.add_entry_block()
ports = self.generator_port_proxy(entry_block.arguments, self)
with self.GeneratorCtxt(self, ports, entry_block, g.loc):
outputs = g.gen_func(ports)
if outputs is not None:
raise ValueError("Generators must not return a value")

ports._check_unconnected_outputs()
raw_handshake.ReturnOp([o.value for o in ports._output_values])

def instantiate(self, module_inst, inputs, instance_name: str):
""""Instantiate this Func from ESI channels. Check that the input types
match expectations."""

port_input_lookup = {port.name: port for port in self.inputs}
circt_inputs: List[Optional[ir.Value]] = [None] * len(self.inputs)
remaining_inputs = set(port_input_lookup.keys())
clk = None
rst = None
for name, signal in inputs.items():
if name == "clk":
if not isinstance(signal, ClockSignal):
raise PortError("'clk' must be a clock signal")
clk = signal.value
continue
elif name == "rst":
if not isinstance(signal, BitsSignal):
raise PortError("'rst' must be a Bits(1)")
rst = signal.value
continue

if name not in port_input_lookup:
raise PortError(f"Input port {name} not found in module")
port = port_input_lookup[name]
if isinstance(signal, ChannelSignal):
# If the input is a channel signal, the types must match.
if signal.type.inner_type != port.type:
raise ValueError(
f"Wrong type on input signal '{name}'. Got '{signal.type}',"
f" expected '{port.type}'")
assert port.idx is not None
circt_inputs[port.idx] = signal.value
remaining_inputs.remove(name)
elif isinstance(signal, Signal):
raise PortError(f"Input {name} must be a channel signal")
else:
raise PortError(f"Port {name} must be a signal")
if clk is None:
raise PortError("Missing 'clk' signal")
if rst is None:
raise PortError("Missing 'rst' signal")
if len(remaining_inputs) > 0:
raise PortError(
f"Missing input signals for ports: {', '.join(remaining_inputs)}")

circt_mod = self.circt_mod
assert circt_mod is not None
result_types = [Channel(port.type)._type for port in self.outputs]
inst = raw_handshake.ESIInstanceOp(
result_types,
ir.StringAttr(circt_mod.attributes["sym_name"]).value,
instance_name,
clk=clk,
rst=rst,
opOperands=circt_inputs,
loc=get_user_loc())
inst.operation.verify()
return inst


class Func(Module):
"""A pure ESI module has no ports and contains only instances of modules with
only ESI ports and connections between said instances. Use ESI services for
external communication."""

BuilderType: type[ModuleLikeBuilderBase] = FuncBuilder
_builder: FuncBuilder
12 changes: 12 additions & 0 deletions frontends/PyCDE/src/pycde/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,23 @@ def get_instance(self,
"builtin.module(verify-esi-connections)",
lambda sys: sys.generate(),
"builtin.module(verify-esi-connections)",

# After all of the pycde code has been executed, we have all the types
# defined so we can go through and output the typedefs delcarations.
lambda sys: TypeAlias.declare_aliases(sys.mod),

# Then run all the passes to lower dialects which produce `hw.module`s.
"builtin.module(lower-handshake-to-dc)",
"builtin.module(dc-materialize-forks-sinks)",
"builtin.module(canonicalize)",
"builtin.module(lower-dc-to-hw)",

# Run ESI manifest passes.
"builtin.module(esi-appid-hier{{top={tops} }}, esi-build-manifest{{top={tops} }})",
"builtin.module(msft-lower-constructs, msft-lower-instances)",
"builtin.module(esi-clean-metadata)",

# Instaniate hlmems, which could produce new esi connections.
"builtin.module(hw.module(lower-seq-hlmem))",
"builtin.module(lower-esi-to-physical)",
# TODO: support more than just cosim.
Expand Down Expand Up @@ -303,6 +314,7 @@ def run_passes(self, debug=False):
tcl_file=tcl_file,
platform=self.platform).strip()
if aplog is not None:
print(f"Running phase #{idx}: {phase}")
aplog.write(f"// passes ran: {passes}\n")
aplog.flush()
pm = passmanager.PassManager.parse(passes)
Expand Down
51 changes: 51 additions & 0 deletions frontends/PyCDE/test/test_handshake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# RUN: %PYTHON% %s | FileCheck %s

from pycde import (Clock, Output, Input, generator, types, Module)
from pycde.handshake import Func
from pycde.testing import unittestmodule
from pycde.types import Bits, Channel

# CHECK: hw.module @Top(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel<i8>, out x : !esi.channel<i8>) attributes {output_file = #hw.output_file<"Top.sv", includeReplicatedOps>} {
# CHECK: %0 = handshake.esi_instance @TestFunc "TestFunc" clk %clk rst %rst(%a) : (!esi.channel<i8>) -> !esi.channel<i8>
# CHECK: hw.output %0 : !esi.channel<i8>
# CHECK: }
# CHECK: handshake.func @TestFunc(%arg0: i8, ...) -> i8 attributes {argNames = ["a"], output_file = #hw.output_file<"TestFunc.sv", includeReplicatedOps>, resNames = ["x"]} {
# CHECK: %c15_i8 = hw.constant 15 : i8
# CHECK: %0 = comb.and bin %arg0, %c15_i8 : i8
# CHECK: return %0 : i8
# CHECK: }

# CHECK: hw.module @Top(in %clk : i1, in %rst : i1, in %a : i8, in %a_valid : i1, in %x_ready : i1, out a_ready : i1, out x : i8, out x_valid : i1)
# CHECK: %TestFunc.in0_ready, %TestFunc.out0, %TestFunc.out0_valid = hw.instance "TestFunc" @TestFunc(in0: %a: i8, in0_valid: %a_valid: i1, clk: %clk: i1, rst: %rst: i1, out0_ready: %x_ready: i1) -> (in0_ready: i1, out0: i8, out0_valid: i1)
# CHECK: hw.output %TestFunc.in0_ready, %TestFunc.out0, %TestFunc.out0_valid : i1, i8, i1
# CHECK: hw.module @TestFunc(in %in0 : i8, in %in0_valid : i1, in %clk : i1, in %rst : i1, in %out0_ready : i1, out in0_ready : i1, out out0 : i8, out out0_valid : i1)
# CHECK: %c15_i8 = hw.constant 15 : i8
# CHECK: [[R0:%.+]] = comb.and %out0_ready, %in0_valid : i1
# CHECK: [[R1:%.+]] = comb.and bin %in0, %c15_i8 : i8
# CHECK: hw.output [[R0]], [[R1]], %in0_valid : i1, i8, i1


class TestFunc(Func):
a = Input(Bits(8))
x = Output(Bits(8))

@generator
def build(ports):
ports.x = ports.a & Bits(8)(0xF)


BarType = types.struct({"foo": types.i12}, "bar")


@unittestmodule(print=True, run_passes=True, print_after_passes=True)
class Top(Module):
clk = Clock()
rst = Input(Bits(1))

a = Input(Channel(Bits(8)))
x = Output(Channel(Bits(8)))

@generator
def build(ports):
test = TestFunc(clk=ports.clk, rst=ports.rst, a=ports.a)
ports.x = test.x
2 changes: 1 addition & 1 deletion frontends/PyCDE/test/test_polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def construct(self):

print("Generating rest...")
poly.generate()
poly.run_passes()
poly.run_passes(debug=True)

print("=== Final IR...")
poly.print()
Expand Down
2 changes: 1 addition & 1 deletion frontends/PyCDE/test/test_verilog_readablility.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def build(self):

sys = System([WireNames], output_directory=sys.argv[1])
sys.generate()
sys.run_passes()
sys.run_passes(debug=True)
sys.print()
# CHECK-LABEL: hw.module @WireNames(in %clk : i1, in %data_in : !hw.array<3xi32>, in %sel : i2, out a : i32, out b : i32)
# CHECK: %foo__reg1 = sv.reg sym @foo__reg1 : !hw.inout<i32>
Expand Down
25 changes: 25 additions & 0 deletions include/circt-c/Dialect/DC.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//===- DC.h - C interface for the DC dialect ----------------------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef CIRCT_C_DIALECT_DC_H
#define CIRCT_C_DIALECT_DC_H

#include "mlir-c/IR.h"

#ifdef __cplusplus
extern "C" {
#endif

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(DC, dc);
MLIR_CAPI_EXPORTED void registerDCPasses(void);

#ifdef __cplusplus
}
#endif

#endif // CIRCT_C_DIALECT_DC_H
2 changes: 2 additions & 0 deletions lib/Bindings/Python/CIRCTModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "circt-c/Conversion.h"
#include "circt-c/Dialect/Arc.h"
#include "circt-c/Dialect/Comb.h"
#include "circt-c/Dialect/DC.h"
#include "circt-c/Dialect/Debug.h"
#include "circt-c/Dialect/ESI.h"
#include "circt-c/Dialect/Emit.h"
Expand Down Expand Up @@ -40,6 +41,7 @@ namespace py = pybind11;
static void registerPasses() {
registerArcPasses();
registerCombPasses();
registerDCPasses();
registerSeqPasses();
registerSVPasses();
registerFSMPasses();
Expand Down
1 change: 1 addition & 0 deletions lib/Bindings/Python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ declare_mlir_python_extension(CIRCTBindingsPythonExtension
EMBED_CAPI_LINK_LIBS
CIRCTCAPIArc
CIRCTCAPIComb
CIRCTCAPIDC
CIRCTCAPIDebug
CIRCTCAPIEmit
CIRCTCAPIESI
Expand Down
Loading
Loading