Skip to content

Commit

Permalink
Add tests on query python objects (#235)
Browse files Browse the repository at this point in the history
* Update certifi for python3

* Fix query on list

* Add assertGreater for test_issue229

* Add -DARROW_JEMALLOC=1 in arrow-cmake

* Fix support ChunkedArray type column

* Add tests of query on Python obj

* Del tests/query_py.py
  • Loading branch information
auxten authored Jun 27, 2024
1 parent f520056 commit dd281e3
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 56 deletions.
2 changes: 1 addition & 1 deletion chdb/build_mac_arm64.sh
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ for PY_VER in 3.9.13 3.10.11 3.11.3 3.12.0; do
exit 1
fi

python3 -m pip install -U pybind11 wheel build tox psutil setuptools pyarrow pandas
python3 -m pip install -U pybind11 wheel build tox psutil setuptools pyarrow pandas certifi
rm -rf ${PROJ_DIR}/buildlib

${PROJ_DIR}/chdb/build.sh
Expand Down
1 change: 1 addition & 0 deletions contrib/arrow-cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ endif()

if (OS_LINUX)
set (ARROW_JEMALLOC ON)
add_definitions(-DARROW_JEMALLOC=1)
message(STATUS "Using jemalloc in arrow lib")
endif()

Expand Down
18 changes: 18 additions & 0 deletions src/Common/PythonUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,26 @@ const void * tryGetPyArray(const py::object & obj, py::handle & result, py::hand
py::array array = obj.attr("to_pandas")();
row_count = py::len(obj);
result = array;
tmp = array;
tmp.inc_ref();
return array.data();
}
else if (type_name == "ChunkedArray")
{
// Try to get the handle of py::array from PyArrow ChunkedArray
py::array array = obj.attr("to_numpy")();
row_count = py::len(obj);
result = array;
tmp = array;
tmp.inc_ref();
return array.data();
}
else if (type_name == "list")
{
// Just set the row count for list
row_count = py::len(obj);
return nullptr;
}

// chdb todo: maybe convert list to py::array?

Expand Down
9 changes: 9 additions & 0 deletions src/Common/PythonUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ inline bool isPyarrowTable(const py::object & obj)
});
}

inline bool hasGetItem(const py::object & obj)
{
return execWithGIL(
[&]()
{
return py::hasattr(obj, "__getitem__");
});
}

// Specific wrappers for common use cases
inline auto castToPyList(const py::object & obj)
{
Expand Down
20 changes: 10 additions & 10 deletions src/Processors/Sources/PythonSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,23 +148,14 @@ ColumnPtr PythonSource::convert_and_insert(const py::object & obj, UInt32 scale)
column = ColumnVector<T>::create();

std::string type_name;
size_t row_count;
size_t row_count = 0;
py::handle py_array;
py::handle tmp;
SCOPE_EXIT({
if (!tmp.is_none())
tmp.dec_ref();
});
const void * data = tryGetPyArray(obj, py_array, tmp, type_name, row_count);
if (!py_array.is_none())
{
if constexpr (std::is_same_v<T, String>)
insert_string_from_array(py_array, column);
else
insert_from_ptr<T>(data, column, 0, row_count);
return column;
}

if (type_name == "list")
{
//reserve the size of the column
Expand All @@ -173,6 +164,15 @@ ColumnPtr PythonSource::convert_and_insert(const py::object & obj, UInt32 scale)
return column;
}

if (!py_array.is_none() && data != nullptr)
{
if constexpr (std::is_same_v<T, String>)
insert_string_from_array(py_array, column);
else
insert_from_ptr<T>(data, column, 0, row_count);
return column;
}

throw Exception(ErrorCodes::BAD_TYPE_OF_FIELD, "Unsupported type {} for value {}", getPyType(obj), castToStr(obj));
}

Expand Down
2 changes: 1 addition & 1 deletion src/TableFunctions/TableFunctionPython.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ py::object find_instances_of_pyreader(const std::string & var_name)
if (dict.contains(var_name))
{
py::object obj = dict[var_name.data()];
if (isInheritsFromPyReader(obj) || isPandasDf(obj) || isPyarrowTable(obj))
if (isInheritsFromPyReader(obj) || isPandasDf(obj) || isPyarrowTable(obj) || hasGetItem(obj))
return obj;
}
}
Expand Down
33 changes: 0 additions & 33 deletions tests/query_py.py

This file was deleted.

40 changes: 29 additions & 11 deletions tests/test_issue229.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import unittest
import threading
from chdb import session

thread_count = 5
insert_count = 15
return_results = [None] * thread_count

def perform_operations(index):
sess = session.Session()
Expand Down Expand Up @@ -28,7 +32,8 @@ def perform_operations(index):
)

# Insert multiple entries into the table
for i in range(15):
for i in range(insert_count):
# print(f"Inserting entry {i} into the table in session {index}")
sess.query(
f"""
INSERT INTO knowledge_base_portal_interface_event
Expand All @@ -37,26 +42,39 @@ def perform_operations(index):
}, "locale": "en", "timestamp": 1717780952772, "event_type": "article_update", "article_id": 7}}]"""
)

print(f"Inserted {insert_count} entries into the table in session {index}")

# Retrieve all entries from the table
results = sess.query(
"SELECT * FROM knowledge_base_portal_interface_event", "JSONObjectEachRow"
)
print("Session Query Result:", results)
return_results[index] = str(results)

# Cleanup session
sess.cleanup()


# Create multiple threads to perform operations
threads = []
for i in range(5):
threads.append(threading.Thread(target=perform_operations, args=(i,)))
class TestIssue229(unittest.TestCase):
def test_issue229(self):
# Create multiple threads to perform operations
threads = []
results = []
for i in range(thread_count):
threads.append(threading.Thread(target=perform_operations, args=(i,)))

for thread in threads:
thread.start()

# Wait for all threads to complete, and collect results returned by each thread
for thread in threads:
thread.join()

for thread in threads:
thread.start()
# Check if all threads have returned results
for i in range(thread_count):
lines = return_results[i].split("\n")
self.assertGreater(len(lines), 2 * insert_count)

for thread in threads:
thread.join()

# for i in range(5):
# perform_operations(i)
if __name__ == "__main__":
unittest.main()
102 changes: 102 additions & 0 deletions tests/test_query_py.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#!python3

import unittest
import numpy as np
import pandas as pd
import pyarrow as pa
import chdb


EXPECTED = """"auxten",9
"jerry",7
"tom",5
"""


class myReader(chdb.PyReader):
def __init__(self, data):
self.data = data
self.cursor = 0
super().__init__(data)

def read(self, col_names, count):
print("Python func read", col_names, count, self.cursor)
if self.cursor >= len(self.data["a"]):
return []
block = [self.data[col] for col in col_names]
self.cursor += len(block[0])
return block


class TestQueryPy(unittest.TestCase):
def test_query_py(self):
reader = myReader(
{
"a": [1, 2, 3, 4, 5, 6],
"b": ["tom", "jerry", "auxten", "tom", "jerry", "auxten"],
}
)

ret = chdb.query("SELECT b, sum(a) FROM Python(reader) GROUP BY b ORDER BY b")
self.assertEqual(str(ret), EXPECTED)

def test_query_df(self):
df = pd.DataFrame(
{
"a": [1, 2, 3, 4, 5, 6],
"b": ["tom", "jerry", "auxten", "tom", "jerry", "auxten"],
}
)

ret = chdb.query("SELECT b, sum(a) FROM Python(df) GROUP BY b ORDER BY b")
self.assertEqual(str(ret), EXPECTED)

def test_query_arrow(self):
table = pa.table(
{
"a": pa.array([1, 2, 3, 4, 5, 6]),
"b": pa.array(["tom", "jerry", "auxten", "tom", "jerry", "auxten"]),
}
)

ret = chdb.query(
"SELECT b, sum(a) FROM Python(table) GROUP BY b ORDER BY b", "debug"
)
self.assertEqual(str(ret), EXPECTED)

def test_query_arrow2(self):
t2 = pa.table(
{
"a": [1, 2, 3, 4, 5, 6],
"b": ["tom", "jerry", "auxten", "tom", "jerry", "auxten"],
}
)

ret = chdb.query(
"SELECT b, sum(a) FROM Python(t2) GROUP BY b ORDER BY b", "debug"
)
self.assertEqual(str(ret), EXPECTED)

# def test_query_np(self):
# t3 = {
# "a": np.array([1, 2, 3, 4, 5, 6]),
# "b": np.array(["tom", "jerry", "auxten", "tom", "jerry", "auxten"]),
# }

# ret = chdb.query(
# "SELECT b, sum(a) FROM Python(t3) GROUP BY b ORDER BY b", "debug"
# )
# self.assertEqual(str(ret), EXPECTED)

# def test_query_dict(self):
# data = {
# "a": [1, 2, 3, 4, 5, 6],
# "b": ["tom", "jerry", "auxten", "tom", "jerry", "auxten"],
# }

# ret = chdb.query("SELECT b, sum(a) FROM Python(data) GROUP BY b ORDER BY b")
# self.assertEqual(str(ret), EXPECTED)


if __name__ == "__main__":
unittest.main()

0 comments on commit dd281e3

Please sign in to comment.