Skip to content

Commit

Permalink
[PyCDE][Handshake] Add bindings for Handshake functions (#7849)
Browse files Browse the repository at this point in the history
Exposes the Handshake dialect's FuncOp in the same style as an HWModuleOp and instantiates them via ESI channels.
  • Loading branch information
teqdruid authored Nov 27, 2024
1 parent ced18f0 commit 8283dcb
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 33 deletions.
26 changes: 25 additions & 1 deletion frontends/PyCDE/integration_test/esi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import pycde
from pycde import (AppID, Clock, Module, Reset, modparams, generator)
from pycde.bsp import cosim
from pycde.common import Constant
from pycde.common import Constant, Input, Output
from pycde.constructs import ControlReg, Reg, Wire
from pycde.esi import ChannelService, FuncService, MMIO, MMIOReadWriteCmdType
from pycde.types import (Bits, Channel, UInt)
from pycde.behavioral import If, Else, EndIf
from pycde.handshake import Func

import sys

Expand Down Expand Up @@ -107,6 +108,28 @@ def construct(ports):
ChannelService.to_host(AppID("const_producer"), ch)


class JoinAddFunc(Func):
a = Input(UInt(32))
b = Input(UInt(32))
x = Output(UInt(32))

@generator
def construct(ports):
ports.x = (ports.a + ports.b).as_uint(32)


class Join(Module):
clk = Clock()
rst = Reset()

@generator
def construct(ports):
a = ChannelService.from_host(AppID("join_a"), UInt(32))
b = ChannelService.from_host(AppID("join_b"), UInt(32))
f = JoinAddFunc(clk=ports.clk, rst=ports.rst, a=a, b=b)
ChannelService.to_host(AppID("join_x"), f.x)


class Top(Module):
clk = Clock()
rst = Reset()
Expand All @@ -118,6 +141,7 @@ def construct(ports):
MMIOClient(i)()
MMIOReadWriteClient(clk=ports.clk, rst=ports.rst)
ConstProducer(clk=ports.clk, rst=ports.rst)
Join(clk=ports.clk, rst=ports.rst)


if __name__ == "__main__":
Expand Down
17 changes: 17 additions & 0 deletions frontends/PyCDE/integration_test/test_software/esi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,20 @@ def read_offset_check(i: int, add_amt: int):
producer.disconnect()
print(f"data: {data}")
assert data == 42

################################################################################
# Handshake JoinAddFunc tests
################################################################################

a = d.ports[esi.AppID("join_a")].write_port("data")
a.connect()
b = d.ports[esi.AppID("join_b")].write_port("data")
b.connect()
x = d.ports[esi.AppID("join_x")].read_port("data")
x.connect()

a.write(15)
b.write(24)
xdata = x.read()
print(f"join: {xdata}")
assert xdata == 15 + 24
1 change: 1 addition & 0 deletions frontends/PyCDE/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ declare_mlir_python_sources(PyCDESources
pycde/common.py
pycde/system.py
pycde/devicedb.py
pycde/handshake.py
pycde/instance.py
pycde/seq.py
pycde/signals.py
Expand Down
126 changes: 126 additions & 0 deletions frontends/PyCDE/src/pycde/handshake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations
from typing import List, Optional, Dict

from .module import Module, ModuleLikeBuilderBase, PortError
from .signals import BitsSignal, ChannelSignal, ClockSignal, Signal
from .system import System
from .support import get_user_loc, obj_to_typed_attribute
from .types import Channel

from .circt.dialects import handshake as raw_handshake
from .circt import ir


class FuncBuilder(ModuleLikeBuilderBase):
"""Defines how a handshake function gets built."""

def create_op(self, sys: System, symbol):
"""Callback for creating a handshake.func op."""

self.create_op_common(sys, symbol)

assert len(self.generators) > 0

if hasattr(self, "parameters") and self.parameters is not None:
self.attributes["pycde.parameters"] = self.parameters
# If this Module has a generator, it's a real module.
return raw_handshake.FuncOp.create(
symbol,
[(p.name, p.type._type) for p in self.inputs],
[(p.name, p.type._type) for p in self.outputs],
attributes=self.attributes,
loc=self.loc,
ip=sys._get_ip(),
)

def generate(self):
"""Fill in (generate) this module. Only supports a single generator
currently."""
if len(self.generators) != 1:
raise ValueError("Must have exactly one generator.")
g: Generator = list(self.generators.values())[0]

entry_block = self.circt_mod.add_entry_block()
ports = self.generator_port_proxy(entry_block.arguments, self)
with self.GeneratorCtxt(self, ports, entry_block, g.loc):
outputs = g.gen_func(ports)
if outputs is not None:
raise ValueError("Generators must not return a value")

ports._check_unconnected_outputs()
raw_handshake.ReturnOp([o.value for o in ports._output_values])

def instantiate(self, module_inst, inputs, instance_name: str):
""""Instantiate this Func from ESI channels. Check that the input types
match expectations."""

port_input_lookup = {port.name: port for port in self.inputs}
circt_inputs: List[Optional[ir.Value]] = [None] * len(self.inputs)
remaining_inputs = set(port_input_lookup.keys())
clk = None
rst = None
for name, signal in inputs.items():
if name == "clk":
if not isinstance(signal, ClockSignal):
raise PortError("'clk' must be a clock signal")
clk = signal.value
continue
elif name == "rst":
if not isinstance(signal, BitsSignal):
raise PortError("'rst' must be a Bits(1)")
rst = signal.value
continue

if name not in port_input_lookup:
raise PortError(f"Input port {name} not found in module")
port = port_input_lookup[name]
if isinstance(signal, ChannelSignal):
# If the input is a channel signal, the types must match.
if signal.type.inner_type != port.type:
raise ValueError(
f"Wrong type on input signal '{name}'. Got '{signal.type}',"
f" expected '{port.type}'")
assert port.idx is not None
circt_inputs[port.idx] = signal.value
remaining_inputs.remove(name)
elif isinstance(signal, Signal):
raise PortError(f"Input {name} must be a channel signal")
else:
raise PortError(f"Port {name} must be a signal")
if clk is None:
raise PortError("Missing 'clk' signal")
if rst is None:
raise PortError("Missing 'rst' signal")
if len(remaining_inputs) > 0:
raise PortError(
f"Missing input signals for ports: {', '.join(remaining_inputs)}")

circt_mod = self.circt_mod
assert circt_mod is not None
result_types = [Channel(port.type)._type for port in self.outputs]
inst = raw_handshake.ESIInstanceOp(
result_types,
ir.StringAttr(circt_mod.attributes["sym_name"]).value,
instance_name,
clk=clk,
rst=rst,
opOperands=circt_inputs,
loc=get_user_loc())
inst.operation.verify()
return inst


class Func(Module):
"""A Handshake function is intended to implicitly model dataflow. If can
contain any combinational operation and offers a software-like (HLS) approach
to hardware design.
The PyCDE interface to it (this class) is considered experimental. Use at your
own risk and test the resulting RTL thoroughly."""

BuilderType: type[ModuleLikeBuilderBase] = FuncBuilder
_builder: FuncBuilder
71 changes: 39 additions & 32 deletions frontends/PyCDE/src/pycde/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,20 @@ def inputs(self) -> List[Input]:
def outputs(self) -> List[Output]:
return [p for p in self.ports if isinstance(p, Output)]

@property
def circt_mod(self):
"""Get the raw CIRCT operation for the module definition. DO NOT store the
returned value!!! It needs to get reaped after the current action (e.g.
instantiation, generation). Memory safety when interacting with native code
can be painful."""

from .system import System
sys: System = System.current()
ret = sys._op_cache.get_circt_mod(self)
if ret is None:
return sys._create_circt_mod(self)
return ret

def go(self):
"""Execute the analysis and mutation to make a `ModuleLike` class operate
as such."""
Expand Down Expand Up @@ -400,6 +414,30 @@ def add_metadata(self, sys, symbol: str, meta: Optional[Metadata]):
for k, v in meta.misc.items():
meta_op.attributes[k] = _obj_to_attribute(v)

def create_op_common(self, sys, symbol):
"""Do common work for creating a module-like op. This includes adding
metadata and adding parameters to the attributes."""

if hasattr(self.modcls, "metadata"):
meta = self.modcls.metadata
self.add_metadata(sys, symbol, meta)
else:
self.add_metadata(sys, symbol, None)

# If there are associated constants, add them to the manifest.
if len(self.constants) > 0:
constants_dict: Dict[str, ir.Attribute] = {}
for name, constant in self.constants.items():
constant_attr = obj_to_typed_attribute(constant.value, constant.type)
constants_dict[name] = constant_attr
with ir.InsertionPoint(sys.mod.body):
from .dialects.esi import esi
esi.SymbolConstantsOp(symbolRef=ir.FlatSymbolRefAttr.get(symbol),
constants=ir.DictAttr.get(constants_dict))

if hasattr(self, "parameters") and self.parameters is not None:
self.attributes["pycde.parameters"] = self.parameters

class GeneratorCtxt:
"""Provides an context which most genertors need."""

Expand Down Expand Up @@ -457,43 +495,12 @@ def __init__(cls, name, bases, dct: Dict):
class ModuleBuilder(ModuleLikeBuilderBase):
"""Defines how a `Module` gets built. Extend the base class and customize."""

@property
def circt_mod(self):
"""Get the raw CIRCT operation for the module definition. DO NOT store the
returned value!!! It needs to get reaped after the current action (e.g.
instantiation, generation). Memory safety when interacting with native code
can be painful."""

from .system import System
sys: System = System.current()
ret = sys._op_cache.get_circt_mod(self)
if ret is None:
return sys._create_circt_mod(self)
return ret

def create_op(self, sys, symbol):
"""Callback for creating a module op."""

if hasattr(self.modcls, "metadata"):
meta = self.modcls.metadata
self.add_metadata(sys, symbol, meta)
else:
self.add_metadata(sys, symbol, None)

# If there are associated constants, add them to the manifest.
if len(self.constants) > 0:
constants_dict: Dict[str, ir.Attribute] = {}
for name, constant in self.constants.items():
constant_attr = obj_to_typed_attribute(constant.value, constant.type)
constants_dict[name] = constant_attr
with ir.InsertionPoint(sys.mod.body):
from .dialects.esi import esi
esi.SymbolConstantsOp(symbolRef=ir.FlatSymbolRefAttr.get(symbol),
constants=ir.DictAttr.get(constants_dict))
self.create_op_common(sys, symbol)

if len(self.generators) > 0:
if hasattr(self, "parameters") and self.parameters is not None:
self.attributes["pycde.parameters"] = self.parameters
# If this Module has a generator, it's a real module.
return hw.HWModuleOp(
symbol,
Expand Down
11 changes: 11 additions & 0 deletions frontends/PyCDE/src/pycde/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,23 @@ def get_instance(self,
"builtin.module(verify-esi-connections)",
lambda sys: sys.generate(),
"builtin.module(verify-esi-connections)",

# After all of the pycde code has been executed, we have all the types
# defined so we can go through and output the typedefs delcarations.
lambda sys: TypeAlias.declare_aliases(sys.mod),

# Then run all the passes to lower dialects which produce `hw.module`s.
"builtin.module(lower-handshake-to-dc)",
"builtin.module(dc-materialize-forks-sinks)",
"builtin.module(canonicalize)",
"builtin.module(lower-dc-to-hw)",

# Run ESI manifest passes.
"builtin.module(esi-appid-hier{{top={tops} }}, esi-build-manifest{{top={tops} }})",
"builtin.module(msft-lower-constructs, msft-lower-instances)",
"builtin.module(esi-clean-metadata)",

# Instaniate hlmems, which could produce new esi connections.
"builtin.module(hw.module(lower-seq-hlmem))",
"builtin.module(lower-esi-to-physical)",
# TODO: support more than just cosim.
Expand Down
42 changes: 42 additions & 0 deletions frontends/PyCDE/test/test_handshake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# RUN: %PYTHON% %s | FileCheck %s

from pycde import (Clock, Output, Input, generator, types, Module)
from pycde.handshake import Func
from pycde.testing import unittestmodule
from pycde.types import Bits, Channel

# CHECK: hw.module @Top(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel<i8>, out x : !esi.channel<i8>)
# CHECK: [[R0:%.+]] = handshake.esi_instance @TestFunc "TestFunc" clk %clk rst %rst(%a) : (!esi.channel<i8>) -> !esi.channel<i8>
# CHECK: hw.output [[R0]] : !esi.channel<i8>
# CHECK: }
# CHECK: handshake.func @TestFunc(%arg0: i8, ...) -> i8
# CHECK: %c15_i8 = hw.constant 15 : i8
# CHECK: %0 = comb.and bin %arg0, %c15_i8 : i8
# CHECK: return %0 : i8
# CHECK: }


class TestFunc(Func):
a = Input(Bits(8))
x = Output(Bits(8))

@generator
def build(ports):
ports.x = ports.a & Bits(8)(0xF)


BarType = types.struct({"foo": types.i12}, "bar")


@unittestmodule(print=True, run_passes=True)
class Top(Module):
clk = Clock()
rst = Input(Bits(1))

a = Input(Channel(Bits(8)))
x = Output(Channel(Bits(8)))

@generator
def build(ports):
test = TestFunc(clk=ports.clk, rst=ports.rst, a=ports.a)
ports.x = test.x

0 comments on commit 8283dcb

Please sign in to comment.