@@ -221,8 +221,16 @@ def _produce_kif(layer: Layer) -> KIF_t:
221221
222222@_produce_kif .register
223223def _ (layer : Input ):
224- k = np .ones (get_output_shape (layer ), dtype = np .int16 )
225- i = f = np .full (get_output_shape (layer ), 126 , dtype = np .int16 )
224+ shape = get_output_shape (layer )
225+ if layer .attributes .get ('trusted' , False ):
226+ precision : FixedPrecisionType = layer .get_output_variable ().type .precision
227+ k , i , f = precision .signed , precision .integer - precision .signed , precision .fractional
228+ k = np .full (shape , k , dtype = np .int16 )
229+ i = np .full (shape , i , dtype = np .int16 )
230+ f = np .full (shape , f , dtype = np .int16 )
231+ else :
232+ k = np .ones (shape , dtype = np .int16 )
233+ i = f = np .full (shape , 126 , dtype = np .int16 )
226234 return k , i , f
227235
228236
@@ -630,8 +638,8 @@ def kif_arrs_to_ints(arr: tuple[np.ndarray, np.ndarray, np.ndarray]):
630638 return tuple (int (np .max (a )) for a in arr )
631639
632640
633- def produce_kif (layer : Layer ) -> KIF_t :
634- if layer .attributes .get ('_produce_kif' ):
641+ def produce_kif (layer : Layer , force_reset = False ) -> KIF_t :
642+ if layer .attributes .get ('_produce_kif' ) and not force_reset :
635643 return layer .attributes ['_produce_kif' ]
636644 kif = _produce_kif (layer )
637645 layer .attributes ['_produce_kif' ] = kif
@@ -885,7 +893,9 @@ def transform(self, model: 'ModelGraph'):
885893 for node in model .graph .values ():
886894 if node .attributes .get ('bit_exact_transformed' ):
887895 continue
888- produce_kif (node ) # Shrink FixedPointQuantizer bits when possible to be used in backward flow (requested_kif).
896+ produce_kif (
897+ node , force_reset = True
898+ ) # Shrink FixedPointQuantizer bits when possible to be used in backward flow (requested_kif).
889899
890900 for node in model .graph .values ():
891901 if node .attributes .get ('bit_exact_transformed' ):
@@ -894,14 +904,31 @@ def transform(self, model: 'ModelGraph'):
894904 node .attributes ['bit_exact_transformed' ] = True
895905
896906 for node in model .graph .values ():
897- if node . attributes . get ( '_produce_kif' ) :
907+ if '_produce_kif' in node . attributes :
898908 del node .attributes ['_produce_kif' ]
899- if node . attributes . get ( '_request_kif' ) :
909+ if '_request_kif' in node . attributes :
900910 del node .attributes ['_request_kif' ]
901911
902912 return True
903913
904914
915+ def get_output_layers_and_quantizers (
916+ node : Layer , layers : list | None = None , quantizers : list | None = None
917+ ) -> tuple [list [Layer ], list [FixedPointQuantizer ]]:
918+
919+ layers = layers if layers is not None else []
920+ quantizers = quantizers if quantizers is not None else []
921+ for _node in get_output_layers (node ):
922+ if isinstance (_node , FixedPointQuantizer ):
923+ quantizers .append (_node )
924+ elif isinstance (_node , (Reshape , Transpose , Concatenate )):
925+ layers .append (_node )
926+ get_output_layers_and_quantizers (_node , layers , quantizers )
927+ else :
928+ raise ValueError (f'Layer { node .name } ({ node .class_name } ) unexpected input layer chain.' )
929+ return layers , quantizers
930+
931+
905932class FixInputPrecision (OptimizerPass ):
906933 def match (self , node : Layer ):
907934 if not isinstance (node , Input ):
@@ -911,21 +938,17 @@ def match(self, node: Layer):
911938 return node .get_output_variable ().type .precision .width > 100
912939
913940 def transform (self , model , node : Layer ):
914- out_layers : list [FixedPointQuantizer ] = get_output_layers (node ) # type: ignore
915- for layer in out_layers :
916- assert isinstance (
917- layer , FixedPointQuantizer
918- ), f'Input { node .name } connected to non-quantizer { layer .name } with non-trivial configuration'
941+ layers , out_quantizers = get_output_layers_and_quantizers (node )
919942
920- if len (out_layers ) == 0 : # Input connected to nothing
943+ if len (out_quantizers ) == 0 : # Input connected to nothing
921944 new_type = to_hls4ml_fixed (0 , 0 , 1 , f'{ node .name } _t' )
922945 node .get_output_variable ().type = new_type
923946 node .model .config .layer_name_precision [node .name ] = str (new_type )
924947 return False
925948
926- sat_modes = [l .SAT for l in out_layers ]
949+ sat_modes = [l .SAT for l in out_quantizers ]
927950 sat_modes_set = set (sat_modes )
928- rnd_modes = [l .RND for l in out_layers ]
951+ rnd_modes = [l .RND for l in out_quantizers ]
929952 rnd_modes_set = set (rnd_modes )
930953 illegal_sat_modes = sat_modes_set - {'WRAP' , 'SAT' , 'SAT_SYM' }
931954 illegal_rnd_modes = rnd_modes_set - {'TRN' , 'RND' }
@@ -936,7 +959,7 @@ def transform(self, model, node: Layer):
936959 if illegal_rnd_modes :
937960 warn (f'Saturation mode { illegal_rnd_modes } may compromise bit-exactness. Forcing at maximum 24 fractional bits.' )
938961
939- kifs = [_produce_kif (l ) for l in out_layers ]
962+ kifs = [_produce_kif (l ) for l in out_quantizers ]
940963 i = np .max ([np .max (i ) for _ , i , _ in kifs ])
941964 k = np .max ([np .max (k ) for k , _ , _ in kifs ])
942965 if illegal_rnd_modes :
@@ -951,4 +974,15 @@ def transform(self, model, node: Layer):
951974 new_type .precision .saturation_mode = 'SAT'
952975 node .get_output_variable ().type = new_type
953976 node .model .config .layer_name_precision [node .name ] = str (new_type )
977+ node .attributes ['trusted' ] = True
978+
979+ for layer in layers :
980+ produce_kif (layer , force_reset = True )
981+ for layer in layers :
982+ register_precision (layer )
983+ for layer in layers :
984+ if '_produce_kif' in layer .attributes :
985+ del layer .attributes ['_produce_kif' ]
986+ if '_request_kif' in layer .attributes :
987+ del layer .attributes ['_request_kif' ]
954988 return False
0 commit comments