Skip to content

Commit 456b5cc

Browse files
committed
fix reference count leak
1 parent 541e368 commit 456b5cc

File tree

3 files changed

+49
-33
lines changed

3 files changed

+49
-33
lines changed

Lib/test/test_defaultdict.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import unittest
66

77
from collections import defaultdict
8-
from threading import Condition, Thread
98

109
def foobar():
1110
return list
@@ -187,31 +186,5 @@ def test_union(self):
187186
with self.assertRaises(TypeError):
188187
i |= None
189188

190-
def test_default_race(self):
191-
cv = Condition()
192-
key = "default_race_key"
193-
194-
def default_factory(cv: Condition, cv_flag: list[bool]):
195-
with cv:
196-
while not cv_flag[0]:
197-
cv.wait()
198-
return "default_value"
199-
ready_flag = [False]
200-
test_dict = defaultdict(lambda: default_factory(cv, ready_flag))
201-
202-
def writer(cv: Condition, cv_flag: list[bool], race_dict: dict):
203-
with cv:
204-
race_dict[key] = "writer_value"
205-
cv_flag[0] = True
206-
cv.notify()
207-
208-
default_factory_thread = Thread(target=lambda: test_dict[key])
209-
writer_thread = Thread(target=lambda: writer(cv, ready_flag, test_dict))
210-
default_factory_thread.start()
211-
writer_thread.start()
212-
default_factory_thread.join()
213-
writer_thread.join()
214-
self.assertEqual(test_dict[key], "writer_value")
215-
216189
if __name__ == "__main__":
217190
unittest.main()
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import unittest
2+
from collections import defaultdict
3+
4+
from threading import Barrier, Thread
5+
from unittest import TestCase
6+
7+
try:
8+
import _testcapi
9+
except ImportError:
10+
_testcapi = None
11+
12+
from test.support import threading_helper
13+
14+
@threading_helper.requires_working_threading()
15+
class TestDefaultDict(TestCase):
16+
def test_default_factory_race(self):
17+
wait_barrier = Barrier(2)
18+
write_barrier = Barrier(2)
19+
key = "default_race_key"
20+
21+
def default_factory():
22+
wait_barrier.wait()
23+
write_barrier.wait()
24+
return "default_value"
25+
26+
test_dict = defaultdict(default_factory)
27+
28+
def writer():
29+
wait_barrier.wait()
30+
test_dict[key] = "writer_value"
31+
write_barrier.wait()
32+
33+
default_factory_thread = Thread(target=lambda: test_dict[key])
34+
writer_thread = Thread(target=writer)
35+
default_factory_thread.start()
36+
writer_thread.start()
37+
default_factory_thread.join()
38+
writer_thread.join()
39+
self.assertEqual(test_dict[key], "writer_value")
40+
41+
42+
if __name__ == "__main__":
43+
unittest.main()

Modules/_collectionsmodule.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,12 +2232,12 @@ defdict_missing(PyObject *op, PyObject *key)
22322232
if (value == NULL)
22332233
return value;
22342234
PyObject* result;
2235-
int res = PyDict_SetDefaultRef(op, key, value, &result);
2236-
if (res != 0) {
2237-
// when res < 0, result will be NULL
2238-
// when res > 0, result is a new reference to the existing value
2239-
Py_DECREF(value);
2240-
}
2235+
PyDict_SetDefaultRef(op, key, value, &result);
2236+
// when PyDict_SetDefaultRef() < 0, result will be NULL
2237+
// when PyDict_SetDefaultRef() == 0, result is a new reference of default value
2238+
// when PyDict_SetDefaultRef() > 0, result is a new reference to the existing value
2239+
// so the value reference must be decref'ed in all cases
2240+
Py_DECREF(value);
22412241
return result;
22422242
}
22432243

0 commit comments

Comments
 (0)