-
Notifications
You must be signed in to change notification settings - Fork 302
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PyCDE][Handshake] Add bindings for Handshake functions (#7849)
Exposes the Handshake dialect's FuncOp in the same style as an HWModuleOp and instantiates them via ESI channels.
- Loading branch information
Showing
7 changed files
with
261 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from __future__ import annotations | ||
from typing import List, Optional, Dict | ||
|
||
from .module import Module, ModuleLikeBuilderBase, PortError | ||
from .signals import BitsSignal, ChannelSignal, ClockSignal, Signal | ||
from .system import System | ||
from .support import get_user_loc, obj_to_typed_attribute | ||
from .types import Channel | ||
|
||
from .circt.dialects import handshake as raw_handshake | ||
from .circt import ir | ||
|
||
|
||
class FuncBuilder(ModuleLikeBuilderBase): | ||
"""Defines how a handshake function gets built.""" | ||
|
||
def create_op(self, sys: System, symbol): | ||
"""Callback for creating a handshake.func op.""" | ||
|
||
self.create_op_common(sys, symbol) | ||
|
||
assert len(self.generators) > 0 | ||
|
||
if hasattr(self, "parameters") and self.parameters is not None: | ||
self.attributes["pycde.parameters"] = self.parameters | ||
# If this Module has a generator, it's a real module. | ||
return raw_handshake.FuncOp.create( | ||
symbol, | ||
[(p.name, p.type._type) for p in self.inputs], | ||
[(p.name, p.type._type) for p in self.outputs], | ||
attributes=self.attributes, | ||
loc=self.loc, | ||
ip=sys._get_ip(), | ||
) | ||
|
||
def generate(self): | ||
"""Fill in (generate) this module. Only supports a single generator | ||
currently.""" | ||
if len(self.generators) != 1: | ||
raise ValueError("Must have exactly one generator.") | ||
g: Generator = list(self.generators.values())[0] | ||
|
||
entry_block = self.circt_mod.add_entry_block() | ||
ports = self.generator_port_proxy(entry_block.arguments, self) | ||
with self.GeneratorCtxt(self, ports, entry_block, g.loc): | ||
outputs = g.gen_func(ports) | ||
if outputs is not None: | ||
raise ValueError("Generators must not return a value") | ||
|
||
ports._check_unconnected_outputs() | ||
raw_handshake.ReturnOp([o.value for o in ports._output_values]) | ||
|
||
def instantiate(self, module_inst, inputs, instance_name: str): | ||
""""Instantiate this Func from ESI channels. Check that the input types | ||
match expectations.""" | ||
|
||
port_input_lookup = {port.name: port for port in self.inputs} | ||
circt_inputs: List[Optional[ir.Value]] = [None] * len(self.inputs) | ||
remaining_inputs = set(port_input_lookup.keys()) | ||
clk = None | ||
rst = None | ||
for name, signal in inputs.items(): | ||
if name == "clk": | ||
if not isinstance(signal, ClockSignal): | ||
raise PortError("'clk' must be a clock signal") | ||
clk = signal.value | ||
continue | ||
elif name == "rst": | ||
if not isinstance(signal, BitsSignal): | ||
raise PortError("'rst' must be a Bits(1)") | ||
rst = signal.value | ||
continue | ||
|
||
if name not in port_input_lookup: | ||
raise PortError(f"Input port {name} not found in module") | ||
port = port_input_lookup[name] | ||
if isinstance(signal, ChannelSignal): | ||
# If the input is a channel signal, the types must match. | ||
if signal.type.inner_type != port.type: | ||
raise ValueError( | ||
f"Wrong type on input signal '{name}'. Got '{signal.type}'," | ||
f" expected '{port.type}'") | ||
assert port.idx is not None | ||
circt_inputs[port.idx] = signal.value | ||
remaining_inputs.remove(name) | ||
elif isinstance(signal, Signal): | ||
raise PortError(f"Input {name} must be a channel signal") | ||
else: | ||
raise PortError(f"Port {name} must be a signal") | ||
if clk is None: | ||
raise PortError("Missing 'clk' signal") | ||
if rst is None: | ||
raise PortError("Missing 'rst' signal") | ||
if len(remaining_inputs) > 0: | ||
raise PortError( | ||
f"Missing input signals for ports: {', '.join(remaining_inputs)}") | ||
|
||
circt_mod = self.circt_mod | ||
assert circt_mod is not None | ||
result_types = [Channel(port.type)._type for port in self.outputs] | ||
inst = raw_handshake.ESIInstanceOp( | ||
result_types, | ||
ir.StringAttr(circt_mod.attributes["sym_name"]).value, | ||
instance_name, | ||
clk=clk, | ||
rst=rst, | ||
opOperands=circt_inputs, | ||
loc=get_user_loc()) | ||
inst.operation.verify() | ||
return inst | ||
|
||
|
||
class Func(Module): | ||
"""A Handshake function is intended to implicitly model dataflow. If can | ||
contain any combinational operation and offers a software-like (HLS) approach | ||
to hardware design. | ||
The PyCDE interface to it (this class) is considered experimental. Use at your | ||
own risk and test the resulting RTL thoroughly.""" | ||
|
||
BuilderType: type[ModuleLikeBuilderBase] = FuncBuilder | ||
_builder: FuncBuilder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# RUN: %PYTHON% %s | FileCheck %s | ||
|
||
from pycde import (Clock, Output, Input, generator, types, Module) | ||
from pycde.handshake import Func | ||
from pycde.testing import unittestmodule | ||
from pycde.types import Bits, Channel | ||
|
||
# CHECK: hw.module @Top(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel<i8>, out x : !esi.channel<i8>) | ||
# CHECK: [[R0:%.+]] = handshake.esi_instance @TestFunc "TestFunc" clk %clk rst %rst(%a) : (!esi.channel<i8>) -> !esi.channel<i8> | ||
# CHECK: hw.output [[R0]] : !esi.channel<i8> | ||
# CHECK: } | ||
# CHECK: handshake.func @TestFunc(%arg0: i8, ...) -> i8 | ||
# CHECK: %c15_i8 = hw.constant 15 : i8 | ||
# CHECK: %0 = comb.and bin %arg0, %c15_i8 : i8 | ||
# CHECK: return %0 : i8 | ||
# CHECK: } | ||
|
||
|
||
class TestFunc(Func): | ||
a = Input(Bits(8)) | ||
x = Output(Bits(8)) | ||
|
||
@generator | ||
def build(ports): | ||
ports.x = ports.a & Bits(8)(0xF) | ||
|
||
|
||
BarType = types.struct({"foo": types.i12}, "bar") | ||
|
||
|
||
@unittestmodule(print=True, run_passes=True) | ||
class Top(Module): | ||
clk = Clock() | ||
rst = Input(Bits(1)) | ||
|
||
a = Input(Channel(Bits(8))) | ||
x = Output(Channel(Bits(8))) | ||
|
||
@generator | ||
def build(ports): | ||
test = TestFunc(clk=ports.clk, rst=ports.rst, a=ports.a) | ||
ports.x = test.x |