Skip to content

Commit da5e0de

Browse files
authored
Merge pull request #74 from USTC-KnowledgeComputingLab/fix/reshape-plan-int1-splitting-assert
fix(reshape): fix reshape with head int1 as trivial Close #74
2 parents 492b569 + 1a135df commit da5e0de

File tree

2 files changed

+71
-2
lines changed

2 files changed

+71
-2
lines changed

grassmann_tensor/tensor.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,8 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
297297
return GrassmannTensor(_arrow=(), _edges=(), _tensor=tensor)
298298

299299
if new_shape == (1,) and int(self.tensor.numel()) == 1:
300-
eo = self._calculate_even_odd()
301-
new_shape = (eo,)
300+
even_self, odd_self = self._calculate_even_odd()
301+
new_shape = ((even_self, odd_self),)
302302

303303
cursor_plan: int = 0
304304
cursor_self: int = 0
@@ -318,6 +318,19 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
318318
f"edges={self.edges}, new_shape={new_shape}"
319319
)
320320

321+
if cursor_plan != len(new_shape):
322+
new_shape_check = new_shape[cursor_plan]
323+
if (
324+
isinstance(new_shape_check, int)
325+
and new_shape_check == 1
326+
and self.tensor.shape[cursor_self] != 1
327+
):
328+
arrow.append(False)
329+
edges.append((1, 0))
330+
shape.append(1)
331+
cursor_plan += 1
332+
continue
333+
321334
if cursor_plan != len(new_shape) and new_shape[cursor_plan] == -1:
322335
# Does not change
323336
arrow.append(self.arrow[cursor_self])

tests/reshape_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,59 @@ def test_reshape_trailing_nontrivial_dim_raises() -> None:
231231
a = GrassmannTensor((True,), ((2, 2),), torch.randn([4]))
232232
with pytest.raises(AssertionError, match="New shape exceeds after exhausting self dimensions"):
233233
_ = a.reshape((-1, (2, 2)))
234+
235+
236+
@pytest.mark.parametrize(
237+
"tensor",
238+
[
239+
GrassmannTensor(
240+
(True, True, True, True),
241+
((1, 0), (1, 0), (2, 2), (8, 8)),
242+
torch.randn(1, 1, 4, 16),
243+
),
244+
],
245+
)
246+
@pytest.mark.parametrize(
247+
"shape",
248+
[
249+
(1, 64),
250+
((1, 0), 64),
251+
(-1, 64),
252+
],
253+
)
254+
def test_reshape_trivial_head_equivalence(
255+
tensor: GrassmannTensor,
256+
shape: tuple[int, ...],
257+
) -> None:
258+
baseline_tensor = tensor.reshape((1, 64))
259+
actual_tensor = tensor.reshape(shape)
260+
261+
assert actual_tensor.edges == ((1, 0), (32, 32))
262+
assert torch.allclose(actual_tensor.tensor, baseline_tensor.tensor)
263+
264+
roundtrip_tensor = actual_tensor.reshape(tensor.edges)
265+
assert torch.allclose(roundtrip_tensor.tensor, tensor.tensor)
266+
267+
268+
def test_reshape_head_1_inserts_trivial_when_self_dim_not_one() -> None:
269+
a = GrassmannTensor(
270+
(True, True),
271+
((2, 2), (8, 8)),
272+
torch.randn(4, 16),
273+
)
274+
out = a.reshape((1, 64))
275+
assert out.edges == ((1, 0), (32, 32))
276+
assert out.tensor.shape == (1, 64)
277+
assert out.arrow[0] is False
278+
279+
280+
def test_reshape_plan_exhausted_then_skip_trivial_self_edges() -> None:
281+
a = GrassmannTensor(
282+
(False, False, False),
283+
((2, 2), (1, 0), (1, 0)),
284+
torch.randn(4, 1, 1),
285+
)
286+
out = a.reshape((4,))
287+
assert out.edges == ((2, 2),)
288+
assert out.tensor.shape == (4,)
289+
assert out.arrow == (False,)

0 commit comments

Comments
 (0)