@@ -2894,3 +2894,188 @@ def test_softmax_f32_f32(self) -> None:
28942894        output  =  torch .ops .cadence ._softmax_f32_f32 (input_tensor , dim = 1 )
28952895        self .assertEqual (output .dtype , torch .float32 )
28962896        self .assertEqual (output .shape , input_tensor .shape )
2897+ 
2898+     @expand ( 
2899+         [ 
2900+             ( 
2901+                 "basic_hidden_dim_4" , 
2902+                 torch .tensor ([[1.0 , 2.0 ]], dtype = torch .float32 ),  # inputs: 1x2  
2903+                 torch .tensor ( 
2904+                     [[0.5 , 0.5 , 0.5 , 0.5 ]], dtype = torch .float32  
2905+                 ),  # hidden: 1x4  
2906+                 torch .ones ( 
2907+                     (12 , 2 ), dtype = torch .int8  
2908+                 ),  # weights_inputs: 12x2 (3*4 x input_dim=2)  
2909+                 0.1 ,  # w_i_scale  
2910+                 torch .ones ((12 , 4 ), dtype = torch .int8 ),  # weights_hidden: 12x4 (3*4 x 4)  
2911+                 0.1 ,  # w_h_scale  
2912+                 torch .zeros (12 , dtype = torch .int8 ),  # bias_inputs: 12  
2913+                 0.1 ,  # b_i_scale  
2914+                 torch .zeros (12 , dtype = torch .int8 ),  # bias_hidden: 12  
2915+                 0.1 ,  # b_h_scale  
2916+             ), 
2917+             ( 
2918+                 "invalid_batch_size_2" , 
2919+                 torch .tensor ( 
2920+                     [[1.0 , 2.0 , 3.0 ], [2.0 , 3.0 , 4.0 ]], dtype = torch .float32  
2921+                 ),  # inputs: 2x3  
2922+                 torch .tensor ( 
2923+                     [[0.5 , 0.5 , 0.5 , 0.5 ], [0.3 , 0.3 , 0.3 , 0.3 ]], dtype = torch .float32  
2924+                 ),  # hidden: 2x4  
2925+                 torch .ones ((12 , 3 ), dtype = torch .int8 ),  # weights_inputs: 12x3  
2926+                 0.1 ,  # w_i_scale  
2927+                 torch .ones ((12 , 4 ), dtype = torch .int8 ),  # weights_hidden: 12x4  
2928+                 0.1 ,  # w_h_scale  
2929+                 torch .zeros (12 , dtype = torch .int8 ),  # bias_inputs: 12  
2930+                 0.1 ,  # b_i_scale  
2931+                 torch .zeros (12 , dtype = torch .int8 ),  # bias_hidden: 12  
2932+                 0.1 ,  # b_h_scale  
2933+             ), 
2934+             ( 
2935+                 "non_zero_biases" , 
2936+                 torch .tensor ([[1.0 , 1.0 ]], dtype = torch .float32 ),  # inputs: 1x2  
2937+                 torch .zeros ((1 , 4 ), dtype = torch .float32 ),  # hidden: 1x4  
2938+                 torch .ones ((12 , 2 ), dtype = torch .int8 ),  # weights_inputs: 12x2  
2939+                 0.2 ,  # w_i_scale  
2940+                 torch .ones ((12 , 4 ), dtype = torch .int8 ),  # weights_hidden: 12x4  
2941+                 0.1 ,  # w_h_scale  
2942+                 torch .tensor ( 
2943+                     [1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 , 3 , 3 , 3 , 3 ], dtype = torch .int8  
2944+                 ),  # bias_inputs: 12  
2945+                 0.1 ,  # b_i_scale  
2946+                 torch .tensor ( 
2947+                     [1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 , 3 , 3 , 3 , 3 ], dtype = torch .int8  
2948+                 ),  # bias_hidden: 12  
2949+                 0.1 ,  # b_h_scale  
2950+             ), 
2951+             ( 
2952+                 "negative_weights" , 
2953+                 torch .tensor ([[1.0 , - 1.0 ]], dtype = torch .float32 ),  # inputs: 1x2  
2954+                 torch .tensor ( 
2955+                     [[0.5 , - 0.5 , 0.5 , - 0.5 ]], dtype = torch .float32  
2956+                 ),  # hidden: 1x4  
2957+                 torch .tensor ( 
2958+                     [[1 , - 1 ], [- 1 , 1 ]] *  6 , dtype = torch .int8  
2959+                 ),  # weights_inputs: 12x2 (alternating pattern)  
2960+                 0.1 ,  # w_i_scale  
2961+                 torch .tensor ( 
2962+                     [[1 , - 1 , 1 , - 1 ], [- 1 , 1 , - 1 , 1 ]] *  6 , dtype = torch .int8  
2963+                 ),  # weights_hidden: 12x4 (alternating pattern)  
2964+                 0.1 ,  # w_h_scale  
2965+                 torch .zeros (12 , dtype = torch .int8 ),  # bias_inputs: 12  
2966+                 0.1 ,  # b_i_scale  
2967+                 torch .zeros (12 , dtype = torch .int8 ),  # bias_hidden: 12  
2968+                 0.1 ,  # b_h_scale  
2969+             ), 
2970+             ( 
2971+                 "hidden_dim_8" , 
2972+                 torch .tensor ([[1.0 , 2.0 , 3.0 ]], dtype = torch .float32 ),  # inputs: 1x3  
2973+                 torch .tensor ( 
2974+                     [[0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 ]], dtype = torch .float32  
2975+                 ),  # hidden: 1x8  
2976+                 torch .ones ((24 , 3 ), dtype = torch .int8 ),  # weights_inputs: 24x3 (3*8 x 3)  
2977+                 0.1 ,  # w_i_scale  
2978+                 torch .ones ((24 , 8 ), dtype = torch .int8 ),  # weights_hidden: 24x8 (3*8 x 8)  
2979+                 0.1 ,  # w_h_scale  
2980+                 torch .zeros (24 , dtype = torch .int8 ),  # bias_inputs: 24  
2981+                 0.1 ,  # b_i_scale  
2982+                 torch .zeros (24 , dtype = torch .int8 ),  # bias_hidden: 24  
2983+                 0.1 ,  # b_h_scale  
2984+             ), 
2985+         ] 
2986+     ) 
2987+     def  test_quantized_w8a32_gru (
2988+         self ,
2989+         name : str ,
2990+         inputs : torch .Tensor ,
2991+         hidden : torch .Tensor ,
2992+         weights_inputs : torch .Tensor ,
2993+         w_i_scale : float ,
2994+         weights_hidden : torch .Tensor ,
2995+         w_h_scale : float ,
2996+         bias_inputs : torch .Tensor ,
2997+         b_i_scale : float ,
2998+         bias_hidden : torch .Tensor ,
2999+         b_h_scale : float ,
3000+     ) ->  None :
3001+ 
3002+         if  name  ==  "invalid_batch_size_2" :
3003+             with  self .assertRaises (ValueError ) as  context :
3004+                 torch .ops .cadence .quantized_w8a32_gru (
3005+                     inputs ,
3006+                     hidden ,
3007+                     weights_inputs ,
3008+                     w_i_scale ,
3009+                     weights_hidden ,
3010+                     w_h_scale ,
3011+                     bias_inputs ,
3012+                     b_i_scale ,
3013+                     bias_hidden ,
3014+                     b_h_scale ,
3015+                 )
3016+             self .assertIn (
3017+                 "Leading dimension of hidden state must be 1" , str (context .exception )
3018+             )
3019+             return 
3020+ 
3021+         output  =  torch .ops .cadence .quantized_w8a32_gru (
3022+             inputs ,
3023+             hidden ,
3024+             weights_inputs ,
3025+             w_i_scale ,
3026+             weights_hidden ,
3027+             w_h_scale ,
3028+             bias_inputs ,
3029+             b_i_scale ,
3030+             bias_hidden ,
3031+             b_h_scale ,
3032+         )
3033+ 
3034+         # Verify output properties 
3035+         self .assertEqual (
3036+             output .dtype ,
3037+             torch .float32 ,
3038+             f"Output dtype should be float32 in { name }  ,
3039+         )
3040+         self .assertEqual (
3041+             output .shape ,
3042+             (2 , hidden .shape [- 1 ]),
3043+             f"Output shape should match { (2 , hidden .shape [- 1 ])} { name }  ,
3044+         )
3045+         assert  isinstance (output , torch .Tensor )
3046+ 
3047+         # Verify output is bounded: GRU hidden state is a convex combination of 
3048+         # tanh([-1,1]) and previous hidden([-1,1]), so output should be in [-1,1] 
3049+         self .assertTrue (
3050+             torch .all (output  >=  - 1.0 ) and  torch .all (output  <=  1.0 ),
3051+             f"Output values should be in [-1.1, 1.1] in { name } { output .min ():.4f} { output .max ():.4f}  ,
3052+         )
3053+ 
3054+     def  test_quantized_w8a32_gru_invalid_hidden_dim (self ) ->  None :
3055+         # Test that non-multiple of 4 hidden dimension raises error 
3056+         inputs  =  torch .tensor ([[1.0 , 2.0 ]], dtype = torch .float32 )  # 1x2 
3057+         hidden  =  torch .tensor (
3058+             [[0.5 , 0.5 , 0.5 ]], dtype = torch .float32 
3059+         )  # 1x3 (not divisible by 4) 
3060+         weights_inputs  =  torch .zeros ((9 , 2 ), dtype = torch .int8 )  # 9x2 
3061+         weights_hidden  =  torch .zeros ((9 , 3 ), dtype = torch .int8 )  # 9x3 
3062+         bias_inputs  =  torch .zeros (9 , dtype = torch .int8 )
3063+         bias_hidden  =  torch .zeros (9 , dtype = torch .int8 )
3064+ 
3065+         with  self .assertRaises (ValueError ) as  context :
3066+             torch .ops .cadence .quantized_w8a32_gru (
3067+                 inputs ,
3068+                 hidden ,
3069+                 weights_inputs ,
3070+                 0.1 ,
3071+                 weights_hidden ,
3072+                 0.1 ,
3073+                 bias_inputs ,
3074+                 0.1 ,
3075+                 bias_hidden ,
3076+                 0.1 ,
3077+             )
3078+ 
3079+         self .assertIn (
3080+             "Hidden dimension must be a multiple of 4" , str (context .exception )
3081+         )
0 commit comments