Skip to content

Commit 8044475

Browse files
committed
adding param allowPrecisionLoss
Signed-off-by: Yuan Zhou <[email protected]>
1 parent 96448fe commit 8044475

File tree

3 files changed

+50
-13
lines changed

3 files changed

+50
-13
lines changed

velox/core/QueryConfig.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ class QueryConfig {
101101
static constexpr const char* kCastMatchStructByName =
102102
"cast_match_struct_by_name";
103103

104+
// This flags forces to bound the decimal precision.
105+
static constexpr const char* kAllowPrecisionLoss = "allow_precision_loss";
106+
104107
/// Used for backpressure to block local exchange producers when the local
105108
/// exchange buffer reaches or exceeds this size.
106109
static constexpr const char* kMaxLocalExchangeBufferSize =
@@ -496,6 +499,10 @@ class QueryConfig {
496499
return get<bool>(kCastMatchStructByName, false);
497500
}
498501

502+
bool isAllowPrecisionLoss() const {
503+
return get<bool>(kAllowPrecisionLoss, true);
504+
}
505+
499506
bool codegenEnabled() const {
500507
return get<bool>(kCodegenEnabled, false);
501508
}

velox/functions/sparksql/DecimalArithmetic.cpp

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -416,11 +416,14 @@ class Addition {
416416
uint8_t aPrecision,
417417
uint8_t aScale,
418418
uint8_t bPrecision,
419-
uint8_t bScale) {
419+
uint8_t bScale,
420+
bool allowPrecisionLoss) {
420421
auto precision = std::max(aPrecision - aScale, bPrecision - bScale) +
421422
std::max(aScale, bScale) + 1;
422423
auto scale = std::max(aScale, bScale);
423-
return DecimalUtil::adjustPrecisionScale(precision, scale);
424+
return allowPrecisionLoss
425+
? DecimalUtil::adjustPrecisionScale(precision, scale)
426+
: DecimalUtil::bounded(precision, scale);
424427
}
425428
};
426429

@@ -464,9 +467,10 @@ class Subtraction {
464467
uint8_t aPrecision,
465468
uint8_t aScale,
466469
uint8_t bPrecision,
467-
uint8_t bScale) {
470+
uint8_t bScale,
471+
bool allowPrecisionLoss) {
468472
return Addition::computeResultPrecisionScale(
469-
aPrecision, aScale, bPrecision, bScale);
473+
aPrecision, aScale, bPrecision, bScale, allowPrecisionLoss);
470474
}
471475
};
472476

@@ -566,9 +570,12 @@ class Multiply {
566570
uint8_t aPrecision,
567571
uint8_t aScale,
568572
uint8_t bPrecision,
569-
uint8_t bScale) {
570-
return DecimalUtil::adjustPrecisionScale(
571-
aPrecision + bPrecision + 1, aScale + bScale);
573+
uint8_t bScale,
574+
const bool allowPrecisionLoss) {
575+
return allowPrecisionLoss
576+
? DecimalUtil::adjustPrecisionScale(
577+
aPrecision + bPrecision + 1, aScale + bScale)
578+
: DecimalUtil::bounded(aPrecision + bPrecision + 1, aScale + bScale);
572579
}
573580

574581
private:
@@ -616,10 +623,22 @@ class Divide {
616623
uint8_t aPrecision,
617624
uint8_t aScale,
618625
uint8_t bPrecision,
619-
uint8_t bScale) {
620-
auto scale = std::max(6, aScale + bPrecision + 1);
621-
auto precision = aPrecision - aScale + bScale + scale;
622-
return DecimalUtil::adjustPrecisionScale(precision, scale);
626+
uint8_t bScale,
627+
bool allowPrecisionLoss) {
628+
if (allowPrecisionLoss) {
629+
auto scale = std::max(6, aScale + bPrecision + 1);
630+
auto precision = aPrecision - aScale + bScale + scale;
631+
return DecimalUtil::adjustPrecisionScale(precision, scale);
632+
} else {
633+
auto intDig = std::min(38, aPrecision - aScale + bScale);
634+
auto decDig = std::min(38, std::max(6, aScale + bPrecision + 1));
635+
auto diff = (intDig + decDig) - 38;
636+
if (diff > 0) {
637+
decDig -= diff / 2 + 1;
638+
intDig = 38 - decDig;
639+
}
640+
return DecimalUtil::bounded(intDig + decDig, decDig);
641+
}
623642
}
624643
};
625644

@@ -689,13 +708,14 @@ template <typename Operation>
689708
std::shared_ptr<exec::VectorFunction> createDecimalFunction(
690709
const std::string& name,
691710
const std::vector<exec::VectorFunctionArg>& inputArgs,
692-
const core::QueryConfig& /*config*/) {
711+
const core::QueryConfig& config) {
693712
const auto& aType = inputArgs[0].type;
694713
const auto& bType = inputArgs[1].type;
695714
const auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType);
696715
const auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType);
716+
const bool allowPrecisionLoss = config.isAllowPrecisionLoss();
697717
const auto [rPrecision, rScale] = Operation::computeResultPrecisionScale(
698-
aPrecision, aScale, bPrecision, bScale);
718+
aPrecision, aScale, bPrecision, bScale, allowPrecisionLoss);
699719
const uint8_t aRescale =
700720
Operation::computeRescaleFactor(aScale, bScale, rScale);
701721
const uint8_t bRescale =

velox/functions/sparksql/DecimalUtil.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ class DecimalUtil {
4646
}
4747
}
4848

49+
/// This method is used when
50+
/// `spark.sql.decimalOperations.allowPrecisionLoss` is set to false.
51+
inline static std::pair<uint8_t, uint8_t> bounded(
52+
uint8_t rPrecision,
53+
uint8_t rScale) {
54+
return {
55+
std::min(static_cast<int32_t>(rPrecision), 38),
56+
std::min(static_cast<int32_t>(rScale), 38)};
57+
}
58+
4959
/// @brief Convert int256 value to int64 or int128, set overflow to true if
5060
/// value cannot convert to specific type.
5161
/// @return The converted value.

0 commit comments

Comments
 (0)