Skip to content

Commit

Permalink
Merge pull request #295 from roocs/adapt-xarray
Browse files Browse the repository at this point in the history
Address typing hints for xarray 2023.08 compatibility, Update ReadTheDocs
  • Loading branch information
Zeitsperre authored Aug 21, 2023
2 parents 2c1fb1d + a0cf4b6 commit 4eb3e1e
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 62 deletions.
13 changes: 8 additions & 5 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@ sphinx:
formats:
- pdf

build:
os: ubuntu-22.04
tools:
python: "mambaforge-22.9"

conda:
environment: environment.yml

# Optionally set the version of Python and requirements required to build your docs
python:
install:
- method: pip
path: .
extra_requirements:
- docs

build:
os: ubuntu-20.04
tools:
python: "3.8"
8 changes: 8 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
Version History
===============

v0.10.1 (2023-08-21)
--------------------

Bug Fixes
^^^^^^^^^
* Fixed an issue with the type hinting for subset functions that were broken due to changes in `xarray` (2023.08). (#295).
* Updated ReadTheDocs configuration to use `Mambaforge` (22.9) as engine for building documentation. (#295).

v0.10.0 (2023-06-28)
--------------------

Expand Down
11 changes: 4 additions & 7 deletions clisops/core/average.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def average_shape(
) -> Union[xr.DataArray, xr.Dataset]:
"""Average a DataArray or Dataset spatially using vector shapes.
Return a DataArray or Dataset averaged over each Polygon given.
Requires xESMF >= 0.5.0.
Return a DataArray or Dataset averaged over each Polygon given. Requires xESMF >= 0.5.0.
Parameters
----------
Expand Down Expand Up @@ -146,8 +145,7 @@ def average_over_dims(
dims: Sequence[str] = None,
ignore_undetected_dims: bool = False,
) -> Union[xr.DataArray, xr.Dataset]:
"""
Average a DataArray or Dataset over the dimensions specified.
"""Average a DataArray or Dataset over the dimensions specified.
Parameters
----------
Expand Down Expand Up @@ -240,14 +238,13 @@ def average_time(
ds: Union[xr.DataArray, xr.Dataset],
freq: str,
) -> Union[xr.DataArray, xr.Dataset]:
"""
Average a DataArray or Dataset over the time frequency specified.
"""Average a DataArray or Dataset over the time frequency specified.
Parameters
----------
ds : Union[xr.DataArray, xr.Dataset]
Input values.
freq: str
freq : str
The frequency to average over. One of "month", "year".
Returns
Expand Down
83 changes: 53 additions & 30 deletions clisops/core/subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from functools import wraps
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple, Union
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import cf_xarray # noqa
import geopandas as gpd
Expand All @@ -21,6 +21,7 @@
from shapely import vectorized
from shapely.geometry import LineString, MultiPolygon, Point, Polygon
from shapely.ops import split, unary_union
from xarray.core import indexing
from xarray.core.utils import get_temp_dimname

from clisops.utils.dataset_utils import adjust_date_to_calendar
Expand All @@ -41,21 +42,21 @@
]


def get_lat(ds):
def get_lat(ds: Union[xarray.Dataset, xarray.DataArray]) -> xarray.DataArray:
try:
return ds.cf["latitude"]
except KeyError:
return ds.lat


def get_lon(ds):
def get_lon(ds: Union[xarray.Dataset, xarray.DataArray]) -> xarray.DataArray:
try:
return ds.cf["longitude"]
except KeyError:
return ds.lon


def check_start_end_dates(func):
def check_start_end_dates(func: Callable) -> Callable:
@wraps(func)
def func_checker(*args, **kwargs):
"""Verify that start and end dates are valid in a time subsetting function."""
Expand Down Expand Up @@ -138,7 +139,7 @@ def func_checker(*args, **kwargs):
return func_checker


def check_start_end_levels(func):
def check_start_end_levels(func: Callable) -> Callable:
@wraps(func)
def func_checker(*args, **kwargs):
"""Verify that first and last levels are valid in a level subsetting function."""
Expand Down Expand Up @@ -223,7 +224,7 @@ def func_checker(*args, **kwargs):
return func_checker


def check_lons(func):
def check_lons(func: Callable) -> Callable:
@wraps(func)
def func_checker(*args, **kwargs):
"""Reformat user-specified "lon" or "lon_bnds" values based on the lon dimensions of a supplied Dataset or DataArray.
Expand Down Expand Up @@ -267,7 +268,7 @@ def func_checker(*args, **kwargs):
return func_checker


def check_levels_exist(func):
def check_levels_exist(func: Callable) -> Callable:
@wraps(func)
def func_checker(*args, **kwargs):
"""Check the requested levels exist in the input Dataset/DataArray and, if not, raise an Exception.
Expand Down Expand Up @@ -304,7 +305,7 @@ def func_checker(*args, **kwargs):
return func_checker


def check_datetimes_exist(func):
def check_datetimes_exist(func: Callable) -> Callable:
@wraps(func)
def func_checker(*args, **kwargs):
"""Check the requested datetimes exist in the input Dataset/DataArray and, if not, raise an Exception.
Expand Down Expand Up @@ -345,7 +346,7 @@ def func_checker(*args, **kwargs):
return func_checker


def convert_lat_lon_to_da(func):
def convert_lat_lon_to_da(func: Callable) -> Callable:
@wraps(func)
def func_checker(*args, **kwargs):
"""Transform input lat, lon to DataArrays.
Expand Down Expand Up @@ -380,7 +381,7 @@ def func_checker(*args, **kwargs):
return func_checker


def wrap_lons_and_split_at_greenwich(func):
def wrap_lons_and_split_at_greenwich(func: Callable) -> Callable:
@wraps(func)
def func_checker(*args, **kwargs):
"""Split and reproject polygon vectors in a GeoDataFrame whose values cross the Greenwich Meridian.
Expand Down Expand Up @@ -480,7 +481,7 @@ def create_mask(
poly: gpd.GeoDataFrame,
wrap_lons: bool = False,
check_overlap: bool = False,
):
) -> xarray.DataArray:
"""Create a mask with values corresponding to the features in a GeoDataFrame using vectorize methods.
The returned mask's points have the value of the first geometry of `poly` they fall in.
Expand Down Expand Up @@ -563,7 +564,7 @@ def create_mask(
return mask


def _rectilinear_grid_exterior_polygon(ds):
def _rectilinear_grid_exterior_polygon(ds: xarray.Dataset) -> Polygon:
"""Return a polygon tracing a rectilinear grid's exterior.
Parameters
Expand Down Expand Up @@ -609,7 +610,9 @@ def _rectilinear_grid_exterior_polygon(ds):
return Polygon(pts)


def _curvilinear_grid_exterior_polygon(ds, mode="bbox"):
def _curvilinear_grid_exterior_polygon(
ds: xarray.Dataset, mode: str = "bbox"
) -> Polygon:
"""Return a polygon tracing a curvilinear grid's exterior.
Parameters
Expand Down Expand Up @@ -693,7 +696,7 @@ def round_down(x, decimal=1):
return Polygon(pts)


def grid_exterior_polygon(ds):
def grid_exterior_polygon(ds: xarray.Dataset) -> Polygon:
"""Return a polygon tracing the grid's exterior.
This function is only accurate for a geographic lat/lon projection. For projected grids, it's a rough approximation.
Expand Down Expand Up @@ -721,13 +724,16 @@ def grid_exterior_polygon(ds):
return _curvilinear_grid_exterior_polygon(ds, mode="bbox")


def is_rectilinear(ds):
def is_rectilinear(ds: Union[xarray.Dataset, xarray.DataArray]) -> bool:
"""Return whether the grid is rectilinear or not."""
sdims = {ds.cf["longitude"].name, ds.cf["latitude"].name}
return sdims.issubset(ds.dims)


def shape_bbox_indexer(ds, poly):
def shape_bbox_indexer(
ds: xarray.Dataset,
poly: Union[gpd.GeoDataFrame, gpd.GeoSeries, gpd.array.GeometryArray],
):
"""Return a spatial indexer that selects the indices of the grid cells covering the given geometries.
Parameters
Expand Down Expand Up @@ -811,7 +817,7 @@ def shape_bbox_indexer(ds, poly):
ds, ind, method="nearest"
)
else:
native_ind = xarray.core.indexing.map_index_queries(
native_ind = indexing.map_index_queries(
ds, ind, method="nearest"
).dim_indexers
else:
Expand Down Expand Up @@ -842,7 +848,7 @@ def shape_bbox_indexer(ds, poly):
def create_weight_masks(
ds_in: Union[xarray.DataArray, xarray.Dataset],
poly: gpd.GeoDataFrame,
):
) -> xarray.DataArray:
"""Create weight masks corresponding to the features in a GeoDataFrame using xESMF.
The returned masks values are the fraction of the corresponding polygon's area
Expand Down Expand Up @@ -1326,20 +1332,20 @@ def subset_bbox(


def assign_bounds(
bounds: Tuple[Optional[float], Optional[float]], coord: xarray.Coordinate
) -> tuple:
bounds: Tuple[Optional[float], Optional[float]], coord: xarray.DataArray
) -> Tuple[Optional[float], Optional[float]]:
"""Replace unset boundaries by the minimum and maximum coordinates.
Parameters
----------
bounds : Tuple[Optional[float], Optional[float]]
Boundaries.
coord : xarray.Coordinate
coord : xarray.DataArray
Grid coordinates.
Returns
-------
tuple
Tuple[Optional[float], Optional[float]]
Lower and upper grid boundaries.
"""
Expand All @@ -1351,20 +1357,37 @@ def assign_bounds(
return bn, bx


def in_bounds(bounds: Tuple[float, float], coord: xarray.Coordinate) -> bool:
"""Check which coordinates are within the boundaries."""
def in_bounds(bounds: Tuple[float, float], coord: xarray.DataArray) -> xarray.DataArray:
"""Check which coordinates are within the boundaries.
Parameters
----------
bounds : Tuple[float, float]
Boundaries.
coord : xarray.DataArray
Grid coordinates.
Returns
-------
xarray.DataArray
"""
bn, bx = bounds
return (coord >= bn) & (coord <= bx)


def _check_desc_coords(coord, bounds, dim):
def _check_desc_coords(
coord: xarray.Dataset,
bounds: Union[Tuple[float, float], List[np.ndarray]],
dim: str,
) -> Tuple[float, float]:
"""If Dataset coordinates are descending, and bounds are ascending, reverse bounds."""
if np.all(coord.diff(dim=dim) < 0) and len(coord) > 1 and bounds[1] > bounds[0]:
bounds = np.flip(bounds)
return bounds


def _check_has_overlaps(polygons: gpd.GeoDataFrame):
def _check_has_overlaps(polygons: gpd.GeoDataFrame) -> None:
non_overlapping = []
for n, p in enumerate(polygons["geometry"][:-1], 1):
if not any(p.overlaps(g) for g in polygons["geometry"][n:]):
Expand All @@ -1377,7 +1400,7 @@ def _check_has_overlaps(polygons: gpd.GeoDataFrame):
)


def _check_has_overlaps_old(polygons: gpd.GeoDataFrame):
def _check_has_overlaps_old(polygons: gpd.GeoDataFrame) -> None:
for i, (inda, pola) in enumerate(polygons.iterrows()):
for indb, polb in polygons.iloc[i + 1 :].iterrows():
if pola.geometry.intersects(polb.geometry):
Expand All @@ -1388,7 +1411,7 @@ def _check_has_overlaps_old(polygons: gpd.GeoDataFrame):
)


def _check_crs_compatibility(shape_crs: CRS, raster_crs: CRS):
def _check_crs_compatibility(shape_crs: CRS, raster_crs: CRS) -> None:
"""If CRS definitions are not WGS84 or incompatible, raise operation warnings."""
wgs84 = CRS(4326)
if not shape_crs.equals(raster_crs):
Expand Down Expand Up @@ -1422,7 +1445,7 @@ def subset_gridpoint(
tolerance: Optional[float] = None,
add_distance: bool = False,
) -> Union[xarray.DataArray, xarray.Dataset]:
"""Extract one or more nearest gridpoint(s) from datarray based on lat lon coordinate(s).
"""Extract one or more of the nearest gridpoint(s) from datarray based on lat lon coordinate(s).
Return a subsetted data array (or Dataset) for the grid point(s) falling nearest the input longitude and latitude
coordinates. Optionally subset the data array for years falling within provided date bounds.
Expand Down Expand Up @@ -1804,7 +1827,7 @@ def distance(
*,
lon: Union[float, Sequence[float], xarray.DataArray],
lat: Union[float, Sequence[float], xarray.DataArray],
):
) -> Union[xarray.DataArray, xarray.Dataset]:
"""Return distance to a point in meters.
Parameters
Expand Down
22 changes: 11 additions & 11 deletions clisops/ops/average.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def average_over_dims(
dims: Optional[Union[Sequence[str], DimensionParameter]] = None,
ignore_undetected_dims: bool = False,
output_dir: Optional[Union[str, Path]] = None,
output_type="netcdf",
split_method="time:auto",
file_namer="standard",
output_type: str = "netcdf",
split_method: str = "time:auto",
file_namer: str = "standard",
) -> List[Union[xr.Dataset, str]]:
"""
Expand All @@ -62,10 +62,10 @@ def average_over_dims(
ignore_undetected_dims : bool
If the dimensions specified are not found in the dataset, an Exception will be raised if set to True.
If False, an exception will not be raised and the other dimensions will be averaged over. Default = False
output_dir: Optional[Union[str, Path]]
output_type: {"netcdf", "nc", "zarr", "xarray"}
split_method: {"time:auto"}
file_namer: {"standard", "simple"}
output_dir : Optional[Union[str, Path]]
output_type : {"netcdf", "nc", "zarr", "xarray"}
split_method : {"time:auto"}
file_namer : {"standard", "simple"}
Returns
-------
Expand Down Expand Up @@ -120,12 +120,12 @@ def _calculate(self):


def average_time(
ds,
ds: Union[xr.Dataset, str],
freq: str,
output_dir: Optional[Union[str, Path]] = None,
output_type="netcdf",
split_method="time:auto",
file_namer="standard",
output_type: str = "netcdf",
split_method: str = "time:auto",
file_namer: str = "standard",
) -> List[Union[xr.Dataset, str]]:
"""
Expand Down
Loading

0 comments on commit 4eb3e1e

Please sign in to comment.