Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions grassmann_tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,19 @@ def __itruediv__(self, other: typing.Any) -> GrassmannTensor:
return self
return NotImplemented

def __matmul__(self, other: typing.Any) -> GrassmannTensor:
if isinstance(other, GrassmannTensor):
return self.matmul(other)
return NotImplemented

def __rmatmul__(self, other: typing.Any) -> GrassmannTensor:
return NotImplemented

def __imatmul__(self, other: typing.Any) -> GrassmannTensor:
if isinstance(other, GrassmannTensor):
return self.matmul(other)
return NotImplemented

def clone(self) -> GrassmannTensor:
"""
Create a deep copy of the Grassmann tensor.
Expand Down
249 changes: 206 additions & 43 deletions tests/matmul_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
import pytest
import torch
import typing
from grassmann_tensor import GrassmannTensor

Broadcast = tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]
MatmulCase = tuple[bool, bool, tuple[int, int], tuple[int, int], tuple[int, int]]


@pytest.mark.parametrize("a_is_vector", [False, True])
@pytest.mark.parametrize("b_is_vector", [False, True])
@pytest.mark.parametrize("normal_arrow_order", [False, True])
@pytest.mark.parametrize(
"broadcast",
[
@pytest.fixture(params=[False, True])
def a_is_vector(request: pytest.FixtureRequest) -> bool:
return request.param


@pytest.fixture(params=[False, True])
def b_is_vector(request: pytest.FixtureRequest) -> bool:
return request.param


@pytest.fixture(params=[False, True])
def normal_arrow_order(request: pytest.FixtureRequest) -> bool:
return request.param


@pytest.fixture(
params=[
((), (), ()),
((2,), (), (2,)),
((), (3,), (3,)),
Expand All @@ -27,9 +39,12 @@
((7, 8), (7, 1), (7, 8)),
],
)
@pytest.mark.parametrize(
"x",
[
def broadcast(request: pytest.FixtureRequest) -> Broadcast:
return request.param


@pytest.fixture(
params=[
(False, False, (1, 1), (1, 1), (1, 1)),
(False, True, (1, 1), (1, 1), (1, 1)),
(True, False, (1, 1), (1, 1), (1, 1)),
Expand All @@ -40,6 +55,10 @@
(True, True, (2, 2), (2, 2), (2, 2)),
],
)
def x(request: pytest.FixtureRequest) -> MatmulCase:
return request.param


def test_matmul(
a_is_vector: bool,
b_is_vector: bool,
Expand Down Expand Up @@ -101,40 +120,6 @@ def test_matmul(
assert torch.allclose(c.tensor, expected)


@pytest.mark.parametrize("a_is_vector", [False, True])
@pytest.mark.parametrize("b_is_vector", [False, True])
@pytest.mark.parametrize("normal_arrow_order", [False, True])
@pytest.mark.parametrize(
"broadcast",
[
((), (), ()),
((2,), (), (2,)),
((), (3,), (3,)),
((1,), (4,), (4,)),
((5,), (1,), (5,)),
((6,), (6,), (6,)),
((7, 8), (7, 8), (7, 8)),
((1, 8), (7, 8), (7, 8)),
((8,), (7, 8), (7, 8)),
((7, 1), (7, 8), (7, 8)),
((7, 8), (1, 8), (7, 8)),
((7, 8), (8,), (7, 8)),
((7, 8), (7, 1), (7, 8)),
],
)
@pytest.mark.parametrize(
"x",
[
(False, False, (1, 1), (1, 1), (1, 1)),
(False, True, (1, 1), (1, 1), (1, 1)),
(True, False, (1, 1), (1, 1), (1, 1)),
(True, True, (1, 1), (1, 1), (1, 1)),
(False, False, (2, 2), (2, 2), (2, 2)),
(False, True, (2, 2), (2, 2), (2, 2)),
(True, False, (2, 2), (2, 2), (2, 2)),
(True, True, (2, 2), (2, 2), (2, 2)),
],
)
@pytest.mark.parametrize("impure_even_for_broadcast_indices", [1, 2])
def test_matmul_unpure_even(
a_is_vector: bool,
Expand Down Expand Up @@ -191,3 +176,181 @@ def test_matmul_unpure_even(
pytest.skip("One of the two tensors needs to have a dimension greater than 2")
with pytest.raises(AssertionError, match="All edges except the last two must be pure even"):
_ = a.matmul(b)


def test_matmul_operator_matmul(
a_is_vector: bool,
b_is_vector: bool,
normal_arrow_order: bool,
broadcast: Broadcast,
x: MatmulCase,
) -> None:
broadcast_a, broadcast_b, broadcast_result = broadcast
arrow_a, arrow_b, edge_a, edge_common, edge_b = x
if a_is_vector and broadcast_a != ():
pytest.skip("Vector a cannot be broadcasted")
if b_is_vector and broadcast_b != ():
pytest.skip("Vector b cannot be broadcasted")
dim_a = sum(edge_a)
dim_common = sum(edge_common)
dim_b = sum(edge_b)
if a_is_vector:
a = GrassmannTensor(
(*(False for _ in broadcast_a), True if normal_arrow_order else False),
(*((i, 0) for i in broadcast_a), edge_common),
torch.randn([*broadcast_a, dim_common]),
).update_mask()
else:
a = GrassmannTensor(
(*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False),
(*((i, 0) for i in broadcast_a), edge_a, edge_common),
torch.randn([*broadcast_a, dim_a, dim_common]),
).update_mask()
if b_is_vector:
b = GrassmannTensor(
(*(False for _ in broadcast_b), False if normal_arrow_order else True),
(*((i, 0) for i in broadcast_b), edge_common),
torch.randn([*broadcast_b, dim_common]),
).update_mask()
else:
b = GrassmannTensor(
(*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b),
(*((i, 0) for i in broadcast_b), edge_common, edge_b),
torch.randn([*broadcast_b, dim_common, dim_b]),
).update_mask()
c = a @ b
expected = a.tensor.matmul(b.tensor)
if not a_is_vector and not b_is_vector and not normal_arrow_order:
expected[..., edge_a[0] :, edge_b[0] :] *= -1
if a_is_vector:
if b_is_vector:
assert c.arrow == tuple(False for _ in broadcast_result)
assert c.edges == tuple((i, 0) for i in broadcast_result)
else:
assert c.arrow == (*(False for _ in broadcast_result), arrow_b)
assert c.edges == (*((i, 0) for i in broadcast_result), edge_b)
else:
if b_is_vector:
assert c.arrow == (*(False for _ in broadcast_result), arrow_a)
assert c.edges == (*((i, 0) for i in broadcast_result), edge_a)
else:
assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b)
assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b)
assert torch.allclose(c.tensor, expected)


@pytest.fixture(
params=[
(
GrassmannTensor((False, False), ((2, 2), (1, 3)), torch.randn([4, 4])),
GrassmannTensor((False, False), ((2, 2), (1, 3)), torch.randn([4, 4])),
),
(
GrassmannTensor((True, False, True), ((1, 1), (2, 2), (3, 1)), torch.randn([2, 4, 4])),
GrassmannTensor((True, False, True), ((1, 1), (2, 2), (3, 1)), torch.randn([2, 4, 4])),
),
(
GrassmannTensor(
(True, True, False, False),
((1, 2), (2, 2), (1, 1), (3, 1)),
torch.randn([3, 4, 2, 4]),
),
GrassmannTensor(
(True, True, False, False),
((1, 2), (2, 2), (1, 1), (3, 1)),
torch.randn([3, 4, 2, 4]),
),
),
]
)
def tensors(request: pytest.FixtureRequest) -> tuple[GrassmannTensor, GrassmannTensor]:
return request.param


@pytest.mark.parametrize(
"unsupported_type",
[
"string", # string
None, # NoneType
{"key", "value"}, # dict
[1, 2, 3], # list
{1, 2}, # set
object(), # arbitrary object
],
)
def test_matmul_unsupported_type_raises_typeerror(
unsupported_type: typing.Any,
tensors: tuple[GrassmannTensor, GrassmannTensor],
) -> None:
tensor_a, _ = tensors

with pytest.raises(TypeError):
_ = tensor_a @ unsupported_type

with pytest.raises(TypeError):
_ = unsupported_type @ tensor_a

with pytest.raises(TypeError):
tensor_a @= unsupported_type


def test_matmul_operator_rmatmul(
a_is_vector: bool,
b_is_vector: bool,
normal_arrow_order: bool,
broadcast: Broadcast,
x: MatmulCase,
) -> None:
broadcast_a, broadcast_b, broadcast_result = broadcast
arrow_a, arrow_b, edge_a, edge_common, edge_b = x
if a_is_vector and broadcast_a != ():
pytest.skip("Vector a cannot be broadcasted")
if b_is_vector and broadcast_b != ():
pytest.skip("Vector b cannot be broadcasted")
dim_a = sum(edge_a)
dim_common = sum(edge_common)
dim_b = sum(edge_b)
if a_is_vector:
a = GrassmannTensor(
(*(False for _ in broadcast_a), True if normal_arrow_order else False),
(*((i, 0) for i in broadcast_a), edge_common),
torch.randn([*broadcast_a, dim_common]),
).update_mask()
else:
a = GrassmannTensor(
(*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False),
(*((i, 0) for i in broadcast_a), edge_a, edge_common),
torch.randn([*broadcast_a, dim_a, dim_common]),
).update_mask()
if b_is_vector:
b = GrassmannTensor(
(*(False for _ in broadcast_b), False if normal_arrow_order else True),
(*((i, 0) for i in broadcast_b), edge_common),
torch.randn([*broadcast_b, dim_common]),
).update_mask()
else:
b = GrassmannTensor(
(*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b),
(*((i, 0) for i in broadcast_b), edge_common, edge_b),
torch.randn([*broadcast_b, dim_common, dim_b]),
).update_mask()
c = a
c @= b
expected = a.tensor.matmul(b.tensor)
if not a_is_vector and not b_is_vector and not normal_arrow_order:
expected[..., edge_a[0] :, edge_b[0] :] *= -1
if a_is_vector:
if b_is_vector:
assert c.arrow == tuple(False for _ in broadcast_result)
assert c.edges == tuple((i, 0) for i in broadcast_result)
else:
assert c.arrow == (*(False for _ in broadcast_result), arrow_b)
assert c.edges == (*((i, 0) for i in broadcast_result), edge_b)
else:
if b_is_vector:
assert c.arrow == (*(False for _ in broadcast_result), arrow_a)
assert c.edges == (*((i, 0) for i in broadcast_result), edge_a)
else:
assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b)
assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b)
assert torch.allclose(c.tensor, expected)