diff --git a/datafusion/spark/src/function/bitwise/bit_count.rs b/datafusion/spark/src/function/bitwise/bit_count.rs index ba44d3bc0a95..4b414b57cb77 100644 --- a/datafusion/spark/src/function/bitwise/bit_count.rs +++ b/datafusion/spark/src/function/bitwise/bit_count.rs @@ -23,6 +23,7 @@ use arrow::datatypes::{ DataType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; +use datafusion_common::cast::as_boolean_array; use datafusion_common::{plan_err, Result}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, @@ -46,6 +47,7 @@ impl SparkBitCount { Self { signature: Signature::one_of( vec![ + TypeSignature::Exact(vec![DataType::Boolean]), TypeSignature::Exact(vec![DataType::Int8]), TypeSignature::Exact(vec![DataType::Int16]), TypeSignature::Exact(vec![DataType::Int32]), @@ -90,28 +92,34 @@ impl ScalarUDFImpl for SparkBitCount { fn spark_bit_count(value_array: &[ArrayRef]) -> Result { let value_array = value_array[0].as_ref(); match value_array.data_type() { + DataType::Boolean => { + let result: Int32Array = as_boolean_array(value_array)? + .iter() + .map(|x| x.map(|y| y as i32)) + .collect(); + Ok(Arc::new(result)) + } DataType::Int8 => { let result: Int32Array = value_array .as_primitive::() - .unary(|v| v.count_ones() as i32); + .unary(|v| bit_count(v.into())); Ok(Arc::new(result)) } DataType::Int16 => { let result: Int32Array = value_array .as_primitive::() - .unary(|v| v.count_ones() as i32); + .unary(|v| bit_count(v.into())); Ok(Arc::new(result)) } DataType::Int32 => { let result: Int32Array = value_array .as_primitive::() - .unary(|v| v.count_ones() as i32); + .unary(|v| bit_count(v.into())); Ok(Arc::new(result)) } DataType::Int64 => { - let result: Int32Array = value_array - .as_primitive::() - .unary(|v| v.count_ones() as i32); + let result: Int32Array = + value_array.as_primitive::().unary(bit_count); Ok(Arc::new(result)) } DataType::UInt8 => { @@ -147,12 +155,26 @@ fn spark_bit_count(value_array: &[ArrayRef]) -> Result { } } +// Here’s the equivalent Rust implementation of the bitCount function (similar to Apache Spark's bitCount for LongType) +// Spark: https://github.com/apache/spark/blob/ac717dd7aec665de578d7c6b0070e8fcdde3cea9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala#L243 +// Java impl: https://github.com/openjdk/jdk/blob/d226023643f90027a8980d161ec6d423887ae3ce/src/java.base/share/classes/java/lang/Long.java#L1584 +fn bit_count(i: i64) -> i32 { + let mut u = i as u64; + u = u - ((u >> 1) & 0x5555555555555555); + u = (u & 0x3333333333333333) + ((u >> 2) & 0x3333333333333333); + u = (u + (u >> 4)) & 0x0f0f0f0f0f0f0f0f; + u = u + (u >> 8); + u = u + (u >> 16); + u = u + (u >> 32); + (u as i32) & 0x7f +} + #[cfg(test)] mod tests { use super::*; use arrow::array::{ - Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, + Array, BooleanArray, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, + UInt32Array, UInt64Array, UInt8Array, }; use arrow::datatypes::Int32Type; @@ -192,7 +214,18 @@ mod tests { assert_eq!(arr.value(2), 2); assert_eq!(arr.value(3), 3); assert_eq!(arr.value(4), 4); - assert_eq!(arr.value(5), 8); + assert_eq!(arr.value(5), 64); + } + + #[test] + fn test_bit_count_boolean() { + // Test bit_count on BooleanArray + let result = + spark_bit_count(&[Arc::new(BooleanArray::from(vec![true, false]))]).unwrap(); + + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), 0); } #[test] @@ -207,7 +240,7 @@ mod tests { assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 10); - assert_eq!(arr.value(4), 16); + assert_eq!(arr.value(4), 64); } #[test] @@ -222,7 +255,7 @@ mod tests { assert_eq!(arr.value(1), 1); // 0b00000000000000000000000000000001 = 1 assert_eq!(arr.value(2), 8); // 0b00000000000000000000000011111111 = 8 assert_eq!(arr.value(3), 10); // 0b00000000000000000000001111111111 = 10 - assert_eq!(arr.value(4), 32); // -1 in two's complement = all 32 bits set + assert_eq!(arr.value(4), 64); // -1 in two's complement = all 32 bits set } #[test] diff --git a/datafusion/sqllogictest/test_files/spark/bitwise/bit_count.slt b/datafusion/sqllogictest/test_files/spark/bitwise/bit_count.slt index 2a75c7648d40..216d99025171 100644 --- a/datafusion/sqllogictest/test_files/spark/bitwise/bit_count.slt +++ b/datafusion/sqllogictest/test_files/spark/bitwise/bit_count.slt @@ -59,17 +59,17 @@ SELECT bit_count(1023::int); query I SELECT bit_count(-1::int); ---- -32 +64 query I SELECT bit_count(-2::int); ---- -31 +63 query I SELECT bit_count(-3::int); ---- -31 +63 # Tests with different integer types query I @@ -85,7 +85,7 @@ SELECT bit_count(arrow_cast(15, 'Int8')); query I SELECT bit_count(arrow_cast(-1, 'Int8')); ---- -8 +64 query I SELECT bit_count(arrow_cast(0, 'Int16')); @@ -100,7 +100,7 @@ SELECT bit_count(arrow_cast(255, 'Int16')); query I SELECT bit_count(arrow_cast(-1, 'Int16')); ---- -16 +64 query I SELECT bit_count(arrow_cast(0, 'Int64')); @@ -214,7 +214,7 @@ SELECT bit_count(arrow_cast(2147483647, 'Int32')); query I SELECT bit_count(arrow_cast(-2147483648, 'Int32')); ---- -1 +33 query I SELECT bit_count(arrow_cast(9223372036854775807, 'Int64'));