@@ -2894,3 +2894,188 @@ def test_softmax_f32_f32(self) -> None:
28942894 output = torch .ops .cadence ._softmax_f32_f32 (input_tensor , dim = 1 )
28952895 self .assertEqual (output .dtype , torch .float32 )
28962896 self .assertEqual (output .shape , input_tensor .shape )
2897+
2898+ @expand (
2899+ [
2900+ (
2901+ "basic_hidden_dim_4" ,
2902+ torch .tensor ([[1.0 , 2.0 ]], dtype = torch .float32 ), # inputs: 1x2
2903+ torch .tensor (
2904+ [[0.5 , 0.5 , 0.5 , 0.5 ]], dtype = torch .float32
2905+ ), # hidden: 1x4
2906+ torch .ones (
2907+ (12 , 2 ), dtype = torch .int8
2908+ ), # weights_inputs: 12x2 (3*4 x input_dim=2)
2909+ 0.1 , # w_i_scale
2910+ torch .ones ((12 , 4 ), dtype = torch .int8 ), # weights_hidden: 12x4 (3*4 x 4)
2911+ 0.1 , # w_h_scale
2912+ torch .zeros (12 , dtype = torch .int8 ), # bias_inputs: 12
2913+ 0.1 , # b_i_scale
2914+ torch .zeros (12 , dtype = torch .int8 ), # bias_hidden: 12
2915+ 0.1 , # b_h_scale
2916+ ),
2917+ (
2918+ "invalid_batch_size_2" ,
2919+ torch .tensor (
2920+ [[1.0 , 2.0 , 3.0 ], [2.0 , 3.0 , 4.0 ]], dtype = torch .float32
2921+ ), # inputs: 2x3
2922+ torch .tensor (
2923+ [[0.5 , 0.5 , 0.5 , 0.5 ], [0.3 , 0.3 , 0.3 , 0.3 ]], dtype = torch .float32
2924+ ), # hidden: 2x4
2925+ torch .ones ((12 , 3 ), dtype = torch .int8 ), # weights_inputs: 12x3
2926+ 0.1 , # w_i_scale
2927+ torch .ones ((12 , 4 ), dtype = torch .int8 ), # weights_hidden: 12x4
2928+ 0.1 , # w_h_scale
2929+ torch .zeros (12 , dtype = torch .int8 ), # bias_inputs: 12
2930+ 0.1 , # b_i_scale
2931+ torch .zeros (12 , dtype = torch .int8 ), # bias_hidden: 12
2932+ 0.1 , # b_h_scale
2933+ ),
2934+ (
2935+ "non_zero_biases" ,
2936+ torch .tensor ([[1.0 , 1.0 ]], dtype = torch .float32 ), # inputs: 1x2
2937+ torch .zeros ((1 , 4 ), dtype = torch .float32 ), # hidden: 1x4
2938+ torch .ones ((12 , 2 ), dtype = torch .int8 ), # weights_inputs: 12x2
2939+ 0.2 , # w_i_scale
2940+ torch .ones ((12 , 4 ), dtype = torch .int8 ), # weights_hidden: 12x4
2941+ 0.1 , # w_h_scale
2942+ torch .tensor (
2943+ [1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 , 3 , 3 , 3 , 3 ], dtype = torch .int8
2944+ ), # bias_inputs: 12
2945+ 0.1 , # b_i_scale
2946+ torch .tensor (
2947+ [1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 , 3 , 3 , 3 , 3 ], dtype = torch .int8
2948+ ), # bias_hidden: 12
2949+ 0.1 , # b_h_scale
2950+ ),
2951+ (
2952+ "negative_weights" ,
2953+ torch .tensor ([[1.0 , - 1.0 ]], dtype = torch .float32 ), # inputs: 1x2
2954+ torch .tensor (
2955+ [[0.5 , - 0.5 , 0.5 , - 0.5 ]], dtype = torch .float32
2956+ ), # hidden: 1x4
2957+ torch .tensor (
2958+ [[1 , - 1 ], [- 1 , 1 ]] * 6 , dtype = torch .int8
2959+ ), # weights_inputs: 12x2 (alternating pattern)
2960+ 0.1 , # w_i_scale
2961+ torch .tensor (
2962+ [[1 , - 1 , 1 , - 1 ], [- 1 , 1 , - 1 , 1 ]] * 6 , dtype = torch .int8
2963+ ), # weights_hidden: 12x4 (alternating pattern)
2964+ 0.1 , # w_h_scale
2965+ torch .zeros (12 , dtype = torch .int8 ), # bias_inputs: 12
2966+ 0.1 , # b_i_scale
2967+ torch .zeros (12 , dtype = torch .int8 ), # bias_hidden: 12
2968+ 0.1 , # b_h_scale
2969+ ),
2970+ (
2971+ "hidden_dim_8" ,
2972+ torch .tensor ([[1.0 , 2.0 , 3.0 ]], dtype = torch .float32 ), # inputs: 1x3
2973+ torch .tensor (
2974+ [[0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 ]], dtype = torch .float32
2975+ ), # hidden: 1x8
2976+ torch .ones ((24 , 3 ), dtype = torch .int8 ), # weights_inputs: 24x3 (3*8 x 3)
2977+ 0.1 , # w_i_scale
2978+ torch .ones ((24 , 8 ), dtype = torch .int8 ), # weights_hidden: 24x8 (3*8 x 8)
2979+ 0.1 , # w_h_scale
2980+ torch .zeros (24 , dtype = torch .int8 ), # bias_inputs: 24
2981+ 0.1 , # b_i_scale
2982+ torch .zeros (24 , dtype = torch .int8 ), # bias_hidden: 24
2983+ 0.1 , # b_h_scale
2984+ ),
2985+ ]
2986+ )
2987+ def test_quantized_w8a32_gru (
2988+ self ,
2989+ name : str ,
2990+ inputs : torch .Tensor ,
2991+ hidden : torch .Tensor ,
2992+ weights_inputs : torch .Tensor ,
2993+ w_i_scale : float ,
2994+ weights_hidden : torch .Tensor ,
2995+ w_h_scale : float ,
2996+ bias_inputs : torch .Tensor ,
2997+ b_i_scale : float ,
2998+ bias_hidden : torch .Tensor ,
2999+ b_h_scale : float ,
3000+ ) -> None :
3001+
3002+ if name == "invalid_batch_size_2" :
3003+ with self .assertRaises (ValueError ) as context :
3004+ torch .ops .cadence .quantized_w8a32_gru (
3005+ inputs ,
3006+ hidden ,
3007+ weights_inputs ,
3008+ w_i_scale ,
3009+ weights_hidden ,
3010+ w_h_scale ,
3011+ bias_inputs ,
3012+ b_i_scale ,
3013+ bias_hidden ,
3014+ b_h_scale ,
3015+ )
3016+ self .assertIn (
3017+ "Leading dimension of hidden state must be 1" , str (context .exception )
3018+ )
3019+ return
3020+
3021+ output = torch .ops .cadence .quantized_w8a32_gru (
3022+ inputs ,
3023+ hidden ,
3024+ weights_inputs ,
3025+ w_i_scale ,
3026+ weights_hidden ,
3027+ w_h_scale ,
3028+ bias_inputs ,
3029+ b_i_scale ,
3030+ bias_hidden ,
3031+ b_h_scale ,
3032+ )
3033+
3034+ # Verify output properties
3035+ self .assertEqual (
3036+ output .dtype ,
3037+ torch .float32 ,
3038+ f"Output dtype should be float32 in { name } " ,
3039+ )
3040+ self .assertEqual (
3041+ output .shape ,
3042+ (2 , hidden .shape [- 1 ]),
3043+ f"Output shape should match { (2 , hidden .shape [- 1 ])} in { name } " ,
3044+ )
3045+ assert isinstance (output , torch .Tensor )
3046+
3047+ # Verify output is bounded: GRU hidden state is a convex combination of
3048+ # tanh([-1,1]) and previous hidden([-1,1]), so output should be in [-1,1]
3049+ self .assertTrue (
3050+ torch .all (output >= - 1.0 ) and torch .all (output <= 1.0 ),
3051+ f"Output values should be in [-1.1, 1.1] in { name } . Got min={ output .min ():.4f} , max={ output .max ():.4f} " ,
3052+ )
3053+
3054+ def test_quantized_w8a32_gru_invalid_hidden_dim (self ) -> None :
3055+ # Test that non-multiple of 4 hidden dimension raises error
3056+ inputs = torch .tensor ([[1.0 , 2.0 ]], dtype = torch .float32 ) # 1x2
3057+ hidden = torch .tensor (
3058+ [[0.5 , 0.5 , 0.5 ]], dtype = torch .float32
3059+ ) # 1x3 (not divisible by 4)
3060+ weights_inputs = torch .zeros ((9 , 2 ), dtype = torch .int8 ) # 9x2
3061+ weights_hidden = torch .zeros ((9 , 3 ), dtype = torch .int8 ) # 9x3
3062+ bias_inputs = torch .zeros (9 , dtype = torch .int8 )
3063+ bias_hidden = torch .zeros (9 , dtype = torch .int8 )
3064+
3065+ with self .assertRaises (ValueError ) as context :
3066+ torch .ops .cadence .quantized_w8a32_gru (
3067+ inputs ,
3068+ hidden ,
3069+ weights_inputs ,
3070+ 0.1 ,
3071+ weights_hidden ,
3072+ 0.1 ,
3073+ bias_inputs ,
3074+ 0.1 ,
3075+ bias_hidden ,
3076+ 0.1 ,
3077+ )
3078+
3079+ self .assertIn (
3080+ "Hidden dimension must be a multiple of 4" , str (context .exception )
3081+ )
0 commit comments