Skip to content

Commit

Permalink
Test support for bfloat16 using ml_dtypes.
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreichold committed Jun 22, 2023
1 parent 2abafac commit 08510a3
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
shell: python
- name: Test
run: |
pip install numpy
pip install numpy ml_dtypes
cargo test --all-features
# Not on PyPy, because no embedding API
if: ${{ !startsWith(matrix.python-version, 'pypy') }}
Expand Down Expand Up @@ -101,7 +101,7 @@ jobs:
continue-on-error: true
- uses: taiki-e/install-action@valgrind
- run: |
pip install numpy
pip install numpy ml_dtypes
cargo test --all-features --release
env:
CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: valgrind --leak-check=no --error-exitcode=1
Expand All @@ -115,7 +115,7 @@ jobs:
- uses: Swatinem/rust-cache@v2
continue-on-error: true
- run: |
pip install numpy
pip install numpy ml_dtypes
cargo install --locked cargo-careful
cargo careful test --all-features
Expand Down Expand Up @@ -201,7 +201,7 @@ jobs:
python-version: 3.9
architecture: x64
- name: Install numpy
run: pip install numpy
run: pip install numpy ml_dtypes
- uses: Swatinem/rust-cache@v2
continue-on-error: true
- uses: dtolnay/rust-toolchain@stable
Expand Down
47 changes: 44 additions & 3 deletions tests/array.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::mem::size_of;

#[cfg(feature = "half")]
use half::f16;
use half::{bf16, f16};
use ndarray::{array, s, Array1, Dim};
use numpy::{
dtype, get_array_module, npyffi::NPY_ORDER, pyarray, PyArray, PyArray1, PyArray2, PyArrayDescr,
Expand Down Expand Up @@ -527,7 +527,7 @@ fn reshape() {

#[cfg(feature = "half")]
#[test]
fn half_works() {
fn half_f16_works() {
Python::with_gil(|py| {
let np = py.eval("__import__('numpy')", None, None).unwrap();
let locals = [("np", np)].into_py_dict(py);
Expand Down Expand Up @@ -558,7 +558,48 @@ fn half_works() {
py_run!(
py,
array np,
"np.testing.assert_array_almost_equal(array, np.array([[2, 4], [6, 8]], dtype='float16'))"
"assert np.all(array == np.array([[2, 4], [6, 8]], dtype='float16'))"
);
});
}

#[cfg(feature = "half")]
#[test]
fn half_bf16_works() {
Python::with_gil(|py| {
let np = py.eval("__import__('numpy')", None, None).unwrap();
// NumPy itself does not provide a `bfloat16` dtype itself,
// so we import ml_dtypes which does register such a dtype.
let mldt = py.eval("__import__('ml_dtypes')", None, None).unwrap();
let locals = [("np", np), ("mldt", mldt)].into_py_dict(py);

let array = py
.eval(
"np.array([[1, 2], [3, 4]], dtype='bfloat16')",
None,
Some(locals),
)
.unwrap()
.downcast::<PyArray2<bf16>>()
.unwrap();

assert_eq!(
array.readonly().as_array(),
array![
[bf16::from_f32(1.0), bf16::from_f32(2.0)],
[bf16::from_f32(3.0), bf16::from_f32(4.0)]
]
);

array
.readwrite()
.as_array_mut()
.map_inplace(|value| *value *= bf16::from_f32(2.0));

py_run!(
py,
array np,
"assert np.all(array == np.array([[2, 4], [6, 8]], dtype='bfloat16'))"
);
});
}
Expand Down

0 comments on commit 08510a3

Please sign in to comment.