From fbaaae0fd1de35df330feae6f6f6d60ef8c818fd Mon Sep 17 00:00:00 2001 From: Maxim Date: Thu, 27 Jul 2023 18:19:26 +0300 Subject: [PATCH 1/3] Added aggregation results in QueryResultsWrapper --- .../lib/include/queryresults_wrapper.h | 3 ++ pyreindexer/lib/src/rawpyreindexer.cc | 33 +++++++++++++++++++ pyreindexer/lib/src/rawpyreindexer.h | 2 ++ pyreindexer/query_results.py | 8 +++++ pyreindexer/tests/tests/test_sql.py | 15 +++++++++ 5 files changed, 61 insertions(+) diff --git a/pyreindexer/lib/include/queryresults_wrapper.h b/pyreindexer/lib/include/queryresults_wrapper.h index 3749c5d..37d7214 100644 --- a/pyreindexer/lib/include/queryresults_wrapper.h +++ b/pyreindexer/lib/include/queryresults_wrapper.h @@ -29,6 +29,9 @@ class QueryResultsWrapper { db_->FetchResults(*this); } + const std::vector& GetAggregationResults() const& { return qresPtr.GetAggregationResults(); } + const std::vector& GetAggregationResults() const&& = delete; + private: friend DBInterface; diff --git a/pyreindexer/lib/src/rawpyreindexer.cc b/pyreindexer/lib/src/rawpyreindexer.cc index c650af4..42be998 100644 --- a/pyreindexer/lib/src/rawpyreindexer.cc +++ b/pyreindexer/lib/src/rawpyreindexer.cc @@ -425,4 +425,37 @@ static PyObject *QueryResultsWrapperDelete(PyObject *self, PyObject *args) { Py_RETURN_NONE; } + +static PyObject *GetAggregationResults(PyObject *self, PyObject *args) { + uintptr_t qresWrapperAddr; + + if (!PyArg_ParseTuple(args, "k", &qresWrapperAddr)) { + return NULL; + } + + QueryResultsWrapper *qresWrapper = getQueryResultsWrapper(qresWrapperAddr); + + const auto &aggResults = qresWrapper->GetAggregationResults(); + WrSerializer wrSer; + wrSer << "["; + for (size_t i = 0; i < aggResults.size(); ++i) { + if (i > 0) { + wrSer << ','; + } + aggResults[i].GetJSON(wrSer); + } + wrSer << "]"; + + PyObject *dictFromJson = nullptr; + try { + dictFromJson = PyObjectFromJson(reindexer::giftStr(wrSer.Slice())); // stolen ref + } catch (const Error &err) { + Py_XDECREF(dictFromJson); + + return Py_BuildValue("is{}", err.code(), err.what().c_str()); + } + + return Py_BuildValue("isO", errOK, "", dictFromJson); +} + } // namespace pyreindexer diff --git a/pyreindexer/lib/src/rawpyreindexer.h b/pyreindexer/lib/src/rawpyreindexer.h index b3a18ac..799d1e3 100644 --- a/pyreindexer/lib/src/rawpyreindexer.h +++ b/pyreindexer/lib/src/rawpyreindexer.h @@ -73,6 +73,7 @@ static PyObject *EnumNamespaces(PyObject *self, PyObject *args); static PyObject *QueryResultsWrapperIterate(PyObject *self, PyObject *args); static PyObject *QueryResultsWrapperDelete(PyObject *self, PyObject *args); +static PyObject *GetAggregationResults(PyObject *self, PyObject *args); // clang-format off static PyMethodDef module_methods[] = { @@ -98,6 +99,7 @@ static PyMethodDef module_methods[] = { {"query_results_iterate", QueryResultsWrapperIterate, METH_VARARGS, "get query result"}, {"query_results_delete", QueryResultsWrapperDelete, METH_VARARGS, "free query results buffer"}, + {"get_agg_results", GetAggregationResults, METH_VARARGS, "get aggregation results"}, {NULL, NULL, 0, NULL} }; diff --git a/pyreindexer/query_results.py b/pyreindexer/query_results.py index 32c90ab..8901f51 100644 --- a/pyreindexer/query_results.py +++ b/pyreindexer/query_results.py @@ -76,3 +76,11 @@ def _close_iterator(self): self.qres_iter_count = 0 self.api.query_results_delete(self.qres_wrapper_ptr) + + + def get_agg_results(self): + """Returns aggregation results for the current query + + """ + + return self.api.get_agg_results(self.qres_wrapper_ptr) \ No newline at end of file diff --git a/pyreindexer/tests/tests/test_sql.py b/pyreindexer/tests/tests/test_sql.py index ad1e497..ca9c94b 100644 --- a/pyreindexer/tests/tests/test_sql.py +++ b/pyreindexer/tests/tests/test_sql.py @@ -63,3 +63,18 @@ def test_sql_select_with_syntax_error(self, namespace, index, item): assert_that(calling(sql_query).with_args(namespace, query), raises(Exception, matching=has_string(string_contains_in_order( "Expected", "but found"))), "Error wasn't raised when syntax was incorrect") + + def test_sql_select_with_aggregations(self, namespace, index, items): + # Given("Create namespace with item") + db, namespace_name = namespace + # When ("Insert items into namespace") + for _ in range(5): + db.item_insert(namespace_name, {"id": 100}, ["id=serial()"]) + + select_result = db.select(f'SELECT min(id), max(id), avg(id) FROM {namespace_name}').get_agg_results()[2] + expected_values = {"min":1,"max":10,"avg":5.5} + + # Then ("Check that returned agg results are correct") + for agg in select_result: + assert_that(agg['value'], equal_to(expected_values[agg['type']]), + f"Incorrect aggregation result for {agg['type']}") \ No newline at end of file From 06b9234fbbbc98654793cbb6c2ad75da4419bca9 Mon Sep 17 00:00:00 2001 From: Maksim Bogatyrev Date: Mon, 7 Oct 2024 18:05:53 +0300 Subject: [PATCH 2/3] Fix by PR comments --- pyreindexer/query_results.py | 6 +++++- pyreindexer/tests/tests/test_sql.py | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pyreindexer/query_results.py b/pyreindexer/query_results.py index fc237a4..be70d93 100644 --- a/pyreindexer/query_results.py +++ b/pyreindexer/query_results.py @@ -83,4 +83,8 @@ def get_agg_results(self): """ - return self.api.get_agg_results(self.qres_wrapper_ptr) \ No newline at end of file + self.err_code, self.err_msg, res = self.api.get_agg_results( + self.qres_wrapper_ptr) + if self.err_code: + raise Exception(self.err_msg) + return res \ No newline at end of file diff --git a/pyreindexer/tests/tests/test_sql.py b/pyreindexer/tests/tests/test_sql.py index ca9c94b..44e9849 100644 --- a/pyreindexer/tests/tests/test_sql.py +++ b/pyreindexer/tests/tests/test_sql.py @@ -71,7 +71,9 @@ def test_sql_select_with_aggregations(self, namespace, index, items): for _ in range(5): db.item_insert(namespace_name, {"id": 100}, ["id=serial()"]) - select_result = db.select(f'SELECT min(id), max(id), avg(id) FROM {namespace_name}').get_agg_results()[2] + select_result = db.select(f'SELECT min(id), max(id), avg(id) FROM {namespace_name}').get_agg_results() + assert_that(select_result, not empty(), "Aggregation result should not be empty") + expected_values = {"min":1,"max":10,"avg":5.5} # Then ("Check that returned agg results are correct") From 15c4827e3f8c606044e2df840278ffabab05b75d Mon Sep 17 00:00:00 2001 From: Maksim Bogatyrev Date: Mon, 7 Oct 2024 18:11:16 +0300 Subject: [PATCH 3/3] Small fix in test --- pyreindexer/tests/tests/test_sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyreindexer/tests/tests/test_sql.py b/pyreindexer/tests/tests/test_sql.py index 44e9849..08c96ce 100644 --- a/pyreindexer/tests/tests/test_sql.py +++ b/pyreindexer/tests/tests/test_sql.py @@ -72,7 +72,7 @@ def test_sql_select_with_aggregations(self, namespace, index, items): db.item_insert(namespace_name, {"id": 100}, ["id=serial()"]) select_result = db.select(f'SELECT min(id), max(id), avg(id) FROM {namespace_name}').get_agg_results() - assert_that(select_result, not empty(), "Aggregation result should not be empty") + assert_that(len(select_result), 3, "The aggregation result must contain 3 elements") expected_values = {"min":1,"max":10,"avg":5.5}