Skip to content

Commit 607dfd8

Browse files
committed
fix(reshape): move target check outside and unify calculate_even_odd logic with reduce
- Moved the `target` condition check to outside of the while loop to improve readability and reduce nesting. - Refactored `_calculate_even_odd` and `calculate_even_odd` to use `functools.reduce`, keeping consistent with other reduce-based implementations in the codebase. - Improved overall code clarity and consistency.
1 parent 4f313c6 commit 607dfd8

File tree

1 file changed

+33
-54
lines changed

1 file changed

+33
-54
lines changed

grassmann_tensor/tensor.py

Lines changed: 33 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -222,17 +222,15 @@ def _reorder_indices(
222222
return len(even), len(odd), reorder, sign.flatten()
223223

224224
def _calculate_even_odd(self) -> tuple[int, int]:
225-
even, odd = 1, 0
226-
for e, o in self.edges:
227-
even, odd = even * e + odd * o, even * o + odd * e
228-
return even, odd
225+
return self.calculate_even_odd(self.edges)
229226

230227
@staticmethod
231228
def calculate_even_odd(edges: tuple[tuple[int, int], ...]) -> tuple[int, int]:
232-
even, odd = 1, 0
233-
for e, o in edges:
234-
even, odd = even * e + odd * o, even * o + odd * e
235-
return even, odd
229+
return functools.reduce(
230+
lambda acc, eo: (acc[0] * eo[0] + acc[1] * eo[1], acc[0] * eo[1] + acc[1] * eo[0]),
231+
edges,
232+
(1, 0),
233+
)
236234

237235
def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTensor:
238236
"""
@@ -287,15 +285,17 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
287285
"Cannot split none edges into illegal edges"
288286
)
289287

288+
if len(new_shape) == 0:
289+
assert self._calculate_even_odd() == (1, 0), (
290+
"Only pure even edges can be merged into none edges"
291+
)
292+
tensor = self.tensor.reshape(())
293+
return GrassmannTensor(_arrow=(), _edges=(), _tensor=tensor)
294+
290295
cursor_plan: int = 0
291296
cursor_self: int = 0
292297
while cursor_plan != len(new_shape) or cursor_self != self.tensor.dim():
293-
if len(new_shape) == 0:
294-
assert self._calculate_even_odd() == (1, 0), (
295-
"Only pure even edges can be merged into none edges"
296-
)
297-
cursor_self = self.tensor.dim() - 1
298-
elif cursor_plan != len(new_shape) and new_shape[cursor_plan] == -1:
298+
if cursor_plan != len(new_shape) and new_shape[cursor_plan] == -1:
299299
# Does not change
300300
arrow.append(self.arrow[cursor_self])
301301
edges.append(self.edges[cursor_self])
@@ -314,16 +314,12 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
314314
# A trivial self edge
315315
cursor_self += 1
316316
continue
317-
if len(new_shape) == 0:
318-
cursor_new_shape = typing.cast(int | tuple[int, int], tuple())
319-
total = 1
320-
else:
321-
cursor_new_shape = new_shape[cursor_plan]
322-
total = (
323-
cursor_new_shape
324-
if isinstance(cursor_new_shape, int)
325-
else cursor_new_shape[0] + cursor_new_shape[1]
326-
)
317+
cursor_new_shape = new_shape[cursor_plan]
318+
total = (
319+
cursor_new_shape
320+
if isinstance(cursor_new_shape, int)
321+
else cursor_new_shape[0] + cursor_new_shape[1]
322+
)
327323
# one of total and shape[cursor_self] is not trivial, otherwise it should be handled before
328324
if total == self.tensor.shape[cursor_self]:
329325
# We do not know whether it is merging or splitting, check more
@@ -339,9 +335,6 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
339335
cursor_self_finding = cursor_self
340336
cursor_self_found = False
341337
while True:
342-
if len(new_shape) == 0:
343-
cursor_self_found = True
344-
break
345338
cursor_self_finding += 1
346339
if cursor_self_finding == self.tensor.dim():
347340
break
@@ -361,10 +354,6 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
361354
new_cursor_self = cursor_self
362355
self_total = 1
363356
while True:
364-
if len(new_shape) == 0:
365-
new_cursor_self += 1
366-
even, odd, reorder, sign = self._reorder_indices(self.edges)
367-
break
368357
# Try to include more dimension from self
369358
self_total *= self.tensor.shape[new_cursor_self]
370359
new_cursor_self += 1
@@ -386,26 +375,19 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
386375
f"New shape exceeds in merging with edges {self.edges} and new shape {new_shape}."
387376
)
388377
# The merging block [cursor_self, new_cursor_self) has been determined
389-
if len(new_shape) == 0:
390-
arrow = []
391-
edges = []
392-
shape = []
393-
merging_sign.append((cursor_plan, sign))
394-
cursor_self = new_cursor_self
395-
else:
396-
arrow.append(self.arrow[cursor_self])
397-
assert all(
398-
self_arrow == arrow[-1]
399-
for self_arrow in self.arrow[cursor_self:new_cursor_self]
400-
), (
401-
f"Cannot merge edges with different arrows {self.arrow[cursor_self:new_cursor_self]}."
402-
)
403-
edges.append((even, odd))
404-
shape.append(total)
405-
merging_sign.append((cursor_plan, sign))
406-
merging_reorder.append((cursor_plan, reorder))
407-
cursor_self = new_cursor_self
408-
cursor_plan += 1
378+
arrow.append(self.arrow[cursor_self])
379+
assert all(
380+
self_arrow == arrow[-1]
381+
for self_arrow in self.arrow[cursor_self:new_cursor_self]
382+
), (
383+
f"Cannot merge edges with different arrows {self.arrow[cursor_self:new_cursor_self]}."
384+
)
385+
edges.append((even, odd))
386+
shape.append(total)
387+
merging_sign.append((cursor_plan, sign))
388+
merging_reorder.append((cursor_plan, reorder))
389+
cursor_self = new_cursor_self
390+
cursor_plan += 1
409391
else:
410392
# Splitting between [cursor_plan, new_cursor_plan) and the another side contains dimension as plan_total
411393
new_cursor_plan = cursor_plan
@@ -469,9 +451,6 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
469451

470452
tensor = tensor.reshape(shape)
471453

472-
if len(new_shape) == 0:
473-
return GrassmannTensor(_arrow=tuple(arrow), _edges=tuple(edges), _tensor=tensor)
474-
475454
merging_parity = functools.reduce(
476455
torch.logical_xor,
477456
(

0 commit comments

Comments
 (0)