Skip to content

Commit 58555d8

Browse files
authored
Implement garbage collection for Rust data holding refs to Python objects (#416)
1 parent e976b81 commit 58555d8

File tree

6 files changed

+135
-2
lines changed

6 files changed

+135
-2
lines changed

src/serializers/mod.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::fmt::Debug;
22

33
use pyo3::prelude::*;
44
use pyo3::types::{PyBytes, PyDict};
5+
use pyo3::{PyTraverseError, PyVisit};
56

67
use crate::build_context::BuildContext;
78
use crate::validators::SelfValidator;
@@ -136,6 +137,21 @@ impl SchemaSerializer {
136137
self.serializer, self.slots
137138
)
138139
}
140+
141+
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
142+
self.serializer.py_gc_traverse(&visit)?;
143+
for slot in self.slots.iter() {
144+
slot.py_gc_traverse(&visit)?;
145+
}
146+
Ok(())
147+
}
148+
149+
fn __clear__(&mut self) {
150+
self.serializer.py_gc_clear();
151+
for slot in self.slots.iter_mut() {
152+
slot.py_gc_clear();
153+
}
154+
}
139155
}
140156

141157
#[allow(clippy::too_many_arguments)]

src/serializers/shared.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ use std::borrow::Cow;
22
use std::fmt::Debug;
33

44
use pyo3::exceptions::PyTypeError;
5-
use pyo3::intern;
65
use pyo3::prelude::*;
76
use pyo3::types::{PyDict, PySet};
7+
use pyo3::{intern, PyTraverseError, PyVisit};
88

99
use enum_dispatch::enum_dispatch;
1010
use serde::Serialize;
@@ -225,6 +225,10 @@ impl BuildSerializer for CombinedSerializer {
225225

226226
#[enum_dispatch(CombinedSerializer)]
227227
pub(crate) trait TypeSerializer: Send + Sync + Clone + Debug {
228+
fn py_gc_traverse(&self, _visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
229+
Ok(())
230+
}
231+
fn py_gc_clear(&mut self) {}
228232
fn to_python(
229233
&self,
230234
value: &PyAny,

src/serializers/type_serializers/model.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ impl ModelSerializer {
5555
}
5656

5757
impl TypeSerializer for ModelSerializer {
58+
fn py_gc_traverse(&self, visit: &pyo3::PyVisit<'_>) -> Result<(), pyo3::PyTraverseError> {
59+
visit.call(&self.class)?;
60+
self.serializer.py_gc_traverse(visit)?;
61+
Ok(())
62+
}
63+
5864
fn to_python(
5965
&self,
6066
value: &PyAny,

src/validators/mod.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ use std::fmt::Debug;
22

33
use enum_dispatch::enum_dispatch;
44

5-
use pyo3::intern;
65
use pyo3::once_cell::GILOnceCell;
76
use pyo3::prelude::*;
87
use pyo3::types::{PyAny, PyDict};
8+
use pyo3::{intern, PyTraverseError, PyVisit};
99

1010
use crate::build_context::BuildContext;
1111
use crate::build_tools::{py_err, py_error_type, SchemaDict, SchemaError};
@@ -211,6 +211,22 @@ impl SchemaValidator {
211211
self.slots,
212212
)
213213
}
214+
215+
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
216+
self.validator.py_gc_traverse(&visit)?;
217+
visit.call(&self.schema)?;
218+
for slot in self.slots.iter() {
219+
slot.py_gc_traverse(&visit)?;
220+
}
221+
Ok(())
222+
}
223+
224+
fn __clear__(&mut self) {
225+
self.validator.py_gc_clear();
226+
for slot in self.slots.iter_mut() {
227+
slot.py_gc_clear();
228+
}
229+
}
214230
}
215231

216232
impl SchemaValidator {
@@ -548,6 +564,10 @@ pub enum CombinedValidator {
548564
/// validators defined in `build_validator` also need `EXPECTED_TYPE` as a const, but that can't be part of the trait
549565
#[enum_dispatch(CombinedValidator)]
550566
pub trait Validator: Send + Sync + Clone + Debug {
567+
fn py_gc_traverse(&self, _visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
568+
Ok(())
569+
}
570+
fn py_gc_clear(&mut self) {}
551571
/// Do the actual validation for this schema/type
552572
fn validate<'s, 'data>(
553573
&'s self,

src/validators/model.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ impl BuildValidator for ModelValidator {
6565
}
6666

6767
impl Validator for ModelValidator {
68+
fn py_gc_traverse(&self, visit: &pyo3::PyVisit<'_>) -> Result<(), pyo3::PyTraverseError> {
69+
visit.call(&self.class)?;
70+
self.validator.py_gc_traverse(visit)?;
71+
Ok(())
72+
}
6873
fn validate<'s, 'data>(
6974
&'s self,
7075
py: Python<'data>,

tests/test_garbage_collection.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import gc
2+
import platform
3+
from typing import Any
4+
from weakref import WeakValueDictionary
5+
6+
import pytest
7+
8+
from pydantic_core import SchemaSerializer, SchemaValidator, core_schema
9+
10+
11+
@pytest.mark.xfail(
12+
condition=platform.python_implementation() == 'PyPy', reason='https://foss.heptapod.net/pypy/pypy/-/issues/3899'
13+
)
14+
def test_gc_schema_serializer() -> None:
15+
# test for https://github.com/pydantic/pydantic/issues/5136
16+
class BaseModel:
17+
__schema__: SchemaSerializer
18+
19+
def __init_subclass__(cls) -> None:
20+
cls.__schema__ = SchemaSerializer(
21+
core_schema.model_schema(
22+
cls,
23+
core_schema.typed_dict_schema(
24+
{'x': core_schema.typed_dict_field(core_schema.definition_reference_schema('model'))}
25+
),
26+
ref='model',
27+
)
28+
)
29+
30+
cache: 'WeakValueDictionary[int, Any]' = WeakValueDictionary()
31+
32+
for _ in range(10_000):
33+
34+
class MyModel(BaseModel):
35+
pass
36+
37+
cache[id(MyModel)] = MyModel
38+
39+
del MyModel
40+
41+
gc.collect(0)
42+
gc.collect(1)
43+
gc.collect(2)
44+
45+
assert len(cache) == 0
46+
47+
48+
@pytest.mark.xfail(
49+
condition=platform.python_implementation() == 'PyPy', reason='https://foss.heptapod.net/pypy/pypy/-/issues/3899'
50+
)
51+
def test_gc_schema_validator() -> None:
52+
# test for https://github.com/pydantic/pydantic/issues/5136
53+
class BaseModel:
54+
__validator__: SchemaValidator
55+
56+
def __init_subclass__(cls) -> None:
57+
cls.__validator__ = SchemaValidator(
58+
core_schema.model_schema(
59+
cls,
60+
core_schema.typed_dict_schema(
61+
{'x': core_schema.typed_dict_field(core_schema.definition_reference_schema('model'))}
62+
),
63+
ref='model',
64+
)
65+
)
66+
67+
cache: 'WeakValueDictionary[int, Any]' = WeakValueDictionary()
68+
69+
for _ in range(10_000):
70+
71+
class MyModel(BaseModel):
72+
pass
73+
74+
cache[id(MyModel)] = MyModel
75+
76+
del MyModel
77+
78+
gc.collect(0)
79+
gc.collect(1)
80+
gc.collect(2)
81+
82+
assert len(cache) == 0

0 commit comments

Comments
 (0)