From 1c887bce4ff7102c6e5a9864a05f591f8bc0190d Mon Sep 17 00:00:00 2001 From: mitchmindtree Date: Thu, 13 Jun 2019 17:38:57 +0200 Subject: [PATCH] Add support to allow for optionally typed inputs and outputs This allows for optionally specifying the full types of the inputs and outputs of a `Node` during implementation by allowing to specify a full, freestanding function, rather than only an expression. The function's arguments and return type will be parsed to produce the number of inputs and outputs for the node, where the number of arguments is the number of inputs, and the number of tuple arguments in the output is the number of outputs (1 if no tuple output type). Some of this may have to be re-written when addressing a few follow-up issues including #29, #19, #21 and #22, but I think it's helpful to break up progress into achievable steps! Closes #27 and makes #20 much more feasible. --- examples/project.rs | 59 +++----- src/graph.rs | 348 +++++++++++++++++++++++++++++++++++--------- src/node/expr.rs | 26 ++-- src/node/mod.rs | 151 ++++++++++++++----- src/node/push.rs | 12 +- src/project.rs | 70 +++++---- 6 files changed, 465 insertions(+), 201 deletions(-) diff --git a/examples/project.rs b/examples/project.rs index 8e0dc32..09c5bbc 100644 --- a/examples/project.rs +++ b/examples/project.rs @@ -11,17 +11,11 @@ struct Add; struct Debug; impl gantz::Node for One { - fn n_inputs(&self) -> u32 { - 0 - } - - fn n_outputs(&self) -> u32 { - 1 - } - - fn expr(&self, args: Vec) -> syn::Expr { - assert!(args.is_empty()); - syn::parse_quote! { 1 } + fn evaluator(&self) -> gantz::node::Evaluator { + let n_inputs = 0; + let n_outputs = 1; + let gen_expr = Box::new(|_| syn::parse_quote! { 1 }); + gantz::node::Evaluator::Expr { n_inputs, n_outputs, gen_expr } } fn push_eval(&self) -> Option { @@ -31,35 +25,28 @@ impl gantz::Node for One { } impl gantz::Node for Add { - fn n_inputs(&self) -> u32 { - 2 - } - - fn n_outputs(&self) -> u32 { - 1 - } - - fn expr(&self, args: Vec) -> syn::Expr { - assert_eq!(args.len(), 2); - let l = &args[0]; - let r = &args[1]; - syn::parse_quote! { #l + #r } + fn evaluator(&self) -> gantz::node::Evaluator { + let n_inputs = 2; + let n_outputs = 1; + let gen_expr = Box::new(move |args: Vec| { + let l = &args[0]; + let r = &args[1]; + syn::parse_quote! { #l + #r } + }); + gantz::node::Evaluator::Expr { n_inputs, n_outputs, gen_expr } } } impl gantz::Node for Debug { - fn n_inputs(&self) -> u32 { - 1 - } - - fn n_outputs(&self) -> u32 { - 0 - } - - fn expr(&self, args: Vec) -> syn::Expr { - assert_eq!(args.len(), 1); - let input = &args[0]; - syn::parse_quote! { println!("{:?}", #input) } + fn evaluator(&self) -> gantz::node::Evaluator { + let n_inputs = 1; + let n_outputs = 0; + let gen_expr = Box::new(move |args: Vec| { + assert_eq!(args.len(), 1); + let input = &args[0]; + syn::parse_quote! { println!("{:?}", #input) } + }); + gantz::node::Evaluator::Expr { n_inputs, n_outputs, gen_expr } } } diff --git a/src/graph.rs b/src/graph.rs index 405d783..6c374c2 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -1,5 +1,7 @@ -use super::{Deserialize, Serialize}; use crate::node::{self, Node}; +use petgraph::visit::GraphBase; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::ops::{Deref, DerefMut}; /// The type used to represent node and edge indices. pub type Index = usize; @@ -16,6 +18,31 @@ pub struct Edge { pub input: node::Input, } +/// A node that itself is implemented in terms of a graph of nodes. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct GraphNode +where + G: GraphBase, +{ + /// The graph used to evaluate the node. + pub graph: G, + /// The types of each of the inputs into the graph node. + /// + /// TODO: Inlets and outlets should possibly use normal `Node`s and these should be their + /// indices. This way we can retrieve the type from the graph, cast it to `Inlet`/`Outlet` and + /// check for types while also allowing inlets and outlets to partake in the graph evaluation + /// process. + pub inlets: Vec, + /// The types of each of the outputs into the graph node. + pub outlets: Vec, +} + +/// A node that may act as an inlet into a graph node. +pub struct Inlet { + /// The type of value that can be input into the graph. + pub ty: syn::Type, +} + /// The petgraph type used to represent a gantz graph. pub type Graph = petgraph::Graph; @@ -30,20 +57,162 @@ impl Edge { } } -impl Node for StableGraph +impl Node for GraphNode +where + G: petgraph::visit::Data, + G::NodeWeight: Node, +{ + fn evaluator(&self) -> node::Evaluator { + let fn_token = syn::token::Fn { span: proc_macro2::Span::call_site() }; + let generics = unimplemented!("if any inlets/outlets use generics, this should too"); + let paren_token = syn::token::Paren { span: proc_macro2::Span::call_site() }; + let variadic = None; + let inputs = unimplemented!("to be inferred from inlet nodes"); + let output = unimplemented!("to be inferred from outlet nodes"); + let fn_decl = syn::FnDecl { fn_token, generics, paren_token, inputs, variadic, output }; + let fn_item = unimplemented!(); + node::Evaluator::Fn { fn_item } + } +} + +// Manual implementation of `Deserialize` as it cannot be derived for a struct with associated +// types without unnecessary trait bounds on the struct itself. +impl<'de, G> Deserialize<'de> for GraphNode +where + G: GraphBase + Deserialize<'de>, + G::NodeId: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + use serde::de::{self, MapAccess, SeqAccess, Visitor}; + + #[derive(Deserialize)] + #[serde(field_identifier, rename_all = "lowercase")] + enum Field { Graph, Inlets, Outlets } + + struct GraphNodeVisitor(std::marker::PhantomData); + + impl<'de, G> Visitor<'de> for GraphNodeVisitor + where + G: GraphBase + Deserialize<'de>, + G::NodeId: Deserialize<'de>, + { + type Value = GraphNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("struct GraphNode") + } + + fn visit_seq(self, mut seq: V) -> Result, V::Error> + where + V: SeqAccess<'de>, + { + let graph = seq.next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let inlets = seq.next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + let outlets = seq.next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + Ok(GraphNode { graph, inlets, outlets }) + } + + fn visit_map(self, mut map: V) -> Result, V::Error> + where + V: MapAccess<'de>, + { + let mut graph = None; + let mut inlets = None; + let mut outlets = None; + while let Some(key) = map.next_key()? { + match key { + Field::Graph => { + if graph.is_some() { + return Err(de::Error::duplicate_field("graph")); + } + graph = Some(map.next_value()?); + } + Field::Inlets => { + if inlets.is_some() { + return Err(de::Error::duplicate_field("inlets")); + } + inlets = Some(map.next_value()?); + } + Field::Outlets => { + if outlets.is_some() { + return Err(de::Error::duplicate_field("outlets")); + } + outlets = Some(map.next_value()?); + } + } + } + let graph = graph.ok_or_else(|| de::Error::missing_field("graph"))?; + let inlets = inlets.ok_or_else(|| de::Error::missing_field("inlets"))?; + let outlets = outlets.ok_or_else(|| de::Error::missing_field("outlets"))?; + Ok(GraphNode { graph, inlets, outlets }) + } + } + + const FIELDS: &[&str] = &["graph", "inlets", "outlets"]; + let visitor: GraphNodeVisitor = GraphNodeVisitor(std::marker::PhantomData); + deserializer.deserialize_struct("GraphNode", FIELDS, visitor) + } +} + +// Manual implementation of `Serialize` as it cannot be derived for a struct with associated +// types without unnecessary trait bounds on the struct itself. +impl Serialize for GraphNode where - N: Node, + G: GraphBase + Serialize, + G::NodeId: Serialize, { - fn n_inputs(&self) -> u32 { - unimplemented!("requires implementing graph inlet nodes") + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + use serde::ser::SerializeStruct; + let mut state = serializer.serialize_struct("GraphNode", 3)?; + state.serialize_field("graph", &self.graph)?; + state.serialize_field("inlets", &self.inlets)?; + state.serialize_field("outlets", &self.outlets)?; + state.end() } +} - fn n_outputs(&self) -> u32 { - unimplemented!("requires implementing graph outlet nodes") +// impl Node for Inlet { +// fn evaluator(&self) -> node::Evaluator { +// let n_inputs = 1; +// let n_outputs = 1; +// let ty = self.ty.clone(); +// let gen_expr = Box::new(move |mut args: Vec| { +// assert_eq!(args.len(), 1, "must be a single input (from the calling fn) for an inlet"); +// let in_expr = args.remove(0); +// syn::parse_quote! { +// let in_expr_checked: #ty = in_expr; +// in_expr_checked +// } +// }); +// node::Evaluator::Expr { n_inputs, n_outputs, gen_expr } +// } +// } + +impl Deref for GraphNode +where + G: GraphBase, +{ + type Target = G; + fn deref(&self) -> &Self::Target { + &self.graph } +} - fn expr(&self, _args: Vec) -> syn::Expr { - unimplemented!("requires implementing graph inlet and outlet nodes") +impl DerefMut for GraphNode +where + G: GraphBase, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.graph } } @@ -90,7 +259,39 @@ pub mod codegen { pub requires_clone: bool, } - /// Given a graph with of gantz nodes, return `NodeId`s of those that require push evaluation. + /// Shorthand for the node evaluator map passed between codegen stages. + pub type NodeEvaluatorMap = HashMap; + + /// Given a graph of gantz nodes, produce the `Evaluator` associated with each. + pub fn node_evaluators(g: G) -> NodeEvaluatorMap + where + G: IntoNodeReferences, + ::Weight: Node, + G::NodeId: Eq + Hash, + { + g.node_references() + .map(|n| (n.id(), n.weight().evaluator())) + .collect() + } + + /// Given a set of node evaluators, return only those that have function definitions. + pub fn node_evaluator_fns( + evaluators: &NodeEvaluatorMap, + ) -> impl Iterator + where + Id: Eq + Hash, + { + evaluators + .iter() + .filter_map(|(id, eval)| { + match eval { + node::Evaluator::Fn { ref fn_item } => Some((id, fn_item)), + node::Evaluator::Expr { .. } => None, + } + }) + } + + /// Given a graph of gantz nodes, return `NodeId`s of those that require push evaluation. /// /// Expects any graph type whose nodes implement `Node`. pub fn push_nodes(g: G) -> Vec<(G::NodeId, node::PushEval)> @@ -122,7 +323,11 @@ pub mod codegen { /// /// Expects any directed graph whose edges are of type `Edge` and whose nodes implement `Node`. /// Direction of edges indicate the flow of data through the graph. - pub fn push_eval_steps(g: G, n: G::NodeId) -> Vec> + pub fn push_eval_steps( + g: G, + node_evaluators: &NodeEvaluatorMap, + n: G::NodeId, + ) -> Vec> where G: GraphRef + IntoEdgesDirected + IntoNodeReferences + NodeIndexable + Visitable, G: Data, @@ -144,15 +349,11 @@ pub mod codegen { continue; } - // Fetch the node reference. - let child = g.node_references() - .nth(g.to_index(node)) - .expect("no node for index"); - // Initialise the arguments to `None` for each input. - let mut args: Vec<_> = (0..child.weight().n_inputs()).map(|_| None).collect(); + let child_evaluator = &node_evaluators[&node]; + let mut args: Vec<_> = (0..child_evaluator.n_inputs()).map(|_| None).collect(); - // Create an argument for each input to this child. + // Create an argument for each input to this child. for e_ref in g.edges_directed(node, petgraph::Incoming) { let w = e_ref.weight(); @@ -160,7 +361,7 @@ pub mod codegen { // value will need to be cloned when passed to this input. let requires_clone = { let parent = e_ref.source(); - // TODO: Connection order should match + // TODO: Connection order should match let mut connection_ix = 0; let mut total_connections_from_output = 0; for (i, pe_ref) in g.edges_directed(parent, petgraph::Outgoing).enumerate() { @@ -191,21 +392,21 @@ pub mod codegen { eval_steps } - /// Pull evaluation from the specified node. - /// - /// Evaluation order is equivalent to depth-first-search order, ending with the specified node. - /// - /// Expects any directed graph whose edges are of type `Edge` and whose nodes implement `Node`. - /// Direction of edges indicate the flow of data through the graph. - pub fn pull_eval_steps(g: G, n: G::NodeId) -> Vec> - where - G: GraphRef + IntoEdgesDirected + IntoNodeReferences + NodeIndexable + Visitable, - G: Data, - ::Weight: Node, - { - // TODO - unimplemented!() - } + // /// Pull evaluation from the specified node. + // /// + // /// Evaluation order is equivalent to depth-first-search order, ending with the specified node. + // /// + // /// Expects any directed graph whose edges are of type `Edge` and whose nodes implement `Node`. + // /// Direction of edges indicate the flow of data through the graph. + // pub fn pull_eval_steps(g: G, n: G::NodeId) -> Vec> + // where + // G: GraphRef + IntoEdgesDirected + IntoNodeReferences + NodeIndexable + Visitable, + // G: Data, + // ::Weight: Node, + // { + // // TODO + // unimplemented!() + // } /// Given a function argument, return its type if known. pub fn ty_from_fn_arg(arg: &syn::FnArg) -> Option { @@ -222,19 +423,13 @@ pub mod codegen { g: G, push_eval: node::PushEval, steps: &[EvalStep], + node_evaluators: &NodeEvaluatorMap, ) -> syn::ItemFn where G: GraphRef + IntoNodeReferences + NodeIndexable, G::NodeId: Eq + Hash, ::Weight: Node, { - // For each evaluation step, generate a statement where the expression for the node at that - // evaluation step is evaluated and the outputs are destructured from a tuple. - let mut stmts: Vec = vec![]; - - // Keep track of each of the lvalues for each of the statements. These are used to pass - let mut lvalues: HashMap<(G::NodeId, node::Output), syn::Ident> = Default::default(); - type LValues = HashMap<(NI, node::Output), syn::Ident>; // A function for constructing a variable name. @@ -292,9 +487,14 @@ pub mod codegen { } } - for (si, step) in steps.iter().enumerate() { - let n_ref = g.node_references().nth(g.to_index(step.node)).expect("no node for index"); + // For each evaluation step, generate a statement where the expression for the node at that + // evaluation step is evaluated and the outputs are destructured from a tuple. + let mut stmts: Vec = vec![]; + // Keep track of each of the lvalues for each of the statements. These are used to pass + let mut lvalues: HashMap<(G::NodeId, node::Output), syn::Ident> = Default::default(); + + for (si, step) in steps.iter().enumerate() { // Retrieve an expression for each argument to the current node's expression. // // E.g. `_n1_v0`, `_n3_v1.clone()` or `Default::default()`. @@ -302,9 +502,9 @@ pub mod codegen { .map(|arg| input_expr(g, arg.as_ref(), &lvalues)) .collect(); - let nw = n_ref.weight(); - let n_outputs = nw.n_outputs(); - let expr: syn::Expr = nw.expr(args); + let ne = &node_evaluators[&step.node]; + let n_outputs = ne.n_outputs(); + let expr: syn::Expr = ne.expr(args); // Create the lvals pattern, either `PatWild` for no outputs, `Ident` for single output // or `Tuple` for multiple. Keep track of each the lvalue ident for each output of the @@ -372,7 +572,11 @@ pub mod codegen { /// Given a list of push evaluation nodes and their evaluation steps, generate a function for /// performing push evaluation for each node. - pub fn push_eval_fns<'a, G, I>( g: G, push_eval_nodes: I,) -> Vec + pub fn push_eval_fns<'a, G, I>( + g: G, + push_eval_nodes: I, + node_evaluators: &NodeEvaluatorMap, + ) -> Vec where G: GraphRef + IntoNodeReferences + NodeIndexable, G::NodeId: 'a + Eq + Hash, @@ -381,44 +585,52 @@ pub mod codegen { { push_eval_nodes .into_iter() - .map(|(_n, eval, steps)| push_eval_fn(g, eval, steps)) + .map(|(_n, eval, steps)| push_eval_fn(g, eval, steps, node_evaluators)) .collect() } - /// Generate a function for performing pull evaluation from the given node with the given - /// evaluation steps. - pub fn pull_eval_fn( - g: G, - pull_eval: node::PullEval, - steps: &[EvalStep], - ) -> syn::ItemFn - where - G: GraphRef + IntoNodeReferences + NodeIndexable, - G::NodeId: Eq + Hash, - ::Weight: Node, - { - // TODO - unimplemented!(); - } + // /// Generate a function for performing pull evaluation from the given node with the given + // /// evaluation steps. + // pub fn pull_eval_fn( + // g: G, + // pull_eval: node::PullEval, + // steps: &[EvalStep], + // ) -> syn::ItemFn + // where + // G: GraphRef + IntoNodeReferences + NodeIndexable, + // G::NodeId: Eq + Hash, + // ::Weight: Node, + // { + // // TODO + // unimplemented!(); + // } /// Given a gantz graph, generate the rust code src file with all the necessary functions for /// executing it. - pub fn file(g: G) -> syn::File + pub fn file(g: G, _inlets: &[G::NodeId], _outlets: &[G::NodeId]) -> syn::File where G: GraphRef + IntoEdgesDirected + IntoNodeReferences + NodeIndexable + Visitable, G: Data, G::NodeId: Eq + Hash, ::Weight: Node, { + let node_evaluators = node_evaluators(g); + let node_evaluator_fn_items = node_evaluator_fns(&node_evaluators); let push_nodes = push_nodes(g); - let items = push_nodes + + let push_node_fn_items = push_nodes .into_iter() .map(|(n, eval)| { - let steps = push_eval_steps(g, n); - let item_fn = push_eval_fn(g, eval, &steps); + let steps = push_eval_steps(g, &node_evaluators, n); + let item_fn = push_eval_fn(g, eval, &steps, &node_evaluators); syn::Item::Fn(item_fn) - }) + }); + + let items = node_evaluator_fn_items + .map(|(_, item_fn)| syn::Item::Fn(item_fn.clone())) + .chain(push_node_fn_items) .collect(); + let file = syn::File { shebang: None, attrs: vec![], items }; file } diff --git a/src/node/expr.rs b/src/node/expr.rs index ff7299f..84c82f0 100644 --- a/src/node/expr.rs +++ b/src/node/expr.rs @@ -1,9 +1,9 @@ -use super::{Deserialize, Fail, From, Serialize}; -use crate::node::Node; +use crate::node::{self, Node}; use proc_macro2::{TokenStream, TokenTree}; use quote::{ToTokens, TokenStreamExt}; use std::fmt; use std::str::FromStr; +use super::{Deserialize, Fail, From, Serialize}; /// A simple node that allows for representing rust expressions as nodes within a gantz graph. /// @@ -71,18 +71,16 @@ impl Expr { } impl Node for Expr { - fn n_inputs(&self) -> u32 { - count_hashes(&self.tokens) - } - - fn n_outputs(&self) -> u32 { - 1 - } - - fn expr(&self, args: Vec) -> syn::Expr { - let args_tokens = args.into_iter().map(|expr| expr.into_token_stream()); - let expr_tokens = interpolate_tokens(&self.tokens, args_tokens); - syn::parse_quote! { #expr_tokens } + fn evaluator(&self) -> node::Evaluator { + let n_inputs = count_hashes(&self.tokens); + let n_outputs = 1; + let tokens = self.tokens.clone(); + let gen_expr = Box::new(move |args: Vec| { + let args_tokens = args.into_iter().map(|expr| expr.into_token_stream()); + let expr_tokens = interpolate_tokens(&tokens, args_tokens); + syn::parse_quote! { #expr_tokens } + }); + node::Evaluator::Expr { n_inputs, n_outputs, gen_expr } } } diff --git a/src/node/mod.rs b/src/node/mod.rs index 51de7e0..657bf27 100644 --- a/src/node/mod.rs +++ b/src/node/mod.rs @@ -9,31 +9,25 @@ pub use self::push::{Push, WithPushEval}; pub use self::serde::SerdeNode; /// Gantz allows for constructing executable directed graphs by composing together **Node**s. -/// +/// /// **Node**s are a way to allow users to abstract and encapsulate logic into smaller, re-usable /// components, similar to a function in a coded programming language. -/// +/// /// Every Node is made up of the following: -/// +/// /// - Any number of inputs, where each input is of some rust type or generic type. /// - Any number of outputs, where each output is of some rust type or generic type. /// - A function that takes the inputs as arguments and returns an Outputs struct containing a /// field for each of the outputs. pub trait Node { - /// The number of inputs to the node. - fn n_inputs(&self) -> u32; - - /// The number of outputs to the node. - fn n_outputs(&self) -> u32; - - /// Tokens representing the rust code that will evaluate to a tuple containing all outputs. + /// The approach taken for evaluating a nodes inputs to its outputs. /// - /// TODO: Consider making `args` a `Vec` of `Option`s and returning an `Option` expr to allow - /// for generating execution paths where only a certain set of inputs have been triggered. - /// Returning `None` could indicate that there is no valid `Expr` for the current set of - /// triggered inputs. This would probably be better than than using `default` as is currently - /// the case. Would also allow for - fn expr(&self, args: Vec) -> syn::Expr; + /// This can either be an expression or a function - the key difference being that the types of + /// a function's inputs and outputs are known before compilation begins. As a result, functions + /// can lead to gantz generating more modular, compiler-friendly code, while raw expressions + /// have the benefit of being more ergonomic for the implementer as types aren't resolved until + /// the compilation process begins. + fn evaluator(&self) -> Evaluator; /// Specifies whether or not code should be generated to allow for push evaluation from /// instances of this node. Enabling push evaluation allows applications to call into @@ -58,6 +52,36 @@ pub trait Node { } } +/// The method of evaluation used for a node. +/// +/// The key distinction between the `Fn` and `Expr` variants is whether or not types of the inputs +/// and outputs are known before a node is connected to a graph or if instead these types should be +/// inferred. +pub enum Evaluator { + /// Functions have the benefit of knowing the types of their inputs and outputs. + /// + /// Knowing the types of a node's inputs and outputs allow us to: + /// + /// - Generate more modular code for a node. + /// - Create better user feedback and error messages. + /// - Implement `Node` for `Graph`. + Fn { + /// A free-standing function, including its name, declaration, the block and other + /// attributes. + fn_item: syn::ItemFn, + }, + /// Expressions have the benefit of not needing to know the exact types of a node's inputs and + /// outputs. This simplifies the implementation of the `Node` trait for users. + Expr { + /// The function for producing an expression given the input expressions. + gen_expr: Box) -> syn::Expr>, + /// The number of inputs to the expression. + n_inputs: u32, + /// The number of outputs to the expression. + n_outputs: u32, + }, +} + /// Items that need to be known in order to generate a push evaluation function for a node. #[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)] pub struct PushEval { @@ -83,20 +107,40 @@ pub struct Input(pub u32); #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, PartialOrd, Ord, Deserialize, Serialize)] pub struct Output(pub u32); -impl<'a, N> Node for &'a N -where - N: Node, -{ - fn n_inputs(&self) -> u32 { - (**self).n_inputs() +impl Evaluator { + /// The number of inputs to the node. + pub fn n_inputs(&self) -> u32 { + match *self { + Evaluator::Fn { ref fn_item } => count_fn_inputs(&fn_item.decl) as _, + Evaluator::Expr { n_inputs, .. } => n_inputs as _, + } } - fn n_outputs(&self) -> u32 { - (**self).n_outputs() + /// The number of outputs to the node. + pub fn n_outputs(&self) -> u32 { + match *self { + Evaluator::Fn { ref fn_item } => count_fn_outputs(&fn_item.decl) as _, + Evaluator::Expr { n_outputs, .. } => n_outputs as _, + } } - fn expr(&self, args: Vec) -> syn::Expr { - (**self).expr(args) + /// Tokens representing the rust code that will evaluate to a tuple containing all outputs. + /// + /// TODO: Handle case where only a subset of inputs are connected. See issue #17. + pub fn expr(&self, args: Vec) -> syn::Expr { + match *self { + Evaluator::Fn { ref fn_item } => fn_call_expr(fn_item, args), + Evaluator::Expr { ref gen_expr, .. } => (*gen_expr)(args), + } + } +} + +impl<'a, N> Node for &'a N +where + N: Node, +{ + fn evaluator(&self) -> Evaluator { + (**self).evaluator() } fn push_eval(&self) -> Option { @@ -111,16 +155,8 @@ where macro_rules! impl_node_for_ptr { ($($Ty:ident)::*) => { impl Node for $($Ty)::* { - fn n_inputs(&self) -> u32 { - (**self).n_inputs() - } - - fn n_outputs(&self) -> u32 { - (**self).n_outputs() - } - - fn expr(&self, args: Vec) -> syn::Expr { - (**self).expr(args) + fn evaluator(&self) -> Evaluator { + (**self).evaluator() } fn push_eval(&self) -> Option { @@ -165,3 +201,46 @@ impl From for Output { pub fn expr(expr: &str) -> Result { Expr::new(expr) } + +// Count the number of arguments to the given function. +// +// This is used to determine the number of inputs to the function. +fn count_fn_inputs(fn_decl: &syn::FnDecl) -> usize { + fn_decl.inputs.len() +} + +// Count the number of arguments to the given function. +// +// This is used to determine the number of inputs to the function. +fn count_fn_outputs(fn_decl: &syn::FnDecl) -> usize { + match fn_decl.output { + syn::ReturnType::Default => 0, + syn::ReturnType::Type(ref _r_arrow, ref ty) => match **ty { + syn::Type::Tuple(ref tuple) => tuple.elems.len(), + _ => 1, + } + } +} + +// Create a rust expression that calls the given `fn_decl` function with the given `args` +// expressions as its inputs. +fn fn_call_expr(fn_item: &syn::ItemFn, args: Vec) -> syn::Expr { + let n_inputs = count_fn_inputs(&fn_item.decl); + assert_eq!(n_inputs, args.len(), "the number of args to a function node must match n_inputs"); + let ident = fn_item.ident.clone(); + let arguments = syn::PathArguments::None; + let segment = syn::PathSegment { ident, arguments }; + let segments = std::iter::once(segment).collect(); + let leading_colon = None; + let path = syn::Path { leading_colon, segments }; + let attrs = vec![]; + let qself = None; + let func_path = syn::ExprPath { attrs, qself, path }; + let attrs = vec![]; + let func = Box::new(syn::Expr::Path(func_path)); + let paren_token = syn::token::Paren { span: proc_macro2::Span::call_site() }; + let args = args.into_iter().collect(); + let expr_call = syn::ExprCall { attrs, func, paren_token, args }; + let expr = syn::Expr::Call(expr_call); + expr +} diff --git a/src/node/push.rs b/src/node/push.rs index 48a9094..f63bd01 100644 --- a/src/node/push.rs +++ b/src/node/push.rs @@ -59,16 +59,8 @@ impl Node for Push where N: Node, { - fn n_inputs(&self) -> u32 { - self.node.n_inputs() - } - - fn n_outputs(&self) -> u32 { - self.node.n_outputs() - } - - fn expr(&self, args: Vec) -> syn::Expr { - self.node.expr(args) + fn evaluator(&self) -> node::Evaluator { + self.node.evaluator() } fn push_eval(&self) -> Option { diff --git a/src/project.rs b/src/project.rs index bfb2693..469e770 100644 --- a/src/project.rs +++ b/src/project.rs @@ -1,5 +1,5 @@ use super::{Deserialize, Fail, From, Serialize}; -use crate::graph; +use crate::graph::{self, GraphNode}; use crate::node::{self, Node, SerdeNode}; use quote::ToTokens; use std::collections::{BTreeMap, HashMap}; @@ -55,25 +55,29 @@ pub struct NodeCollection { /// A graph composed of IDs into the `NodeCollection`. pub type NodeIdGraph = graph::StableGraph; +pub type NodeIdGraphNode = GraphNode; + /// Whether the node is a **Core** node (has no other internal **Node** dependencies) or is a /// **Graph** node, composed entirely of other gantz **Node**s. #[derive(Deserialize, Serialize)] pub enum NodeKind { Core(Box), - Graph(GraphNode), + Graph(ProjectGraph), } -/// A node composed of a graph of other nodes. +/// A gantz node graph useful within gantz `Project`s. +/// +/// This can be thought of as a node that is a graph composed of other nodes. #[derive(Deserialize, Serialize)] -pub struct GraphNode { - pub graph: NodeIdGraph, +pub struct ProjectGraph { + pub graph: NodeIdGraphNode, pub package_id: cargo::core::PackageId, } // A **Node** type constructed as a reference to some other node. enum NodeRef<'a> { Core(&'a Node), - Graph(graph::StableGraph>), + Graph(GraphNode>>), } /// Errors that may occur while creating a node crate. @@ -245,7 +249,7 @@ pub enum GraphNodeCompileError { NoMatchingPackageId, } -/// Errors that might occur while updating a `GraphNode`'s graph. +/// Errors that might occur while updating a `ProjectGraph`'s graph. #[derive(Debug, Fail, From)] pub enum UpdateGraphError { #[fail(display = "failed to replace graph node src: {}", err)] @@ -300,10 +304,13 @@ impl Project { Err(_err) => { let mut nodes = NodeCollection::default(); let graph = NodeIdGraph::default(); + let inlets = vec![]; + let outlets = vec![]; + let graph_node = GraphNode { graph, inlets, outlets }; let ws_dir = workspace_dir(&directory); let proj_name = project_name(&directory); let node_id = - add_graph_node_to_collection(&ws_dir, proj_name, &cargo_config, graph, &mut nodes)?; + add_graph_node_to_collection(&ws_dir, proj_name, &cargo_config, graph_node, &mut nodes)?; if let Some(NodeKind::Graph(ref node)) = nodes.get(&node_id) { graph_node_compile(&ws_dir, &cargo_config, node)?; } @@ -329,7 +336,7 @@ impl Project { /// Add the given node to the collection and return its unique identifier. pub fn add_graph_node( &mut self, - graph: NodeIdGraph, + graph: NodeIdGraphNode, node_name: &str, ) -> Result { let ws_dir = self.workspace_dir(); @@ -363,7 +370,7 @@ impl Project { /// /// Returns `None` if there are no nodes for the given **NodeId** or if a node exists but it is /// not a **Graph** node. - pub fn graph_node(&self, id: &NodeId) -> Option<&GraphNode> { + pub fn graph_node(&self, id: &NodeId) -> Option<&ProjectGraph> { self.nodes.get(id).and_then(|kind| match kind { NodeKind::Graph(ref graph) => Some(graph), _ => None, @@ -373,7 +380,7 @@ impl Project { /// Update the graph associated with the graph node at the given **NodeId**. pub fn update_graph(&mut self, id: &NodeId, update: F) -> Result<(), UpdateGraphError> where - F: FnOnce(&mut NodeIdGraph), + F: FnOnce(&mut NodeIdGraphNode), { match self.nodes.map.get_mut(id) { Some(NodeKind::Graph(ref mut node)) => update(&mut node.graph), @@ -481,24 +488,10 @@ impl NodeCollection { } impl<'a> Node for NodeRef<'a> { - fn n_inputs(&self) -> u32 { - match self { - NodeRef::Core(node) => node.n_inputs(), - NodeRef::Graph(graph) => graph.n_inputs(), - } - } - - fn n_outputs(&self) -> u32 { - match self { - NodeRef::Core(node) => node.n_outputs(), - NodeRef::Graph(graph) => graph.n_outputs(), - } - } - - fn expr(&self, args: Vec) -> syn::Expr { + fn evaluator(&self) -> node::Evaluator { match self { - NodeRef::Core(node) => node.expr(args), - NodeRef::Graph(graph) => graph.expr(args), + NodeRef::Core(node) => node.evaluator(), + NodeRef::Graph(graph) => graph.evaluator(), } } @@ -797,14 +790,14 @@ fn add_graph_node_to_collection

( workspace_dir: P, node_name: &str, cargo_config: &cargo::Config, - graph: NodeIdGraph, + graph: NodeIdGraphNode, nodes: &mut NodeCollection, ) -> Result where P: AsRef, { let package_id = open_node_package(&workspace_dir, node_name, cargo_config)?; - let kind = NodeKind::Graph(GraphNode { graph, package_id }); + let kind = NodeKind::Graph(ProjectGraph { graph, package_id }); let node_id = nodes.insert(kind); let file = graph_node_src(&node_id, nodes).expect("no graph node for NodeId"); graph_node_replace_src(&workspace_dir, cargo_config, &node_id, nodes, file)?; @@ -839,7 +832,7 @@ where fn graph_node_compile<'conf, P>( workspace_dir: P, cargo_config: &'conf cargo::Config, - node: &GraphNode, + node: &ProjectGraph, ) -> Result, GraphNodeCompileError> where P: AsRef, @@ -859,12 +852,14 @@ where Ok(compilation) } -// Given a `NodeIdGraph` and `NodeCollection`, return a graph capable of evaluation. +// Given a `NodeIdGraphNode` and `NodeCollection`, return a graph capable of evaluation. fn id_graph_to_node_graph<'a>( - g: &NodeIdGraph, + g: &NodeIdGraphNode, ns: &'a NodeCollection, -) -> graph::StableGraph> { - g.map( +) -> GraphNode>> { + let inlets = g.inlets.clone(); + let outlets = g.outlets.clone(); + let graph = g.graph.map( |_, n_id| { match ns[n_id] { NodeKind::Core(ref node) => NodeRef::Core(node.node()), @@ -876,7 +871,8 @@ fn id_graph_to_node_graph<'a>( |_, edge| { edge.clone() }, - ) + ); + GraphNode { graph, inlets, outlets } } // Generate a src file for the graph node associated with the given `NodeId`. @@ -885,7 +881,7 @@ fn id_graph_to_node_graph<'a>( fn graph_node_src(id: &NodeId, nodes: &NodeCollection) -> Option { if let Some(NodeKind::Graph(ref node)) = nodes.get(id) { let graph = id_graph_to_node_graph(&node.graph, nodes); - return Some(graph::codegen::file(&graph)); + return Some(graph::codegen::file(&graph.graph, &graph.inlets, &graph.outlets)); } None }