Skip to content

Commit 1df35c3

Browse files
committed
dev(exponential): add support for exponential
- Add support for exponential - Add test cases for exponential
1 parent da5e0de commit 1df35c3

File tree

2 files changed

+83
-13
lines changed

2 files changed

+83
-13
lines changed

grassmann_tensor/tensor.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,27 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor:
569569
_tensor=tensor,
570570
)
571571

572+
def _group_legs(
573+
self,
574+
left_legs: typing.Iterable[int],
575+
) -> tuple[GrassmannTensor, tuple[int, ...], tuple[int, ...]]:
576+
left_legs = tuple(int(i) for i in left_legs)
577+
right_legs = tuple(i for i in range(self.tensor.dim()) if i not in left_legs)
578+
assert set(left_legs) | set(right_legs) == set(range(self.tensor.dim())), (
579+
"Left/right must cover all tensor legs."
580+
)
581+
582+
order = left_legs + right_legs
583+
584+
tensor = self.permute(order)
585+
586+
left_dim = math.prod(tensor.tensor.shape[: len(left_legs)])
587+
right_dim = math.prod(tensor.tensor.shape[len(left_legs) :])
588+
589+
tensor = tensor.reshape((left_dim, right_dim))
590+
591+
return tensor, left_legs, right_legs
592+
572593
def svd(
573594
self,
574595
free_names_u: tuple[int, ...],
@@ -591,22 +612,10 @@ def svd(
591612
Furthermore, if the distance between any two singular values is close to zero, the gradient
592613
will be numerically unstable, as it depends on the singular values
593614
"""
594-
left_legs = tuple(int(i) for i in free_names_u)
595-
right_legs = tuple(i for i in range(self.tensor.dim()) if i not in left_legs)
596-
assert set(left_legs) | set(right_legs) == set(range(self.tensor.dim())), (
597-
"Left/right must cover all tensor legs."
598-
)
599-
600615
if isinstance(cutoff, tuple):
601616
assert len(cutoff) == 2, "The length of cutoff must be 2 if cutoff is a tuple."
602617

603-
order = left_legs + right_legs
604-
tensor = self.permute(order)
605-
606-
left_dim = math.prod(tensor.tensor.shape[: len(left_legs)])
607-
right_dim = math.prod(tensor.tensor.shape[len(left_legs) :])
608-
609-
tensor = tensor.reshape((left_dim, right_dim))
618+
tensor, left_legs, right_legs = self._group_legs(free_names_u)
610619

611620
(even_left, odd_left) = tensor.edges[0]
612621
(even_right, odd_right) = tensor.edges[1]
@@ -709,6 +718,48 @@ def svd(
709718

710719
return U, S, Vh
711720

721+
def exponential(self, pairs: tuple[int, ...]) -> GrassmannTensor:
722+
tensor, left_legs, right_legs = self._group_legs(pairs)
723+
724+
axes_to_reverse = [i for i in range(2) if tensor.arrow[i]]
725+
if axes_to_reverse:
726+
tensor = tensor.reverse(tuple(axes_to_reverse))
727+
728+
left_dim, right_dim = tensor.tensor.shape
729+
730+
assert left_dim == right_dim, (
731+
f"Exponential requires a square operator, but got {left_dim} x {right_dim}."
732+
)
733+
734+
(even_left, odd_left) = tensor.edges[0]
735+
(even_right, odd_right) = tensor.edges[1]
736+
737+
even_tensor = tensor.tensor[:even_left, :even_right]
738+
odd_tensor = tensor.tensor[even_left:, even_right:]
739+
740+
even_tensor_exp = torch.linalg.matrix_exp(even_tensor)
741+
odd_tensor_exp = torch.linalg.matrix_exp(odd_tensor)
742+
743+
tensor_exp = torch.block_diag(even_tensor_exp, odd_tensor_exp) # type: ignore[no-untyped-call]
744+
745+
tensor_exp = dataclasses.replace(tensor, _tensor=tensor_exp)
746+
747+
if axes_to_reverse:
748+
tensor_exp = tensor_exp.reverse(tuple(axes_to_reverse))
749+
750+
order = left_legs + right_legs
751+
edges_after_permute = tuple(self.edges[i] for i in order)
752+
tensor_exp = tensor_exp.reshape(edges_after_permute)
753+
754+
inv = [0] * self.tensor.dim()
755+
for new_pos, orig_idx in enumerate(order):
756+
inv[orig_idx] = new_pos
757+
inv_order = tuple(inv)
758+
759+
tensor_exp = tensor_exp.permute(inv_order)
760+
761+
return tensor_exp
762+
712763
def __post_init__(self) -> None:
713764
assert len(self._arrow) == self._tensor.dim(), (
714765
f"Arrow length ({len(self._arrow)}) must match tensor dimensions ({self._tensor.dim()})."

tests/exponential_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
3+
from grassmann_tensor import GrassmannTensor
4+
5+
6+
def test_exponential() -> None:
7+
a = GrassmannTensor(
8+
(True, True, True, True),
9+
((4, 4), (8, 8), (4, 4), (8, 8)),
10+
torch.randn(8, 16, 8, 16, dtype=torch.float64),
11+
)
12+
a.exponential((0, 3))
13+
14+
15+
def test_exponential_with_empty_parity_block() -> None:
16+
a = GrassmannTensor((False, True), ((1, 0), (1, 0)), torch.randn(1, 1))
17+
a.exponential((0,))
18+
b = GrassmannTensor((False, True), ((0, 1), (0, 1)), torch.randn(1, 1))
19+
b.exponential((0,))

0 commit comments

Comments
 (0)