@@ -256,33 +256,44 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
256256 cursor_plan : int = 0
257257 cursor_self : int = 0
258258 while cursor_plan != len (new_shape ) or cursor_self != self .tensor .dim ():
259- if new_shape [cursor_plan ] == - 1 :
259+ if len (new_shape ) == 0 :
260+ assert all (edge == (0 , 1 ) or edge == (1 , 0 ) for edge in self .edges ), (
261+ f"Edge must be (0, 1) or (1, 0) but got { self .edges } "
262+ )
263+ cursor_self = self .tensor .dim () - 1
264+ elif cursor_plan != len (new_shape ) and new_shape [cursor_plan ] == - 1 :
260265 # Does not change
261266 arrow .append (self .arrow [cursor_self ])
262267 edges .append (self .edges [cursor_self ])
263268 shape .append (self .tensor .shape [cursor_self ])
264269 cursor_self += 1
265270 cursor_plan += 1
266271 continue
267- if new_shape [cursor_plan ] == (1 , 0 ):
268- # An trivial plan edge
272+ elif cursor_plan != len ( new_shape ) and new_shape [cursor_plan ] == (1 , 0 ):
273+ # A trivial plan edge
269274 arrow .append (False )
270275 edges .append ((1 , 0 ))
271276 shape .append (1 )
272277 cursor_plan += 1
273278 continue
274- if self .edges [cursor_self ] == (1 , 0 ):
275- # An trivial self edge
279+ elif cursor_self != self . tensor . dim () and self .edges [cursor_self ] == (1 , 0 ):
280+ # A trivial self edge
276281 cursor_self += 1
277282 continue
278- cursor_new_shape = new_shape [cursor_plan ]
279- total = (
280- cursor_new_shape
281- if isinstance (cursor_new_shape , int )
282- else cursor_new_shape [0 ] + cursor_new_shape [1 ]
283- )
283+ if len (new_shape ) == 0 :
284+ cursor_new_shape = typing .cast (int | tuple [int , int ], tuple ())
285+ total = 1
286+ else :
287+ cursor_new_shape = new_shape [cursor_plan ]
288+ total = (
289+ cursor_new_shape
290+ if isinstance (cursor_new_shape , int )
291+ else cursor_new_shape [0 ] + cursor_new_shape [1 ]
292+ )
284293 # one of total and shape[cursor_self] is not trivial, otherwise it should be handled before
285- if total == self .tensor .shape [cursor_self ]:
294+ if self .tensor .dim () == 0 :
295+ merging = False
296+ elif total == self .tensor .shape [cursor_self ]:
286297 # We do not know whether it is merging or splitting, check more
287298 if isinstance (cursor_new_shape , int ) or cursor_new_shape == self .edges [cursor_self ]:
288299 # If the new shape is exactly the same as the current edge, we treat it as no change
@@ -296,6 +307,9 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
296307 cursor_self_finding = cursor_self
297308 cursor_self_found = False
298309 while True :
310+ if len (new_shape ) == 0 :
311+ cursor_self_found = True
312+ break
299313 cursor_self_finding += 1
300314 if cursor_self_finding == self .tensor .dim ():
301315 break
@@ -306,15 +320,19 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
306320 break
307321 break
308322 merging = cursor_self_found
309- if total > self .tensor .shape [cursor_self ]:
323+ elif total > self .tensor .shape [cursor_self ]:
310324 merging = True
311- if total < self .tensor .shape [cursor_self ]:
325+ elif total < self .tensor .shape [cursor_self ]:
312326 merging = False
313327 if merging :
314328 # Merging between [cursor_self, new_cursor_self) and the another side contains dimension as self_total
315329 new_cursor_self = cursor_self
316330 self_total = 1
317331 while True :
332+ if len (new_shape ) == 0 :
333+ new_cursor_self += 1
334+ even , odd , reorder , sign = self ._reorder_indices (self .edges )
335+ break
318336 # Try to include more dimension from self
319337 self_total *= self .tensor .shape [new_cursor_self ]
320338 new_cursor_self += 1
@@ -336,19 +354,26 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
336354 f"New shape exceeds in merging with edges { self .edges } and new shape { new_shape } ."
337355 )
338356 # The merging block [cursor_self, new_cursor_self) has been determined
339- arrow .append (self .arrow [cursor_self ])
340- assert all (
341- self_arrow == arrow [- 1 ]
342- for self_arrow in self .arrow [cursor_self :new_cursor_self ]
343- ), (
344- f"Cannot merge edges with different arrows { self .arrow [cursor_self :new_cursor_self ]} ."
345- )
346- edges .append ((even , odd ))
347- shape .append (total )
348- merging_sign .append ((cursor_plan , sign ))
349- merging_reorder .append ((cursor_plan , reorder ))
350- cursor_self = new_cursor_self
351- cursor_plan += 1
357+ if len (new_shape ) == 0 :
358+ arrow = []
359+ edges = []
360+ shape = []
361+ merging_sign .append ((cursor_plan , sign ))
362+ cursor_self = new_cursor_self
363+ else :
364+ arrow .append (self .arrow [cursor_self ])
365+ assert all (
366+ self_arrow == arrow [- 1 ]
367+ for self_arrow in self .arrow [cursor_self :new_cursor_self ]
368+ ), (
369+ f"Cannot merge edges with different arrows { self .arrow [cursor_self :new_cursor_self ]} ."
370+ )
371+ edges .append ((even , odd ))
372+ shape .append (total )
373+ merging_sign .append ((cursor_plan , sign ))
374+ merging_reorder .append ((cursor_plan , reorder ))
375+ cursor_self = new_cursor_self
376+ cursor_plan += 1
352377 else :
353378 # Splitting between [cursor_plan, new_cursor_plan) and the another side contains dimension as plan_total
354379 new_cursor_plan = cursor_plan
@@ -362,15 +387,23 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
362387 plan_total *= new_cursor_new_shape [0 ] + new_cursor_new_shape [1 ]
363388 new_cursor_plan += 1
364389 # One dimension included, check if we can stop
365- if plan_total == self .tensor .shape [cursor_self ]:
366- # new_shape block has been verified to be always tuple[int, int] before
390+ if self .tensor .dim () == 0 :
367391 even , odd , reorder , sign = self ._reorder_indices (
368- typing .cast (
369- tuple [tuple [int , int ], ...], new_shape [cursor_plan :new_cursor_plan ]
370- )
392+ typing .cast (tuple [tuple [int , int ], ...], new_shape )
371393 )
372- if (even , odd ) == self .edges [cursor_self ]:
373- break
394+ new_cursor_plan = len (new_shape )
395+ break
396+ else :
397+ if plan_total == self .tensor .shape [cursor_self ]:
398+ # new_shape block has been verified to be always tuple[int, int] before
399+ even , odd , reorder , sign = self ._reorder_indices (
400+ typing .cast (
401+ tuple [tuple [int , int ], ...],
402+ new_shape [cursor_plan :new_cursor_plan ],
403+ )
404+ )
405+ if (even , odd ) == self .edges [cursor_self ]:
406+ break
374407 # For some reason we cannot stop here, continue to include more dimension, check something before continue
375408 assert plan_total <= self .tensor .shape [cursor_self ], (
376409 f"Dimension mismatch in splitting with edges { self .edges } and new shape { new_shape } ."
@@ -382,12 +415,16 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
382415 for i in range (cursor_plan , new_cursor_plan ):
383416 # new_shape block has been verified to be always tuple[int, int] in the loop
384417 new_cursor_new_shape = typing .cast (tuple [int , int ], new_shape [i ])
385- arrow .append (self .arrow [cursor_self ])
418+ if self .tensor .dim () == 0 :
419+ arrow .append (False )
420+ else :
421+ arrow .append (self .arrow [cursor_self ])
386422 edges .append (new_cursor_new_shape )
387423 shape .append (new_cursor_new_shape [0 ] + new_cursor_new_shape [1 ])
388424 splitting_reorder .append ((cursor_self , reorder ))
389425 splitting_sign .append ((cursor_self , sign ))
390- cursor_self += 1
426+ if self .tensor .dim () != 0 :
427+ cursor_self += 1
391428 cursor_plan = new_cursor_plan
392429
393430 tensor = self .tensor
@@ -402,14 +439,17 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
402439 (
403440 self ._unsqueeze (sign , index , self .tensor .dim ())
404441 for index , sign in splitting_sign
405- if self .arrow [index ]
442+ if self .tensor . dim () != 0 and self . arrow [index ]
406443 ),
407444 torch .zeros ([], dtype = torch .bool , device = self .tensor .device ),
408445 )
409446 tensor = torch .where (splitting_parity , - tensor , + tensor )
410447
411448 tensor = tensor .reshape (shape )
412449
450+ if len (new_shape ) == 0 :
451+ return GrassmannTensor (_arrow = tuple (arrow ), _edges = tuple (edges ), _tensor = tensor )
452+
413453 merging_parity = functools .reduce (
414454 torch .logical_xor ,
415455 (
0 commit comments