@@ -221,6 +221,19 @@ def _reorder_indices(
221221 sign = (count & 2 ).to (dtype = torch .bool )
222222 return len (even ), len (odd ), reorder , sign .flatten ()
223223
224+ def _calculate_even_odd (self ) -> tuple [int , int ]:
225+ even , odd = 1 , 0
226+ for e , o in self .edges :
227+ even , odd = even * e + odd * o , even * o + odd * e
228+ return even , odd
229+
230+ @staticmethod
231+ def calculate_even_odd (edges : tuple [tuple [int , int ], ...]) -> tuple [int , int ]:
232+ even , odd = 1 , 0
233+ for e , o in edges :
234+ even , odd = even * e + odd * o , even * o + odd * e
235+ return even , odd
236+
224237 def reshape (self , new_shape : tuple [int | tuple [int , int ], ...]) -> GrassmannTensor :
225238 """
226239 Reshape the Grassmann tensor, which may split or merge edges.
@@ -253,12 +266,33 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
253266 merging_reorder : list [tuple [int , torch .Tensor ]] = []
254267 merging_sign : list [tuple [int , torch .Tensor ]] = []
255268
269+ original_self_is_scalar = self .tensor .dim () == 0
270+ if original_self_is_scalar :
271+ normalized : list [tuple [int , int ]] = []
272+ for item in new_shape :
273+ if item == - 1 :
274+ raise AssertionError ("Cannot use -1 when reshaping from a scalar" )
275+ if isinstance (item , int ):
276+ if item != 1 :
277+ raise AssertionError (
278+ f"Ambiguous integer dim { item } from scalar. "
279+ "Use explicit (even, odd) pairs, or only use 1 for trivial edges."
280+ )
281+ normalized .append ((1 , 0 ))
282+ else :
283+ normalized .append (item )
284+ new_shape = tuple (normalized )
285+ edges_only = typing .cast (tuple [tuple [int , int ], ...], new_shape )
286+ assert self .calculate_even_odd (edges_only ) == (1 , 0 ), (
287+ "Cannot split none edges into illegal edges"
288+ )
289+
256290 cursor_plan : int = 0
257291 cursor_self : int = 0
258292 while cursor_plan != len (new_shape ) or cursor_self != self .tensor .dim ():
259293 if len (new_shape ) == 0 :
260- assert all ( edge == ( 0 , 1 ) or edge == (1 , 0 ) for edge in self . edges ), (
261- f"Edge must be (0, 1) or (1, 0) but got { self . edges } "
294+ assert self . _calculate_even_odd () == (1 , 0 ), (
295+ "Only pure even edges can be merged into none edges"
262296 )
263297 cursor_self = self .tensor .dim () - 1
264298 elif cursor_plan != len (new_shape ) and new_shape [cursor_plan ] == - 1 :
@@ -291,9 +325,7 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
291325 else cursor_new_shape [0 ] + cursor_new_shape [1 ]
292326 )
293327 # one of total and shape[cursor_self] is not trivial, otherwise it should be handled before
294- if self .tensor .dim () == 0 :
295- merging = False
296- elif total == self .tensor .shape [cursor_self ]:
328+ if total == self .tensor .shape [cursor_self ]:
297329 # We do not know whether it is merging or splitting, check more
298330 if isinstance (cursor_new_shape , int ) or cursor_new_shape == self .edges [cursor_self ]:
299331 # If the new shape is exactly the same as the current edge, we treat it as no change
@@ -387,23 +419,16 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
387419 plan_total *= new_cursor_new_shape [0 ] + new_cursor_new_shape [1 ]
388420 new_cursor_plan += 1
389421 # One dimension included, check if we can stop
390- if self .tensor .dim () == 0 :
422+ if plan_total == self .tensor .shape [cursor_self ]:
423+ # new_shape block has been verified to be always tuple[int, int] before
391424 even , odd , reorder , sign = self ._reorder_indices (
392- typing .cast (tuple [tuple [int , int ], ...], new_shape )
393- )
394- new_cursor_plan = len (new_shape )
395- break
396- else :
397- if plan_total == self .tensor .shape [cursor_self ]:
398- # new_shape block has been verified to be always tuple[int, int] before
399- even , odd , reorder , sign = self ._reorder_indices (
400- typing .cast (
401- tuple [tuple [int , int ], ...],
402- new_shape [cursor_plan :new_cursor_plan ],
403- )
425+ typing .cast (
426+ tuple [tuple [int , int ], ...],
427+ new_shape [cursor_plan :new_cursor_plan ],
404428 )
405- if (even , odd ) == self .edges [cursor_self ]:
406- break
429+ )
430+ if (even , odd ) == self .edges [cursor_self ]:
431+ break
407432 # For some reason we cannot stop here, continue to include more dimension, check something before continue
408433 assert plan_total <= self .tensor .shape [cursor_self ], (
409434 f"Dimension mismatch in splitting with edges { self .edges } and new shape { new_shape } ."
@@ -415,10 +440,7 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
415440 for i in range (cursor_plan , new_cursor_plan ):
416441 # new_shape block has been verified to be always tuple[int, int] in the loop
417442 new_cursor_new_shape = typing .cast (tuple [int , int ], new_shape [i ])
418- if self .tensor .dim () == 0 :
419- arrow .append (False )
420- else :
421- arrow .append (self .arrow [cursor_self ])
443+ arrow .append (self .arrow [cursor_self ])
422444 edges .append (new_cursor_new_shape )
423445 shape .append (new_cursor_new_shape [0 ] + new_cursor_new_shape [1 ])
424446 splitting_reorder .append ((cursor_self , reorder ))
0 commit comments