Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/gandiva/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ set(SRC_FILES
context_helper.cc
decimal_ir.cc
decimal_type_util.cc
timestamp_ir.cc
decimal_xlarge.cc
engine.cc
date_utils.cc
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/gandiva/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
#include "gandiva/decimal_ir.h"
#include "gandiva/exported_funcs.h"
#include "gandiva/exported_funcs_registry.h"
#include "gandiva/timestamp_ir.h"

namespace gandiva {

Expand Down Expand Up @@ -228,7 +229,11 @@ Result<std::unique_ptr<llvm::orc::LLJIT>> BuildJIT(
#endif

jit_builder.setJITTargetMachineBuilder(std::move(jtmb));
#if LLVM_VERSION_MAJOR >= 17
jit_builder.setDataLayout(std::make_optional(data_layout));
#else
jit_builder.setDataLayout(llvm::Optional<llvm::DataLayout>(data_layout));
#endif

if (object_cache.has_value()) {
jit_builder.setCompileFunctionCreator(
Expand Down Expand Up @@ -325,6 +330,7 @@ Status Engine::LoadFunctionIRs() {
if (!functions_loaded_) {
ARROW_RETURN_NOT_OK(LoadPreCompiledIR());
ARROW_RETURN_NOT_OK(DecimalIR::AddFunctions(this));
ARROW_RETURN_NOT_OK(TimestampIR::AddFunctions(this));
ARROW_RETURN_NOT_OK(LoadExternalPreCompiledIR());
functions_loaded_ = true;
}
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/gandiva/function_signature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ bool DataTypeEquals(const DataTypePtr& left, const DataTypePtr& right) {
return (dleft != NULL) && (dright != NULL) &&
(dleft->byte_width() == dright->byte_width());
}
case arrow::Type::TIMESTAMP: {
// For timestamp types, the TimeUnit isn't part of the signature
// (conversion is handled at codegen time by TimestampIR).
// However, timezone IS significant — a function registered for
// timestamp(null tz) should not match timestamp("America/New_York").
auto tleft = checked_cast<arrow::TimestampType*>(left.get());
auto tright = checked_cast<arrow::TimestampType*>(right.get());
return tleft->timezone() == tright->timezone();
}
default:
return left->Equals(right);
}
Expand Down
67 changes: 58 additions & 9 deletions cpp/src/gandiva/llvm_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
#include <utility>
#include <vector>

#include "arrow/type.h"
#include "gandiva/bitmap_accumulator.h"
#include "gandiva/decimal_ir.h"
#include "gandiva/dex.h"
#include "gandiva/expr_decomposer.h"
#include "gandiva/expression.h"
#include "gandiva/llvm_types.h"
#include "gandiva/lvalue.h"
#include "gandiva/timestamp_ir.h"

namespace gandiva {

Expand Down Expand Up @@ -384,6 +386,7 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count,
Visitor visitor(this, fn, loop_entry, arg_addrs, arg_local_bitmaps, arg_holder_ptrs,
slice_offsets, arg_context_ptr, position_var);
value_expr->Accept(visitor);
ARROW_RETURN_NOT_OK(visitor.status());
LValuePtr output_value = visitor.result();

// The "current" block may have changed due to code generation in the visitor.
Expand Down Expand Up @@ -813,7 +816,8 @@ void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) {
auto then_lambda = [&] {
ADD_VISITOR_TRACE("fn " + function_name +
" can return errors : all args valid, invoke fn");
return BuildFunctionCall(native_function, arrow_return_type, &params);
return BuildFunctionCall(native_function, arrow_return_type, &params,
dex.func_descriptor());
};

// else block
Expand All @@ -831,7 +835,9 @@ void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) {
result_ = BuildIfElse(is_valid, then_lambda, else_lambda, arrow_return_type);
} else {
// fast path : invoke function without computing validities.
result_ = BuildFunctionCall(native_function, arrow_return_type, &params);
result_ = BuildFunctionCall(native_function, arrow_return_type, &params,
dex.func_descriptor());
if (!status_.ok()) return;
}
}

Expand All @@ -844,7 +850,8 @@ void LLVMGenerator::Visitor::Visit(const NullableNeverFuncDex& dex) {
native_function->NeedsContext());

auto arrow_return_type = dex.func_descriptor()->return_type();
result_ = BuildFunctionCall(native_function, arrow_return_type, &params);
result_ = BuildFunctionCall(native_function, arrow_return_type, &params,
dex.func_descriptor());
}

void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex& dex) {
Expand Down Expand Up @@ -1084,6 +1091,9 @@ void LLVMGenerator::Visitor::VisitInExpression(const InExprDexBase<Type>& dex) {
for (auto& pair : dex.args()) {
DexPtr value_expr = pair->value_expr();
value_expr->Accept(*this);
if (!status_.ok()) {
return;
}
LValue& result_ref = *result();
params.push_back(result_ref.data());

Expand Down Expand Up @@ -1235,6 +1245,9 @@ LValuePtr LLVMGenerator::Visitor::BuildValueAndValidity(const ValueValidityPair&
// generate code for value
auto value_expr = pair.value_expr();
value_expr->Accept(*this);
if (!status_.ok()) {
return nullptr;
}
auto value = result()->data();
auto length = result()->length();

Expand All @@ -1246,12 +1259,44 @@ LValuePtr LLVMGenerator::Visitor::BuildValueAndValidity(const ValueValidityPair&

LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func,
DataTypePtr arrow_return_type,
std::vector<llvm::Value*>* params) {
std::vector<llvm::Value*>* params,
const FuncDescriptorPtr& descriptor) {
auto types = generator_->types();
auto arrow_return_type_id = arrow_return_type->id();
auto llvm_return_type = types->IRType(arrow_return_type_id);
DecimalIR decimalIR(generator_->engine_.get());

// Resolve the function name — may remap to a TimestampIR-built variant
// based on the actual TimeUnit from the expression tree.
std::string pc_name = func->pc_name();
if (descriptor != nullptr) {
arrow::TimeUnit::type ts_unit = arrow::TimeUnit::MILLI;
bool found_ts = false;
for (auto& param : descriptor->params()) {
if (param->id() == arrow::Type::TIMESTAMP) {
auto unit =
arrow::internal::checked_cast<const arrow::TimestampType&>(*param).unit();
if (!found_ts) {
ts_unit = unit;
found_ts = true;
} else if (unit != ts_unit) {
status_ = Status::Invalid(
"Gandiva cannot compile expression: mixed timestamp units in function '",
pc_name, "'. All timestamp arguments must have the same TimeUnit.");
return nullptr;
}
}
}
if (found_ts && ts_unit != arrow::TimeUnit::MILLI) {
std::string suffix = (ts_unit == arrow::TimeUnit::MICRO) ? "_us" : "_ns";
std::string remapped = pc_name + suffix;
ARROW_LOG(DEBUG) << "TimestampIR remap: " << pc_name << " -> " << remapped;
if (TimestampIR::IsTimestampIRFunction(remapped)) {
pc_name = remapped;
}
}
}

if (arrow_return_type_id == arrow::Type::DECIMAL) {
// For decimal fns, the output precision/scale are passed along as parameters.
//
Expand All @@ -1266,7 +1311,7 @@ LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func,
params->push_back(ret_lvalue->scale());

// Make the function call
auto out = decimalIR.CallDecimalFunction(func->pc_name(), llvm_return_type, *params);
auto out = decimalIR.CallDecimalFunction(pc_name, llvm_return_type, *params);
ret_lvalue->set_data(out);
return ret_lvalue;
} else {
Expand All @@ -1287,10 +1332,14 @@ LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func,

// Make the function call
llvm::IRBuilder<>* builder = ir_builder();
auto value =
isDecimalFunction
? decimalIR.CallDecimalFunction(func->pc_name(), llvm_return_type, *params)
: generator_->AddFunctionCall(func->pc_name(), llvm_return_type, *params);
llvm::Value* value;
if (isDecimalFunction) {
value = decimalIR.CallDecimalFunction(pc_name, llvm_return_type, *params);
} else if (auto* ir_fn = generator_->engine_->module()->getFunction(pc_name)) {
value = ir_builder()->CreateCall(ir_fn, *params);
} else {
value = generator_->AddFunctionCall(pc_name, llvm_return_type, *params);
}
auto value_len =
(result_len_ptr == nullptr)
? nullptr
Expand Down
6 changes: 5 additions & 1 deletion cpp/src/gandiva/llvm_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class GANDIVA_EXPORT LLVMGenerator {

bool has_arena_allocs() { return has_arena_allocs_; }

const Status& status() const { return status_; }

private:
enum BufferType { kBufferTypeValidity = 0, kBufferTypeData, kBufferTypeOffsets };

Expand All @@ -158,7 +160,8 @@ class GANDIVA_EXPORT LLVMGenerator {

// Generate code to invoke a function call.
LValuePtr BuildFunctionCall(const NativeFunction* func, DataTypePtr arrow_return_type,
std::vector<llvm::Value*>* params);
std::vector<llvm::Value*>* params,
const FuncDescriptorPtr& descriptor = nullptr);

// Generate code for an if-else condition.
LValuePtr BuildIfElse(llvm::Value* condition, std::function<LValuePtr()> then_func,
Expand All @@ -179,6 +182,7 @@ class GANDIVA_EXPORT LLVMGenerator {

LLVMGenerator* generator_;
LValuePtr result_;
Status status_;
llvm::Function* function_;
llvm::BasicBlock* entry_block_;
llvm::Value* arg_addrs_;
Expand Down
17 changes: 13 additions & 4 deletions cpp/src/gandiva/precompiled/time.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,10 +442,17 @@ EXTRACT_MINUTE_TIME(time32)

EXTRACT_HOUR_TIME(time32)

#define DATE_TRUNC_FIXED_UNIT(NAME, TYPE, NMILLIS_IN_UNIT) \
FORCE_INLINE \
gdv_##TYPE NAME##_##TYPE(gdv_##TYPE millis) { \
return ((millis / NMILLIS_IN_UNIT) * NMILLIS_IN_UNIT); \
#define DATE_TRUNC_FIXED_UNIT(NAME, TYPE, NMILLIS_IN_UNIT) \
FORCE_INLINE \
gdv_##TYPE NAME##_##TYPE(gdv_##TYPE millis) { \
/* Use floor division to correctly handle negative timestamps (pre-epoch). */ \
/* C++ integer division truncates toward zero; we need toward negative inf. */ \
gdv_##TYPE q = millis / NMILLIS_IN_UNIT; \
gdv_##TYPE r = millis % NMILLIS_IN_UNIT; \
if (r != 0 && (millis ^ NMILLIS_IN_UNIT) < 0) { \
--q; \
} \
return q * NMILLIS_IN_UNIT; \
}

#define DATE_TRUNC_WEEK(TYPE) \
Expand Down Expand Up @@ -927,7 +934,9 @@ const char* castVARCHAR_timestamp_int64(gdv_int64 context, gdv_timestamp in,
gdv_int64 hour = extractHour_timestamp(in);
gdv_int64 minute = extractMinute_timestamp(in);
gdv_int64 second = extractSecond_timestamp(in);
// Use non-negative remainder for sub-second millis (pre-epoch safe).
gdv_int64 millis = in % MILLIS_IN_SEC;
if (millis < 0) millis += MILLIS_IN_SEC;

static const int kTimeStampStringLen = 23;
const int char_buffer_length = kTimeStampStringLen + 1; // snprintf adds \0
Expand Down
Loading
Loading