Skip to content

Commit 1ba95b7

Browse files
committed
Fix signature
1 parent fcd23a5 commit 1ba95b7

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

velox/functions/sparksql/aggregates/AverageAggregate.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,9 @@ exec::AggregateRegistrationResult registerAverage(
429429
auto inputScale = inputType->asShortDecimal().scale();
430430
auto sumType =
431431
DECIMAL(std::min(38, inputPrecision + 10), inputScale);
432-
if (exec::isPartialOutput(step)) {
432+
if (exec::isPartialOutput(step) ||
433+
(step == core::AggregationNode::Step::kSingle &&
434+
resultType->isRow())) {
433435
return std::make_unique<
434436
DecimalAverageAggregate<int64_t, int64_t>>(
435437
resultType, sumType);

velox/functions/sparksql/aggregates/SumAggregate.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@ using SumAggregate = SumAggregateBase<TInput, TAccumulator, ResultType, true>;
2929
TypePtr getDecimalSumType(
3030
const TypePtr& resultType,
3131
core::AggregationNode::Step step) {
32-
return exec::isPartialOutput(step) ? resultType->childAt(0) : resultType;
32+
if (exec::isPartialOutput(step)) {
33+
return resultType->childAt(0);
34+
}
35+
if (step == core::AggregationNode::Step::kSingle && resultType->isRow()) {
36+
return resultType->childAt(0);
37+
}
38+
return resultType;
3339
}
3440
} // namespace
3541

velox/functions/sparksql/tests/DecimalUtilTest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class DecimalUtilTest : public testing::Test {
3030
R expectedResult,
3131
bool expectedOverflow) {
3232
R r;
33-
bool overflow;
33+
bool overflow = false;
3434
DecimalUtil::divideWithRoundUp<R, A, B>(r, a, b, aRescale, overflow);
3535
ASSERT_EQ(overflow, expectedOverflow);
3636
ASSERT_EQ(r, expectedResult);

0 commit comments

Comments
 (0)