Skip to content

Commit 4eb3d3d

Browse files
committed
dev(matmul): add support for matmul operator
- Add support for matmul operator @ and @= - Add test cases for matmul operator - Reduce redundant code Signed-off-by: Gausshj <[email protected]>
1 parent b712ec3 commit 4eb3d3d

File tree

2 files changed

+137
-43
lines changed

2 files changed

+137
-43
lines changed

grassmann_tensor/tensor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,19 @@ def __itruediv__(self, other: typing.Any) -> GrassmannTensor:
756756
return self
757757
return NotImplemented
758758

759+
def __matmul__(self, other: typing.Any) -> GrassmannTensor:
760+
if isinstance(other, GrassmannTensor):
761+
return self.matmul(other)
762+
return NotImplemented
763+
764+
def __rmatmul__(self, other: typing.Any) -> GrassmannTensor:
765+
return NotImplemented
766+
767+
def __imatmul__(self, other: typing.Any) -> GrassmannTensor:
768+
if isinstance(other, GrassmannTensor):
769+
return self.matmul(other)
770+
return NotImplemented
771+
759772
def clone(self) -> GrassmannTensor:
760773
"""
761774
Create a deep copy of the Grassmann tensor.

tests/matmul_test.py

Lines changed: 124 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,29 @@
11
import pytest
22
import torch
3+
import typing
34
from grassmann_tensor import GrassmannTensor
45

56
Broadcast = tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]
67
MatmulCase = tuple[bool, bool, tuple[int, int], tuple[int, int], tuple[int, int]]
78

89

9-
@pytest.mark.parametrize("a_is_vector", [False, True])
10-
@pytest.mark.parametrize("b_is_vector", [False, True])
11-
@pytest.mark.parametrize("normal_arrow_order", [False, True])
12-
@pytest.mark.parametrize(
13-
"broadcast",
14-
[
10+
@pytest.fixture(params=[False, True])
11+
def a_is_vector(request: pytest.FixtureRequest) -> bool:
12+
return request.param
13+
14+
15+
@pytest.fixture(params=[False, True])
16+
def b_is_vector(request: pytest.FixtureRequest) -> bool:
17+
return request.param
18+
19+
20+
@pytest.fixture(params=[False, True])
21+
def normal_arrow_order(request: pytest.FixtureRequest) -> bool:
22+
return request.param
23+
24+
25+
@pytest.fixture(
26+
params=[
1527
((), (), ()),
1628
((2,), (), (2,)),
1729
((), (3,), (3,)),
@@ -27,9 +39,12 @@
2739
((7, 8), (7, 1), (7, 8)),
2840
],
2941
)
30-
@pytest.mark.parametrize(
31-
"x",
32-
[
42+
def broadcast(request: pytest.FixtureRequest) -> Broadcast:
43+
return request.param
44+
45+
46+
@pytest.fixture(
47+
params=[
3348
(False, False, (1, 1), (1, 1), (1, 1)),
3449
(False, True, (1, 1), (1, 1), (1, 1)),
3550
(True, False, (1, 1), (1, 1), (1, 1)),
@@ -40,6 +55,10 @@
4055
(True, True, (2, 2), (2, 2), (2, 2)),
4156
],
4257
)
58+
def x(request: pytest.FixtureRequest) -> MatmulCase:
59+
return request.param
60+
61+
4362
def test_matmul(
4463
a_is_vector: bool,
4564
b_is_vector: bool,
@@ -101,40 +120,6 @@ def test_matmul(
101120
assert torch.allclose(c.tensor, expected)
102121

103122

104-
@pytest.mark.parametrize("a_is_vector", [False, True])
105-
@pytest.mark.parametrize("b_is_vector", [False, True])
106-
@pytest.mark.parametrize("normal_arrow_order", [False, True])
107-
@pytest.mark.parametrize(
108-
"broadcast",
109-
[
110-
((), (), ()),
111-
((2,), (), (2,)),
112-
((), (3,), (3,)),
113-
((1,), (4,), (4,)),
114-
((5,), (1,), (5,)),
115-
((6,), (6,), (6,)),
116-
((7, 8), (7, 8), (7, 8)),
117-
((1, 8), (7, 8), (7, 8)),
118-
((8,), (7, 8), (7, 8)),
119-
((7, 1), (7, 8), (7, 8)),
120-
((7, 8), (1, 8), (7, 8)),
121-
((7, 8), (8,), (7, 8)),
122-
((7, 8), (7, 1), (7, 8)),
123-
],
124-
)
125-
@pytest.mark.parametrize(
126-
"x",
127-
[
128-
(False, False, (1, 1), (1, 1), (1, 1)),
129-
(False, True, (1, 1), (1, 1), (1, 1)),
130-
(True, False, (1, 1), (1, 1), (1, 1)),
131-
(True, True, (1, 1), (1, 1), (1, 1)),
132-
(False, False, (2, 2), (2, 2), (2, 2)),
133-
(False, True, (2, 2), (2, 2), (2, 2)),
134-
(True, False, (2, 2), (2, 2), (2, 2)),
135-
(True, True, (2, 2), (2, 2), (2, 2)),
136-
],
137-
)
138123
@pytest.mark.parametrize("impure_even_for_broadcast_indices", [1, 2])
139124
def test_matmul_unpure_even(
140125
a_is_vector: bool,
@@ -191,3 +176,99 @@ def test_matmul_unpure_even(
191176
pytest.skip("One of the two tensors needs to have a dimension greater than 2")
192177
with pytest.raises(AssertionError, match="All edges except the last two must be pure even"):
193178
_ = a.matmul(b)
179+
180+
181+
def test_matmul_operator_matmul(
182+
a_is_vector: bool,
183+
b_is_vector: bool,
184+
normal_arrow_order: bool,
185+
broadcast: Broadcast,
186+
) -> None:
187+
normal_arrow_order = True
188+
broadcast_a, broadcast_b, broadcast_result = (7, 8), (7, 1), (7, 8)
189+
arrow_a, arrow_b, edge_a, edge_common, edge_b = True, True, (2, 2), (2, 2), (2, 2)
190+
dim_a = sum(edge_a)
191+
dim_common = sum(edge_common)
192+
dim_b = sum(edge_b)
193+
a = GrassmannTensor(
194+
(*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False),
195+
(*((i, 0) for i in broadcast_a), edge_a, edge_common),
196+
torch.randn([*broadcast_a, dim_a, dim_common]),
197+
).update_mask()
198+
199+
b = GrassmannTensor(
200+
(*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b),
201+
(*((i, 0) for i in broadcast_b), edge_common, edge_b),
202+
torch.randn([*broadcast_b, dim_common, dim_b]),
203+
).update_mask()
204+
205+
c = a @ b
206+
expected = a.tensor.matmul(b.tensor)
207+
assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b)
208+
assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b)
209+
assert torch.allclose(c.tensor, expected)
210+
211+
212+
@pytest.fixture(
213+
params=[
214+
GrassmannTensor((False, False), ((2, 2), (1, 3)), torch.randn([4, 4])),
215+
GrassmannTensor((True, False, True), ((1, 1), (2, 2), (3, 1)), torch.randn([2, 4, 4])),
216+
GrassmannTensor(
217+
(True, True, False, False), ((1, 2), (2, 2), (1, 1), (3, 1)), torch.randn([3, 4, 2, 4])
218+
),
219+
]
220+
)
221+
def tensors(request: pytest.FixtureRequest) -> GrassmannTensor:
222+
return request.param
223+
224+
225+
@pytest.mark.parametrize(
226+
"unsupported_type",
227+
[
228+
"string", # string
229+
None, # NoneType
230+
{"key", "value"}, # dict
231+
[1, 2, 3], # list
232+
{1, 2}, # set
233+
object(), # arbitrary object
234+
],
235+
)
236+
def test_matmul_unsupported_type_raises_typeerror(
237+
unsupported_type: typing.Any,
238+
tensors: GrassmannTensor,
239+
) -> None:
240+
with pytest.raises(TypeError):
241+
_ = tensors @ unsupported_type
242+
243+
with pytest.raises(TypeError):
244+
_ = unsupported_type @ tensors
245+
246+
with pytest.raises(TypeError):
247+
tensors @= unsupported_type
248+
249+
250+
def test_matmul_operator_rmatmul() -> None:
251+
normal_arrow_order = True
252+
broadcast_a, broadcast_b, broadcast_result = (7, 8), (7, 1), (7, 8)
253+
arrow_a, arrow_b, edge_a, edge_common, edge_b = True, True, (2, 2), (2, 2), (2, 2)
254+
dim_a = sum(edge_a)
255+
dim_common = sum(edge_common)
256+
dim_b = sum(edge_b)
257+
a = GrassmannTensor(
258+
(*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False),
259+
(*((i, 0) for i in broadcast_a), edge_a, edge_common),
260+
torch.randn([*broadcast_a, dim_a, dim_common]),
261+
).update_mask()
262+
263+
b = GrassmannTensor(
264+
(*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b),
265+
(*((i, 0) for i in broadcast_b), edge_common, edge_b),
266+
torch.randn([*broadcast_b, dim_common, dim_b]),
267+
).update_mask()
268+
269+
c = a
270+
c @= b
271+
expected = a.tensor.matmul(b.tensor)
272+
assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b)
273+
assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b)
274+
assert torch.allclose(c.tensor, expected)

0 commit comments

Comments
 (0)