|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
18 | | -use arrow::array::{Array, ArrayRef, Int32Array, Int32Builder}; |
19 | | -use arrow::datatypes::{Field, Schema}; |
| 18 | +use arrow::array::{Array, ArrayRef, Int32Array, Int32Builder, StringArray}; |
| 19 | +use arrow::datatypes::{ArrowNativeTypeOp, Field, Schema}; |
20 | 20 | use arrow::record_batch::RecordBatch; |
21 | | -use criterion::{black_box, criterion_group, criterion_main, Criterion}; |
| 21 | +use arrow::util::test_util::seedable_rng; |
| 22 | +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; |
22 | 23 | use datafusion_expr::Operator; |
23 | 24 | use datafusion_physical_expr::expressions::{case, col, lit, BinaryExpr}; |
24 | 25 | use datafusion_physical_expr_common::physical_expr::PhysicalExpr; |
| 26 | +use itertools::Itertools; |
| 27 | +use rand::distr::uniform::SampleUniform; |
| 28 | +use rand::distr::Alphanumeric; |
| 29 | +use rand::rngs::StdRng; |
| 30 | +use rand::{Rng, RngCore}; |
| 31 | +use std::fmt::{Display, Formatter}; |
| 32 | +use std::ops::Range; |
25 | 33 | use std::sync::Arc; |
26 | 34 |
|
27 | 35 | fn make_x_cmp_y( |
@@ -82,6 +90,8 @@ fn criterion_benchmark(c: &mut Criterion) { |
82 | 90 | run_benchmarks(c, &make_batch(8192, 3)); |
83 | 91 | run_benchmarks(c, &make_batch(8192, 50)); |
84 | 92 | run_benchmarks(c, &make_batch(8192, 100)); |
| 93 | + |
| 94 | + benchmark_lookup_table_case_when(c, 8192); |
85 | 95 | } |
86 | 96 |
|
87 | 97 | fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) { |
@@ -230,5 +240,281 @@ fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) { |
230 | 240 | }); |
231 | 241 | } |
232 | 242 |
|
| 243 | +struct Options<T> { |
| 244 | + number_of_rows: usize, |
| 245 | + range_of_values: Vec<T>, |
| 246 | + in_range_probability: f32, |
| 247 | + null_probability: f32, |
| 248 | +} |
| 249 | + |
| 250 | +fn generate_other_primitive_value<T: ArrowNativeTypeOp + SampleUniform>( |
| 251 | + rng: &mut impl RngCore, |
| 252 | + exclude: &[T], |
| 253 | +) -> T { |
| 254 | + let mut value; |
| 255 | + let retry_limit = 100; |
| 256 | + for _ in 0..retry_limit { |
| 257 | + value = rng.random_range(T::MIN_TOTAL_ORDER..=T::MAX_TOTAL_ORDER); |
| 258 | + if !exclude.contains(&value) { |
| 259 | + return value; |
| 260 | + } |
| 261 | + } |
| 262 | + |
| 263 | + panic!("Could not generate out of range value after {retry_limit} attempts"); |
| 264 | +} |
| 265 | + |
| 266 | +fn create_random_string_generator( |
| 267 | + length: Range<usize>, |
| 268 | +) -> impl Fn(&mut dyn RngCore, &[String]) -> String { |
| 269 | + assert!(length.end > length.start); |
| 270 | + |
| 271 | + move |rng, exclude| { |
| 272 | + let retry_limit = 100; |
| 273 | + for _ in 0..retry_limit { |
| 274 | + let length = rng.random_range(length.clone()); |
| 275 | + let value: String = rng |
| 276 | + .sample_iter(Alphanumeric) |
| 277 | + .take(length) |
| 278 | + .map(char::from) |
| 279 | + .collect(); |
| 280 | + |
| 281 | + if !exclude.contains(&value) { |
| 282 | + return value; |
| 283 | + } |
| 284 | + } |
| 285 | + |
| 286 | + panic!("Could not generate out of range value after {retry_limit} attempts"); |
| 287 | + } |
| 288 | +} |
| 289 | + |
| 290 | +/// Create column with the provided number of rows |
| 291 | +/// `in_range_percentage` is the percentage of values that should be inside the specified range |
| 292 | +/// `null_percentage` is the percentage of null values |
| 293 | +/// The rest of the values will be outside the specified range |
| 294 | +fn generate_values_for_lookup<T, A>( |
| 295 | + options: Options<T>, |
| 296 | + generate_other_value: impl Fn(&mut StdRng, &[T]) -> T, |
| 297 | +) -> A |
| 298 | +where |
| 299 | + T: Clone, |
| 300 | + A: FromIterator<Option<T>>, |
| 301 | +{ |
| 302 | + // Create a value with specified range most of the time, but also some nulls and the rest is generic |
| 303 | + |
| 304 | + assert!( |
| 305 | + options.in_range_probability + options.null_probability <= 1.0, |
| 306 | + "Percentages must sum to 1.0 or less" |
| 307 | + ); |
| 308 | + |
| 309 | + let rng = &mut seedable_rng(); |
| 310 | + |
| 311 | + let in_range_probability = 0.0..options.in_range_probability; |
| 312 | + let null_range_probability = |
| 313 | + in_range_probability.start..in_range_probability.start + options.null_probability; |
| 314 | + let out_range_probability = null_range_probability.end..1.0; |
| 315 | + |
| 316 | + (0..options.number_of_rows) |
| 317 | + .map(|_| { |
| 318 | + let roll: f32 = rng.random(); |
| 319 | + |
| 320 | + match roll { |
| 321 | + v if out_range_probability.contains(&v) => { |
| 322 | + let index = rng.random_range(0..options.range_of_values.len()); |
| 323 | + // Generate value in range |
| 324 | + Some(options.range_of_values[index].clone()) |
| 325 | + } |
| 326 | + v if null_range_probability.contains(&v) => None, |
| 327 | + _ => { |
| 328 | + // Generate value out of range |
| 329 | + Some(generate_other_value(rng, &options.range_of_values)) |
| 330 | + } |
| 331 | + } |
| 332 | + }) |
| 333 | + .collect::<A>() |
| 334 | +} |
| 335 | + |
| 336 | +fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { |
| 337 | + #[derive(Clone, Copy, Debug)] |
| 338 | + struct CaseWhenLookupInput { |
| 339 | + batch_size: usize, |
| 340 | + |
| 341 | + in_range_probability: f32, |
| 342 | + null_probability: f32, |
| 343 | + } |
| 344 | + |
| 345 | + impl Display for CaseWhenLookupInput { |
| 346 | + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
| 347 | + write!( |
| 348 | + f, |
| 349 | + "case_when {} rows: in_range: {}, nulls: {}", |
| 350 | + self.batch_size, self.in_range_probability, self.null_probability, |
| 351 | + ) |
| 352 | + } |
| 353 | + } |
| 354 | + |
| 355 | + let mut case_when_lookup = c.benchmark_group("lookup_table_case_when"); |
| 356 | + |
| 357 | + for in_range_probability in [0.1, 0.5, 0.9, 1.0] { |
| 358 | + for null_probability in [0.0, 0.1, 0.5] { |
| 359 | + if in_range_probability + null_probability > 1.0 { |
| 360 | + continue; |
| 361 | + } |
| 362 | + |
| 363 | + let input = CaseWhenLookupInput { |
| 364 | + batch_size, |
| 365 | + in_range_probability, |
| 366 | + null_probability, |
| 367 | + }; |
| 368 | + |
| 369 | + let when_thens_primitive_to_string = vec![ |
| 370 | + (1, "something"), |
| 371 | + (2, "very"), |
| 372 | + (3, "interesting"), |
| 373 | + (4, "is"), |
| 374 | + (5, "going"), |
| 375 | + (6, "to"), |
| 376 | + (7, "happen"), |
| 377 | + (30, "in"), |
| 378 | + (31, "datafusion"), |
| 379 | + (90, "when"), |
| 380 | + (91, "you"), |
| 381 | + (92, "find"), |
| 382 | + (93, "it"), |
| 383 | + (120, "let"), |
| 384 | + (240, "me"), |
| 385 | + (241, "know"), |
| 386 | + (244, "please"), |
| 387 | + (246, "thank"), |
| 388 | + (250, "you"), |
| 389 | + (252, "!"), |
| 390 | + ]; |
| 391 | + let when_thens_string_to_primitive = when_thens_primitive_to_string |
| 392 | + .iter() |
| 393 | + .map(|&(key, value)| (value, key)) |
| 394 | + .collect_vec(); |
| 395 | + |
| 396 | + for num_entries in [5, 10, 20] { |
| 397 | + for (name, values_range) in [ |
| 398 | + ("all equally true", 0..num_entries), |
| 399 | + // Test when early termination is beneficial |
| 400 | + ("only first 2 are true", 0..2), |
| 401 | + ] { |
| 402 | + let when_thens_primitive_to_string = |
| 403 | + when_thens_primitive_to_string[values_range.clone()].to_vec(); |
| 404 | + |
| 405 | + let when_thens_string_to_primitive = |
| 406 | + when_thens_string_to_primitive[values_range].to_vec(); |
| 407 | + |
| 408 | + case_when_lookup.bench_with_input( |
| 409 | + BenchmarkId::new( |
| 410 | + format!( |
| 411 | + "case when i32 -> utf8, {num_entries} entries, {name}" |
| 412 | + ), |
| 413 | + input, |
| 414 | + ), |
| 415 | + &input, |
| 416 | + |b, input| { |
| 417 | + let array: Int32Array = generate_values_for_lookup( |
| 418 | + Options::<i32> { |
| 419 | + number_of_rows: batch_size, |
| 420 | + range_of_values: when_thens_primitive_to_string |
| 421 | + .iter() |
| 422 | + .map(|(key, _)| *key) |
| 423 | + .collect(), |
| 424 | + in_range_probability: input.in_range_probability, |
| 425 | + null_probability: input.null_probability, |
| 426 | + }, |
| 427 | + |rng, exclude| { |
| 428 | + generate_other_primitive_value::<i32>(rng, exclude) |
| 429 | + }, |
| 430 | + ); |
| 431 | + let batch = RecordBatch::try_new( |
| 432 | + Arc::new(Schema::new(vec![Field::new( |
| 433 | + "col1", |
| 434 | + array.data_type().clone(), |
| 435 | + true, |
| 436 | + )])), |
| 437 | + vec![Arc::new(array)], |
| 438 | + ) |
| 439 | + .unwrap(); |
| 440 | + |
| 441 | + let when_thens = when_thens_primitive_to_string |
| 442 | + .iter() |
| 443 | + .map(|&(key, value)| (lit(key), lit(value))) |
| 444 | + .collect(); |
| 445 | + |
| 446 | + let expr = Arc::new( |
| 447 | + case( |
| 448 | + Some(col("col1", batch.schema_ref()).unwrap()), |
| 449 | + when_thens, |
| 450 | + Some(lit("whatever")), |
| 451 | + ) |
| 452 | + .unwrap(), |
| 453 | + ); |
| 454 | + |
| 455 | + b.iter(|| { |
| 456 | + black_box(expr.evaluate(black_box(&batch)).unwrap()) |
| 457 | + }) |
| 458 | + }, |
| 459 | + ); |
| 460 | + |
| 461 | + case_when_lookup.bench_with_input( |
| 462 | + BenchmarkId::new( |
| 463 | + format!( |
| 464 | + "case when utf8 -> i32, {num_entries} entries, {name}" |
| 465 | + ), |
| 466 | + input, |
| 467 | + ), |
| 468 | + &input, |
| 469 | + |b, input| { |
| 470 | + let array: StringArray = generate_values_for_lookup( |
| 471 | + Options::<String> { |
| 472 | + number_of_rows: batch_size, |
| 473 | + range_of_values: when_thens_string_to_primitive |
| 474 | + .iter() |
| 475 | + .map(|(key, _)| (*key).to_string()) |
| 476 | + .collect(), |
| 477 | + in_range_probability: input.in_range_probability, |
| 478 | + null_probability: input.null_probability, |
| 479 | + }, |
| 480 | + |rng, exclude| { |
| 481 | + create_random_string_generator(3..10)(rng, exclude) |
| 482 | + }, |
| 483 | + ); |
| 484 | + let batch = RecordBatch::try_new( |
| 485 | + Arc::new(Schema::new(vec![Field::new( |
| 486 | + "col1", |
| 487 | + array.data_type().clone(), |
| 488 | + true, |
| 489 | + )])), |
| 490 | + vec![Arc::new(array)], |
| 491 | + ) |
| 492 | + .unwrap(); |
| 493 | + |
| 494 | + let when_thens = when_thens_string_to_primitive |
| 495 | + .iter() |
| 496 | + .map(|&(key, value)| (lit(key), lit(value))) |
| 497 | + .collect(); |
| 498 | + |
| 499 | + let expr = Arc::new( |
| 500 | + case( |
| 501 | + Some(col("col1", batch.schema_ref()).unwrap()), |
| 502 | + when_thens, |
| 503 | + Some(lit(1000)), |
| 504 | + ) |
| 505 | + .unwrap(), |
| 506 | + ); |
| 507 | + |
| 508 | + b.iter(|| { |
| 509 | + black_box(expr.evaluate(black_box(&batch)).unwrap()) |
| 510 | + }) |
| 511 | + }, |
| 512 | + ); |
| 513 | + } |
| 514 | + } |
| 515 | + } |
| 516 | + } |
| 517 | +} |
| 518 | + |
233 | 519 | criterion_group!(benches, criterion_benchmark); |
234 | 520 | criterion_main!(benches); |
0 commit comments