diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index fd86bf2f21b47..bb997d9905915 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -17883,4 +17883,72 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), } ) } + + #[tokio::test] + async fn test_tableau_trunc_extract_year_and_month_rev() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let logical_plan = convert_select_to_query_plan( + r#" + SELECT SUM("KibanaSampleDataEcommerce"."sumPrice") AS "sum:sumPrice:ok" + FROM "public"."KibanaSampleDataEcommerce" "KibanaSampleDataEcommerce" + WHERE ( + "KibanaSampleDataEcommerce"."id" != 0 + AND CAST(TRUNC(EXTRACT(YEAR FROM "KibanaSampleDataEcommerce"."order_date")) AS INTEGER) = 2024 + AND CAST(TRUNC(EXTRACT(MONTH FROM "KibanaSampleDataEcommerce"."order_date")) AS INTEGER) = 2 + AND "KibanaSampleDataEcommerce"."customer_gender" IS NOT NULL + ) + HAVING COUNT(1) > 0 + "# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await + .as_logical_plan(); + + assert_eq!( + logical_plan.find_cube_scan().request, + V1LoadRequestQuery { + measures: Some(vec!["KibanaSampleDataEcommerce.sumPrice".to_string(),]), + dimensions: Some(vec![]), + segments: Some(vec![]), + time_dimensions: Some(vec![V1LoadRequestQueryTimeDimension { + dimension: "KibanaSampleDataEcommerce.order_date".to_string(), + granularity: None, + date_range: Some(json!(vec![ + "2024-02-01".to_string(), + "2024-02-29".to_string(), + ])), + }]), + order: Some(vec![]), + filters: Some(vec![ + V1LoadRequestQueryFilterItem { + member: Some("KibanaSampleDataEcommerce.id".to_string()), + operator: Some("notEquals".to_string()), + values: Some(vec!["0".to_string()]), + or: None, + and: None, + }, + V1LoadRequestQueryFilterItem { + member: Some("KibanaSampleDataEcommerce.customer_gender".to_string()), + operator: Some("set".to_string()), + values: None, + or: None, + and: None, + }, + V1LoadRequestQueryFilterItem { + member: Some("KibanaSampleDataEcommerce.count".to_string()), + operator: Some("gt".to_string()), + values: Some(vec!["0".to_string()]), + or: None, + and: None, + }, + ]), + ..Default::default() + } + ) + } } diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs index 2174ad74da60c..6980b200f1a34 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs @@ -50,6 +50,7 @@ use datafusion::{ }; use egg::{Subst, Var}; use std::{ + cmp::{max, min}, collections::HashSet, fmt::Display, ops::{Index, IndexMut}, @@ -1685,9 +1686,8 @@ impl RewriteRules for FilterRules { "?filter_aliases", ), filter_member("?member", "FilterMemberOp:inDateRange", "?values"), - self.transform_filter_extract_year_month_equals( + self.transform_filter_extract_year_equals( "?year", - None, "?column", "?alias_to_cube", "?members", @@ -1716,9 +1716,8 @@ impl RewriteRules for FilterRules { "?filter_aliases", ), filter_member("?member", "FilterMemberOp:inDateRange", "?values"), - self.transform_filter_extract_year_month_equals( + self.transform_filter_extract_year_equals( "?year", - None, "?column", "?alias_to_cube", "?members", @@ -1727,123 +1726,217 @@ impl RewriteRules for FilterRules { "?filter_aliases", ), ), - // TRUNC(EXTRACT(MONTH FROM "KibanaSampleDataEcommerce"."order_date")) = 3 - // AND TRUNC(EXTRACT(YEAR FROM "KibanaSampleDataEcommerce"."order_date")) = 2019 + // TRUNC(EXTRACT(YEAR FROM "KibanaSampleDataEcommerce"."order_date")) = 2019 + // AND TRUNC(EXTRACT(MONTH FROM "KibanaSampleDataEcommerce"."order_date")) = 3 transforming_rewrite( - "extract-trunc-year-and-month-equals", - filter_replacer( - binary_expr( - binary_expr( - self.fun_expr( - "Trunc", - vec![self.fun_expr( - "DatePart", - vec![literal_string("month"), column_expr("?column")], - )], - ), - "=", - literal_expr("?month"), - ), - "AND", - binary_expr( - self.fun_expr( - "Trunc", - vec![self.fun_expr( - "DatePart", - vec![literal_string("year"), column_expr("?column")], - )], + "extract-date-range-and-trunc-gran-equals", + filter_op( + filter_op_filters( + filter_member("?member", "FilterMemberOp:inDateRange", "?values"), + filter_replacer( + binary_expr( + self.fun_expr( + "Trunc", + vec![self.fun_expr( + "DatePart", + vec![literal_expr("?granularity"), column_expr("?column")], + )], + ), + "=", + literal_expr("?value"), ), - "=", - literal_expr("?year"), + "?alias_to_cube", + "?members", + "?filter_aliases", ), ), - "?alias_to_cube", - "?members", - "?filter_aliases", + "FilterOpOp:and", ), - filter_member("?member", "FilterMemberOp:inDateRange", "?values"), - self.transform_filter_extract_year_month_equals( - "?year", - Some("?month"), + filter_member("?member", "FilterMemberOp:inDateRange", "?new_values"), + self.transform_filter_extract_date_range_and_trunc_gran_equals( + "?member", + "?values", + "?granularity", "?column", + "?value", "?alias_to_cube", "?members", - "?member", - "?values", "?filter_aliases", + "?new_values", ), ), // When the filter set above is paired with other filters, it needs to be // regrouped for the above rewrite rule to match rewrite( - "extract-trunc-year-and-month-equals-regroup-binary", - filter_replacer( - binary_expr( - binary_expr( - "?expr", - "AND", + "extract-date-range-and-trunc-regroup-and", + filter_op( + filter_op_filters( + filter_op( + filter_op_filters( + "?expr", + filter_member("?member", "FilterMemberOp:inDateRange", "?values"), + ), + "FilterOpOp:and", + ), + filter_replacer( binary_expr( self.fun_expr( "Trunc", vec![self.fun_expr( "DatePart", - vec![literal_string("month"), column_expr("?column")], + vec![literal_expr("?granularity"), column_expr("?column")], )], ), "=", - literal_expr("?month"), + literal_expr("?value"), ), + "?alias_to_cube", + "?members", + "?filter_aliases", ), - "AND", - binary_expr( - self.fun_expr( - "Trunc", - vec![self.fun_expr( - "DatePart", - vec![literal_string("year"), column_expr("?column")], - )], + ), + "FilterOpOp:and", + ), + filter_op( + filter_op_filters( + "?expr", + filter_op( + filter_op_filters( + filter_member("?member", "FilterMemberOp:inDateRange", "?values"), + filter_replacer( + binary_expr( + self.fun_expr( + "Trunc", + vec![self.fun_expr( + "DatePart", + vec![ + literal_expr("?granularity"), + column_expr("?column"), + ], + )], + ), + "=", + literal_expr("?value"), + ), + "?alias_to_cube", + "?members", + "?filter_aliases", + ), ), - "=", - literal_expr("?year"), + "FilterOpOp:and", ), ), - "?alias_to_cube", - "?members", - "?filter_aliases", + "FilterOpOp:and", ), - filter_replacer( - binary_expr( - "?expr", - "AND", - binary_expr( + ), + // The filter set above may be inverted, let's account for that as well + rewrite( + "extract-date-range-and-trunc-reverse", + filter_op( + filter_op_filters( + filter_replacer( binary_expr( self.fun_expr( "Trunc", vec![self.fun_expr( "DatePart", - vec![literal_string("month"), column_expr("?column")], + vec![literal_expr("?granularity"), column_expr("?column")], )], ), "=", - literal_expr("?month"), + literal_expr("?value"), ), - "AND", + "?alias_to_cube", + "?members", + "?filter_aliases", + ), + filter_member("?member", "FilterMemberOp:inDateRange", "?values"), + ), + "FilterOpOp:and", + ), + filter_op( + filter_op_filters( + filter_member("?member", "FilterMemberOp:inDateRange", "?values"), + filter_replacer( binary_expr( self.fun_expr( "Trunc", vec![self.fun_expr( "DatePart", - vec![literal_string("year"), column_expr("?column")], + vec![literal_expr("?granularity"), column_expr("?column")], )], ), "=", - literal_expr("?year"), + literal_expr("?value"), ), + "?alias_to_cube", + "?members", + "?filter_aliases", ), ), - "?alias_to_cube", - "?members", - "?filter_aliases", + "FilterOpOp:and", + ), + ), + rewrite( + "extract-date-range-and-trunc-reverse-nested", + filter_op( + filter_op_filters( + filter_op( + filter_op_filters( + "?expr", + filter_replacer( + binary_expr( + self.fun_expr( + "Trunc", + vec![self.fun_expr( + "DatePart", + vec![ + literal_expr("?granularity"), + column_expr("?column"), + ], + )], + ), + "=", + literal_expr("?value"), + ), + "?alias_to_cube", + "?members", + "?filter_aliases", + ), + ), + "FilterOpOp:and", + ), + filter_member("?member", "FilterMemberOp:inDateRange", "?values"), + ), + "FilterOpOp:and", + ), + filter_op( + filter_op_filters( + filter_op( + filter_op_filters( + "?expr", + filter_member("?member", "FilterMemberOp:inDateRange", "?values"), + ), + "FilterOpOp:and", + ), + filter_replacer( + binary_expr( + self.fun_expr( + "Trunc", + vec![self.fun_expr( + "DatePart", + vec![literal_expr("?granularity"), column_expr("?column")], + )], + ), + "=", + literal_expr("?value"), + ), + "?alias_to_cube", + "?members", + "?filter_aliases", + ), + ), + "FilterOpOp:and", ), ), transforming_rewrite( @@ -3697,10 +3790,9 @@ impl FilterRules { } } - fn transform_filter_extract_year_month_equals( + fn transform_filter_extract_year_equals( &self, year_var: &'static str, - month_var: Option<&'static str>, column_var: &'static str, alias_to_cube_var: &'static str, members_var: &'static str, @@ -3709,7 +3801,6 @@ impl FilterRules { filter_aliases_var: &'static str, ) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool { let year_var = var!(year_var); - let month_var = month_var.map(|var| var!(var)); let column_var = var!(column_var); let alias_to_cube_var = var!(alias_to_cube_var); let members_var = var!(members_var); @@ -3718,121 +3809,230 @@ impl FilterRules { let filter_aliases_var = var!(filter_aliases_var); let meta_context = self.meta_context.clone(); move |egraph, subst| { - let Some(year) = - var_iter!(egraph[subst[year_var]], LiteralExprValue).find_map(|year| { + let years: Vec = var_iter!(egraph[subst[year_var]], LiteralExprValue) + .cloned() + .collect(); + if years.is_empty() { + return false; + } + let aliases_es: Vec> = + var_iter!(egraph[subst[filter_aliases_var]], FilterReplacerAliases) + .cloned() + .collect(); + for year in years { + for aliases in aliases_es.iter() { let year = match year { - ScalarValue::Int64(Some(year)) => *year, - ScalarValue::Int32(Some(year)) => *year as i64, - ScalarValue::Float64(Some(year)) if (1000.0..=9999.0).contains(year) => { + ScalarValue::Int64(Some(year)) => year, + ScalarValue::Int32(Some(year)) => year as i64, + ScalarValue::Float64(Some(year)) if (1000.0..=9999.0).contains(&year) => { year.round() as i64 } ScalarValue::Utf8(Some(ref year_str)) if year_str.len() == 4 => { if let Ok(year) = year_str.parse::() { year } else { - return None; + continue; } } - _ => return None, + _ => continue, }; + if !(1000..=9999).contains(&year) { + continue; + } + + if let Some((member_name, cube)) = Self::filter_member_name( + egraph, + subst, + &meta_context, + alias_to_cube_var, + column_var, + members_var, + &aliases, + ) { + if !cube.contains_member(&member_name) { + continue; + } + + subst.insert( + member_var, + egraph.add(LogicalPlanLanguage::FilterMemberMember( + FilterMemberMember(member_name.to_string()), + )), + ); + + subst.insert( + values_var, + egraph.add(LogicalPlanLanguage::FilterMemberValues( + FilterMemberValues(vec![ + format!("{}-01-01", year), + format!("{}-12-31", year), + ]), + )), + ); + + return true; + } + } + } + + false + } + } + + fn transform_filter_extract_date_range_and_trunc_gran_equals( + &self, + member_var: &'static str, + values_var: &'static str, + granularity_var: &'static str, + column_var: &'static str, + value_var: &'static str, + alias_to_cube_var: &'static str, + members_var: &'static str, + filter_aliases_var: &'static str, + new_values_var: &'static str, + ) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool { + let member_var = var!(member_var); + let values_var = var!(values_var); + let granularity_var = var!(granularity_var); + let column_var = var!(column_var); + let value_var = var!(value_var); + let alias_to_cube_var = var!(alias_to_cube_var); + let members_var = var!(members_var); + let filter_aliases_var = var!(filter_aliases_var); + let new_values_var = var!(new_values_var); + let meta_context = self.meta_context.clone(); + move |egraph, subst| { + // Validate that the member name is the same as the passed column + let member_names = var_iter!(egraph[subst[member_var]], FilterMemberMember) + .cloned() + .collect::>(); + let aliases_es = var_iter!(egraph[subst[filter_aliases_var]], FilterReplacerAliases) + .cloned() + .collect::>(); + let mut equal_member = false; + 'member: for member in member_names { + for aliases in &aliases_es { + let Some((member_name, cube)) = Self::filter_member_name( + egraph, + subst, + &meta_context, + alias_to_cube_var, + column_var, + members_var, + &aliases, + ) else { + continue; + }; + + if !cube.contains_member(&member_name) { + continue; + } + + if member_name != member { + continue; + } + + equal_member = true; + break 'member; + } + } + if !equal_member { + return false; + } + + // Get the original dates + let Some((start_date, end_date)) = + var_iter!(egraph[subst[values_var]], FilterMemberValues).find_map(|values| { + if values.len() != 2 { return None; } - Some(year as i32) + // Only date formats are supported for now, no timestamps + let start_date = NaiveDate::parse_from_str(&values[0], "%Y-%m-%d").ok()?; + let end_date = NaiveDate::parse_from_str(&values[1], "%Y-%m-%d").ok()?; + Some((start_date, end_date)) }) else { return false; }; - let month = if let Some(month_var) = month_var { - let month = - var_iter!(egraph[subst[month_var]], LiteralExprValue).find_map(|month| { - let month = match month { - ScalarValue::Int64(Some(month)) => *month, - ScalarValue::Int32(Some(month)) => *month as i64, - ScalarValue::Float64(Some(month)) if (1.0..=12.0).contains(month) => { - month.round() as i64 - } - ScalarValue::Utf8(Some(ref month_str)) - if (1..=2).contains(&month_str.len()) => - { - if let Ok(month) = month_str.parse::() { - month - } else { - return None; - } - } - _ => return None, - }; - if !(1..=12).contains(&month) { - return None; - } - Some(month as u32) - }); - if month.is_none() { - return false; - } - month - } else { - None + // Get the new granularity + let Some(granularity) = var_iter!(egraph[subst[granularity_var]], LiteralExprValue) + .find_map(|granularity| { + if let ScalarValue::Utf8(Some(granularity)) = granularity { + Some(granularity.clone()) + } else { + None + } + }) + else { + return false; }; - let last_day = { - let month = month.unwrap_or(12); - let next_month = if month == 12 { 1 } else { month + 1 }; - let next_month_year = if month == 12 { year + 1 } else { year }; - let Some(next_month_first_date) = - NaiveDate::from_ymd_opt(next_month_year, next_month, 1) - else { - return false; - }; - let Some(last_day_date) = next_month_first_date.checked_sub_days(Days::new(1)) - else { - return false; - }; - last_day_date.day() + // Get the value for that granularity + let Some(value) = var_iter!(egraph[subst[value_var]], LiteralExprValue).find_map( + |value| match value { + ScalarValue::Int64(Some(value)) => Some(*value), + ScalarValue::Int32(Some(value)) => Some(*value as i64), + ScalarValue::Float64(Some(value)) if (0.0..=9999.0).contains(value) => { + Some(value.round() as i64) + } + ScalarValue::Utf8(Some(value_str)) => value_str.parse::().ok(), + _ => None, + }, + ) else { + return false; }; - let aliases_es: Vec> = - var_iter!(egraph[subst[filter_aliases_var]], FilterReplacerAliases) - .cloned() - .collect(); - for aliases in aliases_es.iter() { - if let Some((member_name, cube)) = Self::filter_member_name( - egraph, - subst, - &meta_context, - alias_to_cube_var, - column_var, - members_var, - &aliases, - ) { - if !cube.contains_member(&member_name) { - continue; + let new_values = match granularity.as_str() { + "month" => { + // Check that the range only covers one year + let start_date_year = start_date.year(); + if start_date_year != end_date.year() { + return false; + } + // Month value must be valid + if !(1..=12).contains(&value) { + return false; } - subst.insert( - member_var, - egraph.add(LogicalPlanLanguage::FilterMemberMember(FilterMemberMember( - member_name.to_string(), - ))), - ); + // Obtain the new range + let Some(new_start_date) = + NaiveDate::from_ymd_opt(start_date_year, value as u32, 1) + else { + return false; + }; + let Some(new_end_date) = new_start_date + .checked_add_months(Months::new(1)) + .and_then(|date| date.checked_sub_days(Days::new(1))) + else { + return false; + }; - let date_range_start = format!("{}-{:0>2}-01", year, month.unwrap_or(1)); - let date_range_end = - format!("{}-{:0>2}-{}", year, month.unwrap_or(12), last_day); - subst.insert( - values_var, - egraph.add(LogicalPlanLanguage::FilterMemberValues(FilterMemberValues( - vec![date_range_start, date_range_end], - ))), - ); + // If the resulting range is outside of the original range, we can't merge + // the filters + if new_start_date > end_date || new_end_date < start_date { + return false; + } - return true; + let new_start_date = max(new_start_date, start_date); + let new_end_date = min(new_end_date, end_date); + vec![ + new_start_date.format("%Y-%m-%d").to_string(), + new_end_date.format("%Y-%m-%d").to_string(), + ] } - } + // TODO: handle more granularities + _ => return false, + }; - false + subst.insert( + new_values_var, + egraph.add(LogicalPlanLanguage::FilterMemberValues(FilterMemberValues( + new_values, + ))), + ); + true } }