diff --git a/optd-core/src/cascades/optimizer.rs b/optd-core/src/cascades/optimizer.rs index 2af46e12..001442b8 100644 --- a/optd-core/src/cascades/optimizer.rs +++ b/optd-core/src/cascades/optimizer.rs @@ -5,9 +5,13 @@ use std::collections::{BTreeSet, HashMap, HashSet, VecDeque}; use std::fmt::Display; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::io::Write; +use std::path::Path; use std::sync::Arc; use anyhow::Result; +use itertools::Itertools; use tracing::trace; use super::memo::{ArcMemoPlanNode, GroupInfo, Memo}; @@ -40,6 +44,8 @@ pub struct OptimizerProperties { pub partial_explore_space: Option<usize>, /// Disable pruning during optimization. pub disable_pruning: bool, + /// Dump the memo as 0000.dot, 0001.dot, ... files to this path + pub dot_file_path: Option<String>, } pub struct CascadesOptimizer<T: NodeType, M: Memo<T> = NaiveMemo<T>> { @@ -54,6 +60,7 @@ pub struct CascadesOptimizer<T: NodeType, M: Memo<T> = NaiveMemo<T>> { property_builders: Arc<[Box<dyn LogicalPropertyBuilderAny<T>>]>, pub ctx: OptimizerContext, pub prop: OptimizerProperties, + pub next_dot_file: usize, } /// `RelNode` only contains the representation of the plan nodes. Sometimes, we need more context, @@ -123,6 +130,7 @@ impl<T: NodeType> CascadesOptimizer<T, NaiveMemo<T>> { property_builders, prop, disabled_rules: HashSet::new(), + next_dot_file: 0, } } @@ -170,6 +178,93 @@ impl<T: NodeType, M: Memo<T>> CascadesOptimizer<T, M> { self.disabled_rules.contains(&rule_id) } + pub fn dump_dot(&self, writer: &mut dyn Write) -> Result<(), std::io::Error> { + let memo = self.memo(); + + // Collect all groups in a predictable iteration order + let groups: Vec<GroupId> = memo.get_all_group_ids().iter().sorted().cloned().collect(); + + writeln!(writer, "digraph Memo {{")?; + writeln!( + writer, + "compound=true; ranksep=1.0; node [colorscheme=set312];" + )?; + for group_id in groups.iter() { + let group = memo.get_group(*group_id); + writeln!(writer, "subgraph cluster_{} {{", group_id.0)?; + writeln!(writer, "rank=source;")?; + writeln!(writer, "edge [style=invis];")?; + writeln!( + writer, + "g{} [shape=plaintext,label=\"group_id=!{}\"];", + group_id.0, group_id.0 + )?; + for expr_id in memo.get_all_exprs_in_group(*group_id).iter() { + let expr = memo.get_expr_memoed(*expr_id); + let mut s = DefaultHasher::new(); + expr.typ.hash(&mut s); + let color = (s.finish() % 11) + 1; // %11 looks better than %12! :-) + let shape = if expr.typ.is_logical() { "oval" } else { "box" }; + let rules = match self.fired_rules.get(expr_id) { + None => 0, + Some(v) => v.len(), + }; + writeln!( + writer, + "e{} [shape={},label=\"{}: {:?} ({})\",style=filled,color={}]", + expr_id.0, shape, expr_id.0, expr.typ, rules, color + )?; + writeln!(writer, "g{} -> e{};", group_id.0, expr_id.0)?; + } + writeln!(writer, "}}"); + } + for group_id in groups.iter() { + for expr_id in memo.get_all_exprs_in_group(*group_id).iter() { + let expr = memo.get_expr_memoed(*expr_id); + for child in expr.children.iter() { + writeln!( + writer, + "e{} -> g{} [lhead=\"cluster_{}\"];", + expr_id.0, child.0, child.0 + )?; + } + let mut next_pred: usize = 0; + for pred_id in expr.predicates.iter() { + let pred = memo.get_pred(*pred_id); + let id = next_pred; + self.dump_dot_pred(writer, &pred, expr_id.0, &mut next_pred); + writeln!(writer, "e{} -> p{}_{};", expr_id.0, expr_id.0, id)?; + } + } + } + writeln!(writer, "}}") + } + + fn dump_dot_pred( + &self, + writer: &mut dyn Write, + pred: &ArcPredNode<T>, + base: usize, + next_pred: &mut usize, + ) -> Result<(), std::io::Error> { + let mut s = DefaultHasher::new(); + pred.typ.hash(&mut s); + let color = (s.finish() % 11) + 1; + let id = *next_pred; + *next_pred += 1; + writeln!( + writer, + "p{}_{} [shape=diamond,label=\"{:?}\",penwidth=3,color={}]", + base, id, pred.typ, color + )?; + for child in pred.children.iter() { + let child_id = *next_pred; + self.dump_dot_pred(writer, child, base, next_pred)?; + writeln!(writer, "p{}_{} -> p{}_{};", base, id, base, child_id)?; + } + Ok(()) + } + pub fn dump(&self) { for group_id in self.memo.get_all_group_ids() { let winner_str = match &self.memo.get_group_info(group_id).winner { diff --git a/optd-core/src/cascades/tasks/apply_rule.rs b/optd-core/src/cascades/tasks/apply_rule.rs index b6e62efb..d5932a9c 100644 --- a/optd-core/src/cascades/tasks/apply_rule.rs +++ b/optd-core/src/cascades/tasks/apply_rule.rs @@ -3,11 +3,12 @@ // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +use std::path::Path; use std::sync::Arc; use anyhow::Result; use itertools::Itertools; -use tracing::trace; +use tracing::{debug, trace}; use super::Task; use crate::cascades::memo::ArcMemoPlanNode; @@ -164,6 +165,14 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for ApplyRuleTask { return Ok(vec![]); } + if let Some(pathname) = &optimizer.prop.dot_file_path { + let path = Path::new(pathname).join(format!("{:#08}.dot", optimizer.next_dot_file)); + debug!("dumping memo to {:?}", path); + optimizer.next_dot_file += 1; + let mut writer = std::fs::File::create(path).unwrap(); + optimizer.dump_dot(&mut writer); + } + let rule = optimizer.rules()[self.rule_id].clone(); trace!(event = "task_begin", task = "apply_rule", expr_id = %self.expr_id, rule_id = %self.rule_id, rule = %rule.name()); diff --git a/optd-datafusion-repr/src/lib.rs b/optd-datafusion-repr/src/lib.rs index 5a491d1d..eb224db5 100644 --- a/optd-datafusion-repr/src/lib.rs +++ b/optd-datafusion-repr/src/lib.rs @@ -152,6 +152,7 @@ impl DatafusionOptimizer { partial_explore_iter: Some(1 << 20), partial_explore_space: Some(1 << 10), disable_pruning: false, + dot_file_path: None, }, ), heuristic_optimizer: HeuristicsOptimizer::new_with_rules(