Skip to content

Commit 44d01bd

Browse files
committed
dev(svd): add support for svd
- Add support for svd, perform SVD separately on even/odd parity blocks - Add support for two type of cutoff: int and tuple[int, int] - Add support for grassmann tensor with empty parity block - Add parameterized test cases for svd
1 parent 290bcd3 commit 44d01bd

File tree

3 files changed

+453
-1
lines changed

3 files changed

+453
-1
lines changed

grassmann_tensor/tensor.py

Lines changed: 165 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import dataclasses
1010
import functools
1111
import typing
12+
import math
1213
import torch
1314

1415

@@ -295,9 +296,28 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
295296
tensor = self.tensor.reshape(())
296297
return GrassmannTensor(_arrow=(), _edges=(), _tensor=tensor)
297298

299+
if new_shape == (1,) and int(self.tensor.numel()) == 1:
300+
eo = self._calculate_even_odd()
301+
new_shape = (eo,)
302+
298303
cursor_plan: int = 0
299304
cursor_self: int = 0
300305
while cursor_plan != len(new_shape) or cursor_self != self.tensor.dim():
306+
if cursor_self == self.tensor.dim() and cursor_plan != len(new_shape):
307+
new_shape_check = new_shape[cursor_plan]
308+
if (isinstance(new_shape_check, int) and new_shape_check == 1) or (
309+
new_shape_check == (1, 0)
310+
):
311+
arrow.append(False)
312+
edges.append((1, 0))
313+
shape.append(1)
314+
cursor_plan += 1
315+
continue
316+
raise AssertionError(
317+
"New shape exceeds after exhausting self dimensions: "
318+
f"edges={self.edges}, new_shape={new_shape}"
319+
)
320+
301321
if cursor_plan != len(new_shape) and new_shape[cursor_plan] == -1:
302322
# Does not change
303323
arrow.append(self.arrow[cursor_self])
@@ -306,7 +326,11 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
306326
cursor_self += 1
307327
cursor_plan += 1
308328
continue
309-
elif cursor_plan != len(new_shape) and new_shape[cursor_plan] == (1, 0):
329+
elif (
330+
cursor_plan != len(new_shape)
331+
and new_shape[cursor_plan] == (1, 0)
332+
and cursor_plan < len(new_shape) - 1
333+
):
310334
# A trivial plan edge
311335
arrow.append(False)
312336
edges.append((1, 0))
@@ -532,6 +556,146 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor:
532556
_tensor=tensor,
533557
)
534558

559+
def svd(
560+
self,
561+
free_names_u: tuple[int, ...],
562+
*,
563+
cutoff: int | None | tuple[int, int] = None,
564+
) -> tuple[GrassmannTensor, GrassmannTensor, GrassmannTensor]:
565+
"""
566+
This function is used to computes the singular value decomposition of a grassmann tensor.
567+
The SVD are implemented by follow steps:
568+
1. Split the legs into left and right;
569+
2. Merge the tensor with two groups.
570+
3. Split the block tensor into two parts.
571+
4. Compute the singular value decomposition.
572+
5. Use cutoff to keep the largest cutoff singular values (globally across even/odd blocks).
573+
6. Contract U, S and Vh.
574+
7. Split the legs into original left and right.
575+
The returned tensors U and V are not unique, nor are they continuous with respect to self.
576+
Due to this lack of uniqueness, different hardware and software may compute different singular vectors.
577+
Gradients computed using U or Vh will only be finite when A does not have repeated singular values.
578+
Furthermore, if the distance between any two singular values is close to zero, the gradient
579+
will be numerically unstable, as it depends on the singular values
580+
"""
581+
left_legs = tuple(int(i) for i in free_names_u)
582+
right_legs = tuple(i for i in range(self.tensor.dim()) if i not in left_legs)
583+
assert set(left_legs) | set(right_legs) == set(range(self.tensor.dim())), (
584+
"Left/right must cover all tensor legs."
585+
)
586+
587+
if isinstance(cutoff, tuple):
588+
assert len(cutoff) == 2, "The length of cutoff must be 2 if cutoff is a tuple."
589+
590+
order = left_legs + right_legs
591+
tensor = self.permute(order)
592+
593+
left_dim = math.prod(tensor.tensor.shape[: len(left_legs)])
594+
right_dim = math.prod(tensor.tensor.shape[len(left_legs) :])
595+
596+
tensor = tensor.reshape((left_dim, right_dim))
597+
598+
(even_left, odd_left) = tensor.edges[0]
599+
(even_right, odd_right) = tensor.edges[1]
600+
even_tensor = tensor.tensor[:even_left, :even_right]
601+
odd_tensor = tensor.tensor[even_left:, even_right:]
602+
603+
if even_tensor.numel() > 0:
604+
U_even, S_even, Vh_even = torch.linalg.svd(even_tensor, full_matrices=False)
605+
else:
606+
U_even = even_tensor.new_zeros((even_left, 0))
607+
S_even = even_tensor.new_zeros((0,))
608+
Vh_even = even_tensor.new_zeros((0, even_right))
609+
610+
if odd_tensor.numel() > 0:
611+
U_odd, S_odd, Vh_odd = torch.linalg.svd(odd_tensor, full_matrices=False)
612+
else:
613+
U_odd = odd_tensor.new_zeros((odd_left, 0))
614+
S_odd = odd_tensor.new_zeros((0,))
615+
Vh_odd = odd_tensor.new_zeros((0, odd_right))
616+
617+
n_even, n_odd = S_even.shape[0], S_odd.shape[0]
618+
619+
if cutoff is None:
620+
k_even, k_odd = n_even, n_odd
621+
elif isinstance(cutoff, int):
622+
if n_even == 0 and n_odd == 0:
623+
raise RuntimeError("Both parity block are empty. Can not form SVD.")
624+
assert cutoff > 0, f"Cutoff must be greater than 0, but got {cutoff}"
625+
k_even = min(cutoff, n_even)
626+
k_odd = min(cutoff, n_odd)
627+
elif isinstance(cutoff, tuple):
628+
assert len(cutoff) == 2, "The length of cutoff must be 2 if cutoff is a tuple."
629+
if n_even == 0 and n_odd == 0:
630+
raise RuntimeError("Both parity block are empty. Can not form SVD.")
631+
k_even = max(0, min(int(cutoff[0]), n_even))
632+
k_odd = max(0, min(int(cutoff[1]), n_odd))
633+
else:
634+
raise ValueError(
635+
f"Cutoff must be an integer or a tuple of two integers, but got {cutoff}"
636+
)
637+
638+
assert (k_even > 0 or n_even == 0) and (k_odd > 0 or n_odd == 0), (
639+
"Per-block cutoff must be compatible with available singulars"
640+
)
641+
642+
keep_even = torch.zeros(n_even, dtype=torch.bool, device=S_even.device)
643+
keep_odd = torch.zeros(n_odd, dtype=torch.bool, device=S_odd.device)
644+
if k_even > 0:
645+
keep_even[:k_even] = True
646+
if k_odd > 0:
647+
keep_odd[:k_odd] = True
648+
649+
U_even_trunc = U_even[:, keep_even]
650+
S_even_trunc = S_even[keep_even]
651+
Vh_even_trunc = Vh_even[keep_even, :]
652+
653+
U_odd_trunc = U_odd[:, keep_odd]
654+
S_odd_trunc = S_odd[keep_odd]
655+
Vh_odd_trunc = Vh_odd[keep_odd, :]
656+
657+
U_tensor = torch.block_diag(U_even_trunc, U_odd_trunc) # type: ignore[no-untyped-call]
658+
S_tensor = torch.cat([S_even_trunc, S_odd_trunc], dim=0)
659+
Vh_tensor = torch.block_diag(Vh_even_trunc, Vh_odd_trunc) # type: ignore[no-untyped-call]
660+
661+
U_edges = (
662+
(U_even_trunc.shape[0], U_odd_trunc.shape[0]),
663+
(U_even_trunc.shape[1], U_odd_trunc.shape[1]),
664+
)
665+
S_edges = (
666+
(U_even_trunc.shape[1], U_odd_trunc.shape[1]),
667+
(Vh_even_trunc.shape[0], Vh_odd_trunc.shape[0]),
668+
)
669+
Vh_edges = (
670+
(Vh_even_trunc.shape[0], Vh_odd_trunc.shape[0]),
671+
(Vh_even_trunc.shape[1], Vh_odd_trunc.shape[1]),
672+
)
673+
674+
U = GrassmannTensor(_arrow=(True, True), _edges=U_edges, _tensor=U_tensor)
675+
S = GrassmannTensor(
676+
_arrow=(
677+
False,
678+
True,
679+
),
680+
_edges=S_edges,
681+
_tensor=torch.diag(S_tensor),
682+
)
683+
Vh = GrassmannTensor(_arrow=(False, True), _edges=Vh_edges, _tensor=Vh_tensor)
684+
# Split
685+
left_arrow = [self.arrow[i] for i in left_legs]
686+
left_edges = [self.edges[i] for i in left_legs]
687+
688+
right_arrow = [self.arrow[i] for i in right_legs]
689+
right_edges = [self.edges[i] for i in right_legs]
690+
691+
U = U.reshape((*left_edges, U_edges[1]))
692+
U._arrow = tuple(left_arrow + [True])
693+
694+
Vh = Vh.reshape((Vh_edges[0], *right_edges))
695+
Vh._arrow = tuple([False] + right_arrow)
696+
697+
return U, S, Vh
698+
535699
def __post_init__(self) -> None:
536700
assert len(self._arrow) == self._tensor.dim(), (
537701
f"Arrow length ({len(self._arrow)}) must match tensor dimensions ({self._tensor.dim()})."

tests/reshape_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,37 @@ def test_reshape_with_none_edge_assertion() -> None:
197197
_ = GrassmannTensor((), (), torch.tensor(2333)).reshape((1, -1))
198198
with pytest.raises(AssertionError, match="Ambiguous integer dim"):
199199
_ = GrassmannTensor((), (), torch.tensor(2333)).reshape((2, 2))
200+
201+
202+
@pytest.mark.parametrize(
203+
"arrow, edges, tensor",
204+
[
205+
((True, True), ((0, 1), (0, 1)), torch.tensor([[2333]])),
206+
((True, True, True), ((0, 1), (1, 0), (0, 1)), torch.tensor([[[2333]]])),
207+
],
208+
)
209+
@pytest.mark.parametrize(
210+
"shape",
211+
[
212+
(1,),
213+
(1, 1),
214+
(1, 1, 1),
215+
(1, 1, 1, 1),
216+
],
217+
)
218+
def test_reshape_with_one_dimension(
219+
arrow: tuple[bool, ...],
220+
edges: tuple[tuple[int, int], ...],
221+
tensor: torch.Tensor,
222+
shape: tuple[int, ...],
223+
) -> None:
224+
a = GrassmannTensor(arrow, edges, tensor).reshape(shape)
225+
assert (
226+
len(a.arrow) == len(shape) and len(a.edges) == len(shape) and a.tensor.dim() == len(shape)
227+
)
228+
229+
230+
def test_reshape_trailing_nontrivial_dim_raises() -> None:
231+
a = GrassmannTensor((True,), ((2, 2),), torch.randn([4]))
232+
with pytest.raises(AssertionError, match="New shape exceeds after exhausting self dimensions"):
233+
_ = a.reshape((-1, (2, 2)))

0 commit comments

Comments
 (0)