Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions src/complex_tensor/test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

from typing import Any

import torch
import torch.distributed as dist
from torch._ops import OpOverload, OpOverloadPacket
from torch.testing._internal.common_device_type import OpDTypes, instantiate_device_type_tests, ops
from torch.testing._internal.common_methods_invocations import op_db
Expand Down Expand Up @@ -137,27 +140,47 @@ def get_overload_packet_from_name(name: str) -> OpOverloadPacket:
},
}

STORE = dist.HashStore()
dist.init_process_group(store=STORE, rank=0, world_size=1)
DEVICE_MESH = dist.init_device_mesh("cpu", mesh_shape=(1,))


def _as_complex_dtensor(arg: torch.Tensor | Any) -> torch.Tensor | Any:
if not isinstance(arg, torch.Tensor):
return arg

return dist.tensor.DTensor.from_local(_as_complex_tensor(arg), device_mesh=DEVICE_MESH)


TRANSFORM_FUNCS = {Variant.Op: _as_complex_tensor, Variant.Distributed: _as_complex_dtensor}


class TestComplexTensor(TestCase):
_default_dtype_check_enabled = True

@parametrize("compile", [False, True])
@ops(implemented_op_db, dtypes=OpDTypes.supported, allowed_dtypes=list(COMPLEX_DTYPES))
def test_consistency(self, device, dtype, op: OpInfo, compile: bool):
self.check_consistency(device, dtype, op, compile)
self.check_consistency(device, dtype, op, compile, Variant.Op)

@parametrize("compile", [False, True])
@ops(force_test_op_db, allowed_dtypes=list(COMPLEX_DTYPES))
def test_maybe_error(self, device, dtype, op: OpInfo, compile: bool):
self.check_consistency(device, dtype, op, compile)
self.check_consistency(device, dtype, op, compile, Variant.Op)

def check_consistency(self, device: torch.device, dtype, op: OpInfo, compile: bool) -> None:
@ops(implemented_op_db, allowed_dtypes=list(COMPLEX_DTYPES))
def test_distributed(self, device, dtype, op: OpInfo):
self.check_consistency(device, dtype, op, False, Variant.Distributed)

def check_consistency(
self, device: torch.device, dtype, op: OpInfo, compile: bool, variant: Variant
) -> None:
test_info = Descriptor(
op=get_overload_packet_from_name(op.name),
device=device,
dtype=dtype,
compile=compile,
variant=Variant.Op,
variant=variant,
)
for xfail_info, reason in SKIPS.items():
if xfail_info.matches(test_info):
Expand All @@ -174,12 +197,14 @@ def check_consistency(self, device: torch.device, dtype, op: OpInfo, compile: bo
if compile:
op = torch.compile(op, fullgraph=True)

transform_fn = TRANSFORM_FUNCS[variant]

for sample_input in sample_inputs:

def expected(sample_input=sample_input):
return op_eager(sample_input.input, *sample_input.args, **sample_input.kwargs)

subclass_sample = sample_input.transform(_as_complex_tensor)
subclass_sample = sample_input.transform(transform_fn)

def actual(subclass_sample=subclass_sample):
return op(subclass_sample.input, *subclass_sample.args, **subclass_sample.kwargs)
Expand Down
13 changes: 11 additions & 2 deletions src/complex_tensor/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any

import torch
import torch.distributed as dist
from torch._ops import OpOverloadPacket
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
from torch.utils._pytree import tree_flatten
Expand All @@ -18,6 +19,14 @@
class Variant(Enum):
Op = auto()
GradCheck = auto()
Distributed = auto()


def _as_local(arg: dist.tensor.DTensor | Any) -> torch.Tensor | Any:
if not isinstance(arg, dist.tensor.DTensor):
return arg

return arg.full_tensor()


@dataclass(frozen=True, kw_only=True)
Expand Down Expand Up @@ -79,7 +88,7 @@ def assertSameResult(
spec_e, spec_a, "Both functions must return a result with the same tree structure."
)
for value_e, value_a in zip(flattened_e, flattened_a, strict=True):
value_e = _as_interleaved(value_e)
value_a = _as_interleaved(value_a)
value_e = _as_interleaved(_as_local(value_e))
value_a = _as_interleaved(_as_local(value_a))

self.assertEqual(value_e, value_a, *args, **kwargs)
Loading