diff --git a/src/pregel.rs b/src/pregel.rs index 85fed18..a07ffd2 100644 --- a/src/pregel.rs +++ b/src/pregel.rs @@ -37,6 +37,51 @@ impl AsRef for ColumnIdentifier { } } +/// This defines a struct `SendMessage` in Rust. It has two properties: +/// `message_direction` and `send_message`. The `message_direction` property +/// is the identifier for the direction of the message. The `send_message` +/// property is the function that determines which messages to send from a +/// vertex to its neighbors. +pub struct SendMessage { + /// `message_direction` is the identifier for the direction of the message. + pub message_direction: Expr, + /// `send_message` is the function that determines which messages to send from a + /// vertex to its neighbors. + pub send_message: Expr, +} + +impl SendMessage { + /// The function creates a new instance of the `SendMessage` struct with the + /// specified message direction and send message expression. + /// + /// Arguments: + /// + /// * `message_direction`: An enum that specifies whether the message should be sent + /// to the source vertex or the destination vertex of an edge. + /// * `send_message`: `send_message` is an expression that represents the message + /// that will be sent from a vertex to its neighbors during the Pregel computation. + /// It can be any valid Rust expression that evaluates to a DataFrame. + /// + /// Returns: + /// + /// A new instance of the `SendMessage` struct. + pub fn new(message_direction: MessageReceiver, send_message: Expr) -> Self { + // We make this in this manner because we want to use the `src.id` and `edge.dst` columns + // in the send_messages function. This is because how polars works, when joining DataFrames, + // it will keep only the left-hand side of the joins, thus, we need to use the `src.id` and + // `edge.dst` columns to get the correct vertex IDs. + let message_direction = match message_direction { + MessageReceiver::Src => Pregel::src(ColumnIdentifier::Id), + MessageReceiver::Dst => Pregel::edge(ColumnIdentifier::Dst), + }; + // Now we create the `SendMessage` struct with everything set up. + SendMessage { + message_direction, + send_message, + } + } +} + /// The Pregel struct represents a Pregel computation with various parameters and /// expressions. /// @@ -90,11 +135,13 @@ pub struct Pregel { /// `initial_message` is an expression that defines the initial message that /// each vertex in the graph will receive before the computation starts. initial_message: Expr, - /// `send_messages` is a tuple containing two expressions. The first expression - /// determines whether the message will go from Src to Dst or vice-versa. The - /// second expression represents the message sending function that determines - /// which messages to send from a vertex to its neighbors. - send_messages: (Expr, Expr), + /// The `send_messages` property is a vector of `SendMessage` structs that represent + /// the message sending functions. The `SendMessage` struct contains two expressions. + /// The first expression represents the message sending function that determines whether + /// the message will go from Src to Dst or vice-versa. The second expression represents + /// the message sending function that determines which messages to send from a + /// vertex to its neighbors. + send_messages: Vec, /// `aggregate_messages` is an expression that defines how messages sent to a /// vertex should be aggregated. In Pregel, messages are sent from one vertex /// to another and can be aggregated before being processed by the receiving @@ -162,11 +209,13 @@ pub struct PregelBuilder { /// `initial_message` is an expression that defines the initial message that /// each vertex in the graph will receive before the computation starts. initial_message: Expr, - /// `send_messages` is a tuple containing two expressions. The first expression - /// determines whether the message will go from Src to Dst or vice-versa. The - /// second expression represents the message sending function that determines - /// which messages to send from a vertex to its neighbors. - send_messages: (Expr, Expr), + /// The `send_messages` property is a vector of `SendMessage` structs that represent + /// the message sending functions. The `SendMessage` struct contains two expressions. + /// The first expression represents the message sending function that determines whether + /// the message will go from Src to Dst or vice-versa. The second expression represents + /// the message sending function that determines which messages to send from a + /// vertex to its neighbors. + send_messages: Vec, /// `aggregate_messages` is an expression that defines how messages sent to a /// vertex should be aggregated. In Pregel, messages are sent from one vertex /// to another and can be aggregated before being processed by the receiving @@ -219,7 +268,7 @@ impl PregelBuilder { max_iterations: 10, vertex_column: ColumnIdentifier::Custom("aux"), initial_message: Default::default(), - send_messages: (Default::default(), Default::default()), + send_messages: Default::default(), aggregate_messages: Default::default(), v_prog: Default::default(), } @@ -289,7 +338,45 @@ impl PregelBuilder { } /// This function sets the message sending behavior for a Pregel computation in - /// Rust. + /// Rust. Chaining this method allows for multiple message sending behaviors to be + /// specified for a single Pregel computation. + /// + /// # Examples + /// + /// ```rust + /// use polars::prelude::*; + /// use pregel_rs::graph_frame::GraphFrame; + /// use pregel_rs::pregel::ColumnIdentifier::{Custom, Dst, Id, Src}; + /// use pregel_rs::pregel::{MessageReceiver, Pregel, PregelBuilder}; + /// use std::error::Error; + /// + /// // Simple example of a Pregel algorithm where we chain several `send_messages` calls. In + /// // this example, we send a message to the source of an edge and then to the destination of + /// // the same edge. It has no real use case, but it demonstrates how to chain multiple calls. + /// fn main() -> Result<(), Box> { + /// let edges = df![ + /// Src.as_ref() => [0, 1, 1, 2, 2, 3], + /// Dst.as_ref() => [1, 0, 3, 1, 3, 2], + /// ]?; + /// + /// let vertices = df![ + /// Id.as_ref() => [0, 1, 2, 3], + /// Custom("value").as_ref() => [3, 6, 2, 1], + /// ]?; + /// + /// let pregel = PregelBuilder::new(GraphFrame::new(vertices, edges)?) + /// .max_iterations(4) + /// .with_vertex_column(Custom("aux")) + /// .initial_message(lit(0)) + /// .send_messages(MessageReceiver::Src, lit(1)) + /// .send_messages(MessageReceiver::Dst, lit(-1)) + /// .aggregate_messages(Pregel::msg(None).sum()) + /// .v_prog(Pregel::msg(None) + lit(1)) + /// .build(); + /// + /// Ok(println!("{:?}", pregel.run())) + /// } + /// ``` /// /// Arguments: /// @@ -298,8 +385,7 @@ impl PregelBuilder { /// computation. /// * `send_messages`: `send_messages` is a parameter of type `Expr`. It is used to /// specify the function that will be applied to each vertex in the graph to send - /// messages to its neighboring vertices. The `send_messages` function takes two - /// arguments: the first argument is the vertex ID of the current vertex, and + /// messages to its neighboring vertices. /// /// Returns: /// @@ -308,16 +394,7 @@ impl PregelBuilder { /// multiple methods can be called on the same struct instance in a single /// expression. pub fn send_messages(mut self, to: MessageReceiver, send_messages: Expr) -> Self { - // We make this in this manner because we want to use the `src.id` and `edge.dst` columns - // in the send_messages function. This is because how polars works, when joining dataframes, - // it will keep only the left-hand side of the joins, thus, we need to use the `src.id` and - // `edge.dst` columns to get the correct vertex IDs. - let to = match to { - MessageReceiver::Src => Pregel::src(ColumnIdentifier::Id), - MessageReceiver::Dst => Pregel::edge(ColumnIdentifier::Dst), - }; - // Now we can set the send_messages field of the struct to the provided expression. - self.send_messages = (to, send_messages); + self.send_messages.push(SendMessage::new(to, send_messages)); self } @@ -563,11 +640,22 @@ impl Pregel { // We create a tuple where we store the column names of the `send_messages` DataFrame. We use // the `alias` method to ensure that the column names are properly qualified. We also // do the same for the `aggregate_messages` Expr. And the same with the `v_prog` Expr. - let (send_messages_ids, send_messages_msg) = self.send_messages; - let (send_messages_ids, send_messages_msg) = ( - send_messages_ids.alias(&Self::alias(&ColumnIdentifier::Msg, ColumnIdentifier::Id)), - send_messages_msg.alias(ColumnIdentifier::Pregel.as_ref()), - ); + let (mut send_messages_ids, mut send_messages_msg): (Vec, Vec) = self + .send_messages + .iter() + .map(|send_message| { + let message_direction = &send_message.message_direction; + let send_message_expr = &send_message.send_message; + ( + message_direction + .to_owned() + .alias(&Self::alias(&ColumnIdentifier::Msg, ColumnIdentifier::Id)), + send_message_expr + .to_owned() + .alias(ColumnIdentifier::Pregel.as_ref()), + ) + }) + .unzip(); let aggregate_messages = self .aggregate_messages .alias(ColumnIdentifier::Pregel.as_ref()); @@ -628,14 +716,12 @@ impl Pregel { // are computed by performing an aggregation on the `triplets_df` DataFrame. The aggregation // is performed on the `msg` column of the `triplets_df` DataFrame, and the aggregation // function is the one set by the user at the initialization of the model. - let sends_messages_ids_df = &send_messages_ids; - let send_messages_msg_df = &send_messages_msg; + let send_messages = &mut send_messages_ids; // we create a mutable reference to the `send_messages_ids` Vector + let send_messages_msg_df = &mut send_messages_msg; // we create a mutable reference to the `send_messages_msg` Vector + send_messages.append(send_messages_msg_df); // we append the `send_messages_msg` Vector to the `send_messages` Vector let aggregate_messages_df = &aggregate_messages; let message_df = triplets_df - .select(vec![ - sends_messages_ids_df.to_owned(), - send_messages_msg_df.to_owned(), - ]) + .select(send_messages) .groupby([Self::msg(Some(ColumnIdentifier::Id))]) .agg([aggregate_messages_df.to_owned()]); // We Compute the new values for the vertices. Note that we have to check for possibly @@ -685,29 +771,52 @@ impl Pregel { #[cfg(test)] mod tests { use crate::graph_frame::GraphFrame; - use crate::pregel::ColumnIdentifier::{Custom, Dst, Id, Src}; - use crate::pregel::{MessageReceiver, Pregel, PregelBuilder}; + use crate::pregel::{ColumnIdentifier, MessageReceiver, Pregel, PregelBuilder, SendMessage}; use polars::prelude::*; use std::error::Error; - fn pagerank_builder(iterations: u8) -> Result> { - let edges = df![ - Src.as_ref() => [0, 0, 1, 2, 3, 4, 4, 4], - Dst.as_ref() => [1, 2, 2, 3, 3, 1, 2, 3], - ]?; + fn pagerank_graph() -> Result { + let edges = match df![ + ColumnIdentifier::Src.as_ref() => [0, 0, 1, 2, 3, 4, 4, 4], + ColumnIdentifier::Dst.as_ref() => [1, 2, 2, 3, 3, 1, 2, 3], + ] { + Ok(edges) => edges, + Err(_) => return Err(String::from("Error creating the edges DataFrame")), + }; - let vertices = GraphFrame::from_edges(edges.clone())?.out_degrees()?; + let graph = match GraphFrame::from_edges(edges.clone()) { + Ok(graph) => graph, + Err(_) => return Err(String::from("Error creating the vertices DataFrame")), + }; + + let vertices = match graph.out_degrees() { + Ok(vertices) => vertices, + Err(_) => { + return Err(String::from( + "Error creating the vertices out degree DataFrame", + )) + } + }; + match GraphFrame::new(vertices, edges) { + Ok(graph) => Ok(graph), + Err(_) => Err(String::from("Error creating the graph")), + } + } + + fn pagerank_builder(iterations: u8) -> Result> { + let graph = pagerank_graph()?; let damping_factor = 0.85; - let num_vertices: f64 = vertices.column(Id.as_ref())?.len() as f64; + let num_vertices: f64 = graph.vertices.column(ColumnIdentifier::Id.as_ref())?.len() as f64; - Ok(PregelBuilder::new(GraphFrame::new(vertices, edges)?) + Ok(PregelBuilder::new(graph) .max_iterations(iterations) - .with_vertex_column(Custom("rank")) + .with_vertex_column(ColumnIdentifier::Custom("rank")) .initial_message(lit(1.0 / num_vertices)) .send_messages( MessageReceiver::Dst, - Pregel::src(Custom("rank")) / Pregel::src(Custom("out_degree")), + Pregel::src(ColumnIdentifier::Custom("rank")) + / Pregel::src(ColumnIdentifier::Custom("out_degree")), ) .aggregate_messages(Pregel::msg(None).sum()) .v_prog( @@ -777,16 +886,16 @@ mod tests { fn max_value_graph() -> Result { let edges = match df![ - Src.as_ref() => [0, 1, 1, 2, 2, 3], - Dst.as_ref() => [1, 0, 3, 1, 3, 2], + ColumnIdentifier::Src.as_ref() => [0, 1, 1, 2, 2, 3], + ColumnIdentifier::Dst.as_ref() => [1, 0, 3, 1, 3, 2], ] { Ok(edges) => edges, Err(_) => return Err(String::from("Error creating the edges DataFrame")), }; let vertices = match df![ - Id.as_ref() => [0, 1, 2, 3], - Custom("value").as_ref() => [3, 6, 2, 1], + ColumnIdentifier::Id.as_ref() => [0, 1, 2, 3], + ColumnIdentifier::Custom("value").as_ref() => [3, 6, 2, 1], ] { Ok(vertices) => vertices, Err(_) => return Err(String::from("Error creating the vertices DataFrame")), @@ -802,14 +911,17 @@ mod tests { Ok(Pregel { graph: max_value_graph()?, max_iterations: iterations, - vertex_column: Custom("max_value"), - initial_message: col(Custom("value").as_ref()), - send_messages: ( - Pregel::edge(MessageReceiver::into(MessageReceiver::Dst)), - Pregel::src(Custom("max_value")), - ), + vertex_column: ColumnIdentifier::Custom("max_value"), + initial_message: col(ColumnIdentifier::Custom("value").as_ref()), + send_messages: vec![SendMessage::new( + MessageReceiver::Dst, + Pregel::src(ColumnIdentifier::Custom("value")), + )], aggregate_messages: Pregel::msg(None).max(), - v_prog: max_exprs([col(Custom("max_value").as_ref()), Pregel::msg(None)]), + v_prog: max_exprs([ + col(ColumnIdentifier::Custom("max_value").as_ref()), + Pregel::msg(None), + ]), }) } @@ -874,7 +986,7 @@ mod tests { // useful to test the Pregel model. match PregelBuilder::new(graph) .max_iterations(4) - .with_vertex_column(Custom("does_not_matter")) + .with_vertex_column(ColumnIdentifier::Custom("aux")) .initial_message(lit(0)) // we pass the Undefined state to all vertices .send_messages(MessageReceiver::Src, lit(0)) .aggregate_messages(lit(0)) @@ -886,4 +998,42 @@ mod tests { Err(_) => Err(String::from("Error running the algorithm")), } } + + #[test] + fn test_send_messages_src_dst() -> Result<(), String> { + let graph = pagerank_graph()?; + + let pregel = match PregelBuilder::new(graph) + .max_iterations(4) + .with_vertex_column(ColumnIdentifier::Custom("aux")) + .initial_message(lit(0)) + .send_messages(MessageReceiver::Src, lit(1)) + .send_messages(MessageReceiver::Dst, lit(-1)) + .aggregate_messages(Pregel::msg(None).sum()) + .v_prog(Pregel::msg(None) + lit(1)) + .build() + .run() + { + Ok(pregel) => pregel, + Err(_) => return Err(String::from("Error running pregel")), + }; + + let sorted_pregel = match pregel.sort(&["id"], false) { + Ok(sorted_pregel) => sorted_pregel, + Err(_) => return Err(String::from("Error sorting the DataFrame")), + }; + + let ans = match sorted_pregel.column("aux") { + Ok(ans) => ans, + Err(_) => return Err(String::from("Error retrieving the column")), + }; + + let expected = Series::new("aux", [3, 2, 2, 2, 4]); + + if ans.eq(&expected) { + Ok(()) + } else { + Err(String::from("The resulting DataFrame is not correct")) + } + } }