@@ -231,3 +231,59 @@ def test_reshape_trailing_nontrivial_dim_raises() -> None:
231231    a  =  GrassmannTensor ((True ,), ((2 , 2 ),), torch .randn ([4 ]))
232232    with  pytest .raises (AssertionError , match = "New shape exceeds after exhausting self dimensions" ):
233233        _  =  a .reshape ((- 1 , (2 , 2 )))
234+ 
235+ 
236+ @pytest .mark .parametrize ( 
237+     "tensor" , 
238+     [ 
239+         GrassmannTensor ( 
240+             (True , True , True , True ), 
241+             ((1 , 0 ), (1 , 0 ), (2 , 2 ), (8 , 8 )), 
242+             torch .randn (1 , 1 , 4 , 16 ), 
243+         ), 
244+     ], 
245+ ) 
246+ @pytest .mark .parametrize ( 
247+     "shape" , 
248+     [ 
249+         (1 , 64 ), 
250+         ((1 , 0 ), 64 ), 
251+         (- 1 , 64 ), 
252+     ], 
253+ ) 
254+ def  test_reshape_trivial_head_equivalence (
255+     tensor : GrassmannTensor ,
256+     shape : tuple [int , ...],
257+ ) ->  None :
258+     baseline_tensor  =  tensor .reshape ((1 , 64 ))
259+     actual_tensor  =  tensor .reshape (shape )
260+ 
261+     assert  actual_tensor .edges  ==  ((1 , 0 ), (32 , 32 ))
262+     assert  torch .allclose (actual_tensor .tensor , baseline_tensor .tensor )
263+ 
264+     roundtrip_tensor  =  actual_tensor .reshape (tensor .edges )
265+     assert  torch .allclose (roundtrip_tensor .tensor , tensor .tensor )
266+ 
267+ 
268+ def  test_reshape_head_1_inserts_trivial_when_self_dim_not_one () ->  None :
269+     a  =  GrassmannTensor (
270+         (True , True ),
271+         ((2 , 2 ), (8 , 8 )),
272+         torch .randn (4 , 16 ),
273+     )
274+     out  =  a .reshape ((1 , 64 ))
275+     assert  out .edges  ==  ((1 , 0 ), (32 , 32 ))
276+     assert  out .tensor .shape  ==  (1 , 64 )
277+     assert  out .arrow [0 ] is  False 
278+ 
279+ 
280+ def  test_reshape_plan_exhausted_then_skip_trivial_self_edges () ->  None :
281+     a  =  GrassmannTensor (
282+         (False , False , False ),
283+         ((2 , 2 ), (1 , 0 ), (1 , 0 )),
284+         torch .randn (4 , 1 , 1 ),
285+     )
286+     out  =  a .reshape ((4 ,))
287+     assert  out .edges  ==  ((2 , 2 ),)
288+     assert  out .tensor .shape  ==  (4 ,)
289+     assert  out .arrow  ==  (False ,)
0 commit comments