Skip to content

Commit

Permalink
Fix buffer protocol implementation
Browse files Browse the repository at this point in the history
According to the buffer protocol, `ndim` is a _required_ field [1], and
should always be set correctly. Additionally, `shape` should be set if
flags includes `PyBUF_ND` or higher [2]. The current implementation only
set those fields if flags was `PyBUF_STRIDES`.

[1] https://docs.python.org/3/c-api/buffer.html#request-independent-fields
[2] https://docs.python.org/3/c-api/buffer.html#shape-strides-suboffsets
  • Loading branch information
QuLogic committed Oct 22, 2024
1 parent f7e14e9 commit 9f36b54
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 3 deletions.
7 changes: 4 additions & 3 deletions include/pybind11/detail/class.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,9 +602,9 @@ extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int fla
return -1;
}
view->obj = obj;
view->ndim = 1;
view->internal = info;
view->buf = info->ptr;
view->ndim = (int) info->ndim;
view->itemsize = info->itemsize;
view->len = view->itemsize;
for (auto s : info->shape) {
Expand All @@ -614,10 +614,11 @@ extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int fla
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
view->format = const_cast<char *>(info->format.c_str());
}
if ((flags & PyBUF_ND) == PyBUF_ND) {
view->shape = info->shape.data();
}
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
view->ndim = (int) info->ndim;
view->strides = info->strides.data();
view->shape = info->shape.data();
}
Py_INCREF(view->obj);
return 0;
Expand Down
48 changes: 48 additions & 0 deletions tests/test_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,4 +268,52 @@ TEST_SUBMODULE(buffers, m) {
});

m.def("get_buffer_info", [](const py::buffer &buffer) { return buffer.request(); });

// Expose Py_buffer for testing.
m.attr("PyBUF_SIMPLE") = PyBUF_SIMPLE;
m.attr("PyBUF_ND") = PyBUF_ND;
m.attr("PyBUF_STRIDES") = PyBUF_STRIDES;
m.attr("PyBUF_INDIRECT") = PyBUF_INDIRECT;

m.def("get_py_buffer", [](const py::object &object, int flags) {
Py_buffer buffer;
memset(&buffer, 0, sizeof(Py_buffer));
if (PyObject_GetBuffer(object.ptr(), &buffer, flags) == -1) {
throw py::error_already_set();
}

auto SimpleNamespace = py::module_::import("types").attr("SimpleNamespace");
py::object result = SimpleNamespace("len"_a = buffer.len,
"readonly"_a = buffer.readonly,
"itemsize"_a = buffer.itemsize,
"format"_a = buffer.format,
"ndim"_a = buffer.ndim,
"shape"_a = py::none(),
"strides"_a = py::none(),
"suboffsets"_a = py::none());
if (buffer.shape != nullptr) {
py::list l;
for (auto i = 0; i < buffer.ndim; i++) {
l.append(buffer.shape[i]);
}
py::setattr(result, "shape", l);
}
if (buffer.strides != nullptr) {
py::list l;
for (auto i = 0; i < buffer.ndim; i++) {
l.append(buffer.strides[i]);
}
py::setattr(result, "strides", l);
}
if (buffer.suboffsets != nullptr) {
py::list l;
for (auto i = 0; i < buffer.ndim; i++) {
l.append(buffer.suboffsets[i]);
}
py::setattr(result, "suboffsets", l);
}

PyBuffer_Release(&buffer);
return result;
});
}
37 changes: 37 additions & 0 deletions tests/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,40 @@ def test_buffer_exception():
memoryview(m.BrokenMatrix(1, 1))
assert isinstance(excinfo.value.__cause__, RuntimeError)
assert "for context" in str(excinfo.value.__cause__)


def test_to_pybuffer():
mat = m.Matrix(5, 4)

info = m.get_py_buffer(mat, m.PyBUF_SIMPLE)
assert info.itemsize == ctypes.sizeof(ctypes.c_float)
assert info.len == mat.rows() * mat.cols() * info.itemsize
assert info.ndim == 2
assert info.shape is None
assert info.strides is None
assert info.suboffsets is None
assert not info.readonly
info = m.get_py_buffer(mat, m.PyBUF_ND)
assert info.itemsize == ctypes.sizeof(ctypes.c_float)
assert info.len == mat.rows() * mat.cols() * info.itemsize
assert info.ndim == 2
assert info.shape == [5, 4]
assert info.strides is None
assert info.suboffsets is None
assert not info.readonly
info = m.get_py_buffer(mat, m.PyBUF_STRIDES)
assert info.itemsize == ctypes.sizeof(ctypes.c_float)
assert info.len == mat.rows() * mat.cols() * info.itemsize
assert info.ndim == 2
assert info.shape == [5, 4]
assert info.strides == [4 * info.itemsize, info.itemsize]
assert info.suboffsets is None
assert not info.readonly
info = m.get_py_buffer(mat, m.PyBUF_INDIRECT)
assert info.itemsize == ctypes.sizeof(ctypes.c_float)
assert info.len == mat.rows() * mat.cols() * info.itemsize
assert info.ndim == 2
assert info.shape == [5, 4]
assert info.strides == [4 * info.itemsize, info.itemsize]
assert info.suboffsets is None # Should be filled in here, but we don't use it.
assert not info.readonly

0 comments on commit 9f36b54

Please sign in to comment.