Skip to content

Commit

Permalink
Thread stream and mr through empty output construction
Browse files Browse the repository at this point in the history
Since we call make_lists_column in some cases, we need a stream and mr
around.
  • Loading branch information
wence- committed Dec 17, 2024
1 parent 4591523 commit 3bb751d
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 10 deletions.
21 changes: 15 additions & 6 deletions cpp/src/rolling/detail/rolling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>
#include <rmm/resource_ref.hpp>

#include <cuda/std/climits>
#include <cuda/std/limits>
Expand Down Expand Up @@ -453,7 +454,10 @@ struct DeviceRollingRowNumber {

struct agg_specific_empty_output {
template <typename InputType, aggregation::Kind op>
std::unique_ptr<column> operator()(column_view const& input, rolling_aggregation const&) const
std::unique_ptr<column> operator()(column_view const& input,
rolling_aggregation const&,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr) const
{
using target_type = cudf::detail::target_type_t<InputType, op>;

Expand All @@ -467,15 +471,18 @@ struct agg_specific_empty_output {

if constexpr (op == aggregation::COLLECT_LIST) {
return cudf::make_lists_column(
0, make_empty_column(type_to_id<size_type>()), empty_like(input), 0, {});
0, make_empty_column(type_to_id<size_type>()), empty_like(input), 0, {}, stream, mr);
}

return empty_like(input);
}
};

static std::unique_ptr<column> empty_output_for_rolling_aggregation(column_view const& input,
rolling_aggregation const& agg)
static std::unique_ptr<column> empty_output_for_rolling_aggregation(
column_view const& input,
rolling_aggregation const& agg,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
// TODO:
// Ideally, for UDF aggregations, the returned column would match
Expand All @@ -490,7 +497,7 @@ static std::unique_ptr<column> empty_output_for_rolling_aggregation(column_view
return agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX
? empty_like(input)
: cudf::detail::dispatch_type_and_aggregation(
input.type(), agg.kind, agg_specific_empty_output{}, input, agg);
input.type(), agg.kind, agg_specific_empty_output{}, input, agg, stream, mr);
}

/**
Expand Down Expand Up @@ -1326,7 +1333,9 @@ std::unique_ptr<column> rolling_window(column_view const& input,
static_assert(warp_size == cudf::detail::size_in_bits<cudf::bitmask_type>(),
"bitmask_type size does not match CUDA warp size");

if (input.is_empty()) { return cudf::detail::empty_output_for_rolling_aggregation(input, agg); }
if (input.is_empty()) {
return cudf::detail::empty_output_for_rolling_aggregation(input, agg, stream, mr);
}

if (cudf::is_dictionary(input.type())) {
CUDF_EXPECTS(agg.kind == aggregation::COUNT_ALL || agg.kind == aggregation::COUNT_VALID ||
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/rolling/detail/rolling_fixed_window.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ std::unique_ptr<column> rolling_window(column_view const& input,
{
CUDF_FUNC_RANGE();

if (input.is_empty()) { return cudf::detail::empty_output_for_rolling_aggregation(input, agg); }
if (input.is_empty()) {
return cudf::detail::empty_output_for_rolling_aggregation(input, agg, stream, mr);
}

CUDF_EXPECTS((min_periods >= 0), "min_periods must be non-negative");

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/rolling/detail/rolling_variable_window.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ std::unique_ptr<column> rolling_window(column_view const& input,
CUDF_FUNC_RANGE();

if (preceding_window.is_empty() || following_window.is_empty() || input.is_empty()) {
return cudf::detail::empty_output_for_rolling_aggregation(input, agg);
return cudf::detail::empty_output_for_rolling_aggregation(input, agg, stream, mr);
}

CUDF_EXPECTS(preceding_window.type().id() == type_id::INT32 &&
Expand Down
8 changes: 6 additions & 2 deletions cpp/src/rolling/grouped_rolling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ std::unique_ptr<column> grouped_rolling_window(table_view const& group_keys,
{
CUDF_FUNC_RANGE();

if (input.is_empty()) { return cudf::detail::empty_output_for_rolling_aggregation(input, aggr); }
if (input.is_empty()) {
return cudf::detail::empty_output_for_rolling_aggregation(input, aggr, stream, mr);
}

CUDF_EXPECTS((group_keys.num_columns() == 0 || group_keys.num_rows() == input.size()),
"Size mismatch between group_keys and input vector.");
Expand Down Expand Up @@ -1152,7 +1154,9 @@ std::unique_ptr<column> grouped_range_rolling_window(table_view const& group_key
{
CUDF_FUNC_RANGE();

if (input.is_empty()) { return cudf::detail::empty_output_for_rolling_aggregation(input, aggr); }
if (input.is_empty()) {
return cudf::detail::empty_output_for_rolling_aggregation(input, aggr, stream, mr);
}

CUDF_EXPECTS((group_keys.num_columns() == 0 || group_keys.num_rows() == input.size()),
"Size mismatch between group_keys and input vector.");
Expand Down

0 comments on commit 3bb751d

Please sign in to comment.