Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions datafusion/spark/src/function/bitwise/bit_count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,28 +90,28 @@ impl ScalarUDFImpl for SparkBitCount {
fn spark_bit_count(value_array: &[ArrayRef]) -> Result<ArrayRef> {
let value_array = value_array[0].as_ref();
match value_array.data_type() {
DataType::Int8 => {
DataType::Int8 | DataType::Boolean => {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spark supports only signed int types

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment seems misplaced as the code adds support for boolean 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think this code will now panic if you pass in a Boolean array

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to note that Spark only supports signed integer and boolean types as input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an additional test for BooleanArray.

let result: Int32Array = value_array
.as_primitive::<Int8Type>()
.unary(|v| v.count_ones() as i32);
.unary(|v| bit_count(v.into()));
Ok(Arc::new(result))
}
DataType::Int16 => {
let result: Int32Array = value_array
.as_primitive::<Int16Type>()
.unary(|v| v.count_ones() as i32);
.unary(|v| bit_count(v.into()));
Ok(Arc::new(result))
}
DataType::Int32 => {
let result: Int32Array = value_array
.as_primitive::<Int32Type>()
.unary(|v| v.count_ones() as i32);
.unary(|v| bit_count(v.into()));
Ok(Arc::new(result))
}
DataType::Int64 => {
let result: Int32Array = value_array
.as_primitive::<Int64Type>()
.unary(|v| v.count_ones() as i32);
.unary(|v| bit_count(v.into()));
Ok(Arc::new(result))
}
DataType::UInt8 => {
Expand Down Expand Up @@ -147,6 +147,18 @@ fn spark_bit_count(value_array: &[ArrayRef]) -> Result<ArrayRef> {
}
}

// Here’s the equivalent Rust implementation of the bitCount function (similar to Apache Spark's bitCount for LongType)
fn bit_count(i: i64) -> i32 {
let mut u = i as u64;
u = u - ((u >> 1) & 0x5555555555555555);
u = (u & 0x3333333333333333) + ((u >> 2) & 0x3333333333333333);
u = (u + (u >> 4)) & 0x0f0f0f0f0f0f0f0f;
u = u + (u >> 8);
u = u + (u >> 16);
u = u + (u >> 32);
(u as i32) & 0x7f
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -192,7 +204,7 @@ mod tests {
assert_eq!(arr.value(2), 2);
assert_eq!(arr.value(3), 3);
assert_eq!(arr.value(4), 4);
assert_eq!(arr.value(5), 8);
assert_eq!(arr.value(5), 64);
}

#[test]
Expand All @@ -207,7 +219,7 @@ mod tests {
assert_eq!(arr.value(1), 1);
assert_eq!(arr.value(2), 8);
assert_eq!(arr.value(3), 10);
assert_eq!(arr.value(4), 16);
assert_eq!(arr.value(4), 64);
}

#[test]
Expand All @@ -222,7 +234,7 @@ mod tests {
assert_eq!(arr.value(1), 1); // 0b00000000000000000000000000000001 = 1
assert_eq!(arr.value(2), 8); // 0b00000000000000000000000011111111 = 8
assert_eq!(arr.value(3), 10); // 0b00000000000000000000001111111111 = 10
assert_eq!(arr.value(4), 32); // -1 in two's complement = all 32 bits set
assert_eq!(arr.value(4), 64); // -1 in two's complement = all 32 bits set
}

#[test]
Expand Down
12 changes: 6 additions & 6 deletions datafusion/sqllogictest/test_files/spark/bitwise/bit_count.slt
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,17 @@ SELECT bit_count(1023::int);
query I
SELECT bit_count(-1::int);
----
32
64

query I
SELECT bit_count(-2::int);
----
31
63

query I
SELECT bit_count(-3::int);
----
31
63

# Tests with different integer types
query I
Expand All @@ -85,7 +85,7 @@ SELECT bit_count(arrow_cast(15, 'Int8'));
query I
SELECT bit_count(arrow_cast(-1, 'Int8'));
----
8
64

query I
SELECT bit_count(arrow_cast(0, 'Int16'));
Expand All @@ -100,7 +100,7 @@ SELECT bit_count(arrow_cast(255, 'Int16'));
query I
SELECT bit_count(arrow_cast(-1, 'Int16'));
----
16
64

query I
SELECT bit_count(arrow_cast(0, 'Int64'));
Expand Down Expand Up @@ -214,7 +214,7 @@ SELECT bit_count(arrow_cast(2147483647, 'Int32'));
query I
SELECT bit_count(arrow_cast(-2147483648, 'Int32'));
----
1
33

query I
SELECT bit_count(arrow_cast(9223372036854775807, 'Int64'));
Expand Down
Loading