Skip to content

Commit

Permalink
fixing some issues
Browse files Browse the repository at this point in the history
  • Loading branch information
angelip2303 committed Apr 14, 2023
1 parent 56f0dcc commit f21f29e
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 81 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pregel-rs"
version = "0.0.3"
version = "0.0.4"
authors = [ "Ángel Iglesias Préstamo <[email protected]>" ]
description = "A Graph library written in Rust for implementing your own algorithms in a Pregel fashion"
documentation = "https://docs.rs/crate/pregel-rs/latest"
Expand Down
13 changes: 5 additions & 8 deletions examples/maximum_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,17 @@ fn main() -> Result<(), Box<dyn Error>> {

let vertices = df![
Id.as_ref() => [0, 1, 2, 3],
Custom("value".to_owned()).as_ref() => [3, 6, 2, 1],
Custom("value").as_ref() => [3, 6, 2, 1],
]?;

let pregel = PregelBuilder::new(GraphFrame::new(vertices, edges)?)
.max_iterations(4)
.with_vertex_column(Custom("max_value".to_owned()))
.initial_message(col(Custom("value".to_owned()).as_ref()))
.send_messages(
MessageReceiver::Dst,
Pregel::src(Custom("max_value".to_owned())),
)
.with_vertex_column(Custom("max_value"))
.initial_message(col(Custom("value").as_ref()))
.send_messages(MessageReceiver::Dst, Pregel::src(Custom("max_value")))
.aggregate_messages(Pregel::msg(None).max())
.v_prog(max_exprs([
col(Custom("max_value".to_owned()).as_ref()),
col(Custom("max_value").as_ref()),
Pregel::msg(None),
]))
.build();
Expand Down
4 changes: 2 additions & 2 deletions examples/pagerank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ fn main() -> Result<(), Box<dyn Error>> {

let pregel = PregelBuilder::new(GraphFrame::new(vertices, edges)?)
.max_iterations(4)
.with_vertex_column(Custom("rank".to_owned()))
.with_vertex_column(Custom("rank"))
.initial_message(lit(1.0 / num_vertices))
.send_messages(
MessageReceiver::Dst,
Pregel::src(Custom("rank".to_owned())) / Pregel::src(Custom("out_degree".to_owned())),
Pregel::src(Custom("rank")) / Pregel::src(Custom("out_degree")),
)
.aggregate_messages(Pregel::msg(None).sum())
.v_prog(
Expand Down
9 changes: 2 additions & 7 deletions src/graph_frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ type Result<T> = std::result::Result<T, GraphFrameError>;
/// `FromPolars` and `MissingColumn`.
#[derive(Debug)]
pub enum GraphFrameError {
/// `DuckDbError` is a variant of `GraphFrameError` that represents errors that
/// occur when working with the DuckDB database.
DuckDbError(&'static str),
/// `FromPolars` is a variant of `GraphFrameError` that represents errors that
/// occur when converting from a `PolarsError`.
FromPolars(PolarsError),
Expand All @@ -41,7 +38,6 @@ pub enum GraphFrameError {
impl Display for GraphFrameError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
GraphFrameError::DuckDbError(error) => Display::fmt(error, f),
GraphFrameError::FromPolars(error) => Display::fmt(error, f),
GraphFrameError::MissingColumn(error) => Display::fmt(error, f),
}
Expand All @@ -51,7 +47,6 @@ impl Display for GraphFrameError {
impl error::Error for GraphFrameError {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match *self {
GraphFrameError::DuckDbError(_) => None,
GraphFrameError::FromPolars(ref e) => Some(e),
GraphFrameError::MissingColumn(_) => None,
}
Expand Down Expand Up @@ -173,7 +168,7 @@ impl GraphFrame {
self.edges
.lazy()
.groupby([col(Src.as_ref()).alias(Id.as_ref())])
.agg([count().alias(Custom("out_degree".to_owned()).as_ref())])
.agg([count().alias(Custom("out_degree").as_ref())])
.collect()
}

Expand All @@ -192,7 +187,7 @@ impl GraphFrame {
self.edges
.lazy()
.groupby([col(Dst.as_ref())])
.agg([count().alias(Custom("in_degree".to_owned()).as_ref())])
.agg([count().alias(Custom("in_degree").as_ref())])
.collect()
}
}
Expand Down
149 changes: 86 additions & 63 deletions src/pregel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub enum ColumnIdentifier {
/// The `Id` variant represents the column that contains the vertex IDs.
Id,
/// The `Src` variant represents the column that contains the source vertex IDs.
Src,
Src, // TODO: check if the repeated naming is just fine or not
/// The `Dst` variant represents the column that contains the destination vertex IDs.
Dst,
/// The `Edge` variant represents the column that contains the edge IDs.
Expand All @@ -20,21 +20,7 @@ pub enum ColumnIdentifier {
/// The `Pregel` variant represents the column that contains the messages sent to a vertex.
Pregel,
/// The `Custom` variant represents a column that is not one of the predefined columns.
Custom(String),
}

impl From<String> for ColumnIdentifier {
fn from(value: String) -> Self {
match &*value {
"id" => ColumnIdentifier::Id,
"src" => ColumnIdentifier::Src,
"dst" => ColumnIdentifier::Dst,
"edge" => ColumnIdentifier::Edge,
"msg" => ColumnIdentifier::Msg,
"_pregel_msg_" => ColumnIdentifier::Pregel,
_ => ColumnIdentifier::Custom(value),
}
}
Custom(&'static str),
}

impl AsRef<str> for ColumnIdentifier {
Expand Down Expand Up @@ -231,7 +217,7 @@ impl PregelBuilder {
PregelBuilder {
graph,
max_iterations: 10,
vertex_column: ColumnIdentifier::Custom("aux".to_owned()),
vertex_column: ColumnIdentifier::Custom("aux"),
initial_message: Default::default(),
send_messages: (Default::default(), Default::default()),
aggregate_messages: Default::default(),
Expand Down Expand Up @@ -322,7 +308,15 @@ impl PregelBuilder {
/// multiple methods can be called on the same struct instance in a single
/// expression.
pub fn send_messages(mut self, to: MessageReceiver, send_messages: Expr) -> Self {
self.send_messages = (Pregel::edge(MessageReceiver::into(to)), send_messages);
// We make this in this manner because we want to use the `src.id` and `edge.dst` columns
// in the send_messages function. This is because how polars works, when joining dataframes,
// it will keep only the left-hand side of the joins, thus, we need to use the `src.id` and
// `edge.dst` columns to get the correct vertex IDs.
let to = match to {
MessageReceiver::Src => Pregel::src(ColumnIdentifier::Id),
MessageReceiver::Dst => Pregel::edge(ColumnIdentifier::Dst),
};
self.send_messages = (to, send_messages);
self
}

Expand Down Expand Up @@ -395,16 +389,16 @@ impl PregelBuilder {
///
/// let vertices = df![
/// Id.as_ref() => [0, 1, 2, 3],
/// Custom("value".to_owned()).as_ref() => [3, 6, 2, 1],
/// Custom("value").as_ref() => [3, 6, 2, 1],
/// ]?;
///
/// let pregel = PregelBuilder::new(GraphFrame::new(vertices, edges)?)
/// .max_iterations(4)
/// .with_vertex_column(Custom("max_value".to_owned()))
/// .initial_message(col(Custom("value".to_owned()).as_ref()))
/// .send_messages(MessageReceiver::Dst, Pregel::src(Custom("max_value".to_owned())))
/// .with_vertex_column(Custom("max_value"))
/// .initial_message(col(Custom("value").as_ref()))
/// .send_messages(MessageReceiver::Dst, Pregel::src(Custom("max_value")))
/// .aggregate_messages(Pregel::msg(None).max())
/// .v_prog(max_exprs([col(Custom("max_value".to_owned()).as_ref()), Pregel::msg(None)]))
/// .v_prog(max_exprs([col(Custom("max_value").as_ref()), Pregel::msg(None)]))
/// .build();
///
/// Ok(println!("{}", pregel.run()?))
Expand Down Expand Up @@ -434,13 +428,6 @@ impl Pregel {
format!("{}.{}", prefix.as_ref(), column_name.as_ref())
}

fn prefix_columns(expr: Expr, prefix: &'static ColumnIdentifier) -> Expr {
expr.map_alias(|column_name| {
let column_identifier = ColumnIdentifier::from(column_name.to_string());
Ok(Self::alias(prefix, column_identifier))
})
}

/// This function returns an expression for a column identifier representing
/// the source vertex in a Pregel graph.
///
Expand Down Expand Up @@ -573,12 +560,17 @@ impl Pregel {
/// the resulting `DataFrame` or an error of type `PolarsError`.
pub fn run(self) -> PolarsResult<DataFrame> {
// We create a tuple where we store the column names of the `send_messages` DataFrame. We use
// the `alias` method to ensure that the column names are properly qualified.
// the `alias` method to ensure that the column names are properly qualified. We also
// do the same for the `aggregate_messages` Expr. And the same with the `v_prog` Expr.
let (send_messages_ids, send_messages_msg) = self.send_messages;
let (send_messages_ids, send_messages_msg) = (
send_messages_ids.alias(&Self::alias(&ColumnIdentifier::Msg, ColumnIdentifier::Id)),
send_messages_msg.alias(ColumnIdentifier::Pregel.as_ref()),
);
let aggregate_messages = self
.aggregate_messages
.alias(ColumnIdentifier::Pregel.as_ref());
let v_prog = self.v_prog.alias(self.vertex_column.as_ref());
// We create a DataFrame that contains the edges of the graph. This DataFrame is used to
// compute the triplets of the graph, which are used to send messages to the neighboring
// vertices of each vertex in the graph. For us to do so, we select all the columns of the
Expand All @@ -587,7 +579,7 @@ impl Pregel {
.graph
.edges
.lazy()
.select([Self::prefix_columns(all(), &ColumnIdentifier::Edge)]);
.select([all().prefix(&format!("{}.", ColumnIdentifier::Edge.as_ref()))]);
// We create a DataFrame that contains the vertices of the graph
let vertices = &self.graph.vertices.lazy();
// We start the execution of the algorithm from the super-step 0; that is, all the nodes
Expand All @@ -607,7 +599,7 @@ impl Pregel {
// greater than the maximum number of iterations set by the user at the initialization of
// the model (see the `Pregel::new` method). We start by setting the number of iterations to 1.
let mut iteration = 1;
// TODO: check that nodes are not halted :D
// TODO: check that nodes are not halted. If so, we remove them from the `current_vertices` DataFrame.
while iteration <= self.max_iterations {
// We create a DataFrame that contains the triplets of the graph. Those triplets are
// computed by joining the `current_vertices` DataFrame with the `edges` DataFrame
Expand All @@ -618,16 +610,16 @@ impl Pregel {
let current_vertices_df = &current_vertices.lazy();
let triplets_df = current_vertices_df
.to_owned()
.select([Self::prefix_columns(all(), &ColumnIdentifier::Src)])
.select([all().prefix(&format!("{}.", ColumnIdentifier::Src.as_ref()))])
.inner_join(
edges.clone(),
edges.to_owned(),
Self::src(ColumnIdentifier::Id), // src column of the current_vertices DataFrame
Self::edge(ColumnIdentifier::Src), // src column of the edges DataFrame
)
.inner_join(
current_vertices_df
.to_owned()
.select([Self::prefix_columns(all(), &ColumnIdentifier::Dst)]),
.select([all().prefix(&format!("{}.", ColumnIdentifier::Dst.as_ref()))]),
Self::edge(ColumnIdentifier::Dst), // dst column of the resulting DataFrame
Self::dst(ColumnIdentifier::Id), // id column of the current_vertices DataFrame
);
Expand All @@ -637,7 +629,8 @@ impl Pregel {
// function is the one set by the user at the initialization of the model.
let sends_messages_ids_df = &send_messages_ids;
let send_messages_msg_df = &send_messages_msg;
let aggregate_messages_df = &self.aggregate_messages;
let aggregate_messages_df = &aggregate_messages;
println!("{:?}", triplets_df.clone().collect());
let message_df = triplets_df
.select(vec![
sends_messages_ids_df.to_owned(),
Expand All @@ -650,7 +643,7 @@ impl Pregel {
// not exist in the source DataFrame. In case we find any; for example, given a certain
// node having no incoming edges, we have to replace the null value by 0 for the aggregation
// to work properly.
let v_prog_df = &self.v_prog;
let v_prog_df = &v_prog;
let vertex_columns = current_vertices_df
.to_owned()
.outer_join(
Expand All @@ -665,18 +658,14 @@ impl Pregel {
.otherwise(Self::msg(None))
.alias(ColumnIdentifier::Pregel.as_ref()),
)
.select(
// TODO: fix this move: previous iteration of the loop. Improve?
vec![
col(ColumnIdentifier::Id.as_ref()),
v_prog_df.to_owned().alias(self.vertex_column.as_ref()),
],
);
.select(vec![
col(ColumnIdentifier::Id.as_ref()),
v_prog_df.to_owned(),
]);
// We update the `current_vertices` DataFrame with the new values for the vertices. We
// do so by performing an inner join between the `current_vertices` DataFrame and the
// `vertex_columns` DataFrame. The join is performed on the `id` column of the
// `current_vertices` DataFrame and the `id` column of the `vertex_columns` DataFrame.
// TODO: We also check if the nodes have voted to halt. If so, we remove them from the `current_vertices` DataFrame.
current_vertices = vertices
.to_owned()
.inner_join(
Expand Down Expand Up @@ -714,12 +703,11 @@ mod tests {

Ok(PregelBuilder::new(GraphFrame::new(vertices, edges)?)
.max_iterations(iterations)
.with_vertex_column(Custom("rank".to_owned()))
.with_vertex_column(Custom("rank"))
.initial_message(lit(1.0 / num_vertices))
.send_messages(
MessageReceiver::Dst,
Pregel::src(Custom("rank".to_owned()))
/ Pregel::src(Custom("out_degree".to_owned())),
Pregel::src(Custom("rank")) / Pregel::src(Custom("out_degree")),
)
.aggregate_messages(Pregel::msg(None).sum())
.v_prog(
Expand Down Expand Up @@ -787,31 +775,41 @@ mod tests {
Ok(())
}

fn max_value_builder(iterations: u8) -> Result<Pregel, Box<dyn Error>> {
let edges = df![
fn max_value_graph() -> Result<GraphFrame, String> {
let edges = match df![
Src.as_ref() => [0, 1, 1, 2, 2, 3],
Dst.as_ref() => [1, 0, 3, 1, 3, 2],
]?;
] {
Ok(edges) => edges,
Err(_) => return Err(String::from("Error creating the edges DataFrame")),
};

let vertices = df![
let vertices = match df![
Id.as_ref() => [0, 1, 2, 3],
Custom("value".to_owned()).as_ref() => [3, 6, 2, 1],
]?;
Custom("value").as_ref() => [3, 6, 2, 1],
] {
Ok(vertices) => vertices,
Err(_) => return Err(String::from("Error creating the vertices DataFrame")),
};

match GraphFrame::new(vertices, edges) {
Ok(graph) => Ok(graph),
Err(_) => Err(String::from("Error creating the graph")),
}
}

fn max_value_builder(iterations: u8) -> Result<Pregel, String> {
Ok(Pregel {
graph: GraphFrame::new(vertices, edges)?,
graph: max_value_graph()?,
max_iterations: iterations,
vertex_column: Custom("max_value".to_owned()),
initial_message: col(Custom("value".to_owned()).as_ref()),
vertex_column: Custom("max_value"),
initial_message: col(Custom("value").as_ref()),
send_messages: (
Pregel::edge(MessageReceiver::into(MessageReceiver::Dst)),
Pregel::src(Custom("max_value".to_owned())),
Pregel::src(Custom("max_value")),
),
aggregate_messages: Pregel::msg(None).max(),
v_prog: max_exprs([
col(Custom("max_value".to_owned()).as_ref()),
Pregel::msg(None),
]),
v_prog: max_exprs([col(Custom("max_value").as_ref()), Pregel::msg(None)]),
})
}

Expand Down Expand Up @@ -863,4 +861,29 @@ mod tests {

Ok(())
}

#[test]
fn test_literals() -> Result<(), String> {
// We create a graph using the exact same vertices and edges as the one used in the
// MaxValue algorithm. The graph itself is not important, we just need to test the
// Pregel model.
let graph = max_value_graph()?;

// We create a Pregel algorithm that computes nothing, just sends literals to all the vertices
// and then returns the same literal. Note that the algorithm computes nothing, but it is
// useful to test the Pregel model.
match PregelBuilder::new(graph)
.max_iterations(4)
.with_vertex_column(Custom("does_not_matter"))
.initial_message(lit(0)) // we pass the Undefined state to all vertices
.send_messages(MessageReceiver::Src, lit(0))
.aggregate_messages(lit(0))
.v_prog(lit(0))
.build()
.run()
{
Ok(_) => Ok(()),
Err(_) => Err(String::from("Error running the algorithm")),
}
}
}

0 comments on commit f21f29e

Please sign in to comment.