Skip to content

Commit 29ef0ad

Browse files
committed
dev(exponential): unify terminology to edges and add test cases
- Unify terminology to `edges` to maintain consistency - Add test cases for assertation
1 parent 1df35c3 commit 29ef0ad

File tree

2 files changed

+26
-12
lines changed

2 files changed

+26
-12
lines changed

grassmann_tensor/tensor.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor:
569569
_tensor=tensor,
570570
)
571571

572-
def _group_legs(
572+
def _group_edges(
573573
self,
574574
left_legs: typing.Iterable[int],
575575
) -> tuple[GrassmannTensor, tuple[int, ...], tuple[int, ...]]:
@@ -615,7 +615,7 @@ def svd(
615615
if isinstance(cutoff, tuple):
616616
assert len(cutoff) == 2, "The length of cutoff must be 2 if cutoff is a tuple."
617617

618-
tensor, left_legs, right_legs = self._group_legs(free_names_u)
618+
tensor, left_legs, right_legs = self._group_edges(free_names_u)
619619

620620
(even_left, odd_left) = tensor.edges[0]
621621
(even_right, odd_right) = tensor.edges[1]
@@ -718,12 +718,18 @@ def svd(
718718

719719
return U, S, Vh
720720

721+
def _get_inv_order(self, order: tuple[int, ...]) -> tuple[int, ...]:
722+
inv = [0] * self.tensor.dim()
723+
for new_pos, orig_idx in enumerate(order):
724+
inv[orig_idx] = new_pos
725+
return tuple(inv)
726+
721727
def exponential(self, pairs: tuple[int, ...]) -> GrassmannTensor:
722-
tensor, left_legs, right_legs = self._group_legs(pairs)
728+
tensor, left_legs, right_legs = self._group_edges(pairs)
723729

724-
axes_to_reverse = [i for i in range(2) if tensor.arrow[i]]
725-
if axes_to_reverse:
726-
tensor = tensor.reverse(tuple(axes_to_reverse))
730+
edges_to_reverse = [i for i in range(2) if tensor.arrow[i]]
731+
if edges_to_reverse:
732+
tensor = tensor.reverse(tuple(edges_to_reverse))
727733

728734
left_dim, right_dim = tensor.tensor.shape
729735

@@ -744,17 +750,14 @@ def exponential(self, pairs: tuple[int, ...]) -> GrassmannTensor:
744750

745751
tensor_exp = dataclasses.replace(tensor, _tensor=tensor_exp)
746752

747-
if axes_to_reverse:
748-
tensor_exp = tensor_exp.reverse(tuple(axes_to_reverse))
753+
if edges_to_reverse:
754+
tensor_exp = tensor_exp.reverse(tuple(edges_to_reverse))
749755

750756
order = left_legs + right_legs
751757
edges_after_permute = tuple(self.edges[i] for i in order)
752758
tensor_exp = tensor_exp.reshape(edges_after_permute)
753759

754-
inv = [0] * self.tensor.dim()
755-
for new_pos, orig_idx in enumerate(order):
756-
inv[orig_idx] = new_pos
757-
inv_order = tuple(inv)
760+
inv_order = self._get_inv_order(order)
758761

759762
tensor_exp = tensor_exp.permute(inv_order)
760763

tests/exponential_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import pytest
23

34
from grassmann_tensor import GrassmannTensor
45

@@ -17,3 +18,13 @@ def test_exponential_with_empty_parity_block() -> None:
1718
a.exponential((0,))
1819
b = GrassmannTensor((False, True), ((0, 1), (0, 1)), torch.randn(1, 1))
1920
b.exponential((0,))
21+
22+
23+
def test_exponential_assertation() -> None:
24+
a = GrassmannTensor(
25+
(True, True, True, True),
26+
((2, 2), (4, 4), (8, 8), (16, 16)),
27+
torch.randn(4, 8, 16, 32, dtype=torch.float64),
28+
)
29+
with pytest.raises(AssertionError, match="Exponential requires a square operator"):
30+
a.exponential((0, 2))

0 commit comments

Comments
 (0)