From 775b72ebaf45de1aa9dc34b0f4a2ceefef1c5592 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 30 Dec 2024 09:48:29 +0200 Subject: [PATCH] TST: make torch default dtype configurable --- .github/workflows/array_api.yml | 7 ++++++- scipy/conftest.py | 8 ++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/.github/workflows/array_api.yml b/.github/workflows/array_api.yml index 4a562b2bd876..0ac4b04a1ca0 100644 --- a/.github/workflows/array_api.yml +++ b/.github/workflows/array_api.yml @@ -108,9 +108,14 @@ jobs: run: | python dev.py build - - name: Test SciPy + run_tests: + name: Test SciPy + strategy: + matrix: + default_dtype: ['float32', 'float64'] run: | export OMP_NUM_THREADS=2 + export SCIPY_DEFAULT_DTYPE={{ matrix.default_dtype }} # expand as more modules are supported by adding to `XP_TESTS` above python dev.py --no-build test -b all $XP_TESTS -- --durations 3 --timeout=60 diff --git a/scipy/conftest.py b/scipy/conftest.py index ca2cf0b3187e..889394be9f16 100644 --- a/scipy/conftest.py +++ b/scipy/conftest.py @@ -161,6 +161,14 @@ def num_parallel_threads(): xp_available_backends.update({'torch': torch}) # can use `mps` or `cpu` torch.set_default_device(SCIPY_DEVICE) + + # default dtype: XXX flip the default to float64 + default = os.getenv('SCIPY_DEFAULT_DTYPE', default='float32') + if default == 'float64': + torch.set_default_dtype(torch.float64) + + print(f"default dtype set for {torch.get_default_dtype()}") + except ImportError: pass