@@ -222,17 +222,15 @@ def _reorder_indices(
222222 return len (even ), len (odd ), reorder , sign .flatten ()
223223
224224 def _calculate_even_odd (self ) -> tuple [int , int ]:
225- even , odd = 1 , 0
226- for e , o in self .edges :
227- even , odd = even * e + odd * o , even * o + odd * e
228- return even , odd
225+ return self .calculate_even_odd (self .edges )
229226
230227 @staticmethod
231228 def calculate_even_odd (edges : tuple [tuple [int , int ], ...]) -> tuple [int , int ]:
232- even , odd = 1 , 0
233- for e , o in edges :
234- even , odd = even * e + odd * o , even * o + odd * e
235- return even , odd
229+ return functools .reduce (
230+ lambda acc , eo : (acc [0 ] * eo [0 ] + acc [1 ] * eo [1 ], acc [0 ] * eo [1 ] + acc [1 ] * eo [0 ]),
231+ edges ,
232+ (1 , 0 ),
233+ )
236234
237235 def reshape (self , new_shape : tuple [int | tuple [int , int ], ...]) -> GrassmannTensor :
238236 """
@@ -287,15 +285,17 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
287285 "Cannot split none edges into illegal edges"
288286 )
289287
288+ if len (new_shape ) == 0 :
289+ assert self ._calculate_even_odd () == (1 , 0 ), (
290+ "Only pure even edges can be merged into none edges"
291+ )
292+ tensor = self .tensor .reshape (())
293+ return GrassmannTensor (_arrow = (), _edges = (), _tensor = tensor )
294+
290295 cursor_plan : int = 0
291296 cursor_self : int = 0
292297 while cursor_plan != len (new_shape ) or cursor_self != self .tensor .dim ():
293- if len (new_shape ) == 0 :
294- assert self ._calculate_even_odd () == (1 , 0 ), (
295- "Only pure even edges can be merged into none edges"
296- )
297- cursor_self = self .tensor .dim () - 1
298- elif cursor_plan != len (new_shape ) and new_shape [cursor_plan ] == - 1 :
298+ if cursor_plan != len (new_shape ) and new_shape [cursor_plan ] == - 1 :
299299 # Does not change
300300 arrow .append (self .arrow [cursor_self ])
301301 edges .append (self .edges [cursor_self ])
@@ -314,16 +314,12 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
314314 # A trivial self edge
315315 cursor_self += 1
316316 continue
317- if len (new_shape ) == 0 :
318- cursor_new_shape = typing .cast (int | tuple [int , int ], tuple ())
319- total = 1
320- else :
321- cursor_new_shape = new_shape [cursor_plan ]
322- total = (
323- cursor_new_shape
324- if isinstance (cursor_new_shape , int )
325- else cursor_new_shape [0 ] + cursor_new_shape [1 ]
326- )
317+ cursor_new_shape = new_shape [cursor_plan ]
318+ total = (
319+ cursor_new_shape
320+ if isinstance (cursor_new_shape , int )
321+ else cursor_new_shape [0 ] + cursor_new_shape [1 ]
322+ )
327323 # one of total and shape[cursor_self] is not trivial, otherwise it should be handled before
328324 if total == self .tensor .shape [cursor_self ]:
329325 # We do not know whether it is merging or splitting, check more
@@ -339,9 +335,6 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
339335 cursor_self_finding = cursor_self
340336 cursor_self_found = False
341337 while True :
342- if len (new_shape ) == 0 :
343- cursor_self_found = True
344- break
345338 cursor_self_finding += 1
346339 if cursor_self_finding == self .tensor .dim ():
347340 break
@@ -361,10 +354,6 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
361354 new_cursor_self = cursor_self
362355 self_total = 1
363356 while True :
364- if len (new_shape ) == 0 :
365- new_cursor_self += 1
366- even , odd , reorder , sign = self ._reorder_indices (self .edges )
367- break
368357 # Try to include more dimension from self
369358 self_total *= self .tensor .shape [new_cursor_self ]
370359 new_cursor_self += 1
@@ -386,26 +375,19 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
386375 f"New shape exceeds in merging with edges { self .edges } and new shape { new_shape } ."
387376 )
388377 # The merging block [cursor_self, new_cursor_self) has been determined
389- if len (new_shape ) == 0 :
390- arrow = []
391- edges = []
392- shape = []
393- merging_sign .append ((cursor_plan , sign ))
394- cursor_self = new_cursor_self
395- else :
396- arrow .append (self .arrow [cursor_self ])
397- assert all (
398- self_arrow == arrow [- 1 ]
399- for self_arrow in self .arrow [cursor_self :new_cursor_self ]
400- ), (
401- f"Cannot merge edges with different arrows { self .arrow [cursor_self :new_cursor_self ]} ."
402- )
403- edges .append ((even , odd ))
404- shape .append (total )
405- merging_sign .append ((cursor_plan , sign ))
406- merging_reorder .append ((cursor_plan , reorder ))
407- cursor_self = new_cursor_self
408- cursor_plan += 1
378+ arrow .append (self .arrow [cursor_self ])
379+ assert all (
380+ self_arrow == arrow [- 1 ]
381+ for self_arrow in self .arrow [cursor_self :new_cursor_self ]
382+ ), (
383+ f"Cannot merge edges with different arrows { self .arrow [cursor_self :new_cursor_self ]} ."
384+ )
385+ edges .append ((even , odd ))
386+ shape .append (total )
387+ merging_sign .append ((cursor_plan , sign ))
388+ merging_reorder .append ((cursor_plan , reorder ))
389+ cursor_self = new_cursor_self
390+ cursor_plan += 1
409391 else :
410392 # Splitting between [cursor_plan, new_cursor_plan) and the another side contains dimension as plan_total
411393 new_cursor_plan = cursor_plan
@@ -469,9 +451,6 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
469451
470452 tensor = tensor .reshape (shape )
471453
472- if len (new_shape ) == 0 :
473- return GrassmannTensor (_arrow = tuple (arrow ), _edges = tuple (edges ), _tensor = tensor )
474-
475454 merging_parity = functools .reduce (
476455 torch .logical_xor ,
477456 (
0 commit comments