Skip to content

Commit

Permalink
[PyCDE][Handshake] Add bindings for Handshake functions
Browse files Browse the repository at this point in the history
  • Loading branch information
teqdruid committed Nov 19, 2024
1 parent 79dce73 commit 358b677
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 5 deletions.
1 change: 1 addition & 0 deletions frontends/PyCDE/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ declare_mlir_python_sources(PyCDESources
pycde/common.py
pycde/system.py
pycde/devicedb.py
pycde/handshake.py
pycde/instance.py
pycde/seq.py
pycde/signals.py
Expand Down
146 changes: 146 additions & 0 deletions frontends/PyCDE/src/pycde/handshake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations
from typing import Any, List, Optional, Set, Tuple, Dict
import typing

from .module import Module, ModuleLikeBuilderBase, PortError
from .signals import BitsSignal, ChannelSignal, ClockSignal, Signal
from .system import System
from .support import (get_user_loc, _obj_to_attribute, obj_to_typed_attribute,
create_const_zero)
from .types import Channel

from .circt.dialects import handshake as raw_handshake
from .circt import ir


class FuncBuilder(ModuleLikeBuilderBase):
"""Defines how an ESI `PureModule` gets built."""

@property
def circt_mod(self):
sys: System = System.current()
ret = sys._op_cache.get_circt_mod(self)
if ret is None:
return sys._create_circt_mod(self)
return ret

def create_op(self, sys: System, symbol):
if hasattr(self.modcls, "metadata"):
meta = self.modcls.metadata
self.add_metadata(sys, symbol, meta)
else:
self.add_metadata(sys, symbol, None)

# If there are associated constants, add them to the manifest.
if len(self.constants) > 0:
constants_dict: Dict[str, ir.Attribute] = {}
for name, constant in self.constants.items():
constant_attr = obj_to_typed_attribute(constant.value, constant.type)
constants_dict[name] = constant_attr
with ir.InsertionPoint(sys.mod.body):
from .dialects.esi import esi
esi.SymbolConstantsOp(symbolRef=ir.FlatSymbolRefAttr.get(symbol),
constants=ir.DictAttr.get(constants_dict))

assert len(self.generators) > 0

if hasattr(self, "parameters") and self.parameters is not None:
self.attributes["pycde.parameters"] = self.parameters
# If this Module has a generator, it's a real module.
return raw_handshake.FuncOp.create(
symbol,
[(p.name, p.type._type) for p in self.inputs],
[(p.name, p.type._type) for p in self.outputs],
attributes=self.attributes,
loc=self.loc,
ip=sys._get_ip(),
)

def generate(self):
"""Fill in (generate) this module. Only supports a single generator
currently."""
if len(self.generators) != 1:
raise ValueError("Must have exactly one generator.")
g: Generator = list(self.generators.values())[0]

entry_block = self.circt_mod.add_entry_block()
ports = self.generator_port_proxy(entry_block.arguments, self)
with self.GeneratorCtxt(self, ports, entry_block, g.loc):
outputs = g.gen_func(ports)
if outputs is not None:
raise ValueError("Generators must not return a value")

ports._check_unconnected_outputs()
raw_handshake.ReturnOp([o.value for o in ports._output_values])

def instantiate(self, module_inst, inputs, instance_name: str):
""""Instantiate this Func from ESI channels. Check that the input types
match expectations."""

port_input_lookup = {port.name: port for port in self.inputs}
circt_inputs: List[Optional[ir.Value]] = [None] * len(self.inputs)
remaining_inputs = set(port_input_lookup.keys())
clk = None
rst = None
for name, signal in inputs.items():
if name == "clk":
if not isinstance(signal, ClockSignal):
raise PortError("'clk' must be a clock signal")
clk = signal.value
continue
elif name == "rst":
if not isinstance(signal, BitsSignal):
raise PortError("'rst' must be a Bits(1)")
rst = signal.value
continue

if name not in port_input_lookup:
raise PortError(f"Input port {name} not found in module")
port = port_input_lookup[name]
if isinstance(signal, ChannelSignal):
# If the input is a channel signal, the types must match.
if signal.type.inner_type != port.type:
raise ValueError(
f"Wrong type on input signal '{name}'. Got '{signal.type}',"
f" expected '{port.type}'")
assert port.idx is not None
circt_inputs[port.idx] = signal.value
remaining_inputs.remove(name)
elif isinstance(signal, Signal):
raise PortError(f"Input {name} must be a channel signal")
else:
raise PortError(f"Port {name} must be a signal")
if clk is None:
raise PortError("Missing 'clk' signal")
if rst is None:
raise PortError("Missing 'rst' signal")
if len(remaining_inputs) > 0:
raise PortError(
f"Missing input signals for ports: {', '.join(remaining_inputs)}")

circt_mod = self.circt_mod
assert circt_mod is not None
result_types = [Channel(port.type)._type for port in self.outputs]
inst = raw_handshake.ESIInstanceOp(
result_types,
ir.StringAttr(circt_mod.attributes["sym_name"]).value,
instance_name,
clk=clk,
rst=rst,
opOperands=circt_inputs,
loc=get_user_loc())
inst.operation.verify()
return inst


class Func(Module):
"""A pure ESI module has no ports and contains only instances of modules with
only ESI ports and connections between said instances. Use ESI services for
external communication."""

BuilderType: type[ModuleLikeBuilderBase] = FuncBuilder
_builder: FuncBuilder
42 changes: 42 additions & 0 deletions frontends/PyCDE/test/test_handshake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# RUN: %PYTHON% %s | FileCheck %s

from pycde import (Clock, Output, Input, generator, types, Module)
from pycde.handshake import Func
from pycde.testing import unittestmodule
from pycde.types import Bits, Channel

# CHECK: hw.module @Top(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel<i8>, out x : !esi.channel<i8>) attributes {output_file = #hw.output_file<"Top.sv", includeReplicatedOps>} {
# CHECK: %0 = handshake.esi_instance @TestFunc "TestFunc" clk %clk rst %rst(%a) : (!esi.channel<i8>) -> !esi.channel<i8>
# CHECK: hw.output %0 : !esi.channel<i8>
# CHECK: }
# CHECK: handshake.func @TestFunc(%arg0: i8, ...) -> i8 attributes {argNames = ["a"], output_file = #hw.output_file<"TestFunc.sv", includeReplicatedOps>, resNames = ["x"]} {
# CHECK: %c15_i8 = hw.constant 15 : i8
# CHECK: %0 = comb.and bin %arg0, %c15_i8 : i8
# CHECK: return %0 : i8
# CHECK: }


class TestFunc(Func):
a = Input(Bits(8))
x = Output(Bits(8))

@generator
def build(ports):
ports.x = ports.a & Bits(8)(0xF)


BarType = types.struct({"foo": types.i12}, "bar")


@unittestmodule()
class Top(Module):
clk = Clock()
rst = Input(Bits(1))

a = Input(Channel(Bits(8)))
x = Output(Channel(Bits(8)))

@generator
def build(ports):
test = TestFunc(clk=ports.clk, rst=ports.rst, a=ports.a)
ports.x = test.x
17 changes: 12 additions & 5 deletions lib/Bindings/Python/dialects/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations
from typing import Dict, List, Tuple, Union

from . import handshake
from ._handshake_ops_gen import *
from ._handshake_ops_gen import _Dialect

from ..dialects._ods_common import _cext as _ods_cext
from ..ir import ArrayAttr, FunctionType, StringAttr, Type, TypeAttr

from typing import List, Tuple, Union
from ..ir import ArrayAttr, Attribute, FunctionType, StringAttr, Type, TypeAttr

from ._ods_common import (
equally_sized_accessor as _ods_equally_sized_accessor,
Expand All @@ -32,17 +31,25 @@ class FuncOp(FuncOp):
def create(sym_name: Union[StringAttr, str],
args: List[Tuple[str, Type]],
results: List[Tuple[str, Type]],
private: bool = False) -> FuncOp:
attributes: Dict[str, Attribute] = {},
loc=None,
ip=None) -> FuncOp:
if isinstance(sym_name, str):
sym_name = StringAttr.get(sym_name)
input_types = [t for _, t in args]
res_types = [t for _, t in results]
func_type = FunctionType.get(input_types, res_types)
func_type_attr = TypeAttr.get(func_type)
funcop = FuncOp(func_type_attr)
funcop = FuncOp(func_type_attr, loc=loc, ip=ip)
for k, v in attributes.items():
funcop.attributes[k] = v
funcop.attributes["sym_name"] = sym_name
funcop.attributes["argNames"] = ArrayAttr.get(
[StringAttr.get(name) for name, _ in args])
funcop.attributes["resNames"] = ArrayAttr.get(
[StringAttr.get(name) for name, _ in results])
return funcop

def add_entry_block(self):
self.body.blocks.append(*self.function_type.value.inputs)
return self.body.blocks[0]

0 comments on commit 358b677

Please sign in to comment.