From f21f29e74a3fd7bec5d7f5e295f33ec0318e8e78 Mon Sep 17 00:00:00 2001
From: angelip2303 <angel.iglesias.prestamo@gmail.com>
Date: Fri, 14 Apr 2023 19:47:52 +0200
Subject: [PATCH] fixing some issues

---
 Cargo.toml                |   2 +-
 examples/maximum_value.rs |  13 ++--
 examples/pagerank.rs      |   4 +-
 src/graph_frame.rs        |   9 +--
 src/pregel.rs             | 149 ++++++++++++++++++++++----------------
 5 files changed, 96 insertions(+), 81 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 6e2819e..6827460 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
 [package]
 name = "pregel-rs"
-version = "0.0.3"
+version = "0.0.4"
 authors = [ "Ángel Iglesias Préstamo <angel.iglesias.prestamo@gmail.com>" ]
 description = "A Graph library written in Rust for implementing your own algorithms in a Pregel fashion"
 documentation = "https://docs.rs/crate/pregel-rs/latest"
diff --git a/examples/maximum_value.rs b/examples/maximum_value.rs
index 68df3c7..3d945ba 100644
--- a/examples/maximum_value.rs
+++ b/examples/maximum_value.rs
@@ -20,20 +20,17 @@ fn main() -> Result<(), Box<dyn Error>> {
 
     let vertices = df![
         Id.as_ref() => [0, 1, 2, 3],
-        Custom("value".to_owned()).as_ref() => [3, 6, 2, 1],
+        Custom("value").as_ref() => [3, 6, 2, 1],
     ]?;
 
     let pregel = PregelBuilder::new(GraphFrame::new(vertices, edges)?)
         .max_iterations(4)
-        .with_vertex_column(Custom("max_value".to_owned()))
-        .initial_message(col(Custom("value".to_owned()).as_ref()))
-        .send_messages(
-            MessageReceiver::Dst,
-            Pregel::src(Custom("max_value".to_owned())),
-        )
+        .with_vertex_column(Custom("max_value"))
+        .initial_message(col(Custom("value").as_ref()))
+        .send_messages(MessageReceiver::Dst, Pregel::src(Custom("max_value")))
         .aggregate_messages(Pregel::msg(None).max())
         .v_prog(max_exprs([
-            col(Custom("max_value".to_owned()).as_ref()),
+            col(Custom("max_value").as_ref()),
             Pregel::msg(None),
         ]))
         .build();
diff --git a/examples/pagerank.rs b/examples/pagerank.rs
index 7358197..d2e803a 100644
--- a/examples/pagerank.rs
+++ b/examples/pagerank.rs
@@ -24,11 +24,11 @@ fn main() -> Result<(), Box<dyn Error>> {
 
     let pregel = PregelBuilder::new(GraphFrame::new(vertices, edges)?)
         .max_iterations(4)
-        .with_vertex_column(Custom("rank".to_owned()))
+        .with_vertex_column(Custom("rank"))
         .initial_message(lit(1.0 / num_vertices))
         .send_messages(
             MessageReceiver::Dst,
-            Pregel::src(Custom("rank".to_owned())) / Pregel::src(Custom("out_degree".to_owned())),
+            Pregel::src(Custom("rank")) / Pregel::src(Custom("out_degree")),
         )
         .aggregate_messages(Pregel::msg(None).sum())
         .v_prog(
diff --git a/src/graph_frame.rs b/src/graph_frame.rs
index df44889..3b11536 100644
--- a/src/graph_frame.rs
+++ b/src/graph_frame.rs
@@ -27,9 +27,6 @@ type Result<T> = std::result::Result<T, GraphFrameError>;
 /// `FromPolars` and `MissingColumn`.
 #[derive(Debug)]
 pub enum GraphFrameError {
-    /// `DuckDbError` is a variant of `GraphFrameError` that represents errors that
-    /// occur when working with the DuckDB database.
-    DuckDbError(&'static str),
     /// `FromPolars` is a variant of `GraphFrameError` that represents errors that
     /// occur when converting from a `PolarsError`.
     FromPolars(PolarsError),
@@ -41,7 +38,6 @@ pub enum GraphFrameError {
 impl Display for GraphFrameError {
     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
         match self {
-            GraphFrameError::DuckDbError(error) => Display::fmt(error, f),
             GraphFrameError::FromPolars(error) => Display::fmt(error, f),
             GraphFrameError::MissingColumn(error) => Display::fmt(error, f),
         }
@@ -51,7 +47,6 @@ impl Display for GraphFrameError {
 impl error::Error for GraphFrameError {
     fn source(&self) -> Option<&(dyn error::Error + 'static)> {
         match *self {
-            GraphFrameError::DuckDbError(_) => None,
             GraphFrameError::FromPolars(ref e) => Some(e),
             GraphFrameError::MissingColumn(_) => None,
         }
@@ -173,7 +168,7 @@ impl GraphFrame {
         self.edges
             .lazy()
             .groupby([col(Src.as_ref()).alias(Id.as_ref())])
-            .agg([count().alias(Custom("out_degree".to_owned()).as_ref())])
+            .agg([count().alias(Custom("out_degree").as_ref())])
             .collect()
     }
 
@@ -192,7 +187,7 @@ impl GraphFrame {
         self.edges
             .lazy()
             .groupby([col(Dst.as_ref())])
-            .agg([count().alias(Custom("in_degree".to_owned()).as_ref())])
+            .agg([count().alias(Custom("in_degree").as_ref())])
             .collect()
     }
 }
diff --git a/src/pregel.rs b/src/pregel.rs
index 4edf17d..6b6f299 100644
--- a/src/pregel.rs
+++ b/src/pregel.rs
@@ -10,7 +10,7 @@ pub enum ColumnIdentifier {
     /// The `Id` variant represents the column that contains the vertex IDs.
     Id,
     /// The `Src` variant represents the column that contains the source vertex IDs.
-    Src,
+    Src, // TODO: check if the repeated naming is just fine or not
     /// The `Dst` variant represents the column that contains the destination vertex IDs.
     Dst,
     /// The `Edge` variant represents the column that contains the edge IDs.
@@ -20,21 +20,7 @@ pub enum ColumnIdentifier {
     /// The `Pregel` variant represents the column that contains the messages sent to a vertex.
     Pregel,
     /// The `Custom` variant represents a column that is not one of the predefined columns.
-    Custom(String),
-}
-
-impl From<String> for ColumnIdentifier {
-    fn from(value: String) -> Self {
-        match &*value {
-            "id" => ColumnIdentifier::Id,
-            "src" => ColumnIdentifier::Src,
-            "dst" => ColumnIdentifier::Dst,
-            "edge" => ColumnIdentifier::Edge,
-            "msg" => ColumnIdentifier::Msg,
-            "_pregel_msg_" => ColumnIdentifier::Pregel,
-            _ => ColumnIdentifier::Custom(value),
-        }
-    }
+    Custom(&'static str),
 }
 
 impl AsRef<str> for ColumnIdentifier {
@@ -231,7 +217,7 @@ impl PregelBuilder {
         PregelBuilder {
             graph,
             max_iterations: 10,
-            vertex_column: ColumnIdentifier::Custom("aux".to_owned()),
+            vertex_column: ColumnIdentifier::Custom("aux"),
             initial_message: Default::default(),
             send_messages: (Default::default(), Default::default()),
             aggregate_messages: Default::default(),
@@ -322,7 +308,15 @@ impl PregelBuilder {
     /// multiple methods can be called on the same struct instance in a single
     /// expression.
     pub fn send_messages(mut self, to: MessageReceiver, send_messages: Expr) -> Self {
-        self.send_messages = (Pregel::edge(MessageReceiver::into(to)), send_messages);
+        // We make this in this manner because we want to use the `src.id` and `edge.dst` columns
+        // in the send_messages function. This is because how polars works, when joining dataframes,
+        // it will keep only the left-hand side of the joins, thus, we need to use the `src.id` and
+        // `edge.dst` columns to get the correct vertex IDs.
+        let to = match to {
+            MessageReceiver::Src => Pregel::src(ColumnIdentifier::Id),
+            MessageReceiver::Dst => Pregel::edge(ColumnIdentifier::Dst),
+        };
+        self.send_messages = (to, send_messages);
         self
     }
 
@@ -395,16 +389,16 @@ impl PregelBuilder {
     ///
     ///     let vertices = df![
     ///         Id.as_ref() => [0, 1, 2, 3],
-    ///         Custom("value".to_owned()).as_ref() => [3, 6, 2, 1],
+    ///         Custom("value").as_ref() => [3, 6, 2, 1],
     ///     ]?;
     ///
     ///     let pregel = PregelBuilder::new(GraphFrame::new(vertices, edges)?)
     ///         .max_iterations(4)
-    ///         .with_vertex_column(Custom("max_value".to_owned()))
-    ///         .initial_message(col(Custom("value".to_owned()).as_ref()))
-    ///         .send_messages(MessageReceiver::Dst, Pregel::src(Custom("max_value".to_owned())))
+    ///         .with_vertex_column(Custom("max_value"))
+    ///         .initial_message(col(Custom("value").as_ref()))
+    ///         .send_messages(MessageReceiver::Dst, Pregel::src(Custom("max_value")))
     ///         .aggregate_messages(Pregel::msg(None).max())
-    ///         .v_prog(max_exprs([col(Custom("max_value".to_owned()).as_ref()), Pregel::msg(None)]))
+    ///         .v_prog(max_exprs([col(Custom("max_value").as_ref()), Pregel::msg(None)]))
     ///         .build();
     ///
     ///     Ok(println!("{}", pregel.run()?))
@@ -434,13 +428,6 @@ impl Pregel {
         format!("{}.{}", prefix.as_ref(), column_name.as_ref())
     }
 
-    fn prefix_columns(expr: Expr, prefix: &'static ColumnIdentifier) -> Expr {
-        expr.map_alias(|column_name| {
-            let column_identifier = ColumnIdentifier::from(column_name.to_string());
-            Ok(Self::alias(prefix, column_identifier))
-        })
-    }
-
     /// This function returns an expression for a column identifier representing
     ///  the source vertex in a Pregel graph.
     ///
@@ -573,12 +560,17 @@ impl Pregel {
     /// the resulting `DataFrame` or an error of type `PolarsError`.
     pub fn run(self) -> PolarsResult<DataFrame> {
         // We create a tuple where we store the column names of the `send_messages` DataFrame. We use
-        // the `alias` method to ensure that the column names are properly qualified.
+        // the `alias` method to ensure that the column names are properly qualified. We also
+        // do the same for the `aggregate_messages` Expr. And the same with the `v_prog` Expr.
         let (send_messages_ids, send_messages_msg) = self.send_messages;
         let (send_messages_ids, send_messages_msg) = (
             send_messages_ids.alias(&Self::alias(&ColumnIdentifier::Msg, ColumnIdentifier::Id)),
             send_messages_msg.alias(ColumnIdentifier::Pregel.as_ref()),
         );
+        let aggregate_messages = self
+            .aggregate_messages
+            .alias(ColumnIdentifier::Pregel.as_ref());
+        let v_prog = self.v_prog.alias(self.vertex_column.as_ref());
         // We create a DataFrame that contains the edges of the graph. This DataFrame is used to
         // compute the triplets of the graph, which are used to send messages to the neighboring
         // vertices of each vertex in the graph. For us to do so, we select all the columns of the
@@ -587,7 +579,7 @@ impl Pregel {
             .graph
             .edges
             .lazy()
-            .select([Self::prefix_columns(all(), &ColumnIdentifier::Edge)]);
+            .select([all().prefix(&format!("{}.", ColumnIdentifier::Edge.as_ref()))]);
         // We create a DataFrame that contains the vertices of the graph
         let vertices = &self.graph.vertices.lazy();
         // We start the execution of the algorithm from the super-step 0; that is, all the nodes
@@ -607,7 +599,7 @@ impl Pregel {
         // greater than the maximum number of iterations set by the user at the initialization of
         // the model (see the `Pregel::new` method). We start by setting the number of iterations to 1.
         let mut iteration = 1;
-        // TODO: check that nodes are not halted :D
+        // TODO: check that nodes are not halted. If so, we remove them from the `current_vertices` DataFrame.
         while iteration <= self.max_iterations {
             // We create a DataFrame that contains the triplets of the graph. Those triplets are
             // computed by joining the `current_vertices` DataFrame with the `edges` DataFrame
@@ -618,16 +610,16 @@ impl Pregel {
             let current_vertices_df = &current_vertices.lazy();
             let triplets_df = current_vertices_df
                 .to_owned()
-                .select([Self::prefix_columns(all(), &ColumnIdentifier::Src)])
+                .select([all().prefix(&format!("{}.", ColumnIdentifier::Src.as_ref()))])
                 .inner_join(
-                    edges.clone(),
+                    edges.to_owned(),
                     Self::src(ColumnIdentifier::Id), // src column of the current_vertices DataFrame
                     Self::edge(ColumnIdentifier::Src), // src column of the edges DataFrame
                 )
                 .inner_join(
                     current_vertices_df
                         .to_owned()
-                        .select([Self::prefix_columns(all(), &ColumnIdentifier::Dst)]),
+                        .select([all().prefix(&format!("{}.", ColumnIdentifier::Dst.as_ref()))]),
                     Self::edge(ColumnIdentifier::Dst), // dst column of the resulting DataFrame
                     Self::dst(ColumnIdentifier::Id), // id column of the current_vertices DataFrame
                 );
@@ -637,7 +629,8 @@ impl Pregel {
             // function is the one set by the user at the initialization of the model.
             let sends_messages_ids_df = &send_messages_ids;
             let send_messages_msg_df = &send_messages_msg;
-            let aggregate_messages_df = &self.aggregate_messages;
+            let aggregate_messages_df = &aggregate_messages;
+            println!("{:?}", triplets_df.clone().collect());
             let message_df = triplets_df
                 .select(vec![
                     sends_messages_ids_df.to_owned(),
@@ -650,7 +643,7 @@ impl Pregel {
             // not exist in the source DataFrame. In case we find any; for example, given a certain
             // node having no incoming edges, we have to replace the null value by 0 for the aggregation
             // to work properly.
-            let v_prog_df = &self.v_prog;
+            let v_prog_df = &v_prog;
             let vertex_columns = current_vertices_df
                 .to_owned()
                 .outer_join(
@@ -665,18 +658,14 @@ impl Pregel {
                         .otherwise(Self::msg(None))
                         .alias(ColumnIdentifier::Pregel.as_ref()),
                 )
-                .select(
-                    // TODO: fix this move: previous iteration of the loop. Improve?
-                    vec![
-                        col(ColumnIdentifier::Id.as_ref()),
-                        v_prog_df.to_owned().alias(self.vertex_column.as_ref()),
-                    ],
-                );
+                .select(vec![
+                    col(ColumnIdentifier::Id.as_ref()),
+                    v_prog_df.to_owned(),
+                ]);
             // We update the `current_vertices` DataFrame with the new values for the vertices. We
             // do so by performing an inner join between the `current_vertices` DataFrame and the
             // `vertex_columns` DataFrame. The join is performed on the `id` column of the
             // `current_vertices` DataFrame and the `id` column of the `vertex_columns` DataFrame.
-            // TODO: We also check if the nodes have voted to halt. If so, we remove them from the `current_vertices` DataFrame.
             current_vertices = vertices
                 .to_owned()
                 .inner_join(
@@ -714,12 +703,11 @@ mod tests {
 
         Ok(PregelBuilder::new(GraphFrame::new(vertices, edges)?)
             .max_iterations(iterations)
-            .with_vertex_column(Custom("rank".to_owned()))
+            .with_vertex_column(Custom("rank"))
             .initial_message(lit(1.0 / num_vertices))
             .send_messages(
                 MessageReceiver::Dst,
-                Pregel::src(Custom("rank".to_owned()))
-                    / Pregel::src(Custom("out_degree".to_owned())),
+                Pregel::src(Custom("rank")) / Pregel::src(Custom("out_degree")),
             )
             .aggregate_messages(Pregel::msg(None).sum())
             .v_prog(
@@ -787,31 +775,41 @@ mod tests {
         Ok(())
     }
 
-    fn max_value_builder(iterations: u8) -> Result<Pregel, Box<dyn Error>> {
-        let edges = df![
+    fn max_value_graph() -> Result<GraphFrame, String> {
+        let edges = match df![
             Src.as_ref() => [0, 1, 1, 2, 2, 3],
             Dst.as_ref() => [1, 0, 3, 1, 3, 2],
-        ]?;
+        ] {
+            Ok(edges) => edges,
+            Err(_) => return Err(String::from("Error creating the edges DataFrame")),
+        };
 
-        let vertices = df![
+        let vertices = match df![
             Id.as_ref() => [0, 1, 2, 3],
-            Custom("value".to_owned()).as_ref() => [3, 6, 2, 1],
-        ]?;
+            Custom("value").as_ref() => [3, 6, 2, 1],
+        ] {
+            Ok(vertices) => vertices,
+            Err(_) => return Err(String::from("Error creating the vertices DataFrame")),
+        };
+
+        match GraphFrame::new(vertices, edges) {
+            Ok(graph) => Ok(graph),
+            Err(_) => Err(String::from("Error creating the graph")),
+        }
+    }
 
+    fn max_value_builder(iterations: u8) -> Result<Pregel, String> {
         Ok(Pregel {
-            graph: GraphFrame::new(vertices, edges)?,
+            graph: max_value_graph()?,
             max_iterations: iterations,
-            vertex_column: Custom("max_value".to_owned()),
-            initial_message: col(Custom("value".to_owned()).as_ref()),
+            vertex_column: Custom("max_value"),
+            initial_message: col(Custom("value").as_ref()),
             send_messages: (
                 Pregel::edge(MessageReceiver::into(MessageReceiver::Dst)),
-                Pregel::src(Custom("max_value".to_owned())),
+                Pregel::src(Custom("max_value")),
             ),
             aggregate_messages: Pregel::msg(None).max(),
-            v_prog: max_exprs([
-                col(Custom("max_value".to_owned()).as_ref()),
-                Pregel::msg(None),
-            ]),
+            v_prog: max_exprs([col(Custom("max_value").as_ref()), Pregel::msg(None)]),
         })
     }
 
@@ -863,4 +861,29 @@ mod tests {
 
         Ok(())
     }
+
+    #[test]
+    fn test_literals() -> Result<(), String> {
+        // We create a graph using the exact same vertices and edges as the one used in the
+        // MaxValue algorithm. The graph itself is not important, we just need to test the
+        // Pregel model.
+        let graph = max_value_graph()?;
+
+        // We create a Pregel algorithm that computes nothing, just sends literals to all the vertices
+        // and then returns the same literal. Note that the algorithm computes nothing, but it is
+        // useful to test the Pregel model.
+        match PregelBuilder::new(graph)
+            .max_iterations(4)
+            .with_vertex_column(Custom("does_not_matter"))
+            .initial_message(lit(0)) // we pass the Undefined state to all vertices
+            .send_messages(MessageReceiver::Src, lit(0))
+            .aggregate_messages(lit(0))
+            .v_prog(lit(0))
+            .build()
+            .run()
+        {
+            Ok(_) => Ok(()),
+            Err(_) => Err(String::from("Error running the algorithm")),
+        }
+    }
 }