From 24866f4b2f906279d0703bc13f63e6278d80628a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Josu=C3=A9=20Sehnem?= Date: Sat, 7 Dec 2024 13:27:26 -0300 Subject: [PATCH] fix tests --- tests/test_python_frontend/test_lw_solver.py | 13 +++++++------ tests/test_python_frontend/test_sw_solver.py | 15 ++++++++------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/test_python_frontend/test_lw_solver.py b/tests/test_python_frontend/test_lw_solver.py index 645372c..f9fce52 100644 --- a/tests/test_python_frontend/test_lw_solver.py +++ b/tests/test_python_frontend/test_lw_solver.py @@ -6,7 +6,7 @@ from pyrte_rrtmgp import rrtmgp_gas_optics from pyrte_rrtmgp.rrtmgp_gas_optics import GasOpticsFiles, load_gas_optics from pyrte_rrtmgp.rrtmgp_data import download_rrtmgp_data -from pyrte_rrtmgp.rte_solver import rte_solve +from pyrte_rrtmgp.rte_solver import RTESolver ERROR_TOLERANCE = 1e-7 @@ -26,13 +26,13 @@ os.path.join(ref_dir, "rlu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc"), decode_cf=False, ) -ref_flux_up = rlu.isel(expt=0)["rlu"].values +ref_flux_up = rlu.isel(expt=0)["rlu"] rld = xr.load_dataset( os.path.join(ref_dir, "rld_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc"), decode_cf=False, ) -ref_flux_down = rld.isel(expt=0)["rld"].values +ref_flux_down = rld.isel(expt=0)["rld"] def test_lw_solver_noscat(): @@ -43,8 +43,9 @@ def test_lw_solver_noscat(): gas_optics_lw.gas_optics.compute(atmosphere, problem_type="absorption") # Solve RTE with the new API - fluxes = rte_solve(atmosphere, add_to_input=False) + solver = RTESolver() + fluxes = solver.solve(atmosphere, add_to_input=False) # Compare results with reference data - assert np.isclose(fluxes["lw_flux_up"].values, ref_flux_up, atol=ERROR_TOLERANCE).all() - assert np.isclose(fluxes["lw_flux_down"].values, ref_flux_down, atol=ERROR_TOLERANCE).all() + assert np.isclose(fluxes["lw_flux_up_broadband"], ref_flux_up, atol=ERROR_TOLERANCE).all() + assert np.isclose(fluxes["lw_flux_down_broadband"], ref_flux_down, atol=ERROR_TOLERANCE).all() diff --git a/tests/test_python_frontend/test_sw_solver.py b/tests/test_python_frontend/test_sw_solver.py index fe321a3..8d2fbb4 100644 --- a/tests/test_python_frontend/test_sw_solver.py +++ b/tests/test_python_frontend/test_sw_solver.py @@ -6,7 +6,7 @@ from pyrte_rrtmgp import rrtmgp_gas_optics from pyrte_rrtmgp.rrtmgp_gas_optics import GasOpticsFiles, load_gas_optics from pyrte_rrtmgp.rrtmgp_data import download_rrtmgp_data -from pyrte_rrtmgp.rte_solver import rte_solve +from pyrte_rrtmgp.rte_solver import RTESolver ERROR_TOLERANCE = 1e-7 @@ -20,19 +20,19 @@ input_dir, "multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc" ) ) -atmosphere = atmosphere.sel(expt=0) # only one experiment +atmosphere = atmosphere.sel(expt=0) rsu = xr.load_dataset( os.path.join(ref_dir, "rsu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc"), decode_cf=False, ) -ref_flux_up = rsu.isel(expt=0)["rsu"].values +ref_flux_up = rsu.isel(expt=0)["rsu"] rsd = xr.load_dataset( os.path.join(ref_dir, "rsd_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc"), decode_cf=False, ) -ref_flux_down = rsd.isel(expt=0)["rsd"].values +ref_flux_down = rsd.isel(expt=0)["rsd"] def test_sw_solver_noscat(): @@ -43,8 +43,9 @@ def test_sw_solver_noscat(): gas_optics_sw.gas_optics.compute(atmosphere, problem_type="two-stream") # Solve using new rte_solve function - fluxes = rte_solve(atmosphere, add_to_input=False) + solver = RTESolver() + fluxes = solver.solve(atmosphere, add_to_input=False) # Compare results - assert np.isclose(fluxes["sw_flux_up"].values, ref_flux_up, atol=ERROR_TOLERANCE).all() - assert np.isclose(fluxes["sw_flux_down"].values, ref_flux_down, atol=ERROR_TOLERANCE).all() + assert np.isclose(fluxes["sw_flux_up"], ref_flux_up, atol=ERROR_TOLERANCE).all() + assert np.isclose(fluxes["sw_flux_down"], ref_flux_down, atol=ERROR_TOLERANCE).all()