Skip to content

Commit b712ec3

Browse files
authored
Merge pull request #67 from USTC-KnowledgeComputingLab/dev/add-support-for-reshape-with-none
Fix issue when reshape with none edge.
2 parents eed595c + c0904be commit b712ec3

File tree

2 files changed

+86
-39
lines changed

2 files changed

+86
-39
lines changed

grassmann_tensor/tensor.py

Lines changed: 77 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -256,33 +256,44 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
256256
cursor_plan: int = 0
257257
cursor_self: int = 0
258258
while cursor_plan != len(new_shape) or cursor_self != self.tensor.dim():
259-
if new_shape[cursor_plan] == -1:
259+
if len(new_shape) == 0:
260+
assert all(edge == (0, 1) or edge == (1, 0) for edge in self.edges), (
261+
f"Edge must be (0, 1) or (1, 0) but got {self.edges}"
262+
)
263+
cursor_self = self.tensor.dim() - 1
264+
elif cursor_plan != len(new_shape) and new_shape[cursor_plan] == -1:
260265
# Does not change
261266
arrow.append(self.arrow[cursor_self])
262267
edges.append(self.edges[cursor_self])
263268
shape.append(self.tensor.shape[cursor_self])
264269
cursor_self += 1
265270
cursor_plan += 1
266271
continue
267-
if new_shape[cursor_plan] == (1, 0):
268-
# An trivial plan edge
272+
elif cursor_plan != len(new_shape) and new_shape[cursor_plan] == (1, 0):
273+
# A trivial plan edge
269274
arrow.append(False)
270275
edges.append((1, 0))
271276
shape.append(1)
272277
cursor_plan += 1
273278
continue
274-
if self.edges[cursor_self] == (1, 0):
275-
# An trivial self edge
279+
elif cursor_self != self.tensor.dim() and self.edges[cursor_self] == (1, 0):
280+
# A trivial self edge
276281
cursor_self += 1
277282
continue
278-
cursor_new_shape = new_shape[cursor_plan]
279-
total = (
280-
cursor_new_shape
281-
if isinstance(cursor_new_shape, int)
282-
else cursor_new_shape[0] + cursor_new_shape[1]
283-
)
283+
if len(new_shape) == 0:
284+
cursor_new_shape = typing.cast(int | tuple[int, int], tuple())
285+
total = 1
286+
else:
287+
cursor_new_shape = new_shape[cursor_plan]
288+
total = (
289+
cursor_new_shape
290+
if isinstance(cursor_new_shape, int)
291+
else cursor_new_shape[0] + cursor_new_shape[1]
292+
)
284293
# one of total and shape[cursor_self] is not trivial, otherwise it should be handled before
285-
if total == self.tensor.shape[cursor_self]:
294+
if self.tensor.dim() == 0:
295+
merging = False
296+
elif total == self.tensor.shape[cursor_self]:
286297
# We do not know whether it is merging or splitting, check more
287298
if isinstance(cursor_new_shape, int) or cursor_new_shape == self.edges[cursor_self]:
288299
# If the new shape is exactly the same as the current edge, we treat it as no change
@@ -296,6 +307,9 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
296307
cursor_self_finding = cursor_self
297308
cursor_self_found = False
298309
while True:
310+
if len(new_shape) == 0:
311+
cursor_self_found = True
312+
break
299313
cursor_self_finding += 1
300314
if cursor_self_finding == self.tensor.dim():
301315
break
@@ -306,15 +320,19 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
306320
break
307321
break
308322
merging = cursor_self_found
309-
if total > self.tensor.shape[cursor_self]:
323+
elif total > self.tensor.shape[cursor_self]:
310324
merging = True
311-
if total < self.tensor.shape[cursor_self]:
325+
elif total < self.tensor.shape[cursor_self]:
312326
merging = False
313327
if merging:
314328
# Merging between [cursor_self, new_cursor_self) and the another side contains dimension as self_total
315329
new_cursor_self = cursor_self
316330
self_total = 1
317331
while True:
332+
if len(new_shape) == 0:
333+
new_cursor_self += 1
334+
even, odd, reorder, sign = self._reorder_indices(self.edges)
335+
break
318336
# Try to include more dimension from self
319337
self_total *= self.tensor.shape[new_cursor_self]
320338
new_cursor_self += 1
@@ -336,19 +354,26 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
336354
f"New shape exceeds in merging with edges {self.edges} and new shape {new_shape}."
337355
)
338356
# The merging block [cursor_self, new_cursor_self) has been determined
339-
arrow.append(self.arrow[cursor_self])
340-
assert all(
341-
self_arrow == arrow[-1]
342-
for self_arrow in self.arrow[cursor_self:new_cursor_self]
343-
), (
344-
f"Cannot merge edges with different arrows {self.arrow[cursor_self:new_cursor_self]}."
345-
)
346-
edges.append((even, odd))
347-
shape.append(total)
348-
merging_sign.append((cursor_plan, sign))
349-
merging_reorder.append((cursor_plan, reorder))
350-
cursor_self = new_cursor_self
351-
cursor_plan += 1
357+
if len(new_shape) == 0:
358+
arrow = []
359+
edges = []
360+
shape = []
361+
merging_sign.append((cursor_plan, sign))
362+
cursor_self = new_cursor_self
363+
else:
364+
arrow.append(self.arrow[cursor_self])
365+
assert all(
366+
self_arrow == arrow[-1]
367+
for self_arrow in self.arrow[cursor_self:new_cursor_self]
368+
), (
369+
f"Cannot merge edges with different arrows {self.arrow[cursor_self:new_cursor_self]}."
370+
)
371+
edges.append((even, odd))
372+
shape.append(total)
373+
merging_sign.append((cursor_plan, sign))
374+
merging_reorder.append((cursor_plan, reorder))
375+
cursor_self = new_cursor_self
376+
cursor_plan += 1
352377
else:
353378
# Splitting between [cursor_plan, new_cursor_plan) and the another side contains dimension as plan_total
354379
new_cursor_plan = cursor_plan
@@ -362,15 +387,23 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
362387
plan_total *= new_cursor_new_shape[0] + new_cursor_new_shape[1]
363388
new_cursor_plan += 1
364389
# One dimension included, check if we can stop
365-
if plan_total == self.tensor.shape[cursor_self]:
366-
# new_shape block has been verified to be always tuple[int, int] before
390+
if self.tensor.dim() == 0:
367391
even, odd, reorder, sign = self._reorder_indices(
368-
typing.cast(
369-
tuple[tuple[int, int], ...], new_shape[cursor_plan:new_cursor_plan]
370-
)
392+
typing.cast(tuple[tuple[int, int], ...], new_shape)
371393
)
372-
if (even, odd) == self.edges[cursor_self]:
373-
break
394+
new_cursor_plan = len(new_shape)
395+
break
396+
else:
397+
if plan_total == self.tensor.shape[cursor_self]:
398+
# new_shape block has been verified to be always tuple[int, int] before
399+
even, odd, reorder, sign = self._reorder_indices(
400+
typing.cast(
401+
tuple[tuple[int, int], ...],
402+
new_shape[cursor_plan:new_cursor_plan],
403+
)
404+
)
405+
if (even, odd) == self.edges[cursor_self]:
406+
break
374407
# For some reason we cannot stop here, continue to include more dimension, check something before continue
375408
assert plan_total <= self.tensor.shape[cursor_self], (
376409
f"Dimension mismatch in splitting with edges {self.edges} and new shape {new_shape}."
@@ -382,12 +415,16 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
382415
for i in range(cursor_plan, new_cursor_plan):
383416
# new_shape block has been verified to be always tuple[int, int] in the loop
384417
new_cursor_new_shape = typing.cast(tuple[int, int], new_shape[i])
385-
arrow.append(self.arrow[cursor_self])
418+
if self.tensor.dim() == 0:
419+
arrow.append(False)
420+
else:
421+
arrow.append(self.arrow[cursor_self])
386422
edges.append(new_cursor_new_shape)
387423
shape.append(new_cursor_new_shape[0] + new_cursor_new_shape[1])
388424
splitting_reorder.append((cursor_self, reorder))
389425
splitting_sign.append((cursor_self, sign))
390-
cursor_self += 1
426+
if self.tensor.dim() != 0:
427+
cursor_self += 1
391428
cursor_plan = new_cursor_plan
392429

393430
tensor = self.tensor
@@ -402,14 +439,17 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
402439
(
403440
self._unsqueeze(sign, index, self.tensor.dim())
404441
for index, sign in splitting_sign
405-
if self.arrow[index]
442+
if self.tensor.dim() != 0 and self.arrow[index]
406443
),
407444
torch.zeros([], dtype=torch.bool, device=self.tensor.device),
408445
)
409446
tensor = torch.where(splitting_parity, -tensor, +tensor)
410447

411448
tensor = tensor.reshape(shape)
412449

450+
if len(new_shape) == 0:
451+
return GrassmannTensor(_arrow=tuple(arrow), _edges=tuple(edges), _tensor=tensor)
452+
413453
merging_parity = functools.reduce(
414454
torch.logical_xor,
415455
(

tests/reshape_test.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_reshape_trivial_edges(arrow: tuple[bool, ...], plan_range: tuple[int, i
6363
assert a.edges == c.edges
6464

6565

66-
def test_reshape_merging_dimension_mismatch_edges_because_of_nonequal() -> None:
66+
def test_reshape_merging_dimension_mismatch_edges_because_of_unequal() -> None:
6767
arrow = (True, True, True)
6868
edges = ((2, 2), (8, 8), (2, 2))
6969
a = GrassmannTensor(arrow, edges, torch.randn([4, 16, 4]))
@@ -113,7 +113,7 @@ def test_reshape_splitting_shape_type() -> None:
113113
_ = a.reshape((2, (2, 2)))
114114

115115

116-
def test_reshape_splitting_dimension_mismatch_edges_because_of_nonequal() -> None:
116+
def test_reshape_splitting_dimension_mismatch_edges_because_of_unequal() -> None:
117117
arrow = (True,)
118118
edges = ((8, 8),)
119119
a = GrassmannTensor(arrow, edges, torch.randn([16]))
@@ -175,3 +175,10 @@ def test_reshape_equal_edges_nontrivial_merging_with_other_edge() -> None:
175175
edges = ((1, 3), (1, 0), (0, 1), (2, 2))
176176
a = GrassmannTensor(arrow, edges, torch.randn([4, 1, 1, 4]))
177177
_ = a.reshape(((3, 1), (2, 2)))
178+
179+
180+
def test_reshape_with_none() -> None:
181+
a = GrassmannTensor((), (), torch.tensor(2333)).reshape(((1, 0), (1, 0))).reshape(())
182+
assert len(a.arrow) == 0 and len(a.edges) == 0 and a.tensor.dim() == 0
183+
b = GrassmannTensor((), (), torch.tensor(2333)).reshape(((0, 1), (0, 1))).reshape(())
184+
assert len(b.arrow) == 0 and len(b.edges) == 0 and b.tensor.dim() == 0

0 commit comments

Comments
 (0)