Skip to content

Commit a18d1be

Browse files
liujiayi771glutenperfbot
authored andcommitted
Spark sql sum agg function support decimal (facebookincubator#5372)
1 parent def485b commit a18d1be

File tree

3 files changed

+744
-2
lines changed

3 files changed

+744
-2
lines changed
Lines changed: 388 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,388 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
#include "velox/exec/Aggregate.h"
18+
#include "velox/expression/FunctionSignature.h"
19+
#include "velox/vector/FlatVector.h"
20+
21+
namespace facebook::velox::functions::aggregate::sparksql {
22+
23+
struct DecimalSum {
24+
int128_t sum{0};
25+
int64_t overflow{0};
26+
bool isEmpty{true};
27+
28+
void mergeWith(const DecimalSum& other) {
29+
this->overflow += other.overflow;
30+
this->overflow +=
31+
DecimalUtil::addWithOverflow(this->sum, other.sum, this->sum);
32+
this->isEmpty &= other.isEmpty;
33+
}
34+
};
35+
36+
template <typename TInputType, typename TResultType>
37+
class DecimalSumAggregate : public exec::Aggregate {
38+
public:
39+
explicit DecimalSumAggregate(TypePtr resultType, TypePtr sumType)
40+
: exec::Aggregate(resultType), sumType_(sumType) {}
41+
42+
int32_t accumulatorFixedWidthSize() const override {
43+
return sizeof(DecimalSum);
44+
}
45+
46+
int32_t accumulatorAlignmentSize() const override {
47+
return alignof(DecimalSum);
48+
}
49+
50+
void initializeNewGroups(
51+
char** groups,
52+
folly::Range<const vector_size_t*> indices) override {
53+
setAllNulls(groups, indices);
54+
for (auto i : indices) {
55+
new (groups[i] + offset_) DecimalSum();
56+
}
57+
}
58+
59+
int128_t computeFinalValue(DecimalSum* decimalSum, bool& overflow) {
60+
int128_t sum = decimalSum->sum;
61+
if ((decimalSum->overflow == 1 && decimalSum->sum < 0) ||
62+
(decimalSum->overflow == -1 && decimalSum->sum > 0)) {
63+
sum = static_cast<int128_t>(
64+
DecimalUtil::kOverflowMultiplier * decimalSum->overflow +
65+
decimalSum->sum);
66+
} else {
67+
if (decimalSum->overflow != 0) {
68+
overflow = true;
69+
return 0;
70+
}
71+
}
72+
73+
auto [resultPrecision, resultScale] =
74+
getDecimalPrecisionScale(*sumType_.get());
75+
overflow = !DecimalUtil::valueInPrecisionRange(sum, resultPrecision);
76+
return sum;
77+
}
78+
79+
void extractValues(char** groups, int32_t numGroups, VectorPtr* result)
80+
override {
81+
VELOX_CHECK_EQ((*result)->encoding(), VectorEncoding::Simple::FLAT);
82+
auto vector = (*result)->as<FlatVector<TResultType>>();
83+
VELOX_CHECK(vector);
84+
vector->resize(numGroups);
85+
uint64_t* rawNulls = getRawNulls(vector);
86+
87+
TResultType* rawValues = vector->mutableRawValues();
88+
for (auto i = 0; i < numGroups; ++i) {
89+
char* group = groups[i];
90+
if (isNull(group)) {
91+
vector->setNull(i, true);
92+
} else {
93+
clearNull(rawNulls, i);
94+
auto* decimalSum = accumulator(group);
95+
if (decimalSum->isEmpty) {
96+
// If isEmpty is true, we should set null.
97+
vector->setNull(i, true);
98+
} else {
99+
bool overflow = false;
100+
auto result = (TResultType)computeFinalValue(decimalSum, overflow);
101+
if (overflow) {
102+
// Sum should be set to null on overflow.
103+
vector->setNull(i, true);
104+
} else {
105+
rawValues[i] = result;
106+
}
107+
}
108+
}
109+
}
110+
}
111+
112+
void extractAccumulators(
113+
char** groups,
114+
int32_t numGroups,
115+
facebook::velox::VectorPtr* result) override {
116+
VELOX_CHECK_EQ((*result)->encoding(), VectorEncoding::Simple::ROW);
117+
auto rowVector = (*result)->as<RowVector>();
118+
auto sumVector = rowVector->childAt(0)->asFlatVector<TResultType>();
119+
auto isEmptyVector = rowVector->childAt(1)->asFlatVector<bool>();
120+
121+
rowVector->resize(numGroups);
122+
sumVector->resize(numGroups);
123+
isEmptyVector->resize(numGroups);
124+
125+
TResultType* rawSums = sumVector->mutableRawValues();
126+
// Bool uses compact representation, use mutableRawValues<uint64_t>
127+
// and bits::setBit instead.
128+
auto* rawIsEmpty = isEmptyVector->mutableRawValues<uint64_t>();
129+
uint64_t* rawNulls = getRawNulls(rowVector);
130+
131+
for (auto i = 0; i < numGroups; ++i) {
132+
char* group = groups[i];
133+
clearNull(rawNulls, i);
134+
if (isNull(group)) {
135+
rawSums[i] = 0;
136+
bits::setBit(rawIsEmpty, i, true);
137+
} else {
138+
auto* decimalSum = accumulator(group);
139+
bool overflow = false;
140+
auto result = (TResultType)computeFinalValue(decimalSum, overflow);
141+
if (overflow) {
142+
// Sum should be set to null on overflow, and
143+
// isEmpty should be set to false.
144+
sumVector->setNull(i, true);
145+
bits::setBit(rawIsEmpty, i, false);
146+
} else {
147+
rawSums[i] = result;
148+
bits::setBit(rawIsEmpty, i, decimalSum->isEmpty);
149+
}
150+
}
151+
}
152+
}
153+
154+
void addRawInput(
155+
char** groups,
156+
const facebook::velox::SelectivityVector& rows,
157+
const std::vector<VectorPtr>& args,
158+
bool /* mayPushdown */) override {
159+
decodedRaw_.decode(*args[0], rows);
160+
if (decodedRaw_.isConstantMapping()) {
161+
if (!decodedRaw_.isNullAt(0)) {
162+
auto value = decodedRaw_.valueAt<TInputType>(0);
163+
rows.applyToSelected([&](vector_size_t i) {
164+
updateNonNullValue(groups[i], value, false);
165+
});
166+
}
167+
} else if (decodedRaw_.mayHaveNulls()) {
168+
rows.applyToSelected([&](vector_size_t i) {
169+
if (decodedRaw_.isNullAt(i)) {
170+
return;
171+
}
172+
updateNonNullValue(
173+
groups[i], decodedRaw_.valueAt<TInputType>(i), false);
174+
});
175+
} else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) {
176+
auto data = decodedRaw_.data<TInputType>();
177+
rows.applyToSelected([&](vector_size_t i) {
178+
updateNonNullValue<false>(groups[i], data[i], false);
179+
});
180+
} else {
181+
rows.applyToSelected([&](vector_size_t i) {
182+
updateNonNullValue(
183+
groups[i], decodedRaw_.valueAt<TInputType>(i), false);
184+
});
185+
}
186+
}
187+
188+
void addSingleGroupRawInput(
189+
char* group,
190+
const SelectivityVector& rows,
191+
const std::vector<VectorPtr>& args,
192+
bool /* mayPushdown */) override {
193+
decodedRaw_.decode(*args[0], rows);
194+
if (decodedRaw_.isConstantMapping()) {
195+
if (!decodedRaw_.isNullAt(0)) {
196+
auto value = decodedRaw_.valueAt<TInputType>(0);
197+
rows.template applyToSelected(
198+
[&](vector_size_t i) { updateNonNullValue(group, value, false); });
199+
}
200+
} else if (decodedRaw_.mayHaveNulls()) {
201+
rows.applyToSelected([&](vector_size_t i) {
202+
if (!decodedRaw_.isNullAt(i)) {
203+
updateNonNullValue(group, decodedRaw_.valueAt<TInputType>(i), false);
204+
}
205+
});
206+
} else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) {
207+
auto data = decodedRaw_.data<TInputType>();
208+
DecimalSum decimalSum;
209+
rows.applyToSelected([&](vector_size_t i) {
210+
decimalSum.overflow += DecimalUtil::addWithOverflow(
211+
decimalSum.sum, data[i], decimalSum.sum);
212+
decimalSum.isEmpty = false;
213+
});
214+
mergeAccumulators(group, decimalSum);
215+
} else {
216+
DecimalSum decimalSum;
217+
rows.applyToSelected([&](vector_size_t i) {
218+
decimalSum.overflow += DecimalUtil::addWithOverflow(
219+
decimalSum.sum, decodedRaw_.valueAt<TInputType>(i), decimalSum.sum);
220+
decimalSum.isEmpty = false;
221+
});
222+
mergeAccumulators(group, decimalSum);
223+
}
224+
}
225+
226+
void addIntermediateResults(
227+
char** groups,
228+
const SelectivityVector& rows,
229+
const std::vector<VectorPtr>& args,
230+
bool /* mayPushdown */) override {
231+
decodedPartial_.decode(*args[0], rows);
232+
VELOX_CHECK_EQ(
233+
decodedPartial_.base()->encoding(), VectorEncoding::Simple::ROW);
234+
auto baseRowVector = dynamic_cast<const RowVector*>(decodedPartial_.base());
235+
auto sumVector = baseRowVector->childAt(0)->as<SimpleVector<TResultType>>();
236+
auto isEmptyVector = baseRowVector->childAt(1)->as<SimpleVector<bool>>();
237+
238+
if (decodedPartial_.isConstantMapping()) {
239+
if (!decodedPartial_.isNullAt(0)) {
240+
auto decodedIndex = decodedPartial_.index(0);
241+
if (isIntermediateResultOverflow(
242+
isEmptyVector, sumVector, decodedIndex)) {
243+
rows.applyToSelected([&](vector_size_t i) { setNull(groups[i]); });
244+
} else {
245+
auto sum = sumVector->valueAt(decodedIndex);
246+
auto isEmpty = isEmptyVector->valueAt(decodedIndex);
247+
rows.applyToSelected([&](vector_size_t i) {
248+
clearNull(groups[i]);
249+
updateNonNullValue(groups[i], sum, isEmpty);
250+
});
251+
}
252+
}
253+
} else if (decodedPartial_.mayHaveNulls()) {
254+
rows.applyToSelected([&](vector_size_t i) {
255+
if (decodedPartial_.isNullAt(i)) {
256+
return;
257+
}
258+
auto decodedIndex = decodedPartial_.index(i);
259+
if (isIntermediateResultOverflow(
260+
isEmptyVector, sumVector, decodedIndex)) {
261+
setNull(groups[i]);
262+
} else {
263+
auto sum = sumVector->valueAt(decodedIndex);
264+
auto isEmpty = isEmptyVector->valueAt(decodedIndex);
265+
updateNonNullValue(groups[i], sum, isEmpty);
266+
}
267+
});
268+
} else {
269+
rows.applyToSelected([&](vector_size_t i) {
270+
clearNull(groups[i]);
271+
auto decodedIndex = decodedPartial_.index(i);
272+
if (isIntermediateResultOverflow(
273+
isEmptyVector, sumVector, decodedIndex)) {
274+
setNull(groups[i]);
275+
} else {
276+
auto sum = sumVector->valueAt(decodedIndex);
277+
auto isEmpty = isEmptyVector->valueAt(decodedIndex);
278+
updateNonNullValue(groups[i], sum, isEmpty);
279+
}
280+
});
281+
}
282+
}
283+
284+
void addSingleGroupIntermediateResults(
285+
char* group,
286+
const SelectivityVector& rows,
287+
const std::vector<VectorPtr>& args,
288+
bool /* mayPushdown */) override {
289+
decodedPartial_.decode(*args[0], rows);
290+
VELOX_CHECK_EQ(
291+
decodedPartial_.base()->encoding(), VectorEncoding::Simple::ROW);
292+
auto baseRowVector = dynamic_cast<const RowVector*>(decodedPartial_.base());
293+
auto sumVector = baseRowVector->childAt(0)->as<SimpleVector<TResultType>>();
294+
auto isEmptyVector = baseRowVector->childAt(1)->as<SimpleVector<bool>>();
295+
if (decodedPartial_.isConstantMapping()) {
296+
if (!decodedPartial_.isNullAt(0)) {
297+
auto decodedIndex = decodedPartial_.index(0);
298+
if (isIntermediateResultOverflow(
299+
isEmptyVector, sumVector, decodedIndex)) {
300+
setNull(group);
301+
} else {
302+
auto sum = sumVector->valueAt(decodedIndex);
303+
auto isEmpty = isEmptyVector->valueAt(decodedIndex);
304+
if (rows.hasSelections()) {
305+
clearNull(group);
306+
}
307+
rows.applyToSelected([&](vector_size_t i) {
308+
updateNonNullValue(group, sum, isEmpty);
309+
});
310+
}
311+
}
312+
} else if (decodedPartial_.mayHaveNulls()) {
313+
rows.applyToSelected([&](vector_size_t i) {
314+
if (decodedPartial_.isNullAt(i)) {
315+
return;
316+
}
317+
auto decodedIndex = decodedPartial_.index(i);
318+
if (isIntermediateResultOverflow(
319+
isEmptyVector, sumVector, decodedIndex)) {
320+
setNull(group);
321+
return;
322+
} else {
323+
clearNull(group);
324+
auto sum = sumVector->valueAt(decodedIndex);
325+
auto isEmpty = isEmptyVector->valueAt(decodedIndex);
326+
updateNonNullValue(group, sum, isEmpty);
327+
}
328+
});
329+
} else {
330+
if (rows.hasSelections()) {
331+
clearNull(group);
332+
}
333+
rows.applyToSelected([&](vector_size_t i) {
334+
auto decodedIndex = decodedPartial_.index(i);
335+
if (isIntermediateResultOverflow(
336+
isEmptyVector, sumVector, decodedIndex)) {
337+
setNull(group);
338+
return;
339+
} else {
340+
auto sum = sumVector->valueAt(decodedIndex);
341+
auto isEmpty = isEmptyVector->valueAt(decodedIndex);
342+
updateNonNullValue(group, sum, isEmpty);
343+
}
344+
});
345+
}
346+
}
347+
348+
private:
349+
template <bool tableHasNulls = true>
350+
inline void updateNonNullValue(char* group, TResultType value, bool isEmpty) {
351+
if constexpr (tableHasNulls) {
352+
exec::Aggregate::clearNull(group);
353+
}
354+
auto decimalSum = accumulator(group);
355+
decimalSum->overflow +=
356+
DecimalUtil::addWithOverflow(decimalSum->sum, value, decimalSum->sum);
357+
decimalSum->isEmpty &= isEmpty;
358+
}
359+
360+
template <bool tableHasNulls = true>
361+
inline void mergeAccumulators(char* group, DecimalSum other) {
362+
if constexpr (tableHasNulls) {
363+
exec::Aggregate::clearNull(group);
364+
}
365+
auto decimalSum = accumulator(group);
366+
decimalSum->mergeWith(other);
367+
}
368+
369+
inline DecimalSum* accumulator(char* group) {
370+
return exec::Aggregate::value<DecimalSum>(group);
371+
}
372+
373+
inline bool isIntermediateResultOverflow(
374+
const SimpleVector<bool>* isEmptyVector,
375+
const SimpleVector<TResultType>* sumVector,
376+
vector_size_t index) {
377+
// If isEmpty is false and sum is null, it means this intermediate
378+
// result has an overflow. The final accumulator of this group will
379+
// be null.
380+
return !isEmptyVector->valueAt(index) && sumVector->isNullAt(index);
381+
}
382+
383+
DecodedVector decodedRaw_;
384+
DecodedVector decodedPartial_;
385+
TypePtr sumType_;
386+
};
387+
388+
} // namespace facebook::velox::functions::aggregate::sparksql

0 commit comments

Comments
 (0)