@@ -256,33 +256,44 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
256256        cursor_plan : int  =  0 
257257        cursor_self : int  =  0 
258258        while  cursor_plan  !=  len (new_shape ) or  cursor_self  !=  self .tensor .dim ():
259-             if  new_shape [cursor_plan ] ==  - 1 :
259+             if  len (new_shape ) ==  0 :
260+                 assert  all (edge  ==  (0 , 1 ) or  edge  ==  (1 , 0 ) for  edge  in  self .edges ), (
261+                     f"Edge must be (0, 1) or (1, 0) but got { self .edges }  
262+                 )
263+                 cursor_self  =  self .tensor .dim () -  1 
264+             elif  cursor_plan  !=  len (new_shape ) and  new_shape [cursor_plan ] ==  - 1 :
260265                # Does not change 
261266                arrow .append (self .arrow [cursor_self ])
262267                edges .append (self .edges [cursor_self ])
263268                shape .append (self .tensor .shape [cursor_self ])
264269                cursor_self  +=  1 
265270                cursor_plan  +=  1 
266271                continue 
267-             if  new_shape [cursor_plan ] ==  (1 , 0 ):
268-                 # An  trivial plan edge 
272+             elif   cursor_plan   !=   len ( new_shape )  and  new_shape [cursor_plan ] ==  (1 , 0 ):
273+                 # A  trivial plan edge 
269274                arrow .append (False )
270275                edges .append ((1 , 0 ))
271276                shape .append (1 )
272277                cursor_plan  +=  1 
273278                continue 
274-             if  self .edges [cursor_self ] ==  (1 , 0 ):
275-                 # An  trivial self edge 
279+             elif   cursor_self   !=   self . tensor . dim ()  and  self .edges [cursor_self ] ==  (1 , 0 ):
280+                 # A  trivial self edge 
276281                cursor_self  +=  1 
277282                continue 
278-             cursor_new_shape  =  new_shape [cursor_plan ]
279-             total  =  (
280-                 cursor_new_shape 
281-                 if  isinstance (cursor_new_shape , int )
282-                 else  cursor_new_shape [0 ] +  cursor_new_shape [1 ]
283-             )
283+             if  len (new_shape ) ==  0 :
284+                 cursor_new_shape  =  typing .cast (int  |  tuple [int , int ], tuple ())
285+                 total  =  1 
286+             else :
287+                 cursor_new_shape  =  new_shape [cursor_plan ]
288+                 total  =  (
289+                     cursor_new_shape 
290+                     if  isinstance (cursor_new_shape , int )
291+                     else  cursor_new_shape [0 ] +  cursor_new_shape [1 ]
292+                 )
284293            # one of total and shape[cursor_self] is not trivial, otherwise it should be handled before 
285-             if  total  ==  self .tensor .shape [cursor_self ]:
294+             if  self .tensor .dim () ==  0 :
295+                 merging  =  False 
296+             elif  total  ==  self .tensor .shape [cursor_self ]:
286297                # We do not know whether it is merging or splitting, check more 
287298                if  isinstance (cursor_new_shape , int ) or  cursor_new_shape  ==  self .edges [cursor_self ]:
288299                    # If the new shape is exactly the same as the current edge, we treat it as no change 
@@ -296,6 +307,9 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
296307                cursor_self_finding  =  cursor_self 
297308                cursor_self_found  =  False 
298309                while  True :
310+                     if  len (new_shape ) ==  0 :
311+                         cursor_self_found  =  True 
312+                         break 
299313                    cursor_self_finding  +=  1 
300314                    if  cursor_self_finding  ==  self .tensor .dim ():
301315                        break 
@@ -306,15 +320,19 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
306320                        break 
307321                    break 
308322                merging  =  cursor_self_found 
309-             if  total  >  self .tensor .shape [cursor_self ]:
323+             elif  total  >  self .tensor .shape [cursor_self ]:
310324                merging  =  True 
311-             if  total  <  self .tensor .shape [cursor_self ]:
325+             elif  total  <  self .tensor .shape [cursor_self ]:
312326                merging  =  False 
313327            if  merging :
314328                # Merging between [cursor_self, new_cursor_self) and the another side contains dimension as self_total 
315329                new_cursor_self  =  cursor_self 
316330                self_total  =  1 
317331                while  True :
332+                     if  len (new_shape ) ==  0 :
333+                         new_cursor_self  +=  1 
334+                         even , odd , reorder , sign  =  self ._reorder_indices (self .edges )
335+                         break 
318336                    # Try to include more dimension from self 
319337                    self_total  *=  self .tensor .shape [new_cursor_self ]
320338                    new_cursor_self  +=  1 
@@ -336,19 +354,26 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
336354                        f"New shape exceeds in merging with edges { self .edges } { new_shape }  
337355                    )
338356                # The merging block [cursor_self, new_cursor_self) has been determined 
339-                 arrow .append (self .arrow [cursor_self ])
340-                 assert  all (
341-                     self_arrow  ==  arrow [- 1 ]
342-                     for  self_arrow  in  self .arrow [cursor_self :new_cursor_self ]
343-                 ), (
344-                     f"Cannot merge edges with different arrows { self .arrow [cursor_self :new_cursor_self ]}  
345-                 )
346-                 edges .append ((even , odd ))
347-                 shape .append (total )
348-                 merging_sign .append ((cursor_plan , sign ))
349-                 merging_reorder .append ((cursor_plan , reorder ))
350-                 cursor_self  =  new_cursor_self 
351-                 cursor_plan  +=  1 
357+                 if  len (new_shape ) ==  0 :
358+                     arrow  =  []
359+                     edges  =  []
360+                     shape  =  []
361+                     merging_sign .append ((cursor_plan , sign ))
362+                     cursor_self  =  new_cursor_self 
363+                 else :
364+                     arrow .append (self .arrow [cursor_self ])
365+                     assert  all (
366+                         self_arrow  ==  arrow [- 1 ]
367+                         for  self_arrow  in  self .arrow [cursor_self :new_cursor_self ]
368+                     ), (
369+                         f"Cannot merge edges with different arrows { self .arrow [cursor_self :new_cursor_self ]}  
370+                     )
371+                     edges .append ((even , odd ))
372+                     shape .append (total )
373+                     merging_sign .append ((cursor_plan , sign ))
374+                     merging_reorder .append ((cursor_plan , reorder ))
375+                     cursor_self  =  new_cursor_self 
376+                     cursor_plan  +=  1 
352377            else :
353378                # Splitting between [cursor_plan, new_cursor_plan) and the another side contains dimension as plan_total 
354379                new_cursor_plan  =  cursor_plan 
@@ -362,15 +387,23 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
362387                    plan_total  *=  new_cursor_new_shape [0 ] +  new_cursor_new_shape [1 ]
363388                    new_cursor_plan  +=  1 
364389                    # One dimension included, check if we can stop 
365-                     if  plan_total  ==  self .tensor .shape [cursor_self ]:
366-                         # new_shape block has been verified to be always tuple[int, int] before 
390+                     if  self .tensor .dim () ==  0 :
367391                        even , odd , reorder , sign  =  self ._reorder_indices (
368-                             typing .cast (
369-                                 tuple [tuple [int , int ], ...], new_shape [cursor_plan :new_cursor_plan ]
370-                             )
392+                             typing .cast (tuple [tuple [int , int ], ...], new_shape )
371393                        )
372-                         if  (even , odd ) ==  self .edges [cursor_self ]:
373-                             break 
394+                         new_cursor_plan  =  len (new_shape )
395+                         break 
396+                     else :
397+                         if  plan_total  ==  self .tensor .shape [cursor_self ]:
398+                             # new_shape block has been verified to be always tuple[int, int] before 
399+                             even , odd , reorder , sign  =  self ._reorder_indices (
400+                                 typing .cast (
401+                                     tuple [tuple [int , int ], ...],
402+                                     new_shape [cursor_plan :new_cursor_plan ],
403+                                 )
404+                             )
405+                             if  (even , odd ) ==  self .edges [cursor_self ]:
406+                                 break 
374407                    # For some reason we cannot stop here, continue to include more dimension, check something before continue 
375408                    assert  plan_total  <=  self .tensor .shape [cursor_self ], (
376409                        f"Dimension mismatch in splitting with edges { self .edges } { new_shape }  
@@ -382,12 +415,16 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
382415                for  i  in  range (cursor_plan , new_cursor_plan ):
383416                    # new_shape block has been verified to be always tuple[int, int] in the loop 
384417                    new_cursor_new_shape  =  typing .cast (tuple [int , int ], new_shape [i ])
385-                     arrow .append (self .arrow [cursor_self ])
418+                     if  self .tensor .dim () ==  0 :
419+                         arrow .append (False )
420+                     else :
421+                         arrow .append (self .arrow [cursor_self ])
386422                    edges .append (new_cursor_new_shape )
387423                    shape .append (new_cursor_new_shape [0 ] +  new_cursor_new_shape [1 ])
388424                splitting_reorder .append ((cursor_self , reorder ))
389425                splitting_sign .append ((cursor_self , sign ))
390-                 cursor_self  +=  1 
426+                 if  self .tensor .dim () !=  0 :
427+                     cursor_self  +=  1 
391428                cursor_plan  =  new_cursor_plan 
392429
393430        tensor  =  self .tensor 
@@ -402,14 +439,17 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
402439            (
403440                self ._unsqueeze (sign , index , self .tensor .dim ())
404441                for  index , sign  in  splitting_sign 
405-                 if  self .arrow [index ]
442+                 if  self .tensor . dim ()  !=   0   and   self . arrow [index ]
406443            ),
407444            torch .zeros ([], dtype = torch .bool , device = self .tensor .device ),
408445        )
409446        tensor  =  torch .where (splitting_parity , - tensor , + tensor )
410447
411448        tensor  =  tensor .reshape (shape )
412449
450+         if  len (new_shape ) ==  0 :
451+             return  GrassmannTensor (_arrow = tuple (arrow ), _edges = tuple (edges ), _tensor = tensor )
452+ 
413453        merging_parity  =  functools .reduce (
414454            torch .logical_xor ,
415455            (
0 commit comments