Skip to content

Commit 129f23b

Browse files
committed
switched nufft_cached tests to use torch.double test images.
1 parent f96c637 commit 129f23b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

test/fourier_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_predict_vis_nufft_cached(coords, baselines_1D):
165165
layer = fourier.NuFFTCached(coords=coords, nchan=nchan, uu=uu, vv=vv)
166166

167167
# predict the values of the cube at the u,v locations
168-
blank_packed_img = torch.zeros((nchan, coords.npix, coords.npix))
168+
blank_packed_img = torch.zeros((nchan, coords.npix, coords.npix), dtype=torch.double)
169169
output = layer(blank_packed_img)
170170

171171
# make sure we got back the number of visibilities we expected
@@ -287,7 +287,7 @@ def test_nufft_cached_accuracy_single_chan(coords, baselines_1D, tmp_path):
287287
img_packed = utils.sky_gaussian_arcsec(
288288
coords.packed_x_centers_2D, coords.packed_y_centers_2D, **gauss_kw
289289
)
290-
img_packed_tensor = torch.tensor(img_packed[np.newaxis, :, :], requires_grad=True, dtype=torch.float32)
290+
img_packed_tensor = torch.tensor(img_packed[np.newaxis, :, :], requires_grad=True, dtype=torch.double)
291291

292292
# use the NuFFT to predict the values of the cube at the u,v locations
293293
num_output = layer(img_packed_tensor)[0] # take the channel dim out
@@ -324,7 +324,7 @@ def test_nufft_cached_accuracy_coil_broadcast(coords, baselines_1D, tmp_path):
324324
# broadcast to 5 channels -- the image will be the same for each
325325
img_packed_tensor = torch.tensor(
326326
img_packed[np.newaxis, :, :] * np.ones((nchan, coords.npix, coords.npix)),
327-
requires_grad=True, dtype=torch.float32
327+
requires_grad=True, dtype=torch.double
328328
)
329329

330330
# use the NuFFT to predict the values of the cube at the u,v locations

0 commit comments

Comments
 (0)