Skip to content

Commit

Permalink
TST: make torch default dtype configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
ev-br committed Dec 30, 2024
1 parent 31cbb91 commit 775b72e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
7 changes: 6 additions & 1 deletion .github/workflows/array_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,14 @@ jobs:
run: |
python dev.py build
- name: Test SciPy
run_tests:
name: Test SciPy
strategy:
matrix:
default_dtype: ['float32', 'float64']
run: |
export OMP_NUM_THREADS=2
export SCIPY_DEFAULT_DTYPE={{ matrix.default_dtype }}
# expand as more modules are supported by adding to `XP_TESTS` above
python dev.py --no-build test -b all $XP_TESTS -- --durations 3 --timeout=60
8 changes: 8 additions & 0 deletions scipy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@ def num_parallel_threads():
xp_available_backends.update({'torch': torch})
# can use `mps` or `cpu`
torch.set_default_device(SCIPY_DEVICE)

# default dtype: XXX flip the default to float64
default = os.getenv('SCIPY_DEFAULT_DTYPE', default='float32')
if default == 'float64':
torch.set_default_dtype(torch.float64)

print(f"default dtype set for {torch.get_default_dtype()}")

except ImportError:
pass

Expand Down

0 comments on commit 775b72e

Please sign in to comment.