From 8e65c6f0f399368949af0069874a7a6a7988912d Mon Sep 17 00:00:00 2001 From: Nick Wogan Date: Thu, 28 Mar 2024 14:44:47 -0700 Subject: [PATCH] 64 bit --- src/forwarddiff_const.f90 | 2 +- test/test_jax.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/forwarddiff_const.f90 b/src/forwarddiff_const.f90 index ad8104e..1b196f5 100644 --- a/src/forwarddiff_const.f90 +++ b/src/forwarddiff_const.f90 @@ -1,3 +1,3 @@ module forwarddiff_const - use iso_fortran_env, only: wp => real32 + use iso_fortran_env, only: wp => real64 end module \ No newline at end of file diff --git a/test/test_jax.py b/test/test_jax.py index 106cba5..f661324 100644 --- a/test/test_jax.py +++ b/test/test_jax.py @@ -47,28 +47,28 @@ def test(): x = np.array(10.0,dtype=np.float32) f = func_operators(x) dfdx = jax.grad(func_operators)(x) - f1, dfdx1 = fil.read_record(np.float32) + f1, dfdx1 = fil.read_record(np.float64) print(f/f1,dfdx/dfdx1) assert np.isclose(f,f1) and np.isclose(dfdx,dfdx1) x = np.array(10.0,dtype=np.float32) f = func_intrinsics1(x) dfdx = jax.grad(func_intrinsics1)(x) - f1, dfdx1 = fil.read_record(np.float32) + f1, dfdx1 = fil.read_record(np.float64) print(f/f1,dfdx/dfdx1) assert np.isclose(f,f1) and np.isclose(dfdx,dfdx1) x = np.array(0.1,dtype=np.float32) f = func_intrinsics2(x) dfdx = jax.grad(func_intrinsics2)(x) - f1, dfdx1 = fil.read_record(np.float32) + f1, dfdx1 = fil.read_record(np.float64) print(f/f1,dfdx/dfdx1) assert np.isclose(f,f1) and np.isclose(dfdx,dfdx1) x = np.array([1, 2],dtype=np.float32) f = func_grad1(x) dfdx = jax.grad(func_grad1)(x) - tmp = fil.read_record(np.float32) + tmp = fil.read_record(np.float64) f1, dfdx1 = tmp[0],tmp[1:] print(f/f1,dfdx/dfdx1) assert np.isclose(f,f1) and np.all(np.isclose(dfdx,dfdx1)) @@ -76,7 +76,7 @@ def test(): x = np.array([3, 4],dtype=np.float32) f = func_grad2(x) dfdx = jax.grad(func_grad2)(x) - tmp = fil.read_record(np.float32) + tmp = fil.read_record(np.float64) f1, dfdx1 = tmp[0],tmp[1:] print(f/f1,dfdx/dfdx1) assert np.isclose(f,f1) and np.all(np.isclose(dfdx,dfdx1))