diff --git a/src/py.rs b/src/py.rs index 60b62452..59afabb3 100644 --- a/src/py.rs +++ b/src/py.rs @@ -65,7 +65,7 @@ impl CoreBPE { Err(e) => return Err(PyErr::new::(e.message)), }; - let buffer = TiktokenBuffer { tokens }; + let buffer = TiktokenBuffer::new(tokens); buffer.into_py_any(py) } @@ -186,6 +186,17 @@ impl CoreBPE { #[pyclass(frozen)] struct TiktokenBuffer { tokens: Vec, + // Per the buffer protocol, `view.shape[0]` must be the element count + // (so that `prod(shape) * itemsize == len`). Box it so the pointer we + // hand out in `__getbuffer__` is stable for the lifetime of `self`. + shape: Box, +} + +impl TiktokenBuffer { + fn new(tokens: Vec) -> Self { + let shape = Box::new(tokens.len() as pyo3::ffi::Py_ssize_t); + Self { tokens, shape } + } } #[pymethods] @@ -208,7 +219,8 @@ impl TiktokenBuffer { let view_ref = &mut *view; view_ref.obj = slf.clone().into_any().into_ptr(); - let data = &slf.borrow().tokens; + let borrowed = slf.borrow(); + let data = &borrowed.tokens; view_ref.buf = data.as_ptr() as *mut std::os::raw::c_void; view_ref.len = (data.len() * std::mem::size_of::()) as isize; view_ref.readonly = 1; @@ -221,7 +233,10 @@ impl TiktokenBuffer { }; view_ref.ndim = 1; view_ref.shape = if (flags & pyo3::ffi::PyBUF_ND) == pyo3::ffi::PyBUF_ND { - &mut view_ref.len + // Element count, not byte length, per PEP 3118. + // The Box lives as long as `self` (kept alive via `view_ref.obj`), + // so this pointer is valid for the buffer view's lifetime. + &*borrowed.shape as *const pyo3::ffi::Py_ssize_t as *mut pyo3::ffi::Py_ssize_t } else { std::ptr::null_mut() }; diff --git a/tests/test_misc.py b/tests/test_misc.py index 0832c8ee..45746d2d 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -28,3 +28,19 @@ def test_optional_blobfile_dependency(): assert "blobfile" not in sys.modules """ subprocess.check_call([sys.executable, "-c", prog]) + + +def test_tiktoken_buffer_shape_is_element_count(): + # Regression test for the buffer protocol's `shape` reporting bytes + # instead of element count, which caused `memoryview(...).tolist()` + # to expand each token into 4 byte-sized integers. See issue #405. + import math + + enc = tiktoken.get_encoding("gpt2") + expected = enc.encode("hello world") + buf = enc._core_bpe.encode_to_tiktoken_buffer("hello world", frozenset()) + + mv = memoryview(buf) + # Per PEP 3118: prod(shape) * itemsize == nbytes + assert math.prod(mv.shape) * mv.itemsize == mv.nbytes + assert mv.tolist() == expected