Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

[SQL-DS-CACHE-89][POAE7-1016] Fix agg result when column value could be null #90

Merged
merged 1 commit into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion oap-ape/ape-native/src/utils/AggExpression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,28 @@ int AggExpression::ExecuteWithParam(int batchSize,
std::vector<int8_t>& outBuffers) {
if (!done) {
child->ExecuteWithParam(batchSize, dataBuffers, nullBuffers, outBuffers);
done = true;
}
return 0;
}

void Count::getResult(DecimalVector& result) {
if (typeid(*child) == typeid(LiteralExpression)) { // for count(*) or count(1)
result.data.push_back(arrow::BasicDecimal128(batchSize_));
return;
}
if (!done) {
done = true;
auto tmp = DecimalVector();
child->getResult(tmp);
ARROW_LOG(INFO) << "count node child size: " << tmp.data.size();
for (int i = 0; i < tmp.data.size(); i++) {
if (tmp.nullVector->at(i)) count++;
}
}
result.data.push_back(arrow::BasicDecimal128(count));
result.type = ResultType::LongType;
}

int ArithmeticExpression::ExecuteWithParam(int batchSize,
const std::vector<int64_t>& dataBuffers,
const std::vector<int64_t>& nullBuffers,
Expand All @@ -89,6 +106,9 @@ int AttributeReferenceExpression::ExecuteWithParam(
done = true;
int64_t dataPtr = dataBuffers[columnIndex];
int64_t nullPtr = nullBuffers[columnIndex];
std::vector<uint8_t> nullVec(batchSize);
std::memcpy(nullVec.data(), (uint8_t*)nullPtr, batchSize);
result.nullVector = std::make_shared<std::vector<uint8_t>>(nullVec);
parquet::Type::type columnType = (*schema)[columnIndex].getColType();
if (isDecimalType(dataType)) {
int precision, scale;
Expand Down
101 changes: 65 additions & 36 deletions oap-ape/ape-native/src/utils/AggExpression.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class RootAggExpression : public WithResultExpression {
child->setSchema(schema);
}

void getResult(DecimalVector& result) { child->getResult(result); }
void getResult(DecimalVector& result) override { child->getResult(result); }

private:
bool isDistinct;
Expand All @@ -101,6 +101,18 @@ class AggExpression : public WithResultExpression {
done = false;
child->reset();
}
void getResult(DecimalVector& result) override {
if (!done) {
done = true;
getResultInternal(resultCache);
}
result = resultCache;
}

// build cached DecimalVector resultCache
virtual void getResultInternal(DecimalVector& result) {
ARROW_LOG(INFO) << "should never be called";
};

void setSchema(std::shared_ptr<std::vector<Schema>> schema_) {
schema = schema_;
Expand All @@ -109,18 +121,22 @@ class AggExpression : public WithResultExpression {

protected:
std::shared_ptr<WithResultExpression> child;
DecimalVector resultCache;
};

class Sum : public AggExpression {
public:
~Sum() {}
void getResult(DecimalVector& result) override {
void getResultInternal(DecimalVector& result) override {
auto tmp = DecimalVector();
child->getResult(tmp);
arrow::BasicDecimal128 out;
for (auto e : tmp.data) {
out += e;
for (int i = 0; i < tmp.data.size(); i++) {
if (tmp.nullVector->at(i)) {
out += tmp.data[i];
}
}
result.data.clear();
result.data.push_back(out);
result.precision = 38; // tmp.precision;
result.scale = tmp.scale;
Expand All @@ -131,11 +147,17 @@ class Sum : public AggExpression {
class Min : public AggExpression {
public:
~Min() {}
void getResult(DecimalVector& result) override {

void getResultInternal(DecimalVector& result) override {
auto tmp = DecimalVector();
child->getResult(tmp);
arrow::BasicDecimal128 out(tmp.data[0]);
for (auto e : tmp.data) out = out < e ? out : e;
for (int i = 0; i < tmp.data.size(); i++) {
if (tmp.nullVector->at(i)) {
out = out < tmp.data[i] ? out : tmp.data[i];
}
}
result.data.clear();
result.data.push_back(out);
result.precision = tmp.precision;
result.scale = tmp.scale;
Expand All @@ -146,11 +168,16 @@ class Min : public AggExpression {
class Max : public AggExpression {
public:
~Max() {}
void getResult(DecimalVector& result) override {
void getResultInternal(DecimalVector& result) override {
auto tmp = DecimalVector();
child->getResult(tmp);
arrow::BasicDecimal128 out(tmp.data[0]);
for (auto e : tmp.data) out = out > e ? out : e;
for (int i = 0; i < tmp.data.size(); i++) {
if (tmp.nullVector->at(i)) {
out = out > tmp.data[i] ? out : tmp.data[i];
}
}
result.data.clear();
result.data.push_back(out);
result.precision = tmp.precision;
result.scale = tmp.scale;
Expand All @@ -161,42 +188,21 @@ class Max : public AggExpression {
class Count : public AggExpression {
public:
~Count() {}
void getResult(DecimalVector& result) override {
result.data.push_back(arrow::BasicDecimal128(count));
result.type = ResultType::LongType;
}

void getResult(DecimalVector& result) override;
int ExecuteWithParam(int batchSize, const std::vector<int64_t>& dataBuffers,
const std::vector<int64_t>& nullBuffers,
std::vector<int8_t>& outBuffers) override {
if (!done) {
done = true;
count = batchSize;
count = 0;
batchSize_ = batchSize; // for count(*)
child->ExecuteWithParam(batchSize, dataBuffers, nullBuffers, outBuffers);
}
return 0;
}

private:
int count = 0;
};

class Avg : public AggExpression {
public:
~Avg() {}
void getResult(DecimalVector& result) override {
// should never be called
auto tmp = DecimalVector();
child->getResult(tmp);
arrow::BasicDecimal128 sum;
for (auto e : tmp.data) {
sum += e;
}
result.data.push_back(sum);
result.data.push_back(arrow::BasicDecimal128(tmp.data.size()));
result.precision = 38; // tmp.precision;
result.scale = tmp.scale;
result.type = GetResultType(dataType);
}
int batchSize_ = 0;
};

class ArithmeticExpression : public WithResultExpression {
Expand Down Expand Up @@ -284,16 +290,23 @@ class Add : public ArithmeticExpression {
arrow::BasicDecimal128 out = left.data[0] + right.data[i];
result.data.push_back(out);
}
result.nullVector = right.nullVector;
} else if (right.data.size() == 1) {
for (int i = 0; i < left.data.size(); i++) {
arrow::BasicDecimal128 out = left.data[i] + right.data[0];
result.data.push_back(out);
}
result.nullVector = left.nullVector;
} else if (left.data.size() == right.data.size()) {
for (int i = 0; i < left.data.size(); i++) {
arrow::BasicDecimal128 out = left.data[i] + right.data[i];
result.data.push_back(out);
}
std::vector<uint8_t> nullVec(left.data.size());
for (int i = 0; i < left.data.size(); i++) {
nullVec[i] = left.nullVector->at(i) & right.nullVector->at(i);
}
result.nullVector = std::make_shared<std::vector<uint8_t>>(nullVec);
} else {
ARROW_LOG(ERROR) << "Oops...why left and right has different size?";
}
Expand Down Expand Up @@ -331,16 +344,23 @@ class Sub : public ArithmeticExpression {
arrow::BasicDecimal128 out = left.data[0] - right.data[i];
result.data.push_back(out);
}
result.nullVector = right.nullVector;
} else if (right.data.size() == 1) {
for (int i = 0; i < left.data.size(); i++) {
arrow::BasicDecimal128 out = left.data[i] - right.data[0];
result.data.push_back(out);
}
result.nullVector = left.nullVector;
} else if (left.data.size() == right.data.size()) {
for (int i = 0; i < left.data.size(); i++) {
arrow::BasicDecimal128 out = left.data[i] - right.data[i];
result.data.push_back(out);
}
std::vector<uint8_t> nullVec(left.data.size());
for (int i = 0; i < left.data.size(); i++) {
nullVec[i] = left.nullVector->at(i) & right.nullVector->at(i);
}
result.nullVector = std::make_shared<std::vector<uint8_t>>(nullVec);
} else {
ARROW_LOG(ERROR) << "Oops...why left and right has different size?";
}
Expand Down Expand Up @@ -378,22 +398,30 @@ class Multiply : public ArithmeticExpression {
arrow::BasicDecimal128 out = left.data[0] * right.data[i];
result.data.push_back(out);
}
result.nullVector = right.nullVector;
} else if (right.data.size() == 1) {
for (int i = 0; i < left.data.size(); i++) {
arrow::BasicDecimal128 out = left.data[i] * right.data[0];
result.data.push_back(out);
}
result.nullVector = left.nullVector;
} else if (left.data.size() == right.data.size()) {
for (int i = 0; i < left.data.size(); i++) {
arrow::BasicDecimal128 out = left.data[i] * right.data[i];
result.data.push_back(out);
}
std::vector<uint8_t> nullVec(left.data.size());
for (int i = 0; i < left.data.size(); i++) {
nullVec[i] = left.nullVector->at(i) & right.nullVector->at(i);
}
result.nullVector = std::make_shared<std::vector<uint8_t>>(nullVec);
} else {
ARROW_LOG(ERROR) << "Oops...why left and right has different size?";
}
}
};

// TODO: Impl Divide and Mod.
class Divide : public ArithmeticExpression {
public:
~Divide() {}
Expand Down Expand Up @@ -427,6 +455,7 @@ class AttributeReferenceExpression : public WithResultExpression {
}
res.precision = result.precision;
res.scale = result.scale;
res.nullVector = result.nullVector;
}

void setAttribute(std::string columnName_, std::string dataType_, std::string castType_,
Expand Down Expand Up @@ -458,6 +487,8 @@ class LiteralExpression : public WithResultExpression {
res.data.push_back(value);
res.precision = precision_;
res.scale = scale_;
std::vector<uint8_t> nullVec{1};
res.nullVector = std::make_shared<std::vector<uint8_t>>(nullVec);
}
void setAttribute(std::string dataType_, std::string valueString_) {
dataType = dataType_;
Expand Down Expand Up @@ -511,8 +542,6 @@ class Gen {
return std::make_shared<Max>();
else if (name.compare("Min") == 0)
return std::make_shared<Min>();
else if (name.compare("Average") == 0)
return std::make_shared<Avg>();
else if (name.compare("Count") == 0)
return std::make_shared<Count>();

Expand Down
2 changes: 2 additions & 0 deletions oap-ape/ape-native/src/utils/DecimalConvertor.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ struct DecimalVector {
int32_t precision;
int32_t scale;
ResultType type;
std::shared_ptr<std::vector<uint8_t>> nullVector = nullptr;
void operator=(const DecimalVector& lhs) {
this->data = lhs.data;
this->precision = lhs.precision;
this->scale = lhs.scale;
this->type = lhs.type;
this->nullVector = lhs.nullVector;
}
};

Expand Down