Skip to content

Commit c9aaa58

Browse files
committed
Merge remote-tracking branch 'upstream/main' into issue-9140-new-inferwidths
2 parents e8df104 + daaa75e commit c9aaa58

File tree

154 files changed

+7283
-1539
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

154 files changed

+7283
-1539
lines changed

frontends/PyCDE/integration_test/esitester.py

Lines changed: 749 additions & 195 deletions
Large diffs are not rendered by default.

frontends/PyCDE/integration_test/test_software/esitester.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

frontends/PyCDE/src/pycde/bsp/common.py

Lines changed: 291 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
from ..common import Clock, Input, InputChannel, Output, OutputChannel, Reset
99
from ..constructs import (AssignableSignal, ControlReg, Counter, Mux, NamedWire,
10-
Reg, Wire)
10+
Wire)
1111
from .. import esi
1212
from ..module import Module, generator, modparams
13-
from ..signals import BitsSignal, BundleSignal, ChannelSignal, ClockSignal
13+
from ..signals import BitsSignal, ChannelSignal, StructSignal
1414
from ..support import clog2
1515
from ..types import (Array, Bits, Bundle, BundledChannel, Channel,
1616
ChannelDirection, StructType, Type, UInt)
@@ -93,6 +93,210 @@ def build(ports):
9393
return HeaderMMIO
9494

9595

96+
@modparams
97+
def ChannelDemuxN_HalfStage_ReadyBlocking(
98+
data_type: Type, num_outs: int,
99+
next_sel_width: int) -> type["ChannelDemuxNImpl"]:
100+
"""N-way channel demultiplexer for valid/ready signaling. Contains
101+
valid/ready registers on the output channels. The selection signal is now
102+
embedded in the input channel payload as a struct {sel, data}. Input
103+
signals ready when the selected output register is empty."""
104+
105+
assert num_outs >= 1, "num_outs must be at least 1."
106+
107+
class ChannelDemuxNImpl(Module):
108+
clk = Clock()
109+
rst = Reset()
110+
111+
# Input channel now carries selection along with data.
112+
InPayloadType = StructType([
113+
("sel", Bits(clog2(num_outs))),
114+
("next_sel", Bits(next_sel_width)),
115+
("data", data_type),
116+
])
117+
inp = Input(Channel(InPayloadType))
118+
OutPayloadType = StructType([
119+
("next_sel", Bits(next_sel_width)),
120+
("data", data_type),
121+
])
122+
# Outputs are channels of OutPayloadType, which includes both 'next_sel' and 'data' fields.
123+
for i in range(num_outs):
124+
locals()[f"output_{i}"] = Output(Channel(OutPayloadType))
125+
126+
@generator
127+
def generate(ports) -> None:
128+
# Half-stage demux: one register per output channel. Input is ready
129+
# when the currently selected output register is empty (not valid).
130+
clk = ports.clk
131+
rst = ports.rst
132+
sel_width = clog2(num_outs)
133+
134+
# Unwrap input with backpressure from selected output register.
135+
input_ready = Wire(Bits(1), name="input_ready")
136+
in_payload, in_valid = ports.inp.unwrap(input_ready)
137+
in_sel = in_payload.sel
138+
in_next_sel = in_payload.next_sel
139+
in_data = in_payload.data
140+
141+
# Track per-output valid regs and build a purely combinational
142+
# expression 'selected_valid_expr' = OR_i((sel==i)&valid_i). Avoid
143+
# assigning to a Wire multiple times.
144+
valid_regs: List[BitsSignal] = []
145+
selected_valid_expr = Bits(1)(0)
146+
147+
for i in range(num_outs):
148+
# Write when input transaction targets this output and output not holding data yet.
149+
will_write = Wire(Bits(1), name=f"will_write_{i}")
150+
write_cond = (in_valid & input_ready & (in_sel == Bits(sel_width)(i)))
151+
will_write.assign(write_cond)
152+
153+
# Data and next_sel registers.
154+
out_msg_reg = ChannelDemuxNImpl.OutPayloadType({
155+
"next_sel": in_next_sel,
156+
"data": in_data
157+
}).reg(clk=clk, rst=rst, ce=will_write, name=f"out{i}_msg_reg")
158+
159+
# Valid register cleared on successful downstream consume.
160+
consume = Wire(Bits(1), name=f"consume_{i}")
161+
valid_reg = ControlReg(
162+
clk=clk,
163+
rst=rst,
164+
asserts=[will_write],
165+
resets=[consume],
166+
name=f"out{i}_valid_reg",
167+
)
168+
valid_regs.append(valid_reg)
169+
170+
# Channel wrapper.
171+
ch_sig, ch_ready = Channel(ChannelDemuxNImpl.OutPayloadType).wrap(
172+
out_msg_reg, valid_reg)
173+
setattr(ports, f"output_{i}", ch_sig)
174+
consume.assign(valid_reg & ch_ready)
175+
176+
# Accumulate selected_valid expression.
177+
selected_valid_expr = (selected_valid_expr | (
178+
(in_sel == Bits(sel_width)(i)) & valid_reg)).as_bits()
179+
180+
# Input ready only when selected output has no valid data latched.
181+
input_ready.assign((selected_valid_expr ^ Bits(1)(1)).as_bits())
182+
183+
def get_out(self, index: int) -> ChannelSignal:
184+
return getattr(self, f"output_{index}")
185+
186+
return ChannelDemuxNImpl
187+
188+
189+
@modparams
190+
def ChannelDemuxTree_HalfStage_ReadyBlocking(
191+
data_type: Type, num_outs: int,
192+
branching_factor_log2: int) -> type["ChannelDemuxTree"]:
193+
"""Pipelined N-way channel demultiplexer for valid/ready signaling. This
194+
implementation uses a tree structure of
195+
ChannelDemuxN_HalfStage_ReadyBlocking modules to reduce fanout pressure.
196+
Supports maximum half-throughput to save complexity and area.
197+
"""
198+
199+
root_sel_width = clog2(num_outs)
200+
# Simplify algorithm by making sure num_outs is a power of two.
201+
num_outs = 2**root_sel_width
202+
sel_width = branching_factor_log2
203+
fanout = 2**sel_width
204+
205+
class ChannelDemuxTree(Module):
206+
clk = Clock()
207+
rst = Reset()
208+
# Input now embeds selection bits alongside data.
209+
InPayloadType = StructType([
210+
("sel", Bits(clog2(num_outs))),
211+
("data", data_type),
212+
])
213+
inp = Input(Channel(InPayloadType))
214+
215+
# Outputs (data only).
216+
for i in range(num_outs):
217+
locals()[f"output_{i}"] = Output(Channel(data_type))
218+
219+
@generator
220+
def build(ports) -> None:
221+
assert branching_factor_log2 > 0
222+
if num_outs == 1:
223+
# Strip selection bits and return single channel.
224+
setattr(ports, "output_0", ports.inp.transform(lambda p: p.data))
225+
return
226+
227+
def payload_type(sel_width: int, next_sel_width: int) -> Type:
228+
return StructType([
229+
("sel", Bits(sel_width)),
230+
("next_sel", Bits(next_sel_width)),
231+
("data", data_type),
232+
])
233+
234+
def next_sel_width_calc(curr_sel_width) -> int:
235+
return max(curr_sel_width - sel_width, 0)
236+
237+
def payload_next(curr_msg: StructSignal) -> StructSignal:
238+
"""Given current level payload, produce next level payload by
239+
stripping off the top selection bits."""
240+
241+
next_sel_width = next_sel_width_calc(curr_msg.next_sel.type.width)
242+
curr_sel_width = curr_msg.next_sel.type.width
243+
new_sel_width = min(curr_sel_width, sel_width)
244+
return payload_type(
245+
new_sel_width,
246+
next_sel_width,
247+
)({
248+
# Use the MSB bits of next_sel as the next level selection.
249+
"sel": (curr_msg.next_sel[next_sel_width:]
250+
if curr_sel_width > 0 else Bits(0)(0)),
251+
"next_sel": (curr_msg.next_sel[:next_sel_width]
252+
if next_sel_width > 0 else Bits(0)(0)),
253+
"data": curr_msg.data,
254+
})
255+
256+
current_channels: List[ChannelSignal] = [
257+
ports.inp.transform(lambda m: payload_type(0, root_sel_width)({
258+
"sel": Bits(0)(0),
259+
"next_sel": m.sel,
260+
"data": m.data,
261+
}))
262+
]
263+
264+
curr_sel_width = root_sel_width
265+
level = 0
266+
while len(current_channels) < num_outs:
267+
next_level: List[ChannelSignal] = []
268+
level_num_outs = min(2**curr_sel_width, fanout)
269+
for i, c in enumerate(current_channels):
270+
dmux = ChannelDemuxN_HalfStage_ReadyBlocking(
271+
data_type,
272+
num_outs=level_num_outs,
273+
next_sel_width=next_sel_width_calc(curr_sel_width),
274+
)(
275+
clk=ports.clk,
276+
rst=ports.rst,
277+
inp=c.transform(payload_next),
278+
instance_name=f"demux_l{level}_i{i}",
279+
)
280+
for j in range(level_num_outs):
281+
next_level.append(dmux.get_out(j))
282+
current_channels = next_level
283+
curr_sel_width -= sel_width
284+
level += 1
285+
286+
for i in range(num_outs):
287+
# Strip off next_sel bits for final output.
288+
setattr(
289+
ports,
290+
f"output_{i}",
291+
current_channels[i].transform(lambda p: p.data),
292+
)
293+
294+
def get_out(self, index: int) -> ChannelSignal:
295+
return getattr(self, f"output_{index}")
296+
297+
return ChannelDemuxTree
298+
299+
96300
class ChannelMMIO(esi.ServiceImplementation):
97301
"""MMIO service implementation with MMIO bundle interfaces. Should be
98302
relatively easy to adapt to physical interfaces by wrapping the wires to
@@ -205,11 +409,25 @@ def build_read(ports, manifest_loc: int, table: Dict[int, AssignableSignal]):
205409

206410
# Build the demux/mux and assign the results of each appropriately.
207411
read_clients_clog2 = clog2(len(table))
208-
client_cmd_channels = esi.ChannelDemux(
209-
sel=sel_bits.pad_or_truncate(read_clients_clog2),
210-
input=client_cmd_chan,
211-
num_outs=len(table),
212-
instance_name="client_cmd_demux")
412+
# Combine selection bits and command channel payload into a struct channel for the demux tree.
413+
TreeInType = StructType([
414+
("sel", Bits(read_clients_clog2)),
415+
("data", client_cmd_chan.type.inner_type),
416+
])
417+
sel_bits_truncated = sel_bits.pad_or_truncate(read_clients_clog2)
418+
combined_cmd_chan = client_cmd_chan.transform(
419+
lambda cmd, _sel=sel_bits_truncated: TreeInType({
420+
"sel": _sel,
421+
"data": cmd
422+
}))
423+
demux_inst = ChannelDemuxTree_HalfStage_ReadyBlocking(
424+
client_cmd_chan.type.inner_type, len(table), branching_factor_log2=2)(
425+
clk=ports.clk,
426+
rst=ports.rst,
427+
inp=combined_cmd_chan,
428+
instance_name="client_cmd_demux",
429+
)
430+
client_cmd_channels = [demux_inst.get_out(i) for i in range(len(table))]
213431
client_data_channels = []
214432
for (idx, offset) in enumerate(sorted(table.keys())):
215433
bundle_wire = table[offset]
@@ -553,7 +771,9 @@ def TaggedWriteGearbox(input_bitwidth: int,
553771

554772
if output_bitwidth % 8 != 0:
555773
raise ValueError("Output bitwidth must be a multiple of 8.")
556-
input_pad_bits = 8 - (input_bitwidth % 8)
774+
input_pad_bits = 0
775+
if input_bitwidth % 8 != 0:
776+
input_pad_bits = 8 - (input_bitwidth % 8)
557777
input_padded_bitwidth = input_bitwidth + input_pad_bits
558778

559779
class TaggedWriteGearboxImpl(Module):
@@ -667,6 +887,57 @@ def build(ports):
667887
return TaggedWriteGearboxImpl
668888

669889

890+
@modparams
891+
def EmitEveryN(message_type: Type, N: int) -> type['EmitEveryNImpl']:
892+
"""Emit (forward) one message for every N input messages. The emitted message
893+
is the last one of the N received. N must be >= 1."""
894+
895+
if N < 1:
896+
raise ValueError("N must be >= 1")
897+
898+
class EmitEveryNImpl(Module):
899+
clk = Clock()
900+
rst = Reset()
901+
in_ = InputChannel(message_type)
902+
out = OutputChannel(message_type)
903+
904+
@generator
905+
def build(ports):
906+
ready_for_in = Wire(Bits(1))
907+
in_data, in_valid = ports.in_.unwrap(ready_for_in)
908+
xact = in_valid & ready_for_in
909+
910+
# Fast path: N == 1 -> pass-through.
911+
if N == 1:
912+
out_chan, out_ready = EmitEveryNImpl.out.type.wrap(in_data, in_valid)
913+
ready_for_in.assign(out_ready)
914+
ports.out = out_chan
915+
return
916+
917+
counter_width = clog2(N)
918+
increment = xact
919+
clear = Wire(Bits(1))
920+
counter = Counter(counter_width)(clk=ports.clk,
921+
rst=ports.rst,
922+
increment=increment,
923+
clear=clear)
924+
925+
# Capture last message of the group.
926+
last_msg = in_data.reg(ports.clk, ports.rst, ce=xact, name="last_msg")
927+
928+
hit_last = (counter.out == UInt(counter_width)(N - 1)) & xact
929+
out_valid = ControlReg(ports.clk, ports.rst, [hit_last], [clear])
930+
931+
out_chan, out_ready = EmitEveryNImpl.out.type.wrap(last_msg, out_valid)
932+
# Stall input while waiting for downstream to accept the aggregated output.
933+
ready_for_in.assign(~(out_valid & ~out_ready))
934+
clear.assign(out_valid & out_ready) # Clear after successful emit.
935+
936+
ports.out = out_chan
937+
938+
return EmitEveryNImpl
939+
940+
670941
def HostMemWriteProcessor(
671942
write_width: int, hostmem_module,
672943
reqs: List[esi._OutputBundleSetter]) -> type["HostMemWriteProcessorImpl"]:
@@ -695,6 +966,9 @@ class HostMemWriteProcessorImpl(Module):
695966

696967
@generator
697968
def build(ports):
969+
clk = ports.clk
970+
rst = ports.rst
971+
698972
# If there's no write clients, just create a no-op write bundle
699973
if len(reqs) == 0:
700974
req, _ = Channel(hostmem_module.UpstreamWriteReq).wrap(
@@ -731,8 +1005,8 @@ def build(ports):
7311005
# Pack up the bundle and assign the request channel.
7321006
write_req_bundle_type = esi.HostMem.write_req_bundle_type(
7331007
client_type.data)
734-
bundle_sig, froms = write_req_bundle_type.pack(
735-
ackTag=demuxed_acks.get_out(idx))
1008+
input_flit_ack = Wire(upstream_ack_tag.type)
1009+
bundle_sig, froms = write_req_bundle_type.pack(ackTag=input_flit_ack)
7361010

7371011
gearbox_mod = TaggedWriteGearbox(client_type.data.bitwidth, write_width)
7381012
gearbox_in_type = gearbox_mod.in_.type.inner_type
@@ -755,6 +1029,13 @@ def build(ports):
7551029
"data": m.data,
7561030
"valid_bytes": m.valid_bytes
7571031
})))
1032+
1033+
# Count the number of acks received from hostmem for this client
1034+
# and only send one back to the client per input.
1035+
ack_every_n = EmitEveryN(upstream_ack_tag.type, gearbox_mod.num_chunks)(
1036+
clk=clk, rst=rst, in_=demuxed_acks.get_out(idx))
1037+
input_flit_ack.assign(ack_every_n.out)
1038+
7581039
# Set the port for the client request.
7591040
setattr(ports, HostMemWriteProcessorImpl.reqPortMap[req], bundle_sig)
7601041

frontends/PyCDE/src/pycde/bsp/cosim.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ class ESI_Cosim_UserTopWrapper(Module):
5151
@generator
5252
def build(ports):
5353
user_module(clk=ports.clk, rst=ports.rst)
54+
esi.TelemetryMMIO(esi.Telemetry,
55+
appid=esi.AppID("__telemetry"),
56+
clk=ports.clk,
57+
rst=ports.rst)
58+
5459
if emulate_dma:
5560
ChannelEngineService(OneItemBuffersToHost, OneItemBuffersFromHost)(
5661
None,

0 commit comments

Comments
 (0)