Skip to content

Commit f1d9a5c

Browse files
liujiayi771glutenperfbot
authored andcommitted
Spark sql avg agg function support decimal (facebookincubator#6020)
Revert "Fix decimal agg signature on partial companion function (oap-project#465)" This reverts commit 336d61f.
1 parent 38ea25e commit f1d9a5c

File tree

6 files changed

+527
-43
lines changed

6 files changed

+527
-43
lines changed

velox/functions/lib/aggregates/AverageAggregateBase.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@ namespace facebook::velox::functions::aggregate {
2121
void checkAvgIntermediateType(const TypePtr& type) {
2222
VELOX_USER_CHECK(
2323
type->isRow() || type->isVarbinary(),
24-
"Input type for final average must be row type or varbinary type.");
24+
"Input type for final average must be row type or varbinary type, find {}",
25+
type->toString());
2526
if (type->kind() == TypeKind::VARBINARY) {
2627
return;
2728
}
2829
VELOX_USER_CHECK(
2930
type->childAt(0)->kind() == TypeKind::DOUBLE ||
3031
type->childAt(0)->isLongDecimal(),
31-
"Input type for sum in final average must be double or long decimal type.")
32+
"Input type for sum in final average must be double or long decimal type, find {}",
33+
type->childAt(0)->toString());
3234
VELOX_USER_CHECK_EQ(
3335
type->childAt(1)->kind(),
3436
TypeKind::BIGINT,

velox/functions/lib/aggregates/DecimalAggregate.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,11 @@ class DecimalAggregate : public exec::Aggregate {
7474
explicit DecimalAggregate(TypePtr resultType) : exec::Aggregate(resultType) {}
7575

7676
int32_t accumulatorFixedWidthSize() const override {
77-
return sizeof(DecimalAggregate);
77+
return sizeof(LongDecimalWithOverflowState);
7878
}
7979

8080
int32_t accumulatorAlignmentSize() const override {
81-
return static_cast<int32_t>(sizeof(int128_t));
81+
return alignof(LongDecimalWithOverflowState);
8282
}
8383

8484
void initializeNewGroups(
@@ -287,7 +287,9 @@ class DecimalAggregate : public exec::Aggregate {
287287
}
288288

289289
virtual TResultType computeFinalValue(
290-
LongDecimalWithOverflowState* accumulator) = 0;
290+
LongDecimalWithOverflowState* accumulator) {
291+
return 0;
292+
};
291293

292294
void extractValues(char** groups, int32_t numGroups, VectorPtr* result)
293295
override {
@@ -329,11 +331,12 @@ class DecimalAggregate : public exec::Aggregate {
329331
accumulator->count += 1;
330332
}
331333

332-
private:
334+
protected:
333335
inline LongDecimalWithOverflowState* decimalAccumulator(char* group) {
334336
return exec::Aggregate::value<LongDecimalWithOverflowState>(group);
335337
}
336338

339+
private:
337340
DecodedVector decodedRaw_;
338341
DecodedVector decodedPartial_;
339342
};

0 commit comments

Comments
 (0)