From 3bb751de6d85e6d784300e8d17ed23357a3469c4 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 17 Dec 2024 10:19:34 +0000 Subject: [PATCH] Thread stream and mr through empty output construction Since we call make_lists_column in some cases, we need a stream and mr around. --- cpp/src/rolling/detail/rolling.cuh | 21 +++++++++++++------ .../rolling/detail/rolling_fixed_window.cu | 4 +++- .../rolling/detail/rolling_variable_window.cu | 2 +- cpp/src/rolling/grouped_rolling.cu | 8 +++++-- 4 files changed, 25 insertions(+), 10 deletions(-) diff --git a/cpp/src/rolling/detail/rolling.cuh b/cpp/src/rolling/detail/rolling.cuh index bc0ee2eb519..331d36dd4af 100644 --- a/cpp/src/rolling/detail/rolling.cuh +++ b/cpp/src/rolling/detail/rolling.cuh @@ -51,6 +51,7 @@ #include #include +#include #include #include @@ -453,7 +454,10 @@ struct DeviceRollingRowNumber { struct agg_specific_empty_output { template - std::unique_ptr operator()(column_view const& input, rolling_aggregation const&) const + std::unique_ptr operator()(column_view const& input, + rolling_aggregation const&, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) const { using target_type = cudf::detail::target_type_t; @@ -467,15 +471,18 @@ struct agg_specific_empty_output { if constexpr (op == aggregation::COLLECT_LIST) { return cudf::make_lists_column( - 0, make_empty_column(type_to_id()), empty_like(input), 0, {}); + 0, make_empty_column(type_to_id()), empty_like(input), 0, {}, stream, mr); } return empty_like(input); } }; -static std::unique_ptr empty_output_for_rolling_aggregation(column_view const& input, - rolling_aggregation const& agg) +static std::unique_ptr empty_output_for_rolling_aggregation( + column_view const& input, + rolling_aggregation const& agg, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) { // TODO: // Ideally, for UDF aggregations, the returned column would match @@ -490,7 +497,7 @@ static std::unique_ptr empty_output_for_rolling_aggregation(column_view return agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX ? empty_like(input) : cudf::detail::dispatch_type_and_aggregation( - input.type(), agg.kind, agg_specific_empty_output{}, input, agg); + input.type(), agg.kind, agg_specific_empty_output{}, input, agg, stream, mr); } /** @@ -1326,7 +1333,9 @@ std::unique_ptr rolling_window(column_view const& input, static_assert(warp_size == cudf::detail::size_in_bits(), "bitmask_type size does not match CUDA warp size"); - if (input.is_empty()) { return cudf::detail::empty_output_for_rolling_aggregation(input, agg); } + if (input.is_empty()) { + return cudf::detail::empty_output_for_rolling_aggregation(input, agg, stream, mr); + } if (cudf::is_dictionary(input.type())) { CUDF_EXPECTS(agg.kind == aggregation::COUNT_ALL || agg.kind == aggregation::COUNT_VALID || diff --git a/cpp/src/rolling/detail/rolling_fixed_window.cu b/cpp/src/rolling/detail/rolling_fixed_window.cu index 0603f27852a..a19705b14a9 100644 --- a/cpp/src/rolling/detail/rolling_fixed_window.cu +++ b/cpp/src/rolling/detail/rolling_fixed_window.cu @@ -40,7 +40,9 @@ std::unique_ptr rolling_window(column_view const& input, { CUDF_FUNC_RANGE(); - if (input.is_empty()) { return cudf::detail::empty_output_for_rolling_aggregation(input, agg); } + if (input.is_empty()) { + return cudf::detail::empty_output_for_rolling_aggregation(input, agg, stream, mr); + } CUDF_EXPECTS((min_periods >= 0), "min_periods must be non-negative"); diff --git a/cpp/src/rolling/detail/rolling_variable_window.cu b/cpp/src/rolling/detail/rolling_variable_window.cu index d4851df740b..3afb75a66f3 100644 --- a/cpp/src/rolling/detail/rolling_variable_window.cu +++ b/cpp/src/rolling/detail/rolling_variable_window.cu @@ -39,7 +39,7 @@ std::unique_ptr rolling_window(column_view const& input, CUDF_FUNC_RANGE(); if (preceding_window.is_empty() || following_window.is_empty() || input.is_empty()) { - return cudf::detail::empty_output_for_rolling_aggregation(input, agg); + return cudf::detail::empty_output_for_rolling_aggregation(input, agg, stream, mr); } CUDF_EXPECTS(preceding_window.type().id() == type_id::INT32 && diff --git a/cpp/src/rolling/grouped_rolling.cu b/cpp/src/rolling/grouped_rolling.cu index 3cf292f5abb..3d256236b37 100644 --- a/cpp/src/rolling/grouped_rolling.cu +++ b/cpp/src/rolling/grouped_rolling.cu @@ -158,7 +158,9 @@ std::unique_ptr grouped_rolling_window(table_view const& group_keys, { CUDF_FUNC_RANGE(); - if (input.is_empty()) { return cudf::detail::empty_output_for_rolling_aggregation(input, aggr); } + if (input.is_empty()) { + return cudf::detail::empty_output_for_rolling_aggregation(input, aggr, stream, mr); + } CUDF_EXPECTS((group_keys.num_columns() == 0 || group_keys.num_rows() == input.size()), "Size mismatch between group_keys and input vector."); @@ -1152,7 +1154,9 @@ std::unique_ptr grouped_range_rolling_window(table_view const& group_key { CUDF_FUNC_RANGE(); - if (input.is_empty()) { return cudf::detail::empty_output_for_rolling_aggregation(input, aggr); } + if (input.is_empty()) { + return cudf::detail::empty_output_for_rolling_aggregation(input, aggr, stream, mr); + } CUDF_EXPECTS((group_keys.num_columns() == 0 || group_keys.num_rows() == input.size()), "Size mismatch between group_keys and input vector.");