diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index ec850047e586..e52aeb1aee12 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -15,13 +15,21 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, Int32Array, Int32Builder}; -use arrow::datatypes::{Field, Schema}; +use arrow::array::{Array, ArrayRef, Int32Array, Int32Builder, StringArray}; +use arrow::datatypes::{ArrowNativeTypeOp, Field, Schema}; use arrow::record_batch::RecordBatch; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use arrow::util::test_util::seedable_rng; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{case, col, lit, BinaryExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use itertools::Itertools; +use rand::distr::uniform::SampleUniform; +use rand::distr::Alphanumeric; +use rand::rngs::StdRng; +use rand::{Rng, RngCore}; +use std::fmt::{Display, Formatter}; +use std::ops::Range; use std::sync::Arc; fn make_x_cmp_y( @@ -82,6 +90,8 @@ fn criterion_benchmark(c: &mut Criterion) { run_benchmarks(c, &make_batch(8192, 3)); run_benchmarks(c, &make_batch(8192, 50)); run_benchmarks(c, &make_batch(8192, 100)); + + benchmark_lookup_table_case_when(c, 8192); } fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) { @@ -230,5 +240,281 @@ fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) { }); } +struct Options { + number_of_rows: usize, + range_of_values: Vec, + in_range_probability: f32, + null_probability: f32, +} + +fn generate_other_primitive_value( + rng: &mut impl RngCore, + exclude: &[T], +) -> T { + let mut value; + let retry_limit = 100; + for _ in 0..retry_limit { + value = rng.random_range(T::MIN_TOTAL_ORDER..=T::MAX_TOTAL_ORDER); + if !exclude.contains(&value) { + return value; + } + } + + panic!("Could not generate out of range value after {retry_limit} attempts"); +} + +fn create_random_string_generator( + length: Range, +) -> impl Fn(&mut dyn RngCore, &[String]) -> String { + assert!(length.end > length.start); + + move |rng, exclude| { + let retry_limit = 100; + for _ in 0..retry_limit { + let length = rng.random_range(length.clone()); + let value: String = rng + .sample_iter(Alphanumeric) + .take(length) + .map(char::from) + .collect(); + + if !exclude.contains(&value) { + return value; + } + } + + panic!("Could not generate out of range value after {retry_limit} attempts"); + } +} + +/// Create column with the provided number of rows +/// `in_range_percentage` is the percentage of values that should be inside the specified range +/// `null_percentage` is the percentage of null values +/// The rest of the values will be outside the specified range +fn generate_values_for_lookup( + options: Options, + generate_other_value: impl Fn(&mut StdRng, &[T]) -> T, +) -> A +where + T: Clone, + A: FromIterator>, +{ + // Create a value with specified range most of the time, but also some nulls and the rest is generic + + assert!( + options.in_range_probability + options.null_probability <= 1.0, + "Percentages must sum to 1.0 or less" + ); + + let rng = &mut seedable_rng(); + + let in_range_probability = 0.0..options.in_range_probability; + let null_range_probability = + in_range_probability.start..in_range_probability.start + options.null_probability; + let out_range_probability = null_range_probability.end..1.0; + + (0..options.number_of_rows) + .map(|_| { + let roll: f32 = rng.random(); + + match roll { + v if out_range_probability.contains(&v) => { + let index = rng.random_range(0..options.range_of_values.len()); + // Generate value in range + Some(options.range_of_values[index].clone()) + } + v if null_range_probability.contains(&v) => None, + _ => { + // Generate value out of range + Some(generate_other_value(rng, &options.range_of_values)) + } + } + }) + .collect::() +} + +fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { + #[derive(Clone, Copy, Debug)] + struct CaseWhenLookupInput { + batch_size: usize, + + in_range_probability: f32, + null_probability: f32, + } + + impl Display for CaseWhenLookupInput { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "case_when {} rows: in_range: {}, nulls: {}", + self.batch_size, self.in_range_probability, self.null_probability, + ) + } + } + + let mut case_when_lookup = c.benchmark_group("lookup_table_case_when"); + + for in_range_probability in [0.1, 0.5, 0.9, 1.0] { + for null_probability in [0.0, 0.1, 0.5] { + if in_range_probability + null_probability > 1.0 { + continue; + } + + let input = CaseWhenLookupInput { + batch_size, + in_range_probability, + null_probability, + }; + + let when_thens_primitive_to_string = vec![ + (1, "something"), + (2, "very"), + (3, "interesting"), + (4, "is"), + (5, "going"), + (6, "to"), + (7, "happen"), + (30, "in"), + (31, "datafusion"), + (90, "when"), + (91, "you"), + (92, "find"), + (93, "it"), + (120, "let"), + (240, "me"), + (241, "know"), + (244, "please"), + (246, "thank"), + (250, "you"), + (252, "!"), + ]; + let when_thens_string_to_primitive = when_thens_primitive_to_string + .iter() + .map(|&(key, value)| (value, key)) + .collect_vec(); + + for num_entries in [5, 10, 20] { + for (name, values_range) in [ + ("all equally true", 0..num_entries), + // Test when early termination is beneficial + ("only first 2 are true", 0..2), + ] { + let when_thens_primitive_to_string = + when_thens_primitive_to_string[values_range.clone()].to_vec(); + + let when_thens_string_to_primitive = + when_thens_string_to_primitive[values_range].to_vec(); + + case_when_lookup.bench_with_input( + BenchmarkId::new( + format!( + "case when i32 -> utf8, {num_entries} entries, {name}" + ), + input, + ), + &input, + |b, input| { + let array: Int32Array = generate_values_for_lookup( + Options:: { + number_of_rows: batch_size, + range_of_values: when_thens_primitive_to_string + .iter() + .map(|(key, _)| *key) + .collect(), + in_range_probability: input.in_range_probability, + null_probability: input.null_probability, + }, + |rng, exclude| { + generate_other_primitive_value::(rng, exclude) + }, + ); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "col1", + array.data_type().clone(), + true, + )])), + vec![Arc::new(array)], + ) + .unwrap(); + + let when_thens = when_thens_primitive_to_string + .iter() + .map(|&(key, value)| (lit(key), lit(value))) + .collect(); + + let expr = Arc::new( + case( + Some(col("col1", batch.schema_ref()).unwrap()), + when_thens, + Some(lit("whatever")), + ) + .unwrap(), + ); + + b.iter(|| { + black_box(expr.evaluate(black_box(&batch)).unwrap()) + }) + }, + ); + + case_when_lookup.bench_with_input( + BenchmarkId::new( + format!( + "case when utf8 -> i32, {num_entries} entries, {name}" + ), + input, + ), + &input, + |b, input| { + let array: StringArray = generate_values_for_lookup( + Options:: { + number_of_rows: batch_size, + range_of_values: when_thens_string_to_primitive + .iter() + .map(|(key, _)| (*key).to_string()) + .collect(), + in_range_probability: input.in_range_probability, + null_probability: input.null_probability, + }, + |rng, exclude| { + create_random_string_generator(3..10)(rng, exclude) + }, + ); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "col1", + array.data_type().clone(), + true, + )])), + vec![Arc::new(array)], + ) + .unwrap(); + + let when_thens = when_thens_string_to_primitive + .iter() + .map(|&(key, value)| (lit(key), lit(value))) + .collect(); + + let expr = Arc::new( + case( + Some(col("col1", batch.schema_ref()).unwrap()), + when_thens, + Some(lit(1000)), + ) + .unwrap(), + ); + + b.iter(|| { + black_box(expr.evaluate(black_box(&batch)).unwrap()) + }) + }, + ); + } + } + } + } +} + criterion_group!(benches, criterion_benchmark); criterion_main!(benches);