From 39c4e875f58ce27547c45cea9ae7e0fa99df415d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Josu=C3=A9=20Sehnem?= Date: Sat, 7 Dec 2024 14:20:15 -0300 Subject: [PATCH] add docstrings and types --- examples/lw_example.ipynb | 10 +- pyrte_rrtmgp/config.py | 35 +- pyrte_rrtmgp/constants.py | 55 ++- pyrte_rrtmgp/data_types.py | 33 +- pyrte_rrtmgp/data_validation.py | 133 ++++-- pyrte_rrtmgp/kernels/rrtmgp.py | 440 ++++++++++--------- pyrte_rrtmgp/kernels/rte.py | 75 ++-- pyrte_rrtmgp/rrtmgp_data.py | 35 +- pyrte_rrtmgp/rrtmgp_gas_optics.py | 336 +++++++++++--- pyrte_rrtmgp/rte_solver.py | 127 ++++-- pyrte_rrtmgp/utils.py | 67 --- tests/test_python_frontend/test_lw_solver.py | 16 +- tests/test_python_frontend/test_sw_solver.py | 8 +- 13 files changed, 896 insertions(+), 474 deletions(-) delete mode 100644 pyrte_rrtmgp/utils.py diff --git a/examples/lw_example.ipynb b/examples/lw_example.ipynb index 1c61484..03ac7cb 100644 --- a/examples/lw_example.ipynb +++ b/examples/lw_example.ipynb @@ -27,7 +27,7 @@ "\n", "atmosphere_file = \"multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc\"\n", "atmosphere_path = os.path.join(input_dir, atmosphere_file)\n", - "atmosphere = xr.load_dataset(atmosphere_path)#.sel(expt=0)\n", + "atmosphere = xr.load_dataset(atmosphere_path)\n", "\n", "gas_optics_lw.gas_optics.compute(atmosphere, problem_type=\"absorption\")\n", "\n", @@ -39,8 +39,12 @@ "rlu = xr.load_dataset(rlu_reference, decode_cf=False)\n", "rld = xr.load_dataset(rld_reference, decode_cf=False)\n", "\n", - "assert np.isclose(fluxes[\"lw_flux_up_broadband\"], rlu[\"rlu\"], atol=ERROR_TOLERANCE).all()\n", - "assert np.isclose(fluxes[\"lw_flux_down_broadband\"], rld[\"rld\"], atol=ERROR_TOLERANCE).all()\n" + "assert np.isclose(\n", + " fluxes[\"lw_flux_up_broadband\"], rlu[\"rlu\"], atol=ERROR_TOLERANCE\n", + ").all()\n", + "assert np.isclose(\n", + " fluxes[\"lw_flux_down_broadband\"], rld[\"rld\"], atol=ERROR_TOLERANCE\n", + ").all()" ] } ], diff --git a/pyrte_rrtmgp/config.py b/pyrte_rrtmgp/config.py index ec16611..8373aae 100644 --- a/pyrte_rrtmgp/config.py +++ b/pyrte_rrtmgp/config.py @@ -1,9 +1,16 @@ -from typing import Dict +"""Default mappings for gas names, dimensions and variables used in RRTMGP. -DEFAULT_GAS_MAPPING: Dict[str, str] = { +This module contains dictionaries that map standard names to dataset-specific names +for gases, dimensions and variables used in radiative transfer calculations. +""" + +from typing import Dict, Final + +# Mapping of standard gas names to RRTMGP-specific names +DEFAULT_GAS_MAPPING: Final[Dict[str, str]] = { "h2o": "water_vapor", "co2": "carbon_dioxide_GM", - "o3": "ozone", + "o3": "ozone", "n2o": "nitrous_oxide_GM", "co": "carbon_monoxide_GM", "ch4": "methane_GM", @@ -21,3 +28,25 @@ "cf4": "cf4_GM", "no2": "no2", } + +# Mapping of standard dimension names to dataset-specific names +DEFAULT_DIM_MAPPING: Final[Dict[str, str]] = { + "site": "site", + "layer": "layer", + "level": "level", +} + +# Mapping of standard variable names to dataset-specific names +DEFAULT_VAR_MAPPING: Final[Dict[str, str]] = { + "pres_layer": "pres_layer", + "pres_level": "pres_level", + "temp_layer": "temp_layer", + "temp_level": "temp_level", + "surface_temperature": "surface_temperature", + "solar_zenith_angle": "solar_zenith_angle", + "surface_albedo": "surface_albedo", + "surface_albedo_dir": "surface_albedo_dir", + "surface_albedo_dif": "surface_albedo_dif", + "surface_emissivity": "surface_emissivity", + "surface_emissivity_jacobian": "surface_emissivity_jacobian", +} diff --git a/pyrte_rrtmgp/constants.py b/pyrte_rrtmgp/constants.py index 65a20b1..2129e8c 100644 --- a/pyrte_rrtmgp/constants.py +++ b/pyrte_rrtmgp/constants.py @@ -1,9 +1,48 @@ -HELMERT1 = 9.80665 -HELMERT2 = 0.02586 -M_DRY = 0.028964 -M_H2O = 0.018016 -AVOGAD = 6.02214076e23 -SOLAR_CONSTANTS = { - "A_OFFSET": 0.1495954, - "B_OFFSET": 0.00066696, +"""Physical and mathematical constants used in radiative transfer calculations. + +This module contains various physical and mathematical constants needed for +radiative transfer calculations, including gravitational parameters, molecular +masses, and Gaussian quadrature weights and points. +""" + +from typing import Dict, Final +import numpy as np +from numpy.typing import NDArray + +# Gravitational parameters from Helmert's equation (m/s^2) +HELMERT1: Final[float] = 9.80665 # Standard gravity at sea level +HELMERT2: Final[float] = 0.02586 # Gravity variation with latitude + +# Molecular masses (kg/mol) +M_DRY: Final[float] = 0.028964 # Dry air +M_H2O: Final[float] = 0.018016 # Water vapor + +# Avogadro's number (molecules/mol) +AVOGAD: Final[float] = 6.02214076e23 + +# Solar constants for orbit calculations +SOLAR_CONSTANTS: Final[Dict[str, float]] = { + "A_OFFSET": 0.1495954, # Semi-major axis offset (AU) + "B_OFFSET": 0.00066696, # Orbital eccentricity factor } + +# Gaussian quadrature constants for radiative transfer +GAUSS_DS: NDArray[np.float64] = np.reciprocal( + np.array( + [ + [0.6096748751, np.inf, np.inf, np.inf], + [0.2509907356, 0.7908473988, np.inf, np.inf], + [0.1024922169, 0.4417960320, 0.8633751621, np.inf], + [0.0454586727, 0.2322334416, 0.5740198775, 0.9030775973], + ] + ) +) + +GAUSS_WTS: NDArray[np.float64] = np.array( + [ + [1.0, 0.0, 0.0, 0.0], + [0.2300253764, 0.7699746236, 0.0, 0.0], + [0.0437820218, 0.3875796738, 0.5686383044, 0.0], + [0.0092068785, 0.1285704278, 0.4323381850, 0.4298845087], + ] +) diff --git a/pyrte_rrtmgp/data_types.py b/pyrte_rrtmgp/data_types.py index b707704..921e2cd 100644 --- a/pyrte_rrtmgp/data_types.py +++ b/pyrte_rrtmgp/data_types.py @@ -1,8 +1,19 @@ -from enum import Enum +from enum import Enum, StrEnum -class GasOpticsFiles(Enum): - """Enumeration of default RRTMGP gas optics data files.""" +class GasOpticsFiles(StrEnum): + """Enumeration of default RRTMGP gas optics data files. + + This enum defines the available pre-configured gas optics data files that can be used + with RRTMGP. The files contain absorption coefficients and other optical properties + needed for radiative transfer calculations. + + Attributes: + LW_G128: Longwave gas optics file with 128 g-points + LW_G256: Longwave gas optics file with 256 g-points + SW_G112: Shortwave gas optics file with 112 g-points + SW_G224: Shortwave gas optics file with 224 g-points + """ LW_G128 = "rrtmgp-gas-lw-g128.nc" LW_G256 = "rrtmgp-gas-lw-g256.nc" @@ -10,8 +21,20 @@ class GasOpticsFiles(Enum): SW_G224 = "rrtmgp-gas-sw-g224.nc" -class ProblemTypes(Enum): +class ProblemTypes(StrEnum): + """Enumeration of available radiation calculation types. + + This enum defines the different types of radiation calculations that can be performed, + including both longwave and shortwave calculations with different solution methods. + + Attributes: + LW_ABSORPTION: Longwave absorption-only calculation + LW_2STREAM: Longwave two-stream approximation calculation + SW_DIRECT: Shortwave direct beam calculation + SW_2STREAM: Shortwave two-stream approximation calculation + """ + LW_ABSORPTION = "Longwave absorption" - LW_2STREAM = "Longwave 2-stream" + LW_2STREAM = "Longwave 2-stream" SW_DIRECT = "Shortwave direct" SW_2STREAM = "Shortwave 2-stream" diff --git a/pyrte_rrtmgp/data_validation.py b/pyrte_rrtmgp/data_validation.py index c2d0530..0d62cd3 100644 --- a/pyrte_rrtmgp/data_validation.py +++ b/pyrte_rrtmgp/data_validation.py @@ -1,20 +1,39 @@ from dataclasses import asdict, dataclass -from typing import Dict, Optional, Set +from typing import Dict, Optional, Set, Union import xarray as xr -from pyrte_rrtmgp.config import DEFAULT_GAS_MAPPING +from pyrte_rrtmgp.config import ( + DEFAULT_DIM_MAPPING, + DEFAULT_GAS_MAPPING, + DEFAULT_VAR_MAPPING, +) @dataclass class GasMapping: + """Class for managing gas name mappings between standard and dataset-specific names. + + Attributes: + _mapping: Dictionary mapping standard gas names to dataset-specific names + _required_gases: Set of required gas names that must be present + """ _mapping: Dict[str, str] _required_gases: Set[str] @classmethod def create( - cls, gas_names: Set[str], custom_mapping: Dict[str, str] | None = None + cls, gas_names: Set[str], custom_mapping: Optional[Dict[str, str]] = None ) -> "GasMapping": + """Create a new GasMapping instance with default and custom mappings. + + Args: + gas_names: Set of required gas names + custom_mapping: Optional custom mapping to override defaults + + Returns: + New GasMapping instance + """ mapping = DEFAULT_GAS_MAPPING.copy() if custom_mapping: mapping.update(custom_mapping) @@ -22,7 +41,14 @@ def create( return cls(mapping, gas_names) def validate(self) -> Dict[str, str]: - """Validates and returns the final mapping.""" + """Validate and return the final gas name mapping. + + Returns: + Dictionary mapping standard gas names to dataset-specific names + + Raises: + ValueError: If a required gas is not found in any mapping + """ validated_mapping = {} for gas in self._required_gases: @@ -38,57 +64,91 @@ def validate(self) -> Dict[str, str]: @dataclass class DatasetMapping: - """Container for dimension and variable mappings""" + """Container for dimension and variable name mappings. + Attributes: + dim_mapping: Dictionary mapping standard dimension names to dataset-specific names + var_mapping: Dictionary mapping standard variable names to dataset-specific names + """ dim_mapping: Dict[str, str] var_mapping: Dict[str, str] - def __post_init__(self): - """Validate mappings upon initialization""" + def __post_init__(self) -> None: + """Validate mappings upon initialization.""" pass @classmethod - def from_dict(cls, d: Dict) -> "DatasetMapping": - """Create mapping from dictionary""" + def from_dict(cls, d: Dict[str, Dict[str, str]]) -> "DatasetMapping": + """Create mapping from dictionary representation. + + Args: + d: Dictionary containing dim_mapping and var_mapping + + Returns: + New DatasetMapping instance + """ return cls(dim_mapping=d["dim_mapping"], var_mapping=d["var_mapping"]) @xr.register_dataset_accessor("mapping") class DatasetMappingAccessor: - """ - An accessor for xarray datasets that provides information about variable mappings. - The mapping is stored in the dataset's attributes. + """Accessor for xarray datasets that provides variable mapping functionality. + + The mapping is stored in the dataset's attributes to maintain persistence. """ - def __init__(self, xarray_obj): + def __init__(self, xarray_obj: xr.Dataset) -> None: self._obj = xarray_obj def set_mapping(self, mapping: DatasetMapping) -> None: - """Set the mapping in dataset attributes""" - # Validate that mapped variables exist in dataset + """Set the mapping in dataset attributes. + + Args: + mapping: DatasetMapping instance to store + + Raises: + ValueError: If mapped dimensions don't exist in dataset + """ missing_dims = set(mapping.dim_mapping.values()) - set(self._obj.dims) if missing_dims: raise ValueError(f"Dataset missing required dimensions: {missing_dims}") - # Store mapping in attributes self._obj.attrs["dataset_mapping"] = asdict(mapping) @property def mapping(self) -> Optional[DatasetMapping]: - """Get the mapping from dataset attributes""" + """Get the mapping from dataset attributes. + + Returns: + DatasetMapping if exists, None otherwise + """ if "dataset_mapping" not in self._obj.attrs: return None return DatasetMapping.from_dict(self._obj.attrs["dataset_mapping"]) def get_var(self, standard_name: str) -> Optional[str]: - """Get the actual variable name in the dataset for a standard name""" + """Get the dataset-specific variable name for a standard name. + + Args: + standard_name: Standard variable name + + Returns: + Dataset-specific variable name if found, None otherwise + """ mapping = self.mapping if mapping is None: return None return mapping.var_mapping.get(standard_name) def get_dim(self, standard_name: str) -> Optional[str]: - """Get the actual dimension name in the dataset for a standard name""" + """Get the dataset-specific dimension name for a standard name. + + Args: + standard_name: Standard dimension name + + Returns: + Dataset-specific dimension name if found, None otherwise + """ mapping = self.mapping if mapping is None: return None @@ -97,10 +157,17 @@ def get_dim(self, standard_name: str) -> Optional[str]: @dataclass class AtmosphericMapping(DatasetMapping): - """Specific mapping for atmospheric data""" + """Specific mapping for atmospheric data with required dimensions and variables. - def __post_init__(self): - """Validate atmospheric-specific mappings""" + Inherits from DatasetMapping and adds validation for required atmospheric fields. + """ + + def __post_init__(self) -> None: + """Validate atmospheric-specific mappings. + + Raises: + ValueError: If required dimensions or variables are missing + """ required_dims = {"site", "layer", "level"} missing_dims = required_dims - set(self.dim_mapping.keys()) if missing_dims: @@ -113,20 +180,12 @@ def __post_init__(self): def create_default_mapping() -> AtmosphericMapping: - """Create a default mapping configuration""" + """Create a default atmospheric mapping configuration. + + Returns: + AtmosphericMapping instance with default dimension and variable mappings + """ return AtmosphericMapping( - dim_mapping={"site": "site", "layer": "layer", "level": "level"}, - var_mapping={ - "pres_layer": "pres_layer", - "pres_level": "pres_level", - "temp_layer": "temp_layer", - "temp_level": "temp_level", - "surface_temperature": "surface_temperature", - "solar_zenith_angle": "solar_zenith_angle", - "surface_albedo": "surface_albedo", - "surface_albedo_dir": "surface_albedo_dir", - "surface_albedo_dif": "surface_albedo_dif", - "surface_emissivity": "surface_emissivity", - "surface_emissivity_jacobian": "surface_emissivity_jacobian", - }, + dim_mapping=DEFAULT_DIM_MAPPING, + var_mapping=DEFAULT_VAR_MAPPING, ) diff --git a/pyrte_rrtmgp/kernels/rrtmgp.py b/pyrte_rrtmgp/kernels/rrtmgp.py index e4a4e8c..42faa3d 100644 --- a/pyrte_rrtmgp/kernels/rrtmgp.py +++ b/pyrte_rrtmgp/kernels/rrtmgp.py @@ -9,7 +9,6 @@ rrtmgp_compute_tau_rayleigh, rrtmgp_interpolation, ) -from pyrte_rrtmgp.utils import convert_xarray_args def interpolation( @@ -20,51 +19,54 @@ def interpolation( neta: int, npres: int, ntemp: int, - flavor: npt.NDArray, - press_ref: npt.NDArray, - temp_ref: npt.NDArray, + flavor: npt.NDArray[np.int32], + press_ref: npt.NDArray[np.float64], + temp_ref: npt.NDArray[np.float64], press_ref_trop: float, - vmr_ref: npt.NDArray, - play: npt.NDArray, - tlay: npt.NDArray, - col_gas: npt.NDArray, + vmr_ref: npt.NDArray[np.float64], + play: npt.NDArray[np.float64], + tlay: npt.NDArray[np.float64], + col_gas: npt.NDArray[np.float64], ) -> Tuple[ - npt.NDArray, - npt.NDArray, - npt.NDArray, - npt.NDArray, - npt.NDArray, - npt.NDArray, - npt.NDArray, + npt.NDArray[np.int32], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.bool_], + npt.NDArray[np.int32], + npt.NDArray[np.int32], ]: - """Interpolate the RRTMGP coefficients. + """Interpolate the RRTMGP coefficients to the current atmospheric state. + + This function performs interpolation of gas optics coefficients based on the current + atmospheric temperature and pressure profiles. Args: - ncol (int): Number of atmospheric columns. - nlay (int): Number of atmospheric layers. - ngas (int): Number of gases. - nflav (int): Number of gas flavors. - neta (int): Number of mixing_fraction. - npres (int): Number of reference pressure grid points. - ntemp (int): Number of reference temperature grid points. - flavor (np.ndarray): Index into vmr_ref of major gases for each flavor. - press_ref (np.ndarray): Reference pressure grid. - temp_ref (np.ndarray): Reference temperature grid. - press_ref_trop (float): Reference pressure at the tropopause. - vmr_ref (np.ndarray): Reference volume mixing ratio. - play (np.ndarray): Pressure layers. - tlay (np.ndarray): Temperature layers. - col_gas (np.ndarray): Gas concentrations. + ncol: Number of atmospheric columns + nlay: Number of atmospheric layers + ngas: Number of gases + nflav: Number of gas flavors + neta: Number of mixing fraction points + npres: Number of reference pressure grid points + ntemp: Number of reference temperature grid points + flavor: Index into vmr_ref of major gases for each flavor with shape (nflav,) + press_ref: Reference pressure grid with shape (npres,) + temp_ref: Reference temperature grid with shape (ntemp,) + press_ref_trop: Reference pressure at the tropopause + vmr_ref: Reference volume mixing ratios with shape (ngas,) + play: Layer pressures with shape (ncol, nlay) + tlay: Layer temperatures with shape (ncol, nlay) + col_gas: Gas concentrations with shape (ncol, nlay, ngas) Returns: - Tuple: A tuple containing the following arrays: - - jtemp (np.ndarray): Temperature interpolation index. - - fmajor (np.ndarray): Major gas interpolation fraction. - - fminor (np.ndarray): Minor gas interpolation fraction. - - col_mix (np.ndarray): Mixing fractions. - - tropo (np.ndarray): Use lower (or upper) atmosphere tables. - - jeta (np.ndarray): Index for binary species interpolation. - - jpress (np.ndarray): Pressure interpolation index. + Tuple containing: + - jtemp: Temperature interpolation indices with shape (ncol, nlay) + - fmajor: Major gas interpolation fractions with shape (2, 2, 2, ncol, nlay, nflav) + - fminor: Minor gas interpolation fractions with shape (2, 2, ncol, nlay, nflav) + - col_mix: Mixing fractions with shape (2, ncol, nlay, nflav) + - tropo: Boolean mask for troposphere with shape (ncol, nlay) + - jeta: Binary species interpolation indices with shape (2, ncol, nlay, nflav) + - jpress: Pressure interpolation indices with shape (ncol, nlay) """ press_ref_log = np.log(press_ref) press_ref_log_delta = (press_ref_log.min() - press_ref_log.max()) / ( @@ -77,7 +79,7 @@ def interpolation( ngas = ngas - 1 # Fortran uses index 0 here - # outputs + # Initialize output arrays jtemp = np.ndarray([ncol, nlay], dtype=np.int32, order="F") fmajor = np.ndarray([2, 2, 2, ncol, nlay, nflav], dtype=np.float64, order="F") fminor = np.ndarray([2, 2, ncol, nlay, nflav], dtype=np.float64, order="F") @@ -120,71 +122,80 @@ def interpolation( return jtemp, fmajor, fminor, col_mix, tropo, jeta, jpress -@convert_xarray_args def compute_planck_source( - ncol, - nlay, - nbnd, - ngpt, - nflav, - neta, - npres, - ntemp, - nPlanckTemp, - tlay, - tlev, - tsfc, - top_at_1, - fmajor, - jeta, - tropo, - jtemp, - jpress, - band_lims_gpt, - pfracin, - temp_ref_min, - temp_ref_max, - totplnk, - gpoint_flavor, -): - """Compute the Planck source function for a radiative transfer calculation. + ncol: int, + nlay: int, + nbnd: int, + ngpt: int, + nflav: int, + neta: int, + npres: int, + ntemp: int, + nPlanckTemp: int, + tlay: npt.NDArray[np.float64], + tlev: npt.NDArray[np.float64], + tsfc: npt.NDArray[np.float64], + top_at_1: bool, + fmajor: npt.NDArray[np.float64], + jeta: npt.NDArray[np.int32], + tropo: npt.NDArray[np.bool_], + jtemp: npt.NDArray[np.int32], + jpress: npt.NDArray[np.int32], + band_lims_gpt: npt.NDArray[np.int32], + pfracin: npt.NDArray[np.float64], + temp_ref_min: float, + temp_ref_max: float, + totplnk: npt.NDArray[np.float64], + gpoint_flavor: npt.NDArray[np.int32], +) -> Tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], +]: + """Compute the Planck source function for radiative transfer calculations. + + This function calculates the Planck blackbody emission source terms needed for + longwave radiative transfer calculations. Args: - tlay (numpy.ndarray): Temperature at layer centers (K), shape (ncol, nlay). - tlev (numpy.ndarray): Temperature at layer interfaces (K), shape (ncol, nlay+1). - tsfc (numpy.ndarray): Surface temperature, shape (ncol,). - top_at_1 (bool): Flag indicating if the top layer is at index 0. - sfc_lay (int): Index of the surface layer. - fmajor (numpy.ndarray): Interpolation weights for major gases, shape (2, 2, 2, ncol, nlay, nflav). - jeta (numpy.ndarray): Interpolation indexes in eta, shape (2, ncol, nlay, nflav). - tropo (numpy.ndarray): Use upper- or lower-atmospheric tables, shape (ncol, nlay). - jtemp (numpy.ndarray): Interpolation indexes in temperature, shape (ncol, nlay). - jpress (numpy.ndarray): Interpolation indexes in pressure, shape (ncol, nlay). - band_lims_gpt (numpy.ndarray): Start and end g-point for each band, shape (2, nbnd). - pfracin (numpy.ndarray): Fraction of the Planck function in each g-point, shape (ntemp, neta, npres+1, ngpt). - temp_ref_min (float): Minimum reference temperature for Planck function interpolation. - totplnk (numpy.ndarray): Total Planck function by band at each temperature, shape (nPlanckTemp, nbnd). - gpoint_flavor (numpy.ndarray): Major gas flavor (pair) by upper/lower, g-point, shape (2, ngpt). + ncol: Number of atmospheric columns + nlay: Number of atmospheric layers + nbnd: Number of spectral bands + ngpt: Number of g-points + nflav: Number of gas flavors + neta: Number of eta points + npres: Number of pressure points + ntemp: Number of temperature points + nPlanckTemp: Number of temperatures for Planck function + tlay: Layer temperatures with shape (ncol, nlay) + tlev: Level temperatures with shape (ncol, nlay+1) + tsfc: Surface temperatures with shape (ncol,) + top_at_1: Whether the top of the atmosphere is at index 1 + fmajor: Major gas interpolation weights with shape (2, 2, 2, ncol, nlay, nflav) + jeta: Eta interpolation indices with shape (2, ncol, nlay, nflav) + tropo: Troposphere mask with shape (ncol, nlay) + jtemp: Temperature interpolation indices with shape (ncol, nlay) + jpress: Pressure interpolation indices with shape (ncol, nlay) + band_lims_gpt: Band limits in g-point space with shape (2, nbnd) + pfracin: Planck fractions with shape (ntemp, neta, npres+1, ngpt) + temp_ref_min: Minimum reference temperature + temp_ref_max: Maximum reference temperature + totplnk: Total Planck function by band with shape (nPlanckTemp, nbnd) + gpoint_flavor: G-point flavors with shape (2, ngpt) Returns: - sfc_src (numpy.ndarray): Planck emission from the surface, shape (ncol, ngpt). - lay_src (numpy.ndarray): Planck emission from layer centers, shape (ncol, nlay, ngpt). - lev_src (numpy.ndarray): Planck emission from layer boundaries, shape (ncol, nlay+1, ngpt). - sfc_source_Jac (numpy.ndarray): Jacobian (derivative) of the surface Planck source with respect to surface temperature, shape (ncol, ngpt). + Tuple containing: + - sfc_src: Surface emission with shape (ncol, ngpt) + - lay_src: Layer emission with shape (ncol, nlay, ngpt) + - lev_src: Level emission with shape (ncol, nlay+1, ngpt) + - sfc_src_jac: Surface emission Jacobian with shape (ncol, ngpt) """ - - # _, ncol, nlay, nflav = jeta.shape - # nPlanckTemp, nbnd = totplnk.shape - # ntemp, neta, npres_e, ngpt = pfracin.shape - # npres = npres_e - 1 - sfc_lay = nlay if top_at_1 else 1 - gpoint_bands = [] - totplnk_delta = (temp_ref_max - temp_ref_min) / (nPlanckTemp - 1) - # outputs + # Initialize output arrays sfc_src = np.ndarray((ncol, ngpt), dtype=np.float64, order="F") lay_src = np.ndarray((ncol, nlay, ngpt), dtype=np.float64, order="F") lev_src = np.ndarray((ncol, nlay + 1, ngpt), dtype=np.float64, order="F") @@ -228,87 +239,102 @@ def compute_planck_source( def compute_tau_absorption( - ncol, - nlay, - nbnd, - ngpt, - ngas, - nflav, - neta, - npres, - ntemp, - nminorlower, - nminorklower, - nminorupper, - nminorkupper, - idx_h2o, - gpoint_flavor, - band_lims_gpt, - kmajor, - kminor_lower, - kminor_upper, - minor_limits_gpt_lower, - minor_limits_gpt_upper, - minor_scales_with_density_lower, - minor_scales_with_density_upper, - scale_by_complement_lower, - scale_by_complement_upper, - idx_minor_lower, - idx_minor_upper, - idx_minor_scaling_lower, - idx_minor_scaling_upper, - kminor_start_lower, - kminor_start_upper, - tropo, - col_mix, - fmajor, - fminor, - play, - tlay, - col_gas, - jeta, - jtemp, - jpress, -): - """Compute the absorption optical depth for a set of atmospheric profiles. + ncol: int, + nlay: int, + nbnd: int, + ngpt: int, + ngas: int, + nflav: int, + neta: int, + npres: int, + ntemp: int, + nminorlower: int, + nminorklower: int, + nminorupper: int, + nminorkupper: int, + idx_h2o: int, + gpoint_flavor: npt.NDArray[np.int32], + band_lims_gpt: npt.NDArray[np.int32], + kmajor: npt.NDArray[np.float64], + kminor_lower: npt.NDArray[np.float64], + kminor_upper: npt.NDArray[np.float64], + minor_limits_gpt_lower: npt.NDArray[np.int32], + minor_limits_gpt_upper: npt.NDArray[np.int32], + minor_scales_with_density_lower: npt.NDArray[np.bool_], + minor_scales_with_density_upper: npt.NDArray[np.bool_], + scale_by_complement_lower: npt.NDArray[np.bool_], + scale_by_complement_upper: npt.NDArray[np.bool_], + idx_minor_lower: npt.NDArray[np.int32], + idx_minor_upper: npt.NDArray[np.int32], + idx_minor_scaling_lower: npt.NDArray[np.int32], + idx_minor_scaling_upper: npt.NDArray[np.int32], + kminor_start_lower: npt.NDArray[np.int32], + kminor_start_upper: npt.NDArray[np.int32], + tropo: npt.NDArray[np.bool_], + col_mix: npt.NDArray[np.float64], + fmajor: npt.NDArray[np.float64], + fminor: npt.NDArray[np.float64], + play: npt.NDArray[np.float64], + tlay: npt.NDArray[np.float64], + col_gas: npt.NDArray[np.float64], + jeta: npt.NDArray[np.int32], + jtemp: npt.NDArray[np.int32], + jpress: npt.NDArray[np.int32], +) -> npt.NDArray[np.float64]: + """Compute the absorption optical depth for atmospheric profiles. + + This function calculates the total absorption optical depth by combining contributions + from major and minor gas species in both the upper and lower atmosphere. Args: - idx_h2o (int): Index of the water vapor gas species. - gpoint_flavor (np.ndarray): Spectral g-point flavor indices. - band_lims_gpt (np.ndarray): Spectral band limits in g-point space. - kmajor (np.ndarray): Major gas absorption coefficients. - kminor_lower (np.ndarray): Minor gas absorption coefficients for the lower atmosphere. - kminor_upper (np.ndarray): Minor gas absorption coefficients for the upper atmosphere. - minor_limits_gpt_lower (np.ndarray): Spectral g-point limits for minor contributors in the lower atmosphere. - minor_limits_gpt_upper (np.ndarray): Spectral g-point limits for minor contributors in the upper atmosphere. - minor_scales_with_density_lower (np.ndarray): Flags indicating if minor contributors in the lower atmosphere scale with density. - minor_scales_with_density_upper (np.ndarray): Flags indicating if minor contributors in the upper atmosphere scale with density. - scale_by_complement_lower (np.ndarray): Flags indicating if minor contributors in the lower atmosphere should be scaled by the complement. - scale_by_complement_upper (np.ndarray): Flags indicating if minor contributors in the upper atmosphere should be scaled by the complement. - idx_minor_lower (np.ndarray): Indices of minor contributors in the lower atmosphere. - idx_minor_upper (np.ndarray): Indices of minor contributors in the upper atmosphere. - idx_minor_scaling_lower (np.ndarray): Indices of minor contributors in the lower atmosphere that require scaling. - idx_minor_scaling_upper (np.ndarray): Indices of minor contributors in the upper atmosphere that require scaling. - kminor_start_lower (np.ndarray): Starting indices of minor absorption coefficients in the lower atmosphere. - kminor_start_upper (np.ndarray): Starting indices of minor absorption coefficients in the upper atmosphere. - tropo (np.ndarray): Flags indicating if a layer is in the troposphere. - col_mix (np.ndarray): Column-dependent gas mixing ratios. - fmajor (np.ndarray): Major gas absorption coefficient scaling factors. - fminor (np.ndarray): Minor gas absorption coefficient scaling factors. - play (np.ndarray): Pressure in each layer. - tlay (np.ndarray): Temperature in each layer. - col_gas (np.ndarray): Column-dependent gas concentrations. - jeta (np.ndarray): Indices of temperature/pressure levels. - jtemp (np.ndarray): Indices of temperature levels. - jpress (np.ndarray): Indices of pressure levels. + ncol: Number of atmospheric columns + nlay: Number of atmospheric layers + nbnd: Number of spectral bands + ngpt: Number of g-points + ngas: Number of gases + nflav: Number of gas flavors + neta: Number of eta points + npres: Number of pressure points + ntemp: Number of temperature points + nminorlower: Number of minor species in lower atmosphere + nminorklower: Number of minor absorption coefficients in lower atmosphere + nminorupper: Number of minor species in upper atmosphere + nminorkupper: Number of minor absorption coefficients in upper atmosphere + idx_h2o: Index of water vapor + gpoint_flavor: G-point flavors with shape (2, ngpt) + band_lims_gpt: Band limits in g-point space with shape (2, nbnd) + kmajor: Major gas absorption coefficients + kminor_lower: Minor gas absorption coefficients for lower atmosphere + kminor_upper: Minor gas absorption coefficients for upper atmosphere + minor_limits_gpt_lower: G-point limits for minor gases in lower atmosphere + minor_limits_gpt_upper: G-point limits for minor gases in upper atmosphere + minor_scales_with_density_lower: Density scaling flags for lower atmosphere + minor_scales_with_density_upper: Density scaling flags for upper atmosphere + scale_by_complement_lower: Complement scaling flags for lower atmosphere + scale_by_complement_upper: Complement scaling flags for upper atmosphere + idx_minor_lower: Minor gas indices for lower atmosphere + idx_minor_upper: Minor gas indices for upper atmosphere + idx_minor_scaling_lower: Minor gas scaling indices for lower atmosphere + idx_minor_scaling_upper: Minor gas scaling indices for upper atmosphere + kminor_start_lower: Starting indices for minor gases in lower atmosphere + kminor_start_upper: Starting indices for minor gases in upper atmosphere + tropo: Troposphere mask with shape (ncol, nlay) + col_mix: Gas mixing ratios with shape (2, ncol, nlay, nflav) + fmajor: Major gas interpolation weights + fminor: Minor gas interpolation weights + play: Layer pressures with shape (ncol, nlay) + tlay: Layer temperatures with shape (ncol, nlay) + col_gas: Gas concentrations with shape (ncol, nlay, ngas) + jeta: Eta interpolation indices + jtemp: Temperature interpolation indices + jpress: Pressure interpolation indices Returns: - np.ndarray): tau Absorption optical depth. + Absorption optical depth with shape (ncol, nlay, ngpt) """ - ngas = ngas - 1 # Fortran uses index 0 here - # outputs + # Initialize output array tau = np.zeros((ncol, nlay, ngpt), dtype=np.float64, order="F") args = [ @@ -361,46 +387,54 @@ def compute_tau_absorption( return tau -@convert_xarray_args def compute_tau_rayleigh( - ncol, - nlay, - nbnd, - ngpt, - ngas, - nflav, - neta, - ntemp, - gpoint_flavor, - band_lims_gpt, - krayl, - idx_h2o, - col_dry, - col_gas, - fminor, - jeta, - tropo, - jtemp, -): - """Compute Rayleigh optical depth. + ncol: int, + nlay: int, + nbnd: int, + ngpt: int, + ngas: int, + nflav: int, + neta: int, + ntemp: int, + gpoint_flavor: npt.NDArray[np.int32], + band_lims_gpt: npt.NDArray[np.int32], + krayl: npt.NDArray[np.float64], + idx_h2o: int, + col_dry: npt.NDArray[np.float64], + col_gas: npt.NDArray[np.float64], + fminor: npt.NDArray[np.float64], + jeta: npt.NDArray[np.int32], + tropo: npt.NDArray[np.bool_], + jtemp: npt.NDArray[np.int32], +) -> npt.NDArray[np.float64]: + """Compute Rayleigh scattering optical depth. + + This function calculates the optical depth due to Rayleigh scattering by air molecules. Args: - gpoint_flavor (numpy.ndarray): Major gas flavor (pair) by upper/lower, g-point (shape: (2, ngpt)). - band_lims_gpt (numpy.ndarray): Start and end g-point for each band (shape: (2, nbnd)). - krayl (numpy.ndarray): Rayleigh scattering coefficients (shape: (ntemp, neta, ngpt, 2)). - idx_h2o (int): Index of water vapor in col_gas. - col_dry (numpy.ndarray): Column amount of dry air (shape: (ncol, nlay)). - col_gas (numpy.ndarray): Input column gas amount (molecules/cm^2) (shape: (ncol, nlay, 0:ngas)). - fminor (numpy.ndarray): Interpolation weights for major gases - computed in interpolation() (shape: (2, 2, ncol, nlay, nflav)). - jeta (numpy.ndarray): Interpolation indexes in eta - computed in interpolation() (shape: (2, ncol, nlay, nflav)). - tropo (numpy.ndarray): Use upper- or lower-atmospheric tables? (shape: (ncol, nlay)). - jtemp (numpy.ndarray): Interpolation indexes in temperature - computed in interpolation() (shape: (ncol, nlay)). + ncol: Number of atmospheric columns + nlay: Number of atmospheric layers + nbnd: Number of spectral bands + ngpt: Number of g-points + ngas: Number of gases + nflav: Number of gas flavors + neta: Number of eta points + ntemp: Number of temperature points + gpoint_flavor: G-point flavors with shape (2, ngpt) + band_lims_gpt: Band limits in g-point space with shape (2, nbnd) + krayl: Rayleigh scattering coefficients with shape (ntemp, neta, ngpt, 2) + idx_h2o: Index of water vapor + col_dry: Dry air column amounts with shape (ncol, nlay) + col_gas: Gas concentrations with shape (ncol, nlay, ngas) + fminor: Minor gas interpolation weights + jeta: Eta interpolation indices + tropo: Troposphere mask with shape (ncol, nlay) + jtemp: Temperature interpolation indices Returns: - numpy.ndarray: Rayleigh optical depth (shape: (ncol, nlay, ngpt)). + Rayleigh scattering optical depth with shape (ncol, nlay, ngpt) """ - - # outputs + # Initialize output array tau_rayleigh = np.ndarray((ncol, nlay, ngpt), dtype=np.float64, order="F") args = [ diff --git a/pyrte_rrtmgp/kernels/rte.py b/pyrte_rrtmgp/kernels/rte.py index 2f3dd27..4d8dfdb 100644 --- a/pyrte_rrtmgp/kernels/rte.py +++ b/pyrte_rrtmgp/kernels/rte.py @@ -16,7 +16,7 @@ def lw_solver_noscat( nlay: int, ngpt: int, ds: npt.NDArray[np.float64], - weights: npt.NDArray[np.float64], + weights: npt.NDArray[np.float64], tau: npt.NDArray[np.float64], ssa: npt.NDArray[np.float64], g: npt.NDArray[np.float64], @@ -38,41 +38,46 @@ def lw_solver_noscat( npt.NDArray[np.float64], npt.NDArray[np.float64], ]: - """ - Perform longwave radiation transfer calculations without scattering. + """Perform longwave radiation transfer calculations without scattering. + + This function solves the longwave radiative transfer equation in the absence of scattering, + computing fluxes and optionally their Jacobians. Args: + ncol: Number of columns + nlay: Number of layers + ngpt: Number of g-points + ds: Integration weights with shape (ncol, ngpt, n_quad_angs) + weights: Gaussian quadrature weights with shape (n_quad_angs,) tau: Optical depths with shape (ncol, nlay, ngpt) + ssa: Single scattering albedos with shape (ncol, nlay, ngpt) + g: Asymmetry parameters with shape (ncol, nlay, ngpt) lay_source: Layer source terms with shape (ncol, nlay, ngpt) lev_source: Level source terms with shape (ncol, nlay+1, ngpt) sfc_emis: Surface emissivities with shape (ncol, ngpt) or (ncol,) sfc_src: Surface source terms with shape (ncol, ngpt) + sfc_src_jac: Surface source Jacobians with shape (ncol, nlay+1) + inc_flux: Incident fluxes with shape (ncol, ngpt) top_at_1: Whether the top of the atmosphere is at index 1 nmus: Number of quadrature points (1-4) - inc_flux: Incident fluxes with shape (ncol, ngpt) - ds: Integration weights with shape (ncol, ngpt, n_quad_angs) - weights: Gaussian quadrature weights with shape (n_quad_angs,) do_broadband: Whether to compute broadband fluxes do_Jacobians: Whether to compute Jacobians - sfc_src_jac: Surface source Jacobians with shape (ncol, nlay+1) do_rescaling: Whether to perform flux rescaling - ssa: Single scattering albedos with shape (ncol, nlay, ngpt) - g: Asymmetry parameters with shape (ncol, nlay, ngpt) Returns: Tuple containing: - flux_up_jac: Upward flux Jacobians (ncol, nlay+1) - broadband_up: Upward broadband fluxes (ncol, nlay+1) - broadband_dn: Downward broadband fluxes (ncol, nlay+1) - flux_up: Upward fluxes (ncol, nlay+1, ngpt) - flux_dn: Downward fluxes (ncol, nlay+1, ngpt) + flux_up_jac: Upward flux Jacobians with shape (ncol, nlay+1) + broadband_up: Upward broadband fluxes with shape (ncol, nlay+1) + broadband_dn: Downward broadband fluxes with shape (ncol, nlay+1) + flux_up: Upward fluxes with shape (ncol, nlay+1, ngpt) + flux_dn: Downward fluxes with shape (ncol, nlay+1, ngpt) """ - # outputs - flux_up_jac = np.full([ncol, nlay + 1], np.nan, dtype=np.float64, order="F") - broadband_up = np.full([ncol, nlay + 1], np.nan, dtype=np.float64, order="F") - broadband_dn = np.full([ncol, nlay + 1], np.nan, dtype=np.float64, order="F") - flux_up = np.full([ncol, nlay + 1, ngpt], np.nan, dtype=np.float64, order="F") - flux_dn = np.full([ncol, nlay + 1, ngpt], np.nan, dtype=np.float64, order="F") + # Initialize output arrays + flux_up_jac = np.full((ncol, nlay + 1), np.nan, dtype=np.float64, order="F") + broadband_up = np.full((ncol, nlay + 1), np.nan, dtype=np.float64, order="F") + broadband_dn = np.full((ncol, nlay + 1), np.nan, dtype=np.float64, order="F") + flux_up = np.full((ncol, nlay + 1, ngpt), np.nan, dtype=np.float64, order="F") + flux_dn = np.full((ncol, nlay + 1, ngpt), np.nan, dtype=np.float64, order="F") args = [ ncol, @@ -117,8 +122,10 @@ def lw_solver_2stream( inc_flux: npt.NDArray[np.float64], top_at_1: bool = True, ) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: - """ - Solve the longwave radiative transfer equation using the 2-stream approximation. + """Solve the longwave radiative transfer equation using the 2-stream approximation. + + This function implements the two-stream approximation for longwave radiative transfer, + accounting for both absorption and scattering processes. Args: tau: Optical depths with shape (ncol, nlay, ngpt) @@ -141,7 +148,7 @@ def lw_solver_2stream( if len(sfc_emis.shape) == 1: sfc_emis = np.stack([sfc_emis] * ngpt).T - # outputs + # Initialize output arrays flux_up = np.zeros((ncol, nlay + 1, ngpt), dtype=np.float64, order="F") flux_dn = np.zeros((ncol, nlay + 1, ngpt), dtype=np.float64, order="F") @@ -173,8 +180,10 @@ def sw_solver_noscat( inc_flux_dir: npt.NDArray[np.float64], top_at_1: bool = True, ) -> npt.NDArray[np.float64]: - """ - Perform shortwave radiation transfer calculations without scattering. + """Perform shortwave radiation transfer calculations without scattering. + + This function solves the shortwave radiative transfer equation in the absence of + scattering, computing direct beam fluxes only. Args: tau: Optical depths with shape (ncol, nlay, ngpt) @@ -187,7 +196,7 @@ def sw_solver_noscat( """ ncol, nlay, ngpt = tau.shape - # outputs + # Initialize output array flux_dir = np.zeros((ncol, nlay + 1, ngpt), dtype=np.float64, order="F") args = [ @@ -229,10 +238,15 @@ def sw_solver_2stream( npt.NDArray[np.float64], npt.NDArray[np.float64], ]: - """ - Perform shortwave radiation transfer calculations using the 2-stream approximation. + """Perform shortwave radiation transfer calculations using the 2-stream approximation. + + This function implements the two-stream approximation for shortwave radiative transfer, + computing direct, diffuse upward and downward fluxes, as well as optional broadband fluxes. Args: + ncol: Number of columns + nlay: Number of layers + ngpt: Number of g-points tau: Optical depths with shape (ncol, nlay, ngpt) ssa: Single scattering albedos with shape (ncol, nlay, ngpt) g: Asymmetry parameters with shape (ncol, nlay, ngpt) @@ -240,8 +254,8 @@ def sw_solver_2stream( sfc_alb_dir: Direct surface albedos with shape (ncol, ngpt) or (ncol,) sfc_alb_dif: Diffuse surface albedos with shape (ncol, ngpt) or (ncol,) inc_flux_dir: Direct incident fluxes with shape (ncol, ngpt) - top_at_1: Whether the top of the atmosphere is at index 1 inc_flux_dif: Diffuse incident fluxes with shape (ncol, ngpt) + top_at_1: Whether the top of the atmosphere is at index 1 has_dif_bc: Whether the boundary condition includes diffuse fluxes do_broadband: Whether to compute broadband fluxes @@ -254,8 +268,7 @@ def sw_solver_2stream( broadband_dn: Broadband downward fluxes with shape (ncol, nlay+1) broadband_dir: Broadband direct fluxes with shape (ncol, nlay+1) """ - - # outputs + # Initialize output arrays flux_up = np.zeros((ncol, nlay + 1, ngpt), dtype=np.float64, order="F") flux_dn = np.zeros((ncol, nlay + 1, ngpt), dtype=np.float64, order="F") flux_dir = np.zeros((ncol, nlay + 1, ngpt), dtype=np.float64, order="F") diff --git a/pyrte_rrtmgp/rrtmgp_data.py b/pyrte_rrtmgp/rrtmgp_data.py index 3b3b2a3..6084d3d 100644 --- a/pyrte_rrtmgp/rrtmgp_data.py +++ b/pyrte_rrtmgp/rrtmgp_data.py @@ -2,6 +2,8 @@ import os import platform import tarfile +from pathlib import Path +from typing import Union import requests @@ -10,7 +12,12 @@ DATA_URL = f"https://github.com/earth-system-radiation/rrtmgp-data/archive/refs/tags/{TAG}.tar.gz" -def get_cache_dir(): +def get_cache_dir() -> str: + """Get the system-specific cache directory for pyrte_rrtmgp data. + + Returns: + str: Path to the cache directory + """ # Determine the system cache folder if platform.system() == "Windows": cache_path = os.getenv("LOCALAPPDATA") @@ -27,7 +34,19 @@ def get_cache_dir(): return cache_path -def download_rrtmgp_data(): +def download_rrtmgp_data() -> str: + """Download and extract RRTMGP data files. + + Downloads the RRTMGP data files from GitHub if not already present in the cache, + verifies the checksum, and extracts the contents. + + Returns: + str: Path to the extracted data directory + + Raises: + requests.exceptions.RequestException: If download fails + tarfile.TarError: If extraction fails + """ # Directory where the data will be stored cache_dir = get_cache_dir() @@ -61,8 +80,16 @@ def download_rrtmgp_data(): return os.path.join(cache_dir, f"rrtmgp-data-{TAG[1:]}") -def _get_file_checksum(filepath, mode="r"): - """Helper function to safely read file and get checksum if needed""" +def _get_file_checksum(filepath: Union[str, Path], mode: str = "r") -> str: + """Calculate SHA256 checksum of a file or read existing checksum. + + Args: + filepath: Path to the file + mode: File open mode, "r" for text or "rb" for binary + + Returns: + str: File content if mode="r", or SHA256 hex digest if mode="rb" + """ with open(filepath, mode) as f: content = f.read() return hashlib.sha256(content).hexdigest() if mode == "rb" else content diff --git a/pyrte_rrtmgp/rrtmgp_gas_optics.py b/pyrte_rrtmgp/rrtmgp_gas_optics.py index 70990f8..3ce4d33 100644 --- a/pyrte_rrtmgp/rrtmgp_gas_optics.py +++ b/pyrte_rrtmgp/rrtmgp_gas_optics.py @@ -1,5 +1,7 @@ +import logging import os import sys +from typing import Union import numpy as np import numpy.typing as npt @@ -28,23 +30,36 @@ interpolation, ) from pyrte_rrtmgp.rrtmgp_data import download_rrtmgp_data -from pyrte_rrtmgp.utils import logger + +logger = logging.getLogger(__name__) def load_gas_optics( file_path: str | None = None, gas_optics_file: GasOpticsFiles | None = None, - selected_gases=None, + selected_gases: list[str] | None = None, ) -> xr.Dataset: """Load gas optics data from a file or predefined gas optics file. + This function loads gas optics data either from a custom netCDF file or from + a predefined gas optics file included in the RRTMGP data package. The data + contains absorption coefficients and other optical properties needed for + radiative transfer calculations. + Args: - file_path: Path to custom gas optics netCDF file - gas_optics_file: Predefined gas optics file enum - selected_gases: List of gases to include + file_path: Path to a custom gas optics netCDF file. If provided, this takes + precedence over gas_optics_file. + gas_optics_file: Enum specifying a predefined gas optics file from the RRTMGP + data package. Only used if file_path is None. + selected_gases: Optional list of gas names to include in calculations. + If None, all gases in the file will be used. Returns: - xarray Dataset containing the gas optics data + xr.Dataset: Dataset containing the gas optics data with selected_gases + stored in the attributes. + + Raises: + ValueError: If neither file_path nor gas_optics_file is provided. """ if file_path is not None: dataset = xr.load_dataset(file_path) @@ -58,36 +73,32 @@ def load_gas_optics( return dataset -@xr.register_dataset_accessor("gas_optics") -class GasOpticsAccessor: - """Factory class that returns appropriate GasOptics implementation""" - - def __new__(cls, xarray_obj, selected_gases=None): - # Check if source is internal by looking at required variables - is_internal = ( - "totplnk" in xarray_obj.data_vars - and "plank_fraction" in xarray_obj.data_vars - ) - - if is_internal: - return LWGasOpticsAccessor(xarray_obj, is_internal, selected_gases) - else: - return SWGasOpticsAccessor(xarray_obj, is_internal, selected_gases) - - class BaseGasOpticsAccessor: + """Base class for gas optics calculations. + + This class provides common functionality for both longwave and shortwave gas optics + calculations, including gas interpolation, optical depth computation, and handling of + atmospheric conditions. + + Args: + xarray_obj (xr.Dataset): Dataset containing gas optics data + is_internal (bool): Whether this is for internal (longwave) radiation + selected_gases (list[str] | None): List of gases to include in calculations + + Raises: + ValueError: If 'h2o' is not included in the gas mapping + """ def __init__( self, - xarray_obj, - is_internal, + xarray_obj: xr.Dataset, + is_internal: bool, selected_gases: list[str] | None = None, - ): + ) -> None: self._dataset = xarray_obj - self.is_internal = is_internal # Get the gas names from the dataset - self._gas_names = self.extract_names(self._dataset["gas_names"].values) + self._gas_names: tuple[str, ...] = self.extract_names(self._dataset["gas_names"].values) if selected_gases is not None: # Filter gas names to only include those that exist in the dataset @@ -108,8 +119,16 @@ def __init__( # Set the gas names as coordinate in the dataset self._dataset.coords["absorber_ext"] = np.array(("dry_air",) + self._gas_names) - def _initialize_pressure_levels(self, atmosphere, inplace=True): - """Initialize pressure levels with minimum pressure adjustment""" + def _initialize_pressure_levels(self, atmosphere: xr.Dataset, inplace: bool = True) -> xr.Dataset | None: + """Initialize pressure levels with minimum pressure adjustment. + + Args: + atmosphere: Dataset containing atmospheric conditions + inplace: Whether to modify atmosphere in-place or return a copy + + Returns: + Modified atmosphere dataset if inplace=False, otherwise None + """ pres_level_var = atmosphere.mapping.get_var("pres_level") min_index = np.argmin(atmosphere[pres_level_var].data) @@ -120,14 +139,25 @@ def _initialize_pressure_levels(self, atmosphere, inplace=True): return atmosphere @property - def _selected_gas_names(self): + def _selected_gas_names(self) -> list[str]: + """List of selected gas names.""" return list(self._gas_names) - @property - def _selected_gas_names_ext(self): + @property + def _selected_gas_names_ext(self) -> list[str]: + """List of selected gas names including dry air.""" return ["dry_air"] + self._selected_gas_names - def get_gases_columns(self, atmosphere, gas_name_map): + def get_gases_columns(self, atmosphere: xr.Dataset, gas_name_map: dict[str, str]) -> xr.DataArray: + """Get gas columns from atmospheric conditions. + + Args: + atmosphere: Dataset containing atmospheric conditions + gas_name_map: Mapping between gas names and variable names + + Returns: + DataArray containing gas columns including dry air + """ pres_level_var = atmosphere.mapping.get_var("pres_level") gas_values = [] @@ -160,16 +190,51 @@ def get_gases_columns(self, atmosphere, gas_name_map): return gas_values - def compute_problem(self, atmosphere, gas_interpolation_data): + def compute_problem(self, atmosphere: xr.Dataset, gas_interpolation_data: xr.Dataset) -> xr.Dataset: + """Compute optical properties for radiative transfer problem. + + Args: + atmosphere: Dataset containing atmospheric conditions + gas_interpolation_data: Dataset containing interpolated gas data + + Raises: + NotImplementedError: Must be implemented by subclasses + """ raise NotImplementedError() - def compute_sources(self, atmosphere, gas_interpolation_data): + def compute_sources(self, atmosphere: xr.Dataset, gas_interpolation_data: xr.Dataset) -> xr.Dataset: + """Compute radiation sources. + + Args: + atmosphere: Dataset containing atmospheric conditions + gas_interpolation_data: Dataset containing interpolated gas data + + Raises: + NotImplementedError: Must be implemented by subclasses + """ raise NotImplementedError() - def compute_boundary_conditions(self, atmosphere): + def compute_boundary_conditions(self, atmosphere: xr.Dataset) -> xr.Dataset: + """Compute boundary conditions. + + Args: + atmosphere: Dataset containing atmospheric conditions + + Raises: + NotImplementedError: Must be implemented by subclasses + """ raise NotImplementedError() - def interpolate(self, atmosphere, gas_name_map) -> xr.Dataset: + def interpolate(self, atmosphere: xr.Dataset, gas_name_map: dict[str, str]) -> xr.Dataset: + """Interpolate gas optics data to atmospheric conditions. + + Args: + atmosphere: Dataset containing atmospheric conditions + gas_name_map: Mapping between gas names and variable names + + Returns: + Dataset containing interpolated gas optics data + """ # Get the gas columns from atmospheric conditions gas_order = self._selected_gas_names_ext gases_columns = self.get_gases_columns(atmosphere, gas_name_map).sel( @@ -256,7 +321,16 @@ def interpolate(self, atmosphere, gas_name_map) -> xr.Dataset: return interpolation_results - def tau_absorption(self, atmosphere, gas_interpolation_data): + def tau_absorption(self, atmosphere: xr.Dataset, gas_interpolation_data: xr.Dataset) -> xr.Dataset: + """Compute absorption optical depth. + + Args: + atmosphere: Dataset containing atmospheric conditions + gas_interpolation_data: Dataset containing interpolated gas data + + Returns: + Dataset containing absorption optical depth + """ site_dim = atmosphere.mapping.get_dim("site") layer_dim = atmosphere.mapping.get_dim("layer") @@ -405,7 +479,7 @@ def gpoint_flavor(self) -> xr.DataArray: Each g-point is associated with a flavor, which is a pair of key species. Returns: - np.ndarray: G-point flavors. + DataArray containing g-point flavors """ band_sizes = ( self._dataset["bnd_limits_gpt"].values[:, 1] @@ -435,11 +509,11 @@ def gpoint_flavor(self) -> xr.DataArray: return band_to_flavor.sel(bnd=gpoint_bands - 1) @property - def flavors_sets(self) -> npt.NDArray: + def flavors_sets(self) -> xr.DataArray: """Get the unique flavors from the k-distribution file. Returns: - np.ndarray: Unique flavors. + DataArray containing unique flavors """ # Calculate total number of flavors and pairs n_bands = self._dataset["bnd"].size @@ -470,15 +544,14 @@ def flavors_sets(self) -> npt.NDArray: }, ) - def get_idx_minor(self, minor_gases): - """Index of each minor gas in col_gas + def get_idx_minor(self, minor_gases: list[str]) -> npt.NDArray[np.int32]: + """Get index of each minor gas in col_gas. Args: - gas_names (list): Gas names - minor_gases (list): List of minor gases + minor_gases: List of minor gases Returns: - list: Index of each minor gas in col_gas + Array containing indices of minor gases """ idx_minor_gas = [] for gas in minor_gases: @@ -490,29 +563,29 @@ def get_idx_minor(self, minor_gases): return np.array(idx_minor_gas, dtype=np.int32) @staticmethod - def extract_names(names): - """Extract names from arrays, decoding and removing the suffix + def extract_names(names: npt.NDArray) -> tuple[str, ...]: + """Extract names from arrays, decoding and removing the suffix. Args: - names (np.ndarray): Names + names: Array of encoded names Returns: - tuple: tuple of names + Tuple of decoded and cleaned names """ output = tuple(gas.tobytes().decode().strip().split("_")[0] for gas in names) return output @staticmethod - def get_col_dry(vmr_h2o, plev, latitude=None): - """Calculate the dry column of the atmosphere + def get_col_dry(vmr_h2o: xr.DataArray, plev: xr.DataArray, latitude: xr.DataArray | None = None) -> xr.DataArray: + """Calculate the dry column of the atmosphere. Args: - vmr_h2o (np.ndarray): Water vapor volume mixing ratio - plev (np.ndarray): Pressure levels - latitude (np.ndarray): Latitude of the location + vmr_h2o: Water vapor volume mixing ratio + plev: Pressure levels + latitude: Latitude of the location Returns: - np.ndarray: Dry column of the atmosphere + DataArray containing dry column of the atmosphere """ # Convert latitude to g0 DataArray if latitude is not None: @@ -543,7 +616,23 @@ def compute( gas_name_map: dict[str, str] | None = None, variable_mapping: AtmosphericMapping | None = None, add_to_input: bool = True, - ): + ) -> xr.Dataset | None: + """Compute gas optics for given atmospheric conditions. + + Args: + atmosphere: Dataset containing atmospheric conditions + problem_type: Type of radiative transfer problem to solve + gas_name_map: Optional mapping between gas names and variable names + variable_mapping: Optional mapping for atmospheric variables + add_to_input: Whether to add results to input dataset + + Returns: + Dataset containing gas optics results if add_to_input=False, + otherwise None + + Raises: + ValueError: If problem_type is invalid + """ # Create and validate gas mapping gas_mapping = GasMapping.create(self._gas_names, gas_name_map).validate() @@ -586,15 +675,45 @@ def compute( class LWGasOpticsAccessor(BaseGasOpticsAccessor): - """Accessor for internal radiation sources""" + """Accessor for internal (longwave) radiation sources. + + This class handles gas optics calculations specific to longwave radiation, including + computing absorption optical depths, Planck sources, and boundary conditions. + """ + + def compute_problem(self, atmosphere: xr.Dataset, gas_interpolation_data: xr.Dataset) -> xr.Dataset: + """Compute absorption optical depths for longwave radiation. + + Args: + atmosphere: Dataset containing atmospheric conditions + gas_interpolation_data: Dataset containing interpolated gas properties - def compute_problem(self, atmosphere, gas_interpolation_data): + Returns: + Dataset containing absorption optical depths + """ return self.tau_absorption(atmosphere, gas_interpolation_data) - def compute_sources(self, atmosphere, gas_interpolation_data): + def compute_sources(self, atmosphere: xr.Dataset, gas_interpolation_data: xr.Dataset) -> xr.Dataset: + """Compute Planck source terms for longwave radiation. + + Args: + atmosphere: Dataset containing atmospheric conditions + gas_interpolation_data: Dataset containing interpolated gas properties + + Returns: + Dataset containing Planck source terms + """ return self.compute_planck(atmosphere, gas_interpolation_data) - def compute_boundary_conditions(self, atmosphere): + def compute_boundary_conditions(self, atmosphere: xr.Dataset) -> xr.DataArray: + """Compute surface emissivity boundary conditions. + + Args: + atmosphere: Dataset containing atmospheric conditions + + Returns: + DataArray containing surface emissivity values + """ if "surface_emissivity" not in atmosphere.data_vars: # Add surface emissivity directly to atmospheric conditions return xr.DataArray( @@ -613,7 +732,16 @@ def compute_boundary_conditions(self, atmosphere): else: return atmosphere["surface_emissivity"] - def compute_planck(self, atmosphere, gas_interpolation_data): + def compute_planck(self, atmosphere: xr.Dataset, gas_interpolation_data: xr.Dataset) -> xr.Dataset: + """Compute Planck source terms for longwave radiation. + + Args: + atmosphere: Dataset containing atmospheric conditions + gas_interpolation_data: Dataset containing interpolated gas properties + + Returns: + Dataset containing Planck source terms including surface, layer and level sources + """ site_dim = atmosphere.mapping.get_dim("site") layer_dim = atmosphere.mapping.get_dim("layer") level_dim = atmosphere.mapping.get_dim("level") @@ -715,9 +843,22 @@ def compute_planck(self, atmosphere, gas_interpolation_data): class SWGasOpticsAccessor(BaseGasOpticsAccessor): - """Accessor for external radiation sources""" + """Accessor for external (shortwave) radiation sources. + + This class handles gas optics calculations specific to shortwave radiation, including + computing absorption and Rayleigh scattering optical depths, solar sources, and boundary conditions. + """ + + def compute_problem(self, atmosphere: xr.Dataset, gas_interpolation_data: xr.Dataset) -> xr.Dataset: + """Compute optical properties for shortwave radiation. - def compute_problem(self, atmosphere, gas_interpolation_data): + Args: + atmosphere: Dataset containing atmospheric conditions + gas_interpolation_data: Dataset containing interpolated gas properties + + Returns: + Dataset containing optical properties (tau, ssa, g) + """ # Calculate absorption optical depth tau_abs = self.tau_absorption(atmosphere, gas_interpolation_data) @@ -732,8 +873,17 @@ def compute_problem(self, atmosphere, gas_interpolation_data): g = xr.zeros_like(tau["tau"]).rename("g") return xr.merge([tau, ssa, g]) - def compute_sources(self, atmosphere, *args, **kwargs): - """Implementation for external source computation""" + def compute_sources(self, atmosphere: xr.Dataset, *args, **kwargs) -> xr.DataArray: + """Compute solar source terms. + + Args: + atmosphere: Dataset containing atmospheric conditions + *args: Variable length argument list + **kwargs: Arbitrary keyword arguments + + Returns: + DataArray containing top-of-atmosphere solar source + """ a_offset = SOLAR_CONSTANTS["A_OFFSET"] b_offset = SOLAR_CONSTANTS["B_OFFSET"] @@ -756,7 +906,15 @@ def compute_sources(self, atmosphere, *args, **kwargs): def_tsi = toa_flux.sum(dim="gpt") return (toa_flux * (total_solar_irradiance / def_tsi)).rename("toa_source") - def compute_boundary_conditions(self, atmosphere): + def compute_boundary_conditions(self, atmosphere: xr.Dataset) -> xr.Dataset: + """Compute surface and solar boundary conditions. + + Args: + atmosphere: Dataset containing atmospheric conditions + + Returns: + Dataset containing solar zenith angles, surface albedos and solar angle mask + """ solar_zenith_angle_var = atmosphere.mapping.get_var("solar_zenith_angle") surface_albedo_var = atmosphere.mapping.get_var("surface_albedo") surface_albedo_dir_var = atmosphere.mapping.get_var("surface_albedo_dir") @@ -803,7 +961,15 @@ def compute_boundary_conditions(self, atmosphere): ] ) - def tau_rayleigh(self, gas_interpolation_data): + def tau_rayleigh(self, gas_interpolation_data: xr.Dataset) -> xr.Dataset: + """Compute Rayleigh scattering optical depth. + + Args: + gas_interpolation_data: Dataset containing interpolated gas properties + + Returns: + Dataset containing Rayleigh scattering optical depth + """ # Combine upper and lower Rayleigh coefficients krayl = xr.concat( [self._dataset["rayl_lower"], self._dataset["rayl_upper"]], @@ -858,3 +1024,35 @@ def tau_rayleigh(self, gas_interpolation_data): ) return tau_rayleigh.rename("tau").to_dataset() + + +@xr.register_dataset_accessor("gas_optics") +class GasOpticsAccessor: + """Factory class that returns appropriate GasOptics implementation based on dataset contents. + + This class determines whether to return a longwave (LW) or shortwave (SW) gas optics + accessor by checking for the presence of internal source variables in the dataset. + + Args: + xarray_obj (xr.Dataset): The xarray Dataset containing gas optics data + selected_gases (list[str] | None): Optional list of gas names to include. + If None, all gases in the dataset will be used. + + Returns: + Union[LWGasOpticsAccessor, SWGasOpticsAccessor]: The appropriate gas optics accessor + based on whether internal source terms are present. + """ + + def __new__( + cls, xarray_obj: xr.Dataset, selected_gases: list[str] | None = None + ) -> Union[LWGasOpticsAccessor, SWGasOpticsAccessor]: + # Check if source is internal by looking for required LW variables + is_internal: bool = ( + "totplnk" in xarray_obj.data_vars + and "plank_fraction" in xarray_obj.data_vars + ) + + if is_internal: + return LWGasOpticsAccessor(xarray_obj, is_internal, selected_gases) + else: + return SWGasOpticsAccessor(xarray_obj, is_internal, selected_gases) diff --git a/pyrte_rrtmgp/rte_solver.py b/pyrte_rrtmgp/rte_solver.py index a1fd803..290c6e2 100644 --- a/pyrte_rrtmgp/rte_solver.py +++ b/pyrte_rrtmgp/rte_solver.py @@ -1,42 +1,36 @@ from typing import Optional -import numpy as np import xarray as xr +from pyrte_rrtmgp.constants import GAUSS_DS, GAUSS_WTS from pyrte_rrtmgp.data_types import ProblemTypes from pyrte_rrtmgp.kernels.rte import lw_solver_noscat, sw_solver_2stream -from pyrte_rrtmgp.utils import logger class RTESolver: - GAUSS_DS = np.reciprocal( - np.array( - [ - [0.6096748751, np.inf, np.inf, np.inf], - [0.2509907356, 0.7908473988, np.inf, np.inf], - [0.1024922169, 0.4417960320, 0.8633751621, np.inf], - [0.0454586727, 0.2322334416, 0.5740198775, 0.9030775973], - ] - ) - ) - - GAUSS_WTS = np.array( - [ - [1.0, 0.0, 0.0, 0.0], - [0.2300253764, 0.7699746236, 0.0, 0.0], - [0.0437820218, 0.3875796738, 0.5686383044, 0.0], - [0.0092068785, 0.1285704278, 0.4323381850, 0.4298845087], - ] - ) + GAUSS_DS = GAUSS_DS + GAUSS_WTS = GAUSS_WTS def _compute_quadrature( self, ncol: int, ngpt: int, nmus: int ) -> tuple[xr.DataArray, xr.DataArray]: - """Compute quadrature weights and secants.""" - n_quad_angs = nmus + """Compute quadrature weights and secants for radiative transfer calculations. + + Args: + ncol: Number of atmospheric columns. + ngpt: Number of g-points (spectral points). + nmus: Number of quadrature angles. + + Returns: + tuple containing: + ds (xr.DataArray): Quadrature secants (directional cosines) with dimensions + [site, gpt, n_quad_angs]. + weights (xr.DataArray): Quadrature weights with dimension [n_quad_angs]. + """ + n_quad_angs: int = nmus # Create DataArray for ds with proper dimensions and coordinates - ds = xr.DataArray( + ds: xr.DataArray = xr.DataArray( self.GAUSS_DS[0:n_quad_angs, n_quad_angs - 1], dims=["n_quad_angs"], coords={"n_quad_angs": range(n_quad_angs)}, @@ -45,7 +39,7 @@ def _compute_quadrature( ds = ds.expand_dims({"site": ncol, "gpt": ngpt}) # Create DataArray for weights - weights = xr.DataArray( + weights: xr.DataArray = xr.DataArray( self.GAUSS_WTS[0:n_quad_angs, n_quad_angs - 1], dims=["n_quad_angs"], coords={"n_quad_angs": range(n_quad_angs)}, @@ -56,11 +50,36 @@ def _compute_quadrature( def _compute_lw_fluxes_absorption( self, problem_ds: xr.Dataset, spectrally_resolved: bool = False ) -> xr.Dataset: - nmus = 1 - top_at_1 = problem_ds["layer"][0] < problem_ds["layer"][-1] + """Compute longwave fluxes for absorption-only radiative transfer. + + Args: + problem_ds: Dataset containing the problem specification with required variables: + - tau: Optical depth + - layer_source: Layer source function + - level_source: Level source function + - surface_emissivity: Surface emissivity + - surface_source: Surface source function + - surface_source_jacobian: Surface source Jacobian + Optional variables: + - incident_flux: Incident flux at top of atmosphere + - ssa: Single scattering albedo + - g: Asymmetry parameter + spectrally_resolved: If True, return spectrally resolved fluxes. + If False, return broadband fluxes. Defaults to False. + + Returns: + Dataset containing the computed fluxes: + - lw_flux_up_jacobian: Upward flux Jacobian + - lw_flux_up_broadband: Broadband upward flux + - lw_flux_down_broadband: Broadband downward flux + - lw_flux_up: Spectrally resolved upward flux + - lw_flux_down: Spectrally resolved downward flux + """ + nmus: int = 1 + top_at_1: bool = problem_ds["layer"][0] < problem_ds["layer"][-1] if "incident_flux" not in problem_ds: - incident_flux = xr.zeros_like(problem_ds["surface_source"]) + incident_flux: xr.DataArray = xr.zeros_like(problem_ds["surface_source"]) else: incident_flux = problem_ds["incident_flux"] @@ -72,8 +91,12 @@ def _compute_lw_fluxes_absorption( ds, weights = self._compute_quadrature( problem_ds.sizes["site"], problem_ds.sizes["gpt"], nmus ) - ssa = problem_ds["ssa"] if "ssa" in problem_ds else problem_ds["tau"].copy() - g = problem_ds["g"] if "g" in problem_ds else problem_ds["tau"].copy() + ssa: xr.DataArray = ( + problem_ds["ssa"] if "ssa" in problem_ds else problem_ds["tau"].copy() + ) + g: xr.DataArray = ( + problem_ds["g"] if "g" in problem_ds else problem_ds["tau"].copy() + ) ( solver_flux_up_jacobian, @@ -138,6 +161,24 @@ def _compute_lw_fluxes_absorption( def _compute_sw_fluxes( self, problem_ds: xr.Dataset, spectrally_resolved: bool = False ) -> xr.Dataset: + """Compute shortwave fluxes using two-stream solver. + + Args: + problem_ds: Dataset containing problem definition including optical properties, + surface properties and boundary conditions. + spectrally_resolved: If True, return spectrally resolved fluxes. + If False, return broadband fluxes. + + Returns: + Dataset containing computed shortwave fluxes: + - sw_flux_up_broadband: Upward broadband flux + - sw_flux_down_broadband: Downward broadband flux + - sw_flux_dir_broadband: Direct broadband flux + - sw_flux_up: Upward spectral flux + - sw_flux_down: Downward spectral flux + - sw_flux_dir: Direct spectral flux + """ + # Expand surface albedo dimensions if needed if "gpt" not in problem_ds["surface_albedo_direct"].dims: problem_ds["surface_albedo_direct"] = problem_ds[ "surface_albedo_direct" @@ -147,13 +188,16 @@ def _compute_sw_fluxes( "surface_albedo_diffuse" ].expand_dims({"gpt": problem_ds.sizes["gpt"]}, axis=1) + # Set diffuse incident flux if "incident_flux_dif" not in problem_ds: incident_flux_dif = xr.zeros_like(problem_ds["toa_source"]) else: incident_flux_dif = problem_ds["incident_flux_dif"] + # Determine vertical orientation top_at_1 = problem_ds["layer"][0] < problem_ds["layer"][-1] + # Call solver ( solver_flux_up_broadband, solver_flux_down_broadband, @@ -200,6 +244,7 @@ def _compute_sw_fluxes( dask="allowed", ) + # Construct output dataset fluxes = xr.Dataset( { "sw_flux_up_broadband": solver_flux_up_broadband, @@ -217,14 +262,28 @@ def solve( self, problem_ds: xr.Dataset, add_to_input: bool = True, - spectrally_resolved: Optional[bool] = False, - ): + spectrally_resolved: bool = False, + ) -> Optional[xr.Dataset]: + """Solve radiative transfer problem based on problem type. + + Args: + problem_ds: Dataset containing problem definition and inputs + add_to_input: If True, add computed fluxes to input dataset. If False, return fluxes separately + spectrally_resolved: If True, return spectrally resolved fluxes. If False, return broadband fluxes + + Returns: + Dataset containing computed fluxes if add_to_input is False, None otherwise + """ if problem_ds.attrs["problem_type"] == ProblemTypes.LW_ABSORPTION.value: fluxes = self._compute_lw_fluxes_absorption(problem_ds, spectrally_resolved) elif problem_ds.attrs["problem_type"] == ProblemTypes.SW_2STREAM.value: fluxes = self._compute_sw_fluxes(problem_ds, spectrally_resolved) + else: + raise ValueError( + f"Unknown problem type: {problem_ds.attrs['problem_type']}" + ) if add_to_input: problem_ds.assign_coords(fluxes.coords) - else: - return fluxes + return None + return fluxes diff --git a/pyrte_rrtmgp/utils.py b/pyrte_rrtmgp/utils.py deleted file mode 100644 index 77c7dcf..0000000 --- a/pyrte_rrtmgp/utils.py +++ /dev/null @@ -1,67 +0,0 @@ -import logging - -import numpy as np -import xarray as xr - - -def get_usecols(solar_zenith_angle): - """Get the usecols values - - Args: - solar_zenith_angle (np.ndarray): Solar zenith angle in degrees - - Returns: - np.ndarray: Usecols values - """ - return solar_zenith_angle < 90.0 - 2.0 * np.spacing(90.0) - - -def compute_mu0(solar_zenith_angle, nlayer=None): - """Calculate the cosine of the solar zenith angle - - Args: - solar_zenith_angle (np.ndarray): Solar zenith angle in degrees - nlayer (int, optional): Number of layers. Defaults to None. - """ - usecol_values = get_usecols(solar_zenith_angle) - mu0 = np.where(usecol_values, np.cos(np.radians(solar_zenith_angle)), 1.0) - if nlayer is not None: - mu0 = np.stack([mu0] * nlayer).T - return mu0 - - -def compute_toa_flux(total_solar_irradiance, solar_source): - """Compute the top of atmosphere flux - - Args: - total_solar_irradiance (np.ndarray): Total solar irradiance - solar_source (np.ndarray): Solar source - - Returns: - np.ndarray: Top of atmosphere flux - """ - ncol = total_solar_irradiance.shape[0] - toa_flux = np.stack([solar_source] * ncol) - def_tsi = toa_flux.sum(axis=1) - return (toa_flux.T * (total_solar_irradiance / def_tsi)).T - - -def convert_xarray_args(func): - """Decorator to convert xarray DataArrays to numpy arrays efficiently""" - - def wrapper(*args, **kwargs): - new_args = [] - for arg in args: - if hasattr(arg, "values"): - # Get direct reference to underlying numpy array without copy - new_args.append(arg.values) - else: - new_args.append(arg) - return func(*new_args, **kwargs) - - return wrapper - - -def logger(): - """Get the logger""" - return logging.getLogger(__name__) diff --git a/tests/test_python_frontend/test_lw_solver.py b/tests/test_python_frontend/test_lw_solver.py index f9fce52..9aa08cd 100644 --- a/tests/test_python_frontend/test_lw_solver.py +++ b/tests/test_python_frontend/test_lw_solver.py @@ -4,8 +4,8 @@ import xarray as xr from pyrte_rrtmgp import rrtmgp_gas_optics -from pyrte_rrtmgp.rrtmgp_gas_optics import GasOpticsFiles, load_gas_optics from pyrte_rrtmgp.rrtmgp_data import download_rrtmgp_data +from pyrte_rrtmgp.rrtmgp_gas_optics import GasOpticsFiles, load_gas_optics from pyrte_rrtmgp.rte_solver import RTESolver ERROR_TOLERANCE = 1e-7 @@ -38,14 +38,18 @@ def test_lw_solver_noscat(): # Load gas optics with the new API gas_optics_lw = load_gas_optics(gas_optics_file=GasOpticsFiles.LW_G256) - + # Compute gas optics for the atmosphere gas_optics_lw.gas_optics.compute(atmosphere, problem_type="absorption") - + # Solve RTE with the new API solver = RTESolver() fluxes = solver.solve(atmosphere, add_to_input=False) - + # Compare results with reference data - assert np.isclose(fluxes["lw_flux_up_broadband"], ref_flux_up, atol=ERROR_TOLERANCE).all() - assert np.isclose(fluxes["lw_flux_down_broadband"], ref_flux_down, atol=ERROR_TOLERANCE).all() + assert np.isclose( + fluxes["lw_flux_up_broadband"], ref_flux_up, atol=ERROR_TOLERANCE + ).all() + assert np.isclose( + fluxes["lw_flux_down_broadband"], ref_flux_down, atol=ERROR_TOLERANCE + ).all() diff --git a/tests/test_python_frontend/test_sw_solver.py b/tests/test_python_frontend/test_sw_solver.py index 8d2fbb4..4db358b 100644 --- a/tests/test_python_frontend/test_sw_solver.py +++ b/tests/test_python_frontend/test_sw_solver.py @@ -4,8 +4,8 @@ import xarray as xr from pyrte_rrtmgp import rrtmgp_gas_optics -from pyrte_rrtmgp.rrtmgp_gas_optics import GasOpticsFiles, load_gas_optics from pyrte_rrtmgp.rrtmgp_data import download_rrtmgp_data +from pyrte_rrtmgp.rrtmgp_gas_optics import GasOpticsFiles, load_gas_optics from pyrte_rrtmgp.rte_solver import RTESolver ERROR_TOLERANCE = 1e-7 @@ -38,14 +38,14 @@ def test_sw_solver_noscat(): # Load gas optics with new API gas_optics_sw = load_gas_optics(gas_optics_file=GasOpticsFiles.SW_G224) - + # Load and compute gas optics with atmosphere data gas_optics_sw.gas_optics.compute(atmosphere, problem_type="two-stream") - + # Solve using new rte_solve function solver = RTESolver() fluxes = solver.solve(atmosphere, add_to_input=False) - + # Compare results assert np.isclose(fluxes["sw_flux_up"], ref_flux_up, atol=ERROR_TOLERANCE).all() assert np.isclose(fluxes["sw_flux_down"], ref_flux_down, atol=ERROR_TOLERANCE).all()