Skip to content

Commit 7ea6574

Browse files
authored
Merge pull request #17 from hameerabbasi/refactor-tests
Refactor tests to use variants instead of flags
2 parents 018b8b0 + a6c4619 commit 7ea6574

File tree

3 files changed

+82
-55
lines changed

3 files changed

+82
-55
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ repos:
1616
exclude: ".ipynb"
1717

1818
- repo: https://github.com/astral-sh/ruff-pre-commit
19-
rev: v0.14.0
19+
rev: v0.14.2
2020
hooks:
2121
- id: ruff-check
2222
args: ["--fix"]

src/complex_tensor/test/test_ops.py

Lines changed: 70 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import torch
4-
from torch._ops import OpOverload
4+
from torch._ops import OpOverload, OpOverloadPacket
55
from torch.testing._internal.common_device_type import OpDTypes, instantiate_device_type_tests, ops
66
from torch.testing._internal.common_methods_invocations import op_db
77
from torch.testing._internal.common_utils import (
@@ -16,8 +16,9 @@
1616
from complex_tensor.ops._common import ComplexDispatchMode, _as_complex_tensor
1717
from complex_tensor.test.utils import (
1818
COMPLEX_DTYPES,
19+
Descriptor,
1920
TestCase,
20-
TestDescriptor,
21+
Variant,
2122
)
2223

2324
torch._dynamo.config.recompile_limit = float("inf")
@@ -30,12 +31,22 @@
3031
)
3132

3233

33-
def _get_opname_from_aten_op(aten_op):
34+
def _get_opname_from_aten_op(aten_op: OpOverloadPacket) -> str:
3435
if isinstance(aten_op, OpOverload):
3536
aten_op = aten_op.overloadpacket
3637
return aten_op._qualified_op_name.split("::")[-1]
3738

3839

40+
def get_overload_packet_from_name(name: str) -> OpOverloadPacket:
41+
for domain_name in torch.ops:
42+
op_namespace = getattr(torch.ops, domain_name)
43+
op: OpOverloadPacket = getattr(op_namespace, name, None)
44+
if op is not None:
45+
return op
46+
47+
raise RuntimeError(f"No op with {name=} found.")
48+
49+
3950
force_test_names = set(map(_get_opname_from_aten_op, FORCE_TEST_LIST))
4051
implemented_op_names = (
4152
set(map(_get_opname_from_aten_op, COMPLEX_OPS_TABLE.keys())) - force_test_names
@@ -64,62 +75,63 @@ def _get_opname_from_aten_op(aten_op):
6475

6576

6677
SKIPS = {
67-
TestDescriptor(op_name="empty_like"): "Inconsistent output",
78+
Descriptor(op=aten.empty_like, variant=None): "Non-deterministic output",
6879
# This passes with `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=35 ...
6980
# but when the whole test is run, it fails with this exact
7081
# sample.
71-
TestDescriptor(op_name="repeat", compile=True): "Heisenbug",
72-
TestDescriptor(
73-
op_name="allclose", compile=True
82+
Descriptor(op=aten.repeat, compile=True, variant=None): "Heisenbug",
83+
Descriptor(
84+
op=aten.allclose, compile=True, variant=None
7485
): "`aten.allclose` requires data-dependent control-flow",
75-
TestDescriptor(op_name="randn_like"): "Inconsistent output",
76-
TestDescriptor(op_name="angle", gradcheck=True): "Numerical inconsistency",
77-
TestDescriptor(op_name="asinh", gradcheck=True): "Numerical inconsistency",
78-
TestDescriptor(op_name="atanh", gradcheck=True): "Numerical inconsistency",
79-
TestDescriptor(op_name="reciprocal", gradcheck=True): "Numerical inconsistency",
80-
TestDescriptor(op_name="rsqrt", gradcheck=True): "Numerical inconsistency",
81-
TestDescriptor(op_name="select", gradcheck=True): "Numerical inconsistency",
82-
TestDescriptor(op_name="asin", gradcheck=True): "Numerical inconsistency",
83-
TestDescriptor(op_name="log", gradcheck=True): "Numerical inconsistency",
84-
TestDescriptor(op_name="sgn", gradcheck=True): "Numerical inconsistency",
85-
TestDescriptor(op_name="cumprod", gradcheck=True): "Numerical inconsistency",
86-
TestDescriptor(op_name="slice", gradcheck=True): "Numerical inconsistency",
87-
TestDescriptor(op_name="sqrt", gradcheck=True): "Numerical inconsistency",
88-
TestDescriptor(op_name="tan", gradcheck=True): "Numerical inconsistency",
89-
TestDescriptor(op_name="true_divide", gradcheck=True): "Numerical inconsistency",
90-
TestDescriptor(op_name="prod", gradcheck=True): "Numerical inconsistency",
91-
TestDescriptor(op_name="div", gradcheck=True): "Numerical inconsistency",
92-
TestDescriptor(op_name="expm1", gradcheck=True): "Numerical inconsistency",
93-
TestDescriptor(op_name="var", gradcheck=True): "Numerical inconsistency",
94-
TestDescriptor(op_name="bmm", gradcheck=True): "Numerical inconsistency",
95-
TestDescriptor(op_name="diagonal", gradcheck=True): "Numerical inconsistency",
96-
TestDescriptor(op_name="sinh", gradcheck=True): "Numerical inconsistency",
97-
TestDescriptor(op_name="abs", gradcheck=True): "Numerical inconsistency",
98-
TestDescriptor(op_name="sin", gradcheck=True): "Numerical inconsistency",
99-
TestDescriptor(op_name="atan", gradcheck=True): "Numerical inconsistency",
100-
TestDescriptor(op_name="acos", gradcheck=True): "Numerical inconsistency",
101-
TestDescriptor(op_name="acosh", gradcheck=True): "Numerical inconsistency",
102-
TestDescriptor(op_name="cos", gradcheck=True): "Numerical inconsistency",
103-
TestDescriptor(op_name="cosh", gradcheck=True): "Numerical inconsistency",
104-
TestDescriptor(op_name="addmm", gradcheck=True): "Numerical inconsistency",
105-
TestDescriptor(op_name="pow", gradcheck=True): "Numerical inconsistency",
106-
TestDescriptor(op_name="log1p", gradcheck=True): "Numerical inconsistency",
107-
TestDescriptor(op_name="tanh", gradcheck=True): "Numerical inconsistency",
108-
TestDescriptor(op_name="mm", gradcheck=True): "Numerical inconsistency",
109-
TestDescriptor(op_name="mul", gradcheck=True): "Numerical inconsistency",
110-
TestDescriptor(op_name="exp", gradcheck=True): "Numerical inconsistency",
86+
Descriptor(op=aten.randn_like, variant=None): "Non-deterministic output",
87+
Descriptor(op=aten.angle, variant=Variant.GradCheck): "Numerical inconsistency",
88+
Descriptor(op=aten.asinh, variant=Variant.GradCheck): "Numerical inconsistency",
89+
Descriptor(op=aten.atanh, variant=Variant.GradCheck): "Numerical inconsistency",
90+
Descriptor(op=aten.reciprocal, variant=Variant.GradCheck): "Numerical inconsistency",
91+
Descriptor(op=aten.rsqrt, variant=Variant.GradCheck): "Numerical inconsistency",
92+
Descriptor(op=aten.select, variant=Variant.GradCheck): "Numerical inconsistency",
93+
Descriptor(op=aten.asin, variant=Variant.GradCheck): "Numerical inconsistency",
94+
Descriptor(op=aten.log, variant=Variant.GradCheck): "Numerical inconsistency",
95+
Descriptor(op=aten.sgn, variant=Variant.GradCheck): "Numerical inconsistency",
96+
Descriptor(op=aten.cumprod, variant=Variant.GradCheck): "Numerical inconsistency",
97+
Descriptor(op=aten.slice, variant=Variant.GradCheck): "Numerical inconsistency",
98+
Descriptor(op=aten.sqrt, variant=Variant.GradCheck): "Numerical inconsistency",
99+
Descriptor(op=aten.tan, variant=Variant.GradCheck): "Numerical inconsistency",
100+
Descriptor(op=aten.true_divide, variant=Variant.GradCheck): "Numerical inconsistency",
101+
Descriptor(op=aten.prod, variant=Variant.GradCheck): "Numerical inconsistency",
102+
Descriptor(op=aten.div, variant=Variant.GradCheck): "Numerical inconsistency",
103+
Descriptor(op=aten.expm1, variant=Variant.GradCheck): "Numerical inconsistency",
104+
Descriptor(op=aten.var, variant=Variant.GradCheck): "Numerical inconsistency",
105+
Descriptor(op=aten.bmm, variant=Variant.GradCheck): "Numerical inconsistency",
106+
Descriptor(op=aten.diagonal, variant=Variant.GradCheck): "Numerical inconsistency",
107+
Descriptor(op=aten.sinh, variant=Variant.GradCheck): "Numerical inconsistency",
108+
Descriptor(op=aten.abs, variant=Variant.GradCheck): "Numerical inconsistency",
109+
Descriptor(op=aten.sin, variant=Variant.GradCheck): "Numerical inconsistency",
110+
Descriptor(op=aten.atan, variant=Variant.GradCheck): "Numerical inconsistency",
111+
Descriptor(op=aten.acos, variant=Variant.GradCheck): "Numerical inconsistency",
112+
Descriptor(op=aten.acosh, variant=Variant.GradCheck): "Numerical inconsistency",
113+
Descriptor(op=aten.cos, variant=Variant.GradCheck): "Numerical inconsistency",
114+
Descriptor(op=aten.cosh, variant=Variant.GradCheck): "Numerical inconsistency",
115+
Descriptor(op=aten.addmm, variant=Variant.GradCheck): "Numerical inconsistency",
116+
Descriptor(op=aten.pow, variant=Variant.GradCheck): "Numerical inconsistency",
117+
Descriptor(op=aten.log1p, variant=Variant.GradCheck): "Numerical inconsistency",
118+
Descriptor(op=aten.tanh, variant=Variant.GradCheck): "Numerical inconsistency",
119+
Descriptor(op=aten.mm, variant=Variant.GradCheck): "Numerical inconsistency",
120+
Descriptor(op=aten.dot, variant=Variant.GradCheck): "Numerical inconsistency",
121+
Descriptor(op=aten.mul, variant=Variant.GradCheck): "Numerical inconsistency",
122+
Descriptor(op=aten.exp, variant=Variant.GradCheck): "Numerical inconsistency",
111123
}
112124

113125
EXTRA_KWARGS = {
114-
TestDescriptor(op_name="asinh", dtype=torch.complex64, gradcheck=False): {
126+
Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Op): {
115127
"rtol": 2e-5,
116128
"atol": 5e-5,
117129
},
118-
TestDescriptor(op_name="tanh", dtype=torch.complex64, gradcheck=False): {
130+
Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Op): {
119131
"rtol": 1e-4,
120132
"atol": 1e-5,
121133
},
122-
TestDescriptor(op_name="pow", dtype=torch.complex64, gradcheck=False): {
134+
Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Op): {
123135
"rtol": 2e-2,
124136
"atol": 2e-6,
125137
},
@@ -140,8 +152,12 @@ def test_maybe_error(self, device, dtype, op: OpInfo, compile: bool):
140152
self.check_consistency(device, dtype, op, compile)
141153

142154
def check_consistency(self, device: torch.device, dtype, op: OpInfo, compile: bool) -> None:
143-
test_info = TestDescriptor(
144-
op_name=op.name, device=device, dtype=dtype, compile=compile, gradcheck=False
155+
test_info = Descriptor(
156+
op=get_overload_packet_from_name(op.name),
157+
device=device,
158+
dtype=dtype,
159+
compile=compile,
160+
variant=Variant.Op,
145161
)
146162
for xfail_info, reason in SKIPS.items():
147163
if xfail_info.matches(test_info):
@@ -175,8 +191,12 @@ def actual(subclass_sample=subclass_sample):
175191
class TestComplexBwdGradients(TestGradients):
176192
@ops(implemented_op_db, dtypes=OpDTypes.supported_backward, allowed_dtypes=[torch.complex128])
177193
def test_fn_grad(self, device: torch.device, dtype: torch.dtype, op: OpInfo) -> None:
178-
test_info = TestDescriptor(
179-
op_name=op.name, device=device, dtype=dtype, compile=False, gradcheck=True
194+
test_info = Descriptor(
195+
op=get_overload_packet_from_name(op.name),
196+
device=device,
197+
dtype=dtype,
198+
compile=False,
199+
variant=Variant.GradCheck,
180200
)
181201
for xfail_info, reason in SKIPS.items():
182202
if xfail_info.matches(test_info):

src/complex_tensor/test/utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
from collections.abc import Callable
44
from dataclasses import dataclass, field, fields
5+
from enum import Enum, auto
56
from typing import Any
67

78
import torch
9+
from torch._ops import OpOverloadPacket
810
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
911
from torch.utils._pytree import tree_flatten
1012

@@ -13,15 +15,20 @@
1315
COMPLEX_DTYPES = set(COMPLEX_TO_REAL)
1416

1517

18+
class Variant(Enum):
19+
Op = auto()
20+
GradCheck = auto()
21+
22+
1623
@dataclass(frozen=True, kw_only=True)
17-
class TestDescriptor:
18-
op_name: str | None = field(default=None)
24+
class Descriptor:
25+
op: OpOverloadPacket
26+
variant: Variant | None
1927
device: torch.device | None = field(default=None)
2028
dtype: torch.dtype | None = field(default=None)
2129
compile: bool | None = field(default=None)
22-
gradcheck: bool | None = field(default=None)
2330

24-
def matches(self, other: TestDescriptor) -> bool:
31+
def matches(self, other: Descriptor) -> bool:
2532
fields1 = fields(self)
2633
fields2 = fields(other)
2734
if fields1 != fields2:

0 commit comments

Comments
 (0)