Skip to content

Commit 5c0dcb3

Browse files
authored
gh-146059: Call fast_save_leave() in pickle save_frozenset() (#146173)
Add more pickle tests: test also nested structures.
1 parent e44993a commit 5c0dcb3

File tree

2 files changed

+108
-4
lines changed

2 files changed

+108
-4
lines changed

Lib/test/pickletester.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757
# kind of outer loop.
5858
protocols = range(pickle.HIGHEST_PROTOCOL + 1)
5959

60+
FAST_NESTING_LIMIT = 50
61+
6062

6163
# Return True if opcode code appears in the pickle, else False.
6264
def opcode_in_pickle(code, pickle):
@@ -4552,6 +4554,98 @@ def __reduce__(self):
45524554
expected = "changed size during iteration"
45534555
self.assertIn(expected, str(e))
45544556

4557+
def fast_save_enter(self, create_data, minprotocol=0):
4558+
# gh-146059: Check that fast_save() is called when
4559+
# fast_save_enter() is called.
4560+
if not hasattr(self, "pickler"):
4561+
self.skipTest("need Pickler class")
4562+
4563+
data = [create_data(i) for i in range(FAST_NESTING_LIMIT * 2)]
4564+
data = {"key": data}
4565+
protocols = range(minprotocol, pickle.HIGHEST_PROTOCOL + 1)
4566+
for proto in protocols:
4567+
with self.subTest(proto=proto):
4568+
buf = io.BytesIO()
4569+
pickler = self.pickler(buf, protocol=proto)
4570+
# Enable fast mode (disables memo, enables cycle detection)
4571+
pickler.fast = 1
4572+
pickler.dump(data)
4573+
4574+
buf.seek(0)
4575+
data2 = self.unpickler(buf).load()
4576+
self.assertEqual(data2, data)
4577+
4578+
def test_fast_save_enter_tuple(self):
4579+
self.fast_save_enter(lambda i: (i,))
4580+
4581+
def test_fast_save_enter_list(self):
4582+
self.fast_save_enter(lambda i: [i])
4583+
4584+
def test_fast_save_enter_frozenset(self):
4585+
self.fast_save_enter(lambda i: frozenset([i]))
4586+
4587+
def test_fast_save_enter_set(self):
4588+
self.fast_save_enter(lambda i: set([i]))
4589+
4590+
def test_fast_save_enter_frozendict(self):
4591+
if self.py_version < (3, 15):
4592+
self.skipTest('need frozendict')
4593+
self.fast_save_enter(lambda i: frozendict(key=i), minprotocol=2)
4594+
4595+
def test_fast_save_enter_dict(self):
4596+
self.fast_save_enter(lambda i: {"key": i})
4597+
4598+
def deep_nested_struct(self, seed, create_nested,
4599+
minprotocol=0, compare_equal=True,
4600+
depth=FAST_NESTING_LIMIT * 2):
4601+
# gh-146059: Check that fast_save() is called when
4602+
# fast_save_enter() is called.
4603+
if not hasattr(self, "pickler"):
4604+
self.skipTest("need Pickler class")
4605+
4606+
data = seed
4607+
for i in range(depth):
4608+
data = create_nested(data)
4609+
data = {"key": data}
4610+
protocols = range(minprotocol, pickle.HIGHEST_PROTOCOL + 1)
4611+
for proto in protocols:
4612+
with self.subTest(proto=proto):
4613+
buf = io.BytesIO()
4614+
pickler = self.pickler(buf, protocol=proto)
4615+
# Enable fast mode (disables memo, enables cycle detection)
4616+
pickler.fast = 1
4617+
pickler.dump(data)
4618+
4619+
buf.seek(0)
4620+
data2 = self.unpickler(buf).load()
4621+
if compare_equal:
4622+
self.assertEqual(data2, data)
4623+
4624+
def test_deep_nested_struct_tuple(self):
4625+
self.deep_nested_struct((1,), lambda data: (data,))
4626+
4627+
def test_deep_nested_struct_list(self):
4628+
self.deep_nested_struct([1], lambda data: [data])
4629+
4630+
def test_deep_nested_struct_frozenset(self):
4631+
self.deep_nested_struct(frozenset((1,)),
4632+
lambda data: frozenset((1, data)))
4633+
4634+
def test_deep_nested_struct_set(self):
4635+
self.deep_nested_struct({1}, lambda data: {K(data)},
4636+
depth=FAST_NESTING_LIMIT+1,
4637+
compare_equal=False)
4638+
4639+
def test_deep_nested_struct_frozendict(self):
4640+
if self.py_version < (3, 15):
4641+
self.skipTest('need frozendict')
4642+
self.deep_nested_struct(frozendict(x=1),
4643+
lambda data: frozendict(x=data),
4644+
minprotocol=2)
4645+
4646+
def test_deep_nested_struct_dict(self):
4647+
self.deep_nested_struct({'x': 1}, lambda data: {'x': data})
4648+
45554649

45564650
class BigmemPickleTests:
45574651

Modules/_pickle.c

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3671,16 +3671,13 @@ save_set(PickleState *state, PicklerObject *self, PyObject *obj)
36713671
}
36723672

36733673
static int
3674-
save_frozenset(PickleState *state, PicklerObject *self, PyObject *obj)
3674+
save_frozenset_impl(PickleState *state, PicklerObject *self, PyObject *obj)
36753675
{
36763676
PyObject *iter;
36773677

36783678
const char mark_op = MARK;
36793679
const char frozenset_op = FROZENSET;
36803680

3681-
if (self->fast && !fast_save_enter(self, obj))
3682-
return -1;
3683-
36843681
if (self->proto < 4) {
36853682
PyObject *items;
36863683
PyObject *reduce_value;
@@ -3751,6 +3748,19 @@ save_frozenset(PickleState *state, PicklerObject *self, PyObject *obj)
37513748
return 0;
37523749
}
37533750

3751+
static int
3752+
save_frozenset(PickleState *state, PicklerObject *self, PyObject *obj)
3753+
{
3754+
if (self->fast && !fast_save_enter(self, obj)) {
3755+
return -1;
3756+
}
3757+
int status = save_frozenset_impl(state, self, obj);
3758+
if (self->fast && !fast_save_leave(self, obj)) {
3759+
return -1;
3760+
}
3761+
return status;
3762+
}
3763+
37543764
static int
37553765
fix_imports(PickleState *st, PyObject **module_name, PyObject **global_name)
37563766
{

0 commit comments

Comments
 (0)