Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for code format at commit and pull request #4

Merged
merged 7 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: 'Lint codebase'

on:
pull_request:
branches: [ "main" ]
push:
branches: [ "main" ]

permissions:
contents: read

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.x'
- uses: pre-commit/[email protected]
with:
extra_args: --verbose
24 changes: 24 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Config to automatically run checks and fixes of code format on git commit.
#
# To use, run `pre-commit install` in the root of the repository.
# See https://pre-commit.com for more information.

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-yaml
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.9.2
hooks:
# Run the linter.
- id: ruff
args: [ --fix ] # enable lint fixes
# Run the formatter.
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
hooks:
- id: codespell
args: ["--skip=*.ipynb", "--ignore-words-list=hist"]
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

## Contribute to geoarches

You can make changes on your own `dev` branch(s). This way you are not blocked by development on the `main` branch, but can still contribute to the `main` branch if you want to and can still incoroporate updates from other team members.
You can make changes on your own `dev` branch(s). This way you are not blocked by development on the `main` branch, but can still contribute to the `main` branch if you want to and can still incorporate updates from other team members.

1. Create a `dev` branch from the `main` branch of geoarches to start making changes.
```sh
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# geoarches

geoarches is a machine learning package for training, running and evaluating ML models on weather and climate data, developped by Guillaume Couairon and Renu Singh in the ARCHES team at INRIA (Paris, France).
geoarches is a machine learning package for training, running and evaluating ML models on weather and climate data, developed by Guillaume Couairon and Renu Singh in the ARCHES team at INRIA (Paris, France).

geoarches's building blocks can be easily integrated into research ML pipelines.
It can also be used to run the ArchesWeather and ArchesWeatherGen weather models.
Expand All @@ -20,15 +20,15 @@ Data:

Model training:
- `backbones/`: network architecture that can be plugged into lightning modules.
- `lightning_modules/`: wrapper around backbone modules to handle loss computation, optimizer, etc for training and inferrence (agnostic to backbone but specific to ML task).
- `lightning_modules/`: wrapper around backbone modules to handle loss computation, optimizer, etc for training and inference (agnostic to backbone but specific to ML task).

Evaluation:
- `metrics/`: tested suite of iterative metrics (memory efficient) for deterministic and generative models.
- `evaluation/`: scripts for running metrics over model predictions and plotting.

Pipeline:
- `main_hydra.py`: script to run training or inferrence with hydra configuration.
- `documentation/`: quickstart code for training and inferrence from a notebook.
- `main_hydra.py`: script to run training or inference with hydra configuration.
- `documentation/`: quickstart code for training and inference from a notebook.

## Installation

Expand Down Expand Up @@ -85,7 +85,7 @@ done
You can follow instructions in [`documentation/archesweather-tutorial.ipynb`](documentation/archesweather-tutorial.ipynb) to load the models and run inference with them. See [`documentation/archesweathergen_pipeline.md`](documentation/archesweathergen_pipeline.md) to run training.

### Downloading ERA5 statistics
To compute brier score on ERA5 (needed to instantiate ArchesWeather models for inferrence or training), you will need to download ERA5 quantiles:
To compute brier score on ERA5 (needed to instantiate ArchesWeather models for inference or training), you will need to download ERA5 quantiles:
```sh
src="https://huggingface.co/gcouairon/ArchesWeather/resolve/main"
wget -O geoarches/stats/era5-quantiles-2016_2022.nc $src/era5-quantiles-2016_2022.nc
Expand Down
86 changes: 48 additions & 38 deletions documentation/archesweather-tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"metadata": {},
"outputs": [],
"source": [
"autoreload 2"
"%autoreload 2"
]
},
{
Expand Down Expand Up @@ -81,7 +81,7 @@
],
"source": [
"# load_module will look in modelstore/\n",
"model, config = load_module('archesweather-m-seed0')"
"model, config = load_module(\"archesweather-m-seed0\")"
]
},
{
Expand All @@ -104,10 +104,14 @@
"source": [
"# we can also load the 4-members deterministic ensemble ArchesWeather-Mx4 like so:\n",
"\n",
"model, config = load_module('archesweather-m-seed0',\n",
" avg_with_modules=['archesweather-m-seed1',\n",
" 'archesweather-m-skip-seed0',\n",
" 'archesweather-m-skip-seed1'])"
"model, config = load_module(\n",
" \"archesweather-m-seed0\",\n",
" avg_with_modules=[\n",
" \"archesweather-m-seed1\",\n",
" \"archesweather-m-skip-seed0\",\n",
" \"archesweather-m-skip-seed1\",\n",
" ],\n",
")"
]
},
{
Expand Down Expand Up @@ -142,11 +146,12 @@
"# load sample from dataloader\n",
"from geoarches.dataloaders.era5 import Era5Forecast\n",
"\n",
"ds = Era5Forecast(path='data/era5_240/full', # default path\n",
" load_prev=True, # whether to load previous state\n",
" norm_scheme='pangu', # default normalization scheme\n",
" domain='test', # domain to consider. domain = 'test' loads the 2020 period\n",
" )"
"ds = Era5Forecast(\n",
" path=\"data/era5_240/full\", # default path\n",
" load_prev=True, # whether to load previous state\n",
" norm_scheme=\"pangu\", # default normalization scheme\n",
" domain=\"test\", # domain to consider. domain = 'test' loads the 2020 period\n",
")"
]
},
{
Expand All @@ -171,10 +176,10 @@
}
],
"source": [
"#The dataset returns a dict of TensorDict:\n",
"print('keys in a sample:', ds[0].keys())\n",
"# The dataset returns a dict of TensorDict:\n",
"print(\"keys in a sample:\", ds[0].keys())\n",
"# a state contains level and surface variables in a TensorDict, which is a specialized structure for dict of tensors\n",
"print('sample state', ds[0]['state'])"
"print(\"sample state\", ds[0][\"state\"])"
]
},
{
Expand Down Expand Up @@ -208,7 +213,8 @@
"# we can visualize a sample, e.g. Z500, with the following\n",
"\n",
"import matplotlib.pyplot as plt\n",
"plt.imshow(ds[0]['state']['level'][0, 7], cmap='terrain')"
"\n",
"plt.imshow(ds[0][\"state\"][\"level\"][0, 7], cmap=\"terrain\")"
]
},
{
Expand All @@ -220,13 +226,14 @@
"source": [
"# now we can run inference with the deterministic model ArchesWeather:\n",
"import torch\n",
"\n",
"torch.set_grad_enabled(False)\n",
"\n",
"device = 'cuda:0'\n",
"device = \"cuda:0\"\n",
"\n",
"model = model.to(device)\n",
"\n",
"batch = {k:v[None].to(device) for k, v in ds[0].items()}\n",
"batch = {k: v[None].to(device) for k, v in ds[0].items()}\n",
"\n",
"pred = model(batch).cpu()"
]
Expand Down Expand Up @@ -284,15 +291,15 @@
}
],
"source": [
"# we can visualize predictions, compared to \n",
"# we can visualize predictions, compared to ground truth\n",
"\n",
"print('24h Z500 prediction')\n",
"plt.imshow(pred['level'][0, 0, 7], cmap='terrain')\n",
"print(\"24h Z500 prediction\")\n",
"plt.imshow(pred[\"level\"][0, 0, 7], cmap=\"terrain\")\n",
"plt.show()\n",
"\n",
"print('24h ground truth')\n",
"plt.imshow(ds[0]['next_state']['level'][0, 7], cmap='terrain')\n",
"plt.show()\n"
"print(\"24h ground truth\")\n",
"plt.imshow(ds[0][\"next_state\"][\"level\"][0, 7], cmap=\"terrain\")\n",
"plt.show()"
]
},
{
Expand Down Expand Up @@ -320,7 +327,7 @@
],
"source": [
"# we can also do multistep rollouts with the deterministic model:\n",
"pred_multistep = model.forward_multistep(batch, iters=10) # this does a 10-day rollout\n",
"pred_multistep = model.forward_multistep(batch, iters=10) # this does a 10-day rollout\n",
"# the rollout dimension is the second dimension in the predicted tensors\n",
"pred_multistep"
]
Expand Down Expand Up @@ -353,10 +360,10 @@
],
"source": [
"# loading ArchesWeatherFlow\n",
"device = 'cuda:0'\n",
"device = \"cuda:0\"\n",
"\n",
"# load_module will look in modelstore/\n",
"gen_model, gen_config = load_module('archesweathergen')\n",
"gen_model, gen_config = load_module(\"archesweathergen\")\n",
"\n",
"gen_model = gen_model.to(device)"
]
Expand All @@ -378,16 +385,15 @@
"source": [
"# run model on a sample\n",
"seed = 0\n",
"num_steps = 25 # if not provided to model.sample, model will use the default value (25)\n",
"num_steps = 25 # if not provided to model.sample, model will use the default value (25)\n",
"scale_input_noise = 1.05\n",
"\n",
"batch = {k:v[None].to(device) for k, v in ds[0].items()}\n",
"batch = {k: v[None].to(device) for k, v in ds[0].items()}\n",
"\n",
"\n",
"sample = gen_model.sample(batch, \n",
" seed=seed, \n",
" num_steps=num_steps,\n",
" scale_input_noise=scale_input_noise).cpu()"
"sample = gen_model.sample(\n",
" batch, seed=seed, num_steps=num_steps, scale_input_noise=scale_input_noise\n",
").cpu()"
]
},
{
Expand All @@ -397,16 +403,20 @@
"metadata": {},
"outputs": [],
"source": [
"# run a model auto-regressively \n",
"# run a model auto-regressively\n",
"\n",
"rollout_iterations = 10\n",
"n_members = 10\n",
"\n",
"sample_multistep = [gen_model.sample_rollout(batch,\n",
" batch_nb=0, # should be different for each input\n",
" member=i,\n",
" iterations=rollout_iterations)\n",
" for i in range(n_members)]"
"sample_multistep = [\n",
" gen_model.sample_rollout(\n",
" batch,\n",
" batch_nb=0, # should be different for each input\n",
" member=i,\n",
" iterations=rollout_iterations,\n",
" )\n",
" for i in range(n_members)\n",
"]"
]
},
{
Expand Down
5 changes: 3 additions & 2 deletions geoarches/backbones/archesweather.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import importlib

import geoarches.stats as geoarches_stats
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa N812
import torch.utils.checkpoint as gradient_checkpoint
from geoarches.backbones.archesweather_layers import ICNR_init
from tensordict.tensordict import TensorDict
from timm.layers.mlp import SwiGLU

import geoarches.stats as geoarches_stats
from geoarches.backbones.archesweather_layers import ICNR_init

from .archesweather_layers import (
CondBasicLayer,
DownSample,
Expand Down
2 changes: 1 addition & 1 deletion geoarches/backbones/archesweather_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ class EarthSpecificBlock(nn.Module):
3D Transformer Block
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
input_resolution (tuple[int]): Input resolution.
num_heads (int): Number of attention heads.
window_size (tuple[int]): Window size [pressure levels, latitude, longitude].
shift_size (tuple[int]): Shift size for SW-MSA [pressure levels, latitude, longitude].
Expand Down
12 changes: 6 additions & 6 deletions geoarches/backbones/weatherlearn_utils/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None):

def forward(self, x: torch.Tensor):
B, C, H, W = x.shape
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
assert H == self.img_size[0] and W == self.img_size[1], (
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
)
x = self.pad(x)
x = self.proj(x)
if self.norm is not None:
Expand Down Expand Up @@ -102,9 +102,9 @@ def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None):

def forward(self, x: torch.Tensor):
B, C, L, H, W = x.shape
assert (
L == self.img_size[0] and H == self.img_size[1] and W == self.img_size[2]
), f"Input image size ({L}*{H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}*{self.img_size[2]})."
assert L == self.img_size[0] and H == self.img_size[1] and W == self.img_size[2], (
f"Input image size ({L}*{H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}*{self.img_size[2]})."
)
x = self.pad(x)
x = self.proj(x)
if self.norm:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def window_reverse(windows, window_size, Pl, Lat, Lon):
def get_shift_window_mask(input_resolution, window_size, shift_size):
"""
Along the longitude dimension, the leftmost and rightmost indices are actually close to each other.
If half windows apper at both leftmost and rightmost positions, they are dircetly merged into one window.
If half windows apper at both leftmost and rightmost positions, they are directly merged into one window.
Args:
input_resolution (tuple[int]): [pressure levels, latitude, longitude]
window_size (tuple[int]): Window size [pressure levels, latitude, longitude].
Expand Down
2 changes: 1 addition & 1 deletion geoarches/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ limit_val_batches: null
accumulate_grad_batches: 1
debug: False # set to True to debug

mode: train # Specify "train" or "test" to run training or inferrence.
mode: train # Specify "train" or "test" to run training or inference.
4 changes: 2 additions & 2 deletions geoarches/dataloaders/era5.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def __init__(
load_clim: Whether to load climatology.
norm_scheme: Normalization scheme to use. Can be None to perform no normalization.
timedelta_hours: Time difference (hours) between 2 consecutive timestamps. If not expecified,
default is 6 or 12, depeding on domain.
default is 6 or 12, depending on domain.
variables: Variables to load from dataset. Dict holding variable lists mapped by their keys to be processed into tensordict.
e.g. {surface:[...], level:[...] By default uses standard 6 level and 4 surface vars.
dimension_indexers: Dict of dimensions to select using Dataset.sel(dimension_indexers).
Expand All @@ -282,7 +282,7 @@ def __init__(
if self.load_prev:
start_time = start_time - self.lead_time_hours * np.timedelta64(1, "h")
end_time = np.datetime64(
f"{year+1}-01-01T00:00:00"
f"{year + 1}-01-01T00:00:00"
) + self.multistep * self.lead_time_hours * np.timedelta64(1, "h")
print("start time", start_time)
super().set_timestamp_bounds(start_time, end_time)
Expand Down
Loading
Loading