|  | 
| 1 | 1 | import pytest | 
| 2 | 2 | import torch | 
|  | 3 | +import typing | 
| 3 | 4 | from grassmann_tensor import GrassmannTensor | 
| 4 | 5 | 
 | 
| 5 | 6 | Broadcast = tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]] | 
| 6 | 7 | MatmulCase = tuple[bool, bool, tuple[int, int], tuple[int, int], tuple[int, int]] | 
| 7 | 8 | 
 | 
| 8 | 9 | 
 | 
| 9 |  | -@pytest.mark.parametrize("a_is_vector", [False, True]) | 
| 10 |  | -@pytest.mark.parametrize("b_is_vector", [False, True]) | 
| 11 |  | -@pytest.mark.parametrize("normal_arrow_order", [False, True]) | 
| 12 |  | -@pytest.mark.parametrize( | 
| 13 |  | -    "broadcast", | 
| 14 |  | -    [ | 
|  | 10 | +@pytest.fixture(params=[False, True]) | 
|  | 11 | +def a_is_vector(request: pytest.FixtureRequest) -> bool: | 
|  | 12 | +    return request.param | 
|  | 13 | + | 
|  | 14 | + | 
|  | 15 | +@pytest.fixture(params=[False, True]) | 
|  | 16 | +def b_is_vector(request: pytest.FixtureRequest) -> bool: | 
|  | 17 | +    return request.param | 
|  | 18 | + | 
|  | 19 | + | 
|  | 20 | +@pytest.fixture(params=[False, True]) | 
|  | 21 | +def normal_arrow_order(request: pytest.FixtureRequest) -> bool: | 
|  | 22 | +    return request.param | 
|  | 23 | + | 
|  | 24 | + | 
|  | 25 | +@pytest.fixture( | 
|  | 26 | +    params=[ | 
| 15 | 27 |         ((), (), ()), | 
| 16 | 28 |         ((2,), (), (2,)), | 
| 17 | 29 |         ((), (3,), (3,)), | 
|  | 
| 27 | 39 |         ((7, 8), (7, 1), (7, 8)), | 
| 28 | 40 |     ], | 
| 29 | 41 | ) | 
| 30 |  | -@pytest.mark.parametrize( | 
| 31 |  | -    "x", | 
| 32 |  | -    [ | 
|  | 42 | +def broadcast(request: pytest.FixtureRequest) -> Broadcast: | 
|  | 43 | +    return request.param | 
|  | 44 | + | 
|  | 45 | + | 
|  | 46 | +@pytest.fixture( | 
|  | 47 | +    params=[ | 
| 33 | 48 |         (False, False, (1, 1), (1, 1), (1, 1)), | 
| 34 | 49 |         (False, True, (1, 1), (1, 1), (1, 1)), | 
| 35 | 50 |         (True, False, (1, 1), (1, 1), (1, 1)), | 
|  | 
| 40 | 55 |         (True, True, (2, 2), (2, 2), (2, 2)), | 
| 41 | 56 |     ], | 
| 42 | 57 | ) | 
|  | 58 | +def x(request: pytest.FixtureRequest) -> MatmulCase: | 
|  | 59 | +    return request.param | 
|  | 60 | + | 
|  | 61 | + | 
| 43 | 62 | def test_matmul( | 
| 44 | 63 |     a_is_vector: bool, | 
| 45 | 64 |     b_is_vector: bool, | 
| @@ -101,40 +120,6 @@ def test_matmul( | 
| 101 | 120 |     assert torch.allclose(c.tensor, expected) | 
| 102 | 121 | 
 | 
| 103 | 122 | 
 | 
| 104 |  | -@pytest.mark.parametrize("a_is_vector", [False, True]) | 
| 105 |  | -@pytest.mark.parametrize("b_is_vector", [False, True]) | 
| 106 |  | -@pytest.mark.parametrize("normal_arrow_order", [False, True]) | 
| 107 |  | -@pytest.mark.parametrize( | 
| 108 |  | -    "broadcast", | 
| 109 |  | -    [ | 
| 110 |  | -        ((), (), ()), | 
| 111 |  | -        ((2,), (), (2,)), | 
| 112 |  | -        ((), (3,), (3,)), | 
| 113 |  | -        ((1,), (4,), (4,)), | 
| 114 |  | -        ((5,), (1,), (5,)), | 
| 115 |  | -        ((6,), (6,), (6,)), | 
| 116 |  | -        ((7, 8), (7, 8), (7, 8)), | 
| 117 |  | -        ((1, 8), (7, 8), (7, 8)), | 
| 118 |  | -        ((8,), (7, 8), (7, 8)), | 
| 119 |  | -        ((7, 1), (7, 8), (7, 8)), | 
| 120 |  | -        ((7, 8), (1, 8), (7, 8)), | 
| 121 |  | -        ((7, 8), (8,), (7, 8)), | 
| 122 |  | -        ((7, 8), (7, 1), (7, 8)), | 
| 123 |  | -    ], | 
| 124 |  | -) | 
| 125 |  | -@pytest.mark.parametrize( | 
| 126 |  | -    "x", | 
| 127 |  | -    [ | 
| 128 |  | -        (False, False, (1, 1), (1, 1), (1, 1)), | 
| 129 |  | -        (False, True, (1, 1), (1, 1), (1, 1)), | 
| 130 |  | -        (True, False, (1, 1), (1, 1), (1, 1)), | 
| 131 |  | -        (True, True, (1, 1), (1, 1), (1, 1)), | 
| 132 |  | -        (False, False, (2, 2), (2, 2), (2, 2)), | 
| 133 |  | -        (False, True, (2, 2), (2, 2), (2, 2)), | 
| 134 |  | -        (True, False, (2, 2), (2, 2), (2, 2)), | 
| 135 |  | -        (True, True, (2, 2), (2, 2), (2, 2)), | 
| 136 |  | -    ], | 
| 137 |  | -) | 
| 138 | 123 | @pytest.mark.parametrize("impure_even_for_broadcast_indices", [1, 2]) | 
| 139 | 124 | def test_matmul_unpure_even( | 
| 140 | 125 |     a_is_vector: bool, | 
| @@ -191,3 +176,99 @@ def test_matmul_unpure_even( | 
| 191 | 176 |         pytest.skip("One of the two tensors needs to have a dimension greater than 2") | 
| 192 | 177 |     with pytest.raises(AssertionError, match="All edges except the last two must be pure even"): | 
| 193 | 178 |         _ = a.matmul(b) | 
|  | 179 | + | 
|  | 180 | + | 
|  | 181 | +def test_matmul_operator_matmul( | 
|  | 182 | +    a_is_vector: bool, | 
|  | 183 | +    b_is_vector: bool, | 
|  | 184 | +    normal_arrow_order: bool, | 
|  | 185 | +    broadcast: Broadcast, | 
|  | 186 | +) -> None: | 
|  | 187 | +    normal_arrow_order = True | 
|  | 188 | +    broadcast_a, broadcast_b, broadcast_result = (7, 8), (7, 1), (7, 8) | 
|  | 189 | +    arrow_a, arrow_b, edge_a, edge_common, edge_b = True, True, (2, 2), (2, 2), (2, 2) | 
|  | 190 | +    dim_a = sum(edge_a) | 
|  | 191 | +    dim_common = sum(edge_common) | 
|  | 192 | +    dim_b = sum(edge_b) | 
|  | 193 | +    a = GrassmannTensor( | 
|  | 194 | +        (*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False), | 
|  | 195 | +        (*((i, 0) for i in broadcast_a), edge_a, edge_common), | 
|  | 196 | +        torch.randn([*broadcast_a, dim_a, dim_common]), | 
|  | 197 | +    ).update_mask() | 
|  | 198 | + | 
|  | 199 | +    b = GrassmannTensor( | 
|  | 200 | +        (*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b), | 
|  | 201 | +        (*((i, 0) for i in broadcast_b), edge_common, edge_b), | 
|  | 202 | +        torch.randn([*broadcast_b, dim_common, dim_b]), | 
|  | 203 | +    ).update_mask() | 
|  | 204 | + | 
|  | 205 | +    c = a @ b | 
|  | 206 | +    expected = a.tensor.matmul(b.tensor) | 
|  | 207 | +    assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b) | 
|  | 208 | +    assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b) | 
|  | 209 | +    assert torch.allclose(c.tensor, expected) | 
|  | 210 | + | 
|  | 211 | + | 
|  | 212 | +@pytest.fixture( | 
|  | 213 | +    params=[ | 
|  | 214 | +        GrassmannTensor((False, False), ((2, 2), (1, 3)), torch.randn([4, 4])), | 
|  | 215 | +        GrassmannTensor((True, False, True), ((1, 1), (2, 2), (3, 1)), torch.randn([2, 4, 4])), | 
|  | 216 | +        GrassmannTensor( | 
|  | 217 | +            (True, True, False, False), ((1, 2), (2, 2), (1, 1), (3, 1)), torch.randn([3, 4, 2, 4]) | 
|  | 218 | +        ), | 
|  | 219 | +    ] | 
|  | 220 | +) | 
|  | 221 | +def tensors(request: pytest.FixtureRequest) -> GrassmannTensor: | 
|  | 222 | +    return request.param | 
|  | 223 | + | 
|  | 224 | + | 
|  | 225 | +@pytest.mark.parametrize( | 
|  | 226 | +    "unsupported_type", | 
|  | 227 | +    [ | 
|  | 228 | +        "string",  # string | 
|  | 229 | +        None,  # NoneType | 
|  | 230 | +        {"key", "value"},  # dict | 
|  | 231 | +        [1, 2, 3],  # list | 
|  | 232 | +        {1, 2},  # set | 
|  | 233 | +        object(),  # arbitrary object | 
|  | 234 | +    ], | 
|  | 235 | +) | 
|  | 236 | +def test_matmul_unsupported_type_raises_typeerror( | 
|  | 237 | +    unsupported_type: typing.Any, | 
|  | 238 | +    tensors: GrassmannTensor, | 
|  | 239 | +) -> None: | 
|  | 240 | +    with pytest.raises(TypeError): | 
|  | 241 | +        _ = tensors @ unsupported_type | 
|  | 242 | + | 
|  | 243 | +    with pytest.raises(TypeError): | 
|  | 244 | +        _ = unsupported_type @ tensors | 
|  | 245 | + | 
|  | 246 | +    with pytest.raises(TypeError): | 
|  | 247 | +        tensors @= unsupported_type | 
|  | 248 | + | 
|  | 249 | + | 
|  | 250 | +def test_matmul_operator_rmatmul() -> None: | 
|  | 251 | +    normal_arrow_order = True | 
|  | 252 | +    broadcast_a, broadcast_b, broadcast_result = (7, 8), (7, 1), (7, 8) | 
|  | 253 | +    arrow_a, arrow_b, edge_a, edge_common, edge_b = True, True, (2, 2), (2, 2), (2, 2) | 
|  | 254 | +    dim_a = sum(edge_a) | 
|  | 255 | +    dim_common = sum(edge_common) | 
|  | 256 | +    dim_b = sum(edge_b) | 
|  | 257 | +    a = GrassmannTensor( | 
|  | 258 | +        (*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False), | 
|  | 259 | +        (*((i, 0) for i in broadcast_a), edge_a, edge_common), | 
|  | 260 | +        torch.randn([*broadcast_a, dim_a, dim_common]), | 
|  | 261 | +    ).update_mask() | 
|  | 262 | + | 
|  | 263 | +    b = GrassmannTensor( | 
|  | 264 | +        (*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b), | 
|  | 265 | +        (*((i, 0) for i in broadcast_b), edge_common, edge_b), | 
|  | 266 | +        torch.randn([*broadcast_b, dim_common, dim_b]), | 
|  | 267 | +    ).update_mask() | 
|  | 268 | + | 
|  | 269 | +    c = a | 
|  | 270 | +    c @= b | 
|  | 271 | +    expected = a.tensor.matmul(b.tensor) | 
|  | 272 | +    assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b) | 
|  | 273 | +    assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b) | 
|  | 274 | +    assert torch.allclose(c.tensor, expected) | 
0 commit comments