Skip to content

Commit c6fca59

Browse files
committed
chore(svd): improve svd function and add test cases
- Remove redundant parameters full matrics. - Remove redundant code. - Add support for new type of cutoff: tuple[int, int] - Add support for grassmann tensor with empty parity block
1 parent 79e5d1f commit c6fca59

File tree

2 files changed

+62
-25
lines changed

2 files changed

+62
-25
lines changed

grassmann_tensor/tensor.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,19 @@ def _reorder_indices(
221221
sign = (count & 2).to(dtype=torch.bool)
222222
return len(even), len(odd), reorder, sign.flatten()
223223

224+
def _calculate_even_odd(self) -> tuple[int, int]:
225+
even, odd = 1, 0
226+
for e, o in self.edges:
227+
even, odd = even * e + odd * o, even * o + odd * e
228+
return even, odd
229+
230+
@staticmethod
231+
def calculate_even_odd(edges: tuple[tuple[int, int], ...]) -> tuple[int, int]:
232+
even, odd = 1, 0
233+
for e, o in edges:
234+
even, odd = even * e + odd * o, even * o + odd * e
235+
return even, odd
236+
224237
def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTensor:
225238
"""
226239
Reshape the Grassmann tensor, which may split or merge edges.
@@ -253,12 +266,33 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
253266
merging_reorder: list[tuple[int, torch.Tensor]] = []
254267
merging_sign: list[tuple[int, torch.Tensor]] = []
255268

269+
original_self_is_scalar = self.tensor.dim() == 0
270+
if original_self_is_scalar:
271+
normalized: list[tuple[int, int]] = []
272+
for item in new_shape:
273+
if item == -1:
274+
raise AssertionError("Cannot use -1 when reshaping from a scalar")
275+
if isinstance(item, int):
276+
if item != 1:
277+
raise AssertionError(
278+
f"Ambiguous integer dim {item} from scalar. "
279+
"Use explicit (even, odd) pairs, or only use 1 for trivial edges."
280+
)
281+
normalized.append((1, 0))
282+
else:
283+
normalized.append(item)
284+
new_shape = tuple(normalized)
285+
edges_only = typing.cast(tuple[tuple[int, int], ...], new_shape)
286+
assert self.calculate_even_odd(edges_only) == (1, 0), (
287+
"Cannot split none edges into illegal edges"
288+
)
289+
256290
cursor_plan: int = 0
257291
cursor_self: int = 0
258292
while cursor_plan != len(new_shape) or cursor_self != self.tensor.dim():
259293
if len(new_shape) == 0:
260-
assert all(edge == (0, 1) or edge == (1, 0) for edge in self.edges), (
261-
f"Edge must be (0, 1) or (1, 0) but got {self.edges}"
294+
assert self._calculate_even_odd() == (1, 0), (
295+
"Only pure even edges can be merged into none edges"
262296
)
263297
cursor_self = self.tensor.dim() - 1
264298
elif cursor_plan != len(new_shape) and new_shape[cursor_plan] == -1:
@@ -291,9 +325,7 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
291325
else cursor_new_shape[0] + cursor_new_shape[1]
292326
)
293327
# one of total and shape[cursor_self] is not trivial, otherwise it should be handled before
294-
if self.tensor.dim() == 0:
295-
merging = False
296-
elif total == self.tensor.shape[cursor_self]:
328+
if total == self.tensor.shape[cursor_self]:
297329
# We do not know whether it is merging or splitting, check more
298330
if isinstance(cursor_new_shape, int) or cursor_new_shape == self.edges[cursor_self]:
299331
# If the new shape is exactly the same as the current edge, we treat it as no change
@@ -387,23 +419,16 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
387419
plan_total *= new_cursor_new_shape[0] + new_cursor_new_shape[1]
388420
new_cursor_plan += 1
389421
# One dimension included, check if we can stop
390-
if self.tensor.dim() == 0:
422+
if plan_total == self.tensor.shape[cursor_self]:
423+
# new_shape block has been verified to be always tuple[int, int] before
391424
even, odd, reorder, sign = self._reorder_indices(
392-
typing.cast(tuple[tuple[int, int], ...], new_shape)
393-
)
394-
new_cursor_plan = len(new_shape)
395-
break
396-
else:
397-
if plan_total == self.tensor.shape[cursor_self]:
398-
# new_shape block has been verified to be always tuple[int, int] before
399-
even, odd, reorder, sign = self._reorder_indices(
400-
typing.cast(
401-
tuple[tuple[int, int], ...],
402-
new_shape[cursor_plan:new_cursor_plan],
403-
)
425+
typing.cast(
426+
tuple[tuple[int, int], ...],
427+
new_shape[cursor_plan:new_cursor_plan],
404428
)
405-
if (even, odd) == self.edges[cursor_self]:
406-
break
429+
)
430+
if (even, odd) == self.edges[cursor_self]:
431+
break
407432
# For some reason we cannot stop here, continue to include more dimension, check something before continue
408433
assert plan_total <= self.tensor.shape[cursor_self], (
409434
f"Dimension mismatch in splitting with edges {self.edges} and new shape {new_shape}."
@@ -415,10 +440,7 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
415440
for i in range(cursor_plan, new_cursor_plan):
416441
# new_shape block has been verified to be always tuple[int, int] in the loop
417442
new_cursor_new_shape = typing.cast(tuple[int, int], new_shape[i])
418-
if self.tensor.dim() == 0:
419-
arrow.append(False)
420-
else:
421-
arrow.append(self.arrow[cursor_self])
443+
arrow.append(self.arrow[cursor_self])
422444
edges.append(new_cursor_new_shape)
423445
shape.append(new_cursor_new_shape[0] + new_cursor_new_shape[1])
424446
splitting_reorder.append((cursor_self, reorder))

tests/reshape_test.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,5 +180,20 @@ def test_reshape_equal_edges_nontrivial_merging_with_other_edge() -> None:
180180
def test_reshape_with_none() -> None:
181181
a = GrassmannTensor((), (), torch.tensor(2333)).reshape(((1, 0), (1, 0))).reshape(())
182182
assert len(a.arrow) == 0 and len(a.edges) == 0 and a.tensor.dim() == 0
183-
b = GrassmannTensor((), (), torch.tensor(2333)).reshape(((0, 1), (0, 1))).reshape(())
183+
b = GrassmannTensor((), (), torch.tensor(2333)).reshape(((1, 0), (1, 0))).reshape(())
184184
assert len(b.arrow) == 0 and len(b.edges) == 0 and b.tensor.dim() == 0
185+
c = GrassmannTensor((), (), torch.tensor(2333)).reshape((1, 1))
186+
assert len(c.arrow) == 2 and len(c.edges) == 2 and c.tensor.dim() == 2
187+
188+
189+
def test_reshape_with_none_edge_assertion() -> None:
190+
with pytest.raises(AssertionError, match="Only pure even edges can be merged into none edges"):
191+
_ = GrassmannTensor((True, True), ((0, 1), (1, 0)), torch.tensor([[2333]])).reshape(())
192+
with pytest.raises(AssertionError, match="Cannot split none edges into illegal edges"):
193+
_ = GrassmannTensor((), (), torch.tensor(2333)).reshape(((0, 1),))
194+
with pytest.raises(AssertionError, match="Cannot split none edges into illegal edges"):
195+
_ = GrassmannTensor((), (), torch.tensor(2333)).reshape(((0, 1), (1, 0)))
196+
with pytest.raises(AssertionError, match="Cannot use -1 when reshaping from a scalar"):
197+
_ = GrassmannTensor((), (), torch.tensor(2333)).reshape((1, -1))
198+
with pytest.raises(AssertionError, match="Ambiguous integer dim"):
199+
_ = GrassmannTensor((), (), torch.tensor(2333)).reshape((2, 2))

0 commit comments

Comments
 (0)