diff --git a/velox/common/base/Exceptions.h b/velox/common/base/Exceptions.h index cac50d615de6..1a8fd6bb0bb9 100644 --- a/velox/common/base/Exceptions.h +++ b/velox/common/base/Exceptions.h @@ -65,11 +65,13 @@ template static_assert( !std::is_same_v, "BUG: we should not pass std::string by value to veloxCheckFail"); - LOG(ERROR) << "Line: " << args.file << ":" << args.line - << ", Function:" << args.function - << ", Expression: " << args.expression << " " << s - << ", Source: " << args.errorSource - << ", ErrorCode: " << args.errorCode; + if constexpr (!std::is_same_v) { + LOG(ERROR) << "Line: " << args.file << ":" << args.line + << ", Function:" << args.function + << ", Expression: " << args.expression << " " << s + << ", Source: " << args.errorSource + << ", ErrorCode: " << args.errorCode; + } throw Exception( args.file, diff --git a/velox/docs/functions/spark/math.rst b/velox/docs/functions/spark/math.rst index 3e674647b4ce..bb2b40042345 100644 --- a/velox/docs/functions/spark/math.rst +++ b/velox/docs/functions/spark/math.rst @@ -77,11 +77,29 @@ Mathematical Functions .. spark:function:: rand() -> double - Returns a random value with independent and identically distributed uniformly distributed values in [0, 1). :: + Returns a random value with uniformly distributed values in [0, 1). :: SELECT rand(); -- 0.9629742951434543 - SELECT rand(0); -- 0.7604953758285915 - SELECT rand(null); -- 0.7604953758285915 + +.. spark:function:: rand(seed, partitionIndex) -> double + + Returns a random value with uniformly distributed values in [0, 1) using a seed formed + by combining user-specified ``seed`` and framework provided ``partitionIndex``. The + framework is responsible for deterministic partitioning of the data and assigning unique + ``partitionIndex`` to each thread (in a deterministic way). + ``seed`` must be constant. NULL ``seed`` is identical to zero ``seed``. ``partitionIndex`` + cannot be NULL. :: + + SELECT rand(0); -- 0.5488135024422883 + SELECT rand(NULL); -- 0.5488135024422883 + +.. spark:function:: random() -> double + + An alias for ``rand()``. + +.. spark:function:: random(seed, partitionIndex) -> double + + An alias for ``rand(seed, partitionIndex)``. .. spark:function:: remainder(n, m) -> [same as n] diff --git a/velox/expression/CastExpr.cpp b/velox/expression/CastExpr.cpp index 2e9a0de5e99d..cfb56b4c537d 100644 --- a/velox/expression/CastExpr.cpp +++ b/velox/expression/CastExpr.cpp @@ -47,7 +47,8 @@ template void applyCastKernel( vector_size_t row, const SimpleVector::NativeType>* input, - FlatVector::NativeType>* result) { + FlatVector::NativeType>* result, + const std::string& sessionTzName) { if constexpr (ToKind == TypeKind::VARCHAR || ToKind == TypeKind::VARBINARY) { std::string output; if (input->type()->isDecimal()) { @@ -61,6 +62,11 @@ void applyCastKernel( auto writer = exec::StringWriter<>(result, row); writer.copy_from(output); writer.finalize(); + } else if constexpr ( + FromKind == TypeKind::TIMESTAMP && ToKind == TypeKind::DATE) { + auto output = util::Converter::cast( + input->valueAt(row), sessionTzName); + result->set(row, output); } else { if (input->type()->isDecimal()) { auto output = util::Converter::cast( @@ -221,6 +227,10 @@ void applyCastPrimitives( const auto& queryConfig = context.execCtx()->queryCtx()->queryConfig(); const bool isCastIntAllowDecimal = queryConfig.isCastIntAllowDecimal(); auto* inputSimpleVector = input.as>(); + std::string sessionTzName = ""; + if (queryConfig.adjustTimestampToTimezone()) { + sessionTzName = queryConfig.sessionTimezone(); + } if (!queryConfig.isCastToIntByTruncate()) { context.applyToSelectedNoThrow(rows, [&](int row) { @@ -228,10 +238,10 @@ void applyCastPrimitives( // Passing a false truncate flag if (isCastIntAllowDecimal) { applyCastKernel( - row, inputSimpleVector, resultFlatVector); + row, inputSimpleVector, resultFlatVector, sessionTzName); } else { applyCastKernel( - row, inputSimpleVector, resultFlatVector); + row, inputSimpleVector, resultFlatVector, sessionTzName); } } catch (const VeloxRuntimeError& re) { VELOX_FAIL( @@ -253,10 +263,10 @@ void applyCastPrimitives( // Passing a true truncate flag if (isCastIntAllowDecimal) { applyCastKernel( - row, inputSimpleVector, resultFlatVector); + row, inputSimpleVector, resultFlatVector, sessionTzName); } else { applyCastKernel( - row, inputSimpleVector, resultFlatVector); + row, inputSimpleVector, resultFlatVector, sessionTzName); } } catch (const VeloxRuntimeError& re) { VELOX_FAIL( @@ -279,7 +289,7 @@ void applyCastPrimitives( if constexpr (ToKind == TypeKind::TIMESTAMP) { // If user explicitly asked us to adjust the timezone. if (queryConfig.adjustTimestampToTimezone()) { - auto sessionTzName = queryConfig.sessionTimezone(); + // auto sessionTzName = queryConfig.sessionTimezone(); if (!sessionTzName.empty()) { // locate_zone throws runtime_error if the timezone couldn't be found // (so we're safe to dereference the pointer). diff --git a/velox/expression/ComplexWriterTypes.h b/velox/expression/ComplexWriterTypes.h index 1ab6a242e10a..df2a40a5e117 100644 --- a/velox/expression/ComplexWriterTypes.h +++ b/velox/expression/ComplexWriterTypes.h @@ -428,7 +428,6 @@ class MapWriter { std::tuple, PrimitiveWriter> operator[]( vector_size_t index) { - static_assert(std_interface, "operator [] not allowed for this map"); VELOX_DCHECK_LT(index, length_, "out of bound access"); return { PrimitiveWriter{keysVector_, innerOffset_ + index}, diff --git a/velox/expression/tests/SparkExpressionFuzzerTest.cpp b/velox/expression/tests/SparkExpressionFuzzerTest.cpp index 73eacbfcbe2d..166b7231c144 100644 --- a/velox/expression/tests/SparkExpressionFuzzerTest.cpp +++ b/velox/expression/tests/SparkExpressionFuzzerTest.cpp @@ -62,7 +62,11 @@ int main(int argc, char** argv) { "chr", "replace", "might_contain", - "unix_timestamp"}; + "unix_timestamp", + // Skip concat_ws as it triggers a test failure due to an incorrect + // expression generation from fuzzer: + // https://github.com/facebookincubator/velox/issues/6590 + "concat_ws"}; return FuzzerRunner::run( FLAGS_only, FLAGS_seed, skipFunctions, FLAGS_special_forms); } diff --git a/velox/functions/lib/DateTimeFormatter.cpp b/velox/functions/lib/DateTimeFormatter.cpp index 3a104a6f1d98..10f030735c10 100644 --- a/velox/functions/lib/DateTimeFormatter.cpp +++ b/velox/functions/lib/DateTimeFormatter.cpp @@ -214,27 +214,6 @@ std::string padContent( } } -void validateTimePoint(const std::chrono::time_point< - std::chrono::system_clock, - std::chrono::milliseconds>& timePoint) { - // Due to the limit of std::chrono we can only represent time in - // [-32767-01-01, 32767-12-31] date range - const auto minTimePoint = date::sys_days{ - date::year_month_day(date::year::min(), date::month(1), date::day(1))}; - const auto maxTimePoint = date::sys_days{ - date::year_month_day(date::year::max(), date::month(12), date::day(31))}; - if (timePoint < minTimePoint || timePoint > maxTimePoint) { - VELOX_USER_FAIL( - "Cannot format time out of range of [{}-{}-{}, {}-{}-{}]", - (int)date::year::min(), - "01", - "01", - (int)date::year::max(), - "12", - "31"); - } -} - size_t countOccurence(const std::string_view& base, const std::string& target) { int occurrences = 0; std::string::size_type pos = 0; @@ -952,10 +931,11 @@ void parseFromPattern( std::string DateTimeFormatter::format( const Timestamp& timestamp, const date::time_zone* timezone) const { - const std::chrono:: - time_point - timePoint(std::chrono::milliseconds(timestamp.toMillis())); - validateTimePoint(timePoint); + Timestamp t = timestamp; + if (timezone != nullptr) { + t.toTimezone(*timezone); + } + const auto timePoint = t.toTimePoint(); const auto daysTimePoint = date::floor(timePoint); const auto durationInTheDay = date::make_time(timePoint - daysTimePoint); diff --git a/velox/functions/prestosql/tests/DateTimeFunctionsTest.cpp b/velox/functions/prestosql/tests/DateTimeFunctionsTest.cpp index 6a7677d007f2..e49294e94d4e 100644 --- a/velox/functions/prestosql/tests/DateTimeFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/DateTimeFunctionsTest.cpp @@ -696,7 +696,7 @@ TEST_F(DateTimeFunctionsTest, hour) { EXPECT_EQ(std::nullopt, hour(std::nullopt)); EXPECT_EQ(13, hour(Timestamp(0, 0))); - EXPECT_EQ(12, hour(Timestamp(-1, Timestamp::kMaxNanos))); + EXPECT_EQ(13, hour(Timestamp(-1, Timestamp::kMaxNanos))); // Disabled for now because the TZ for Pacific/Apia in 2096 varies between // systems. // EXPECT_EQ(21, hour(Timestamp(4000000000, 0))); @@ -2529,12 +2529,12 @@ TEST_F(DateTimeFunctionsTest, formatDateTime) { // Multi-specifier and literal formats EXPECT_EQ( - "AD 19 1970 4 Thu 1970 1 1 1 AM 2 2 2 2 33 11 5 Asia/Kolkata", + "AD 19 1970 4 Thu 1970 1 1 1 AM 8 8 8 8 3 11 5 Asia/Kolkata", formatDatetime( fromTimestampString("1970-01-01 02:33:11.5"), "G C Y e E y D M d a K h H k m s S zzzz")); EXPECT_EQ( - "AD 19 1970 4 asdfghjklzxcvbnmqwertyuiop Thu ' 1970 1 1 1 AM 2 2 2 2 33 11 5 1234567890\\\"!@#$%^&*()-+`~{}[];:,./ Asia/Kolkata", + "AD 19 1970 4 asdfghjklzxcvbnmqwertyuiop Thu ' 1970 1 1 1 AM 8 8 8 8 3 11 5 1234567890\\\"!@#$%^&*()-+`~{}[];:,./ Asia/Kolkata", formatDatetime( fromTimestampString("1970-01-01 02:33:11.5"), "G C Y e 'asdfghjklzxcvbnmqwertyuiop' E '' y D M d a K h H k m s S 1234567890\\\"!@#$%^&*()-+`~{}[];:,./ zzzz")); @@ -2787,21 +2787,43 @@ TEST_F(DateTimeFunctionsTest, dateFormat) { EXPECT_EQ("z", dateFormat(fromTimestampString("1970-01-01"), "%z")); EXPECT_EQ("g", dateFormat(fromTimestampString("1970-01-01"), "%g")); - // With timezone + // With timezone. Indian Standard Time (IST) UTC+5:30. setQueryTimeZone("Asia/Kolkata"); + EXPECT_EQ( "1970-01-01", dateFormat(fromTimestampString("1970-01-01"), "%Y-%m-%d")); EXPECT_EQ( - "2000-02-29 12:00:00 AM", + "2000-02-29 05:30:00 AM", dateFormat( fromTimestampString("2000-02-29 00:00:00.987"), "%Y-%m-%d %r")); EXPECT_EQ( - "2000-02-29 00:00:00.987000", + "2000-02-29 05:30:00.987000", dateFormat( fromTimestampString("2000-02-29 00:00:00.987"), "%Y-%m-%d %H:%i:%s.%f")); EXPECT_EQ( - "-2000-02-29 00:00:00.987000", + "-2000-02-29 05:53:29.987000", + dateFormat( + fromTimestampString("-2000-02-29 00:00:00.987"), + "%Y-%m-%d %H:%i:%s.%f")); + + // Same timestamps with a different timezone. Pacific Daylight Time (North + // America) PDT UTC-8:00. + setQueryTimeZone("America/Los_Angeles"); + + EXPECT_EQ( + "1969-12-31", dateFormat(fromTimestampString("1970-01-01"), "%Y-%m-%d")); + EXPECT_EQ( + "2000-02-28 04:00:00 PM", + dateFormat( + fromTimestampString("2000-02-29 00:00:00.987"), "%Y-%m-%d %r")); + EXPECT_EQ( + "2000-02-28 16:00:00.987000", + dateFormat( + fromTimestampString("2000-02-29 00:00:00.987"), + "%Y-%m-%d %H:%i:%s.%f")); + EXPECT_EQ( + "-2000-02-28 16:07:03.987000", dateFormat( fromTimestampString("-2000-02-29 00:00:00.987"), "%Y-%m-%d %H:%i:%s.%f")); diff --git a/velox/functions/sparksql/DateTimeFunctions.h b/velox/functions/sparksql/DateTimeFunctions.h index d68246709818..603b2b2293f0 100644 --- a/velox/functions/sparksql/DateTimeFunctions.h +++ b/velox/functions/sparksql/DateTimeFunctions.h @@ -152,6 +152,90 @@ struct UnixTimestampParseWithFormatFunction bool invalidFormat_{false}; }; +/// Parse unix time in seconds to a string in given time format. +template +struct FromUnixtimeFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + const date::time_zone* sessionTimeZone_ = nullptr; + std::shared_ptr mysqlDateTime_; + bool isConstantTimeFormat = false; + + FOLLY_ALWAYS_INLINE void initialize( + const core::QueryConfig& config, + const arg_type* /*unixtime*/, + const arg_type* timeFormat) { + sessionTimeZone_ = getTimeZoneFromConfig(config); + if (timeFormat != nullptr) { + isConstantTimeFormat = true; + mysqlDateTime_ = buildJodaDateTimeFormatter( + std::string_view(timeFormat->data(), timeFormat->size())); + } + } + + FOLLY_ALWAYS_INLINE void call( + out_type& result, + const arg_type second, + const arg_type timeFormat) { + if (!isConstantTimeFormat) { + mysqlDateTime_ = buildJodaDateTimeFormatter( + std::string_view(timeFormat.data(), timeFormat.size())); + } + Timestamp timestamp = Timestamp::fromMillis(1000 * second); + auto formattedResult = mysqlDateTime_->format(timestamp, sessionTimeZone_); + auto resultSize = formattedResult.size(); + result.resize(resultSize); + if (resultSize != 0) { + std::memcpy(result.data(), formattedResult.data(), resultSize); + } + } +}; + +template +struct GetTimestampFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + std::shared_ptr formatter_; + bool isConstantTimeFormat_ = false; + std::optional sessionTzID_; + + int16_t getTimezoneId(const DateTimeResult& result) { + // If timezone was not parsed, fallback to the session timezone. If there's + // no session timezone, fallback to 0 (GMT). + return result.timezoneId != -1 ? result.timezoneId + : sessionTzID_.value_or(0); + } + + FOLLY_ALWAYS_INLINE void initialize( + const core::QueryConfig& config, + const arg_type* /*input*/, + const arg_type* format) { + auto sessionTzName = config.sessionTimezone(); + if (!sessionTzName.empty()) { + sessionTzID_ = util::getTimeZoneID(sessionTzName); + } + if (format != nullptr) { + this->formatter_ = buildJodaDateTimeFormatter( + std::string_view(format->data(), format->size())); + isConstantTimeFormat_ = true; + } + } + + FOLLY_ALWAYS_INLINE void call( + out_type& result, + const arg_type& input, + const arg_type& format) { + if (!isConstantTimeFormat_) { + formatter_ = buildJodaDateTimeFormatter( + std::string_view(format.data(), format.size())); + } + auto dateTimeResult = + this->formatter_->parse(std::string_view(input.data(), input.size())); + dateTimeResult.timestamp.toGMT(getTimezoneId(dateTimeResult)); + result = dateTimeResult.timestamp; + } +}; + template struct MakeDateFunction { VELOX_DEFINE_FUNCTION_TYPES(T); @@ -253,4 +337,46 @@ struct DateDiffFunction { } }; +template +struct DateFormatFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + const date::time_zone* sessionTimeZone_ = nullptr; + std::shared_ptr formatter_; + bool isConstFormat_ = false; + + FOLLY_ALWAYS_INLINE void setFormatter(const arg_type* formatString) { + if (formatString != nullptr) { + formatter_ = buildJodaDateTimeFormatter( + std::string_view(formatString->data(), formatString->size())); + isConstFormat_ = true; + } + } + + FOLLY_ALWAYS_INLINE void initialize( + const core::QueryConfig& config, + const arg_type* /*timestamp*/, + const arg_type* formatString) { + sessionTimeZone_ = getTimeZoneFromConfig(config); + setFormatter(formatString); + } + + FOLLY_ALWAYS_INLINE void call( + out_type& result, + const arg_type& timestamp, + const arg_type& formatString) { + if (!isConstFormat_) { + formatter_ = buildJodaDateTimeFormatter( + std::string_view(formatString.data(), formatString.size())); + } + + auto formattedResult = formatter_->format(timestamp, sessionTimeZone_); + auto resultSize = formattedResult.size(); + result.resize(resultSize); + if (resultSize != 0) { + std::memcpy(result.data(), formattedResult.data(), resultSize); + } + } +}; + } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/Rand.h b/velox/functions/sparksql/Rand.h new file mode 100644 index 000000000000..6af302970475 --- /dev/null +++ b/velox/functions/sparksql/Rand.h @@ -0,0 +1,85 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/functions/Macros.h" + +namespace facebook::velox::functions::sparksql { + +template +struct RandFunction { + static constexpr bool is_deterministic = false; + + std::optional generator; + + FOLLY_ALWAYS_INLINE void call(double& result) { + result = folly::Random::randDouble01(); + } + + FOLLY_ALWAYS_INLINE void callNullable( + double& result, + const int32_t* seed, + const int32_t* partitionIndex) { + VELOX_USER_CHECK_NOT_NULL(partitionIndex, "partitionIndex cannot be null."); + if (!generator.has_value()) { + generator = std::mt19937{}; + if (seed) { + generator->seed((uint64_t)*seed + *partitionIndex); + } else { + // For null input, 0 plus partitionIndex is the seed, consistent with + // Spark. + generator->seed(*partitionIndex); + } + } + result = folly::Random::randDouble01(*generator); + } + + /// To differentiate generator for each thread, seed plus partitionIndex is + /// the actual seed used for generator. + FOLLY_ALWAYS_INLINE void callNullable( + double& result, + const int64_t* seed, + const int32_t* partitionIndex) { + VELOX_USER_CHECK_NOT_NULL(partitionIndex, "partitionIndex cannot be null."); + if (!generator.has_value()) { + generator = std::mt19937{}; + if (seed) { + generator->seed((uint64_t)*seed + *partitionIndex); + } else { + // For null input, 0 plus partitionIndex is the seed, consistent with + // Spark. + generator->seed(*partitionIndex); + } + } + result = folly::Random::randDouble01(*generator); + } + + // For NULL constant input of unknown type. + FOLLY_ALWAYS_INLINE void callNullable( + double& result, + const UnknownValue* seed, + const int32_t* partitionIndex) { + VELOX_USER_CHECK_NOT_NULL(partitionIndex, "partitionIndex cannot be null."); + if (!generator.has_value()) { + generator = std::mt19937{}; + // For null input, 0 plus partitionIndex is the seed, consistent with + // Spark. + generator->seed(*partitionIndex); + } + result = folly::Random::randDouble01(*generator); + } +}; +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 3a6025a903b0..f4e43fdd45f6 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -19,7 +19,6 @@ #include "velox/functions/lib/Re2Functions.h" #include "velox/functions/lib/RegistrationHelpers.h" #include "velox/functions/prestosql/JsonFunctions.h" -#include "velox/functions/prestosql/Rand.h" #include "velox/functions/prestosql/StringFunctions.h" #include "velox/functions/sparksql/Arithmetic.h" #include "velox/functions/sparksql/ArraySort.h" @@ -73,8 +72,6 @@ static void workAroundRegistrationMacro(const std::string& prefix) { namespace sparksql { void registerFunctions(const std::string& prefix) { - registerFunction({prefix + "rand"}); - // Register size functions registerSize(prefix + "size"); @@ -115,6 +112,9 @@ void registerFunctions(const std::string& prefix) { prefix + "instr", instrSignatures(), makeInstr); exec::registerStatefulVectorFunction( prefix + "length", lengthSignatures(), makeLength); + VELOX_REGISTER_VECTOR_FUNCTION(udf_str_to_map, prefix + "str_to_map"); + exec::registerStatefulVectorFunction( + prefix + "concat_ws", concatWsSignatures(), makeConcatWs); registerFunction({prefix + "md5"}); registerFunction( @@ -207,6 +207,11 @@ void registerFunctions(const std::string& prefix) { {prefix + "make_date"}); registerFunction({prefix + "last_day"}); + registerFunction( + {prefix + "get_timestamp"}); + + registerFunction( + {prefix + "from_unixtime"}); // Register bloom filter function registerFunction( @@ -266,6 +271,8 @@ void registerFunctions(const std::string& prefix) { registerFunction({"date_add"}); registerFunction({"date_add"}); registerFunction({"date_diff"}); + registerFunction( + {"date_format"}); registerFunction( {prefix + "atan2"}); registerFunction({prefix + "log2"}); diff --git a/velox/functions/sparksql/RegisterArithmetic.cpp b/velox/functions/sparksql/RegisterArithmetic.cpp index 80f28da52294..14c031189ca8 100644 --- a/velox/functions/sparksql/RegisterArithmetic.cpp +++ b/velox/functions/sparksql/RegisterArithmetic.cpp @@ -18,9 +18,32 @@ #include "velox/functions/prestosql/Arithmetic.h" #include "velox/functions/prestosql/CheckedArithmetic.h" #include "velox/functions/sparksql/Arithmetic.h" +#include "velox/functions/sparksql/Rand.h" namespace facebook::velox::functions::sparksql { +void registerRandFunctions(const std::string& prefix) { + registerFunction({prefix + "rand", prefix + "random"}); + // Has seed & partition index as input. + registerFunction< + RandFunction, + double, + int32_t /*seed*/, + int32_t /*partition index*/>({prefix + "rand", prefix + "random"}); + // Has seed & partition index as input. + registerFunction< + RandFunction, + double, + int64_t /*seed*/, + int32_t /*partition index*/>({prefix + "rand", prefix + "random"}); + // NULL constant as seed of unknown type. + registerFunction< + RandFunction, + double, + UnknownValue /*seed*/, + int32_t /*partition index*/>({prefix + "rand", prefix + "random"}); +} + void registerArithmeticFunctions(const std::string& prefix) { // Operators. registerBinaryNumeric({prefix + "add"}); @@ -63,6 +86,7 @@ void registerArithmeticFunctions(const std::string& prefix) { VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_sub, prefix + "decimal_subtract"); VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_mul, prefix + "decimal_multiply"); VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_div, prefix + "decimal_divide"); + registerRandFunctions(prefix); } } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/SplitFunctions.cpp b/velox/functions/sparksql/SplitFunctions.cpp index 1fa68613c363..80c3b80c699b 100644 --- a/velox/functions/sparksql/SplitFunctions.cpp +++ b/velox/functions/sparksql/SplitFunctions.cpp @@ -23,6 +23,116 @@ namespace facebook::velox::functions::sparksql { namespace { +// str_to_map(expr [, pairDelim [, keyValueDelim] ] ) +class StrToMap final : public exec::VectorFunction { + public: + StrToMap() = default; + + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& /* outputType */, + exec::EvalCtx& context, + VectorPtr& result) const override { + exec::DecodedArgs decodedArgs(rows, args, context); + DecodedVector* strings = decodedArgs.at(0); + char pairDelim = ','; + char kvDelim = ':'; + VELOX_CHECK( + !args.empty(), + "StrToMap function should provide at least one argument"); + if (args.size() > 1) { + pairDelim = args[1]->as>()->valueAt(0).data()[0]; + if (args.size() > 2) { + kvDelim = args[2]->as>()->valueAt(0).data()[0]; + } + } + + BaseVector::ensureWritable( + rows, MAP(VARCHAR(), VARCHAR()), context.pool(), result); + exec::VectorWriter> resultWriter; + resultWriter.init(*result->as()); + + std::unordered_map keyToIdx; + rows.applyToSelected([&](vector_size_t row) { + resultWriter.setOffset(row); + auto& mapWriter = resultWriter.current(); + + const StringView& current = strings->valueAt(row); + const char* pos = current.begin(); + const char* end = pos + current.size(); + const char* pair; + const char* kv; + do { + pair = std::find(pos, end, pairDelim); + kv = std::find(pos, pair, kvDelim); + auto key = StringView(pos, kv - pos); + auto iter = keyToIdx.find(key); + if (iter == keyToIdx.end()) { + keyToIdx.emplace(key, mapWriter.size()); + if (kv == pair) { + mapWriter.add_null().append(key); + } else { + auto [keyWriter, valueWriter] = mapWriter.add_item(); + keyWriter.append(key); + valueWriter.append(StringView(kv + 1, pair - kv - 1)); + } + } else { + auto valueWriter = std::get<1>(mapWriter[iter->second]); + if (kv == pair) { + valueWriter = std::nullopt; + } else { + valueWriter = StringView(kv + 1, pair - kv - 1); + } + } + + pos = pair + 1; // Skip past delim. + } while (pair != end); + + resultWriter.commit(); + }); + + resultWriter.finish(); + + // Ensure that our result elements vector uses the same string buffer as + // the input vector of strings. + result->as() + ->mapKeys() + ->as>() + ->acquireSharedStringBuffers(strings->base()); + result->as() + ->mapValues() + ->as>() + ->acquireSharedStringBuffers(strings->base()); + } +}; + +std::vector> strToMapSignatures() { + // varchar, varchar -> array(varchar) + return { + exec::FunctionSignatureBuilder() + .returnType("map(varchar, varchar)") + .argumentType("varchar") + .build(), + exec::FunctionSignatureBuilder() + .returnType("map(varchar, varchar)") + .argumentType("varchar") + .argumentType("varchar") + .build(), + exec::FunctionSignatureBuilder() + .returnType("map(varchar, varchar)") + .argumentType("varchar") + .argumentType("varchar") + .argumentType("varchar") + .build()}; +} + +std::shared_ptr createStrToMap( + const std::string& name, + const std::vector& inputArgs) { + return std::make_shared(); +} + /// The function returns specialized version of split based on the constant /// inputs. /// \param inputArgs the inputs types (VARCHAR, VARCHAR, int64) and constant @@ -39,6 +149,11 @@ std::vector> signatures() { } } // namespace +VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( + udf_str_to_map, + strToMapSignatures(), + createStrToMap); + VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( udf_regexp_split, signatures(), diff --git a/velox/functions/sparksql/String.cpp b/velox/functions/sparksql/String.cpp index b3d04dd31c60..aafbf021fb74 100644 --- a/velox/functions/sparksql/String.cpp +++ b/velox/functions/sparksql/String.cpp @@ -103,6 +103,230 @@ class Length : public exec::VectorFunction { } }; +void concatWsVariableParameters( + const SelectivityVector& rows, + std::vector& args, + exec::EvalCtx& context, + const std::string& connector, + FlatVector& flatResult) { + std::vector argMapping; + std::vector constantStrings; + std::vector constantStringViews; + auto numArgs = args.size(); + + // Save constant values to constantStrings_. + // Identify and combine consecutive constant inputs. + argMapping.reserve(numArgs - 1); + constantStrings.reserve(numArgs - 1); + for (auto i = 1; i < numArgs; ++i) { + argMapping.push_back(i); + if (args[i] && args[i]->as>() && + !args[i]->as>()->isNullAt(0)) { + std::string value = + args[i]->as>()->valueAt(0).str(); + column_index_t j = i + 1; + for (; j < args.size(); ++j) { + if (!args[j] || !args[j]->as>() || + args[j]->as>()->isNullAt(0)) { + break; + } + + value += connector + + args[j]->as>()->valueAt(0).str(); + } + constantStrings.push_back(std::string(value.data(), value.size())); + i = j - 1; + } else { + constantStrings.push_back(std::string()); + } + } + + // Create StringViews for constant strings. + constantStringViews.reserve(numArgs - 1); + for (const auto& constantString : constantStrings) { + constantStringViews.push_back( + StringView(constantString.data(), constantString.size())); + } + + auto numCols = argMapping.size(); + std::vector decodedArgs; + decodedArgs.reserve(numCols); + + for (auto i = 0; i < numCols; ++i) { + auto index = argMapping[i]; + if (constantStringViews[i].empty()) { + decodedArgs.emplace_back(context, *args[index], rows); + } else { + // Do not decode constant inputs. + decodedArgs.emplace_back(context); + } + } + + size_t totalResultBytes = 0; + rows.applyToSelected([&](auto row) { + auto isFirst = true; + for (int i = 0; i < numCols; i++) { + auto value = constantStringViews[i].empty() + ? decodedArgs[i]->valueAt(row) + : constantStringViews[i]; + if (!value.empty()) { + if (isFirst) { + isFirst = false; + } else { + totalResultBytes += connector.size(); + } + totalResultBytes += value.size(); + } + } + }); + + // Allocate a string buffer. + auto rawBuffer = flatResult.getRawStringBufferWithSpace(totalResultBytes); + size_t offset = 0; + rows.applyToSelected([&](int row) { + const char* start = rawBuffer + offset; + size_t combinedSize = 0; + auto isFirst = true; + for (int i = 0; i < numCols; i++) { + StringView value; + if (constantStringViews[i].empty()) { + value = decodedArgs[i]->valueAt(row); + } else { + value = constantStringViews[i]; + } + auto size = value.size(); + if (size > 0) { + if (isFirst) { + isFirst = false; + } else { + memcpy(rawBuffer + offset, connector.data(), connector.size()); + offset += connector.size(); + combinedSize += connector.size(); + } + memcpy(rawBuffer + offset, value.data(), size); + combinedSize += size; + offset += size; + } + } + flatResult.setNoCopy(row, StringView(start, combinedSize)); + }); +} + +void concatWsArray( + const SelectivityVector& rows, + std::vector& args, + exec::EvalCtx& context, + const std::string& connector, + FlatVector& flatResult) { + exec::LocalDecodedVector arrayHolder(context, *args[1], rows); + auto& arrayDecoded = *arrayHolder.get(); + auto baseArray = arrayDecoded.base()->as(); + auto rawSizes = baseArray->rawSizes(); + auto rawOffsets = baseArray->rawOffsets(); + auto indices = arrayDecoded.indices(); + + auto elements = arrayHolder.get()->base()->as()->elements(); + exec::LocalSelectivityVector nestedRows(context, elements->size()); + nestedRows.get()->setAll(); + exec::LocalDecodedVector elementsHolder( + context, *elements, *nestedRows.get()); + auto& elementsDecoded = *elementsHolder.get(); + auto elementsBase = elementsDecoded.base(); + + size_t totalResultBytes = 0; + rows.applyToSelected([&](auto row) { + auto size = rawSizes[indices[row]]; + auto offset = rawOffsets[indices[row]]; + + auto isFirst = true; + for (auto i = 0; i < size; ++i) { + if (!elementsBase->isNullAt(offset + i)) { + auto element = elementsDecoded.valueAt(offset + i); + if (!element.empty()) { + if (isFirst) { + isFirst = false; + } else { + totalResultBytes += connector.size(); + } + totalResultBytes += element.size(); + } + } + } + }); + + // Allocate a string buffer. + auto rawBuffer = flatResult.getRawStringBufferWithSpace(totalResultBytes); + size_t bufferOffset = 0; + rows.applyToSelected([&](int row) { + auto size = rawSizes[indices[row]]; + auto offset = rawOffsets[indices[row]]; + + const char* start = rawBuffer + bufferOffset; + size_t combinedSize = 0; + auto isFirst = true; + for (auto i = 0; i < size; ++i) { + if (!elementsBase->isNullAt(offset + i)) { + auto element = elementsDecoded.valueAt(offset + i); + if (!element.empty()) { + if (isFirst) { + isFirst = false; + } else { + memcpy( + rawBuffer + bufferOffset, connector.data(), connector.size()); + bufferOffset += connector.size(); + combinedSize += connector.size(); + } + memcpy(rawBuffer + bufferOffset, element.data(), element.size()); + bufferOffset += element.size(); + combinedSize += element.size(); + } + } + flatResult.setNoCopy(row, StringView(start, combinedSize)); + } + }); +} + +class ConcatWs : public exec::VectorFunction { + public: + explicit ConcatWs(const std::string& connector) : connector_(connector) {} + + bool isDefaultNullBehavior() const override { + return false; + } + + void apply( + const SelectivityVector& selected, + std::vector& args, + const TypePtr& /* outputType */, + exec::EvalCtx& context, + VectorPtr& result) const override { + context.ensureWritable(selected, VARCHAR(), result); + auto flatResult = result->asFlatVector(); + auto numArgs = args.size(); + if (numArgs == 1) { + selected.applyToSelected( + [&](int row) { flatResult->setNoCopy(row, StringView("")); }); + return; + } + + if (args[0]->isNullAt(0)) { + selected.applyToSelected([&](int row) { result->setNull(row, true); }); + return; + } + + auto arrayArgs = args[1]->typeKind() == TypeKind::ARRAY; + if (arrayArgs) { + concatWsArray(selected, args, context, connector_, *flatResult); + } else { + concatWsVariableParameters( + selected, args, context, connector_, *flatResult); + } + } + + private: + const std::string connector_; +}; + } // namespace std::vector> instrSignatures() { @@ -142,6 +366,43 @@ std::shared_ptr makeLength( return kLengthFunction; } +std::vector> concatWsSignatures() { + return { + // varchar, varchar,... -> varchar. + exec::FunctionSignatureBuilder() + .returnType("varchar") + .constantArgumentType("varchar") + .argumentType("varchar") + .variableArity() + .build(), + // varchar, array(varchar) -> varchar. + exec::FunctionSignatureBuilder() + .returnType("varchar") + .constantArgumentType("varchar") + .argumentType("array(varchar)") + .build(), + }; +} + +std::shared_ptr makeConcatWs( + const std::string& name, + const std::vector& inputArgs) { + auto numArgs = inputArgs.size(); + VELOX_USER_CHECK( + numArgs >= 1, + "concat_ws requires one arguments at least, but got {}.", + numArgs); + + BaseVector* constantPattern = inputArgs[0].constantValue.get(); + VELOX_USER_CHECK( + nullptr != constantPattern, + "concat_ws requires constant connector arguments."); + + auto connector = + constantPattern->as>()->valueAt(0).str(); + return std::make_shared(connector); +} + void encodeDigestToBase16(uint8_t* output, int digestSize) { static unsigned char const kHexCodes[] = "0123456789abcdef"; for (int i = digestSize - 1; i >= 0; --i) { diff --git a/velox/functions/sparksql/String.h b/velox/functions/sparksql/String.h index 121e3cc54605..03940221f00f 100644 --- a/velox/functions/sparksql/String.h +++ b/velox/functions/sparksql/String.h @@ -93,6 +93,12 @@ std::shared_ptr makeLength( const std::string& name, const std::vector& inputArgs); +std::vector> concatWsSignatures(); + +std::shared_ptr makeConcatWs( + const std::string& name, + const std::vector& inputArgs); + /// Expands each char of the digest data to two chars, /// representing the hex value of each digest char, in order. /// Note: digestSize must be one-half of outputSize. diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index 4314153b0c03..8f037ae5c151 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -27,6 +27,7 @@ add_executable( LeastGreatestTest.cpp MapTest.cpp MightContainTest.cpp + RandTest.cpp RegexFunctionsTest.cpp SizeTest.cpp SortArrayTest.cpp diff --git a/velox/functions/sparksql/tests/DateTimeFunctionsTest.cpp b/velox/functions/sparksql/tests/DateTimeFunctionsTest.cpp index 03bed76d4816..3367f9e32f33 100644 --- a/velox/functions/sparksql/tests/DateTimeFunctionsTest.cpp +++ b/velox/functions/sparksql/tests/DateTimeFunctionsTest.cpp @@ -17,6 +17,7 @@ #include #include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" +#include "velox/type/TimestampConversion.h" #include "velox/type/tz/TimeZoneMap.h" namespace facebook::velox::functions::sparksql::test { @@ -271,5 +272,28 @@ TEST_F(DateTimeFunctionsTest, dateDiff) { EXPECT_EQ(-366, dateDiff(parseDate("2020-02-29"), parseDate("2019-02-28"))); } +TEST_F(DateTimeFunctionsTest, fromUnixTime) { + const auto fromUnixTime = [&](std::optional unixTime, + std::optional timeFormat) { + return evaluateOnce( + "from_unixtime(c0, c1)", unixTime, timeFormat); + }; + + EXPECT_EQ(fromUnixTime(100, "yyyy-MM-dd"), "1970-01-01"); + EXPECT_EQ(fromUnixTime(120, "yyyy-MM-dd HH:mm"), "1970-01-01 00:02"); + EXPECT_EQ(fromUnixTime(100, "yyyy-MM-dd HH:mm:ss"), "1970-01-01 00:01:40"); +} + +TEST_F(DateTimeFunctionsTest, dateFormat) { + const auto dateFormat = [&](std::optional timestamp, + const std::string& formatString) { + return evaluateOnce( + fmt::format("date_format(c0, '{}')", formatString), timestamp); + }; + using util::fromTimestampString; + + EXPECT_EQ("1970", dateFormat(fromTimestampString("1970-01-01"), "y")); +} + } // namespace } // namespace facebook::velox::functions::sparksql::test diff --git a/velox/functions/sparksql/tests/RandTest.cpp b/velox/functions/sparksql/tests/RandTest.cpp new file mode 100644 index 000000000000..cbb49476c8ff --- /dev/null +++ b/velox/functions/sparksql/tests/RandTest.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class RandTest : public SparkFunctionBaseTest { + public: + RandTest() { + // Allow for parsing literal integers as INTEGER, not BIGINT. + options_.parseIntegerAsBigint = false; + } + + protected: + std::optional rand(int32_t seed, int32_t partitionIndex = 0) { + return evaluateOnce( + fmt::format("rand({}, {})", seed, partitionIndex), + makeRowVector(ROW({}), 1)); + } + + std::optional randWithNullSeed(int32_t partitionIndex = 0) { + return evaluateOnce( + fmt::format("rand(NULL, {})", partitionIndex), + makeRowVector(ROW({}), 1)); + } + + std::optional randWithNoSeed() { + return evaluateOnce("rand()", makeRowVector(ROW({}), 1)); + } + + VectorPtr randWithBatchInput(int32_t seed, int32_t partitionIndex = 0) { + auto exprSet = compileExpression( + fmt::format("rand({}, {})", seed, partitionIndex), ROW({})); + return evaluate(*exprSet, makeRowVector(ROW({}), 20)); + } + + void checkResult(const std::optional& result) { + EXPECT_NE(result, std::nullopt); + EXPECT_GE(result.value(), 0.0); + EXPECT_LT(result.value(), 1.0); + } + + // Check whether two vectors that have same size & type, but not all same + // values. + void assertNotEqualVectors(const VectorPtr& left, const VectorPtr& right) { + ASSERT_EQ(left->size(), right->size()); + ASSERT_TRUE(left->type()->equivalent(*right->type())); + for (auto i = 0; i < left->size(); i++) { + if (!left->equalValueAt(right.get(), i, i)) { + return; + } + } + FAIL() << "Expect two different vectors are produced."; + } +}; + +TEST_F(RandTest, withSeed) { + checkResult(rand(0)); + // With same default partitionIndex used, same seed always produces same + // result. + EXPECT_EQ(rand(0), rand(0)); + + checkResult(rand(1)); + EXPECT_EQ(rand(1), rand(1)); + + checkResult(rand(20000)); + EXPECT_EQ(rand(20000), rand(20000)); + + // Test with same seed, but different partitionIndex. + EXPECT_NE(rand(0, 0), rand(0, 1)); + EXPECT_NE(rand(1000, 0), rand(1000, 1)); + + checkResult(randWithNullSeed()); + // Null as seed is identical to 0 as seed. + EXPECT_EQ(randWithNullSeed(), rand(0)); + // Same null as seed but different partition index. + EXPECT_NE(randWithNullSeed(0), randWithNullSeed(1)); + + // Test with batch input. + auto batchResult1 = randWithBatchInput(100); + auto batchResult2 = randWithBatchInput(100); + // Same seed & partition index produce same results. + velox::test::assertEqualVectors(batchResult1, batchResult2); + batchResult1 = randWithBatchInput(100, 0 /*partitionIndex*/); + batchResult2 = randWithBatchInput(100, 1 /*partitionIndex*/); + // Same seed but different partition index cannot produce absolutely same + // result. + assertNotEqualVectors(batchResult1, batchResult2); +} + +TEST_F(RandTest, withoutSeed) { + auto result1 = randWithNoSeed(); + auto result2 = randWithNoSeed(); + auto result3 = randWithNoSeed(); + checkResult(result1); + checkResult(result2); + checkResult(result3); + // It is impossible to get three same results by three separate callings. + EXPECT_FALSE( + (result1.value() == result2.value()) && + (result1.value() == result3.value())); +} + +} // namespace +} // namespace facebook::velox::functions::sparksql::test diff --git a/velox/functions/sparksql/tests/StringTest.cpp b/velox/functions/sparksql/tests/StringTest.cpp index 05294ffad9c9..0f27721468a0 100644 --- a/velox/functions/sparksql/tests/StringTest.cpp +++ b/velox/functions/sparksql/tests/StringTest.cpp @@ -173,6 +173,79 @@ class StringTest : public SparkFunctionBaseTest { return evaluateOnce( "replace(c0, c1, c2)", str, replaced, replacement); } + + std::string generateRandomString(size_t length) { + const std::string chars = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + std::string randomString; + for (std::size_t i = 0; i < length; ++i) { + randomString += chars[folly::Random::rand32() % chars.size()]; + } + return randomString; + } + + void testConcatWsFlatVector( + const std::vector>& inputTable, + const size_t argsCount, + const std::string& separator) { + std::vector inputVectors; + + for (int i = 0; i < argsCount; i++) { + inputVectors.emplace_back( + BaseVector::create(VARCHAR(), inputTable.size(), execCtx_.pool())); + } + + for (int row = 0; row < inputTable.size(); row++) { + for (int col = 0; col < argsCount; col++) { + std::static_pointer_cast>(inputVectors[col]) + ->set(row, StringView(inputTable[row][col])); + } + } + + auto buildConcatQuery = [&]() { + std::string output = "concat_ws('" + separator + "'"; + + for (int i = 0; i < argsCount; i++) { + output += ",c" + std::to_string(i); + } + output += ")"; + return output; + }; + + // Evaluate 'concat_ws' expression and verify no excessive memory + // allocation. We expect 2 allocations: one for the values buffer and + // another for the strings buffer. I.e. FlatVector::values and + // FlatVector::stringBuffers. + auto numAllocsBefore = pool()->stats().numAllocs; + + auto result = evaluate>( + buildConcatQuery(), makeRowVector(inputVectors)); + + auto numAllocsAfter = pool()->stats().numAllocs; + ASSERT_EQ(numAllocsAfter - numAllocsBefore, 2); + + auto concatStd = [&](const std::vector& inputs) { + auto isFirst = true; + std::string output; + for (int i = 0; i < inputs.size(); i++) { + auto value = inputs[i]; + if (!value.empty()) { + if (isFirst) { + isFirst = false; + } else { + output += separator; + } + output += value; + } + } + return output; + }; + + for (int i = 0; i < inputTable.size(); ++i) { + EXPECT_EQ(result->valueAt(i), concatStd(inputTable[i])) << "at " << i; + } + } }; TEST_F(StringTest, Ascii) { @@ -538,5 +611,131 @@ TEST_F(StringTest, replace) { "123\u6570data"); } +// Test concat_ws vector function +TEST_F(StringTest, concat_ws) { + // test concat_ws variable arguments + size_t maxArgsCount = 10; // cols + size_t rowCount = 100; + size_t maxStringLength = 100; + + std::vector> inputTable; + for (int argsCount = 1; argsCount <= maxArgsCount; argsCount++) { + inputTable.clear(); + + // Create table with argsCount columns + inputTable.resize(rowCount, std::vector(argsCount)); + + // Fill the table + for (int row = 0; row < rowCount; row++) { + for (int col = 0; col < argsCount; col++) { + inputTable[row][col] = + generateRandomString(folly::Random::rand32() % maxStringLength); + } + } + + SCOPED_TRACE(fmt::format("Number of arguments: {}", argsCount)); + testConcatWsFlatVector(inputTable, argsCount, "--testSep--"); + } + + // Test constant input vector with 2 args + { + auto rows = makeRowVector(makeRowType({VARCHAR(), VARCHAR()}), 10); + auto c0 = generateRandomString(20); + auto c1 = generateRandomString(20); + auto result = evaluate>( + fmt::format("concat_ws('-', '{}', '{}')", c0, c1), rows); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(result->valueAt(i), c0 + "-" + c1); + } + } + + // Multiple consecutive constant inputs. + { + std::string value; + auto data = makeRowVector({ + makeFlatVector( + 1'000, + [&](auto /* row */) { + value = generateRandomString( + folly::Random::rand32() % maxStringLength); + return StringView(value); + }), + makeFlatVector( + 1'000, + [&](auto /* row */) { + value = generateRandomString( + folly::Random::rand32() % maxStringLength); + return StringView(value); + }), + }); + + auto c0 = data->childAt(0)->as>()->rawValues(); + auto c1 = data->childAt(1)->as>()->rawValues(); + + auto result = evaluate>( + "concat_ws('--', c0, c1, 'foo', 'bar')", data); + + auto expected = makeFlatVector(1'000, [&](auto row) { + value = ""; + const std::string& s0 = c0[row].str(); + const std::string& s1 = c1[row].str(); + + if (s0.empty() && s1.empty()) { + value = "foo--bar"; + } else if (!s0.empty() && !s1.empty()) { + value = s0 + "--" + s1 + "--foo--bar"; + } else { + value = s0 + s1 + "--foo--bar"; + } + return StringView(value); + }); + + velox::test::assertEqualVectors(expected, result); + + result = evaluate>( + "concat_ws('$*@', 'aaa', '测试', c0, 'eee', 'ddd', c1, '\u82f9\u679c', 'fff')", + data); + + expected = makeFlatVector(1'000, [&](auto row) { + value = ""; + std::string delim = "$*@"; + const std::string& s0 = + c0[row].str().empty() ? c0[row].str() : delim + c0[row].str(); + const std::string& s1 = + c1[row].str().empty() ? c1[row].str() : delim + c1[row].str(); + + value = "aaa" + delim + "测试" + s0 + delim + "eee" + delim + "ddd" + s1 + + delim + "\u82f9\u679c" + delim + "fff"; + return StringView(value); + }); + velox::test::assertEqualVectors(expected, result); + } + + // test concat_ws array + { + using S = StringView; + auto arrayVector = makeNullableArrayVector({ + {S("red"), S("blue")}, + {S("blue"), std::nullopt, S("yellow"), std::nullopt, S("orange")}, + {}, + {std::nullopt}, + {S("red"), S("purple"), S("green")}, + }); + + auto result = evaluate>( + "concat_ws('----', c0)", makeRowVector({arrayVector})); + + auto expected = { + S("red----blue"), + S("blue----yellow----orange"), + S(""), + S(""), + S("red----purple----green"), + }; + + velox::test::assertEqualVectors( + makeFlatVector(expected), result); + } +} } // namespace } // namespace facebook::velox::functions::sparksql::test diff --git a/velox/type/Conversions.h b/velox/type/Conversions.h index 3dc5db429391..f05aa3d68e8a 100644 --- a/velox/type/Conversions.h +++ b/velox/type/Conversions.h @@ -21,6 +21,7 @@ #include #include #include "velox/common/base/Exceptions.h" +#include "velox/external/date/tz.h" #include "velox/type/DecimalUtil.h" #include "velox/type/TimestampConversion.h" #include "velox/type/Type.h" @@ -603,6 +604,23 @@ struct Converter { // -1 day. return Date(seconds / kSecsPerDay - 1); } + + static T cast(const Timestamp& t, const std::string& sessionTzName) { + static const int32_t kSecsPerDay{86'400}; + auto ts = t; + if (!sessionTzName.empty()) { + auto* timeZone = date::locate_zone(sessionTzName); + ts.toTimezone(*timeZone); + } + auto seconds = ts.getSeconds(); + if (seconds >= 0 || seconds % kSecsPerDay == 0) { + return Date(seconds / kSecsPerDay); + } + // For division with negatives, minus 1 to compensate the discarded + // fractional part. e.g. -1/86'400 yields 0, yet it should be considered as + // -1 day. + return Date(seconds / kSecsPerDay - 1); + } }; } // namespace facebook::velox::util diff --git a/velox/type/Timestamp.cpp b/velox/type/Timestamp.cpp index 563289bc68aa..7758d36b8ad0 100644 --- a/velox/type/Timestamp.cpp +++ b/velox/type/Timestamp.cpp @@ -69,9 +69,40 @@ void Timestamp::toGMT(int16_t tzID) { } } +namespace { +void validateTimePoint(const std::chrono::time_point< + std::chrono::system_clock, + std::chrono::milliseconds>& timePoint) { + // Due to the limit of std::chrono we can only represent time in + // [-32767-01-01, 32767-12-31] date range + const auto minTimePoint = date::sys_days{ + date::year_month_day(date::year::min(), date::month(1), date::day(1))}; + const auto maxTimePoint = date::sys_days{ + date::year_month_day(date::year::max(), date::month(12), date::day(31))}; + if (timePoint < minTimePoint || timePoint > maxTimePoint) { + VELOX_USER_FAIL( + "Timestamp is outside of supported range of [{}-{}-{}, {}-{}-{}]", + (int)date::year::min(), + "01", + "01", + (int)date::year::max(), + "12", + "31"); + } +} +} // namespace + +std::chrono::time_point +Timestamp::toTimePoint() const { + auto tp = std::chrono:: + time_point( + std::chrono::milliseconds(toMillis())); + validateTimePoint(tp); + return tp; +} + void Timestamp::toTimezone(const date::time_zone& zone) { - auto tp = std::chrono::time_point( - std::chrono::seconds(seconds_)); + auto tp = toTimePoint(); auto epoch = zone.to_local(tp).time_since_epoch(); seconds_ = std::chrono::duration_cast(epoch).count(); } diff --git a/velox/type/Timestamp.h b/velox/type/Timestamp.h index c3b7c04ff4ee..0dbf889b6236 100644 --- a/velox/type/Timestamp.h +++ b/velox/type/Timestamp.h @@ -117,6 +117,11 @@ struct Timestamp { } } + /// Due to the limit of std::chrono, throws if timestamp is outside of + /// [-32767-01-01, 32767-12-31] range. + std::chrono::time_point + toTimePoint() const; + static Timestamp fromMillis(int64_t millis) { if (millis >= 0 || millis % 1'000 == 0) { return Timestamp(millis / 1'000, (millis % 1'000) * 1'000'000); diff --git a/velox/type/tests/TimestampTest.cpp b/velox/type/tests/TimestampTest.cpp index c157da6f5ff5..08287dad3972 100644 --- a/velox/type/tests/TimestampTest.cpp +++ b/velox/type/tests/TimestampTest.cpp @@ -17,6 +17,7 @@ #include #include "velox/common/base/tests/GTestUtils.h" +#include "velox/external/date/tz.h" #include "velox/type/Timestamp.h" namespace facebook::velox { @@ -158,5 +159,15 @@ TEST(TimestampTest, toString) { EXPECT_EQ("-292275055-05-16T16:47:04.000000000", kMin.toString()); EXPECT_EQ("292278994-08-17T07:12:55.999999999", kMax.toString()); } + +TEST(TimestampTest, outOfRange) { + auto* timezone = date::locate_zone("GMT"); + Timestamp t(-3217830796800, 0); + + VELOX_ASSERT_THROW( + t.toTimePoint(), "Timestamp is outside of supported range"); + VELOX_ASSERT_THROW( + t.toTimezone(*timezone), "Timestamp is outside of supported range"); +} } // namespace } // namespace facebook::velox