From c5bd7bc08f0824b7717ed1a4be0567026298419e Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Mon, 26 Aug 2024 12:42:21 +0300 Subject: [PATCH] Serialize and deserialize `NullTreatment` (#261) --- datafusion/proto/proto/datafusion.proto | 2 ++ datafusion/proto/src/generated/pbjson.rs | 36 +++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 4 +++ .../proto/src/logical_plan/from_proto.rs | 9 +++-- datafusion/proto/src/logical_plan/to_proto.rs | 16 +++++---- 5 files changed, 56 insertions(+), 11 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 2a9fa368eb71..8497c21a3568 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -515,6 +515,7 @@ message AggregateExprNode { bool distinct = 3; LogicalExprNode filter = 4; repeated LogicalExprNode order_by = 5; + bool ignore_nulls = 6; } message AggregateUDFExprNode { @@ -524,6 +525,7 @@ message AggregateUDFExprNode { repeated LogicalExprNode order_by = 4; bool distinct = 5; optional bytes fun_definition = 6; + bool ignore_nulls = 7; } message ScalarUDFExprNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 5e98d2ab6984..15fe00201e6f 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -385,6 +385,9 @@ impl serde::Serialize for AggregateExprNode { if !self.order_by.is_empty() { len += 1; } + if self.ignore_nulls { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AggregateExprNode", len)?; if self.aggr_function != 0 { let v = AggregateFunction::try_from(self.aggr_function) @@ -403,6 +406,9 @@ impl serde::Serialize for AggregateExprNode { if !self.order_by.is_empty() { struct_ser.serialize_field("orderBy", &self.order_by)?; } + if self.ignore_nulls { + struct_ser.serialize_field("ignoreNulls", &self.ignore_nulls)?; + } struct_ser.end() } } @@ -420,6 +426,8 @@ impl<'de> serde::Deserialize<'de> for AggregateExprNode { "filter", "order_by", "orderBy", + "ignore_nulls", + "ignoreNulls", ]; #[allow(clippy::enum_variant_names)] @@ -429,6 +437,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExprNode { Distinct, Filter, OrderBy, + IgnoreNulls, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -455,6 +464,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExprNode { "distinct" => Ok(GeneratedField::Distinct), "filter" => Ok(GeneratedField::Filter), "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), + "ignoreNulls" | "ignore_nulls" => Ok(GeneratedField::IgnoreNulls), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -479,6 +489,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExprNode { let mut distinct__ = None; let mut filter__ = None; let mut order_by__ = None; + let mut ignore_nulls__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::AggrFunction => { @@ -511,6 +522,12 @@ impl<'de> serde::Deserialize<'de> for AggregateExprNode { } order_by__ = Some(map_.next_value()?); } + GeneratedField::IgnoreNulls => { + if ignore_nulls__.is_some() { + return Err(serde::de::Error::duplicate_field("ignoreNulls")); + } + ignore_nulls__ = Some(map_.next_value()?); + } } } Ok(AggregateExprNode { @@ -519,6 +536,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExprNode { distinct: distinct__.unwrap_or_default(), filter: filter__, order_by: order_by__.unwrap_or_default(), + ignore_nulls: ignore_nulls__.unwrap_or_default(), }) } } @@ -916,6 +934,9 @@ impl serde::Serialize for AggregateUdfExprNode { if self.fun_definition.is_some() { len += 1; } + if self.ignore_nulls { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AggregateUDFExprNode", len)?; if !self.fun_name.is_empty() { struct_ser.serialize_field("funName", &self.fun_name)?; @@ -936,6 +957,9 @@ impl serde::Serialize for AggregateUdfExprNode { #[allow(clippy::needless_borrow)] struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; } + if self.ignore_nulls { + struct_ser.serialize_field("ignoreNulls", &self.ignore_nulls)?; + } struct_ser.end() } } @@ -955,6 +979,8 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { "distinct", "fun_definition", "funDefinition", + "ignore_nulls", + "ignoreNulls", ]; #[allow(clippy::enum_variant_names)] @@ -965,6 +991,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { OrderBy, Distinct, FunDefinition, + IgnoreNulls, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -992,6 +1019,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), "distinct" => Ok(GeneratedField::Distinct), "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), + "ignoreNulls" | "ignore_nulls" => Ok(GeneratedField::IgnoreNulls), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -1017,6 +1045,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { let mut order_by__ = None; let mut distinct__ = None; let mut fun_definition__ = None; + let mut ignore_nulls__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::FunName => { @@ -1057,6 +1086,12 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) ; } + GeneratedField::IgnoreNulls => { + if ignore_nulls__.is_some() { + return Err(serde::de::Error::duplicate_field("ignoreNulls")); + } + ignore_nulls__ = Some(map_.next_value()?); + } } } Ok(AggregateUdfExprNode { @@ -1066,6 +1101,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { order_by: order_by__.unwrap_or_default(), distinct: distinct__.unwrap_or_default(), fun_definition: fun_definition__, + ignore_nulls: ignore_nulls__.unwrap_or_default(), }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index b62e9ff34810..7a608b24a4ac 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -759,6 +759,8 @@ pub struct AggregateExprNode { pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "5")] pub order_by: ::prost::alloc::vec::Vec, + #[prost(bool, tag = "6")] + pub ignore_nulls: bool, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -775,6 +777,8 @@ pub struct AggregateUdfExprNode { pub distinct: bool, #[prost(bytes = "vec", optional, tag = "6")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, + #[prost(bool, tag = "7")] + pub ignore_nulls: bool, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b467c60001b6..057cc6b852c5 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -18,6 +18,7 @@ use std::sync::Arc; use datafusion::execution::registry::FunctionRegistry; +use datafusion::sql::sqlparser::ast::NullTreatment; use datafusion_common::{ internal_err, plan_datafusion_err, DataFusionError, Result, ScalarValue, TableReference, UnnestOptions, @@ -410,16 +411,14 @@ pub fn parse_expr( } } ExprType::AggregateExpr(expr) => { - let fun = parse_i32_to_aggregate_function(&expr.aggr_function)?; - Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, + parse_i32_to_aggregate_function(&expr.aggr_function)?, parse_exprs(&expr.expr, registry, codec)?, expr.distinct, parse_optional_expr(expr.filter.as_deref(), registry, codec)? .map(Box::new), parse_vec_expr(&expr.order_by, registry, codec)?, - None, + expr.ignore_nulls.then_some(NullTreatment::IgnoreNulls), ))) } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( @@ -661,7 +660,7 @@ pub fn parse_expr( pb.distinct, parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), parse_vec_expr(&pb.order_by, registry, codec)?, - None, + pb.ignore_nulls.then_some(NullTreatment::IgnoreNulls), ))) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 06f1aea7dfa2..12025eb99e5b 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -19,6 +19,7 @@ //! DataFusion logical plans to be serialized and transmitted between //! processes. +use datafusion::sql::sqlparser::ast::NullTreatment; use datafusion_common::{TableReference, UnnestOptions}; use datafusion_expr::expr::{ self, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Cast, GroupingSet, @@ -401,12 +402,12 @@ pub fn serialize_expr( } } Expr::AggregateFunction(expr::AggregateFunction { - ref func_def, - ref args, - ref distinct, - ref filter, - ref order_by, - null_treatment: _, + func_def, + args, + distinct, + filter, + order_by, + null_treatment, }) => match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { let aggr_function = match fun { @@ -479,6 +480,7 @@ pub fn serialize_expr( Some(e) => serialize_exprs(e, codec)?, None => vec![], }, + ignore_nulls: null_treatment == &Some(NullTreatment::IgnoreNulls), }; protobuf::LogicalExprNode { expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), @@ -504,6 +506,8 @@ pub fn serialize_expr( }, distinct: *distinct, fun_definition: (!buf.is_empty()).then_some(buf), + ignore_nulls: null_treatment + == &Some(NullTreatment::IgnoreNulls), }, ))), }