diff --git a/cpp/src/gandiva/function_holder_maker_registry.cc b/cpp/src/gandiva/function_holder_maker_registry.cc index 2d9657489670..f45cf2b820f8 100644 --- a/cpp/src/gandiva/function_holder_maker_registry.cc +++ b/cpp/src/gandiva/function_holder_maker_registry.cc @@ -62,6 +62,7 @@ FunctionHolderMakerRegistry::MakerMap FunctionHolderMakerRegistry::DefaultHolder {"to_date", HolderMaker}, {"random", HolderMaker}, {"rand", HolderMaker}, + {"rand_integer", HolderMaker}, {"regexp_replace", HolderMaker}, {"regexp_extract", HolderMaker}, {"castintervalday", HolderMaker}, diff --git a/cpp/src/gandiva/function_registry_math_ops.cc b/cpp/src/gandiva/function_registry_math_ops.cc index 232c7c532600..3bfcfc180e7e 100644 --- a/cpp/src/gandiva/function_registry_math_ops.cc +++ b/cpp/src/gandiva/function_registry_math_ops.cc @@ -103,6 +103,14 @@ std::vector GetMathOpsFunctionRegistry() { "gdv_fn_random", NativeFunction::kNeedsFunctionHolder), NativeFunction("random", {"rand"}, DataTypeVector{int32()}, float64(), kResultNullNever, "gdv_fn_random_with_seed", + NativeFunction::kNeedsFunctionHolder), + NativeFunction("rand_integer", {}, DataTypeVector{}, int32(), kResultNullNever, + "gdv_fn_rand_integer", NativeFunction::kNeedsFunctionHolder), + NativeFunction("rand_integer", {}, DataTypeVector{int32()}, int32(), + kResultNullNever, "gdv_fn_rand_integer_with_range", + NativeFunction::kNeedsFunctionHolder), + NativeFunction("rand_integer", {}, DataTypeVector{int32(), int32()}, int32(), + kResultNullNever, "gdv_fn_rand_integer_with_min_max", NativeFunction::kNeedsFunctionHolder)}; return math_fn_registry_; diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index e1dba4b1ee81..cc5e09284d85 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -70,12 +70,33 @@ double gdv_fn_random(int64_t ptr) { return (*holder)(); } -double gdv_fn_random_with_seed(int64_t ptr, int32_t seed, bool seed_validity) { +double gdv_fn_random_with_seed(int64_t ptr, int32_t /*seed*/, bool /*seed_validity*/) { gandiva::RandomGeneratorHolder* holder = reinterpret_cast(ptr); return (*holder)(); } +int32_t gdv_fn_rand_integer(int64_t ptr) { + gandiva::RandomIntegerGeneratorHolder* holder = + reinterpret_cast(ptr); + return (*holder)(); +} + +int32_t gdv_fn_rand_integer_with_range(int64_t ptr, int32_t /*range*/, + bool /*range_validity*/) { + gandiva::RandomIntegerGeneratorHolder* holder = + reinterpret_cast(ptr); + return (*holder)(); +} + +int32_t gdv_fn_rand_integer_with_min_max(int64_t ptr, int32_t /*min*/, + bool /*min_validity*/, int32_t /*max*/, + bool /*max_validity*/) { + gandiva::RandomIntegerGeneratorHolder* holder = + reinterpret_cast(ptr); + return (*holder)(); +} + bool gdv_fn_in_expr_lookup_int32(int64_t ptr, int32_t value, bool in_validity) { if (!in_validity) { return false; @@ -936,6 +957,22 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { engine->AddGlobalMappingForFunc("gdv_fn_random_with_seed", types->double_type(), args, reinterpret_cast(gdv_fn_random_with_seed)); + // gdv_fn_rand_integer + args = {types->i64_type()}; + engine->AddGlobalMappingForFunc("gdv_fn_rand_integer", types->i32_type(), args, + reinterpret_cast(gdv_fn_rand_integer)); + + args = {types->i64_type(), types->i32_type(), types->i1_type()}; + engine->AddGlobalMappingForFunc( + "gdv_fn_rand_integer_with_range", types->i32_type(), args, + reinterpret_cast(gdv_fn_rand_integer_with_range)); + + args = {types->i64_type(), types->i32_type(), types->i1_type(), types->i32_type(), + types->i1_type()}; + engine->AddGlobalMappingForFunc( + "gdv_fn_rand_integer_with_min_max", types->i32_type(), args, + reinterpret_cast(gdv_fn_rand_integer_with_min_max)); + // gdv_fn_dec_from_string args = { types->i64_type(), // context diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index bfb34eeb31d8..d6d459f62bd5 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -21,6 +21,8 @@ #include #include +#include + #include "arrow/util/logging.h" #include "gandiva/execution_context.h" #include "gandiva/encrypt_utils_ecb.h" @@ -353,6 +355,14 @@ TEST(TestGdvFnStubs, TestCastVARCHARFromInt64) { out_str = gdv_fn_castVARCHAR_int64_int64(ctx_ptr, 12345, 3, &out_len); EXPECT_EQ(std::string(out_str, out_len), "123"); EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_int64_int64(ctx_ptr, 347, 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_int64_int64(ctx_ptr, 347, -1, &out_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer length cannot be negative")); + ctx.Reset(); } TEST(TestGdvFnStubs, TestCastVARCHARFromMilliseconds) { @@ -384,6 +394,15 @@ TEST(TestGdvFnStubs, TestCastVARCHARFromMilliseconds) { out_str = gdv_fn_castVARCHAR_date64_int64(ctx_ptr, ts, 4, &out_len); EXPECT_EQ(std::string(out_str, out_len), "2008"); EXPECT_FALSE(ctx.has_error()); + + ts = StringToTimestamp("2021-04-23 10:20:33"); + out_str = gdv_fn_castVARCHAR_date64_int64(ctx_ptr, ts, 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_date64_int64(ctx_ptr, ts, -1, &out_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer length cannot be negative")); + ctx.Reset(); } TEST(TestGdvFnStubs, TestCastVARCHARFromFloat) { @@ -419,6 +438,14 @@ TEST(TestGdvFnStubs, TestCastVARCHARFromFloat) { out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 1.2345f, 3, &out_len); EXPECT_EQ(std::string(out_str, out_len), "1.2"); EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 1.2345f, 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 1.2345f, -1, &out_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer length cannot be negative")); + ctx.Reset(); } TEST(TestGdvFnStubs, TestCastVARCHARFromDouble) { @@ -454,6 +481,25 @@ TEST(TestGdvFnStubs, TestCastVARCHARFromDouble) { out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 1.2345, 3, &out_len); EXPECT_EQ(std::string(out_str, out_len), "1.2"); EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 1.2345, 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 1.2345, -1, &out_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer length cannot be negative")); + ctx.Reset(); + + // test long repeating decimal (1/3) with large buffer + out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 1.0 / 3.0, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "0.3333333333333333"); + EXPECT_FALSE(ctx.has_error()); + + // test exponential notation with large negative exponent (24 chars) + out_str = + gdv_fn_castVARCHAR_float64_int64(ctx_ptr, -1.2345678901234567e-100, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "-1.2345678901234567E-100"); + EXPECT_FALSE(ctx.has_error()); } TEST(TestGdvFnStubs, TestSubstringIndex) { @@ -529,6 +575,21 @@ TEST(TestGdvFnStubs, TestSubstringIndex) { out_str = gdv_fn_substring_index(ctx_ptr, "路学\\L", 8, "\\", 1, -1, &out_len); EXPECT_EQ(std::string(out_str, out_len), "L"); EXPECT_FALSE(ctx.has_error()); + + // Large counts return full string when delimiter not found enough times + out_str = gdv_fn_substring_index(ctx_ptr, "a.b.c", 5, ".", 1, -1000, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "a.b.c"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_substring_index(ctx_ptr, "a.b.c", 5, ".", 1, + std::numeric_limits::max(), &out_len); + EXPECT_EQ(std::string(out_str, out_len), "a.b.c"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_substring_index(ctx_ptr, "a.b.c", 5, ".", 1, + std::numeric_limits::min(), &out_len); + EXPECT_EQ(std::string(out_str, out_len), "a.b.c"); + EXPECT_FALSE(ctx.has_error()); } TEST(TestGdvFnStubs, TestUpper) { diff --git a/cpp/src/gandiva/gdv_string_function_stubs.cc b/cpp/src/gandiva/gdv_string_function_stubs.cc index e7982461b439..d271834fb478 100644 --- a/cpp/src/gandiva/gdv_string_function_stubs.cc +++ b/cpp/src/gandiva/gdv_string_function_stubs.cc @@ -18,8 +18,10 @@ #include "gandiva/gdv_function_stubs.h" #include +#include #include #include +#include #include #include @@ -81,91 +83,110 @@ const char* gdv_fn_regexp_extract_utf8_utf8_int32(int64_t ptr, int64_t holder_pt return (*holder)(context, data, data_len, extract_index, out_length); } -#define GDV_FN_CAST_VARLEN_TYPE_FROM_TYPE(IN_TYPE, CAST_NAME, ARROW_TYPE) \ - GANDIVA_EXPORT \ - const char* gdv_fn_cast##CAST_NAME##_##IN_TYPE##_int64( \ - int64_t context, gdv_##IN_TYPE value, int64_t len, int32_t * out_len) { \ - if (len < 0) { \ - gdv_fn_context_set_error_msg(context, "Buffer length cannot be negative"); \ - *out_len = 0; \ - return ""; \ - } \ - if (len == 0) { \ - *out_len = 0; \ - return ""; \ - } \ - arrow::internal::StringFormatter formatter; \ - char* ret = reinterpret_cast( \ - gdv_fn_context_arena_malloc(context, static_cast(len))); \ - if (ret == nullptr) { \ - gdv_fn_context_set_error_msg(context, "Could not allocate memory"); \ - *out_len = 0; \ - return ""; \ - } \ - arrow::Status status = formatter(value, [&](std::string_view v) { \ - int64_t size = static_cast(v.size()); \ - *out_len = static_cast(len < size ? len : size); \ - memcpy(ret, v.data(), *out_len); \ - return arrow::Status::OK(); \ - }); \ - if (!status.ok()) { \ - std::string err = "Could not cast " + std::to_string(value) + " to string"; \ - gdv_fn_context_set_error_msg(context, err.c_str()); \ - *out_len = 0; \ - return ""; \ - } \ - return ret; \ +// The following castVARCHAR macros are optimized to allocate only the actual +// string size instead of the maximum buffer length (which can be 65536+ bytes). + +// Helper: arena allocation + null check +#define GDV_FN_CAST_VARLEN_ALLOC(SIZE) \ + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, SIZE)); \ + if (ret == nullptr) { \ + gdv_fn_context_set_error_msg(context, "Could not allocate memory"); \ + *out_len = 0; \ + return ""; \ + } + +// Helper: function signature + len validation +#define GDV_FN_CAST_VARLEN_PREFIX(IN_TYPE, CAST_NAME) \ + GANDIVA_EXPORT \ + const char* gdv_fn_cast##CAST_NAME##_##IN_TYPE##_int64( \ + int64_t context, gdv_##IN_TYPE value, int64_t len, int32_t * out_len) { \ + if (len < 0) { \ + gdv_fn_context_set_error_msg(context, "Buffer length cannot be negative"); \ + *out_len = 0; \ + return ""; \ + } \ + if (len == 0) { \ + *out_len = 0; \ + return ""; \ + } + +// Macro for integer types (int32/int64). Uses arrow::internal::detail::FormatAllDigits +// to convert digits right-to-left into a small arena allocation (11 bytes for int32, +// 20 for int64). +#define GDV_FN_CAST_VARLEN_TYPE_FROM_INTEGER(IN_TYPE, CAST_NAME, ARROW_TYPE) \ + GDV_FN_CAST_VARLEN_PREFIX(IN_TYPE, CAST_NAME) \ + constexpr int32_t max_len = std::numeric_limits::digits10 + 2; \ + GDV_FN_CAST_VARLEN_ALLOC(max_len) \ + char* end = ret + max_len; \ + char* cursor = end; \ + auto uval = arrow::internal::detail::Abs(value); \ + arrow::internal::detail::FormatAllDigits(uval, &cursor); \ + if (value < 0) { \ + arrow::internal::detail::FormatOneChar('-', &cursor); \ + } \ + int32_t slen = static_cast(end - cursor); \ + *out_len = static_cast(len < slen ? len : slen); \ + memmove(ret, cursor, *out_len); \ + return ret; \ } -#define GDV_FN_CAST_VARLEN_TYPE_FROM_REAL(IN_TYPE, CAST_NAME, ARROW_TYPE) \ - GANDIVA_EXPORT \ - const char* gdv_fn_cast##CAST_NAME##_##IN_TYPE##_int64( \ - int64_t context, gdv_##IN_TYPE value, int64_t len, int32_t * out_len) { \ - if (len < 0) { \ - gdv_fn_context_set_error_msg(context, "Buffer length cannot be negative"); \ - *out_len = 0; \ - return ""; \ - } \ - if (len == 0) { \ - *out_len = 0; \ - return ""; \ - } \ - gandiva::GdvStringFormatter formatter; \ - char* ret = reinterpret_cast( \ - gdv_fn_context_arena_malloc(context, static_cast(len))); \ - if (ret == nullptr) { \ - gdv_fn_context_set_error_msg(context, "Could not allocate memory"); \ - *out_len = 0; \ - return ""; \ - } \ - arrow::Status status = formatter(value, [&](std::string_view v) { \ - int64_t size = static_cast(v.size()); \ - *out_len = static_cast(len < size ? len : size); \ - memcpy(ret, v.data(), *out_len); \ - return arrow::Status::OK(); \ - }); \ - if (!status.ok()) { \ - std::string err = "Could not cast " + std::to_string(value) + " to string"; \ - gdv_fn_context_set_error_msg(context, err.c_str()); \ - *out_len = 0; \ - return ""; \ - } \ - return ret; \ +// Helper: invoke formatter callback, copy result to ret, handle errors. +// Used by date64 and float types that rely on arrow::internal::StringFormatter. +#define GDV_FN_CAST_VARLEN_FORMATTER_SUFFIX \ + arrow::Status status = formatter(value, [&](std::string_view v) { \ + int64_t size = static_cast(v.size()); \ + *out_len = static_cast(len < size ? len : size); \ + memcpy(ret, v.data(), *out_len); \ + return arrow::Status::OK(); \ + }); \ + if (!status.ok()) { \ + std::string err = "Could not cast " + std::to_string(value) + " to string"; \ + gdv_fn_context_set_error_msg(context, err.c_str()); \ + *out_len = 0; \ + return ""; \ + } \ + return ret; \ } -#define CAST_VARLEN_TYPE_FROM_NUMERIC(VARLEN_TYPE) \ - GDV_FN_CAST_VARLEN_TYPE_FROM_TYPE(int32, VARLEN_TYPE, Int32Type) \ - GDV_FN_CAST_VARLEN_TYPE_FROM_TYPE(int64, VARLEN_TYPE, Int64Type) \ - GDV_FN_CAST_VARLEN_TYPE_FROM_TYPE(date64, VARLEN_TYPE, Date64Type) \ - GDV_FN_CAST_VARLEN_TYPE_FROM_REAL(float32, VARLEN_TYPE, FloatType) \ +// Macro for date64 type. Output is always "YYYY-MM-DD" = 10 chars max. +#define GDV_FN_CAST_VARLEN_TYPE_FROM_DATE64(IN_TYPE, CAST_NAME, ARROW_TYPE) \ + GDV_FN_CAST_VARLEN_PREFIX(IN_TYPE, CAST_NAME) \ + constexpr int32_t max_date_str_len = 10; \ + int32_t alloc_len = \ + static_cast(len < max_date_str_len ? len : max_date_str_len); \ + GDV_FN_CAST_VARLEN_ALLOC(alloc_len) \ + arrow::internal::StringFormatter formatter; \ + GDV_FN_CAST_VARLEN_FORMATTER_SUFFIX + +// Macro for float types (float32/float64). Uses Java-compatible formatting. +// Max string: "-1.2345678901234567E-308" = 24 chars. +#define GDV_FN_CAST_VARLEN_TYPE_FROM_REAL(IN_TYPE, CAST_NAME, ARROW_TYPE) \ + GDV_FN_CAST_VARLEN_PREFIX(IN_TYPE, CAST_NAME) \ + constexpr int32_t max_real_str_len = 24; \ + int32_t alloc_len = \ + static_cast(len < max_real_str_len ? len : max_real_str_len); \ + GDV_FN_CAST_VARLEN_ALLOC(alloc_len) \ + gandiva::GdvStringFormatter formatter; \ + GDV_FN_CAST_VARLEN_FORMATTER_SUFFIX + +// Use optimized integer macro for int32/int64, date64 macro, and real macro for floats +#define CAST_VARLEN_TYPE_FROM_NUMERIC(VARLEN_TYPE) \ + GDV_FN_CAST_VARLEN_TYPE_FROM_INTEGER(int32, VARLEN_TYPE, Int32Type) \ + GDV_FN_CAST_VARLEN_TYPE_FROM_INTEGER(int64, VARLEN_TYPE, Int64Type) \ + GDV_FN_CAST_VARLEN_TYPE_FROM_DATE64(date64, VARLEN_TYPE, Date64Type) \ + GDV_FN_CAST_VARLEN_TYPE_FROM_REAL(float32, VARLEN_TYPE, FloatType) \ GDV_FN_CAST_VARLEN_TYPE_FROM_REAL(float64, VARLEN_TYPE, DoubleType) CAST_VARLEN_TYPE_FROM_NUMERIC(VARCHAR) CAST_VARLEN_TYPE_FROM_NUMERIC(VARBINARY) #undef CAST_VARLEN_TYPE_FROM_NUMERIC -#undef GDV_FN_CAST_VARLEN_TYPE_FROM_TYPE +#undef GDV_FN_CAST_VARLEN_TYPE_FROM_INTEGER +#undef GDV_FN_CAST_VARLEN_TYPE_FROM_DATE64 #undef GDV_FN_CAST_VARLEN_TYPE_FROM_REAL +#undef GDV_FN_CAST_VARLEN_FORMATTER_SUFFIX +#undef GDV_FN_CAST_VARLEN_ALLOC +#undef GDV_FN_CAST_VARLEN_PREFIX GDV_FORCE_INLINE void gdv_fn_set_error_for_invalid_utf8(int64_t execution_context, char val) { @@ -407,14 +428,17 @@ const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt } } - if (static_cast(abs(cnt)) <= static_cast(occ.size()) && cnt > 0) { + // Use int64_t to avoid undefined behavior with abs(INT_MIN) + int64_t abs_cnt = (cnt < 0) ? -static_cast(cnt) : static_cast(cnt); + int64_t occ_size = static_cast(occ.size()); + + if (abs_cnt <= occ_size && cnt > 0) { memcpy(out, txt, occ[cnt - 1]); *out_len = occ[cnt - 1]; return out; - } else if (static_cast(abs(cnt)) <= static_cast(occ.size()) && - cnt < 0) { - int32_t sz = static_cast(occ.size()); - int32_t temp = static_cast(abs(cnt)); + } else if (abs_cnt <= occ_size && cnt < 0) { + int64_t sz = occ_size; + int64_t temp = abs_cnt; memcpy(out, txt + occ[sz - temp] + pat_len, txt_len - occ[sz - temp] - pat_len); *out_len = txt_len - occ[sz - temp] - pat_len; diff --git a/cpp/src/gandiva/precompiled/extended_math_ops.cc b/cpp/src/gandiva/precompiled/extended_math_ops.cc index b2562e955acd..c29f8f2a8684 100644 --- a/cpp/src/gandiva/precompiled/extended_math_ops.cc +++ b/cpp/src/gandiva/precompiled/extended_math_ops.cc @@ -386,16 +386,22 @@ gdv_int64 get_power_of_10(gdv_int32 exp) { FORCE_INLINE gdv_int64 truncate_int64_int32(gdv_int64 in, gdv_int32 out_scale) { + // For int64 (no fractional digits), positive scale is a no-op + if (out_scale >= 0) { + return in; + } + // GetScaleMultiplier only supports scales 0-38 + if (out_scale < -38) { + return 0; + } + bool overflow = false; arrow::BasicDecimal128 decimal = gandiva::decimalops::FromInt64(in, 38, 0, &overflow); arrow::BasicDecimal128 decimal_with_outscale = gandiva::decimalops::Truncate(gandiva::BasicDecimalScalar128(decimal, 38, 0), 38, out_scale, out_scale, &overflow); - if (out_scale < 0) { - out_scale = 0; - } return gandiva::decimalops::ToInt64( - gandiva::BasicDecimalScalar128(decimal_with_outscale, 38, out_scale), &overflow); + gandiva::BasicDecimalScalar128(decimal_with_outscale, 38, 0), &overflow); } FORCE_INLINE diff --git a/cpp/src/gandiva/precompiled/extended_math_ops_test.cc b/cpp/src/gandiva/precompiled/extended_math_ops_test.cc index 7170fad01d25..ad0cb78188c5 100644 --- a/cpp/src/gandiva/precompiled/extended_math_ops_test.cc +++ b/cpp/src/gandiva/precompiled/extended_math_ops_test.cc @@ -22,6 +22,7 @@ #include #include +#include #include "gandiva/execution_context.h" #include "gandiva/precompiled/types.h" @@ -208,6 +209,18 @@ TEST(TestExtendedMathOps, TestTruncate) { EXPECT_EQ(truncate_int64_int32(-1234, -2), -1200); EXPECT_EQ(truncate_int64_int32(8124674407369523212, 0), 8124674407369523212); EXPECT_EQ(truncate_int64_int32(8124674407369523212, -2), 8124674407369523200); + + // Positive scales are no-op for int64 (no fractional digits) + EXPECT_EQ(truncate_int64_int32(12345, std::numeric_limits::max()), 12345); + EXPECT_EQ(truncate_int64_int32(-12345, std::numeric_limits::max()), -12345); + EXPECT_EQ(truncate_int64_int32(12345, 100), 12345); + EXPECT_EQ(truncate_int64_int32(12345, 39), 12345); + + // Scales beyond [-38, 0) truncate all digits + EXPECT_EQ(truncate_int64_int32(12345, std::numeric_limits::min()), 0); + EXPECT_EQ(truncate_int64_int32(12345, -100), 0); + EXPECT_EQ(truncate_int64_int32(12345, -39), 0); + EXPECT_EQ(truncate_int64_int32(-99999, -39), 0); } TEST(TestExtendedMathOps, TestTrigonometricFunctions) { diff --git a/cpp/src/gandiva/precompiled/string_ops.cc b/cpp/src/gandiva/precompiled/string_ops.cc index 0b787f461c21..035d3c8c62e1 100644 --- a/cpp/src/gandiva/precompiled/string_ops.cc +++ b/cpp/src/gandiva/precompiled/string_ops.cc @@ -582,11 +582,13 @@ const char* castVARCHAR_bool_int64(gdv_int64 context, gdv_boolean value, *out_length = 0; return ""; } - const char* out = - reinterpret_cast(gdv_fn_context_arena_malloc(context, 5)); - out = value ? "true" : "false"; - *out_length = value ? ((len > 4) ? 4 : len) : ((len > 5) ? 5 : len); - return out; + if (value) { + *out_length = (len > 4) ? 4 : len; + return "true"; + } else { + *out_length = (len > 5) ? 5 : len; + return "false"; + } } // Truncates the string to given length @@ -1966,6 +1968,23 @@ gdv_int32 evaluate_return_char_length(gdv_int32 text_len, gdv_int32 actual_text_ return return_char_length; } +// Fill a buffer with repeated fill_text using O(log n) doubling strategy +static FORCE_INLINE void fill_buffer_with_pattern(gdv_binary dest, + gdv_int32 total_fill_bytes, + const char* fill_text, + gdv_int32 fill_text_len) { + gdv_int32 initial_copy = std::min(fill_text_len, total_fill_bytes); + memcpy(dest, fill_text, initial_copy); + gdv_int32 written = initial_copy; + while (written * 2 <= total_fill_bytes) { + memcpy(dest + written, dest, written); + written *= 2; + } + if (written < total_fill_bytes) { + memcpy(dest + written, dest, total_fill_bytes - written); + } +} + FORCE_INLINE const char* lpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32 text_len, gdv_int32 return_length, const char* fill_text, @@ -1988,48 +2007,49 @@ const char* lpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32 // fill into text but "fill_text" is empty, then return text directly. *out_len = text_len; return text; - } else if (return_length < actual_text_len) { + } + if (return_length < actual_text_len) { // case where it truncates the result on return length. *out_len = utf8_byte_pos(context, text, text_len, return_length); return text; - } else { - // case (return_length > actual_text_len) - // case where it needs to copy "fill_text" on the string left. The total number - // of chars to copy is given by (return_length - actual_text_len) - gdv_int32 return_char_length = evaluate_return_char_length( - text_len, actual_text_len, return_length, fill_text, fill_text_len); - char* ret = reinterpret_cast( - gdv_fn_context_arena_malloc(context, return_char_length)); + } + + gdv_int32 chars_to_pad = return_length - actual_text_len; + + // FAST PATH: Single-byte fill (most common - space padding) + if (fill_text_len == 1) { + gdv_int32 out_len_bytes = chars_to_pad + text_len; + gdv_binary ret = + reinterpret_cast(gdv_fn_context_arena_malloc(context, out_len_bytes)); if (ret == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); *out_len = 0; return ""; } - // try to fulfill the return string with the "fill_text" continuously - int32_t copied_chars_count = 0; - int32_t copied_chars_position = 0; - while (copied_chars_count < return_length - actual_text_len) { - int32_t char_len; - int32_t fill_index; - // for each char, evaluate its length to consider it when mem copying - for (fill_index = 0; fill_index < fill_text_len; fill_index += char_len) { - if (copied_chars_count >= return_length - actual_text_len) { - break; - } - char_len = utf8_char_length(fill_text[fill_index]); - // ignore invalid char on the fill text, considering it as size 1 - if (char_len == 0) char_len += 1; - copied_chars_count++; - } - memcpy(ret + copied_chars_position, fill_text, fill_index); - copied_chars_position += fill_index; - } - // after fulfilling the text, copy the main string - memcpy(ret + copied_chars_position, text, text_len); - *out_len = copied_chars_position + text_len; + memset(ret, fill_text[0], chars_to_pad); + memcpy(ret + chars_to_pad, text, text_len); + *out_len = out_len_bytes; return ret; } + + // GENERAL PATH: Multi-byte fill - use evaluate_return_char_length for buffer size + gdv_int32 return_char_length = evaluate_return_char_length( + text_len, actual_text_len, return_length, fill_text, fill_text_len); + gdv_binary ret = reinterpret_cast( + gdv_fn_context_arena_malloc(context, return_char_length)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + + // Fill padding region using doubling strategy, then append text + gdv_int32 total_fill_bytes = return_char_length - text_len; + fill_buffer_with_pattern(ret, total_fill_bytes, fill_text, fill_text_len); + memcpy(ret + total_fill_bytes, text, text_len); + *out_len = return_char_length; + return ret; } FORCE_INLINE @@ -2054,47 +2074,49 @@ const char* rpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32 // fill into text but "fill_text" is empty, then return text directly. *out_len = text_len; return text; - } else if (return_length < actual_text_len) { + } + if (return_length < actual_text_len) { // case where it truncates the result on return length. *out_len = utf8_byte_pos(context, text, text_len, return_length); return text; - } else { - // case (return_length > actual_text_len) - // case where it needs to copy "fill_text" on the string right - gdv_int32 return_char_length = evaluate_return_char_length( - text_len, actual_text_len, return_length, fill_text, fill_text_len); - char* ret = reinterpret_cast( - gdv_fn_context_arena_malloc(context, return_char_length)); + } + + gdv_int32 chars_to_pad = return_length - actual_text_len; + + // FAST PATH: Single-byte fill (most common - space padding) + if (fill_text_len == 1) { + gdv_int32 out_len_bytes = chars_to_pad + text_len; + gdv_binary ret = + reinterpret_cast(gdv_fn_context_arena_malloc(context, out_len_bytes)); if (ret == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); *out_len = 0; return ""; } - // fulfill the initial text copying the main input string memcpy(ret, text, text_len); - // try to fulfill the return string with the "fill_text" continuously - int32_t copied_chars_count = 0; - int32_t copied_chars_position = 0; - while (actual_text_len + copied_chars_count < return_length) { - int32_t char_len; - int32_t fill_length; - // for each char, evaluate its length to consider it when mem copying - for (fill_length = 0; fill_length < fill_text_len; fill_length += char_len) { - if (actual_text_len + copied_chars_count >= return_length) { - break; - } - char_len = utf8_char_length(fill_text[fill_length]); - // ignore invalid char on the fill text, considering it as size 1 - if (char_len == 0) char_len += 1; - copied_chars_count++; - } - memcpy(ret + text_len + copied_chars_position, fill_text, fill_length); - copied_chars_position += fill_length; - } - *out_len = copied_chars_position + text_len; + memset(ret + text_len, fill_text[0], chars_to_pad); + *out_len = out_len_bytes; return ret; } + + // GENERAL PATH: Multi-byte fill - use evaluate_return_char_length for buffer size + gdv_int32 return_char_length = evaluate_return_char_length( + text_len, actual_text_len, return_length, fill_text, fill_text_len); + gdv_binary ret = reinterpret_cast( + gdv_fn_context_arena_malloc(context, return_char_length)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + + // Copy text first, then fill padding region using doubling strategy + memcpy(ret, text, text_len); + gdv_int32 total_fill_bytes = return_char_length - text_len; + fill_buffer_with_pattern(ret + text_len, total_fill_bytes, fill_text, fill_text_len); + *out_len = return_char_length; + return ret; } FORCE_INLINE diff --git a/cpp/src/gandiva/precompiled/string_ops_test.cc b/cpp/src/gandiva/precompiled/string_ops_test.cc index a204627a39a8..ea661585ecb5 100644 --- a/cpp/src/gandiva/precompiled/string_ops_test.cc +++ b/cpp/src/gandiva/precompiled/string_ops_test.cc @@ -417,6 +417,10 @@ TEST(TestStringOps, TestCastBoolToVarchar) { EXPECT_EQ(std::string(out_str, out_len), "false"); EXPECT_FALSE(ctx.has_error()); + out_str = castVARCHAR_bool_int64(ctx_ptr, true, 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + castVARCHAR_bool_int64(ctx_ptr, true, -3, &out_len); EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Output buffer length can't be negative")); @@ -1318,6 +1322,99 @@ TEST(TestStringOps, TestLpadString) { out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, -1, &out_len); EXPECT_EQ(std::string(out_str, out_len), ""); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "x", 1, 65536, "😀", 4, &out_len); + EXPECT_EQ(out_len, 65535 * 4 + 1); + EXPECT_FALSE(ctx.has_error()); + EXPECT_EQ(out_str[out_len - 1], 'x'); + EXPECT_EQ(std::string_view(out_str, 4), "😀"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "A", 1, 65536, "哈", 3, &out_len); + EXPECT_EQ(out_len, 65535 * 3 + 1); + EXPECT_FALSE(ctx.has_error()); + EXPECT_EQ(out_str[out_len - 1], 'A'); + EXPECT_EQ(std::string_view(out_str, 3), "哈"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "X", 1, 2, ".", 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ".X"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "Z", 1, 65536, "@", 1, &out_len); + EXPECT_EQ(out_len, 65536); + for (int i = 0; i < 100; i++) { + EXPECT_EQ(out_str[i], '@') << "Mismatch at position " << i; + } + EXPECT_EQ(out_str[out_len - 1], 'Z'); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "END", 3, 11, "ab", 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "ababababEND"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "END", 3, 10, "abc", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "abcabcaEND"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "X", 1, 5, "αβ", 4, &out_len); + EXPECT_EQ(out_len, 9); + EXPECT_EQ(std::string(out_str, out_len), "αβαβX"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "Y", 1, 4, "中文", 6, &out_len); + EXPECT_EQ(out_len, 10); + EXPECT_EQ(std::string(out_str, out_len), "中文中Y"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "X", 1, 4, "abc", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "abcX"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "X", 1, 7, "abc", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "abcabcX"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "X", 1, 13, "abc", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "abcabcabcabcX"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "X", 1, 10, "abc", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "abcabcabcX"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "E", 1, 129, "ab", 2, &out_len); + EXPECT_EQ(out_len, 129); + EXPECT_EQ(out_str[0], 'a'); + EXPECT_EQ(out_str[1], 'b'); + EXPECT_EQ(out_str[126], 'a'); + EXPECT_EQ(out_str[127], 'b'); + EXPECT_EQ(out_str[128], 'E'); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "E", 1, 127, "ab", 2, &out_len); + EXPECT_EQ(out_len, 127); + EXPECT_EQ(out_str[0], 'a'); + EXPECT_EQ(out_str[125], 'b'); + EXPECT_EQ(out_str[126], 'E'); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "X", 1, 2, "abc", 3, &out_len); + EXPECT_EQ(out_len, 2); + EXPECT_EQ(std::string(out_str, out_len), "aX"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "Y", 1, 3, "abcde", 5, &out_len); + EXPECT_EQ(out_len, 3); + EXPECT_EQ(std::string(out_str, out_len), "abY"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "Z", 1, 2, "αβ", 4, &out_len); + EXPECT_EQ(out_len, 3); + EXPECT_EQ(std::string(out_str, out_len), "αZ"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "A", 1, 2, "中文字", 9, &out_len); + EXPECT_EQ(out_len, 4); + EXPECT_EQ(std::string(out_str, out_len), "中A"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "B", 1, 3, "中文字", 9, &out_len); + EXPECT_EQ(out_len, 7); + EXPECT_EQ(std::string(out_str, out_len), "中文B"); + + std::string large_text(5000, 'X'); + std::string large_fill; + for (int i = 0; i < 50; ++i) { + large_fill += "α"; + } + out_str = lpad_utf8_int32_utf8(ctx_ptr, large_text.c_str(), 5000, 5001, + large_fill.c_str(), 100, &out_len); + EXPECT_EQ(out_len, 5002); + EXPECT_EQ(std::string(out_str, 2), "α"); + EXPECT_EQ(std::string(out_str + 2, 5000), large_text); } TEST(TestStringOps, TestRpadString) { @@ -1396,6 +1493,99 @@ TEST(TestStringOps, TestRpadString) { out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, -1, &out_len); EXPECT_EQ(std::string(out_str, out_len), ""); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "x", 1, 65536, "😀", 4, &out_len); + EXPECT_EQ(out_len, 1 + 65535 * 4); + EXPECT_FALSE(ctx.has_error()); + EXPECT_EQ(out_str[0], 'x'); + EXPECT_EQ(std::string_view(out_str + out_len - 4, 4), "😀"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "A", 1, 65536, "哈", 3, &out_len); + EXPECT_EQ(out_len, 1 + 65535 * 3); + EXPECT_FALSE(ctx.has_error()); + EXPECT_EQ(out_str[0], 'A'); + EXPECT_EQ(std::string_view(out_str + out_len - 3, 3), "哈"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "X", 1, 2, ".", 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "X."); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "Z", 1, 65536, "@", 1, &out_len); + EXPECT_EQ(out_len, 65536); + EXPECT_EQ(out_str[0], 'Z'); + for (int i = 1; i < 100; i++) { + EXPECT_EQ(out_str[i], '@') << "Mismatch at position " << i; + } + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "BEG", 3, 11, "ab", 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "BEGabababab"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "BEG", 3, 10, "abc", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "BEGabcabca"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "X", 1, 5, "αβ", 4, &out_len); + EXPECT_EQ(out_len, 9); + EXPECT_EQ(std::string(out_str, out_len), "Xαβαβ"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "Y", 1, 4, "中文", 6, &out_len); + EXPECT_EQ(out_len, 10); + EXPECT_EQ(std::string(out_str, out_len), "Y中文中"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "X", 1, 4, "abc", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Xabc"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "X", 1, 7, "abc", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Xabcabc"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "X", 1, 13, "abc", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Xabcabcabcabc"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "X", 1, 10, "abc", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Xabcabcabc"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "S", 1, 129, "ab", 2, &out_len); + EXPECT_EQ(out_len, 129); + EXPECT_EQ(out_str[0], 'S'); + EXPECT_EQ(out_str[1], 'a'); + EXPECT_EQ(out_str[2], 'b'); + EXPECT_EQ(out_str[127], 'a'); + EXPECT_EQ(out_str[128], 'b'); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "S", 1, 127, "ab", 2, &out_len); + EXPECT_EQ(out_len, 127); + EXPECT_EQ(out_str[0], 'S'); + EXPECT_EQ(out_str[125], 'a'); + EXPECT_EQ(out_str[126], 'b'); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "X", 1, 2, "abc", 3, &out_len); + EXPECT_EQ(out_len, 2); + EXPECT_EQ(std::string(out_str, out_len), "Xa"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "Y", 1, 3, "abcde", 5, &out_len); + EXPECT_EQ(out_len, 3); + EXPECT_EQ(std::string(out_str, out_len), "Yab"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "Z", 1, 2, "αβ", 4, &out_len); + EXPECT_EQ(out_len, 3); + EXPECT_EQ(std::string(out_str, out_len), "Zα"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "A", 1, 2, "中文字", 9, &out_len); + EXPECT_EQ(out_len, 4); + EXPECT_EQ(std::string(out_str, out_len), "A中"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "B", 1, 3, "中文字", 9, &out_len); + EXPECT_EQ(out_len, 7); + EXPECT_EQ(std::string(out_str, out_len), "B中文"); + + std::string large_text(5000, 'X'); + std::string large_fill; + for (int i = 0; i < 50; ++i) { + large_fill += "α"; + } + out_str = rpad_utf8_int32_utf8(ctx_ptr, large_text.c_str(), 5000, 5001, + large_fill.c_str(), 100, &out_len); + EXPECT_EQ(out_len, 5002); + EXPECT_EQ(std::string(out_str, 5000), large_text); + EXPECT_EQ(std::string(out_str + 5000, 2), "α"); } TEST(TestStringOps, TestRtrim) { diff --git a/cpp/src/gandiva/precompiled/time.cc b/cpp/src/gandiva/precompiled/time.cc index e1e9ac44567c..8414d0ed37cf 100644 --- a/cpp/src/gandiva/precompiled/time.cc +++ b/cpp/src/gandiva/precompiled/time.cc @@ -923,13 +923,15 @@ gdv_time32 castTIME_int32(int32_t int_val) { const char* castVARCHAR_timestamp_int64(gdv_int64 context, gdv_timestamp in, gdv_int64 length, gdv_int32* out_len) { - gdv_int64 year = extractYear_timestamp(in); - gdv_int64 month = extractMonth_timestamp(in); - gdv_int64 day = extractDay_timestamp(in); - gdv_int64 hour = extractHour_timestamp(in); - gdv_int64 minute = extractMinute_timestamp(in); - gdv_int64 second = extractSecond_timestamp(in); - gdv_int64 millis = in % MILLIS_IN_SEC; + EpochTimePoint tp(in); + gdv_int64 year = 1900 + tp.TmYear(); + gdv_int64 month = 1 + tp.TmMon(); + gdv_int64 day = tp.TmMday(); + gdv_int64 hour = tp.TmHour(); + gdv_int64 minute = tp.TmMin(); + gdv_int64 second = tp.TmSec(); + // Use TimeOfDay().subseconds() to correctly handle negative timestamps + gdv_int64 millis = tp.TimeOfDay().subseconds().count(); static const int kTimeStampStringLen = 23; const int char_buffer_length = kTimeStampStringLen + 1; // snprintf adds \0 diff --git a/cpp/src/gandiva/precompiled/time_test.cc b/cpp/src/gandiva/precompiled/time_test.cc index 82b38d1b5777..6cfa6acf579d 100644 --- a/cpp/src/gandiva/precompiled/time_test.cc +++ b/cpp/src/gandiva/precompiled/time_test.cc @@ -904,6 +904,24 @@ TEST(TestTime, castVarcharTimestamp) { ts = StringToTimestamp("2-5-1 00:00:04"); out = castVARCHAR_timestamp_int64(context_ptr, ts, 24L, &out_len); EXPECT_EQ(std::string(out, out_len), "0002-05-01 00:00:04.000"); + + // StringToTimestamp doesn't parse milliseconds, so we add them manually + ts = StringToTimestamp("67-5-1 00:00:04") + 920; + out = castVARCHAR_timestamp_int64(context_ptr, ts, 24L, &out_len); + EXPECT_EQ(std::string(out, out_len), "0067-05-01 00:00:04.920"); + + ts = StringToTimestamp("107-10-17 12:20:03") + 900; + out = castVARCHAR_timestamp_int64(context_ptr, ts, 24L, &out_len); + EXPECT_EQ(std::string(out, out_len), "0107-10-17 12:20:03.900"); + + // Test pre-epoch timestamps with 4-digit years + ts = StringToTimestamp("1969-12-31 23:59:59") + 920; + out = castVARCHAR_timestamp_int64(context_ptr, ts, 24L, &out_len); + EXPECT_EQ(std::string(out, out_len), "1969-12-31 23:59:59.920"); + + ts = StringToTimestamp("1899-12-31 23:59:59") + 123; + out = castVARCHAR_timestamp_int64(context_ptr, ts, 24L, &out_len); + EXPECT_EQ(std::string(out, out_len), "1899-12-31 23:59:59.123"); } TEST(TestTime, TestCastTimestampToDate) { diff --git a/cpp/src/gandiva/random_generator_holder.cc b/cpp/src/gandiva/random_generator_holder.cc index 8f80c5826d93..2729c2875ad7 100644 --- a/cpp/src/gandiva/random_generator_holder.cc +++ b/cpp/src/gandiva/random_generator_holder.cc @@ -16,6 +16,9 @@ // under the License. #include "gandiva/random_generator_holder.h" + +#include + #include "gandiva/node.h" namespace gandiva { @@ -40,4 +43,62 @@ Result> RandomGeneratorHolder::Make( return std::shared_ptr(new RandomGeneratorHolder( literal->is_null() ? 0 : std::get(literal->holder()))); } + +Result> RandomIntegerGeneratorHolder::Make( + const FunctionNode& node) { + ARROW_RETURN_IF( + node.children().size() > 2, + Status::Invalid("'rand_integer' function requires at most two parameters")); + + // No params: full int32 range [INT32_MIN, INT32_MAX] + if (node.children().empty()) { + return std::shared_ptr( + new RandomIntegerGeneratorHolder()); + } + + // One param: range [0, range - 1] + if (node.children().size() == 1) { + auto literal = dynamic_cast(node.children().at(0).get()); + ARROW_RETURN_IF( + literal == nullptr, + Status::Invalid("'rand_integer' function requires a literal as parameter")); + ARROW_RETURN_IF( + literal->return_type()->id() != arrow::Type::INT32, + Status::Invalid( + "'rand_integer' function requires an int32 literal as parameter")); + + // NULL range defaults to INT32_MAX (full positive range) + int32_t range = literal->is_null() ? std::numeric_limits::max() + : std::get(literal->holder()); + ARROW_RETURN_IF(range <= 0, + Status::Invalid("'rand_integer' function range must be positive")); + + return std::shared_ptr( + new RandomIntegerGeneratorHolder(range)); + } + + // Two params: min, max [min, max] inclusive + auto min_literal = dynamic_cast(node.children().at(0).get()); + auto max_literal = dynamic_cast(node.children().at(1).get()); + + ARROW_RETURN_IF( + min_literal == nullptr || max_literal == nullptr, + Status::Invalid("'rand_integer' function requires literals as parameters")); + ARROW_RETURN_IF( + min_literal->return_type()->id() != arrow::Type::INT32 || + max_literal->return_type()->id() != arrow::Type::INT32, + Status::Invalid("'rand_integer' function requires int32 literals as parameters")); + + // NULL min defaults to 0, NULL max defaults to INT32_MAX + int32_t min_val = min_literal->is_null() ? 0 : std::get(min_literal->holder()); + int32_t max_val = max_literal->is_null() ? std::numeric_limits::max() + : std::get(max_literal->holder()); + + ARROW_RETURN_IF(min_val > max_val, + Status::Invalid("'rand_integer' function min must be <= max")); + + return std::shared_ptr( + new RandomIntegerGeneratorHolder(min_val, max_val)); +} + } // namespace gandiva diff --git a/cpp/src/gandiva/random_generator_holder.h b/cpp/src/gandiva/random_generator_holder.h index ffab725aa7fc..752e8d242015 100644 --- a/cpp/src/gandiva/random_generator_holder.h +++ b/cpp/src/gandiva/random_generator_holder.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include @@ -53,4 +54,36 @@ class GANDIVA_EXPORT RandomGeneratorHolder : public FunctionHolder { std::uniform_real_distribution<> distribution_; }; +/// Function Holder for 'rand_integer' +class GANDIVA_EXPORT RandomIntegerGeneratorHolder : public FunctionHolder { + public: + ~RandomIntegerGeneratorHolder() override = default; + + static Result> Make( + const FunctionNode& node); + + int32_t operator()() { return distribution_(generator_); } + + private: + // Full range: [INT32_MIN, INT32_MAX] + RandomIntegerGeneratorHolder() + : distribution_(std::numeric_limits::min(), + std::numeric_limits::max()) { + generator_.seed(::arrow::internal::GetRandomSeed()); + } + + // Range: [0, range - 1] + explicit RandomIntegerGeneratorHolder(int32_t range) : distribution_(0, range - 1) { + generator_.seed(::arrow::internal::GetRandomSeed()); + } + + // Min/Max: [min, max] inclusive + RandomIntegerGeneratorHolder(int32_t min, int32_t max) : distribution_(min, max) { + generator_.seed(::arrow::internal::GetRandomSeed()); + } + + std::mt19937_64 generator_; + std::uniform_int_distribution distribution_; +}; + } // namespace gandiva diff --git a/cpp/src/gandiva/random_generator_holder_test.cc b/cpp/src/gandiva/random_generator_holder_test.cc index 77b2750f2e95..26677515c275 100644 --- a/cpp/src/gandiva/random_generator_holder_test.cc +++ b/cpp/src/gandiva/random_generator_holder_test.cc @@ -17,8 +17,10 @@ #include "gandiva/random_generator_holder.h" +#include #include +#include #include #include "arrow/testing/gtest_util.h" @@ -87,4 +89,161 @@ TEST_F(TestRandGenHolder, WithInValidSeed) { EXPECT_EQ(random_1(), random_2()); } +// Test that non-literal seed argument is rejected +TEST_F(TestRandGenHolder, NonLiteralSeedRejected) { + auto field_node = std::make_shared(arrow::field("seed", arrow::int32())); + FunctionNode rand_func = {"rand", {field_node}, arrow::float64()}; + + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("requires a literal as parameter"), + RandomGeneratorHolder::Make(rand_func).status()); +} + +class TestRandIntGenHolder : public ::testing::Test { + public: + FunctionNode BuildRandIntFunc() { return {"rand_integer", {}, arrow::int32()}; } + + FunctionNode BuildRandIntWithRangeFunc(int32_t range, bool range_is_null) { + auto range_node = std::make_shared(arrow::int32(), LiteralHolder(range), + range_is_null); + return {"rand_integer", {range_node}, arrow::int32()}; + } + + FunctionNode BuildRandIntWithMinMaxFunc(int32_t min, bool min_is_null, int32_t max, + bool max_is_null) { + auto min_node = + std::make_shared(arrow::int32(), LiteralHolder(min), min_is_null); + auto max_node = + std::make_shared(arrow::int32(), LiteralHolder(max), max_is_null); + return {"rand_integer", {min_node, max_node}, arrow::int32()}; + } +}; + +TEST_F(TestRandIntGenHolder, NoParams) { + FunctionNode rand_func = BuildRandIntFunc(); + EXPECT_OK_AND_ASSIGN(auto rand_gen_holder, + RandomIntegerGeneratorHolder::Make(rand_func)); + + auto& random = *rand_gen_holder; + // Generate multiple values and verify they are integers + for (int i = 0; i < 10; i++) { + int32_t val = random(); + EXPECT_GE(val, std::numeric_limits::min()); + EXPECT_LE(val, std::numeric_limits::max()); + } +} + +TEST_F(TestRandIntGenHolder, WithRange) { + FunctionNode rand_func = BuildRandIntWithRangeFunc(100, false); + EXPECT_OK_AND_ASSIGN(auto rand_gen_holder, + RandomIntegerGeneratorHolder::Make(rand_func)); + + auto& random = *rand_gen_holder; + // Generate multiple values and verify they are in range [0, 99] + for (int i = 0; i < 100; i++) { + int32_t val = random(); + EXPECT_GE(val, 0); + EXPECT_LT(val, 100); + } +} + +TEST_F(TestRandIntGenHolder, WithMinMax) { + FunctionNode rand_func = BuildRandIntWithMinMaxFunc(10, false, 20, false); + EXPECT_OK_AND_ASSIGN(auto rand_gen_holder, + RandomIntegerGeneratorHolder::Make(rand_func)); + + auto& random = *rand_gen_holder; + // Generate multiple values and verify they are in range [10, 20] + for (int i = 0; i < 100; i++) { + int32_t val = random(); + EXPECT_GE(val, 10); + EXPECT_LE(val, 20); + } +} + +TEST_F(TestRandIntGenHolder, WithNegativeMinMax) { + FunctionNode rand_func = BuildRandIntWithMinMaxFunc(-50, false, -10, false); + EXPECT_OK_AND_ASSIGN(auto rand_gen_holder, + RandomIntegerGeneratorHolder::Make(rand_func)); + + auto& random = *rand_gen_holder; + // Generate multiple values and verify they are in range [-50, -10] + for (int i = 0; i < 100; i++) { + int32_t val = random(); + EXPECT_GE(val, -50); + EXPECT_LE(val, -10); + } +} + +TEST_F(TestRandIntGenHolder, InvalidRangeZero) { + FunctionNode rand_func = BuildRandIntWithRangeFunc(0, false); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("range must be positive"), + RandomIntegerGeneratorHolder::Make(rand_func).status()); +} + +TEST_F(TestRandIntGenHolder, InvalidRangeNegative) { + FunctionNode rand_func = BuildRandIntWithRangeFunc(-5, false); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("range must be positive"), + RandomIntegerGeneratorHolder::Make(rand_func).status()); +} + +TEST_F(TestRandIntGenHolder, InvalidMinGreaterThanMax) { + FunctionNode rand_func = BuildRandIntWithMinMaxFunc(20, false, 10, false); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("min must be <= max"), + RandomIntegerGeneratorHolder::Make(rand_func).status()); +} + +TEST_F(TestRandIntGenHolder, NullRangeDefaultsToMaxInt) { + FunctionNode rand_func = BuildRandIntWithRangeFunc(0, true); // null range + EXPECT_OK_AND_ASSIGN(auto rand_gen_holder, + RandomIntegerGeneratorHolder::Make(rand_func)); + + auto& random = *rand_gen_holder; + // With NULL range defaulting to INT32_MAX, values should be in [0, INT32_MAX-1] + for (int i = 0; i < 100; i++) { + int32_t val = random(); + EXPECT_GE(val, 0); + EXPECT_LT(val, std::numeric_limits::max()); + } +} + +// Test that non-literal arguments are rejected +TEST_F(TestRandIntGenHolder, NonLiteralRangeRejected) { + // Create a FieldNode instead of LiteralNode for the range parameter + auto field_node = std::make_shared(arrow::field("range", arrow::int32())); + FunctionNode rand_func = {"rand_integer", {field_node}, arrow::int32()}; + + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("requires a literal as parameter"), + RandomIntegerGeneratorHolder::Make(rand_func).status()); +} + +TEST_F(TestRandIntGenHolder, NonLiteralMinMaxRejected) { + // Create FieldNodes instead of LiteralNodes for min/max parameters + auto min_field = std::make_shared(arrow::field("min", arrow::int32())); + auto max_literal = + std::make_shared(arrow::int32(), LiteralHolder(100), false); + FunctionNode rand_func = {"rand_integer", {min_field, max_literal}, arrow::int32()}; + + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("requires literals as parameters"), + RandomIntegerGeneratorHolder::Make(rand_func).status()); +} + +TEST_F(TestRandIntGenHolder, NullMinMaxDefaults) { + // Test null handling for 2-arg form: NULL min defaults to 0, NULL max defaults to + // INT32_MAX + FunctionNode rand_func = BuildRandIntWithMinMaxFunc(0, true, 0, true); // both null + EXPECT_OK_AND_ASSIGN(auto rand_gen_holder, + RandomIntegerGeneratorHolder::Make(rand_func)); + + auto& random = *rand_gen_holder; + // With NULL min=0, NULL max=INT32_MAX, values should be in [0, INT32_MAX] + for (int i = 0; i < 100; i++) { + int32_t val = random(); + EXPECT_GE(val, 0); + EXPECT_LE(val, std::numeric_limits::max()); + } +} + } // namespace gandiva diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc index dc1ac9dfd266..268cb55a6422 100644 --- a/cpp/src/gandiva/tests/projector_test.cc +++ b/cpp/src/gandiva/tests/projector_test.cc @@ -3678,4 +3678,161 @@ TEST_F(TestProjector, TestExtendedCFunctionThatNeedsContext) { EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0)); } +TEST_F(TestProjector, TestRandomNoArgs) { + // Test random() with no arguments - returns double in [0, 1) + auto dummy_field = field("dummy", arrow::int32()); + auto schema = arrow::schema({dummy_field}); + auto out_field = field("out", arrow::float64()); + + auto rand_node = TreeExprBuilder::MakeFunction("random", {}, arrow::float64()); + auto expr = TreeExprBuilder::MakeExpression(rand_node, out_field); + + std::shared_ptr projector; + ARROW_EXPECT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector)); + + int num_records = 100; + auto dummy_array = MakeArrowArrayInt32(std::vector(num_records, 0), + std::vector(num_records, true)); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {dummy_array}); + + arrow::ArrayVector outs; + ARROW_EXPECT_OK(projector->Evaluate(*in_batch, pool_, &outs)); + + // Verify all values are in range [0, 1) + auto result = std::dynamic_pointer_cast(outs.at(0)); + EXPECT_EQ(result->length(), num_records); + EXPECT_EQ(result->null_count(), 0); + for (int i = 0; i < num_records; i++) { + double value = result->Value(i); + EXPECT_GE(value, 0.0); + EXPECT_LT(value, 1.0); + } +} + +TEST_F(TestProjector, TestRandomWithSeed) { + // Test rand(seed) - with seed literal, returns double in [0, 1) + auto dummy_field = field("dummy", arrow::int32()); + auto schema = arrow::schema({dummy_field}); + auto out_field = field("out", arrow::float64()); + + auto seed_literal = TreeExprBuilder::MakeLiteral(static_cast(12345)); + auto rand_node = + TreeExprBuilder::MakeFunction("rand", {seed_literal}, arrow::float64()); + auto expr = TreeExprBuilder::MakeExpression(rand_node, out_field); + + std::shared_ptr projector; + ARROW_EXPECT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector)); + + int num_records = 100; + auto dummy_array = MakeArrowArrayInt32(std::vector(num_records, 0), + std::vector(num_records, true)); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {dummy_array}); + + arrow::ArrayVector outs; + ARROW_EXPECT_OK(projector->Evaluate(*in_batch, pool_, &outs)); + + // Verify all values are in range [0, 1) + auto result = std::dynamic_pointer_cast(outs.at(0)); + EXPECT_EQ(result->length(), num_records); + EXPECT_EQ(result->null_count(), 0); + for (int i = 0; i < num_records; i++) { + double value = result->Value(i); + EXPECT_GE(value, 0.0); + EXPECT_LT(value, 1.0); + } +} + +TEST_F(TestProjector, TestRandIntegerNoArgs) { + // Test rand_integer() with no arguments - full int32 range + auto dummy_field = field("dummy", arrow::int32()); + auto schema = arrow::schema({dummy_field}); + auto out_field = field("out", arrow::int32()); + + auto rand_int_node = TreeExprBuilder::MakeFunction("rand_integer", {}, arrow::int32()); + auto expr = TreeExprBuilder::MakeExpression(rand_int_node, out_field); + + std::shared_ptr projector; + ARROW_EXPECT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector)); + + int num_records = 100; + auto dummy_array = MakeArrowArrayInt32(std::vector(num_records, 0), + std::vector(num_records, true)); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {dummy_array}); + + arrow::ArrayVector outs; + ARROW_EXPECT_OK(projector->Evaluate(*in_batch, pool_, &outs)); + + // Verify all values are valid int32 (no specific range check for full range) + auto result = std::dynamic_pointer_cast(outs.at(0)); + EXPECT_EQ(result->length(), num_records); + EXPECT_EQ(result->null_count(), 0); +} + +TEST_F(TestProjector, TestRandIntegerWithRange) { + // Test rand_integer(10) - range [0, 9] + auto dummy_field = field("dummy", arrow::int32()); + auto schema = arrow::schema({dummy_field}); + auto out_field = field("out", arrow::int32()); + + auto range_literal = TreeExprBuilder::MakeLiteral(static_cast(10)); + auto rand_int_node = + TreeExprBuilder::MakeFunction("rand_integer", {range_literal}, arrow::int32()); + auto expr = TreeExprBuilder::MakeExpression(rand_int_node, out_field); + + std::shared_ptr projector; + ARROW_EXPECT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector)); + + int num_records = 100; + auto dummy_array = MakeArrowArrayInt32(std::vector(num_records, 0), + std::vector(num_records, true)); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {dummy_array}); + + arrow::ArrayVector outs; + ARROW_EXPECT_OK(projector->Evaluate(*in_batch, pool_, &outs)); + + // Verify all values are in range [0, 9] + auto result = std::dynamic_pointer_cast(outs.at(0)); + EXPECT_EQ(result->length(), num_records); + EXPECT_EQ(result->null_count(), 0); + for (int i = 0; i < num_records; i++) { + int32_t value = result->Value(i); + EXPECT_GE(value, 0); + EXPECT_LT(value, 10); + } +} + +TEST_F(TestProjector, TestRandIntegerWithMinMax) { + // Test rand_integer(5, 15) - range [5, 15] inclusive + auto dummy_field = field("dummy", arrow::int32()); + auto schema = arrow::schema({dummy_field}); + auto out_field = field("out", arrow::int32()); + + auto min_literal = TreeExprBuilder::MakeLiteral(static_cast(5)); + auto max_literal = TreeExprBuilder::MakeLiteral(static_cast(15)); + auto rand_int_node = TreeExprBuilder::MakeFunction( + "rand_integer", {min_literal, max_literal}, arrow::int32()); + auto expr = TreeExprBuilder::MakeExpression(rand_int_node, out_field); + + std::shared_ptr projector; + ARROW_EXPECT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector)); + + int num_records = 100; + auto dummy_array = MakeArrowArrayInt32(std::vector(num_records, 0), + std::vector(num_records, true)); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {dummy_array}); + + arrow::ArrayVector outs; + ARROW_EXPECT_OK(projector->Evaluate(*in_batch, pool_, &outs)); + + // Verify all values are in range [5, 15] inclusive + auto result = std::dynamic_pointer_cast(outs.at(0)); + EXPECT_EQ(result->length(), num_records); + EXPECT_EQ(result->null_count(), 0); + for (int i = 0; i < num_records; i++) { + int32_t value = result->Value(i); + EXPECT_GE(value, 5); + EXPECT_LE(value, 15); + } +} + } // namespace gandiva