Skip to content

Commit

Permalink
send_messages can now be chained
Browse files Browse the repository at this point in the history
  • Loading branch information
angelip2303 committed Apr 15, 2023
1 parent 2a383cc commit acb63d1
Showing 1 changed file with 209 additions and 59 deletions.
268 changes: 209 additions & 59 deletions src/pregel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,51 @@ impl AsRef<str> for ColumnIdentifier {
}
}

/// This defines a struct `SendMessage` in Rust. It has two properties:
/// `message_direction` and `send_message`. The `message_direction` property
/// is the identifier for the direction of the message. The `send_message`
/// property is the function that determines which messages to send from a
/// vertex to its neighbors.
pub struct SendMessage {
/// `message_direction` is the identifier for the direction of the message.
pub message_direction: Expr,
/// `send_message` is the function that determines which messages to send from a
/// vertex to its neighbors.
pub send_message: Expr,
}

impl SendMessage {
/// The function creates a new instance of the `SendMessage` struct with the
/// specified message direction and send message expression.
///
/// Arguments:
///
/// * `message_direction`: An enum that specifies whether the message should be sent
/// to the source vertex or the destination vertex of an edge.
/// * `send_message`: `send_message` is an expression that represents the message
/// that will be sent from a vertex to its neighbors during the Pregel computation.
/// It can be any valid Rust expression that evaluates to a DataFrame.
///
/// Returns:
///
/// A new instance of the `SendMessage` struct.
pub fn new(message_direction: MessageReceiver, send_message: Expr) -> Self {
// We make this in this manner because we want to use the `src.id` and `edge.dst` columns
// in the send_messages function. This is because how polars works, when joining DataFrames,
// it will keep only the left-hand side of the joins, thus, we need to use the `src.id` and
// `edge.dst` columns to get the correct vertex IDs.
let message_direction = match message_direction {
MessageReceiver::Src => Pregel::src(ColumnIdentifier::Id),
MessageReceiver::Dst => Pregel::edge(ColumnIdentifier::Dst),
};
// Now we create the `SendMessage` struct with everything set up.
SendMessage {
message_direction,
send_message,
}
}
}

/// The Pregel struct represents a Pregel computation with various parameters and
/// expressions.
///
Expand Down Expand Up @@ -90,11 +135,13 @@ pub struct Pregel {
/// `initial_message` is an expression that defines the initial message that
/// each vertex in the graph will receive before the computation starts.
initial_message: Expr,
/// `send_messages` is a tuple containing two expressions. The first expression
/// determines whether the message will go from Src to Dst or vice-versa. The
/// second expression represents the message sending function that determines
/// which messages to send from a vertex to its neighbors.
send_messages: (Expr, Expr),
/// The `send_messages` property is a vector of `SendMessage` structs that represent
/// the message sending functions. The `SendMessage` struct contains two expressions.
/// The first expression represents the message sending function that determines whether
/// the message will go from Src to Dst or vice-versa. The second expression represents
/// the message sending function that determines which messages to send from a
/// vertex to its neighbors.
send_messages: Vec<SendMessage>,
/// `aggregate_messages` is an expression that defines how messages sent to a
/// vertex should be aggregated. In Pregel, messages are sent from one vertex
/// to another and can be aggregated before being processed by the receiving
Expand Down Expand Up @@ -162,11 +209,13 @@ pub struct PregelBuilder {
/// `initial_message` is an expression that defines the initial message that
/// each vertex in the graph will receive before the computation starts.
initial_message: Expr,
/// `send_messages` is a tuple containing two expressions. The first expression
/// determines whether the message will go from Src to Dst or vice-versa. The
/// second expression represents the message sending function that determines
/// which messages to send from a vertex to its neighbors.
send_messages: (Expr, Expr),
/// The `send_messages` property is a vector of `SendMessage` structs that represent
/// the message sending functions. The `SendMessage` struct contains two expressions.
/// The first expression represents the message sending function that determines whether
/// the message will go from Src to Dst or vice-versa. The second expression represents
/// the message sending function that determines which messages to send from a
/// vertex to its neighbors.
send_messages: Vec<SendMessage>,
/// `aggregate_messages` is an expression that defines how messages sent to a
/// vertex should be aggregated. In Pregel, messages are sent from one vertex
/// to another and can be aggregated before being processed by the receiving
Expand Down Expand Up @@ -219,7 +268,7 @@ impl PregelBuilder {
max_iterations: 10,
vertex_column: ColumnIdentifier::Custom("aux"),
initial_message: Default::default(),
send_messages: (Default::default(), Default::default()),
send_messages: Default::default(),
aggregate_messages: Default::default(),
v_prog: Default::default(),
}
Expand Down Expand Up @@ -289,7 +338,45 @@ impl PregelBuilder {
}

/// This function sets the message sending behavior for a Pregel computation in
/// Rust.
/// Rust. Chaining this method allows for multiple message sending behaviors to be
/// specified for a single Pregel computation.
///
/// # Examples
///
/// ```rust
/// use polars::prelude::*;
/// use pregel_rs::graph_frame::GraphFrame;
/// use pregel_rs::pregel::ColumnIdentifier::{Custom, Dst, Id, Src};
/// use pregel_rs::pregel::{MessageReceiver, Pregel, PregelBuilder};
/// use std::error::Error;
///
/// // Simple example of a Pregel algorithm where we chain several `send_messages` calls. In
/// // this example, we send a message to the source of an edge and then to the destination of
/// // the same edge. It has no real use case, but it demonstrates how to chain multiple calls.
/// fn main() -> Result<(), Box<dyn Error>> {
/// let edges = df![
/// Src.as_ref() => [0, 1, 1, 2, 2, 3],
/// Dst.as_ref() => [1, 0, 3, 1, 3, 2],
/// ]?;
///
/// let vertices = df![
/// Id.as_ref() => [0, 1, 2, 3],
/// Custom("value").as_ref() => [3, 6, 2, 1],
/// ]?;
///
/// let pregel = PregelBuilder::new(GraphFrame::new(vertices, edges)?)
/// .max_iterations(4)
/// .with_vertex_column(Custom("aux"))
/// .initial_message(lit(0))
/// .send_messages(MessageReceiver::Src, lit(1))
/// .send_messages(MessageReceiver::Dst, lit(-1))
/// .aggregate_messages(Pregel::msg(None).sum())
/// .v_prog(Pregel::msg(None) + lit(1))
/// .build();
///
/// Ok(println!("{:?}", pregel.run()))
/// }
/// ```
///
/// Arguments:
///
Expand All @@ -298,8 +385,7 @@ impl PregelBuilder {
/// computation.
/// * `send_messages`: `send_messages` is a parameter of type `Expr`. It is used to
/// specify the function that will be applied to each vertex in the graph to send
/// messages to its neighboring vertices. The `send_messages` function takes two
/// arguments: the first argument is the vertex ID of the current vertex, and
/// messages to its neighboring vertices.
///
/// Returns:
///
Expand All @@ -308,16 +394,7 @@ impl PregelBuilder {
/// multiple methods can be called on the same struct instance in a single
/// expression.
pub fn send_messages(mut self, to: MessageReceiver, send_messages: Expr) -> Self {
// We make this in this manner because we want to use the `src.id` and `edge.dst` columns
// in the send_messages function. This is because how polars works, when joining dataframes,
// it will keep only the left-hand side of the joins, thus, we need to use the `src.id` and
// `edge.dst` columns to get the correct vertex IDs.
let to = match to {
MessageReceiver::Src => Pregel::src(ColumnIdentifier::Id),
MessageReceiver::Dst => Pregel::edge(ColumnIdentifier::Dst),
};
// Now we can set the send_messages field of the struct to the provided expression.
self.send_messages = (to, send_messages);
self.send_messages.push(SendMessage::new(to, send_messages));
self
}

Expand Down Expand Up @@ -563,11 +640,22 @@ impl Pregel {
// We create a tuple where we store the column names of the `send_messages` DataFrame. We use
// the `alias` method to ensure that the column names are properly qualified. We also
// do the same for the `aggregate_messages` Expr. And the same with the `v_prog` Expr.
let (send_messages_ids, send_messages_msg) = self.send_messages;
let (send_messages_ids, send_messages_msg) = (
send_messages_ids.alias(&Self::alias(&ColumnIdentifier::Msg, ColumnIdentifier::Id)),
send_messages_msg.alias(ColumnIdentifier::Pregel.as_ref()),
);
let (mut send_messages_ids, mut send_messages_msg): (Vec<Expr>, Vec<Expr>) = self
.send_messages
.iter()
.map(|send_message| {
let message_direction = &send_message.message_direction;
let send_message_expr = &send_message.send_message;
(
message_direction
.to_owned()
.alias(&Self::alias(&ColumnIdentifier::Msg, ColumnIdentifier::Id)),
send_message_expr
.to_owned()
.alias(ColumnIdentifier::Pregel.as_ref()),
)
})
.unzip();
let aggregate_messages = self
.aggregate_messages
.alias(ColumnIdentifier::Pregel.as_ref());
Expand Down Expand Up @@ -628,14 +716,12 @@ impl Pregel {
// are computed by performing an aggregation on the `triplets_df` DataFrame. The aggregation
// is performed on the `msg` column of the `triplets_df` DataFrame, and the aggregation
// function is the one set by the user at the initialization of the model.
let sends_messages_ids_df = &send_messages_ids;
let send_messages_msg_df = &send_messages_msg;
let send_messages = &mut send_messages_ids; // we create a mutable reference to the `send_messages_ids` Vector
let send_messages_msg_df = &mut send_messages_msg; // we create a mutable reference to the `send_messages_msg` Vector
send_messages.append(send_messages_msg_df); // we append the `send_messages_msg` Vector to the `send_messages` Vector
let aggregate_messages_df = &aggregate_messages;
let message_df = triplets_df
.select(vec![
sends_messages_ids_df.to_owned(),
send_messages_msg_df.to_owned(),
])
.select(send_messages)
.groupby([Self::msg(Some(ColumnIdentifier::Id))])
.agg([aggregate_messages_df.to_owned()]);
// We Compute the new values for the vertices. Note that we have to check for possibly
Expand Down Expand Up @@ -685,29 +771,52 @@ impl Pregel {
#[cfg(test)]
mod tests {
use crate::graph_frame::GraphFrame;
use crate::pregel::ColumnIdentifier::{Custom, Dst, Id, Src};
use crate::pregel::{MessageReceiver, Pregel, PregelBuilder};
use crate::pregel::{ColumnIdentifier, MessageReceiver, Pregel, PregelBuilder, SendMessage};
use polars::prelude::*;
use std::error::Error;

fn pagerank_builder(iterations: u8) -> Result<Pregel, Box<dyn Error>> {
let edges = df![
Src.as_ref() => [0, 0, 1, 2, 3, 4, 4, 4],
Dst.as_ref() => [1, 2, 2, 3, 3, 1, 2, 3],
]?;
fn pagerank_graph() -> Result<GraphFrame, String> {
let edges = match df![
ColumnIdentifier::Src.as_ref() => [0, 0, 1, 2, 3, 4, 4, 4],
ColumnIdentifier::Dst.as_ref() => [1, 2, 2, 3, 3, 1, 2, 3],
] {
Ok(edges) => edges,
Err(_) => return Err(String::from("Error creating the edges DataFrame")),
};

let vertices = GraphFrame::from_edges(edges.clone())?.out_degrees()?;
let graph = match GraphFrame::from_edges(edges.clone()) {
Ok(graph) => graph,
Err(_) => return Err(String::from("Error creating the vertices DataFrame")),
};

let vertices = match graph.out_degrees() {
Ok(vertices) => vertices,
Err(_) => {
return Err(String::from(
"Error creating the vertices out degree DataFrame",
))
}
};

match GraphFrame::new(vertices, edges) {
Ok(graph) => Ok(graph),
Err(_) => Err(String::from("Error creating the graph")),
}
}

fn pagerank_builder(iterations: u8) -> Result<Pregel, Box<dyn Error>> {
let graph = pagerank_graph()?;
let damping_factor = 0.85;
let num_vertices: f64 = vertices.column(Id.as_ref())?.len() as f64;
let num_vertices: f64 = graph.vertices.column(ColumnIdentifier::Id.as_ref())?.len() as f64;

Ok(PregelBuilder::new(GraphFrame::new(vertices, edges)?)
Ok(PregelBuilder::new(graph)
.max_iterations(iterations)
.with_vertex_column(Custom("rank"))
.with_vertex_column(ColumnIdentifier::Custom("rank"))
.initial_message(lit(1.0 / num_vertices))
.send_messages(
MessageReceiver::Dst,
Pregel::src(Custom("rank")) / Pregel::src(Custom("out_degree")),
Pregel::src(ColumnIdentifier::Custom("rank"))
/ Pregel::src(ColumnIdentifier::Custom("out_degree")),
)
.aggregate_messages(Pregel::msg(None).sum())
.v_prog(
Expand Down Expand Up @@ -777,16 +886,16 @@ mod tests {

fn max_value_graph() -> Result<GraphFrame, String> {
let edges = match df![
Src.as_ref() => [0, 1, 1, 2, 2, 3],
Dst.as_ref() => [1, 0, 3, 1, 3, 2],
ColumnIdentifier::Src.as_ref() => [0, 1, 1, 2, 2, 3],
ColumnIdentifier::Dst.as_ref() => [1, 0, 3, 1, 3, 2],
] {
Ok(edges) => edges,
Err(_) => return Err(String::from("Error creating the edges DataFrame")),
};

let vertices = match df![
Id.as_ref() => [0, 1, 2, 3],
Custom("value").as_ref() => [3, 6, 2, 1],
ColumnIdentifier::Id.as_ref() => [0, 1, 2, 3],
ColumnIdentifier::Custom("value").as_ref() => [3, 6, 2, 1],
] {
Ok(vertices) => vertices,
Err(_) => return Err(String::from("Error creating the vertices DataFrame")),
Expand All @@ -802,14 +911,17 @@ mod tests {
Ok(Pregel {
graph: max_value_graph()?,
max_iterations: iterations,
vertex_column: Custom("max_value"),
initial_message: col(Custom("value").as_ref()),
send_messages: (
Pregel::edge(MessageReceiver::into(MessageReceiver::Dst)),
Pregel::src(Custom("max_value")),
),
vertex_column: ColumnIdentifier::Custom("max_value"),
initial_message: col(ColumnIdentifier::Custom("value").as_ref()),
send_messages: vec![SendMessage::new(
MessageReceiver::Dst,
Pregel::src(ColumnIdentifier::Custom("value")),
)],
aggregate_messages: Pregel::msg(None).max(),
v_prog: max_exprs([col(Custom("max_value").as_ref()), Pregel::msg(None)]),
v_prog: max_exprs([
col(ColumnIdentifier::Custom("max_value").as_ref()),
Pregel::msg(None),
]),
})
}

Expand Down Expand Up @@ -874,7 +986,7 @@ mod tests {
// useful to test the Pregel model.
match PregelBuilder::new(graph)
.max_iterations(4)
.with_vertex_column(Custom("does_not_matter"))
.with_vertex_column(ColumnIdentifier::Custom("aux"))
.initial_message(lit(0)) // we pass the Undefined state to all vertices
.send_messages(MessageReceiver::Src, lit(0))
.aggregate_messages(lit(0))
Expand All @@ -886,4 +998,42 @@ mod tests {
Err(_) => Err(String::from("Error running the algorithm")),
}
}

#[test]
fn test_send_messages_src_dst() -> Result<(), String> {
let graph = pagerank_graph()?;

let pregel = match PregelBuilder::new(graph)
.max_iterations(4)
.with_vertex_column(ColumnIdentifier::Custom("aux"))
.initial_message(lit(0))
.send_messages(MessageReceiver::Src, lit(1))
.send_messages(MessageReceiver::Dst, lit(-1))
.aggregate_messages(Pregel::msg(None).sum())
.v_prog(Pregel::msg(None) + lit(1))
.build()
.run()
{
Ok(pregel) => pregel,
Err(_) => return Err(String::from("Error running pregel")),
};

let sorted_pregel = match pregel.sort(&["id"], false) {
Ok(sorted_pregel) => sorted_pregel,
Err(_) => return Err(String::from("Error sorting the DataFrame")),
};

let ans = match sorted_pregel.column("aux") {
Ok(ans) => ans,
Err(_) => return Err(String::from("Error retrieving the column")),
};

let expected = Series::new("aux", [3, 2, 2, 2, 4]);

if ans.eq(&expected) {
Ok(())
} else {
Err(String::from("The resulting DataFrame is not correct"))
}
}
}

0 comments on commit acb63d1

Please sign in to comment.