Skip to content

Commit

Permalink
Adding more arithmetic dyn kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
psvri committed Nov 19, 2023
1 parent bf30060 commit 1cdc209
Show file tree
Hide file tree
Showing 8 changed files with 761 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/deploy-pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- name: Set up mdbook
uses: taiki-e/install-action@v2
with:
tool: mdbook
tool: mdbook,mdbook-pagetoc
- name: Deploy GitHub Pages
run: |
cd docs
Expand Down
2 changes: 2 additions & 0 deletions crates/arithmetic/src/f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ mod tests {
Float32ArrayGPU,
Float32ArrayGPU,
sub,
sub_dyn,
vec![Some(0.0), Some(1.0), None, None, Some(4.0), Some(10.0)],
vec![Some(1.0), Some(2.0), None, Some(4.0), None, Some(0.0)],
vec![Some(-1.0), Some(-1.0), None, None, None, Some(10.0)]
Expand All @@ -234,6 +235,7 @@ mod tests {
Float32ArrayGPU,
Float32ArrayGPU,
div,
div_dyn,
vec![Some(0.0), Some(1.0), None, None, Some(4.0)],
vec![Some(1.0), Some(2.0), None, Some(4.0), None],
vec![Some(0.0), Some(0.5), None, None, None]
Expand Down
36 changes: 36 additions & 0 deletions crates/arithmetic/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,15 @@ pub async fn add_array_dyn(input1: &ArrowArrayGPU, input2: &ArrowArrayGPU) -> Ar
}
}

pub async fn sub_array_dyn(input1: &ArrowArrayGPU, input2: &ArrowArrayGPU) -> ArrowArrayGPU {
match (input1, input2) {
(ArrowArrayGPU::Float32ArrayGPU(arr1), ArrowArrayGPU::Float32ArrayGPU(arr2)) => {
arr1.sub(arr2).await.into()
}
_ => panic!("Operation not supported"),
}
}

pub async fn mul_array_dyn(input1: &ArrowArrayGPU, input2: &ArrowArrayGPU) -> ArrowArrayGPU {
match (input1, input2) {
(ArrowArrayGPU::Float32ArrayGPU(arr1), ArrowArrayGPU::Float32ArrayGPU(arr2)) => {
Expand All @@ -190,6 +199,15 @@ pub async fn mul_array_dyn(input1: &ArrowArrayGPU, input2: &ArrowArrayGPU) -> Ar
}
}

pub async fn div_array_dyn(input1: &ArrowArrayGPU, input2: &ArrowArrayGPU) -> ArrowArrayGPU {
match (input1, input2) {
(ArrowArrayGPU::Float32ArrayGPU(arr1), ArrowArrayGPU::Float32ArrayGPU(arr2)) => {
arr1.div(arr2).await.into()
}
_ => panic!("Operation not supported"),
}
}

pub async fn add_dyn(input1: &ArrowArrayGPU, input2: &ArrowArrayGPU) -> ArrowArrayGPU {
match (input1.len(), input2.len()) {
(x, y) if (x == 1 && y == 1) || (x != 1 && y != 1) => add_array_dyn(input1, input2).await,
Expand All @@ -199,6 +217,15 @@ pub async fn add_dyn(input1: &ArrowArrayGPU, input2: &ArrowArrayGPU) -> ArrowArr
}
}

pub async fn sub_dyn(input1: &ArrowArrayGPU, input2: &ArrowArrayGPU) -> ArrowArrayGPU {
match (input1.len(), input2.len()) {
(x, y) if (x == 1 && y == 1) || (x != 1 && y != 1) => sub_array_dyn(input1, input2).await,
(_, 1) => sub_scalar_dyn(input1, input2).await,
(1, _) => sub_scalar_dyn(input2, input1).await,
_ => unreachable!(),
}
}

pub async fn mul_dyn(input1: &ArrowArrayGPU, input2: &ArrowArrayGPU) -> ArrowArrayGPU {
match (input1.len(), input2.len()) {
(x, y) if (x == 1 && y == 1) || (x != 1 && y != 1) => mul_array_dyn(input1, input2).await,
Expand All @@ -207,3 +234,12 @@ pub async fn mul_dyn(input1: &ArrowArrayGPU, input2: &ArrowArrayGPU) -> ArrowArr
_ => unreachable!(),
}
}

pub async fn div_dyn(input1: &ArrowArrayGPU, input2: &ArrowArrayGPU) -> ArrowArrayGPU {
match (input1.len(), input2.len()) {
(x, y) if (x == 1 && y == 1) || (x != 1 && y != 1) => div_array_dyn(input1, input2).await,
(_, 1) => div_scalar_dyn(input1, input2).await,
(1, _) => div_scalar_dyn(input2, input1).await,
_ => unreachable!(),
}
}
6 changes: 6 additions & 0 deletions docs/book.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@ language = "en"
multilingual = false
src = "src"
title = "Arrow GPU Documentation"


[preprocessor.pagetoc]
[output.html]
additional-css = ["theme/pagetoc.css"]
additional-js = ["theme/pagetoc.js"]
266 changes: 266 additions & 0 deletions docs/src/kernels/arithmetic.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,272 @@

### Floats

| | Float32 |
|-|-|
| Int8 | |
| Int16 | |
| Int32 | |
| UInt8 | |
| UInt16 | |
| UInt32 | |
| Float32 ||

## Sub Scalar

### Signed Integers

| | Int8 | Int16 | Int32 |
|-|-|-|-|
| Int8 | | | |
| Int16 | | | |
| Int32 | | ||
| UInt8 | | |
| UInt16 | | |
| UInt32 | | |
| Float32 | | |

### Unsigned Integers

| | UInt8 | UInt16 | UInt32 |
|-|-|-|-|
| Int8 | | | |
| Int16 | | | |
| Int32 | | | |
| UInt8 | | |
| UInt16 | | |
| UInt32 | ||
| Float32 | | |

### Floats

| | Float32 |
|-|-|
| Int8 | |
| Int16 | |
| Int32 | |
| UInt8 | |
| UInt16 | |
| UInt32 | |
| Float32 ||

## Multiply Scalar

### Signed Integers

| | Int8 | Int16 | Int32 |
|-|-|-|-|
| Int8 | | | |
| Int16 | | | |
| Int32 | | ||
| UInt8 | | |
| UInt16 | | |
| UInt32 | | |
| Float32 | | |

### Unsigned Integers

| | UInt8 | UInt16 | UInt32 |
|-|-|-|-|
| Int8 | | | |
| Int16 | | | |
| Int32 | | | |
| UInt8 | | |
| UInt16 | | |
| UInt32 | ||
| Float32 | | |

### Floats

| | Float32 |
|-|-|
| Int8 | |
| Int16 | |
| Int32 | |
| UInt8 | |
| UInt16 | |
| UInt32 | |
| Float32 ||

## Div Scalar

### Signed Integers

| | Int8 | Int16 | Int32 |
|-|-|-|-|
| Int8 | | | |
| Int16 | | | |
| Int32 | | ||
| UInt8 | | |
| UInt16 | | |
| UInt32 | | |
| Float32 | | |

### Unsigned Integers

| | UInt8 | UInt16 | UInt32 |
|-|-|-|-|
| Int8 | | | |
| Int16 | | | |
| Int32 | | | |
| UInt8 | | |
| UInt16 | | |
| UInt32 | ||
| Float32 | | |

### Floats

| | Float32 |
|-|-|
| Int8 | |
| Int16 | |
| Int32 | |
| UInt8 | |
| UInt16 | |
| UInt32 | |
| Float32 ||

## Add Array

### Signed Integers

| | Int8 | Int16 | Int32 |
|-|-|-|-|
| Int8 | | | |
| Int16 | | | |
| Int32 | | | |
| UInt8 | | |
| UInt16 | | |
| UInt32 | | |
| Float32 | | |

### Unsigned Integers

| | UInt8 | UInt16 | UInt32 |
|-|-|-|-|
| Int8 | | | |
| Int16 | | | |
| Int32 | | | |
| UInt8 | | |
| UInt16 | | |
| UInt32 | | |
| Float32 | | |

### Floats

| | Float32 |
|-|-|
| Int8 | |
| Int16 | |
| Int32 | |
| UInt8 | |
| UInt16 | |
| UInt32 | |
| Float32 ||

## Sub Array

### Signed Integers

| | Int8 | Int16 | Int32 |
|-|-|-|-|
| Int8 | | | |
| Int16 | | | |
| Int32 | | | |
| UInt8 | | |
| UInt16 | | |
| UInt32 | | |
| Float32 | | |

### Unsigned Integers

| | UInt8 | UInt16 | UInt32 |
|-|-|-|-|
| Int8 | | | |
| Int16 | | | |
| Int32 | | | |
| UInt8 | | |
| UInt16 | | |
| UInt32 | | |
| Float32 | | |

### Floats

| | Float32 |
|-|-|
| Int8 | |
| Int16 | |
| Int32 | |
| UInt8 | |
| UInt16 | |
| UInt32 | |
| Float32 ||

## Multiply Array

### Signed Integers

| | Int8 | Int16 | Int32 |
|-|-|-|-|
| Int8 | | | |
| Int16 | | | |
| Int32 | | | |
| UInt8 | | |
| UInt16 | | |
| UInt32 | | |
| Float32 | | |

### Unsigned Integers

| | UInt8 | UInt16 | UInt32 |
|-|-|-|-|
| Int8 | | | |
| Int16 | | | |
| Int32 | | | |
| UInt8 | | |
| UInt16 | | |
| UInt32 | | |
| Float32 | | |

### Floats

| | Float32 |
|-|-|
| Int8 | |
| Int16 | |
| Int32 | |
| UInt8 | |
| UInt16 | |
| UInt32 | |
| Float32 ||

## Div Array

### Signed Integers

| | Int8 | Int16 | Int32 |
|-|-|-|-|
| Int8 | | | |
| Int16 | | | |
| Int32 | | | |
| UInt8 | | |
| UInt16 | | |
| UInt32 | | |
| Float32 | | |

### Unsigned Integers

| | UInt8 | UInt16 | UInt32 |
|-|-|-|-|
| Int8 | | | |
| Int16 | | | |
| Int32 | | | |
| UInt8 | | |
| UInt16 | | |
| UInt32 | | |
| Float32 | | |

### Floats

| | Float32 |
|-|-|
| Int8 | |
Expand Down
Loading

0 comments on commit 1cdc209

Please sign in to comment.