Skip to content

Commit b9c0ee4

Browse files
StevePnykysolvik
andauthored
Pip packaging (#44)
* Add dependencies to pyproject.toml * Remove environment.yml * Replace from jax.config import config (deprecated) with just jax.config * Streamline resources and xarray manipulation to match new versions * Call asnumpy on xarray ds when opening to avoid scipyarraywrapper attribute error * Replace jax.numpy in1d (no longer exists) with isin * Update getargspec (deprecated) to inspect.signature * Add qgs to pyproject.toml * Fix mistake, signature() must called on h * Add zarr, dask, and others needed for google cloud downloads * Need to precompute indexer when using ds.where on remote zarr array * Add qgs and cloud to dev install (will add pyqg later) * Move from conda to pip for CI testing * Rename CI config yml from conda to pip * Add flake8 to dev build * Use workaround for installing pyqg with older Cython * Manually install wheel before pyqg install * Add dask for cloud build * Remove pyqg from dependencies for now, include tests in full install * ETKF fix to exclude end of timewindow exact match observations * Add extras to example build (ray, pandas, matplotlib) * Add hyperopt to examples build * Add cftime to examples (needed for loading netcdf) * Update README with new install instructions * Updated README with new install instructions --------- Co-authored-by: Kylen Solvik <[email protected]>
1 parent 0e3cbc5 commit b9c0ee4

File tree

10 files changed

+94
-318
lines changed

10 files changed

+94
-318
lines changed

.github/workflows/python-ci-conda.yml renamed to .github/workflows/python-ci-pip.yml

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,19 @@ jobs:
1313
steps:
1414

1515
- uses: actions/checkout@v3
16-
- name: Set up Python 3.10.8
16+
- name: Set up Python 3.11.9
1717
uses: actions/setup-python@v3
1818
with:
19-
python-version: '3.10.8'
19+
python-version: '3.11.9'
2020

21-
- name: Add conda to system path
22-
run: |
23-
# $CONDA is an environment variable pointing to the root of the miniconda directory
24-
echo $CONDA/bin >> $GITHUB_PATH
2521
- name: Set cache date
2622
run: echo "DATE=$(date +'%Y%m%d')" >> $GITHUB_ENV
2723

28-
- uses: actions/cache@v2
29-
with:
30-
path: /usr/share/miniconda3/envs/base/
31-
key: conda-${{ hashFiles('environment.yml') }}-${{ env.DATE }}-${{ env.CACHE_NUMBER }}
32-
id: cache
33-
34-
- name: Update environment
24+
- name: Install package and requirements
3525
run: |
36-
conda install conda-libmamba-solver flake8 pytest
37-
# Slightly awkward way of adding an optional package to our environment
38-
echo " - pyqg" >> environment.yml
39-
conda env update --name base --file environment.yml --solver=libmamba
40-
pip install qgs
41-
if: steps.cache.outputs.cache-hit != 'true'
26+
python3 -m pip install -e ".[dev]"
27+
python3 -m pip install Cython==0.29.37 wheel
28+
python3 -m pip install pyqg==0.7.2 --no-build-isolation --use-deprecated=legacy-resolver
4229
4330
- name: Lint with flake8
4431
run: |

README.md

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,50 +20,45 @@ Neural Networks,Volume 153, 530-552, ISSN 0893-6080, https://doi.org/10.1016/j.n
2020

2121
## Installation
2222

23+
We recommend setting up a virtual environment using either [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html) or [virtualenv](https://virtualenv.pypa.io/en/latest/user_guide.html).
24+
2325
#### Clone Repo:
2426

2527
```bash
2628
git clone [email protected]:StevePny/DataAssimBench.git
2729
```
2830

29-
#### Set Up Conda Environment
30-
31-
```bash
32-
cd DataAssimBench
33-
conda env create -f environment.yml
34-
conda activate dab
35-
```
36-
3731
#### Install dabench
3832
```bash
39-
pip install .
33+
cd ./DataAssimBench
34+
pip install -e ".[full]"
4035
```
4136

42-
#### Install dependencies (optional)
43-
The user may have to manually install:
37+
This will create a full installation including the ability to access cloud data or interface with other packages such as qgs. Alternatively, for a minimal installation, run:
38+
4439
```bash
45-
conda install -c conda-forge jax
46-
conda install -c conda-forge pyqg
40+
pip install -e .
4741
```
4842

43+
4944
## Quick Start
5045

5146
For more detailed examples, go to the [DataAssimBench-Examples](https://github.com/StevePny/DataAssimBench-Examples) repo.
5247

5348
#### Importing data generators
5449

5550
```python
56-
from dabench import data
57-
help(data) # View data classes, e.g. data.Lorenz96
58-
help(data.Lorenz96) # Get more info about Lorenz96 class
51+
import dabench as dab
52+
help(dab.data) # View data classes, e.g. data.Lorenz96
53+
help(dab.data.Lorenz96) # Get more info about Lorenz96 class
5954
```
6055

6156
#### Generating data
6257

6358
All of the data objects are set up with reasonable defaults. Generating data is as easy as:
6459

6560
```python
66-
l96_obj = data.Lorenz96() # Create data generator object
61+
l96_obj = dab.data.Lorenz96() # Create data generator object
6762
l96_obj.generate(n_steps=1000) # Generate Lorenz96 simulation data
6863
l96_obj.values # View the output values
6964
```
@@ -76,7 +71,7 @@ All data objects are customizable.
7671

7772
For data-generators (e.g. numerical models such as Lorenz63, Lorenz96, SQGTurb), this means you can change initial conditions, model parameters, timestep size, number of timesteps, etc.
7873

79-
For data-downloaders (e.g. ENSOIDX, AWS, GCP), this means changing which variables you download, the lat/lon bounding box, the time period, etc.
74+
For data-downloaders (e.g. ENSOIDX, GCP), this means changing which variables you download, the lat/lon bounding box, the time period, etc.
8075

8176
The recommended way of specifying options is to pass a keyword argument (kwargs) dictionary. The exact options vary between the different types of data objects, so be sure to check the specific documentation for your chosen generator/downloader more info.
8277

@@ -86,18 +81,18 @@ The recommended way of specifying options is to pass a keyword argument (kwargs)
8681
l96_options = {'forcing_term': 7.5,
8782
'system_dim': 5,
8883
'delta_t': 0.05}
89-
l96_obj = data.Lorenz96(**l96_options) # Create data generator object
84+
l96_obj = dab.data.Lorenz96(**l96_options) # Create data generator object
9085
l96_obj.generate(n_steps=1000) # Generate Lorenz96 simulation data
9186
l96_obj.values # View the output values
9287
```
9388

94-
- For example, for the Amazon Web Services (AWS) ERA5 data-downloader, we can select our variables and time period like this:
89+
- For example, for the Google Cloud (GCP) ERA5 data-downloader, we can select our variables and time period like this:
9590

9691
```python
97-
aws_options = {'variables': ['air_pressure_at_mean_sea_level', 'sea_surface_temperature'],
98-
'years': [1984, 1985]}
99-
aws_obj = data.AWS(**aws_options) # Create data generator object
100-
aws_obj.load() # Loads data. Can also use aws_obj.generate()
101-
aws_obj.values # View the output values
92+
gcp_options = {'variables': ['2m_temperature', 'sea_surface_temperature'],
93+
'date_start': '2020-06-01'
94+
'date_end': '2020-06-07'}
95+
gcp_obj = dab.data.GCP(**gcp_options) # Create data generator object
96+
gcp_obj.load() # Loads data. Can also use gcp_obj.generate()
97+
gcp_obj.values # View the output values
10298
```
103-

dabench/dacycler/_etkf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,9 @@ def cycle(self,
265265
(obs_vector.times > cur_time - analysis_window/2)
266266
# AND Less than end of window
267267
* (obs_vector.times < cur_time + analysis_window/2)
268+
# AND not equal to end of window
269+
* (1-jnp.isclose(obs_vector.times, cur_time + analysis_window/2,
270+
rtol=0))
268271
# OR Equal to start of window
269272
+ jnp.isclose(obs_vector.times, cur_time - analysis_window/2,
270273
rtol=0)

dabench/data/_data.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ def generate(self, n_steps=None, t_final=None, x0=None, M0=None,
275275
def _import_xarray_ds(self, ds, include_vars=None, exclude_vars=None,
276276
years_select=None, dates_select=None,
277277
lat_sorting=None):
278+
ds = ds.as_numpy()
278279
if dates_select is not None:
279280
dates_filter_indices = ds.time.dt.date.isin(dates_select)
280281
# First check to make sure the dates exist in the object
@@ -313,8 +314,8 @@ def _import_xarray_ds(self, ds, include_vars=None, exclude_vars=None,
313314
'stall, or crash.'.format(size_gb))
314315

315316
# Get dims
316-
dims = ds.dims
317-
dims_names = list(ds.dims)
317+
dims = ds.sizes
318+
dims_names = list(ds.sizes)
318319

319320
# Set times
320321
time_key = None
@@ -402,7 +403,7 @@ def _import_xarray_ds(self, ds, include_vars=None, exclude_vars=None,
402403
)
403404

404405
# Gather values and set dimensions
405-
temp_values = np.moveaxis(np.array(ds.to_array()), 0, -1)
406+
temp_values = np.moveaxis(ds.to_dataarray().to_numpy(), 0, -1)
406407
self.original_dim = temp_values.shape[1:]
407408
if self.original_dim[-1] == 1 and len(self.original_dim) > 2:
408409
self.original_dim = self.original_dim[:-1]
@@ -449,21 +450,13 @@ def load_netcdf(self, filepath=None, include_vars=None, exclude_vars=None,
449450
"""
450451
if filepath is None:
451452
# Use importlib.resources to get the default netCDF from dabench
452-
with resources.open_binary(
453-
_suppl_data, 'era5_japan_slp.nc') as nc_file:
454-
with xr.open_dataset(nc_file, decode_coords='all') as ds:
455-
self._import_xarray_ds(
456-
ds, include_vars=include_vars,
457-
exclude_vars=exclude_vars,
458-
years_select=years_select, dates_select=dates_select,
459-
lat_sorting=lat_sorting)
460-
else:
461-
with xr.open_dataset(filepath, decode_coords='all') as ds:
462-
self._import_xarray_ds(
463-
ds, include_vars=include_vars,
464-
exclude_vars=exclude_vars,
465-
years_select=years_select, dates_select=dates_select,
466-
lat_sorting=lat_sorting)
453+
filepath = resources.files(_suppl_data).joinpath('era5_japan_slp.nc')
454+
with xr.open_dataset(filepath, decode_coords='all') as ds:
455+
self._import_xarray_ds(
456+
ds, include_vars=include_vars,
457+
exclude_vars=exclude_vars,
458+
years_select=years_select, dates_select=dates_select,
459+
lat_sorting=lat_sorting)
467460

468461
def save_netcdf(self, filename):
469462
"""Saves values in values attribute to netCDF file

dabench/data/enso_indices.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def _combine_vals_years(self, all_vals, all_years):
275275
all_vals[v] = all_vals[v][np.sort(indices)]
276276
all_years[v] = all_years[v][np.sort(indices)]
277277
# Append common_vals
278-
common_vals.append(all_vals[v][jnp.in1d(all_years[v],
278+
common_vals.append(all_vals[v][jnp.isin(all_years[v],
279279
common_years)])
280280
for f in common_vals:
281281
print(f.shape)

dabench/data/gcp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ def _load_gcp_era5(self):
129129
if self.min_lat is not None and self.max_lat is not None:
130130
# Subset by lat boundaries
131131
ds = ds.where(
132-
(ds.latitude < self.max_lat) &
133-
(ds.latitude > self.min_lat),
132+
((ds.latitude < self.max_lat) &
133+
(ds.latitude > self.min_lat)).compute(),
134134
drop=True)
135135
if self.min_lon is not None and self.max_lon is not None:
136136
# Convert west longs to degrees east
@@ -142,8 +142,8 @@ def _load_gcp_era5(self):
142142
subset_max_lon += 360
143143
# Subset by lon boundaries
144144
ds = ds.where(
145-
(ds.longitude < subset_max_lon) &
146-
(ds.longitude > subset_min_lon),
145+
((ds.longitude < subset_max_lon) &
146+
(ds.longitude > subset_min_lon)).compute(),
147147
drop=True)
148148

149149
self._import_xarray_ds(ds)

dabench/data/sqgturb.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,14 @@
3636
import jax
3737
import jax.numpy as jnp
3838
from jax.numpy.fft import rfft2, irfft2
39-
from jax.config import config
4039
from functools import partial
4140
from importlib import resources
4241

4342
from dabench.data import _data
4443
from dabench import _suppl_data
4544

4645
# Set to enable 64bit floats in Jax
47-
config.update('jax_enable_x64', True)
46+
jax.config.update('jax_enable_x64', True)
4847

4948

5049
class SQGTurb(_data.Data):

dabench/obsop/_obsop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self,
2929
self.h = self._index_state_vec
3030
else:
3131
# Check custom h
32-
custom_args = inspect.getargspec(h).args
32+
custom_args = inspect.signature(h).parameters
3333
if 'state_vec' not in custom_args or 'obs_vec' not in custom_args:
3434
raise ValueError('User-specified h does not accept the '
3535
'required args for h: "state_vec" and '

0 commit comments

Comments
 (0)