Skip to content

Commit c511401

Browse files
authored
Merge pull request #75 from USTC-KnowledgeComputingLab/dev/add-support-for-exponential
dev(exponential): add support for exponential Close #75
2 parents da5e0de + f55d7f4 commit c511401

File tree

2 files changed

+97
-13
lines changed

2 files changed

+97
-13
lines changed

grassmann_tensor/tensor.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,27 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor:
569569
_tensor=tensor,
570570
)
571571

572+
def _group_edges(
573+
self,
574+
left_legs: typing.Iterable[int],
575+
) -> tuple[GrassmannTensor, tuple[int, ...], tuple[int, ...]]:
576+
left_legs = tuple(int(i) for i in left_legs)
577+
right_legs = tuple(i for i in range(self.tensor.dim()) if i not in left_legs)
578+
assert set(left_legs) | set(right_legs) == set(range(self.tensor.dim())), (
579+
"Left/right must cover all tensor legs."
580+
)
581+
582+
order = left_legs + right_legs
583+
584+
tensor = self.permute(order)
585+
586+
left_dim = math.prod(tensor.tensor.shape[: len(left_legs)])
587+
right_dim = math.prod(tensor.tensor.shape[len(left_legs) :])
588+
589+
tensor = tensor.reshape((left_dim, right_dim))
590+
591+
return tensor, left_legs, right_legs
592+
572593
def svd(
573594
self,
574595
free_names_u: tuple[int, ...],
@@ -591,22 +612,10 @@ def svd(
591612
Furthermore, if the distance between any two singular values is close to zero, the gradient
592613
will be numerically unstable, as it depends on the singular values
593614
"""
594-
left_legs = tuple(int(i) for i in free_names_u)
595-
right_legs = tuple(i for i in range(self.tensor.dim()) if i not in left_legs)
596-
assert set(left_legs) | set(right_legs) == set(range(self.tensor.dim())), (
597-
"Left/right must cover all tensor legs."
598-
)
599-
600615
if isinstance(cutoff, tuple):
601616
assert len(cutoff) == 2, "The length of cutoff must be 2 if cutoff is a tuple."
602617

603-
order = left_legs + right_legs
604-
tensor = self.permute(order)
605-
606-
left_dim = math.prod(tensor.tensor.shape[: len(left_legs)])
607-
right_dim = math.prod(tensor.tensor.shape[len(left_legs) :])
608-
609-
tensor = tensor.reshape((left_dim, right_dim))
618+
tensor, left_legs, right_legs = self._group_edges(free_names_u)
610619

611620
(even_left, odd_left) = tensor.edges[0]
612621
(even_right, odd_right) = tensor.edges[1]
@@ -709,6 +718,51 @@ def svd(
709718

710719
return U, S, Vh
711720

721+
def _get_inv_order(self, order: tuple[int, ...]) -> tuple[int, ...]:
722+
inv = [0] * self.tensor.dim()
723+
for new_position, origin_idx in enumerate(order):
724+
inv[origin_idx] = new_position
725+
return tuple(inv)
726+
727+
def exponential(self, pairs: tuple[int, ...]) -> GrassmannTensor:
728+
tensor, left_legs, right_legs = self._group_edges(pairs)
729+
730+
edges_to_reverse = [i for i in range(2) if tensor.arrow[i]]
731+
if edges_to_reverse:
732+
tensor = tensor.reverse(tuple(edges_to_reverse))
733+
734+
left_dim, right_dim = tensor.tensor.shape
735+
736+
assert left_dim == right_dim, (
737+
f"Exponential requires a square operator, but got {left_dim} x {right_dim}."
738+
)
739+
740+
(even_left, odd_left) = tensor.edges[0]
741+
(even_right, odd_right) = tensor.edges[1]
742+
743+
even_tensor = tensor.tensor[:even_left, :even_right]
744+
odd_tensor = tensor.tensor[even_left:, even_right:]
745+
746+
even_tensor_exp = torch.linalg.matrix_exp(even_tensor)
747+
odd_tensor_exp = torch.linalg.matrix_exp(odd_tensor)
748+
749+
tensor_exp = torch.block_diag(even_tensor_exp, odd_tensor_exp) # type: ignore[no-untyped-call]
750+
751+
tensor_exp = dataclasses.replace(tensor, _tensor=tensor_exp)
752+
753+
if edges_to_reverse:
754+
tensor_exp = tensor_exp.reverse(tuple(edges_to_reverse))
755+
756+
order = left_legs + right_legs
757+
edges_after_permute = tuple(self.edges[i] for i in order)
758+
tensor_exp = tensor_exp.reshape(edges_after_permute)
759+
760+
inv_order = self._get_inv_order(order)
761+
762+
tensor_exp = tensor_exp.permute(inv_order)
763+
764+
return tensor_exp
765+
712766
def __post_init__(self) -> None:
713767
assert len(self._arrow) == self._tensor.dim(), (
714768
f"Arrow length ({len(self._arrow)}) must match tensor dimensions ({self._tensor.dim()})."

tests/exponential_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
import pytest
3+
4+
from grassmann_tensor import GrassmannTensor
5+
6+
7+
def test_exponential() -> None:
8+
a = GrassmannTensor(
9+
(True, True, True, True),
10+
((4, 4), (8, 8), (4, 4), (8, 8)),
11+
torch.randn(8, 16, 8, 16, dtype=torch.float64),
12+
)
13+
a.exponential((0, 3))
14+
15+
16+
def test_exponential_with_empty_parity_block() -> None:
17+
a = GrassmannTensor((False, True), ((1, 0), (1, 0)), torch.randn(1, 1))
18+
a.exponential((0,))
19+
b = GrassmannTensor((False, True), ((0, 1), (0, 1)), torch.randn(1, 1))
20+
b.exponential((0,))
21+
22+
23+
def test_exponential_assertation() -> None:
24+
a = GrassmannTensor(
25+
(True, True, True, True),
26+
((2, 2), (4, 4), (8, 8), (16, 16)),
27+
torch.randn(4, 8, 16, 32, dtype=torch.float64),
28+
)
29+
with pytest.raises(AssertionError, match="Exponential requires a square operator"):
30+
a.exponential((0, 2))

0 commit comments

Comments
 (0)