diff --git a/src/complex_tensor/ops/_common.py b/src/complex_tensor/ops/_common.py index 0ed4bd9..f6da649 100644 --- a/src/complex_tensor/ops/_common.py +++ b/src/complex_tensor/ops/_common.py @@ -183,7 +183,7 @@ def impl( out_flat = [ComplexTensor(ui, vi) for ui, vi in zip(u_flat, v_flat, strict=False)] return tree_unflatten(out_flat, u_spec) - func_name = f"{str(op).split('.', 1)}_impl" + func_name = f"{str(op).split('.', 1)[1]}_impl" impl.__name__ = func_name impl.__qualname__ = func_name diff --git a/src/complex_tensor/ops/aten.py b/src/complex_tensor/ops/aten.py index ff5f715..9b1f8e3 100644 --- a/src/complex_tensor/ops/aten.py +++ b/src/complex_tensor/ops/aten.py @@ -92,7 +92,7 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b ] for simple_op in SIMPLE_OPS_LIST: - globals()[f"{str(simple_op).split('.', 1)}_impl"] = register_simple(simple_op) + globals()[f"{str(simple_op).split('.', 1)[1]}_impl"] = register_simple(simple_op) # TODO (hameerabbasi): Not being tested SIMPLE_FORCE_TESTED_OPS = [ diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index f61d9f3..1ae830c 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -1,6 +1,9 @@ from __future__ import annotations +from typing import Any + import torch +import torch.distributed as dist from torch._ops import OpOverload, OpOverloadPacket from torch.testing._internal.common_device_type import OpDTypes, instantiate_device_type_tests, ops from torch.testing._internal.common_methods_invocations import op_db @@ -120,6 +123,54 @@ def get_overload_packet_from_name(name: str) -> OpOverloadPacket: Descriptor(op=aten.dot, variant=Variant.GradCheck): "Numerical inconsistency", Descriptor(op=aten.mul, variant=Variant.GradCheck): "Numerical inconsistency", Descriptor(op=aten.exp, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor( + op=aten.any, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.all, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.allclose, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.conj_physical, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten._conj_physical, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.cumprod, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.index_add, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.diagonal_scatter, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.flip, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.masked_fill, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.masked_scatter, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.rsub, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.ne, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.squeeze, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor(op=aten.index_select, variant=Variant.Distributed): "Sharding propagation failed", + Descriptor(op=aten.real, variant=Variant.Distributed): "No scalar support", + Descriptor(op=aten.imag, variant=Variant.Distributed): "No scalar support", + Descriptor(op=aten.isfinite, variant=Variant.Distributed): "No scalar support", + Descriptor(op=aten.transpose, variant=Variant.Distributed): "No scalar support", + Descriptor(op=aten.view_as_real, variant=Variant.Distributed): "No scalar support", } EXTRA_KWARGS = { @@ -135,8 +186,38 @@ def get_overload_packet_from_name(name: str) -> OpOverloadPacket: "rtol": 2e-2, "atol": 2e-6, }, + Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Distributed): { + "rtol": 2e-5, + "atol": 5e-5, + }, + Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Distributed): { + "rtol": 1e-4, + "atol": 1e-5, + }, + Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Distributed): { + "rtol": 2e-2, + "atol": 2e-6, + }, + Descriptor(op=aten.tan, dtype=torch.complex64, variant=Variant.Distributed): { + "rtol": 2e-6, + "atol": 1e-2, + }, } +STORE = dist.HashStore() +dist.init_process_group(store=STORE, rank=0, world_size=1) +DEVICE_MESH = dist.init_device_mesh("cpu", mesh_shape=(1,)) + + +def _as_complex_dtensor(arg: torch.Tensor | Any) -> torch.Tensor | Any: + if not isinstance(arg, torch.Tensor): + return arg + + return dist.tensor.DTensor.from_local(_as_complex_tensor(arg), device_mesh=DEVICE_MESH) + + +TRANSFORM_FUNCS = {Variant.Op: _as_complex_tensor, Variant.Distributed: _as_complex_dtensor} + class TestComplexTensor(TestCase): _default_dtype_check_enabled = True @@ -144,20 +225,26 @@ class TestComplexTensor(TestCase): @parametrize("compile", [False, True]) @ops(implemented_op_db, dtypes=OpDTypes.supported, allowed_dtypes=list(COMPLEX_DTYPES)) def test_consistency(self, device, dtype, op: OpInfo, compile: bool): - self.check_consistency(device, dtype, op, compile) + self.check_consistency(device, dtype, op, compile, Variant.Op) @parametrize("compile", [False, True]) @ops(force_test_op_db, allowed_dtypes=list(COMPLEX_DTYPES)) def test_maybe_error(self, device, dtype, op: OpInfo, compile: bool): - self.check_consistency(device, dtype, op, compile) + self.check_consistency(device, dtype, op, compile, Variant.Op) - def check_consistency(self, device: torch.device, dtype, op: OpInfo, compile: bool) -> None: + @ops(implemented_op_db, allowed_dtypes=list(COMPLEX_DTYPES)) + def test_distributed(self, device, dtype, op: OpInfo): + self.check_consistency(device, dtype, op, False, Variant.Distributed) + + def check_consistency( + self, device: torch.device, dtype, op: OpInfo, compile: bool, variant: Variant + ) -> None: test_info = Descriptor( op=get_overload_packet_from_name(op.name), device=device, dtype=dtype, compile=compile, - variant=Variant.Op, + variant=variant, ) for xfail_info, reason in SKIPS.items(): if xfail_info.matches(test_info): @@ -174,12 +261,14 @@ def check_consistency(self, device: torch.device, dtype, op: OpInfo, compile: bo if compile: op = torch.compile(op, fullgraph=True) + transform_fn = TRANSFORM_FUNCS[variant] + for sample_input in sample_inputs: def expected(sample_input=sample_input): return op_eager(sample_input.input, *sample_input.args, **sample_input.kwargs) - subclass_sample = sample_input.transform(_as_complex_tensor) + subclass_sample = sample_input.transform(transform_fn) def actual(subclass_sample=subclass_sample): return op(subclass_sample.input, *subclass_sample.args, **subclass_sample.kwargs) diff --git a/src/complex_tensor/test/utils.py b/src/complex_tensor/test/utils.py index b74ceb7..7c003c0 100644 --- a/src/complex_tensor/test/utils.py +++ b/src/complex_tensor/test/utils.py @@ -6,6 +6,7 @@ from typing import Any import torch +import torch.distributed as dist from torch._ops import OpOverloadPacket from torch.testing._internal.common_utils import TestCase as PytorchTestCase from torch.utils._pytree import tree_flatten @@ -18,6 +19,14 @@ class Variant(Enum): Op = auto() GradCheck = auto() + Distributed = auto() + + +def _as_local(arg: dist.tensor.DTensor | Any) -> torch.Tensor | Any: + if not isinstance(arg, dist.tensor.DTensor): + return arg + + return arg.full_tensor() @dataclass(frozen=True, kw_only=True) @@ -79,7 +88,7 @@ def assertSameResult( spec_e, spec_a, "Both functions must return a result with the same tree structure." ) for value_e, value_a in zip(flattened_e, flattened_a, strict=True): - value_e = _as_interleaved(value_e) - value_a = _as_interleaved(value_a) + value_e = _as_interleaved(_as_local(value_e)) + value_a = _as_interleaved(_as_local(value_a)) self.assertEqual(value_e, value_a, *args, **kwargs)