@@ -356,6 +356,14 @@ class Pad(OpFromGraph):
356356 Wrapper Op for Pad graphs
357357 """
358358
359+ def __init__ (self , inputs , outputs , pad_mode , reflect_type = None , kind = None ):
360+ self .pad_mode = pad_mode
361+ self .reflect_type = reflect_type
362+ self .kind = kind
363+ self .reflect_type = reflect_type
364+
365+ super ().__init__ (inputs = inputs , outputs = outputs )
366+
359367
360368def pad (x : TensorLike , pad_width : TensorLike , mode : PadMode = "constant" , ** kwargs ):
361369 if any (value not in allowed_kwargs [mode ] for value in kwargs .keys ()):
@@ -388,9 +396,6 @@ def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwar
388396 stat_length = as_tensor (stat_length , name = "stat_length" )
389397 inputs += [stat_length ]
390398
391- attrs .update (
392- {"stat_func" : stat_func , "stat_length_input" : stat_length is not None }
393- )
394399 outputs = _stat_pad (x , pad_width , stat_func , stat_length )
395400
396401 elif mode == "linear_ramp" :
@@ -401,15 +406,14 @@ def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwar
401406 outputs = _linear_ramp_pad (x , pad_width , end_values )
402407
403408 elif mode == "wrap" :
404- attrs .update ({"kind" : "wrap" })
405409 outputs = _looping_pad (x , pad_width , kind = "wrap" )
406410
407411 elif mode == "symmetric" :
408412 reflect_type = kwargs .pop ("reflect_type" , "even" )
409413 if reflect_type == "odd" :
410414 raise NotImplementedError ("Odd reflection not implemented" )
411415
412- attrs .update ({"kind " : reflect_type })
416+ attrs .update ({"reflect_type " : reflect_type })
413417 outputs = _looping_pad (x , pad_width , kind = "symmetric" )
414418
415419 elif mode == "reflect" :
@@ -421,11 +425,7 @@ def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwar
421425 else :
422426 raise ValueError (f"Invalid mode: { mode } " )
423427
424- op = Pad (inputs = inputs , outputs = [outputs ])(* inputs ) # type: ignore
425-
426- setattr (op , "pad_mode" , mode )
427- for pad_arg , value in attrs .items ():
428- setattr (op , pad_arg , value )
428+ op = Pad (inputs = inputs , outputs = [outputs ], pad_mode = mode , ** attrs )(* inputs )
429429 return op
430430
431431
0 commit comments