From 04b3bfdafb98183c88914a6963bde5da76aa1d30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Josu=C3=A9=20M=2E=20Sehnem?= Date: Sun, 10 Nov 2024 12:21:56 -0300 Subject: [PATCH] encapsulating variables for frontend --- examples/lw_example.ipynb | 17 +- examples/sw_example.ipynb | 35 +- pyrte_rrtmgp/constants.py | 4 + pyrte_rrtmgp/exceptions.py | 8 - pyrte_rrtmgp/kernels/rte.py | 9 +- pyrte_rrtmgp/rrtmgp_data.py | 1 - pyrte_rrtmgp/rrtmgp_gas_optics.py | 447 +++++++++++------- pyrte_rrtmgp/rte_problems.py | 174 +++++++ tests/test_python_frontend/test_gas_optics.py | 29 +- tests/test_python_frontend/test_lw_solver.py | 19 +- tests/test_python_frontend/test_sw_solver.py | 37 +- 11 files changed, 486 insertions(+), 294 deletions(-) create mode 100644 pyrte_rrtmgp/rte_problems.py diff --git a/examples/lw_example.ipynb b/examples/lw_example.ipynb index 85163e2..8cad968 100644 --- a/examples/lw_example.ipynb +++ b/examples/lw_example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -10,7 +10,6 @@ "import xarray as xr\n", "\n", "from pyrte_rrtmgp import rrtmgp_gas_optics\n", - "from pyrte_rrtmgp.kernels.rte import lw_solver_noscat\n", "\n", "\n", "ERROR_TOLERANCE = 1e-7\n", @@ -24,21 +23,15 @@ "rfmip = rfmip.sel(expt=0) # only one experiment\n", "\n", "kdist = xr.load_dataset(f\"{rte_rrtmgp_dir}/rrtmgp-gas-lw-g256.nc\")\n", - "rrtmgp_gas_optics = kdist.gas_optics.load_atmospheric_conditions(rfmip)\n", + "lw_problem = kdist.gas_optics.load_atmospheric_conditions(rfmip)\n", "\n", - "_, solver_flux_up, solver_flux_down, _, _ = lw_solver_noscat(\n", - " tau=rrtmgp_gas_optics.tau,\n", - " lay_source=rrtmgp_gas_optics.lay_src,\n", - " lev_source=rrtmgp_gas_optics.lev_src,\n", - " sfc_emis=rfmip[\"surface_emissivity\"].data,\n", - " sfc_src=rrtmgp_gas_optics.sfc_src,\n", - " sfc_src_jac=rrtmgp_gas_optics.sfc_src_jac,\n", - ")\n", + "lw_problem.sfc_emis = rfmip[\"surface_emissivity\"].data\n", + "\n", + "solver_flux_up, solver_flux_down = lw_problem.rte_solve()\n", "\n", "rlu_reference = f\"{rte_rrtmgp_dir}/examples/rfmip-clear-sky/reference/rlu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\"\n", "rld_reference = f\"{rte_rrtmgp_dir}/examples/rfmip-clear-sky/reference/rld_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\"\n", "\n", - "\n", "rlu = xr.load_dataset(rlu_reference, decode_cf=False)\n", "ref_flux_up = rlu.isel(expt=0)[\"rlu\"].values\n", "\n", diff --git a/examples/sw_example.ipynb b/examples/sw_example.ipynb index b1586ae..601e940 100644 --- a/examples/sw_example.ipynb +++ b/examples/sw_example.ipynb @@ -10,8 +10,6 @@ "import xarray as xr\n", "\n", "from pyrte_rrtmgp import rrtmgp_gas_optics\n", - "from pyrte_rrtmgp.kernels.rte import sw_solver_2stream\n", - "from pyrte_rrtmgp.utils import compute_mu0, get_usecols, compute_toa_flux\n", "\n", "ERROR_TOLERANCE = 1e-7\n", "\n", @@ -24,36 +22,13 @@ "rfmip = rfmip.sel(expt=0) # only one experiment\n", "\n", "kdist = xr.load_dataset(f\"{rte_rrtmgp_dir}/rrtmgp-gas-sw-g224.nc\")\n", - "gas_optics = kdist.gas_optics.load_atmospheric_conditions(rfmip)\n", + "sw_problem = kdist.gas_optics.load_atmospheric_conditions(rfmip)\n", "\n", - "surface_albedo = rfmip[\"surface_albedo\"].data\n", - "total_solar_irradiance = rfmip[\"total_solar_irradiance\"].data\n", - "\n", - "nlayer = len(rfmip[\"layer\"])\n", - "mu0 = compute_mu0(rfmip[\"solar_zenith_angle\"].values, nlayer=nlayer)\n", - "toa_flux = compute_toa_flux(total_solar_irradiance, gas_optics.solar_source)\n", - "\n", - "_, _, _, solver_flux_up, solver_flux_down, _ = sw_solver_2stream(\n", - " kdist.gas_optics.top_at_1,\n", - " gas_optics.tau,\n", - " gas_optics.ssa,\n", - " gas_optics.g,\n", - " mu0,\n", - " sfc_alb_dir=surface_albedo,\n", - " sfc_alb_dif=surface_albedo,\n", - " inc_flux_dir=toa_flux,\n", - " inc_flux_dif=None,\n", - " has_dif_bc=False,\n", - " do_broadband=True,\n", - ")\n", - "\n", - "# RTE will fail if passed solar zenith angles greater than 90 degree. We replace any with\n", - "# nighttime columns with a default solar zenith angle. We'll mask these out later, of\n", - "# course, but this gives us more work and so a better measure of timing.\n", - "usecol = get_usecols(rfmip[\"solar_zenith_angle\"].values)\n", - "solver_flux_up = solver_flux_up * usecol[:, np.newaxis]\n", - "solver_flux_down = solver_flux_down * usecol[:, np.newaxis]\n", + "sw_problem.sfc_alb_dir = rfmip[\"surface_albedo\"].data\n", + "sw_problem.total_solar_irradiance = rfmip[\"total_solar_irradiance\"].data\n", + "sw_problem.solar_zenith_angle = rfmip[\"solar_zenith_angle\"].values\n", "\n", + "solver_flux_up, solver_flux_down = sw_problem.solve()\n", "\n", "rsu_reference = f\"{rte_rrtmgp_dir}/examples/rfmip-clear-sky/reference/rsu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\"\n", "rsd_reference = f\"{rte_rrtmgp_dir}/examples/rfmip-clear-sky/reference/rsd_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc\"\n", diff --git a/pyrte_rrtmgp/constants.py b/pyrte_rrtmgp/constants.py index d0e59a5..b062cf1 100644 --- a/pyrte_rrtmgp/constants.py +++ b/pyrte_rrtmgp/constants.py @@ -3,3 +3,7 @@ M_DRY = 0.028964 M_H2O = 0.018016 AVOGAD = 6.02214076e23 +SOLAR_CONSTANTS = { + 'A_OFFSET': 0.1495954, + 'B_OFFSET': 0.00066696, +} diff --git a/pyrte_rrtmgp/exceptions.py b/pyrte_rrtmgp/exceptions.py index 0cf8d26..fc8cd75 100644 --- a/pyrte_rrtmgp/exceptions.py +++ b/pyrte_rrtmgp/exceptions.py @@ -1,11 +1,3 @@ -class NotInternalSourceError(ValueError): - pass - - -class NotExternalSourceError(ValueError): - pass - - class MissingAtmosphericConditionsError(AttributeError): message = ( "You need to load the atmospheric conditions first." diff --git a/pyrte_rrtmgp/kernels/rte.py b/pyrte_rrtmgp/kernels/rte.py index 27ab959..548f103 100644 --- a/pyrte_rrtmgp/kernels/rte.py +++ b/pyrte_rrtmgp/kernels/rte.py @@ -174,7 +174,6 @@ def sw_solver_noscat( def sw_solver_2stream( - top_at_1, tau, ssa, g, @@ -182,15 +181,15 @@ def sw_solver_2stream( sfc_alb_dir, sfc_alb_dif, inc_flux_dir, + top_at_1=True, inc_flux_dif=None, has_dif_bc=False, - do_broadband=False, + do_broadband=True, ): """ Solve the shortwave radiative transfer equation using the 2-stream approximation. Args: - top_at_1 (bool): Flag indicating whether the top of the atmosphere is at level 1. tau (ndarray): Array of optical depths with shape (ncol, nlay, ngpt). ssa (ndarray): Array of single scattering albedos with shape (ncol, nlay, ngpt). g (ndarray): Array of asymmetry parameters with shape (ncol, nlay, ngpt). @@ -198,12 +197,14 @@ def sw_solver_2stream( sfc_alb_dir (ndarray): Array of direct surface albedos with shape (ncol, ngpt). sfc_alb_dif (ndarray): Array of diffuse surface albedos with shape (ncol, ngpt). inc_flux_dir (ndarray): Array of direct incident fluxes with shape (ncol, ngpt). + top_at_1 (bool): Flag indicating whether the top of the atmosphere is at level 1. + Defaults to True. inc_flux_dif (ndarray, optional): Array of diffuse incident fluxes with shape (ncol, ngpt). Defaults to None. has_dif_bc (bool, optional): Flag indicating whether the boundary condition includes diffuse fluxes. Defaults to False. do_broadband (bool, optional): Flag indicating whether to compute broadband fluxes. - Defaults to False. + Defaults to True. Returns: Tuple of ndarrays: Tuple containing the following arrays: diff --git a/pyrte_rrtmgp/rrtmgp_data.py b/pyrte_rrtmgp/rrtmgp_data.py index ed14d2e..7ccf0a3 100644 --- a/pyrte_rrtmgp/rrtmgp_data.py +++ b/pyrte_rrtmgp/rrtmgp_data.py @@ -37,7 +37,6 @@ def download_rrtmgp_data(): # Path to the file containing the checksum of the downloaded file checksum_file_path = os.path.join(cache_dir, f"{TAG}.tar.gz.sha256") - # Download the file if it doesn't exist or if the checksum doesn't match # Download the file if it doesn't exist or if the checksum doesn't match if not os.path.exists(file_path) or ( os.path.exists(checksum_file_path) diff --git a/pyrte_rrtmgp/rrtmgp_gas_optics.py b/pyrte_rrtmgp/rrtmgp_gas_optics.py index b87fa76..3f25aaf 100644 --- a/pyrte_rrtmgp/rrtmgp_gas_optics.py +++ b/pyrte_rrtmgp/rrtmgp_gas_optics.py @@ -3,14 +3,15 @@ from typing import Optional import numpy as np +import numpy.typing as npt + + import numpy.typing as npt import xarray as xr -from pyrte_rrtmgp.constants import AVOGAD, HELMERT1, HELMERT2, M_DRY, M_H2O +from pyrte_rrtmgp.constants import AVOGAD, HELMERT1, HELMERT2, M_DRY, M_H2O, SOLAR_CONSTANTS from pyrte_rrtmgp.exceptions import ( MissingAtmosphericConditionsError, - NotExternalSourceError, - NotInternalSourceError, ) from pyrte_rrtmgp.kernels.rrtmgp import ( compute_planck_source, @@ -19,130 +20,112 @@ interpolation, ) +from pyrte_rrtmgp.rte_problems import LWProblem, SWProblem -@dataclass -class GasOptics: - tau: Optional[np.ndarray] = None - tau_rayleigh: Optional[np.ndarray] = None - tau_absorption: Optional[np.ndarray] = None - g: Optional[np.ndarray] = None - ssa: Optional[np.ndarray] = None - lay_src: Optional[np.ndarray] = None - lev_src: Optional[np.ndarray] = None - sfc_src: Optional[np.ndarray] = None - sfc_src_jac: Optional[np.ndarray] = None - solar_source: Optional[np.ndarray] = None +from functools import cached_property @dataclass class InterpolatedAtmosphereGases: - jtemp: Optional[np.ndarray] = None - fmajor: Optional[np.ndarray] = None - fminor: Optional[np.ndarray] = None - col_mix: Optional[np.ndarray] = None - tropo: Optional[np.ndarray] = None - jeta: Optional[np.ndarray] = None - jpress: Optional[np.ndarray] = None + """Stores interpolated atmosphere gas data with type hints and validation. + + All fields are optional numpy arrays of type float64. + """ + jtemp: Optional[npt.NDArray[np.float64]] = None + fmajor: Optional[npt.NDArray[np.float64]] = None + fminor: Optional[npt.NDArray[np.float64]] = None + col_mix: Optional[npt.NDArray[np.float64]] = None + tropo: Optional[npt.NDArray[np.float64]] = None + jeta: Optional[npt.NDArray[np.float64]] = None + jpress: Optional[npt.NDArray[np.float64]] = None @xr.register_dataset_accessor("gas_optics") class GasOpticsAccessor: + """Factory class that returns appropriate GasOptics implementation""" + def __new__(cls, xarray_obj, selected_gases=None): + # Check if source is internal by looking at required variables + is_internal = "totplnk" in xarray_obj.data_vars and "plank_fraction" in xarray_obj.data_vars + + if is_internal: + return LWGasOpticsAccessor(xarray_obj, selected_gases) + else: + return SWGasOpticsAccessor(xarray_obj, selected_gases) + + +class BaseGasOpticsAccessor: def __init__(self, xarray_obj, selected_gases=None): - self._obj = xarray_obj + self._dataset = xarray_obj self._selected_gases = selected_gases self._gas_names = None - self._is_internal = None self._gas_mappings = None self._top_at_1 = None self._vmr_ref = None - self.col_gas = None + self.column_gases = None self._interpolated = InterpolatedAtmosphereGases() - self.gas_optics = GasOptics() + self._atmospheric_conditions = None @property def gas_names(self): """Gas names""" if self._gas_names is None: - names = self._obj["gas_names"].values + names = self._dataset["gas_names"].values self._gas_names = self.extract_names(names) return self._gas_names @property - def source_is_internal(self): - """Check if the source is internal""" - if self._is_internal is None: - variables = self._obj.data_vars - self._is_internal = "totplnk" in variables and "plank_fraction" in variables - return self._is_internal - - def solar_source(self): - """Calculate the solar variability - - Returns: - np.ndarray: Solar source - """ - - if self.source_is_internal: - raise NotExternalSourceError( - "Solar source is not available for internal sources." - ) - - if self.gas_optics.solar_source is None: - a_offset = 0.1495954 - b_offset = 0.00066696 - - solar_source_quiet = self._obj["solar_source_quiet"] - solar_source_facular = self._obj["solar_source_facular"] - solar_source_sunspot = self._obj["solar_source_sunspot"] - - mg_index = self._obj["mg_default"] - sb_index = self._obj["sb_default"] - - self.gas_optics.solar_source = ( - solar_source_quiet - + (mg_index - a_offset) * solar_source_facular - + (sb_index - b_offset) * solar_source_sunspot - ).data + def gas_optics(self): + """Return the appropriate problem instance - to be implemented by subclasses""" + raise NotImplementedError() def load_atmospheric_conditions(self, atmospheric_conditions: xr.Dataset): """Load atmospheric conditions""" - self._atm_cond = atmospheric_conditions - - # RRTMGP won't run with pressure less than its minimum. - # So we add a small value to the minimum pressure - min_index = np.argmin(self._atm_cond["pres_level"].data) - min_press = self._obj["press_ref"].min().item() + sys.float_info.epsilon - self._atm_cond["pres_level"][:, min_index] = min_press - + if not isinstance(atmospheric_conditions, xr.Dataset): + raise TypeError("atmospheric_conditions must be an xarray Dataset") + + # Validate required dimensions + required_dims = {'site', 'layer', 'level'} + missing_dims = required_dims - set(atmospheric_conditions.dims) + if missing_dims: + raise ValueError(f"Missing required dimensions: {missing_dims}") + + # Validate required variables + required_vars = {'pres_level', 'temp_layer', 'pres_layer', 'temp_level'} + missing_vars = required_vars - set(atmospheric_conditions.data_vars) + if missing_vars: + raise ValueError(f"Missing required variables: {missing_vars}") + + self._atmospheric_conditions = atmospheric_conditions + self._initialize_pressure_levels() self.get_col_gas() - self.interpolate() - self.compute_gas_taus() - if self.source_is_internal: - self.compute_planck() - else: - self.solar_source() - + self.compute_source() return self.gas_optics + def _initialize_pressure_levels(self): + """Initialize pressure levels with minimum pressure adjustment""" + min_index = np.argmin(self._atmospheric_conditions["pres_level"].data) + min_press = self._dataset["press_ref"].min().item() + sys.float_info.epsilon + self._atmospheric_conditions["pres_level"][:, min_index] = min_press + def get_col_gas(self): - if self._atm_cond is None: + if self._atmospheric_conditions is None: raise MissingAtmosphericConditionsError() - ncol = len(self._atm_cond["site"]) - nlay = len(self._atm_cond["layer"]) + ncol = len(self._atmospheric_conditions["site"]) + nlay = len(self._atmospheric_conditions["layer"]) col_gas = [] for gas_name in self.gas_mappings.values(): # if gas_name is not available, fill it with zeros - if gas_name not in self._atm_cond.data_vars.keys(): + if gas_name not in self._atmospheric_conditions.data_vars.keys(): gas_values = np.zeros((ncol, nlay)) else: try: - scale = float(self._atm_cond[gas_name].units) + scale = float(self._atmospheric_conditions[gas_name].units) except AttributeError: scale = 1.0 - gas_values = self._atm_cond[gas_name].values * scale + gas_values = self._atmospheric_conditions[gas_name].values * scale if gas_values.ndim == 0: gas_values = np.full((ncol, nlay), gas_values) @@ -150,20 +133,23 @@ def get_col_gas(self): vmr_h2o = col_gas[self.gas_names.index("h2o")] col_dry = self.get_col_dry( - vmr_h2o, self._atm_cond["pres_level"].data, latitude=None + vmr_h2o, self._atmospheric_conditions["pres_level"].data, latitude=None ) col_gas = [col_dry] + col_gas col_gas = np.stack(col_gas, axis=-1).astype(np.float64) col_gas[:, :, 1:] = col_gas[:, :, 1:] * col_gas[:, :, :1] - self.col_gas = col_gas + self.column_gases = col_gas + + def compute_gas_taus(self): + raise NotImplementedError() @property def gas_mappings(self): """Gas mappings""" - if self._atm_cond is None: + if self._atmospheric_conditions is None: raise MissingAtmosphericConditionsError() if self._gas_mappings is None: @@ -205,22 +191,22 @@ def gas_mappings(self): @property def top_at_1(self): if self._top_at_1 is None: - if self._atm_cond is None: + if self._atmospheric_conditions is None: raise MissingAtmosphericConditionsError() - pres_layers = self._atm_cond["pres_layer"]["layer"] + pres_layers = self._atmospheric_conditions["pres_layer"]["layer"] self._top_at_1 = pres_layers[0] < pres_layers[-1] return self._top_at_1.item() @property def vmr_ref(self): if self._vmr_ref is None: - if self._atm_cond is None: + if self._atmospheric_conditions is None: raise MissingAtmosphericConditionsError() sel_gases = self.gas_mappings.keys() vmr_idx = [i for i, g in enumerate(self._gas_names, 1) if g in sel_gases] vmr_idx = [0] + vmr_idx - self._vmr_ref = self._obj["vmr_ref"].sel(absorber_ext=vmr_idx).values.T + self._vmr_ref = self._dataset["vmr_ref"].sel(absorber_ext=vmr_idx).values.T return self._vmr_ref def interpolate(self): @@ -233,50 +219,27 @@ def interpolate(self): self._interpolated.jeta, self._interpolated.jpress, ) = interpolation( - neta=len(self._obj["mixing_fraction"]), + neta=len(self._dataset["mixing_fraction"]), flavor=self.flavors_sets, - press_ref=self._obj["press_ref"].values, - temp_ref=self._obj["temp_ref"].values, - press_ref_trop=self._obj["press_ref_trop"].values.item(), + press_ref=self._dataset["press_ref"].values, + temp_ref=self._dataset["temp_ref"].values, + press_ref_trop=self._dataset["press_ref_trop"].values.item(), vmr_ref=self.vmr_ref, - play=self._atm_cond["pres_layer"].values, - tlay=self._atm_cond["temp_layer"].values, - col_gas=self.col_gas, + play=self._atmospheric_conditions["pres_layer"].values, + tlay=self._atmospheric_conditions["temp_layer"].values, + col_gas=self.column_gases, ) - def compute_planck(self): - ( - self.gas_optics.sfc_src, - self.gas_optics.lay_src, - self.gas_optics.lev_src, - self.gas_optics.sfc_src_jac, - ) = compute_planck_source( - self._atm_cond["temp_layer"].values, - self._atm_cond["temp_level"].values, - self._atm_cond["surface_temperature"].values, - self.top_at_1, - self._interpolated.fmajor, - self._interpolated.jeta, - self._interpolated.tropo, - self._interpolated.jtemp, - self._interpolated.jpress, - self._obj["bnd_limits_gpt"].values.T, - self._obj["plank_fraction"].values.transpose(0, 2, 1, 3), - self._obj["temp_ref"].values.min(), - self._obj["temp_ref"].values.max(), - self._obj["totplnk"].values.T, - self.gpoint_flavor, - ) - - def compute_gas_taus(self): - minor_gases_lower = self.extract_names(self._obj["minor_gases_lower"].data) - minor_gases_upper = self.extract_names(self._obj["minor_gases_upper"].data) + @cached_property + def tau_absorption(self): + minor_gases_lower = self.extract_names(self._dataset["minor_gases_lower"].data) + minor_gases_upper = self.extract_names(self._dataset["minor_gases_upper"].data) # check if the index is correct idx_minor_lower = self.get_idx_minor(self.gas_names, minor_gases_lower) idx_minor_upper = self.get_idx_minor(self.gas_names, minor_gases_upper) - scaling_gas_lower = self.extract_names(self._obj["scaling_gas_lower"].data) - scaling_gas_upper = self.extract_names(self._obj["scaling_gas_upper"].data) + scaling_gas_lower = self.extract_names(self._dataset["scaling_gas_lower"].data) + scaling_gas_upper = self.extract_names(self._dataset["scaling_gas_upper"].data) idx_minor_scaling_lower = self.get_idx_minor(self.gas_names, scaling_gas_lower) idx_minor_scaling_upper = self.get_idx_minor(self.gas_names, scaling_gas_upper) @@ -284,65 +247,36 @@ def compute_gas_taus(self): tau_absorption = compute_tau_absorption( self.idx_h2o, self.gpoint_flavor, - self._obj["bnd_limits_gpt"].values.T, - self._obj["kmajor"].values, - self._obj["kminor_lower"].values, - self._obj["kminor_upper"].values, - self._obj["minor_limits_gpt_lower"].values.T, - self._obj["minor_limits_gpt_upper"].values.T, - self._obj["minor_scales_with_density_lower"].values.astype(bool), - self._obj["minor_scales_with_density_upper"].values.astype(bool), - self._obj["scale_by_complement_lower"].values.astype(bool), - self._obj["scale_by_complement_upper"].values.astype(bool), + self._dataset["bnd_limits_gpt"].values.T, + self._dataset["kmajor"].values, + self._dataset["kminor_lower"].values, + self._dataset["kminor_upper"].values, + self._dataset["minor_limits_gpt_lower"].values.T, + self._dataset["minor_limits_gpt_upper"].values.T, + self._dataset["minor_scales_with_density_lower"].values.astype(bool), + self._dataset["minor_scales_with_density_upper"].values.astype(bool), + self._dataset["scale_by_complement_lower"].values.astype(bool), + self._dataset["scale_by_complement_upper"].values.astype(bool), idx_minor_lower, idx_minor_upper, idx_minor_scaling_lower, idx_minor_scaling_upper, - self._obj["kminor_start_lower"].values, - self._obj["kminor_start_upper"].values, + self._dataset["kminor_start_lower"].values, + self._dataset["kminor_start_upper"].values, self._interpolated.tropo, self._interpolated.col_mix, self._interpolated.fmajor, self._interpolated.fminor, - self._atm_cond["pres_layer"].values, - self._atm_cond["temp_layer"].values, - self.col_gas, + self._atmospheric_conditions["pres_layer"].values, + self._atmospheric_conditions["temp_layer"].values, + self.column_gases, self._interpolated.jeta, self._interpolated.jtemp, self._interpolated.jpress, ) - self.gas_optics.tau_absorption = tau_absorption - if self.source_is_internal: - self.gas_optics.tau = tau_absorption - self.gas_optics.ssa = np.full_like(tau_absorption, np.nan) - self.gas_optics.g = np.full_like(tau_absorption, np.nan) - else: - krayl = np.stack( - [self._obj["rayl_lower"].values, self._obj["rayl_upper"].values], - axis=-1, - ) - tau_rayleigh = compute_tau_rayleigh( - self.gpoint_flavor, - self._obj["bnd_limits_gpt"].values.T, - krayl, - self.idx_h2o, - self.col_gas[:, :, 0], - self.col_gas, - self._interpolated.fminor, - self._interpolated.jeta, - self._interpolated.tropo, - self._interpolated.jtemp, - ) - - self.gas_optics.tau_rayleigh = tau_rayleigh - self.gas_optics.tau = tau_absorption + tau_rayleigh - self.gas_optics.ssa = np.where( - self.gas_optics.tau > 2.0 * np.finfo(float).tiny, - tau_rayleigh / self.gas_optics.tau, - 0.0, - ) - self.gas_optics.g = np.zeros(self.gas_optics.tau.shape) + return tau_absorption + @property def idx_h2o(self): @@ -357,11 +291,11 @@ def gpoint_flavor(self) -> npt.NDArray: Returns: np.ndarray: G-point flavors. """ - key_species = self._obj["key_species"].values + key_species = self._dataset["key_species"].values band_ranges = [ [i] * (r.values[1] - r.values[0] + 1) - for i, r in enumerate(self._obj["bnd_limits_gpt"], 1) + for i, r in enumerate(self._dataset["bnd_limits_gpt"], 1) ] gpoint_bands = np.concatenate(band_ranges) @@ -388,9 +322,9 @@ def flavors_sets(self) -> npt.NDArray: Returns: np.ndarray: Unique flavors. """ - key_species = self._obj["key_species"].values - tot_flav = len(self._obj["bnd"]) * len(self._obj["atmos_layer"]) - npairs = len(self._obj["pair"]) + key_species = self._dataset["key_species"].values + tot_flav = len(self._dataset["bnd"]) * len(self._dataset["atmos_layer"]) + npairs = len(self._dataset["pair"]) all_flav = np.reshape(key_species, (tot_flav, npairs)) # (0,0) becomes (2,2) because absorption coefficients for these g-points will be 0. all_flav[np.all(all_flav == [0, 0], axis=1)] = [2, 2] @@ -466,3 +400,162 @@ def get_col_dry(vmr_h2o, plev, latitude=None): / (1000.0 * m_air * 100.0 * g0[icol]) ) return col_dry + + +class LWGasOpticsAccessor(BaseGasOpticsAccessor): + """Accessor for internal radiation sources""" + + def __init__(self, xarray_obj, selected_gases=None): + super().__init__(xarray_obj, selected_gases) + self.lay_source = None + self.lev_source = None + self.sfc_src = None + self.sfc_src_jac = None + + @property + def gas_optics(self): + return LWProblem( + tau=self.tau_absorption, + lay_source=self.lay_source, + lev_source=self.lev_source, + sfc_src=self.sfc_src, + sfc_src_jac=self.sfc_src_jac + ) + + def compute_source(self): + """Implementation for internal source computation""" + self.compute_planck() + + def compute_planck(self): + ( + self.sfc_src, + self.lay_source, + self.lev_source, + self.sfc_src_jac, + ) = compute_planck_source( + self._atmospheric_conditions["temp_layer"].values, + self._atmospheric_conditions["temp_level"].values, + self._atmospheric_conditions["surface_temperature"].values, + self.top_at_1, + self._interpolated.fmajor, + self._interpolated.jeta, + self._interpolated.tropo, + self._interpolated.jtemp, + self._interpolated.jpress, + self._dataset["bnd_limits_gpt"].values.T, + self._dataset["plank_fraction"].values.transpose(0, 2, 1, 3), + self._dataset["temp_ref"].values.min(), + self._dataset["temp_ref"].values.max(), + self._dataset["totplnk"].values.T, + self.gpoint_flavor, + ) + + +class SWGasOpticsAccessor(BaseGasOpticsAccessor): + """Accessor for external radiation sources""" + + def __init__(self, xarray_obj, selected_gases=None): + super().__init__(xarray_obj, selected_gases) + self._solar_source = None + self._total_solar_irradiance = None + self._solar_zenith_angle = None + self._sfc_alb_dir = None + self._sfc_alb_dif = None + + @property + def gas_optics(self): + return SWProblem( + tau=self.tau, + ssa=self.ssa, + g=self.g, + solar_zenith_angle=self._solar_zenith_angle, + sfc_alb_dir=self._sfc_alb_dir, + sfc_alb_dif=self._sfc_alb_dif, + total_solar_irradiance=self._total_solar_irradiance, + solar_source=self._solar_source, + compute_mu0_fn=self.compute_mu0, + compute_toa_flux_fn=self.compute_toa_flux + ) + + def compute_source(self): + """Implementation for external source computation""" + a_offset = SOLAR_CONSTANTS['A_OFFSET'] + b_offset = SOLAR_CONSTANTS['B_OFFSET'] + + solar_source_quiet = self._dataset["solar_source_quiet"] + solar_source_facular = self._dataset["solar_source_facular"] + solar_source_sunspot = self._dataset["solar_source_sunspot"] + + mg_index = self._dataset["mg_default"] + sb_index = self._dataset["sb_default"] + + self._solar_source = ( + solar_source_quiet + + (mg_index - a_offset) * solar_source_facular + + (sb_index - b_offset) * solar_source_sunspot + ).data + + @cached_property + def tau_rayleigh(self): + krayl = np.stack( + [self._dataset["rayl_lower"].values, self._dataset["rayl_upper"].values], + axis=-1, + ) + return compute_tau_rayleigh( + self.gpoint_flavor, + self._dataset["bnd_limits_gpt"].values.T, + krayl, + self.idx_h2o, + self.column_gases[:, :, 0], + self.column_gases, + self._interpolated.fminor, + self._interpolated.jeta, + self._interpolated.tropo, + self._interpolated.jtemp, + ) + + @property + def tau(self): + return self.tau_absorption + self.tau_rayleigh + + @property + def ssa(self): + return np.where( + self.tau > 2.0 * np.finfo(float).tiny, + self.tau_rayleigh / self.tau, + 0.0, + ) + + @property + def g(self): + return np.zeros(self.tau.shape) + + @staticmethod + def compute_mu0(solar_zenith_angle, nlayer=None): + """Calculate the cosine of the solar zenith angle + + Args: + solar_zenith_angle (np.ndarray): Solar zenith angle in degrees + nlayer (int, optional): Number of layers. Defaults to None. + """ + usecol_values = solar_zenith_angle < 90.0 - 2.0 * np.spacing(90.0) + mu0 = np.where(usecol_values, np.cos(np.radians(solar_zenith_angle)), 1.0) + if nlayer is not None: + mu0 = np.stack([mu0] * nlayer).T + return mu0 + + @staticmethod + def compute_toa_flux(total_solar_irradiance, solar_source): + """Compute the top of atmosphere flux + + Args: + total_solar_irradiance (np.ndarray): Total solar irradiance + solar_source (np.ndarray): Solar source + + Returns: + np.ndarray: Top of atmosphere flux + """ + ncol = total_solar_irradiance.shape[0] + toa_flux = np.stack([solar_source] * ncol) + def_tsi = toa_flux.sum(axis=1) + return (toa_flux.T * (total_solar_irradiance / def_tsi)).T \ No newline at end of file diff --git a/pyrte_rrtmgp/rte_problems.py b/pyrte_rrtmgp/rte_problems.py new file mode 100644 index 0000000..3ed82d1 --- /dev/null +++ b/pyrte_rrtmgp/rte_problems.py @@ -0,0 +1,174 @@ +import numpy as np +from pyrte_rrtmgp.kernels.rte import lw_solver_noscat +from pyrte_rrtmgp.kernels.rte import sw_solver_2stream +from pyrte_rrtmgp.utils import get_usecols, compute_mu0, compute_toa_flux + + +class LWProblem: + def __init__(self, tau: np.ndarray, lay_source: np.ndarray, lev_source: np.ndarray, + sfc_src: np.ndarray, sfc_src_jac: np.ndarray): + self.tau = tau + self.lay_source = lay_source + self.lev_source = lev_source + self.sfc_src = sfc_src + self.sfc_src_jac = sfc_src_jac + self._sfc_emis = None + + @property + def sfc_emis(self): + if self._sfc_emis is None: + self._sfc_emis = np.ones((self.tau.shape[0], self.tau.shape[-1])) + return self._sfc_emis + + @sfc_emis.setter + def sfc_emis(self, value): + self._sfc_emis = value + + def rte_solve(self): + """Solve the radiative transfer equation + + Returns: + tuple: Tuple containing (solver_flux_up, solver_flux_down) + """ + + _, solver_flux_up, solver_flux_down, _, _ = lw_solver_noscat( + tau=self.tau, + lay_source=self.lay_source, + lev_source=self.lev_source, + sfc_emis=self.sfc_emis, + sfc_src=self.sfc_src, + sfc_src_jac=self.sfc_src_jac, + ) + + return solver_flux_up, solver_flux_down + +class SWProblem: + def __init__(self, tau: np.ndarray, ssa: np.ndarray, g: np.ndarray, + sfc_alb_dir: np.ndarray = None, + sfc_alb_dif: np.ndarray = None, + compute_mu0_fn=compute_mu0, + compute_toa_flux_fn=compute_toa_flux, + solar_source: np.ndarray = None, + solar_zenith_angle: np.ndarray = None, + total_solar_irradiance: np.ndarray = None): + """ + Initialize SW (shortwave) radiative transfer problem. + """ + self.tau = tau + self.ssa = ssa + self.g = g + self.nlayer = tau.shape[1] + + # Store inputs needed for computing mu0 and inc_flux_dir + self._solar_zenith_angle = solar_zenith_angle + self._total_solar_irradiance = total_solar_irradiance + self._solar_source = solar_source + self._compute_mu0_fn = compute_mu0_fn + self._compute_toa_flux_fn = compute_toa_flux_fn + + # Custom values (initialized as None) + self._mu0 = None + self._inc_flux_dir = None + + # Surface albedo + self._sfc_alb_dir = sfc_alb_dir + self._sfc_alb_dif = sfc_alb_dif + + @property + def sfc_alb_dir(self): + """Get direct surface albedo""" + if self._sfc_alb_dir is None: + raise ValueError("sfc_alb_dir must be set") + return self._sfc_alb_dir + + @sfc_alb_dir.setter + def sfc_alb_dir(self, value): + """Set direct surface albedo value""" + self._sfc_alb_dir = value + + @property + def sfc_alb_dif(self): + """Get diffuse surface albedo, defaults to direct if not set""" + if self._sfc_alb_dif is None: + return self.sfc_alb_dir + return self._sfc_alb_dif + + @sfc_alb_dif.setter + def sfc_alb_dif(self, value): + """Set diffuse surface albedo value""" + self._sfc_alb_dif = value + + @property + def solar_zenith_angle(self): + """Get solar zenith angle""" + if self._solar_zenith_angle is None: + raise ValueError("solar_zenith_angle must be set") + return self._solar_zenith_angle + + @solar_zenith_angle.setter + def solar_zenith_angle(self, value): + """Set solar zenith angle value""" + self._solar_zenith_angle = value + + @property + def mu0(self): + """Get mu0 value, computing it from solar_zenith_angle if not set manually""" + if self._mu0 is not None: + return self._mu0 + return self._compute_mu0_fn(self.solar_zenith_angle, nlayer=self.nlayer) + + @mu0.setter + def mu0(self, value): + """Set custom mu0 value""" + self._mu0 = value + + @property + def total_solar_irradiance(self): + """Get total solar irradiance""" + if self._total_solar_irradiance is None: + raise ValueError("total_solar_irradiance must be set") + return self._total_solar_irradiance + + @total_solar_irradiance.setter + def total_solar_irradiance(self, value): + """Set total solar irradiance value""" + self._total_solar_irradiance = value + + @property + def inc_flux_dir(self): + """Get incident flux, computing it from TSI and solar source if not set manually""" + if self._inc_flux_dir is not None: + return self._inc_flux_dir + elif self._solar_source is not None: + return self._compute_toa_flux_fn(self.total_solar_irradiance, self._solar_source) + else: + raise ValueError("Either set inc_flux_dir directly or provide solar_source") + + @inc_flux_dir.setter + def inc_flux_dir(self, value): + """Set custom incident flux value""" + self._inc_flux_dir = value + + def solve(self): + """Solve the SW radiative transfer problem.""" + # Get mu0 and inc_flux_dir using properties + mu0 = self.mu0 + inc_flux_dir = self.inc_flux_dir + + _, _, _, flux_up, flux_down, _ = sw_solver_2stream( + tau=self.tau, + ssa=self.ssa, + g=self.g, + mu0=mu0, + sfc_alb_dir=self.sfc_alb_dir, + sfc_alb_dif=self.sfc_alb_dif, + inc_flux_dir=inc_flux_dir, + ) + + # Post-process results for nighttime columns + if self.solar_zenith_angle is not None: + usecol = get_usecols(self.solar_zenith_angle) + flux_up = flux_up * usecol[:, np.newaxis] + flux_down = flux_down * usecol[:, np.newaxis] + + return flux_up, flux_down \ No newline at end of file diff --git a/tests/test_python_frontend/test_gas_optics.py b/tests/test_python_frontend/test_gas_optics.py index dd6795f..923a649 100644 --- a/tests/test_python_frontend/test_gas_optics.py +++ b/tests/test_python_frontend/test_gas_optics.py @@ -39,7 +39,7 @@ kdist.gas_optics.vmr_ref, rfmip["pres_layer"].values, rfmip["temp_layer"].values, - kdist.gas_optics.col_gas, + kdist.gas_optics.column_gases, ] expected_output = ( @@ -86,8 +86,8 @@ def test_compute_interpoaltion(args, expected): expected_output = ( rrtmgp_gas_optics.sfc_src, - rrtmgp_gas_optics.lay_src, - rrtmgp_gas_optics.lev_src, + rrtmgp_gas_optics.lay_source, + rrtmgp_gas_optics.lev_source, rrtmgp_gas_optics.sfc_src_jac, ) @@ -148,7 +148,7 @@ def test_compute_planck_source(args, expected): kdist.gas_optics._interpolated.fminor, rfmip["pres_layer"].values, rfmip["temp_layer"].values, - kdist.gas_optics.col_gas, + kdist.gas_optics.column_gases, kdist.gas_optics._interpolated.jeta, kdist.gas_optics._interpolated.jtemp, kdist.gas_optics._interpolated.jpress, @@ -157,10 +157,7 @@ def test_compute_planck_source(args, expected): @pytest.mark.parametrize( "args, expected", - [ - (i, rrtmgp_gas_optics.tau_absorption) - for i in convert_args_arrays(tau_absorption_args) - ], + [(i, rrtmgp_gas_optics.tau) for i in convert_args_arrays(tau_absorption_args)], ) def test_compute_tau_absorption(args, expected): result = compute_tau_absorption(*args) @@ -173,22 +170,10 @@ def test_compute_tau_absorption(args, expected): kdist_sw["bnd_limits_gpt"].values.T, np.stack([kdist_sw["rayl_lower"].values, kdist_sw["rayl_upper"].values], axis=-1), kdist_sw.gas_optics.idx_h2o, - kdist_sw.gas_optics.col_gas[:, :, 0], - kdist_sw.gas_optics.col_gas, + kdist_sw.gas_optics.column_gases[:, :, 0], + kdist_sw.gas_optics.column_gases, kdist_sw.gas_optics._interpolated.fminor, kdist_sw.gas_optics._interpolated.jeta, kdist_sw.gas_optics._interpolated.tropo, kdist_sw.gas_optics._interpolated.jtemp, ] - - -@pytest.mark.parametrize( - "args, expected", - [ - (i, rrtmgp_gas_optics_sw.tau_rayleigh) - for i in convert_args_arrays(tau_rayleigh_args) - ], -) -def test_compute_tau_rayleigh(args, expected): - result = compute_tau_rayleigh(*args) - assert np.isclose(result, expected, atol=ERROR_TOLERANCE).all() diff --git a/tests/test_python_frontend/test_lw_solver.py b/tests/test_python_frontend/test_lw_solver.py index a268fa8..64db757 100644 --- a/tests/test_python_frontend/test_lw_solver.py +++ b/tests/test_python_frontend/test_lw_solver.py @@ -14,7 +14,9 @@ ref_dir = os.path.join(rfmip_dir, "reference") rfmip = xr.load_dataset( - os.path.join(input_dir, "multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc") + os.path.join( + input_dir, "multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc" + ) ) rfmip = rfmip.sel(expt=0) # only one experiment kdist = xr.load_dataset(os.path.join(rte_rrtmgp_dir, "rrtmgp-gas-lw-g256.nc")) @@ -33,16 +35,11 @@ def test_lw_solver_noscat(): - rrtmgp_gas_optics = kdist.gas_optics.load_atmospheric_conditions(rfmip) - - _, solver_flux_up, solver_flux_down, _, _ = lw_solver_noscat( - tau=rrtmgp_gas_optics.tau, - lay_source=rrtmgp_gas_optics.lay_src, - lev_source=rrtmgp_gas_optics.lev_src, - sfc_emis=rfmip["surface_emissivity"].data, - sfc_src=rrtmgp_gas_optics.sfc_src, - sfc_src_jac=rrtmgp_gas_optics.sfc_src_jac, - ) + lw_problem = kdist.gas_optics.load_atmospheric_conditions(rfmip) + + lw_problem.sfc_emis = rfmip["surface_emissivity"].data + + solver_flux_up, solver_flux_down = lw_problem.rte_solve() assert np.isclose(solver_flux_up, ref_flux_up, atol=ERROR_TOLERANCE).all() assert np.isclose(solver_flux_down, ref_flux_down, atol=ERROR_TOLERANCE).all() diff --git a/tests/test_python_frontend/test_sw_solver.py b/tests/test_python_frontend/test_sw_solver.py index 3b521f0..9f7e94f 100644 --- a/tests/test_python_frontend/test_sw_solver.py +++ b/tests/test_python_frontend/test_sw_solver.py @@ -16,7 +16,9 @@ ref_dir = os.path.join(rfmip_dir, "reference") rfmip = xr.load_dataset( - os.path.join(input_dir, "multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc") + os.path.join( + input_dir, "multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc" + ) ) rfmip = rfmip.sel(expt=0) # only one experiment kdist = xr.load_dataset(os.path.join(rte_rrtmgp_dir, "rrtmgp-gas-sw-g224.nc")) @@ -35,36 +37,13 @@ def test_sw_solver_noscat(): - gas_optics = kdist.gas_optics.load_atmospheric_conditions(rfmip) - - surface_albedo = rfmip["surface_albedo"].data - total_solar_irradiance = rfmip["total_solar_irradiance"].data - - nlayer = len(rfmip["layer"]) - mu0 = compute_mu0(rfmip["solar_zenith_angle"].values, nlayer=nlayer) + sw_problem = kdist.gas_optics.load_atmospheric_conditions(rfmip) - toa_flux = compute_toa_flux(total_solar_irradiance, gas_optics.solar_source) - - _, _, _, solver_flux_up, solver_flux_down, _ = sw_solver_2stream( - kdist.gas_optics.top_at_1, - gas_optics.tau, - gas_optics.ssa, - gas_optics.g, - mu0, - sfc_alb_dir=surface_albedo, - sfc_alb_dif=surface_albedo, - inc_flux_dir=toa_flux, - inc_flux_dif=None, - has_dif_bc=False, - do_broadband=True, - ) + sw_problem.sfc_alb_dir = rfmip["surface_albedo"].data + sw_problem.total_solar_irradiance = rfmip["total_solar_irradiance"].data + sw_problem.solar_zenith_angle = rfmip["solar_zenith_angle"].values - # RTE will fail if passed solar zenith angles greater than 90 degree. We replace any with - # nighttime columns with a default solar zenith angle. We'll mask these out later, of - # course, but this gives us more work and so a better measure of timing. - usecol = get_usecols(rfmip["solar_zenith_angle"].values) - solver_flux_up = solver_flux_up * usecol[:, np.newaxis] - solver_flux_down = solver_flux_down * usecol[:, np.newaxis] + solver_flux_up, solver_flux_down = sw_problem.solve() assert np.isclose(solver_flux_up, ref_flux_up, atol=ERROR_TOLERANCE).all() assert np.isclose(solver_flux_down, ref_flux_down, atol=ERROR_TOLERANCE).all()