Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[c++] Fix offsets for nullable columns #3611

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apis/python/src/tiledbsoma/managed_query.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ void load_managed_query(py::module& m) {
data.size(),
(const void*)data_info.ptr,
static_cast<uint64_t*>(nullptr),
static_cast<uint8_t*>(nullptr));
std::nullopt);
py::gil_scoped_acquire acquire;
})
.def(
Expand Down
48 changes: 48 additions & 0 deletions apis/python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2050,3 +2050,51 @@ def test_arrow_table_sliced_writer(tmp_path):
np.testing.assert_array_equal(pdf["myenumint"], pydict["myenumint"])
np.testing.assert_array_equal(pdf["myenumstr"], pydict["myenumstr"])
np.testing.assert_array_equal(pdf["myenumbool"], pydict["myenumbool"])


def test_arrow_table_validity_with_slicing(tmp_path):
uri = tmp_path.as_posix()
num_rows = 10

schema = pa.schema(
[
("myint", pa.int32()),
("mystring", pa.large_string()),
("mybool", pa.bool_()),
("myenum", pa.dictionary(pa.int64(), pa.large_string())),
]
)

pydict = {}
pydict["soma_joinid"] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
pydict["myint"] = [1, 2, 3, 4, 5, 6, None, 8, None, None]
pydict["mystring"] = ["g1", "g2", "g3", None, "g2", "g3", "g1", None, "g3", "g1"]
pydict["mybool"] = [True, True, True, False, True, False, None, False, None, None]
pydict["myenum"] = pd.Categorical(
["g1", "g2", "g3", None, "g2", "g3", "g1", None, "g3", "g1"]
)

table = pa.Table.from_pydict(pydict)
domain = ((0, np.iinfo(np.int64).max - 2050),)

with soma.DataFrame.create(uri, schema=schema, domain=domain) as A:
A.write(table)

with soma.DataFrame.open(uri) as A:
pdf = A.read().concat().to_pandas()
assert pdf["myint"].compare(table["myint"].to_pandas()).empty
assert pdf["mystring"].compare(table["mystring"].to_pandas()).empty
assert pdf["mybool"].compare(table["mybool"].to_pandas()).empty
assert pdf["myenum"].compare(table["myenum"].to_pandas()).empty

with soma.DataFrame.open(uri, "w") as A:
mid = num_rows // 2
A.write(table[:mid])
A.write(table[mid:])

with soma.DataFrame.open(uri) as A:
pdf = A.read().concat().to_pandas()
nguyenv marked this conversation as resolved.
Show resolved Hide resolved
assert pdf["myint"].compare(table["myint"].to_pandas()).empty
assert pdf["mystring"].compare(table["mystring"].to_pandas()).empty
assert pdf["mybool"].compare(table["mybool"].to_pandas()).empty
assert pdf["myenum"].compare(table["myenum"].to_pandas()).empty
10 changes: 3 additions & 7 deletions libtiledbsoma/src/soma/column_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class ColumnBuffer {
uint64_t num_elems,
const void* data,
T* offsets,
uint8_t* validity = nullptr) {
std::optional<std::vector<uint8_t>> validity = std::nullopt) {
nguyenv marked this conversation as resolved.
Show resolved Hide resolved
num_cells_ = num_elems;

// Ensure the offset type is either uint32_t* or uint64_t*
Expand All @@ -147,12 +147,8 @@ class ColumnBuffer {
}

if (is_nullable_) {
if (validity != nullptr) {
for (uint64_t i = 0; i < num_elems; ++i) {
uint8_t byte = validity[i / 8];
uint8_t bit = (byte >> (i % 8)) & 0x01;
validity_.push_back(bit);
}
if (validity.has_value()) {
validity_ = *validity;
} else {
validity_.assign(num_elems, 1); // Default all to valid (1)
}
Expand Down
59 changes: 37 additions & 22 deletions libtiledbsoma/src/soma/managed_query.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "utils/common.h"
#include "utils/logger.h"
#include "utils/util.h"

namespace tiledbsoma {

using namespace tiledb;
Expand Down Expand Up @@ -806,7 +807,7 @@ void ManagedQuery::_cast_dictionary_values(
array->length,
(const void*)index_to_value.data(),
(uint64_t*)nullptr,
(uint8_t*)value_array->buffers[0]);
std::nullopt); // validities are set by index column
}

template <>
Expand Down Expand Up @@ -863,7 +864,7 @@ void ManagedQuery::_cast_dictionary_values<std::string>(
value_offsets.size() - 1,
(const void*)index_to_value.data(),
(uint64_t*)value_offsets.data(),
(uint8_t*)value_array->buffers[0]);
std::nullopt); // validities are set by index column
}

template <>
Expand All @@ -876,8 +877,7 @@ void ManagedQuery::_cast_dictionary_values<bool>(
auto value_array = array->dictionary;

std::vector<int64_t> indexes = _get_index_vector(schema, array);
std::vector<uint8_t> values = util::cast_bit_to_uint8(
value_schema, value_array);
std::vector<uint8_t> values = _cast_bool_data(value_schema, value_array);
std::vector<uint8_t> index_to_value;

for (auto i : indexes) {
Expand All @@ -889,7 +889,7 @@ void ManagedQuery::_cast_dictionary_values<bool>(
array->length,
(const void*)index_to_value.data(),
(uint64_t*)nullptr,
(uint8_t*)value_array->buffers[0]);
std::nullopt); // validities are set by index column
}

template <typename UserType>
Expand Down Expand Up @@ -982,13 +982,9 @@ bool ManagedQuery::_cast_column_aux<std::string>(
array->n_buffers));
}

const char* data = (const char*)array->buffers[2];
uint8_t* validity = (uint8_t*)array->buffers[0];
const void* data = array->buffers[2];
std::optional<std::vector<uint8_t>> validity = _cast_validity_buffer(array);

// If this is a table-slice, slice into the validity buffer.
if (validity != nullptr) {
validity += array->offset;
}
// If this is a table-slice, do *not* slice into the data
// buffer since it is indexed via offsets, which we slice
// into below.
Expand All @@ -997,14 +993,12 @@ bool ManagedQuery::_cast_column_aux<std::string>(
(strcmp(schema->format, "Z") == 0)) {
// If this is a table-slice, slice into the offsets buffer.
uint64_t* offset = (uint64_t*)array->buffers[1] + array->offset;
setup_write_column(
schema->name, array->length, (const void*)data, offset, validity);
setup_write_column(schema->name, array->length, data, offset, validity);

} else {
// If this is a table-slice, slice into the offsets buffer.
uint32_t* offset = (uint32_t*)array->buffers[1] + array->offset;
setup_write_column(
schema->name, array->length, (const void*)data, offset, validity);
setup_write_column(schema->name, array->length, data, offset, validity);
}
return false;
}
Expand All @@ -1014,18 +1008,14 @@ bool ManagedQuery::_cast_column_aux<bool>(
ArrowSchema* schema, ArrowArray* array, ArraySchemaEvolution se) {
(void)se; // se is unused in bool specialization

auto casted = util::cast_bit_to_uint8(schema, array);
uint8_t* validity = (uint8_t*)array->buffers[0];
if (validity != nullptr) {
validity += array->offset;
}
auto casted = _cast_bool_data(schema, array);

setup_write_column(
schema->name,
array->length,
(const void*)casted.data(),
(uint64_t*)nullptr,
(uint8_t*)validity);
_cast_validity_buffer(array));
return false;
}

Expand Down Expand Up @@ -1102,7 +1092,7 @@ bool ManagedQuery::_extend_and_evolve_schema(
if (strcmp(value_schema->format, "b") == 0) {
// Specially handle Boolean types as their representation in Arrow (bit)
// is different from what is in TileDB (uint8_t)
auto casted = util::cast_bit_to_uint8(value_schema, value_array);
auto casted = _cast_bool_data(value_schema, value_array);
enums_in_write.assign(casted.data(), casted.data() + num_elems);
} else {
// General case
Expand Down Expand Up @@ -1257,4 +1247,29 @@ bool ManagedQuery::_extend_and_evolve_schema<std::string>(
}
return false;
}

std::vector<uint8_t> ManagedQuery::_cast_bool_data(
ArrowSchema* schema, ArrowArray* array) {
if (strcmp(schema->format, "b") != 0) {
throw TileDBSOMAError(std::format(
"_cast_bit_to_uint8 expected column format to be 'b' but saw "
"{}",
schema->format));
}

uint8_t* data;
if (array->n_buffers == 3) {
data = (uint8_t*)array->buffers[2];
nguyenv marked this conversation as resolved.
Show resolved Hide resolved
} else {
data = (uint8_t*)array->buffers[1];
}

return *util::bitmap_to_uint8(data, array->length, array->offset);
}

std::optional<std::vector<uint8_t>> ManagedQuery::_cast_validity_buffer(
ArrowArray* array) {
uint8_t* validity = (uint8_t*)array->buffers[0];
return util::bitmap_to_uint8(validity, array->length, array->offset);
}
}; // namespace tiledbsoma
42 changes: 29 additions & 13 deletions libtiledbsoma/src/soma/managed_query.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,16 +249,17 @@ class ManagedQuery {
* @param name Column name
* @param num_elems Number of array elements in buffer
* @param data Pointer to the data buffer
* If the data type is Boolean, the data has already been casted to uint8
* @param offsets Pointer to the offsets buffer
* @param validity Pointer to the validity buffer
* @param validity Vector of validity buffer casted to uint8
*/
template <typename T>
void setup_write_column(
std::string_view name,
uint64_t num_elems,
const void* data,
T* offsets,
uint8_t* validity) {
std::optional<std::vector<uint8_t>> validity = std::nullopt) {
// Ensure the offset type is either uint32_t* or uint64_t*
static_assert(
std::is_same_v<T, uint32_t> || std::is_same_v<T, uint64_t>,
Expand Down Expand Up @@ -625,10 +626,6 @@ class ManagedQuery {
} else {
buf = (UserType*)array->buffers[1] + array->offset;
}
uint8_t* validity = (uint8_t*)array->buffers[0];
if (validity != nullptr) {
validity += array->offset;
}

bool has_attr = schema_->has_attribute(schema->name);
if (has_attr && attr_has_enum(schema->name)) {
Expand Down Expand Up @@ -656,7 +653,7 @@ class ManagedQuery {
casted_values.size(),
(const void*)casted_values.data(),
(uint64_t*)nullptr,
validity);
_cast_validity_buffer(array));

// Return false because we do not extend the enumeration
return false;
Expand Down Expand Up @@ -794,17 +791,12 @@ class ManagedQuery {
std::vector<DiskIndexType> casted_indexes(
shifted_indexes.begin(), shifted_indexes.end());

uint8_t* validity = (uint8_t*)index_array->buffers[0];
if (validity != nullptr) {
validity += index_array->offset;
}

setup_write_column(
column_name,
casted_indexes.size(),
(const void*)casted_indexes.data(),
(uint64_t*)nullptr,
(uint8_t*)validity);
_cast_validity_buffer(index_array));
}

bool _extend_enumeration(
Expand Down Expand Up @@ -851,6 +843,30 @@ class ManagedQuery {
bool attr_has_enum(std::string attr_name) {
return get_enum_label_on_attr(attr_name).has_value();
}

/**
* @brief Take an arrow schema and array containing bool
* data in bits and return a vector containing the uint8_t
* representation
*
* @param schema the ArrowSchema which must be format 'b'
* @param array the ArrowArray holding Boolean data
* @return std::vector<uint8_t>
*/
std::vector<uint8_t> _cast_bool_data(
ArrowSchema* schema, ArrowArray* array);

/**
* @brief Take a validity buffer (in bits) and shift according to the
* offset. This function returns a copy of the shifted bitmap as a
* std::vector<uint8_t>. If the validity buffer is null, then return a
* nullopt.
*
* @param array the ArrowArray holding offset to shift
* @return std::optional<std::vector<uint8_t>>
*/
std::optional<std::vector<uint8_t>> _cast_validity_buffer(
ArrowArray* array);
};

// These are all specializations to string/bool of various methods
Expand Down
14 changes: 12 additions & 2 deletions libtiledbsoma/src/soma/soma_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,12 @@ void SOMAArray::set_column_data(
const void* data,
uint64_t* offsets,
uint8_t* validity) {
mq_->setup_write_column(name, num_elems, data, offsets, validity);
mq_->setup_write_column(
name,
num_elems,
data,
offsets,
util::bitmap_to_uint8(validity, num_elems));
};

void SOMAArray::set_column_data(
Expand All @@ -285,7 +290,12 @@ void SOMAArray::set_column_data(
const void* data,
uint32_t* offsets,
uint8_t* validity) {
mq_->setup_write_column(name, num_elems, data, offsets, validity);
mq_->setup_write_column(
name,
num_elems,
data,
offsets,
util::bitmap_to_uint8(validity, num_elems));
};

uint64_t SOMAArray::ndim() const {
Expand Down
23 changes: 6 additions & 17 deletions libtiledbsoma/src/utils/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,15 @@ std::string rstrip_uri(std::string_view uri) {
return std::regex_replace(std::string(uri), std::regex("/+$"), "");
}

std::vector<uint8_t> cast_bit_to_uint8(ArrowSchema* schema, ArrowArray* array) {
if (strcmp(schema->format, "b") != 0) {
throw TileDBSOMAError(std::format(
"_cast_bit_to_uint8 expected column format to be 'b' but saw {}",
schema->format));
std::optional<std::vector<uint8_t>> bitmap_to_uint8(
uint8_t* bitmap, size_t length, size_t offset) {
if (bitmap == nullptr) {
return std::nullopt;
}

uint8_t* data;
if (array->n_buffers == 3) {
data = (uint8_t*)array->buffers[2];
} else {
data = (uint8_t*)array->buffers[1];
}

std::vector<uint8_t> casted(array->length);
std::vector<uint8_t> casted(length);
ArrowBitsUnpackInt8(
data,
array->offset,
array->length,
reinterpret_cast<int8_t*>(casted.data()));
bitmap, offset, length, reinterpret_cast<int8_t*>(casted.data()));
return casted;
}

Expand Down
15 changes: 8 additions & 7 deletions libtiledbsoma/src/utils/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,16 @@ bool is_tiledb_uri(std::string_view uri);
std::string rstrip_uri(std::string_view uri);

/**
* @brief Take an arrow schema and array containing bool
* data in bits and return a vector containing the uint8_t
* representation
* @brief Take a bitmap and return a vector containing the uint8_t
* representation. If the bitmap is null, then return a nullopt.
*
* @param schema the ArrowSchema which must be format 'b'
* @param array the ArrowArray holding Boolean data
* @return std::vector<uint8_t>
* @param bitmap Pointer to the start of the bitmap
* @param length Total number of elements
* @param offset Optionally offset the data
* @return std::optional<std::vector<uint8_t>>
*/
std::vector<uint8_t> cast_bit_to_uint8(ArrowSchema* schema, ArrowArray* array);
std::optional<std::vector<uint8_t>> bitmap_to_uint8(
uint8_t* bitmap, size_t length, size_t offset = 0);

} // namespace tiledbsoma::util

Expand Down