77
88from ..common import Clock , Input , InputChannel , Output , OutputChannel , Reset
99from ..constructs import (AssignableSignal , ControlReg , Counter , Mux , NamedWire ,
10- Reg , Wire )
10+ Wire )
1111from .. import esi
1212from ..module import Module , generator , modparams
13- from ..signals import BitsSignal , BundleSignal , ChannelSignal , ClockSignal
13+ from ..signals import BitsSignal , ChannelSignal , StructSignal
1414from ..support import clog2
1515from ..types import (Array , Bits , Bundle , BundledChannel , Channel ,
1616 ChannelDirection , StructType , Type , UInt )
@@ -93,6 +93,210 @@ def build(ports):
9393 return HeaderMMIO
9494
9595
96+ @modparams
97+ def ChannelDemuxN_HalfStage_ReadyBlocking (
98+ data_type : Type , num_outs : int ,
99+ next_sel_width : int ) -> type ["ChannelDemuxNImpl" ]:
100+ """N-way channel demultiplexer for valid/ready signaling. Contains
101+ valid/ready registers on the output channels. The selection signal is now
102+ embedded in the input channel payload as a struct {sel, data}. Input
103+ signals ready when the selected output register is empty."""
104+
105+ assert num_outs >= 1 , "num_outs must be at least 1."
106+
107+ class ChannelDemuxNImpl (Module ):
108+ clk = Clock ()
109+ rst = Reset ()
110+
111+ # Input channel now carries selection along with data.
112+ InPayloadType = StructType ([
113+ ("sel" , Bits (clog2 (num_outs ))),
114+ ("next_sel" , Bits (next_sel_width )),
115+ ("data" , data_type ),
116+ ])
117+ inp = Input (Channel (InPayloadType ))
118+ OutPayloadType = StructType ([
119+ ("next_sel" , Bits (next_sel_width )),
120+ ("data" , data_type ),
121+ ])
122+ # Outputs are channels of OutPayloadType, which includes both 'next_sel' and 'data' fields.
123+ for i in range (num_outs ):
124+ locals ()[f"output_{ i } " ] = Output (Channel (OutPayloadType ))
125+
126+ @generator
127+ def generate (ports ) -> None :
128+ # Half-stage demux: one register per output channel. Input is ready
129+ # when the currently selected output register is empty (not valid).
130+ clk = ports .clk
131+ rst = ports .rst
132+ sel_width = clog2 (num_outs )
133+
134+ # Unwrap input with backpressure from selected output register.
135+ input_ready = Wire (Bits (1 ), name = "input_ready" )
136+ in_payload , in_valid = ports .inp .unwrap (input_ready )
137+ in_sel = in_payload .sel
138+ in_next_sel = in_payload .next_sel
139+ in_data = in_payload .data
140+
141+ # Track per-output valid regs and build a purely combinational
142+ # expression 'selected_valid_expr' = OR_i((sel==i)&valid_i). Avoid
143+ # assigning to a Wire multiple times.
144+ valid_regs : List [BitsSignal ] = []
145+ selected_valid_expr = Bits (1 )(0 )
146+
147+ for i in range (num_outs ):
148+ # Write when input transaction targets this output and output not holding data yet.
149+ will_write = Wire (Bits (1 ), name = f"will_write_{ i } " )
150+ write_cond = (in_valid & input_ready & (in_sel == Bits (sel_width )(i )))
151+ will_write .assign (write_cond )
152+
153+ # Data and next_sel registers.
154+ out_msg_reg = ChannelDemuxNImpl .OutPayloadType ({
155+ "next_sel" : in_next_sel ,
156+ "data" : in_data
157+ }).reg (clk = clk , rst = rst , ce = will_write , name = f"out{ i } _msg_reg" )
158+
159+ # Valid register cleared on successful downstream consume.
160+ consume = Wire (Bits (1 ), name = f"consume_{ i } " )
161+ valid_reg = ControlReg (
162+ clk = clk ,
163+ rst = rst ,
164+ asserts = [will_write ],
165+ resets = [consume ],
166+ name = f"out{ i } _valid_reg" ,
167+ )
168+ valid_regs .append (valid_reg )
169+
170+ # Channel wrapper.
171+ ch_sig , ch_ready = Channel (ChannelDemuxNImpl .OutPayloadType ).wrap (
172+ out_msg_reg , valid_reg )
173+ setattr (ports , f"output_{ i } " , ch_sig )
174+ consume .assign (valid_reg & ch_ready )
175+
176+ # Accumulate selected_valid expression.
177+ selected_valid_expr = (selected_valid_expr | (
178+ (in_sel == Bits (sel_width )(i )) & valid_reg )).as_bits ()
179+
180+ # Input ready only when selected output has no valid data latched.
181+ input_ready .assign ((selected_valid_expr ^ Bits (1 )(1 )).as_bits ())
182+
183+ def get_out (self , index : int ) -> ChannelSignal :
184+ return getattr (self , f"output_{ index } " )
185+
186+ return ChannelDemuxNImpl
187+
188+
189+ @modparams
190+ def ChannelDemuxTree_HalfStage_ReadyBlocking (
191+ data_type : Type , num_outs : int ,
192+ branching_factor_log2 : int ) -> type ["ChannelDemuxTree" ]:
193+ """Pipelined N-way channel demultiplexer for valid/ready signaling. This
194+ implementation uses a tree structure of
195+ ChannelDemuxN_HalfStage_ReadyBlocking modules to reduce fanout pressure.
196+ Supports maximum half-throughput to save complexity and area.
197+ """
198+
199+ root_sel_width = clog2 (num_outs )
200+ # Simplify algorithm by making sure num_outs is a power of two.
201+ num_outs = 2 ** root_sel_width
202+ sel_width = branching_factor_log2
203+ fanout = 2 ** sel_width
204+
205+ class ChannelDemuxTree (Module ):
206+ clk = Clock ()
207+ rst = Reset ()
208+ # Input now embeds selection bits alongside data.
209+ InPayloadType = StructType ([
210+ ("sel" , Bits (clog2 (num_outs ))),
211+ ("data" , data_type ),
212+ ])
213+ inp = Input (Channel (InPayloadType ))
214+
215+ # Outputs (data only).
216+ for i in range (num_outs ):
217+ locals ()[f"output_{ i } " ] = Output (Channel (data_type ))
218+
219+ @generator
220+ def build (ports ) -> None :
221+ assert branching_factor_log2 > 0
222+ if num_outs == 1 :
223+ # Strip selection bits and return single channel.
224+ setattr (ports , "output_0" , ports .inp .transform (lambda p : p .data ))
225+ return
226+
227+ def payload_type (sel_width : int , next_sel_width : int ) -> Type :
228+ return StructType ([
229+ ("sel" , Bits (sel_width )),
230+ ("next_sel" , Bits (next_sel_width )),
231+ ("data" , data_type ),
232+ ])
233+
234+ def next_sel_width_calc (curr_sel_width ) -> int :
235+ return max (curr_sel_width - sel_width , 0 )
236+
237+ def payload_next (curr_msg : StructSignal ) -> StructSignal :
238+ """Given current level payload, produce next level payload by
239+ stripping off the top selection bits."""
240+
241+ next_sel_width = next_sel_width_calc (curr_msg .next_sel .type .width )
242+ curr_sel_width = curr_msg .next_sel .type .width
243+ new_sel_width = min (curr_sel_width , sel_width )
244+ return payload_type (
245+ new_sel_width ,
246+ next_sel_width ,
247+ )({
248+ # Use the MSB bits of next_sel as the next level selection.
249+ "sel" : (curr_msg .next_sel [next_sel_width :]
250+ if curr_sel_width > 0 else Bits (0 )(0 )),
251+ "next_sel" : (curr_msg .next_sel [:next_sel_width ]
252+ if next_sel_width > 0 else Bits (0 )(0 )),
253+ "data" : curr_msg .data ,
254+ })
255+
256+ current_channels : List [ChannelSignal ] = [
257+ ports .inp .transform (lambda m : payload_type (0 , root_sel_width )({
258+ "sel" : Bits (0 )(0 ),
259+ "next_sel" : m .sel ,
260+ "data" : m .data ,
261+ }))
262+ ]
263+
264+ curr_sel_width = root_sel_width
265+ level = 0
266+ while len (current_channels ) < num_outs :
267+ next_level : List [ChannelSignal ] = []
268+ level_num_outs = min (2 ** curr_sel_width , fanout )
269+ for i , c in enumerate (current_channels ):
270+ dmux = ChannelDemuxN_HalfStage_ReadyBlocking (
271+ data_type ,
272+ num_outs = level_num_outs ,
273+ next_sel_width = next_sel_width_calc (curr_sel_width ),
274+ )(
275+ clk = ports .clk ,
276+ rst = ports .rst ,
277+ inp = c .transform (payload_next ),
278+ instance_name = f"demux_l{ level } _i{ i } " ,
279+ )
280+ for j in range (level_num_outs ):
281+ next_level .append (dmux .get_out (j ))
282+ current_channels = next_level
283+ curr_sel_width -= sel_width
284+ level += 1
285+
286+ for i in range (num_outs ):
287+ # Strip off next_sel bits for final output.
288+ setattr (
289+ ports ,
290+ f"output_{ i } " ,
291+ current_channels [i ].transform (lambda p : p .data ),
292+ )
293+
294+ def get_out (self , index : int ) -> ChannelSignal :
295+ return getattr (self , f"output_{ index } " )
296+
297+ return ChannelDemuxTree
298+
299+
96300class ChannelMMIO (esi .ServiceImplementation ):
97301 """MMIO service implementation with MMIO bundle interfaces. Should be
98302 relatively easy to adapt to physical interfaces by wrapping the wires to
@@ -205,11 +409,25 @@ def build_read(ports, manifest_loc: int, table: Dict[int, AssignableSignal]):
205409
206410 # Build the demux/mux and assign the results of each appropriately.
207411 read_clients_clog2 = clog2 (len (table ))
208- client_cmd_channels = esi .ChannelDemux (
209- sel = sel_bits .pad_or_truncate (read_clients_clog2 ),
210- input = client_cmd_chan ,
211- num_outs = len (table ),
212- instance_name = "client_cmd_demux" )
412+ # Combine selection bits and command channel payload into a struct channel for the demux tree.
413+ TreeInType = StructType ([
414+ ("sel" , Bits (read_clients_clog2 )),
415+ ("data" , client_cmd_chan .type .inner_type ),
416+ ])
417+ sel_bits_truncated = sel_bits .pad_or_truncate (read_clients_clog2 )
418+ combined_cmd_chan = client_cmd_chan .transform (
419+ lambda cmd , _sel = sel_bits_truncated : TreeInType ({
420+ "sel" : _sel ,
421+ "data" : cmd
422+ }))
423+ demux_inst = ChannelDemuxTree_HalfStage_ReadyBlocking (
424+ client_cmd_chan .type .inner_type , len (table ), branching_factor_log2 = 2 )(
425+ clk = ports .clk ,
426+ rst = ports .rst ,
427+ inp = combined_cmd_chan ,
428+ instance_name = "client_cmd_demux" ,
429+ )
430+ client_cmd_channels = [demux_inst .get_out (i ) for i in range (len (table ))]
213431 client_data_channels = []
214432 for (idx , offset ) in enumerate (sorted (table .keys ())):
215433 bundle_wire = table [offset ]
@@ -553,7 +771,9 @@ def TaggedWriteGearbox(input_bitwidth: int,
553771
554772 if output_bitwidth % 8 != 0 :
555773 raise ValueError ("Output bitwidth must be a multiple of 8." )
556- input_pad_bits = 8 - (input_bitwidth % 8 )
774+ input_pad_bits = 0
775+ if input_bitwidth % 8 != 0 :
776+ input_pad_bits = 8 - (input_bitwidth % 8 )
557777 input_padded_bitwidth = input_bitwidth + input_pad_bits
558778
559779 class TaggedWriteGearboxImpl (Module ):
@@ -667,6 +887,57 @@ def build(ports):
667887 return TaggedWriteGearboxImpl
668888
669889
890+ @modparams
891+ def EmitEveryN (message_type : Type , N : int ) -> type ['EmitEveryNImpl' ]:
892+ """Emit (forward) one message for every N input messages. The emitted message
893+ is the last one of the N received. N must be >= 1."""
894+
895+ if N < 1 :
896+ raise ValueError ("N must be >= 1" )
897+
898+ class EmitEveryNImpl (Module ):
899+ clk = Clock ()
900+ rst = Reset ()
901+ in_ = InputChannel (message_type )
902+ out = OutputChannel (message_type )
903+
904+ @generator
905+ def build (ports ):
906+ ready_for_in = Wire (Bits (1 ))
907+ in_data , in_valid = ports .in_ .unwrap (ready_for_in )
908+ xact = in_valid & ready_for_in
909+
910+ # Fast path: N == 1 -> pass-through.
911+ if N == 1 :
912+ out_chan , out_ready = EmitEveryNImpl .out .type .wrap (in_data , in_valid )
913+ ready_for_in .assign (out_ready )
914+ ports .out = out_chan
915+ return
916+
917+ counter_width = clog2 (N )
918+ increment = xact
919+ clear = Wire (Bits (1 ))
920+ counter = Counter (counter_width )(clk = ports .clk ,
921+ rst = ports .rst ,
922+ increment = increment ,
923+ clear = clear )
924+
925+ # Capture last message of the group.
926+ last_msg = in_data .reg (ports .clk , ports .rst , ce = xact , name = "last_msg" )
927+
928+ hit_last = (counter .out == UInt (counter_width )(N - 1 )) & xact
929+ out_valid = ControlReg (ports .clk , ports .rst , [hit_last ], [clear ])
930+
931+ out_chan , out_ready = EmitEveryNImpl .out .type .wrap (last_msg , out_valid )
932+ # Stall input while waiting for downstream to accept the aggregated output.
933+ ready_for_in .assign (~ (out_valid & ~ out_ready ))
934+ clear .assign (out_valid & out_ready ) # Clear after successful emit.
935+
936+ ports .out = out_chan
937+
938+ return EmitEveryNImpl
939+
940+
670941def HostMemWriteProcessor (
671942 write_width : int , hostmem_module ,
672943 reqs : List [esi ._OutputBundleSetter ]) -> type ["HostMemWriteProcessorImpl" ]:
@@ -695,6 +966,9 @@ class HostMemWriteProcessorImpl(Module):
695966
696967 @generator
697968 def build (ports ):
969+ clk = ports .clk
970+ rst = ports .rst
971+
698972 # If there's no write clients, just create a no-op write bundle
699973 if len (reqs ) == 0 :
700974 req , _ = Channel (hostmem_module .UpstreamWriteReq ).wrap (
@@ -731,8 +1005,8 @@ def build(ports):
7311005 # Pack up the bundle and assign the request channel.
7321006 write_req_bundle_type = esi .HostMem .write_req_bundle_type (
7331007 client_type .data )
734- bundle_sig , froms = write_req_bundle_type . pack (
735- ackTag = demuxed_acks . get_out ( idx ) )
1008+ input_flit_ack = Wire ( upstream_ack_tag . type )
1009+ bundle_sig , froms = write_req_bundle_type . pack ( ackTag = input_flit_ack )
7361010
7371011 gearbox_mod = TaggedWriteGearbox (client_type .data .bitwidth , write_width )
7381012 gearbox_in_type = gearbox_mod .in_ .type .inner_type
@@ -755,6 +1029,13 @@ def build(ports):
7551029 "data" : m .data ,
7561030 "valid_bytes" : m .valid_bytes
7571031 })))
1032+
1033+ # Count the number of acks received from hostmem for this client
1034+ # and only send one back to the client per input.
1035+ ack_every_n = EmitEveryN (upstream_ack_tag .type , gearbox_mod .num_chunks )(
1036+ clk = clk , rst = rst , in_ = demuxed_acks .get_out (idx ))
1037+ input_flit_ack .assign (ack_every_n .out )
1038+
7581039 # Set the port for the client request.
7591040 setattr (ports , HostMemWriteProcessorImpl .reqPortMap [req ], bundle_sig )
7601041
0 commit comments