11
11
from functools import partial
12
12
import torch .nn .functional as F
13
13
14
- class TestEmbedding (NVFuserTest ):
15
- def test_embedding (self ):
16
- def fusion_func (
17
- fd : FusionDefinition ,
18
- has_optional_inputs : list [bool ],
19
- optional_inputs_dtypes : list [DataType ]
20
- ):
21
- input = fd .define_tensor (
22
- shape = [- 1 ],
23
- contiguity = [True ],
24
- dtype = DataType .Int ,
25
- is_cpu = False ,
26
- )
27
- weight = fd .define_tensor (
28
- shape = [- 1 , - 1 ],
29
- contiguity = [True , True ],
30
- dtype = DataType .BFloat16 ,
31
- is_cpu = False ,
32
- )
33
- # padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse
34
- optional_inputs = [None ] * 5
35
- for idx in range (len (optional_inputs )):
36
- if has_optional_inputs [idx ]:
37
- optional_inputs [idx ] = fd .define_scalar (value = None , dtype = optional_inputs_dtypes [idx ])
38
- out = fd .ops .embedding (input , weight , * optional_inputs )
39
- fd .add_output (out )
40
14
41
- N , S = 10 , 3
42
- input = torch .randint (N , (S ,), dtype = torch .int64 , device = 'cuda' , requires_grad = False )
43
- weight = torch .randn (N , S , dtype = torch .bfloat16 , device = 'cuda' , requires_grad = True )
44
-
45
- padding_idx_vals = [None , - 1 , - 2 ]
46
- max_norm_vals = [None , 1e-5 ]
47
- norm_type_vals = [None , 2.0 , 1.0 ]
48
- scale_grad_by_freq = [None , True ]
49
- sparse = [None , False , True ]
50
- optional_inputs_dtypes = [DataType .Int , DataType .Float , DataType .Float , DataType .Bool , DataType .Bool ]
15
+ @pytest .mark .parametrize ("padding_idx" , [None , - 2 ])
16
+ @pytest .mark .parametrize ("max_norm" , [None , 1e-5 ])
17
+ @pytest .mark .parametrize ("norm_type" , [None , 1.0 ])
18
+ @pytest .mark .parametrize ("scale_grad_by_freq" , [None , True ])
19
+ @pytest .mark .parametrize ("sparse" , [None , True ])
20
+ def test_embedding (
21
+ padding_idx : None | int ,
22
+ max_norm : None | float ,
23
+ norm_type : None | float ,
24
+ scale_grad_by_freq : None | bool ,
25
+ sparse : None | bool
26
+ ):
27
+ def fusion_func (
28
+ fd : FusionDefinition ,
29
+ has_optional_inputs : list [bool ],
30
+ optional_inputs_dtypes : list [DataType ]
31
+ ):
32
+ input = fd .define_tensor (
33
+ shape = [- 1 ],
34
+ contiguity = [True ],
35
+ dtype = DataType .Int ,
36
+ is_cpu = False ,
37
+ )
38
+ weight = fd .define_tensor (
39
+ shape = [- 1 , - 1 ],
40
+ contiguity = [True , True ],
41
+ dtype = DataType .BFloat16 ,
42
+ is_cpu = False ,
43
+ )
44
+ # padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse
45
+ optional_inputs = [None ] * 5
46
+ for idx in range (len (optional_inputs )):
47
+ if has_optional_inputs [idx ]:
48
+ optional_inputs [idx ] = fd .define_scalar (value = None , dtype = optional_inputs_dtypes [idx ])
49
+ out = fd .ops .embedding (input , weight , * optional_inputs )
50
+ fd .add_output (out )
51
51
52
-
53
- # TODO: Try to move this to pytest_ops.py. Currently, it does not work since the API between nvFuser and torch differs.
54
- for padding_idx , max_norm , norm_type , scale_grad_by_freq , sparse in itertools .product (
55
- padding_idx_vals , max_norm_vals , norm_type_vals , scale_grad_by_freq , sparse
56
- ):
57
- with self .subTest (padding_idx = padding_idx , max_norm = max_norm , norm_type = norm_type , scale_grad_by_freq = scale_grad_by_freq , sparse = sparse ):
58
- # Reset the FusionCache or the fusion would not recompile for all subtests, failing checks in exec_nvfuser.
59
- FusionCache .reset ()
60
- optional_inputs = [padding_idx , max_norm , norm_type , scale_grad_by_freq , sparse ]
61
- has_optional_inputs = [None ] * 5
62
- inputs = [input , weight ]
63
- for idx , param in enumerate (optional_inputs ):
64
- if param is not None :
65
- has_optional_inputs [idx ] = True
66
- inputs .append (param )
67
-
68
- with FusionDefinition () as fd :
69
- fusion_func (fd ,
70
- has_optional_inputs = has_optional_inputs ,
71
- optional_inputs_dtypes = optional_inputs_dtypes )
72
- nvf_out = fd .execute (inputs )
52
+ N , S = 10 , 3
53
+ input = torch .randint (N , (S ,), dtype = torch .int64 , device = 'cuda' , requires_grad = False )
54
+ weight = torch .randn (N , S , dtype = torch .bfloat16 , device = 'cuda' , requires_grad = True )
55
+ optional_inputs_dtypes = [DataType .Int , DataType .Float , DataType .Float , DataType .Bool , DataType .Bool ]
73
56
74
- torch .manual_seed (0 )
75
- norm_type = 2.0 if norm_type is None else norm_type
76
- scale_grad_by_freq = False if scale_grad_by_freq is None else scale_grad_by_freq
77
- sparse = False if sparse is None else sparse
78
- ref_out = F .embedding (input , weight , padding_idx , max_norm , norm_type , scale_grad_by_freq , sparse )
79
- torch .testing .assert_close (nvf_out [0 ], ref_out )
57
+ # This is not in pytest_ops.py since the torch API does not accept None values for some arguments.
58
+ # Different inputs for nvfuser and torch API cannot be handled within OpInfo
59
+ optional_inputs = [padding_idx , max_norm , norm_type , scale_grad_by_freq , sparse ]
60
+ has_optional_inputs = [None ] * 5
61
+ inputs = [input , weight ]
62
+ for idx , param in enumerate (optional_inputs ):
63
+ if param is not None :
64
+ has_optional_inputs [idx ] = True
65
+ inputs .append (param )
66
+
67
+ with FusionDefinition () as fd :
68
+ fusion_func (fd ,
69
+ has_optional_inputs = has_optional_inputs ,
70
+ optional_inputs_dtypes = optional_inputs_dtypes )
71
+ nvf_out = fd .execute (inputs )
72
+
73
+ norm_type = 2.0 if norm_type is None else norm_type
74
+ scale_grad_by_freq = False if scale_grad_by_freq is None else scale_grad_by_freq
75
+ sparse = False if sparse is None else sparse
76
+ ref_out = F .embedding (input , weight , padding_idx , max_norm , norm_type , scale_grad_by_freq , sparse )
77
+ torch .testing .assert_close (nvf_out [0 ], ref_out )
0 commit comments