Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions src/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl CoreBPE {
Err(e) => return Err(PyErr::new::<exceptions::PyValueError, _>(e.message)),
};

let buffer = TiktokenBuffer { tokens };
let buffer = TiktokenBuffer::new(tokens);
buffer.into_py_any(py)
}

Expand Down Expand Up @@ -186,6 +186,17 @@ impl CoreBPE {
#[pyclass(frozen)]
struct TiktokenBuffer {
tokens: Vec<Rank>,
// Per the buffer protocol, `view.shape[0]` must be the element count
// (so that `prod(shape) * itemsize == len`). Box it so the pointer we
// hand out in `__getbuffer__` is stable for the lifetime of `self`.
shape: Box<pyo3::ffi::Py_ssize_t>,
}

impl TiktokenBuffer {
fn new(tokens: Vec<Rank>) -> Self {
let shape = Box::new(tokens.len() as pyo3::ffi::Py_ssize_t);
Self { tokens, shape }
}
}

#[pymethods]
Expand All @@ -208,7 +219,8 @@ impl TiktokenBuffer {
let view_ref = &mut *view;
view_ref.obj = slf.clone().into_any().into_ptr();

let data = &slf.borrow().tokens;
let borrowed = slf.borrow();
let data = &borrowed.tokens;
view_ref.buf = data.as_ptr() as *mut std::os::raw::c_void;
view_ref.len = (data.len() * std::mem::size_of::<Rank>()) as isize;
view_ref.readonly = 1;
Expand All @@ -221,7 +233,10 @@ impl TiktokenBuffer {
};
view_ref.ndim = 1;
view_ref.shape = if (flags & pyo3::ffi::PyBUF_ND) == pyo3::ffi::PyBUF_ND {
&mut view_ref.len
// Element count, not byte length, per PEP 3118.
// The Box lives as long as `self` (kept alive via `view_ref.obj`),
// so this pointer is valid for the buffer view's lifetime.
&*borrowed.shape as *const pyo3::ffi::Py_ssize_t as *mut pyo3::ffi::Py_ssize_t
} else {
std::ptr::null_mut()
};
Expand Down
16 changes: 16 additions & 0 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,19 @@ def test_optional_blobfile_dependency():
assert "blobfile" not in sys.modules
"""
subprocess.check_call([sys.executable, "-c", prog])


def test_tiktoken_buffer_shape_is_element_count():
# Regression test for the buffer protocol's `shape` reporting bytes
# instead of element count, which caused `memoryview(...).tolist()`
# to expand each token into 4 byte-sized integers. See issue #405.
import math

enc = tiktoken.get_encoding("gpt2")
expected = enc.encode("hello world")
buf = enc._core_bpe.encode_to_tiktoken_buffer("hello world", frozenset())

mv = memoryview(buf)
# Per PEP 3118: prod(shape) * itemsize == nbytes
assert math.prod(mv.shape) * mv.itemsize == mv.nbytes
assert mv.tolist() == expected