Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions kagen/edge_range.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,27 @@ EdgeRange::iterator EdgeRange::iterator::operator++(int) {
return tmp;
}

EdgeRange::iterator& EdgeRange::iterator::operator--() {
assert(parent_ != nullptr);

if (parent_->representation_ == GraphRepresentation::EDGE_LIST) {
assert(idx_ > 0 && "Cannot decrement begin iterator");
--idx_;
load_current();
return *this;
}

retreat_to_prev_valid_csr();
load_current();
return *this;
}

EdgeRange::iterator EdgeRange::iterator::operator--(int) {
iterator tmp = *this;
--(*this);
return tmp;
}

bool EdgeRange::iterator::operator==(const iterator& other) const noexcept {
// must belong to same parent to compare reliably
if (parent_ != other.parent_)
Expand Down Expand Up @@ -178,6 +199,51 @@ void EdgeRange::iterator::advance_to_next_valid_csr() noexcept {
}
}

void EdgeRange::iterator::retreat_to_prev_valid_csr() noexcept {
const auto& xadj = *parent_->xadj_ptr_;
const auto& adjncy = *parent_->adjncy_ptr_;
const std::size_t n_local = xadj.empty() ? 0 : (xadj.size() - 1);

// If we're at end(), move to the last valid edge
if (off_ >= adjncy.size()) {
assert(adjncy.size() > 0 && "Cannot decrement begin iterator");
off_ = adjncy.size() - 1;
// Find which vertex this edge belongs to
for (std::size_t v = 0; v < n_local; ++v) {
if (xadj[v] <= off_ && off_ < xadj[v + 1]) {
u_ = v;
return;
}
}
// Should not reach here if data is consistent
assert(false && "Invalid state in retreat_to_prev_valid_csr");
return;
}

// Check if we can move back within the current vertex's adjacency list
if (off_ > 0 && xadj[u_] < off_) {
--off_;
return;
}

// Move to the previous vertex
assert(u_ > 0 && "Cannot decrement begin iterator");
--u_;

// Find the last edge of the previous vertex (skip empty vertices going backward)
while (u_ < n_local && xadj[u_] >= xadj[u_ + 1]) {
if (u_ == 0) {
// No valid edges before this point
assert(false && "Cannot decrement begin iterator");
return;
}
--u_;
}

// Set off_ to the last edge of vertex u_
off_ = xadj[u_ + 1] - 1;
}

EdgeRange::iterator EdgeRange::begin() const noexcept {
if (representation_ == GraphRepresentation::EDGE_LIST)
return iterator::edgelist_begin(this);
Expand Down
7 changes: 6 additions & 1 deletion kagen/edge_range.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class EdgeRange {

class iterator {
public:
using iterator_category = std::forward_iterator_tag;
using iterator_category = std::bidirectional_iterator_tag;
using value_type = Edge;
using difference_type = std::ptrdiff_t;

Expand All @@ -42,6 +42,10 @@ class EdgeRange {

iterator operator++(int);

iterator& operator--();

iterator operator--(int);

bool operator==(const iterator& other) const noexcept;
bool operator!=(const iterator& other) const noexcept;

Expand All @@ -60,6 +64,7 @@ class EdgeRange {
void load_current() noexcept;
void init_csr_begin() noexcept;
void advance_to_next_valid_csr() noexcept;
void retreat_to_prev_valid_csr() noexcept;
};

iterator begin() const noexcept;
Expand Down
82 changes: 82 additions & 0 deletions tests/edge_range.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,85 @@ TEST_P(EdgeRangeTestFixture, iterate_sparse_csr_representation) {
std::iota(expected_indices.begin(), expected_indices.end(), 0);
EXPECT_THAT(iterators, Pointwise(EdgeIndexMatches(), expected_indices));
}

TEST_P(EdgeRangeTestFixture, bidirectional_iteration_edgelist) {
using ::testing::ElementsAreArray;

auto [name, generate] = GetParam();
const SInt n = 100;
const SInt m = 10 * n;

kagen::KaGen generator(MPI_COMM_WORLD);
generator.UseEdgeListRepresentation();
Graph graph = generate(generator, n, m);

Edgelist expected = graph.edges;
EdgeRange edge_range(graph);

if (expected.empty()) {
return;
}

// Forward then backward iteration - edges should match when reversed
std::vector<EdgeRange::Edge> forward_edges(edge_range.begin(), edge_range.end());
std::vector<EdgeRange::Edge> backward_edges;
auto it = edge_range.end();
while (it != edge_range.begin()) {
--it;
backward_edges.push_back(*it);
}
std::reverse(backward_edges.begin(), backward_edges.end());
EXPECT_THAT(backward_edges, ElementsAreArray(forward_edges));

// Mixed forward and backward iteration
auto it2 = edge_range.begin();
++it2;
++it2;
auto edge_at_2 = *it2;
--it2;
auto edge_at_1 = *it2;
++it2;
EXPECT_EQ(*it2, edge_at_2);
EXPECT_EQ(edge_at_1, expected[1]);
}

TEST_P(EdgeRangeTestFixture, bidirectional_iteration_csr) {
using ::testing::ElementsAreArray;

auto [name, generate] = GetParam();
const SInt n = 100;
const SInt m = 10 * n;

kagen::KaGen generator(MPI_COMM_WORLD);
generator.UseCSRRepresentation();
Graph graph = generate(generator, n, m);

Edgelist expected = BuildEdgeListFromCSR(graph.vertex_range, graph.xadj, graph.adjncy);
EdgeRange edge_range(graph);

if (expected.empty()) {
return;
}

// Forward then backward iteration - edges should match when reversed
std::vector<EdgeRange::Edge> forward_edges(edge_range.begin(), edge_range.end());
std::vector<EdgeRange::Edge> backward_edges;
auto it = edge_range.end();
while (it != edge_range.begin()) {
--it;
backward_edges.push_back(*it);
}
std::reverse(backward_edges.begin(), backward_edges.end());
EXPECT_THAT(backward_edges, ElementsAreArray(forward_edges));

// Mixed forward and backward iteration
auto it2 = edge_range.begin();
++it2;
++it2;
auto edge_at_2 = *it2;
--it2;
auto edge_at_1 = *it2;
++it2;
EXPECT_EQ(*it2, edge_at_2);
EXPECT_EQ(edge_at_1, expected[1]);
}