diff --git a/be/src/core/string_buffer.hpp b/be/src/core/string_buffer.hpp index c29b024d0c12bf..81223f99b0d99f 100644 --- a/be/src/core/string_buffer.hpp +++ b/be/src/core/string_buffer.hpp @@ -53,6 +53,12 @@ class BufferWritable final { _now_offset = 0; } + char* data() { return reinterpret_cast(_data.data() + _now_offset + _offsets.back()); } + + void add_offset(size_t len) { _now_offset += len; } + + void resize(size_t size) { _data.resize(size + _now_offset + _offsets.back()); } + template void write_number(T data) { fmt::memory_buffer buffer; @@ -236,6 +242,10 @@ class BufferReadable { _data += len; } + const char* data() { return _data; } + + void add_offset(size_t len) { _data += len; } + void read_var_uint(UInt64& x) { x = 0; // get length from first byte firstly diff --git a/be/src/exec/operator/hashjoin_build_sink.cpp b/be/src/exec/operator/hashjoin_build_sink.cpp index 0d5f742692e622..8caae7b2364282 100644 --- a/be/src/exec/operator/hashjoin_build_sink.cpp +++ b/be/src/exec/operator/hashjoin_build_sink.cpp @@ -597,22 +597,19 @@ Status HashJoinBuildSinkLocalState::process_build_block(RuntimeState* state, Blo RETURN_IF_ERROR(_hash_table_init(state, raw_ptrs)); Status st = std::visit( - Overload {[&](std::monostate& arg, auto join_op, - auto short_circuit_for_null_in_build_side, - auto with_other_conjuncts) -> Status { + Overload {[&](std::monostate& arg, auto join_op) -> Status { throw Exception(Status::FatalError("FATAL: uninited hash table")); }, - [&](auto&& arg, auto&& join_op, auto short_circuit_for_null_in_build_side, - auto with_other_conjuncts) -> Status { + [&](auto&& arg, auto&& join_op) -> Status { using HashTableCtxType = std::decay_t; using JoinOpType = std::decay_t; ProcessHashTableBuild hash_table_build_process( rows, raw_ptrs, this, state->batch_size(), state); - auto st = hash_table_build_process.template run< - JoinOpType::value, short_circuit_for_null_in_build_side, - with_other_conjuncts>( + auto st = hash_table_build_process.template run( arg, null_map_val ? &null_map_val->get_data() : nullptr, - &_shared_state->_has_null_in_build_side); + &_shared_state->_has_null_in_build_side, + p._short_circuit_for_null_in_build_side, + p._have_other_join_conjunct); COUNTER_SET(_memory_used_counter, _build_blocks_memory_usage->value() + (int64_t)(arg.hash_table->get_byte_size() + @@ -620,9 +617,7 @@ Status HashJoinBuildSinkLocalState::process_build_block(RuntimeState* state, Blo return st; }}, _shared_state->hash_table_variant_vector.front()->method_variant, - _shared_state->join_op_variants, - make_bool_variant(p._short_circuit_for_null_in_build_side), - make_bool_variant((p._have_other_join_conjunct))); + _shared_state->join_op_variants); return st; } diff --git a/be/src/exec/operator/hashjoin_build_sink.h b/be/src/exec/operator/hashjoin_build_sink.h index bb7c16207d5480..c6f492e8df743c 100644 --- a/be/src/exec/operator/hashjoin_build_sink.h +++ b/be/src/exec/operator/hashjoin_build_sink.h @@ -212,8 +212,9 @@ struct ProcessHashTableBuild { _batch_size(batch_size), _state(state) {} - template - Status run(HashTableContext& hash_table_ctx, ConstNullMapPtr null_map, bool* has_null_key) { + template + Status run(HashTableContext& hash_table_ctx, ConstNullMapPtr null_map, bool* has_null_key, + bool short_circuit_for_null, bool with_other_conjuncts) { if (null_map) { // first row is mocked and is null if (simd::contain_one(null_map->data() + 1, _rows - 1)) { diff --git a/be/src/exec/operator/scan_operator.cpp b/be/src/exec/operator/scan_operator.cpp index cdfad66dde6e67..3d34d32ceb07ef 100644 --- a/be/src/exec/operator/scan_operator.cpp +++ b/be/src/exec/operator/scan_operator.cpp @@ -141,6 +141,7 @@ Status ScanLocalState::init(RuntimeState* state, LocalStateInfo& info) SCOPED_TIMER(exec_time_counter()); SCOPED_TIMER(_init_timer); auto& p = _parent->cast(); + _max_pushdown_conditions_per_column = p._max_pushdown_conditions_per_column; RETURN_IF_ERROR(_helper.init(state, p.is_serial_operator(), p.node_id(), p.operator_id(), _filter_dependencies, p.get_name() + "_FILTER_DEPENDENCY")); RETURN_IF_ERROR(_init_profile()); @@ -480,8 +481,7 @@ Status ScanLocalState::_normalize_predicate(VExprContext* context, cons return Status::OK(); } -template -Status ScanLocalState::_normalize_bloom_filter( +Status ScanLocalStateBase::_normalize_bloom_filter( VExprContext* expr_ctx, const VExprSPtr& root, SlotDescriptor* slot, std::vector>& predicates, PushDownType* pdt) { std::shared_ptr pred = nullptr; @@ -509,8 +509,7 @@ Status ScanLocalState::_normalize_bloom_filter( return Status::OK(); } -template -Status ScanLocalState::_normalize_topn_filter( +Status ScanLocalStateBase::_normalize_topn_filter( VExprContext* expr_ctx, const VExprSPtr& root, SlotDescriptor* slot, std::vector>& predicates, PushDownType* pdt) { std::shared_ptr pred = nullptr; @@ -526,18 +525,16 @@ Status ScanLocalState::_normalize_topn_filter( DCHECK(root->is_topn_filter()); *pdt = _should_push_down_topn_filter(); if (*pdt != PushDownType::UNACCEPTABLE) { - auto& p = _parent->cast(); auto& tmp = _state->get_query_ctx()->get_runtime_predicate( assert_cast(root.get())->source_node_id()); if (_push_down_topn(tmp)) { - pred = tmp.get_predicate(p.node_id()); + pred = tmp.get_predicate(_parent->node_id()); } } return Status::OK(); } -template -Status ScanLocalState::_normalize_bitmap_filter( +Status ScanLocalStateBase::_normalize_bitmap_filter( VExprContext* expr_ctx, const VExprSPtr& root, SlotDescriptor* slot, std::vector>& predicates, PushDownType* pdt) { std::shared_ptr pred = nullptr; @@ -565,10 +562,8 @@ Status ScanLocalState::_normalize_bitmap_filter( return Status::OK(); } -template -Status ScanLocalState::_normalize_function_filters(VExprContext* expr_ctx, - SlotDescriptor* slot, - PushDownType* pdt) { +Status ScanLocalStateBase::_normalize_function_filters(VExprContext* expr_ctx, SlotDescriptor* slot, + PushDownType* pdt) { auto expr = expr_ctx->root()->is_rf_wrapper() ? expr_ctx->root()->get_impl() : expr_ctx->root(); bool opposite = false; VExpr* fn_expr = expr.get(); @@ -648,8 +643,7 @@ std::string ScanLocalState::debug_string(int indentation_level) const { return fmt::to_string(debug_string_buffer); } -template -Status ScanLocalState::_eval_const_conjuncts(VExprContext* expr_ctx, PushDownType* pdt) { +Status ScanLocalStateBase::_eval_const_conjuncts(VExprContext* expr_ctx, PushDownType* pdt) { auto vexpr = expr_ctx->root()->is_rf_wrapper() ? expr_ctx->root()->get_impl() : expr_ctx->root(); // Used to handle constant expressions, such as '1 = 1' _eval_const_conjuncts does not handle cases like 'colA = 1' @@ -697,9 +691,8 @@ Status ScanLocalState::_eval_const_conjuncts(VExprContext* expr_ctx, Pu return Status::OK(); } -template template -Status ScanLocalState::_normalize_in_predicate( +Status ScanLocalStateBase::_normalize_in_predicate( VExprContext* expr_ctx, const VExprSPtr& root, SlotDescriptor* slot, std::vector>& predicates, ColumnValueRange& range, PushDownType* pdt) { @@ -731,8 +724,7 @@ Status ScanLocalState::_normalize_in_predicate( auto is_in = false; if (hybrid_set != nullptr) { // runtime filter produce VDirectInPredicate - if (hybrid_set->size() <= - _parent->cast()._max_pushdown_conditions_per_column) { + if (hybrid_set->size() <= static_cast(_max_pushdown_conditions_per_column)) { iter = hybrid_set->begin(); } is_in = true; @@ -810,9 +802,8 @@ Status ScanLocalState::_normalize_in_predicate( return Status::OK(); } -template template -Status ScanLocalState::_normalize_binary_predicate( +Status ScanLocalStateBase::_normalize_binary_predicate( VExprContext* expr_ctx, const VExprSPtr& root, SlotDescriptor* slot, std::vector>& predicates, ColumnValueRange& range, PushDownType* pdt) { @@ -921,13 +912,12 @@ Status ScanLocalState::_normalize_binary_predicate( return Status::OK(); } -template template -Status ScanLocalState::_change_value_range(bool is_equal_op, - ColumnValueRange& temp_range, - const Field& value, - const ChangeFixedValueRangeFunc& func, - const std::string& fn_name) { +Status ScanLocalStateBase::_change_value_range(bool is_equal_op, + ColumnValueRange& temp_range, + const Field& value, + const ChangeFixedValueRangeFunc& func, + const std::string& fn_name) { if constexpr (PrimitiveType == TYPE_DATE) { auto tmp_value = value.template get(); if (is_equal_op) { @@ -959,9 +949,8 @@ Status ScanLocalState::_change_value_range(bool is_equal_op, return Status::OK(); } -template template -Status ScanLocalState::_normalize_is_null_predicate( +Status ScanLocalStateBase::_normalize_is_null_predicate( VExprContext* expr_ctx, const VExprSPtr& root, SlotDescriptor* slot, std::vector>& predicates, ColumnValueRange& range, PushDownType* pdt) { diff --git a/be/src/exec/operator/scan_operator.h b/be/src/exec/operator/scan_operator.h index 46297a7154777d..635e3c8d593582 100644 --- a/be/src/exec/operator/scan_operator.h +++ b/be/src/exec/operator/scan_operator.h @@ -19,9 +19,11 @@ #include #include +#include #include #include "common/status.h" +#include "core/field.h" #include "exec/common/util.hpp" #include "exec/operator/operator.h" #include "exec/pipeline/dependency.h" @@ -128,6 +130,83 @@ class ScanLocalStateBase : public PipelineXLocalState<> { RuntimeFilterConsumerHelper _helper; // magic number as seed to generate hash value for condition cache uint64_t _condition_cache_digest = 0; + + // Moved from ScanLocalState to avoid re-instantiation for each Derived type. + std::atomic _eos = false; + int _max_pushdown_conditions_per_column = 1024; + // Save all function predicates which may be pushed down to data source. + std::vector _push_down_functions; + + // Virtual methods with default implementations; overridden by subclasses when supported. + // Declared here so that the normalize methods below (non-Derived-template) can call them. + virtual bool _push_down_topn(const RuntimePredicate& predicate) { return false; } + virtual PushDownType _should_push_down_bloom_filter() const { + return PushDownType::UNACCEPTABLE; + } + virtual PushDownType _should_push_down_topn_filter() const { + return PushDownType::UNACCEPTABLE; + } + virtual PushDownType _should_push_down_bitmap_filter() const { + return PushDownType::UNACCEPTABLE; + } + virtual PushDownType _should_push_down_is_null_predicate(VectorizedFnCall* fn_call) const { + return PushDownType::UNACCEPTABLE; + } + virtual PushDownType _should_push_down_in_predicate() const { + return PushDownType::UNACCEPTABLE; + } + virtual PushDownType _should_push_down_binary_predicate( + VectorizedFnCall* fn_call, VExprContext* expr_ctx, Field& constant_val, + const std::set fn_name) const { + return PushDownType::UNACCEPTABLE; + } + virtual Status _should_push_down_function_filter(VectorizedFnCall* fn_call, + VExprContext* expr_ctx, + StringRef* constant_str, + doris::FunctionContext** fn_ctx, + PushDownType& pdt) { + pdt = PushDownType::UNACCEPTABLE; + return Status::OK(); + } + + // Non-templated normalize methods, moved here to avoid re-compilation per Derived type. + Status _eval_const_conjuncts(VExprContext* expr_ctx, PushDownType* pdt); + Status _normalize_bloom_filter(VExprContext* expr_ctx, const VExprSPtr& root, + SlotDescriptor* slot, + std::vector>& predicates, + PushDownType* pdt); + Status _normalize_topn_filter(VExprContext* expr_ctx, const VExprSPtr& root, + SlotDescriptor* slot, + std::vector>& predicates, + PushDownType* pdt); + Status _normalize_bitmap_filter(VExprContext* expr_ctx, const VExprSPtr& root, + SlotDescriptor* slot, + std::vector>& predicates, + PushDownType* pdt); + Status _normalize_function_filters(VExprContext* expr_ctx, SlotDescriptor* slot, + PushDownType* pdt); + + // Inner PrimitiveType-template methods. Moved to base to avoid N(Derived)×M(PrimitiveType) + // instantiation blowup: now instantiated M times total instead of N×M times. + template + Status _normalize_in_predicate(VExprContext* expr_ctx, const VExprSPtr& root, + SlotDescriptor* slot, + std::vector>& predicates, + ColumnValueRange& range, PushDownType* pdt); + template + Status _normalize_binary_predicate(VExprContext* expr_ctx, const VExprSPtr& root, + SlotDescriptor* slot, + std::vector>& predicates, + ColumnValueRange& range, PushDownType* pdt); + template + Status _normalize_is_null_predicate(VExprContext* expr_ctx, const VExprSPtr& root, + SlotDescriptor* slot, + std::vector>& predicates, + ColumnValueRange& range, PushDownType* pdt); + template + Status _change_value_range(bool is_equal_op, ColumnValueRange& range, + const Field& value, const ChangeFixedValueRangeFunc& func, + const std::string& fn_name); }; template @@ -202,37 +281,7 @@ class ScanLocalState : public ScanLocalStateBase { virtual bool _should_push_down_common_expr() { return false; } virtual bool _storage_no_merge() { return false; } - virtual bool _push_down_topn(const RuntimePredicate& predicate) { return false; } virtual bool _is_key_column(const std::string& col_name) { return false; } - virtual PushDownType _should_push_down_bloom_filter() const { - return PushDownType::UNACCEPTABLE; - } - virtual PushDownType _should_push_down_topn_filter() const { - return PushDownType::UNACCEPTABLE; - } - virtual PushDownType _should_push_down_bitmap_filter() const { - return PushDownType::UNACCEPTABLE; - } - virtual PushDownType _should_push_down_is_null_predicate(VectorizedFnCall* fn_call) const { - return PushDownType::UNACCEPTABLE; - } - virtual PushDownType _should_push_down_in_predicate() const { - return PushDownType::UNACCEPTABLE; - } - virtual PushDownType _should_push_down_binary_predicate( - VectorizedFnCall* fn_call, VExprContext* expr_ctx, Field& constant_val, - const std::set fn_name) const { - return PushDownType::UNACCEPTABLE; - } - - virtual Status _should_push_down_function_filter(VectorizedFnCall* fn_call, - VExprContext* expr_ctx, - StringRef* constant_str, - doris::FunctionContext** fn_ctx, - PushDownType& pdt) { - pdt = PushDownType::UNACCEPTABLE; - return Status::OK(); - } // Create a list of scanners. // The number of scanners is related to the implementation of the data source, @@ -247,46 +296,6 @@ class ScanLocalState : public ScanLocalStateBase { VExprSPtr& output_expr); bool _is_predicate_acting_on_slot(const VExprSPtrs& children, SlotDescriptor** slot_desc, ColumnValueRangeType** range); - Status _eval_const_conjuncts(VExprContext* expr_ctx, PushDownType* pdt); - - template - Status _normalize_in_predicate(VExprContext* expr_ctx, const VExprSPtr& root, - SlotDescriptor* slot, - std::vector>& predicates, - ColumnValueRange& range, PushDownType* pdt); - template - Status _normalize_binary_predicate(VExprContext* expr_ctx, const VExprSPtr& root, - SlotDescriptor* slot, - std::vector>& predicates, - ColumnValueRange& range, PushDownType* pdt); - Status _normalize_bloom_filter(VExprContext* expr_ctx, const VExprSPtr& root, - SlotDescriptor* slot, - std::vector>& predicates, - PushDownType* pdt); - Status _normalize_topn_filter(VExprContext* expr_ctx, const VExprSPtr& root, - SlotDescriptor* slot, - std::vector>& predicates, - PushDownType* pdt); - - Status _normalize_bitmap_filter(VExprContext* expr_ctx, const VExprSPtr& root, - SlotDescriptor* slot, - std::vector>& predicates, - PushDownType* pdt); - - Status _normalize_function_filters(VExprContext* expr_ctx, SlotDescriptor* slot, - PushDownType* pdt); - - template - Status _normalize_is_null_predicate(VExprContext* expr_ctx, const VExprSPtr& root, - SlotDescriptor* slot, - std::vector>& predicates, - ColumnValueRange& range, PushDownType* pdt); - - template - Status _change_value_range(bool is_equal_op, ColumnValueRange& range, - const Field& value, const ChangeFixedValueRangeFunc& func, - const std::string& fn_name); - Status _prepare_scanners(); // Submit the scanner to the thread pool and start execution @@ -310,9 +319,6 @@ class ScanLocalState : public ScanLocalStateBase { atomic_shared_ptr _scanner_ctx = nullptr; - // Save all function predicates which may be pushed down to data source. - std::vector _push_down_functions; - // colname -> cast dst type std::map _cast_types_for_variants; @@ -322,8 +328,6 @@ class ScanLocalState : public ScanLocalStateBase { phmap::flat_hash_map>> _slot_id_to_predicates; std::vector> _or_predicates; - std::atomic _eos = false; - std::vector> _filter_dependencies; // ScanLocalState owns the ownership of scanner, scanner context only has its weakptr diff --git a/be/src/exprs/aggregate/aggregate_function.h b/be/src/exprs/aggregate/aggregate_function.h index 9186c9ad39bf51..0e07f74c1aeab1 100644 --- a/be/src/exprs/aggregate/aggregate_function.h +++ b/be/src/exprs/aggregate/aggregate_function.h @@ -302,12 +302,42 @@ class IAggregateFunction { int version {}; }; +/// Marker base for aggregate function classes that intentionally form an inheritance +/// hierarchy (e.g. AggregateStateUnion -> AggregateStateMerge, or +/// AggregateFunctionForEach -> AggregateFunctionForEachV2) and therefore cannot be +/// marked 'final'. Classes that inherit this marker are exempt from the +/// static_assert in IAggregateFunctionHelper. +struct AggregateFunctionNonFinalBase {}; + /// Implement method to obtain an address of 'add' function. template class IAggregateFunctionHelper : public IAggregateFunction { public: IAggregateFunctionHelper(const DataTypes& argument_types_) - : IAggregateFunction(argument_types_) {} + : IAggregateFunction(argument_types_) { + // NOTE: This static_assert is placed in the constructor body (not at class scope) + // because at class-scope instantiation time Derived is still an incomplete type, + // whereas the constructor body is instantiated lazily when a concrete object is + // constructed, at which point Derived is fully defined. + // + // Marking Derived as 'final' is an *optimization hint*, not a correctness + // requirement. add() is virtual in IAggregateFunction, so subclasses always + // dispatch correctly through the vtable regardless. However, when Derived is + // final, the compiler can see that assert_cast(this)->add(...) + // inside add_batch() / add_batch_single_place() etc. has no further overrides, + // allowing it to devirtualize (and potentially inline) the add() call -- which + // is critical for hot aggregation paths. + // + // Classes that intentionally form inheritance hierarchies (and therefore accept + // the vtable overhead) must inherit AggregateFunctionNonFinalBase to opt out. + static_assert( + std::is_final_v || + std::is_base_of_v, + "Derived should be marked 'final' to allow the compiler to devirtualize " + "add() calls inside add_batch() and related hot paths. " + "If the class intentionally has subclasses, inherit AggregateFunctionNonFinalBase " + "to opt out of this check."); + } void destroy_vec(AggregateDataPtr __restrict place, const size_t num_rows) const noexcept override { diff --git a/be/src/exprs/aggregate/aggregate_function_array_agg.h b/be/src/exprs/aggregate/aggregate_function_array_agg.h index a998dd8fdbd553..eda644add88e94 100644 --- a/be/src/exprs/aggregate/aggregate_function_array_agg.h +++ b/be/src/exprs/aggregate/aggregate_function_array_agg.h @@ -279,7 +279,7 @@ struct AggregateFunctionArrayAggData { //ShowNull is just used to support array_agg because array_agg needs to display NULL //todo: Supports order by sorting for array_agg template -class AggregateFunctionArrayAgg +class AggregateFunctionArrayAgg final : public IAggregateFunctionDataHelper, true>, UnaryExpression, NotNullableAggregateFunction { diff --git a/be/src/exprs/aggregate/aggregate_function_binary.h b/be/src/exprs/aggregate/aggregate_function_binary.h index 96a8d9c3ad3b8f..5f8565a89bba82 100644 --- a/be/src/exprs/aggregate/aggregate_function_binary.h +++ b/be/src/exprs/aggregate/aggregate_function_binary.h @@ -42,7 +42,7 @@ struct StatFunc { }; template -struct AggregateFunctionBinary +struct AggregateFunctionBinary final : public IAggregateFunctionDataHelper>, MultiExpression, diff --git a/be/src/exprs/aggregate/aggregate_function_collect.cpp b/be/src/exprs/aggregate/aggregate_function_collect.cpp index 7fb9e0f6e543bb..4f8d4b83cac2e6 100644 --- a/be/src/exprs/aggregate/aggregate_function_collect.cpp +++ b/be/src/exprs/aggregate/aggregate_function_collect.cpp @@ -17,9 +17,6 @@ #include "exprs/aggregate/aggregate_function_collect.h" -#include "common/exception.h" -#include "common/status.h" -#include "core/call_on_type_index.h" #include "exprs/aggregate/aggregate_function_simple_factory.h" #include "exprs/aggregate/factory_helpers.h" #include "exprs/aggregate/helpers.h" @@ -27,104 +24,14 @@ namespace doris { #include "common/compile_check_begin.h" -template -AggregateFunctionPtr do_create_agg_function_collect(bool distinct, const DataTypes& argument_types, - const bool result_is_nullable, - - const AggregateFunctionAttr& attr) { - if (distinct) { - if constexpr (T == INVALID_TYPE) { - throw Exception(ErrorCode::INTERNAL_ERROR, - "unexpected type for collect, please check the input"); - } else { - return creator_without_type::create, HasLimit>>( - argument_types, result_is_nullable, attr); - } - } else { - return creator_without_type::create< - AggregateFunctionCollect, HasLimit>>( - argument_types, result_is_nullable, attr); - } -} - -template -AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string& name, - const DataTypes& argument_types, - const bool result_is_nullable, - - const AggregateFunctionAttr& attr) { - bool distinct = name == "collect_set"; - - switch (argument_types[0]->get_primitive_type()) { - case PrimitiveType::TYPE_BOOLEAN: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_TINYINT: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_SMALLINT: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_INT: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_BIGINT: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_LARGEINT: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_FLOAT: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_DOUBLE: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_DECIMAL32: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_DECIMAL64: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_DECIMALV2: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_DECIMAL128I: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_DECIMAL256: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_DATE: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_DATETIME: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_DATEV2: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_DATETIMEV2: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_IPV6: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_IPV4: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - case PrimitiveType::TYPE_STRING: - case PrimitiveType::TYPE_CHAR: - case PrimitiveType::TYPE_VARCHAR: - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - default: - // We do not care what the real type is. - return do_create_agg_function_collect(distinct, argument_types, - result_is_nullable, attr); - } -} +// Forward declarations — template instantiations live in separate TUs +AggregateFunctionPtr create_aggregate_function_collect_no_limit(const std::string& name, + const DataTypes& argument_types, + const bool result_is_nullable, + const AggregateFunctionAttr& attr); +AggregateFunctionPtr create_aggregate_function_collect_with_limit( + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, + const AggregateFunctionAttr& attr); AggregateFunctionPtr create_aggregate_function_collect(const std::string& name, const DataTypes& argument_types, @@ -133,11 +40,11 @@ AggregateFunctionPtr create_aggregate_function_collect(const std::string& name, const AggregateFunctionAttr& attr) { assert_arity_range(name, argument_types, 1, 2); if (argument_types.size() == 1) { - return create_aggregate_function_collect_impl(name, argument_types, - result_is_nullable, attr); + return create_aggregate_function_collect_no_limit(name, argument_types, result_is_nullable, + attr); } if (argument_types.size() == 2) { - return create_aggregate_function_collect_impl(name, argument_types, + return create_aggregate_function_collect_with_limit(name, argument_types, result_is_nullable, attr); } return nullptr; diff --git a/be/src/exprs/aggregate/aggregate_function_collect.h b/be/src/exprs/aggregate/aggregate_function_collect.h index c5d4e169f76202..b93815e3880aa1 100644 --- a/be/src/exprs/aggregate/aggregate_function_collect.h +++ b/be/src/exprs/aggregate/aggregate_function_collect.h @@ -396,7 +396,7 @@ struct AggregateFunctionCollectListData { }; template -class AggregateFunctionCollect +class AggregateFunctionCollect final : public IAggregateFunctionDataHelper, true>, VarargsExpression, NotNullableAggregateFunction { diff --git a/be/src/exprs/aggregate/aggregate_function_collect_impl.h b/be/src/exprs/aggregate/aggregate_function_collect_impl.h new file mode 100644 index 00000000000000..7c52b20b495e33 --- /dev/null +++ b/be/src/exprs/aggregate/aggregate_function_collect_impl.h @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "common/exception.h" +#include "common/status.h" +#include "core/call_on_type_index.h" +#include "exprs/aggregate/aggregate_function_collect.h" +#include "exprs/aggregate/factory_helpers.h" +#include "exprs/aggregate/helpers.h" + +namespace doris { +#include "common/compile_check_begin.h" + +template +AggregateFunctionPtr do_create_agg_function_collect(bool distinct, const DataTypes& argument_types, + const bool result_is_nullable, + + const AggregateFunctionAttr& attr) { + if (distinct) { + if constexpr (T == INVALID_TYPE) { + throw Exception(ErrorCode::INTERNAL_ERROR, + "unexpected type for collect, please check the input"); + } else { + return creator_without_type::create, HasLimit>>( + argument_types, result_is_nullable, attr); + } + } else { + return creator_without_type::create< + AggregateFunctionCollect, HasLimit>>( + argument_types, result_is_nullable, attr); + } +} + +template +AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string& name, + const DataTypes& argument_types, + const bool result_is_nullable, + + const AggregateFunctionAttr& attr) { + bool distinct = name == "collect_set"; + + AggregateFunctionPtr agg_fn; + auto call = [&](const auto& type) -> bool { + using DispatcType = std::decay_t; + agg_fn = do_create_agg_function_collect( + distinct, argument_types, result_is_nullable, attr); + return true; + }; + + if (!dispatch_switch_all(argument_types[0]->get_primitive_type(), call)) { + // We do not care what the real type is. + agg_fn = do_create_agg_function_collect(distinct, argument_types, + result_is_nullable, attr); + } + return agg_fn; +} + +} // namespace doris +#include "common/compile_check_end.h" diff --git a/be/src/exprs/aggregate/aggregate_function_collect_limit.cpp b/be/src/exprs/aggregate/aggregate_function_collect_limit.cpp new file mode 100644 index 00000000000000..318529689dc895 --- /dev/null +++ b/be/src/exprs/aggregate/aggregate_function_collect_limit.cpp @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "exprs/aggregate/aggregate_function_collect_impl.h" + +namespace doris { +#include "common/compile_check_begin.h" + +AggregateFunctionPtr create_aggregate_function_collect_with_limit( + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, + const AggregateFunctionAttr& attr) { + return create_aggregate_function_collect_impl(name, argument_types, result_is_nullable, + attr); +} + +} // namespace doris diff --git a/be/src/exprs/aggregate/aggregate_function_collect_no_limit.cpp b/be/src/exprs/aggregate/aggregate_function_collect_no_limit.cpp new file mode 100644 index 00000000000000..ba81be14068bdf --- /dev/null +++ b/be/src/exprs/aggregate/aggregate_function_collect_no_limit.cpp @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "exprs/aggregate/aggregate_function_collect_impl.h" + +namespace doris { +#include "common/compile_check_begin.h" + +AggregateFunctionPtr create_aggregate_function_collect_no_limit(const std::string& name, + const DataTypes& argument_types, + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { + return create_aggregate_function_collect_impl(name, argument_types, result_is_nullable, + attr); +} + +} // namespace doris diff --git a/be/src/exprs/aggregate/aggregate_function_covar.h b/be/src/exprs/aggregate/aggregate_function_covar.h index 7607ff656fe877..fe6840719dbbc5 100644 --- a/be/src/exprs/aggregate/aggregate_function_covar.h +++ b/be/src/exprs/aggregate/aggregate_function_covar.h @@ -138,7 +138,7 @@ struct SampData : BaseData { }; template -class AggregateFunctionSampCovariance +class AggregateFunctionSampCovariance final : public IAggregateFunctionDataHelper>, MultiExpression, NullableAggregateFunction { diff --git a/be/src/exprs/aggregate/aggregate_function_distinct.h b/be/src/exprs/aggregate/aggregate_function_distinct.h index 188e34369f984f..618d9b46f41996 100644 --- a/be/src/exprs/aggregate/aggregate_function_distinct.h +++ b/be/src/exprs/aggregate/aggregate_function_distinct.h @@ -261,7 +261,7 @@ struct AggregateFunctionDistinctMultipleGenericData * Adding -Distinct suffix to aggregate function **/ template