diff --git a/include/seastar/coroutine/generator.hh b/include/seastar/coroutine/generator.hh index 24598584617..c995af4f015 100644 --- a/include/seastar/coroutine/generator.hh +++ b/include/seastar/coroutine/generator.hh @@ -94,7 +94,6 @@ public: generator_type get_return_object() noexcept; void set_generator(generator_type* g) noexcept { - assert(!_generator); _generator = g; } @@ -190,7 +189,6 @@ public: auto get_return_object() noexcept -> generator_type; void set_generator(generator_type* g) noexcept { - assert(!_generator); _generator = g; } @@ -335,13 +333,27 @@ public: generator(const generator&) = delete; generator(generator&& other) noexcept : _coro{std::exchange(other._coro, {})} - , _buffer_capacity{other._buffer_capacity} {} + , _promise(std::exchange(other._promise, nullptr)) + , _values(std::move(other._values)) + , _buffer_capacity{other._buffer_capacity} + , _exception(std::exchange(other._exception, nullptr)) { + if (_promise) { + _promise->set_generator(this); + } + } generator& operator=(generator&& other) noexcept { if (std::addressof(other) != this) { auto old_coro = std::exchange(_coro, std::exchange(other._coro, {})); if (old_coro) { old_coro.destroy(); } + _promise = std::exchange(other._promise, nullptr); + if (_promise) { + _promise->set_generator(this); + } + _values = std::move(other._values); + const_cast(_buffer_capacity) = other._buffer_capacity; + _exception = std::exchange(other._exception, nullptr); } return *this; } @@ -353,6 +365,16 @@ public: void swap(generator& other) noexcept { std::swap(_coro, other._coro); + std::swap(_promise, other._promise); + if (_promise) { + _promise->set_generator(this); + } + if (other._promise) { + other._promise->set_generator(&other); + } + std::swap(_values, other._values); + std::swap(const_cast(_buffer_capacity), const_cast(other._buffer_capacity)); + std::swap(_exception, other._exception); } internal::next_awaiter operator()() noexcept { @@ -425,13 +447,26 @@ public: } generator(const generator&) = delete; generator(generator&& other) noexcept - : _coro{std::exchange(other._coro, {})} {} + : _coro{std::exchange(other._coro, {})} + , _promise(std::exchange(other._promise, nullptr)) + , _maybe_value(std::exchange(other._maybe_value, std::nullopt)) + , _exception(std::exchange(other._exception, nullptr)) { + if (_promise) { + _promise->set_generator(this); + } + } generator& operator=(generator&& other) noexcept { if (std::addressof(other) != this) { auto old_coro = std::exchange(_coro, std::exchange(other._coro, {})); if (old_coro) { old_coro.destroy(); } + _promise = std::exchange(other._promise, nullptr); + if (_promise) { + _promise->set_generator(this); + } + _maybe_value = std::exchange(other._maybe_value, std::nullopt); + _exception = std::exchange(other._exception, nullptr); } return *this; } @@ -443,6 +478,15 @@ public: void swap(generator& other) noexcept { std::swap(_coro, other._coro); + std::swap(_promise, other._promise); + if (_promise) { + _promise->set_generator(this); + } + if (other._promise) { + other._promise->set_generator(&other); + } + std::swap(_maybe_value, other._maybe_value); + std::swap(_exception, other._exception); } internal::next_awaiter operator()() noexcept { diff --git a/tests/unit/coroutines_test.cc b/tests/unit/coroutines_test.cc index 8e2b8a470c6..6edc186593d 100644 --- a/tests/unit/coroutines_test.cc +++ b/tests/unit/coroutines_test.cc @@ -741,6 +741,14 @@ SEASTAR_TEST_CASE(test_as_future_preemption) { BOOST_REQUIRE_THROW(f0.get(), std::runtime_error); } +std::vector gen_expected_fibs(unsigned count) { + std::vector expected_fibs = {0, 1}; + for (unsigned i = 2; i < count; ++i) { + expected_fibs.emplace_back(expected_fibs[i-2] + expected_fibs[i-1]); + } + return expected_fibs; +} + template class Container> coroutine::experimental::generator fibonacci_sequence(coroutine::experimental::buffer_size_t size, unsigned count) { @@ -755,10 +763,8 @@ fibonacci_sequence(coroutine::experimental::buffer_size_t size, unsigned count) } template class Container> -seastar::future<> test_async_generator_drained() { - auto expected_fibs = {0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55}; - auto fib = fibonacci_sequence(coroutine::experimental::buffer_size_t{2}, - std::size(expected_fibs)); +seastar::future<> test_async_generator_drained(coroutine::experimental::generator fib, unsigned count) { + auto expected_fibs = gen_expected_fibs(count); for (auto expected_fib : expected_fibs) { auto actual_fib = co_await fib(); BOOST_REQUIRE(actual_fib.has_value()); @@ -768,6 +774,37 @@ seastar::future<> test_async_generator_drained() { BOOST_REQUIRE(!sentinel.has_value()); } +template class Container> +seastar::future<> test_async_generator_drained() { + unsigned count = 11; + co_return co_await test_async_generator_drained(fibonacci_sequence(coroutine::experimental::buffer_size_t{2}, count), count); +} + +template class Container> +seastar::future<> test_move_async_generator_drained() { + unsigned count = 11; + auto fib0 = fibonacci_sequence(coroutine::experimental::buffer_size_t{2}, count); + co_await test_async_generator_drained(std::move(fib0), count); + auto fib1 = fibonacci_sequence(coroutine::experimental::buffer_size_t{2}, ++count); + fib0 = std::move(fib1); + co_await test_async_generator_drained(std::move(fib0), count); + fib0 = fibonacci_sequence(coroutine::experimental::buffer_size_t{2}, ++count); + fib1 = fibonacci_sequence(coroutine::experimental::buffer_size_t{2}, ++count); + fib0 = std::move(fib1); + co_await test_async_generator_drained(std::move(fib0), count); +} + +template class Container> +seastar::future<> test_swap_async_generator_drained() { + unsigned count[2] = {11, 17}; + auto fib0 = fibonacci_sequence(coroutine::experimental::buffer_size_t{2}, count[0]); + auto fib1 = fibonacci_sequence(coroutine::experimental::buffer_size_t{2}, count[1]); + std::swap(fib0, fib1); + std::swap(count[0], count[1]); + co_await test_async_generator_drained(std::move(fib0), count[0]); + co_await test_async_generator_drained(std::move(fib1), count[1]); +} + template using buffered_container = circular_buffer; @@ -779,6 +816,55 @@ SEASTAR_TEST_CASE(test_async_generator_drained_unbuffered) { return test_async_generator_drained(); } +SEASTAR_TEST_CASE(test_move_async_generator_drained_buffered) { + return test_move_async_generator_drained(); +} + +SEASTAR_TEST_CASE(test_move_async_generator_drained_unbuffered) { + return test_move_async_generator_drained(); +} + +SEASTAR_TEST_CASE(test_swap_async_generator_drained_buffered) { + return test_swap_async_generator_drained(); +} + +SEASTAR_TEST_CASE(test_swap_async_generator_drained_unbuffered) { + return test_swap_async_generator_drained(); +} + +template class Container> +seastar::future> test_async_generator_drained_incrementally(coroutine::experimental::generator fib, std::optional expected_value) { + auto actual_fib = co_await fib(); + if (expected_value) { + BOOST_REQUIRE(actual_fib.has_value()); + BOOST_REQUIRE_EQUAL(actual_fib.value(), *expected_value); + } else { + BOOST_REQUIRE(!actual_fib.has_value()); + } + co_return fib; +} + +template class Container> +seastar::future<> test_async_generator_drained_incrementally() { + unsigned count = 17; + auto expected_fibs = gen_expected_fibs(count); + auto fib = fibonacci_sequence(coroutine::experimental::buffer_size_t{2}, count); + for (auto it = expected_fibs.begin(); it != expected_fibs.end(); ++it) { + fib = co_await test_async_generator_drained_incrementally(std::move(fib), *it); + } + fib = co_await test_async_generator_drained_incrementally(std::move(fib), std::nullopt); + // once drained generator return std::nullopt + fib = co_await test_async_generator_drained_incrementally(std::move(fib), std::nullopt); +} + +SEASTAR_TEST_CASE(test_async_generator_drained_incrementally_buffered) { + return test_async_generator_drained_incrementally(); +} + +SEASTAR_TEST_CASE(test_async_generator_drained_incrementally_unbuffered) { + return test_async_generator_drained_incrementally(); +} + template class Container> seastar::future<> test_async_generator_not_drained() { auto fib = fibonacci_sequence(coroutine::experimental::buffer_size_t{2},