11from __future__ import annotations
22
33import torch
4- from torch ._ops import OpOverload
4+ from torch ._ops import OpOverload , OpOverloadPacket
55from torch .testing ._internal .common_device_type import OpDTypes , instantiate_device_type_tests , ops
66from torch .testing ._internal .common_methods_invocations import op_db
77from torch .testing ._internal .common_utils import (
1616from complex_tensor .ops ._common import ComplexDispatchMode , _as_complex_tensor
1717from complex_tensor .test .utils import (
1818 COMPLEX_DTYPES ,
19+ Descriptor ,
1920 TestCase ,
20- TestDescriptor ,
21+ Variant ,
2122)
2223
2324torch ._dynamo .config .recompile_limit = float ("inf" )
3031)
3132
3233
33- def _get_opname_from_aten_op (aten_op ) :
34+ def _get_opname_from_aten_op (aten_op : OpOverloadPacket ) -> str :
3435 if isinstance (aten_op , OpOverload ):
3536 aten_op = aten_op .overloadpacket
3637 return aten_op ._qualified_op_name .split ("::" )[- 1 ]
3738
3839
40+ def get_overload_packet_from_name (name : str ) -> OpOverloadPacket :
41+ for domain_name in torch .ops :
42+ op_namespace = getattr (torch .ops , domain_name )
43+ op : OpOverloadPacket = getattr (op_namespace , name , None )
44+ if op is not None :
45+ return op
46+
47+ raise RuntimeError (f"No op with { name = } found." )
48+
49+
3950force_test_names = set (map (_get_opname_from_aten_op , FORCE_TEST_LIST ))
4051implemented_op_names = (
4152 set (map (_get_opname_from_aten_op , COMPLEX_OPS_TABLE .keys ())) - force_test_names
@@ -64,62 +75,63 @@ def _get_opname_from_aten_op(aten_op):
6475
6576
6677SKIPS = {
67- TestDescriptor ( op_name = " empty_like" ): "Inconsistent output" ,
78+ Descriptor ( op = aten . empty_like , variant = None ): "Non-deterministic output" ,
6879 # This passes with `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=35 ...
6980 # but when the whole test is run, it fails with this exact
7081 # sample.
71- TestDescriptor ( op_name = " repeat" , compile = True ): "Heisenbug" ,
72- TestDescriptor (
73- op_name = " allclose" , compile = True
82+ Descriptor ( op = aten . repeat , compile = True , variant = None ): "Heisenbug" ,
83+ Descriptor (
84+ op = aten . allclose , compile = True , variant = None
7485 ): "`aten.allclose` requires data-dependent control-flow" ,
75- TestDescriptor (op_name = "randn_like" ): "Inconsistent output" ,
76- TestDescriptor (op_name = "angle" , gradcheck = True ): "Numerical inconsistency" ,
77- TestDescriptor (op_name = "asinh" , gradcheck = True ): "Numerical inconsistency" ,
78- TestDescriptor (op_name = "atanh" , gradcheck = True ): "Numerical inconsistency" ,
79- TestDescriptor (op_name = "reciprocal" , gradcheck = True ): "Numerical inconsistency" ,
80- TestDescriptor (op_name = "rsqrt" , gradcheck = True ): "Numerical inconsistency" ,
81- TestDescriptor (op_name = "select" , gradcheck = True ): "Numerical inconsistency" ,
82- TestDescriptor (op_name = "asin" , gradcheck = True ): "Numerical inconsistency" ,
83- TestDescriptor (op_name = "log" , gradcheck = True ): "Numerical inconsistency" ,
84- TestDescriptor (op_name = "sgn" , gradcheck = True ): "Numerical inconsistency" ,
85- TestDescriptor (op_name = "cumprod" , gradcheck = True ): "Numerical inconsistency" ,
86- TestDescriptor (op_name = "slice" , gradcheck = True ): "Numerical inconsistency" ,
87- TestDescriptor (op_name = "sqrt" , gradcheck = True ): "Numerical inconsistency" ,
88- TestDescriptor (op_name = "tan" , gradcheck = True ): "Numerical inconsistency" ,
89- TestDescriptor (op_name = "true_divide" , gradcheck = True ): "Numerical inconsistency" ,
90- TestDescriptor (op_name = "prod" , gradcheck = True ): "Numerical inconsistency" ,
91- TestDescriptor (op_name = "div" , gradcheck = True ): "Numerical inconsistency" ,
92- TestDescriptor (op_name = "expm1" , gradcheck = True ): "Numerical inconsistency" ,
93- TestDescriptor (op_name = "var" , gradcheck = True ): "Numerical inconsistency" ,
94- TestDescriptor (op_name = "bmm" , gradcheck = True ): "Numerical inconsistency" ,
95- TestDescriptor (op_name = "diagonal" , gradcheck = True ): "Numerical inconsistency" ,
96- TestDescriptor (op_name = "sinh" , gradcheck = True ): "Numerical inconsistency" ,
97- TestDescriptor (op_name = "abs" , gradcheck = True ): "Numerical inconsistency" ,
98- TestDescriptor (op_name = "sin" , gradcheck = True ): "Numerical inconsistency" ,
99- TestDescriptor (op_name = "atan" , gradcheck = True ): "Numerical inconsistency" ,
100- TestDescriptor (op_name = "acos" , gradcheck = True ): "Numerical inconsistency" ,
101- TestDescriptor (op_name = "acosh" , gradcheck = True ): "Numerical inconsistency" ,
102- TestDescriptor (op_name = "cos" , gradcheck = True ): "Numerical inconsistency" ,
103- TestDescriptor (op_name = "cosh" , gradcheck = True ): "Numerical inconsistency" ,
104- TestDescriptor (op_name = "addmm" , gradcheck = True ): "Numerical inconsistency" ,
105- TestDescriptor (op_name = "pow" , gradcheck = True ): "Numerical inconsistency" ,
106- TestDescriptor (op_name = "log1p" , gradcheck = True ): "Numerical inconsistency" ,
107- TestDescriptor (op_name = "tanh" , gradcheck = True ): "Numerical inconsistency" ,
108- TestDescriptor (op_name = "mm" , gradcheck = True ): "Numerical inconsistency" ,
109- TestDescriptor (op_name = "mul" , gradcheck = True ): "Numerical inconsistency" ,
110- TestDescriptor (op_name = "exp" , gradcheck = True ): "Numerical inconsistency" ,
86+ Descriptor (op = aten .randn_like , variant = None ): "Non-deterministic output" ,
87+ Descriptor (op = aten .angle , variant = Variant .GradCheck ): "Numerical inconsistency" ,
88+ Descriptor (op = aten .asinh , variant = Variant .GradCheck ): "Numerical inconsistency" ,
89+ Descriptor (op = aten .atanh , variant = Variant .GradCheck ): "Numerical inconsistency" ,
90+ Descriptor (op = aten .reciprocal , variant = Variant .GradCheck ): "Numerical inconsistency" ,
91+ Descriptor (op = aten .rsqrt , variant = Variant .GradCheck ): "Numerical inconsistency" ,
92+ Descriptor (op = aten .select , variant = Variant .GradCheck ): "Numerical inconsistency" ,
93+ Descriptor (op = aten .asin , variant = Variant .GradCheck ): "Numerical inconsistency" ,
94+ Descriptor (op = aten .log , variant = Variant .GradCheck ): "Numerical inconsistency" ,
95+ Descriptor (op = aten .sgn , variant = Variant .GradCheck ): "Numerical inconsistency" ,
96+ Descriptor (op = aten .cumprod , variant = Variant .GradCheck ): "Numerical inconsistency" ,
97+ Descriptor (op = aten .slice , variant = Variant .GradCheck ): "Numerical inconsistency" ,
98+ Descriptor (op = aten .sqrt , variant = Variant .GradCheck ): "Numerical inconsistency" ,
99+ Descriptor (op = aten .tan , variant = Variant .GradCheck ): "Numerical inconsistency" ,
100+ Descriptor (op = aten .true_divide , variant = Variant .GradCheck ): "Numerical inconsistency" ,
101+ Descriptor (op = aten .prod , variant = Variant .GradCheck ): "Numerical inconsistency" ,
102+ Descriptor (op = aten .div , variant = Variant .GradCheck ): "Numerical inconsistency" ,
103+ Descriptor (op = aten .expm1 , variant = Variant .GradCheck ): "Numerical inconsistency" ,
104+ Descriptor (op = aten .var , variant = Variant .GradCheck ): "Numerical inconsistency" ,
105+ Descriptor (op = aten .bmm , variant = Variant .GradCheck ): "Numerical inconsistency" ,
106+ Descriptor (op = aten .diagonal , variant = Variant .GradCheck ): "Numerical inconsistency" ,
107+ Descriptor (op = aten .sinh , variant = Variant .GradCheck ): "Numerical inconsistency" ,
108+ Descriptor (op = aten .abs , variant = Variant .GradCheck ): "Numerical inconsistency" ,
109+ Descriptor (op = aten .sin , variant = Variant .GradCheck ): "Numerical inconsistency" ,
110+ Descriptor (op = aten .atan , variant = Variant .GradCheck ): "Numerical inconsistency" ,
111+ Descriptor (op = aten .acos , variant = Variant .GradCheck ): "Numerical inconsistency" ,
112+ Descriptor (op = aten .acosh , variant = Variant .GradCheck ): "Numerical inconsistency" ,
113+ Descriptor (op = aten .cos , variant = Variant .GradCheck ): "Numerical inconsistency" ,
114+ Descriptor (op = aten .cosh , variant = Variant .GradCheck ): "Numerical inconsistency" ,
115+ Descriptor (op = aten .addmm , variant = Variant .GradCheck ): "Numerical inconsistency" ,
116+ Descriptor (op = aten .pow , variant = Variant .GradCheck ): "Numerical inconsistency" ,
117+ Descriptor (op = aten .log1p , variant = Variant .GradCheck ): "Numerical inconsistency" ,
118+ Descriptor (op = aten .tanh , variant = Variant .GradCheck ): "Numerical inconsistency" ,
119+ Descriptor (op = aten .mm , variant = Variant .GradCheck ): "Numerical inconsistency" ,
120+ Descriptor (op = aten .dot , variant = Variant .GradCheck ): "Numerical inconsistency" ,
121+ Descriptor (op = aten .mul , variant = Variant .GradCheck ): "Numerical inconsistency" ,
122+ Descriptor (op = aten .exp , variant = Variant .GradCheck ): "Numerical inconsistency" ,
111123}
112124
113125EXTRA_KWARGS = {
114- TestDescriptor ( op_name = " asinh" , dtype = torch .complex64 , gradcheck = False ): {
126+ Descriptor ( op = aten . asinh , dtype = torch .complex64 , variant = Variant . Op ): {
115127 "rtol" : 2e-5 ,
116128 "atol" : 5e-5 ,
117129 },
118- TestDescriptor ( op_name = " tanh" , dtype = torch .complex64 , gradcheck = False ): {
130+ Descriptor ( op = aten . tanh , dtype = torch .complex64 , variant = Variant . Op ): {
119131 "rtol" : 1e-4 ,
120132 "atol" : 1e-5 ,
121133 },
122- TestDescriptor ( op_name = " pow" , dtype = torch .complex64 , gradcheck = False ): {
134+ Descriptor ( op = aten . pow , dtype = torch .complex64 , variant = Variant . Op ): {
123135 "rtol" : 2e-2 ,
124136 "atol" : 2e-6 ,
125137 },
@@ -140,8 +152,12 @@ def test_maybe_error(self, device, dtype, op: OpInfo, compile: bool):
140152 self .check_consistency (device , dtype , op , compile )
141153
142154 def check_consistency (self , device : torch .device , dtype , op : OpInfo , compile : bool ) -> None :
143- test_info = TestDescriptor (
144- op_name = op .name , device = device , dtype = dtype , compile = compile , gradcheck = False
155+ test_info = Descriptor (
156+ op = get_overload_packet_from_name (op .name ),
157+ device = device ,
158+ dtype = dtype ,
159+ compile = compile ,
160+ variant = Variant .Op ,
145161 )
146162 for xfail_info , reason in SKIPS .items ():
147163 if xfail_info .matches (test_info ):
@@ -175,8 +191,12 @@ def actual(subclass_sample=subclass_sample):
175191class TestComplexBwdGradients (TestGradients ):
176192 @ops (implemented_op_db , dtypes = OpDTypes .supported_backward , allowed_dtypes = [torch .complex128 ])
177193 def test_fn_grad (self , device : torch .device , dtype : torch .dtype , op : OpInfo ) -> None :
178- test_info = TestDescriptor (
179- op_name = op .name , device = device , dtype = dtype , compile = False , gradcheck = True
194+ test_info = Descriptor (
195+ op = get_overload_packet_from_name (op .name ),
196+ device = device ,
197+ dtype = dtype ,
198+ compile = False ,
199+ variant = Variant .GradCheck ,
180200 )
181201 for xfail_info , reason in SKIPS .items ():
182202 if xfail_info .matches (test_info ):
0 commit comments