diff --git a/kagen/edge_range.cpp b/kagen/edge_range.cpp index 66a4f42..ed00079 100644 --- a/kagen/edge_range.cpp +++ b/kagen/edge_range.cpp @@ -96,6 +96,27 @@ EdgeRange::iterator EdgeRange::iterator::operator++(int) { return tmp; } +EdgeRange::iterator& EdgeRange::iterator::operator--() { + assert(parent_ != nullptr); + + if (parent_->representation_ == GraphRepresentation::EDGE_LIST) { + assert(idx_ > 0 && "Cannot decrement begin iterator"); + --idx_; + load_current(); + return *this; + } + + retreat_to_prev_valid_csr(); + load_current(); + return *this; +} + +EdgeRange::iterator EdgeRange::iterator::operator--(int) { + iterator tmp = *this; + --(*this); + return tmp; +} + bool EdgeRange::iterator::operator==(const iterator& other) const noexcept { // must belong to same parent to compare reliably if (parent_ != other.parent_) @@ -178,6 +199,51 @@ void EdgeRange::iterator::advance_to_next_valid_csr() noexcept { } } +void EdgeRange::iterator::retreat_to_prev_valid_csr() noexcept { + const auto& xadj = *parent_->xadj_ptr_; + const auto& adjncy = *parent_->adjncy_ptr_; + const std::size_t n_local = xadj.empty() ? 0 : (xadj.size() - 1); + + // If we're at end(), move to the last valid edge + if (off_ >= adjncy.size()) { + assert(adjncy.size() > 0 && "Cannot decrement begin iterator"); + off_ = adjncy.size() - 1; + // Find which vertex this edge belongs to + for (std::size_t v = 0; v < n_local; ++v) { + if (xadj[v] <= off_ && off_ < xadj[v + 1]) { + u_ = v; + return; + } + } + // Should not reach here if data is consistent + assert(false && "Invalid state in retreat_to_prev_valid_csr"); + return; + } + + // Check if we can move back within the current vertex's adjacency list + if (off_ > 0 && xadj[u_] < off_) { + --off_; + return; + } + + // Move to the previous vertex + assert(u_ > 0 && "Cannot decrement begin iterator"); + --u_; + + // Find the last edge of the previous vertex (skip empty vertices going backward) + while (u_ < n_local && xadj[u_] >= xadj[u_ + 1]) { + if (u_ == 0) { + // No valid edges before this point + assert(false && "Cannot decrement begin iterator"); + return; + } + --u_; + } + + // Set off_ to the last edge of vertex u_ + off_ = xadj[u_ + 1] - 1; +} + EdgeRange::iterator EdgeRange::begin() const noexcept { if (representation_ == GraphRepresentation::EDGE_LIST) return iterator::edgelist_begin(this); diff --git a/kagen/edge_range.h b/kagen/edge_range.h index 1b92507..7e965af 100644 --- a/kagen/edge_range.h +++ b/kagen/edge_range.h @@ -21,7 +21,7 @@ class EdgeRange { class iterator { public: - using iterator_category = std::forward_iterator_tag; + using iterator_category = std::bidirectional_iterator_tag; using value_type = Edge; using difference_type = std::ptrdiff_t; @@ -42,6 +42,10 @@ class EdgeRange { iterator operator++(int); + iterator& operator--(); + + iterator operator--(int); + bool operator==(const iterator& other) const noexcept; bool operator!=(const iterator& other) const noexcept; @@ -60,6 +64,7 @@ class EdgeRange { void load_current() noexcept; void init_csr_begin() noexcept; void advance_to_next_valid_csr() noexcept; + void retreat_to_prev_valid_csr() noexcept; }; iterator begin() const noexcept; diff --git a/tests/edge_range.cpp b/tests/edge_range.cpp index 4256f8e..d6d4c4a 100644 --- a/tests/edge_range.cpp +++ b/tests/edge_range.cpp @@ -146,3 +146,85 @@ TEST_P(EdgeRangeTestFixture, iterate_sparse_csr_representation) { std::iota(expected_indices.begin(), expected_indices.end(), 0); EXPECT_THAT(iterators, Pointwise(EdgeIndexMatches(), expected_indices)); } + +TEST_P(EdgeRangeTestFixture, bidirectional_iteration_edgelist) { + using ::testing::ElementsAreArray; + + auto [name, generate] = GetParam(); + const SInt n = 100; + const SInt m = 10 * n; + + kagen::KaGen generator(MPI_COMM_WORLD); + generator.UseEdgeListRepresentation(); + Graph graph = generate(generator, n, m); + + Edgelist expected = graph.edges; + EdgeRange edge_range(graph); + + if (expected.empty()) { + return; + } + + // Forward then backward iteration - edges should match when reversed + std::vector forward_edges(edge_range.begin(), edge_range.end()); + std::vector backward_edges; + auto it = edge_range.end(); + while (it != edge_range.begin()) { + --it; + backward_edges.push_back(*it); + } + std::reverse(backward_edges.begin(), backward_edges.end()); + EXPECT_THAT(backward_edges, ElementsAreArray(forward_edges)); + + // Mixed forward and backward iteration + auto it2 = edge_range.begin(); + ++it2; + ++it2; + auto edge_at_2 = *it2; + --it2; + auto edge_at_1 = *it2; + ++it2; + EXPECT_EQ(*it2, edge_at_2); + EXPECT_EQ(edge_at_1, expected[1]); +} + +TEST_P(EdgeRangeTestFixture, bidirectional_iteration_csr) { + using ::testing::ElementsAreArray; + + auto [name, generate] = GetParam(); + const SInt n = 100; + const SInt m = 10 * n; + + kagen::KaGen generator(MPI_COMM_WORLD); + generator.UseCSRRepresentation(); + Graph graph = generate(generator, n, m); + + Edgelist expected = BuildEdgeListFromCSR(graph.vertex_range, graph.xadj, graph.adjncy); + EdgeRange edge_range(graph); + + if (expected.empty()) { + return; + } + + // Forward then backward iteration - edges should match when reversed + std::vector forward_edges(edge_range.begin(), edge_range.end()); + std::vector backward_edges; + auto it = edge_range.end(); + while (it != edge_range.begin()) { + --it; + backward_edges.push_back(*it); + } + std::reverse(backward_edges.begin(), backward_edges.end()); + EXPECT_THAT(backward_edges, ElementsAreArray(forward_edges)); + + // Mixed forward and backward iteration + auto it2 = edge_range.begin(); + ++it2; + ++it2; + auto edge_at_2 = *it2; + --it2; + auto edge_at_1 = *it2; + ++it2; + EXPECT_EQ(*it2, edge_at_2); + EXPECT_EQ(edge_at_1, expected[1]); +}