Skip to content

Commit 43ae2dc

Browse files
authored
chore: AggregateFnRef [de]serialize (#7070)
Signed-off-by: blaginin <github@blaginin.me>
1 parent ea106e4 commit 43ae2dc

File tree

12 files changed

+241
-10
lines changed

12 files changed

+241
-10
lines changed

vortex-array/public-api.lock

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::reset(&self, pa
8282

8383
pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
8484

85-
pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
85+
pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
8686

8787
pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
8888

@@ -120,7 +120,7 @@ pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::coerce_args(&self,
120120

121121
pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>
122122

123-
pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
123+
pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::deserialize(&self, metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
124124

125125
pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>
126126

@@ -228,7 +228,7 @@ pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::reset(&self, partial: &
228228

229229
pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
230230

231-
pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
231+
pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
232232

233233
pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
234234

@@ -306,7 +306,7 @@ pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::reset(&self, partia
306306

307307
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
308308

309-
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
309+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
310310

311311
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
312312

@@ -368,7 +368,7 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::reset(&self, partial: &mut Sel
368368

369369
pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
370370

371-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
371+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
372372

373373
pub fn vortex_array::aggregate_fn::fns::sum::Sum::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
374374

@@ -388,6 +388,8 @@ pub fn vortex_array::aggregate_fn::kernels::DynGroupedAggregateKernel::grouped_a
388388

389389
pub fn vortex_array::aggregate_fn::kernels::DynGroupedAggregateKernel::grouped_aggregate_fixed_size(&self, aggregate_fn: &vortex_array::aggregate_fn::AggregateFnRef, groups: &vortex_array::arrays::FixedSizeListArray) -> vortex_error::VortexResult<core::option::Option<vortex_array::ArrayRef>>
390390

391+
pub mod vortex_array::aggregate_fn::proto
392+
391393
pub mod vortex_array::aggregate_fn::session
392394

393395
pub struct vortex_array::aggregate_fn::session::AggregateFnSession
@@ -500,6 +502,12 @@ pub fn vortex_array::aggregate_fn::AggregateFnRef::state_dtype(&self, input_dtyp
500502

501503
pub fn vortex_array::aggregate_fn::AggregateFnRef::vtable_ref<V: vortex_array::aggregate_fn::AggregateFnVTable>(&self) -> core::option::Option<&V>
502504

505+
impl vortex_array::aggregate_fn::AggregateFnRef
506+
507+
pub fn vortex_array::aggregate_fn::AggregateFnRef::from_proto(proto: &vortex_proto::expr::AggregateFn, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self>
508+
509+
pub fn vortex_array::aggregate_fn::AggregateFnRef::serialize_proto(&self) -> vortex_error::VortexResult<vortex_proto::expr::AggregateFn>
510+
503511
impl core::clone::Clone for vortex_array::aggregate_fn::AggregateFnRef
504512

505513
pub fn vortex_array::aggregate_fn::AggregateFnRef::clone(&self) -> vortex_array::aggregate_fn::AggregateFnRef
@@ -638,7 +646,7 @@ pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::reset(&self, pa
638646

639647
pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
640648

641-
pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
649+
pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
642650

643651
pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
644652

@@ -654,7 +662,7 @@ pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::coerce_args(&self,
654662

655663
pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>
656664

657-
pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
665+
pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::deserialize(&self, metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
658666

659667
pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>
660668

@@ -706,7 +714,7 @@ pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::reset(&self, partial: &
706714

707715
pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
708716

709-
pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
717+
pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
710718

711719
pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
712720

@@ -740,7 +748,7 @@ pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::reset(&self, partia
740748

741749
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
742750

743-
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
751+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
744752

745753
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
746754

@@ -774,7 +782,7 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::reset(&self, partial: &mut Sel
774782

775783
pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
776784

777-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
785+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
778786

779787
pub fn vortex_array::aggregate_fn::fns::sum::Sum::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
780788

vortex-array/src/aggregate_fn/fns/is_constant/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,18 @@ impl AggregateFnVTable for IsConstant {
259259
AggregateFnId::new_ref("vortex.is_constant")
260260
}
261261

262+
fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
263+
Ok(Some(vec![]))
264+
}
265+
266+
fn deserialize(
267+
&self,
268+
_metadata: &[u8],
269+
_session: &vortex_session::VortexSession,
270+
) -> VortexResult<Self::Options> {
271+
Ok(EmptyOptions)
272+
}
273+
262274
fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
263275
match input_dtype {
264276
DType::Null => None,

vortex-array/src/aggregate_fn/fns/is_sorted/mod.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use std::fmt::Formatter;
1313

1414
use vortex_error::VortexExpect;
1515
use vortex_error::VortexResult;
16+
use vortex_error::vortex_bail;
1617

1718
use self::bool::check_bool_sorted;
1819
use self::decimal::check_decimal_sorted;
@@ -231,6 +232,29 @@ impl AggregateFnVTable for IsSorted {
231232
AggregateFnId::new_ref("vortex.is_sorted")
232233
}
233234

235+
fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
236+
Ok(Some(vec![u8::from(options.strict)]))
237+
}
238+
239+
fn deserialize(
240+
&self,
241+
metadata: &[u8],
242+
_session: &vortex_session::VortexSession,
243+
) -> VortexResult<Self::Options> {
244+
let &[strict_byte] = metadata else {
245+
vortex_bail!(
246+
"IsSorted: expected 1 byte of metadata, got {}",
247+
metadata.len()
248+
);
249+
};
250+
let strict = match strict_byte {
251+
0 => false,
252+
1 => true,
253+
_ => vortex_bail!("IsSorted: expected 0 or 1 for strict, got {}", strict_byte),
254+
};
255+
Ok(IsSortedOptions { strict })
256+
}
257+
234258
fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
235259
match input_dtype {
236260
DType::Null | DType::Struct(..) | DType::List(..) | DType::FixedSizeList(..) => None,

vortex-array/src/aggregate_fn/fns/min_max/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,18 @@ impl AggregateFnVTable for MinMax {
176176
AggregateFnId::new_ref("vortex.min_max")
177177
}
178178

179+
fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
180+
Ok(Some(vec![]))
181+
}
182+
183+
fn deserialize(
184+
&self,
185+
_metadata: &[u8],
186+
_session: &vortex_session::VortexSession,
187+
) -> VortexResult<Self::Options> {
188+
Ok(EmptyOptions)
189+
}
190+
179191
fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
180192
match input_dtype {
181193
DType::Bool(_)

vortex-array/src/aggregate_fn/fns/nan_count/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,18 @@ impl AggregateFnVTable for NanCount {
8484
AggregateFnId::new_ref("vortex.nan_count")
8585
}
8686

87+
fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
88+
Ok(Some(vec![]))
89+
}
90+
91+
fn deserialize(
92+
&self,
93+
_metadata: &[u8],
94+
_session: &vortex_session::VortexSession,
95+
) -> VortexResult<Self::Options> {
96+
Ok(EmptyOptions)
97+
}
98+
8799
fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
88100
if let DType::Primitive(ptype, ..) = input_dtype
89101
&& ptype.is_float()

vortex-array/src/aggregate_fn/fns/sum/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,18 @@ impl AggregateFnVTable for Sum {
7474
AggregateFnId::new_ref("vortex.sum")
7575
}
7676

77+
fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
78+
Ok(Some(vec![]))
79+
}
80+
81+
fn deserialize(
82+
&self,
83+
_metadata: &[u8],
84+
_session: &vortex_session::VortexSession,
85+
) -> VortexResult<Self::Options> {
86+
Ok(EmptyOptions)
87+
}
88+
7789
fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
7890
// When a sum overflows, we return a sum _value_ of null. Therefore, we all return dtypes
7991
// are nullable.

vortex-array/src/aggregate_fn/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub use options::*;
3131

3232
pub mod fns;
3333
pub mod kernels;
34+
pub mod proto;
3435
pub mod session;
3536

3637
/// A unique identifier for an aggregate function.
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::sync::Arc;
5+
6+
use arcref::ArcRef;
7+
use vortex_error::VortexResult;
8+
use vortex_error::vortex_bail;
9+
use vortex_error::vortex_err;
10+
use vortex_proto::expr as pb;
11+
use vortex_session::VortexSession;
12+
13+
use crate::aggregate_fn::AggregateFnId;
14+
use crate::aggregate_fn::AggregateFnRef;
15+
use crate::aggregate_fn::session::AggregateFnSessionExt;
16+
17+
impl AggregateFnRef {
18+
/// Serialize this aggregate function to its protobuf representation.
19+
///
20+
/// Note: the serialization format is not stable and may change between versions.
21+
pub fn serialize_proto(&self) -> VortexResult<pb::AggregateFn> {
22+
let metadata = self
23+
.options()
24+
.serialize()?
25+
.ok_or_else(|| vortex_err!("Aggregate function '{}' is not serializable", self.id()))?;
26+
27+
Ok(pb::AggregateFn {
28+
id: self.id().to_string(),
29+
metadata: Some(metadata),
30+
})
31+
}
32+
33+
/// Deserialize an aggregate function from its protobuf representation.
34+
///
35+
/// Looks up the aggregate function plugin by ID in the session's registry
36+
/// and delegates deserialization to it.
37+
///
38+
/// Note: the serialization format is not stable and may change between versions.
39+
pub fn from_proto(proto: &pb::AggregateFn, session: &VortexSession) -> VortexResult<Self> {
40+
let agg_fn_id: AggregateFnId = ArcRef::new_arc(Arc::from(proto.id.as_str()));
41+
let plugin = session
42+
.aggregate_fns()
43+
.registry()
44+
.find(&agg_fn_id)
45+
.ok_or_else(|| vortex_err!("unknown aggregate function id: {}", proto.id))?;
46+
let agg_fn = plugin.deserialize(proto.metadata(), session)?;
47+
48+
if agg_fn.id() != agg_fn_id {
49+
vortex_bail!(
50+
"Aggregate function ID mismatch: expected {}, got {}",
51+
agg_fn_id,
52+
agg_fn.id()
53+
);
54+
}
55+
56+
Ok(agg_fn)
57+
}
58+
}
59+
60+
#[cfg(test)]
61+
mod tests {
62+
use prost::Message;
63+
use vortex_proto::expr as pb;
64+
use vortex_session::VortexSession;
65+
66+
use crate::aggregate_fn::AggregateFnRef;
67+
use crate::aggregate_fn::AggregateFnVTableExt;
68+
use crate::aggregate_fn::EmptyOptions;
69+
use crate::aggregate_fn::fns::sum::Sum;
70+
use crate::aggregate_fn::session::AggregateFnSession;
71+
use crate::aggregate_fn::session::AggregateFnSessionExt;
72+
73+
#[test]
74+
fn aggregate_fn_serde() {
75+
let session = VortexSession::empty().with::<AggregateFnSession>();
76+
session.aggregate_fns().register(Sum);
77+
78+
let agg_fn = Sum.bind(EmptyOptions);
79+
80+
let serialized = agg_fn.serialize_proto().unwrap();
81+
let buf = serialized.encode_to_vec();
82+
let deserialized_proto = pb::AggregateFn::decode(buf.as_slice()).unwrap();
83+
let deserialized = AggregateFnRef::from_proto(&deserialized_proto, &session).unwrap();
84+
85+
assert_eq!(deserialized, agg_fn);
86+
}
87+
}

vortex-array/src/aggregate_fn/session.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ use crate::aggregate_fn::AggregateFnVTable;
1515
use crate::aggregate_fn::fns::is_constant::IsConstant;
1616
use crate::aggregate_fn::fns::is_sorted::IsSorted;
1717
use crate::aggregate_fn::fns::min_max::MinMax;
18+
use crate::aggregate_fn::fns::nan_count::NanCount;
19+
use crate::aggregate_fn::fns::sum::Sum;
1820
use crate::aggregate_fn::kernels::DynAggregateKernel;
1921
use crate::aggregate_fn::kernels::DynGroupedAggregateKernel;
2022
use crate::arrays::Chunked;
@@ -47,6 +49,13 @@ impl Default for AggregateFnSession {
4749
grouped_kernels: RwLock::new(HashMap::default()),
4850
};
4951

52+
// Register the built-in aggregate functions
53+
this.register(IsConstant);
54+
this.register(IsSorted);
55+
this.register(MinMax);
56+
this.register(NanCount);
57+
this.register(Sum);
58+
5059
// Register the built-in aggregate kernels.
5160
this.register_aggregate_kernel(Chunked::ID, None, &ChunkedArrayAggregate);
5261
this.register_aggregate_kernel(Dict::ID, Some(MinMax.id()), &DictMinMaxKernel);

vortex-proto/proto/expr.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ message Expr {
2020
optional bytes metadata = 3;
2121
}
2222

23+
// Captures a serialized aggregate function with its ID and options metadata.
24+
message AggregateFn {
25+
string id = 1;
26+
optional bytes metadata = 2;
27+
}
28+
2329
// Options for `vortex.literal`
2430
message LiteralOpts {
2531
vortex.scalar.Scalar value = 1;

0 commit comments

Comments
 (0)