diff --git a/docs/img/healda.png b/docs/img/healda.png
new file mode 100644
index 0000000000..0a54ff10f1
Binary files /dev/null and b/docs/img/healda.png differ
diff --git a/examples/weather/healda/README.md b/examples/weather/healda/README.md
new file mode 100644
index 0000000000..c2ace6d347
--- /dev/null
+++ b/examples/weather/healda/README.md
@@ -0,0 +1,203 @@
+
+# HealDA: Highlighting the importance of initial errors in end-to-end AI weather forecasts
+
+
+
+
+
+[๐ arXiv](https://arxiv.org/abs/2601.17636) ยท ๐ฆ Checkpoints (coming soon)
+
+---
+
+## Problem Overview
+
+Machine-learning (ML) weather models now rival leading numerical weather prediction (NWP) systems in medium-range skill. However, almost all still rely on NWP data assimilation (DA) to provide initial conditions, tying them to expensive infrastructure and limiting the practical speed and accuracy gains of ML.
+
+**HealDA** is a global ML-based data assimilation system that maps satellite and conventional observations (microwave sounders, aircraft, radiosondes, surface stations) to a 1ยฐ atmospheric state on the HEALPix grid. HealDA analyses can initialize off-the-shelf ML forecast models (e.g., FourCastNet3, Aurora, FengWu) without fine-tuning, enabling end-to-end ML weather forecasting with less than one day loss of skill compared to ERA5 initialization.
+
+---
+
+## Installation
+
+### Using uv
+
+```bash
+# 1. From PhysicsNeMo root directory
+cd /path/to/physicsnemo
+
+# 2. Create .venv and install PNM
+uv sync
+
+# 3. Activate the virtual environment
+source .venv/bin/activate
+
+# 4. Install earth2grid
+uv pip install setuptools hatchling
+uv pip install --no-build-isolation \
+ "earth2grid @ https://github.com/NVlabs/earth2grid/archive/main.tar.gz"
+
+# 5. Install healda dependencies
+uv pip install -r examples/weather/healda/requirements.txt
+```
+
+### Using pip
+
+```bash
+# 1. Install PhysicsNeMo
+pip install nvidia-physicsnemo
+
+# 2. Install earth2grid
+pip install setuptools hatchling
+pip install --no-build-isolation https://github.com/NVlabs/earth2grid/archive/main.tar.gz
+
+# 3. Install healda dependencies
+pip install -r requirements.txt
+```
+
+> **Warning:** Include `--no-build-isolation` when installing earth2grid to avoid building against the wrong PyTorch version.
+
+---
+
+## Configuration
+
+Create a `.env` file in the `examples/weather/healda/` directory with the following:
+
+```bash
+# Project paths
+PROJECT_ROOT=/path/to/project
+
+# Raw observation data (NC4 files downloaded from NOAA S3)
+UFS_RAW_OBS_DIR=/path/to/raw_obs
+
+# Processed observation data (parquet from ETL)
+UFS_OBS_PATH=/path/to/processed_obs
+# UFS_OBS_PROFILE=
+
+# ERA5 HEALPix zarr (training targets)
+V6_ERA5_ZARR=/path/to/era5_hpx.zarr
+# V6_ERA5_ZARR_PROFILE=
+
+# Land fraction mask
+UFS_LAND_DATA_ZARR=/path/to/land_frac.zarr
+# UFS_LAND_DATA_PROFILE=
+```
+
+> **Note:** The `*_PROFILE` variables configure [rclone](https://rclone.org/) S3 profiles for cloud storage access. Leave empty for local paths.
+
+---
+
+## Data Preparation
+
+HealDA requires preprocessed observation data and ERA5 target fields. We source observational data from the [NOAA Unified Forecast System (UFS) GEFSv13 Replay dataset](https://psl.noaa.gov/data/ufs_replay/) (NOAA, 2024).
+
+See [`datasets/etl/`](datasets/etl/) for ETL scripts to prepare observation data into a parquet data format.
+
+---
+
+## Training
+
+```bash
+python train.py --name era5-v2-dense-noInfill-10M-fusion512-lrObs1e-4
+```
+
+This uses the paper configuration defined in `train.py`. See `python train.py --help` for options.
+
+> **Resource Requirements:** Training takes approximately **8.3 days on 1 H100 node** (8 GPUs total) with batch size 1 per GPU.
+
+---
+
+## Inference
+
+### Step 1: Generate DA Analysis (Initial Conditions)
+
+The following produces analyses for all of 2022. `See inference_helpers.py` to configure inference. Inference only requires ~20GB of memory and can produce an analysis in under 1 second on a single H100.
+```bash
+python inference.py \
+ /path/to/checkpoint.pt \
+ --output_path /path/to/da_output.zarr \
+ --context_start -21 \
+ --context_end 3 \
+ --time_frequency 6h \
+ --num_samples -1 \
+ --batch_gpu 1
+```
+
+### Step 2: Forecast from HealDA initial conditions
+
+#### Installing FCN3 dependencies
+
+FCN3 requires `earth2studio`. Recommended to install torch-harmonics with CUDA extensions for best performance:
+
+```bash
+# Using uv
+export FORCE_CUDA_EXTENSION=1
+uv pip install torch-harmonics==0.8.0 --no-build-isolation
+uv pip install earth2studio[fcn3]
+
+# Or using pip
+export FORCE_CUDA_EXTENSION=1
+pip install torch-harmonics==0.8.0 --no-build-isolation
+pip install earth2studio[fcn3]
+```
+
+> **Note:** See [Earth2Studio docs](https://nvidia.github.io/earth2studio/userguide/about/install.html) for more information or installing other forecast models beyond FCN3.
+
+#### Running forecasts
+
+Use the DA output to initialize the FCN3 forecast model and create a 10-day forecast (40 6-hour steps):
+
+```bash
+python scripts/forecast.py \
+ --init_path /path/to/da_output.zarr \
+ --out_dir /path/to/forecast_output \
+ --model FCN3 \
+ --num_steps 40 \
+ --num_ensemble 1 \
+ --num_times 1
+```
+
+> **Note:** The forecast script:
+> - Regrids HealDA analysis (HPX64) โ 0.25ยฐ lat-lon for FCN3 input
+> - Regrids FCN3 output (0.25ยฐ lat-lon) โ HPX64 NEST format for storage
+
+> **ERA5-initialized forecasts:** To create forecasts from ERA5 instead of DA output, run `inference.py` with `--use_analysis` flag to create an ERA5 zarr in the same format, then use that as `--init_path`.
+
+
+### Step 3: Score Forecasts
+
+Score forecasts against a reference dataset (also on HPX64 grid):
+
+```bash
+python scripts/score_forecast.py \
+ --forecast_path /path/to/forecast.zarr \
+ --reference_path /path/to/era5.zarr \
+ --output_path /path/to/scores.nc
+```
+
+To plot the metrics:
+
+```bash
+python scripts/plot_panel.py \
+ --stats /path/to/scores.nc \
+ --labels "HealDA-initialized FCN3" \
+ --metric crps \
+ --output_path /path/to/plots/crps_comparison.pdf
+```
+
+See `python inference.py --help` and `python scripts/forecast.py --help` for full options.
+
+---
+
+## Citation
+
+```bibtex
+@misc{gupta2026healdahighlightingimportanceinitial,
+ title={HealDA: Highlighting the importance of initial errors in end-to-end AI weather forecasts},
+ author={Aayush Gupta and Akshay Subramaniam and Michael S. Pritchard and Karthik Kashinath and Sergey Frolov and Kelsey Lieberman and Christopher Miller and Nicholas Silverman and Noah D. Brenowitz},
+ year={2026},
+ eprint={2601.17636},
+ archivePrefix={arXiv},
+ primaryClass={physics.ao-ph},
+ url={https://arxiv.org/abs/2601.17636},
+}
+```
diff --git a/examples/weather/healda/config/__init__.py b/examples/weather/healda/config/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/weather/healda/config/environment.py b/examples/weather/healda/config/environment.py
new file mode 100644
index 0000000000..d6e6b64316
--- /dev/null
+++ b/examples/weather/healda/config/environment.py
@@ -0,0 +1,64 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+import dotenv
+
+dotenv.load_dotenv(dotenv.find_dotenv(usecwd=True))
+
+non_config = dir()
+
+CACHE_DIR = os.path.expanduser("~/.cache/healda")
+
+###############
+# ERA5 inputs #
+###############
+V6_ERA5_ZARR = os.getenv("V6_ERA5_ZARR", "")
+V6_ERA5_ZARR_PROFILE = os.getenv("V6_ERA5_ZARR_PROFILE", "")
+
+########
+# UFS #
+########
+UFS_HPX6_ZARR = os.getenv("UFS_HPX6_ZARR", "")
+UFS_LAND_DATA_ZARR = os.getenv("UFS_LAND_DATA_ZARR", "")
+UFS_LAND_DATA_PROFILE = os.getenv("UFS_LAND_DATA_PROFILE", "")
+UFS_ZARR_PROFILE = os.getenv("UFS_ZARR_PROFILE", "")
+UFS_OBS_PATH = os.getenv("UFS_OBS_PATH", "")
+UFS_OBS_PROFILE = os.getenv("UFS_OBS_PROFILE", "")
+# project file
+PROJECT_ROOT = os.getenv("PROJECT_ROOT", "")
+DATA_ROOT = os.getenv("DATA_ROOT", os.path.join(PROJECT_ROOT, "datasets"))
+CHECKPOINT_ROOT = os.getenv(
+ "CHECKPOINT_ROOT", os.path.join(PROJECT_ROOT, "training-runs")
+)
+
+
+_config_vars = dict(vars())
+
+
+def print_config():
+ print("Environment settings:")
+ print("-" * 80)
+ for v in _config_vars:
+ if v == "non_config":
+ continue
+
+ if v in non_config:
+ continue
+
+ value = _config_vars[v]
+ print(f"{v}={value}")
+ print("-" * 80)
diff --git a/examples/weather/healda/config/model_config.py b/examples/weather/healda/config/model_config.py
new file mode 100644
index 0000000000..7318000321
--- /dev/null
+++ b/examples/weather/healda/config/model_config.py
@@ -0,0 +1,130 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Training configuration dataclasses for HealDA.
+
+These are training-specific configs with serialization for checkpointing.
+Model constructor parameter types (SensorEmbedderConfig, ModelSensorConfig)
+are in physicsnemo.experimental.models.healda.config.
+"""
+
+import dataclasses
+import json
+
+from physicsnemo.experimental.models.healda.config import (
+ ModelSensorConfig,
+ SensorEmbedderConfig,
+)
+
+
+def _filter_to_dataclass_fields(d: dict, cls) -> dict:
+ """Filter dict to only include fields defined in the dataclass."""
+ valid_fields = {f.name for f in dataclasses.fields(cls)}
+ return {k: v for k, v in d.items() if k in valid_fields}
+
+
+@dataclasses.dataclass(frozen=True)
+class ObsConfig:
+ """Observation dataset configuration for training."""
+
+ use_obs: bool = False
+ innovation_type: str = "none"
+ context_start: int = -21 # start/end in hours
+ context_end: int = 3
+ randomize_interval_times: bool = False
+ out_channels: int = -1
+ embed_dim: int = 0
+ use_infrared: bool = False
+ use_conv: bool = False
+ use_density: bool = False
+ conv_uv_in_situ_only: bool = False
+ conv_gps_level1_only: bool = False
+ dropout: float = 0.0
+ # Optional list of global observation channel IDs to drop
+ drop_obs_channel_ids: list[int] | None = None
+
+
+@dataclasses.dataclass
+class ModelConfigV1:
+ """Training configuration for HealDA."""
+
+ architecture: str = "dit-l_reg_hpx6_per_sensor"
+ label_dim: int = 0
+ out_channels: int = 1
+ condition_channels: int = 0
+ time_length: int = 1
+ label_dropout: float = 0.0
+ legacy_label_bias: bool = False
+
+ obs_config: ObsConfig = dataclasses.field(default_factory=ObsConfig)
+
+ p_dropout: float = 0.0
+ drop_path: float = 0.0
+ group_norm_eps: float = 1e-6
+ pos_emb_gains: bool = False
+
+ # DiT settings
+ dit_temporal_attention: bool = False
+ compile_dit: bool = False
+ qk_rms_norm: bool = False
+ embed_v2: bool = False
+ allow_nans_condition: bool = False
+ emb_channels: int | None = None
+ noise_channels: int | None = None
+ as_vit: bool = False # run DiT without noise/label conditioning
+
+ # Obs encoder settings
+ sensor_embedder_config: SensorEmbedderConfig | None = dataclasses.field(
+ default_factory=SensorEmbedderConfig
+ )
+ sensors: dict[str, ModelSensorConfig] | None = None
+
+ def dumps(self) -> str:
+ """Serialize config to JSON string for checkpointing."""
+ return json.dumps(dataclasses.asdict(self))
+
+ @classmethod
+ def loads(cls, s: str) -> "ModelConfigV1":
+ """Deserialize config from JSON string."""
+ d = json.loads(s)
+
+ # Filter out fields that aren't in the current model config definition
+ d = _filter_to_dataclass_fields(d, cls)
+
+ if isinstance(d.get("obs_config"), dict):
+ d["obs_config"] = ObsConfig(
+ **_filter_to_dataclass_fields(d["obs_config"], ObsConfig)
+ )
+
+ if isinstance(d.get("sensor_embedder_config"), dict):
+ embed_cfg = d["sensor_embedder_config"]
+ # Backwards compat: old checkpoints had sensors nested inside
+ nested_sensors = embed_cfg.pop("sensors", None) or embed_cfg.pop(
+ "sensor_config", None
+ )
+ if nested_sensors and "sensors" not in d:
+ d["sensors"] = nested_sensors
+ d["sensor_embedder_config"] = SensorEmbedderConfig(
+ **_filter_to_dataclass_fields(embed_cfg, SensorEmbedderConfig)
+ )
+
+ if isinstance(d.get("sensors"), dict):
+ d["sensors"] = {
+ k: ModelSensorConfig(**v) if isinstance(v, dict) else v
+ for k, v in d["sensors"].items()
+ }
+
+ return cls(**d)
diff --git a/examples/weather/healda/config/training/__init__.py b/examples/weather/healda/config/training/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/weather/healda/config/training/loop.py b/examples/weather/healda/config/training/loop.py
new file mode 100644
index 0000000000..8887bd879f
--- /dev/null
+++ b/examples/weather/healda/config/training/loop.py
@@ -0,0 +1,75 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import dataclasses
+from typing import Optional
+
+
+@dataclasses.dataclass
+class TrainingLoopBase:
+ """Base training config"""
+
+ run_dir: str = "." # Output directory.
+ seed: int = 0 # Global random seed.
+ batch_size: int = 512 # Total batch size for one training iteration.
+ batch_gpu: Optional[int] = None # Limit batch size per GPU, None = no limit.
+ enable_ema: bool = False
+ ema_halflife_kimg: int = (
+ 500 # Half-life of the exponential moving average (EMA) of model weights.
+ )
+ ema_rampup_ratio: float = 0.05 # EMA ramp-up coefficient, None = no rampup.
+ lr_rampup_img: int = 10_000 # Learning rate ramp-up duration.
+ flat_imgs: int = 1_500_000 - 10_000
+ decay_imgs: int = 1_500_000
+ lr_min: float = 1e-6
+ lr: float = 1e-4
+
+ loss_reduction: str = "v1"
+ """
+ Controls how the [b c t x] shaped loss is reduced, where 'b' is the
+
+ Options:
+ - v1 (default) - sum over c x, mean over b c
+ - mean - mean over all dimensions
+ """
+
+ loss_scaling: float = 1.0 # Loss scaling factor for reducing FP16 under/overflows.
+ gradient_clip_max_norm: Optional[float] = None
+ total_ticks: int = 10
+ print_steps: int = 50
+ steps_per_tick: int = 1024
+ snapshot_ticks: int | None = (
+ 50 # How often to save network snapshots, None = disable.
+ )
+ state_dump_ticks: int | None = (
+ 500 # How often to dump training state, None = disable.
+ )
+
+ test_with_single_batch: bool = False
+ """Only load a single batch of data for testing and profiling purposes"""
+
+ # Performance optimizations
+ # Mixed precision and performance options
+ cudnn_benchmark: bool = True # Enable torch.backends.cudnn.benchmark?
+ tf32: bool = True
+ bf16: bool = True
+ compile_optimizer: bool = False # if true wrap the optimizer with torch compile
+
+ # wandb
+ wandb_id: str | None = None # will be read from checkpoint if not provided
+
+ # logging
+ log_parameter_norm: bool = False
+ log_parameter_grad_norm: bool = False
diff --git a/examples/weather/healda/datasets/__init__.py b/examples/weather/healda/datasets/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/weather/healda/datasets/analysis_loaders.py b/examples/weather/healda/datasets/analysis_loaders.py
new file mode 100644
index 0000000000..8a8e178ba9
--- /dev/null
+++ b/examples/weather/healda/datasets/analysis_loaders.py
@@ -0,0 +1,257 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pathlib
+from typing import Optional
+
+import numpy as np
+import pandas as pd
+import zarr.api.asynchronous
+
+from datasets import catalog
+from datasets.base import BatchInfo, TimeUnit, VariableConfig
+from datasets.zarr_loader import NO_LEVEL, ZarrLoader
+
+__all__ = ["ERA5Loader", "get_batch_info"]
+
+SST_LAND_FILL_VALUE = 290
+MONTHLY_SST = "monthly_sst"
+
+HPX_LEVEL = 6
+
+LABELS = [
+ "icon",
+ "era5",
+ "ufs",
+]
+
+
+class PackedZarrLoader:
+ """Async loader for packed Zarr arrays with time-based indexing."""
+
+ def __init__(self, entry, array_name, label: int):
+ self.time = entry.to_xarray(chunks=None).indexes["time"]
+ self._store = entry.to_zarr().store
+ self._array = None
+ self._name = array_name
+ self._label = label
+
+ async def sel_time(self, times):
+ if self._array is None:
+ group = await zarr.api.asynchronous.open_group(
+ self._store,
+ use_consolidated=True,
+ mode="r",
+ )
+ self._array = await group.get(self._name)
+
+ t = self.time.get_indexer(times)
+ if any(t == -1):
+ raise KeyError(t[t == -1])
+
+ state = await self._array.get_orthogonal_selection(t)
+ return {
+ "state": state,
+ "label": [self._label] * len(times),
+ }
+
+
+class ERA5Loader:
+ """Async loader for ERA5 reanalysis data on HEALPix grid."""
+
+ def __init__(self, variable_config: VariableConfig):
+ self.variable_config = variable_config
+ variables_2d = [
+ "sstk",
+ "ci",
+ "msl",
+ "10u",
+ "10v",
+ "2t",
+ "tcwv",
+ "100u",
+ "100v",
+ ]
+ # add tp, handling missing data from 2022-2023, deal with surface pressure
+ entry = catalog.era5_hpx6()
+ self._loader = ZarrLoader(
+ path=entry.to_store(),
+ variables_3d=["u", "v", "t", "z", "q"],
+ variables_2d=variables_2d,
+ level_coord_name="levels",
+ levels=variable_config.levels,
+ )
+
+ async def sel_time(self, times):
+ data = await self._loader.sel_time(times)
+ self._convert_to_standard(data)
+ shape = (len(times), 4**HPX_LEVEL * 12)
+
+ state = _collect_fields(
+ _get_index(self.variable_config), data, shape=shape
+ ) # c t x
+ state = np.moveaxis(state, 0, 1) # t c x
+ return {
+ "state": state,
+ "label": [LABELS.index("era5")] * len(times),
+ }
+
+ def _convert_to_standard(self, data):
+ if ("sstk", NO_LEVEL) in data:
+ sstk = data[("sstk", NO_LEVEL)]
+
+ if not np.ma.isMaskedArray(sstk):
+ sstk = np.ma.masked_invalid(sstk)
+
+ data[("sstk", NO_LEVEL)] = sstk.filled(SST_LAND_FILL_VALUE)
+
+ if ("ci", NO_LEVEL) in data:
+ ci = data[("ci", NO_LEVEL)]
+
+ if not np.ma.isMaskedArray(ci):
+ ci = np.ma.masked_invalid(ci)
+
+ data[("ci", NO_LEVEL)] = ci.filled(0)
+
+ # era5 precip is in liquid water equivalent accumulated over 1 hour (m)
+ # icon is in mass flux units (kg / s / m^2)
+ # unit conversion: tp / 3600 * density water = tp / 3600 * 1000
+ if ("tp", NO_LEVEL) in data:
+ water_density = 1000
+ seconds_per_hour = 3600
+ data[("tp", NO_LEVEL)] = (
+ data[("tp", NO_LEVEL)] * water_density / seconds_per_hour
+ )
+
+ fields_out_map = {
+ # mapping of ecmwf name to icon name
+ "tclw": "cllvi",
+ "tciw": "clivi",
+ "2t": "tas",
+ "10u": "uas",
+ "10v": "vas",
+ "100u": "100u",
+ "100v": "100v",
+ "msl": "pres_msl",
+ "tp": "pr",
+ "sstk": "sst",
+ "ci": "sic",
+ "tcwv": "prw",
+ "u": "U",
+ "v": "V",
+ "t": "T",
+ "z": "Z",
+ "q": "Q",
+ "tosbcs": MONTHLY_SST,
+ }
+ for key, value in list(data.items()):
+ match key:
+ case (name, level):
+ if name in fields_out_map:
+ data[(fields_out_map[name], level)] = value
+
+
+def get_batch_info(
+ config: VariableConfig,
+ time_step: int = 1,
+ time_unit: TimeUnit = TimeUnit.HOUR,
+) -> BatchInfo:
+ """Returns BatchInfo for the given variable config"""
+ return BatchInfo(
+ channels=[_encode_channel(tup) for tup in _get_index(config).tolist()],
+ scales=_get_std(config),
+ center=_get_mean(config),
+ time_step=time_step,
+ time_unit=time_unit,
+ )
+
+
+def _get_index(config: VariableConfig):
+ return pd.MultiIndex.from_tuples(
+ [(v, level) for v in config.variables_3d for level in config.levels]
+ + [(v, NO_LEVEL) for v in config.variables_2d],
+ names=["variable", "level"],
+ )
+
+
+def _collect_fields(
+ index,
+ data: dict[tuple[str, int | None], np.ndarray],
+ shape,
+ prefix: Optional[str] = None,
+) -> np.ndarray:
+ out = np.full(
+ shape=(index.size,) + shape,
+ dtype=np.float32,
+ fill_value=np.nan,
+ )
+ for i, (var, lev) in enumerate(index):
+ key = (prefix, var, lev) if prefix is not None else (var, lev)
+ if key in data:
+ out[i] = data[key]
+ return out
+
+
+def _get_mean(config: VariableConfig) -> np.ndarray:
+ mean = _get_nearest_stats(config)["mean"].values
+ return mean
+
+
+def _get_std(config: VariableConfig) -> np.ndarray:
+ std = _get_nearest_stats(config)["std"].values
+ return std
+
+
+def _encode_channel(channel) -> str:
+ name, level = channel
+ if level != NO_LEVEL:
+ return f"{name}{level}"
+ else:
+ return name
+
+
+def _load_raw_stats(config: VariableConfig) -> pd.DataFrame:
+ if config.name == "ufs":
+ file_name = "ufs_v0_stats.csv"
+ elif config.name == "era5":
+ file_name = "era5_13_levels_stats.csv"
+ else:
+ raise ValueError(f"Unknown dataset: {config.name}")
+ path = pathlib.Path(__file__).parent / file_name
+ return pd.read_csv(path).set_index(["variable", "level"])
+
+
+# def get_sst_stats(config: VariableConfig = _default_config):
+# df = _load_raw_stats(config)
+# row = df.loc[("sst", NO_LEVEL)]
+# return row["mean"].item(), row["std"].item()
+
+
+def _get_nearest_stats(config: VariableConfig):
+ # To handle float levels, gets nearest level
+ raw = _load_raw_stats(config)
+ idx = _get_index(config)
+
+ mapped_idx = []
+ for var, level in idx:
+ if level != NO_LEVEL:
+ available = raw.loc[var].index.values
+ nearest = available[np.abs(available - level).argmin()]
+ mapped_idx.append((var, nearest))
+ else:
+ mapped_idx.append((var, level))
+
+ mapped_idx = pd.MultiIndex.from_tuples(mapped_idx, names=["variable", "level"])
+ return raw.loc[mapped_idx]
diff --git a/examples/weather/healda/datasets/base.py b/examples/weather/healda/datasets/base.py
new file mode 100644
index 0000000000..9f01edb9c2
--- /dev/null
+++ b/examples/weather/healda/datasets/base.py
@@ -0,0 +1,200 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import dataclasses
+import json
+from datetime import timedelta
+from enum import Enum
+from typing import Any, Protocol
+
+import numpy as np
+import torch
+
+from physicsnemo.experimental.models.healda import Domain
+
+
+class TimeUnit(Enum):
+ """Time units supported by the dataset.
+ Values are the pandas frequency strings (offset aliases)"""
+
+ HOUR = "h"
+ DAY = "D"
+ MINUTE = "min"
+ SECOND = "s"
+
+ def to_timedelta(self, steps: float) -> timedelta:
+ return {
+ TimeUnit.HOUR: timedelta(hours=steps),
+ TimeUnit.DAY: timedelta(days=steps),
+ TimeUnit.MINUTE: timedelta(minutes=steps),
+ TimeUnit.SECOND: timedelta(seconds=steps),
+ }[self]
+
+
+@dataclasses.dataclass
+class BatchInfo:
+ """Metadata describing model output"""
+
+ channels: list[str]
+ time_step: int = 1 # Time (in units `time_unit`) between consecutive frames
+ time_unit: TimeUnit = TimeUnit.HOUR
+ scales: Any | None = None
+ center: Any | None = None
+
+ def __post_init__(self):
+ if isinstance(self.time_unit, str):
+ raise ValueError("Time unit is an str. Should be a TimeUnit.")
+
+ @staticmethod
+ def loads(s):
+ kw = json.loads(s)
+
+ if "time_unit" in kw:
+ kw["time_unit"] = TimeUnit(kw["time_unit"])
+
+ # Ignore deprecated residual normalization field if present
+ kw.pop("residual_normalization", None)
+
+ return BatchInfo(**kw)
+
+ def asdict(self):
+ """Return a dictionary representation of the BatchInfo, suitable for JSON serialization."""
+ out = {}
+ out["channels"] = self.channels
+ out["time_step"] = self.time_step
+ out["time_unit"] = self.time_unit.value # TimeUnit is always a TimeUnit enum
+ # Convert numpy arrays to lists for JSON compatibility
+ if self.scales is not None:
+ out["scales"] = np.asarray(self.scales).tolist()
+ else:
+ out["scales"] = None
+ if self.center is not None:
+ out["center"] = np.asarray(self.center).tolist()
+ else:
+ out["center"] = None
+ return out
+
+ def sel_channels(self, channels: list[str]):
+ channels = list(channels)
+ index = np.array([self.channels.index(ch) for ch in channels])
+ scales = None
+ if self.scales is not None:
+ scales = np.asarray(self.scales)[index]
+
+ center = None
+ if self.center is not None:
+ center = np.asarray(self.center)[index]
+
+ return BatchInfo(
+ time_step=self.time_step,
+ time_unit=self.time_unit,
+ channels=channels,
+ scales=scales,
+ center=center,
+ )
+
+ def denormalize(self, x):
+ scales = torch.as_tensor(self.scales).to(x)
+ scales = scales.view(-1, 1, 1)
+
+ center = torch.as_tensor(self.center).to(x)
+ center = center.view(-1, 1, 1)
+ return x * scales + center
+
+ def get_time_delta(self, t: int) -> timedelta:
+ """Gets time offset of the t-th frame in a frame sequence."""
+ total_steps = t * self.time_step
+ return self.time_unit.to_timedelta(total_steps)
+
+
+@dataclasses.dataclass
+class DatasetMetadata:
+ name: str
+ start: str
+ end: str
+ time_step: int # time between successive data points in `time_unit`
+ time_unit: TimeUnit
+
+ @property
+ def freq(self) -> str:
+ return f"{self.time_step}{self.time_unit.value}"
+
+
+@dataclasses.dataclass(frozen=True)
+class VariableConfig:
+ """Input variable set"""
+
+ name: str
+ variables_2d: list[str]
+ variables_3d: list[str]
+ levels: list[int]
+ variables_static: list[str] = dataclasses.field(default_factory=list)
+
+
+class SpatioTemporalDataset(Protocol):
+ """Protocol for time-indexed gridded datasets."""
+
+ @property
+ def domain(self) -> Domain:
+ pass
+
+ def __len__(self) -> int:
+ pass
+
+ @property
+ def num_channels(self) -> int:
+ pass
+
+ @property
+ def condition_channels(self) -> int:
+ pass
+
+ @property
+ def augment_channels(self) -> int:
+ return 0
+
+ @property
+ def label_dim(self) -> int:
+ return 0
+
+ @property
+ def time_length(self) -> int:
+ pass
+
+ @property
+ def batch_info(self) -> BatchInfo:
+ return BatchInfo(
+ channels=[str(i) for i in range(self.num_channels)],
+ )
+
+ def metadata(self) -> Any:
+ """Unstructured metadata about the dataset and the values it yields
+
+ Can be used to save normalization constants, timestamps, channel names,
+ config values, etc. The training code will avoid looking into this, but
+ could be useful for inference.
+
+ """
+ return {}
+
+ def __getitem__(self, idx) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+
+ Returns:
+ image: shaped (num_channels, time_length, x)
+ labels: shaped (label_dim,)
+ condition: shaped (condition_channels, time_length, x)
+ """
+ pass
diff --git a/examples/weather/healda/datasets/catalog.py b/examples/weather/healda/datasets/catalog.py
new file mode 100644
index 0000000000..d8f4e82ae3
--- /dev/null
+++ b/examples/weather/healda/datasets/catalog.py
@@ -0,0 +1,116 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import urllib.parse
+from dataclasses import dataclass
+
+import xarray
+import zarr
+import zarr.storage
+from config import environment
+from utils import storage
+
+
+@dataclass
+class _Zarr:
+ path: str
+ profile: str
+
+ @property
+ def storage_options(self):
+ return storage.get_storage_options(self.profile)
+
+ def to_store(self, obstore=True) -> zarr.storage.StoreLike:
+ if self.profile == "":
+ return self.path
+
+ url = urllib.parse.urlparse(self.path)
+ if obstore:
+ bucket = url.netloc
+ store = storage.get_obstore(self.profile, bucket=bucket, prefix=url.path)
+ zarr_store = zarr.storage.ObjectStore(store)
+ else:
+ import fsspec
+
+ fs = fsspec.filesystem(
+ url.scheme, storage_options=self.storage_options, asyn=True
+ )
+ zarr_store = zarr.storage.FsspecStore(fs)
+
+ return zarr_store
+
+ def to_zarr(self, obstore=True, use_consolidated=True) -> zarr.Group:
+ store = self.to_store(obstore=True)
+ return zarr.open_group(store, use_consolidated=use_consolidated)
+
+ def to_xarray(self, obstore: bool = True, **kwargs) -> xarray.Dataset:
+ return xarray.open_zarr(
+ self.to_store(obstore=obstore),
+ **kwargs,
+ )
+
+ def consolidate_metadata(self):
+ store = zarr.storage.FsspecStore.from_url(
+ self.path, storage_options=storage.get_storage_options(self.profile)
+ )
+ zarr.consolidate_metadata(store)
+
+
+@dataclass
+class _Parquet:
+ path: str
+ profile: str
+
+ @property
+ def storage_options(self):
+ return storage.get_storage_options(self.profile)
+
+ @property
+ def polars_storage_options(self):
+ return storage.get_polars_storage_options(self.profile)
+
+ def files(self):
+ import fsspec
+
+ fs = fsspec.filesystem("s3", **self.storage_options)
+ return ["s3://" + f for f in fs.glob(self.path)]
+
+ def to_pandas(self, year, month, day):
+ import pandas as pd
+
+ path = f"{self.path}/{year:04d}{month:02d}{day:02d}.parquet"
+ return pd.read_parquet(path, storage_options=self.storage_options)
+
+ def to_polars(self):
+ import polars
+
+ return polars.scan_parquet(
+ self.path + "/*.parquet",
+ storage_options=storage.get_polars_storage_options(self.profile),
+ )
+
+
+def era5_hpx6():
+ return _Zarr(environment.V6_ERA5_ZARR, environment.V6_ERA5_ZARR_PROFILE)
+
+
+def ufs_obs():
+ """UFS obs parquet dataset."""
+ return _Parquet(environment.UFS_OBS_PATH + "amsua", "pbss")
+
+
+def ufs():
+ """UFS analysis dataset"""
+ return _Zarr(path=environment.UFS_HPX6_ZARR, profile=environment.UFS_ZARR_PROFILE)
diff --git a/examples/weather/healda/datasets/dataset.py b/examples/weather/healda/datasets/dataset.py
new file mode 100644
index 0000000000..0e33c6801f
--- /dev/null
+++ b/examples/weather/healda/datasets/dataset.py
@@ -0,0 +1,292 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+import datetime
+from typing import Literal, Optional
+
+import config.environment as config
+import numpy as np
+import pandas as pd
+import torch
+
+from datasets.analysis_loaders import (
+ ERA5Loader,
+ get_batch_info,
+)
+from datasets.base import (
+ DatasetMetadata,
+ TimeUnit,
+ VariableConfig,
+)
+from datasets.filter_times import get_chunk_aligned_times
+from datasets.merged_dataset import TimeMergedDataset, TimeMergedMapStyle
+from datasets.obs_loader import UFSUnifiedLoader
+from datasets.variable_configs import VARIABLE_CONFIGS
+from config.model_config import ObsConfig
+
+OBS_INTERVALS = [[-48, -24], [-24, 0], [-21, 3], [-18, 6], [-15, 9]]
+
+_default_config = VARIABLE_CONFIGS["default"]
+
+DATASET_METADATA: dict[str, DatasetMetadata] = {
+ "ufs": DatasetMetadata(
+ name="ufs",
+ start="1994-01-01 00:00:00", # actual is 1993-12-31 18:00:00, but align to 00/12
+ end="2023-10-13 03:00:00",
+ time_step=6,
+ time_unit=TimeUnit.HOUR,
+ ),
+ "ufs_obs": DatasetMetadata(
+ name="ufs_obs",
+ start="2000-01-01 00:00:00", # actual is 1994-01-01-00:00:00
+ end="2023-12-31 18:00:00",
+ time_step=6, # can get 3hr spacing
+ time_unit=TimeUnit.HOUR,
+ ),
+ "era5": DatasetMetadata(
+ name="era5",
+ start="2000",
+ end="2023-10-31 23:00:00",
+ time_step=6,
+ time_unit=TimeUnit.HOUR,
+ ),
+}
+
+
+def frame_dropout(x, p_dropout):
+ gate = torch.rand_like(x[..., :1, :1]) < p_dropout
+ return x * gate.to(x)
+
+
+def shift_time(x, shift):
+ if shift < 0:
+ raise ValueError(f"shift must be non-negative, got {shift}")
+ t_dim = -2
+ x = x.roll(shift, dims=t_dim)
+ x[..., :shift, :] = 0.0
+ return x
+
+
+def _compute_frame_step(
+ dataset_spacing: datetime.timedelta, time_step: int, time_length: int
+) -> int:
+ if time_length == 1:
+ return 1
+
+ model_resolution_timedelta = datetime.timedelta(hours=time_step)
+ return model_resolution_timedelta // dataset_spacing
+
+
+def get_label_from_obs_context_hours(obs_context_hours):
+ """Map observation context window to a label index for conditioning."""
+ if isinstance(obs_context_hours, np.ndarray):
+ obs_interval = obs_context_hours.tolist()
+ else:
+ obs_interval = list(obs_context_hours)
+
+ if obs_interval in OBS_INTERVALS:
+ label = OBS_INTERVALS.index(obs_interval)
+ else:
+ label = 0
+ return label
+
+
+@dataclasses.dataclass
+class NullTransform:
+ """Placeholder transform."""
+
+ dataset: Literal["era5", "ufs"] = "era5"
+ variable_config: VariableConfig = _default_config
+
+
+def get_sensors_for_config(config: ObsConfig):
+ """Return list of sensor names enabled by the ObsConfig."""
+ sensors = ["atms", "mhs", "amsua", "amsub"]
+ if config.use_infrared:
+ sensors.append("iasi")
+
+ if config.use_conv:
+ sensors.append("conv")
+ return sensors
+
+
+def _get_ufs_obs_loaders(
+ obs_config: ObsConfig,
+):
+ if obs_config.innovation_type != "none":
+ raise ValueError(
+ f"innovation_type must be 'none' for UFS obs loaders, "
+ f"got '{obs_config.innovation_type}'"
+ )
+
+ return [
+ UFSUnifiedLoader(
+ config.UFS_OBS_PATH,
+ sensors=get_sensors_for_config(obs_config),
+ obs_context_hours=(obs_config.context_start, obs_config.context_end),
+ normalization="zscore",
+ filesystem_type="s3"
+ if config.UFS_OBS_PATH.startswith("s3://")
+ else "local",
+ remote_name=config.UFS_OBS_PROFILE,
+ drop_obs_channel_ids=obs_config.drop_obs_channel_ids,
+ conv_uv_in_situ_only=obs_config.conv_uv_in_situ_only,
+ conv_gps_level1_only=obs_config.conv_gps_level1_only,
+ )
+ ]
+
+
+def _get_splits(
+ dataset: str,
+ obs_config: Optional[ObsConfig] = None,
+ start_year: Optional[int] = None,
+):
+ metadata = DATASET_METADATA[dataset]
+ valid_times = pd.date_range(metadata.start, metadata.end, freq=metadata.freq)
+
+ if obs_config is not None and obs_config.use_obs:
+ obs_metadata = DATASET_METADATA["ufs_obs"]
+
+ if obs_config.innovation_type != "none":
+ # ufs obs anl files are missing for these dates
+ dropouts = [
+ ("2018-12-19", "2020-07-10"),
+ ("2022-05-05", "2022-10-01"),
+ ]
+ else:
+ dropouts = []
+
+ aligned_times = get_chunk_aligned_times(
+ base_metadata=metadata,
+ obs_metadata=obs_metadata,
+ dropouts=dropouts,
+ chunk_size=24,
+ )
+
+ valid_times = aligned_times
+
+ train_times = valid_times[valid_times.year < 2022]
+ test_times = valid_times[valid_times.year >= 2022]
+
+ if start_year is not None:
+ train_times = train_times[train_times.year >= start_year]
+
+ return {"train": train_times, "test": test_times, "": valid_times}
+
+
+def get_dataset(
+ *,
+ dataset: Literal["era5", "ufs"] = "era5",
+ split: str = "",
+ transform=None,
+ variable_config=None,
+ rank: int = 0,
+ world_size: int = 1,
+ model_rank: int = 0,
+ model_world_size: int = 1,
+ obs_config: Optional[ObsConfig] = None,
+ infinite: bool = False,
+ shuffle: bool = True,
+ chunk_size: int = 8,
+ time_step: int = 1, # in hours
+ time_length: int = 1,
+ window_stride: int = 1,
+ map_style: bool = False,
+ batch_transform=None,
+ start_year: Optional[int] = None,
+) -> torch.utils.data.Dataset:
+ """Build dataset for DA training or inference"""
+ variable_config = variable_config or VARIABLE_CONFIGS[dataset]
+
+ obs_input = obs_config is not None and obs_config.use_obs
+
+ loaders = []
+
+ if dataset == "era5":
+ loaders.append(ERA5Loader(variable_config))
+ elif dataset == "ufs":
+ raise ValueError(
+ "Training with ufs analysis as a target is no longer supported."
+ )
+
+ if obs_input:
+ loaders.extend(_get_ufs_obs_loaders(obs_config))
+
+ # if transform.background_source == "da":
+ # loaders.append(
+ # PthFileDataset(
+ # config.ERA5_DA_BACKGROUND_PATH,
+ # )
+ # )
+ # elif transform.background_source is not None:
+ # raise ValueError(f"Invalid background source: {transform.background_source}")
+
+ # Get the appropriate loaders for the dataset
+ times = _get_splits(dataset, obs_config, start_year=start_year)[split]
+ if times.size == 0:
+ raise RuntimeError("No times are selected.")
+
+ # Compute frame step
+ dataset_key = "ufs_obs" if obs_input else dataset
+ metadata = DATASET_METADATA[dataset_key]
+ meta_time_step = metadata.time_step
+ dataset_spacing = metadata.time_unit.to_timedelta(meta_time_step)
+ frame_step = _compute_frame_step(dataset_spacing, time_step, time_length)
+
+ # Force map_style for multi-frame validation/inference
+ map_style = map_style or (time_length > 1 and split != "train")
+
+ # Create and return the dataset
+ if map_style:
+ # Used for video validation/inference
+ ds = TimeMergedMapStyle(
+ times,
+ time_loaders=loaders,
+ frame_step=frame_step,
+ time_length=time_length,
+ cache_chunk_size=chunk_size,
+ batch_transform=batch_transform,
+ transform=transform,
+ model_rank=model_rank,
+ model_world_size=model_world_size,
+ )
+ else:
+ ds = TimeMergedDataset(
+ times,
+ time_loaders=loaders,
+ # transform=transform.transform,
+ transform=transform,
+ rank=rank,
+ world_size=world_size,
+ infinite=infinite,
+ shuffle=shuffle,
+ chunk_size=chunk_size,
+ frame_step=frame_step,
+ time_length=time_length,
+ window_stride=window_stride,
+ )
+
+ ds.batch_info = get_batch_info(
+ # config=transform.variable_config,
+ config=variable_config,
+ time_step=time_step,
+ time_unit=TimeUnit.HOUR,
+ # background_source=transform.background_source,
+ )
+ ds.calendar = "standard"
+ ds.time_units = "seconds since 1970-1-1 0:0:0"
+ return ds
diff --git a/examples/weather/healda/datasets/datetime_utils.py b/examples/weather/healda/datasets/datetime_utils.py
new file mode 100644
index 0000000000..522316341b
--- /dev/null
+++ b/examples/weather/healda/datasets/datetime_utils.py
@@ -0,0 +1,66 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import datetime
+
+import cftime
+import numpy as np
+import pandas as pd
+
+
+def as_pydatetime(time) -> datetime.datetime:
+ if isinstance(time, cftime.datetime):
+ # very important to set the timezone to UTC for example when using timestamps
+ return datetime.datetime(*cftime.to_tuple(time), tzinfo=datetime.timezone.utc)
+ elif isinstance(time, datetime.datetime):
+ return time
+ else:
+ raise NotImplementedError(type(time))
+
+
+def as_numpy(time) -> np.ndarray:
+ # Standardize time to np.ndarray of np.datetime64
+ if hasattr(time, "values"): # Handle pandas Index
+ time = time.values
+ elif isinstance(time, (pd.Timestamp, datetime.datetime)):
+ time = np.array([np.datetime64(time)])
+ elif isinstance(time, cftime.datetime):
+ return as_numpy(as_pydatetime(time))
+ elif isinstance(time, np.datetime64):
+ time = np.array([time])
+ else:
+ time = np.array([np.datetime64(t) for t in time])
+ return time
+
+
+def as_timestamp(time) -> np.ndarray:
+ """return time as an int unix timestamp"""
+ return as_numpy(time).astype("datetime64[s]").astype(int)
+
+
+def second_of_day(time):
+ begin_of_day = time.replace(hour=0, second=0, minute=0)
+ return (time - begin_of_day).total_seconds()
+
+
+def as_cftime(timestamp):
+ return cftime.DatetimeGregorian(
+ timestamp.year,
+ timestamp.month,
+ timestamp.day,
+ timestamp.hour,
+ timestamp.minute,
+ timestamp.second,
+ )
diff --git a/examples/weather/healda/datasets/era5_13_levels_stats.csv b/examples/weather/healda/datasets/era5_13_levels_stats.csv
new file mode 100644
index 0000000000..d57746565e
--- /dev/null
+++ b/examples/weather/healda/datasets/era5_13_levels_stats.csv
@@ -0,0 +1,77 @@
+variable,level,std,mean
+U,1000,6.0761532856545495,-0.4012036280949351
+U,925,7.8438841318137085,0.1875586020534487
+U,850,8.117545798838897,1.0676392264524672
+U,700,9.161262910521476,3.1993547283822656
+U,600,10.351920350301654,4.753506500293617
+U,500,12.020863918100371,6.643045555863143
+U,400,14.37786455089647,9.166420770646372
+U,300,17.368910160479544,12.604205075841238
+U,250,18.63126418939197,14.458098693848147
+U,200,18.826071570402007,15.603421939654016
+U,150,17.338743800703853,14.694208759254527
+U,100,14.207755780676433,10.084541123485184
+U,50,14.828288062604061,3.447298300210448
+V,1000,5.0562302721142185,0.1881812905956076
+V,925,6.087634459436962,0.18351119177675437
+V,850,5.788139666651214,0.08865528649740444
+V,700,6.326620906882278,-0.01624326876106696
+V,600,7.18152443362136,-0.046446792220198006
+V,500,8.45848562350414,-0.032289933410975316
+V,400,10.38343861775161,-0.02010503440455113
+V,300,12.609499025107015,-0.025803734238538843
+V,250,13.057155983469311,-0.04324580333432856
+V,200,12.04089626815829,-0.07330144652625546
+V,150,9.733874174308893,-0.06046714286886432
+V,100,7.052043591267376,0.016139828470052277
+V,50,5.626839932669169,0.000506756767366129
+T,1000,13.3926808838397,288.34811963631944
+T,925,12.75887207878649,284.14180529926887
+T,850,12.337912690900826,281.1236718383152
+T,700,11.485912313471955,273.6522781145025
+T,600,10.914242703518575,266.8225997851199
+T,500,10.947319271596404,258.48344241794354
+T,400,10.889317302118906,247.5078738667765
+T,300,9.38121872205424,233.2984452545507
+T,250,7.28726013236191,225.6329557585312
+T,200,5.29271246649905,218.53493343877926
+T,150,7.447585393440247,211.46310701251957
+T,100,11.427625612914374,204.83933151080333
+T,50,7.494050550499991,211.44287440973292
+Z,1000,893.4657098478391,935.5049140418249
+Z,925,1008.7658942548571,7360.676041141024
+Z,850,1197.2156800245555,14248.78134896371
+Z,700,1736.7489061933863,29767.144398835746
+Z,600,2188.213530821762,41747.04921153613
+Z,500,2729.9217253050874,55509.29599258445
+Z,400,3402.601245989482,71726.5567290568
+Z,300,4213.390725518206,91575.34112031905
+Z,250,4602.533249864132,103579.02469155286
+Z,200,4831.571016444258,117791.77588622355
+Z,150,4711.8643515216945,135539.4191093247
+Z,100,4105.159434190763,159702.46332065153
+Z,50,3851.9190680022352,200924.18906200194
+Q,1000,0.005781177709317301,0.00936470198136523
+Q,925,0.004981825818460944,0.008008882445730241
+Q,850,0.004179579792520903,0.006031608541363992
+Q,700,0.002737859580661181,0.0031682692015671905
+Q,600,0.0019479582011688481,0.0019997242562778714
+Q,500,0.0012087878889463929,0.0011044648648525284
+Q,400,0.0005684813340060475,0.0005001939892631466
+Q,300,0.00018788822542209214,0.00016860593486309685
+Q,250,8.253217825122686e-05,7.728339864456503e-05
+Q,200,2.476106074180645e-05,2.5977126545155267e-05
+Q,150,4.08396736011409e-06,6.446599195006216e-06
+Q,100,6.154373024833254e-07,2.6868721151493317e-06
+Q,50,2.5960505606920645e-07,2.6752383145132987e-06
+tcwv,-1,16.707112756025314,24.224098621015028
+tas,-1,15.355112655773071,287.40642670567786
+uas,-1,5.436672873915895,-0.3751276131251324
+vas,-1,4.491587780564064,0.18441198768158903
+100u,-1,6.684378709883516,-0.36058662354821425
+100v,-1,5.613643415909631,0.1893235299805761
+pres_msl,-1,1109.2809461275167,101138.98195665042
+sst,-1,8.851771853018453,290.8944586515412
+sic,-1,0.18702627733432053,0.04226433445734063
+orog,-1.0,627.3885284872,232.56013904090733
+lfrac,-1.0,0.4695501683565522,0.3410480857539571
\ No newline at end of file
diff --git a/examples/weather/healda/datasets/etl/README.md b/examples/weather/healda/datasets/etl/README.md
new file mode 100644
index 0000000000..be05797af9
--- /dev/null
+++ b/examples/weather/healda/datasets/etl/README.md
@@ -0,0 +1,101 @@
+# Observation Data ETL
+
+Pipeline to prepare UFS observation data for training.
+
+## Full Pipeline
+
+```text
+1. pull_from_noaa_s3.sh Download raw NC4 from NOAA S3
+2. etl_unified.py Convert NC4 โ parquet
+3. compute_normalizations Compute stats from processed parquet
+4. Re-run etl_unified.py Regenerate channel_table with actual normalization stats
+```
+
+## Configuration
+
+Configure paths in your `.env` file (see main [README](../../README.md)):
+
+- `UFS_RAW_OBS_DIR` โ where raw NC4 files are downloaded
+- `UFS_OBS_PATH` โ where processed parquet files are stored
+
+## Step 1: Download
+
+`pull_from_noaa_s3.sh` downloads GSI diagnostic files from the public NOAA
+GEFS-v13 replay archive.
+
+```bash
+./pull_from_noaa_s3.sh
+```
+
+- Downloads to `UFS_RAW_OBS_DIR` from `.env`
+- Edit `YEARS`, `SENSORS`, `KIND` in the script as needed
+- Requires [`s5cmd`](https://github.com/peak/s5cmd#installation)
+
+## Step 2: Process
+
+`etl_unified.py` converts NC4 files to parquet with a unified schema.
+
+```bash
+python3 etl_unified.py --sensor amsua,conv,atms,amsub,mhs --num-workers 32
+```
+
+Defaults to `$UFS_RAW_OBS_DIR` as input and `$UFS_OBS_PATH` output dir using `.env`.
+
+## Normalization Stats
+
+Normalization stats (mean/std/min/max per channel) are stored in `etl/normalizations/`.
+
+**If CSVs are missing:** ETL defaults to mean=0, std=1 (with a warning).
+The observation parquet files are still valid โ only `channel_table.parquet`
+needs regeneration after computing proper stats. If using our pretrained
+checkpoint, use the provided CSVs instead of recomputing to prevent any differences.
+
+**To recompute stats for new sensors:**
+
+1. Run ETL to produce parquet (stats will default to mean=0, std=1)
+1. Compute normalizations: `python3 compute_normalizations.py --sensors conv,amsua,atms,mhs,amsub`
+1. Regenerate channel table: `python3 etl_unified.py --channel-table-only`
+
+## Output Structure
+
+```text
+processed_obs_v7_ges/
+โโโ channel_table.parquet # Channel metadata (IDs, normalization stats)
+โโโ amsua/
+โ โโโ 20220101/0.parquet
+โโโ conv/
+โ โโโ 20220101/0.parquet
+โโโ ...
+```
+
+## Schema
+
+Defined in `combined_schema.py`. One row per observation.
+
+**Common fields** (all sensors):
+
+- `Latitude`, `Longitude` โ observation location
+- `Absolute_Obs_Time` โ timestamp (ns precision)
+- `DA_window` โ 3-hourly assimilation window (used for row grouping)
+- `Global_Channel_ID` โ unique ID across all sensors/channels
+- `Platform_ID` โ satellite platform or conventional type
+- `Observation` โ the measurement value
+
+**Satellite-specific** (nullable for conv):
+
+- `Sat_Zenith_Angle`, `Sol_Zenith_Angle`, `Scan_Angle`
+
+**Conventional-specific** (nullable for satellite):
+
+- `Pressure`, `Height`, `Observation_Type`
+
+**Channel table** (`channel_table.parquet`):
+
+- `Global_Channel_ID` โ joins to observation data
+- `min_valid`, `max_valid` โ valid range for QC
+- `mean`, `stddev` โ normalization statistics
+- `is_conv` โ conventional vs satellite flag
+
+Conventional observations with multiple components (e.g., uv winds) are
+flattened into separate rows sharing the same location/time but different
+`Global_Channel_ID`.
diff --git a/examples/weather/healda/datasets/etl/__init__.py b/examples/weather/healda/datasets/etl/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/weather/healda/datasets/etl/combined_schema.py b/examples/weather/healda/datasets/etl/combined_schema.py
new file mode 100644
index 0000000000..021854edc8
--- /dev/null
+++ b/examples/weather/healda/datasets/etl/combined_schema.py
@@ -0,0 +1,119 @@
+#!/usr/bin/env python3
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Combined PyArrow schema for data produced by etl_unified.py
+
+This schema handles both satellite observations (atms, mhs, amsua, etc.)
+and conventional observations (gps, ps, q, t, uv), as well as both 'ges'
+(guess) and 'anl' (analysis) data types with optional fields that are only present
+in certain contexts.
+
+Key design decision: Conventional observations (u, v, T, q) are flattened
+into a single 'Observation' column with multiple rows per location for
+multi-component observations.
+"""
+
+import pyarrow as pa
+
+GLOBAL_CHANNEL_ID = pa.field("Global_Channel_ID", pa.uint16(), nullable=False)
+SENSOR_ID = pa.field("sensor_id", pa.uint16())
+
+
+def get_combined_observation_schema() -> pa.Schema:
+ """
+ Create a combined PyArrow schema for both satellite and conventional observations.
+
+ This schema accommodates:
+ 1. Common fields present in both data types
+ 2. Satellite-specific fields (angles, channel info)
+ 3. Conventional-specific fields (pressure, height, observation types)
+ 4. Flattened observation structure (all obs types in single Observation column)
+ 5. Optional analysis fields (present only in 'anl' data)
+
+ Note: Conventional observations (u, v, T, q) are flattened into a single
+ 'Observation' column with multiple rows per location for multi-component obs.
+
+ Returns:
+ pa.Schema: Combined schema for all observation data
+ """
+
+ # Common fields present in both satellite and conventional data
+ common_fields = [
+ # Spatial and temporal information
+ pa.field("Latitude", pa.float32()),
+ pa.field("Longitude", pa.float32()),
+ pa.field(
+ "Absolute_Obs_Time", pa.timestamp("ns")
+ ), # nanosecond is excessively precise, but is valid from 1678 --2262, so good enough for our pruposes
+ pa.field("DA_window", pa.timestamp("ns")),
+ # Platform identification
+ pa.field("Platform_ID", pa.uint16()), # Maps to PLATFORM_NAME_TO_ID
+ # Observation data - flattened structure
+ pa.field("Observation", pa.float32()), # Main observation value (required)
+ GLOBAL_CHANNEL_ID,
+ ]
+
+ # Satellite-specific fields (from etl.py)
+ satellite_fields = [
+ # Angular information (satellite observations only)
+ pa.field("Sat_Zenith_Angle", pa.float32(), nullable=True),
+ pa.field("Sol_Zenith_Angle", pa.float32(), nullable=True),
+ pa.field("Scan_Angle", pa.float32(), nullable=True),
+ ]
+
+ # Conventional observation specific fields (from etl_conv.py)
+ conventional_fields = [
+ # Metadata fields
+ pa.field("Pressure", pa.float32(), nullable=True),
+ pa.field("Height", pa.float32(), nullable=True),
+ pa.field("Observation_Type", pa.uint16(), nullable=True),
+ ]
+
+ # Analysis fields (present only in 'anl' data)
+ analysis_fields = [
+ # Quality control and forecast differences
+ pa.field("QC_Flag", pa.int32(), nullable=True), # Satellite QC flag
+ pa.field(
+ "Analysis_Use_Flag", pa.int8(), nullable=True
+ ), # Conventional analysis flag
+ # Forecast differences (flattened to single column)
+ pa.field("Obs_Minus_Forecast_adjusted", pa.float32(), nullable=True),
+ pa.field("Obs_Minus_Forecast_unadjusted", pa.float32(), nullable=True),
+ ]
+
+ # Combine all fields
+ all_fields = (
+ common_fields + satellite_fields + conventional_fields + analysis_fields
+ )
+
+ return pa.schema(all_fields)
+
+
+def get_channel_table_schema():
+ """Return PyArrow schema for channel metadata table."""
+ return pa.schema(
+ [
+ GLOBAL_CHANNEL_ID,
+ pa.field("min_valid", pa.float32()),
+ pa.field("max_valid", pa.float32()),
+ SENSOR_ID,
+ pa.field("is_conv", pa.bool_()),
+ pa.field("name", pa.string()),
+ pa.field("mean", pa.float32()),
+ pa.field("stddev", pa.float32()),
+ ]
+ )
diff --git a/examples/weather/healda/datasets/etl/compute_normalizations.py b/examples/weather/healda/datasets/etl/compute_normalizations.py
new file mode 100644
index 0000000000..675a6e8e0e
--- /dev/null
+++ b/examples/weather/healda/datasets/etl/compute_normalizations.py
@@ -0,0 +1,185 @@
+#!/usr/bin/env python3
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ruff: noqa: S608 # SQL built from internal config, not user input
+"""
+Compute normalization statistics for observation data.
+
+Uses channel_table.parquet for min/max valid ranges per channel,
+ensuring consistent filtering with training code.
+"""
+
+import argparse
+import os
+import time
+
+import duckdb
+from dotenv import load_dotenv
+
+from datasets.sensors import CONV_GPS_GLOBAL_IDS, SENSOR_OFFSET, QCLimits
+
+load_dotenv()
+
+DEFAULT_SENSORS = ["conv", "mhs", "amsua", "atms", "amsub"]
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Compute normalizations for UFS observation data (satellite and conventional)"
+ )
+ parser.add_argument(
+ "--sensors",
+ type=str,
+ nargs="+",
+ default=DEFAULT_SENSORS,
+ help=f"Sensors to process (default: {DEFAULT_SENSORS})",
+ )
+ parser.add_argument(
+ "--data-root",
+ type=str,
+ default=os.getenv("UFS_OBS_PATH"),
+ help="Root directory with processed obs (default: $UFS_OBS_PATH)",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default=None,
+ help="Output directory (default: etl/normalizations/)",
+ )
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ if not args.data_root:
+ raise ValueError("--data-root required (or set UFS_OBS_PATH in .env)")
+
+ sensors = args.sensors
+ data_root = args.data_root
+ # Default: store normalizations in code directory (etl/normalizations/)
+ output_dir = args.output_dir or os.path.join(
+ os.path.dirname(__file__), "normalizations"
+ )
+ channel_table = os.path.join(data_root, "channel_table.parquet")
+
+ print(f"Processing sensors: {sensors}")
+ print(f"Data root: {data_root}")
+ print(f"Channel table: {channel_table}")
+ print(f"Output directory: {output_dir}")
+
+ if not os.path.exists(channel_table):
+ raise FileNotFoundError(
+ f"Channel table not found: {channel_table}. Please run etl_unified.py first."
+ )
+
+ os.makedirs(output_dir, exist_ok=True)
+
+ conn = duckdb.connect()
+ conn.execute("PRAGMA threads=32;")
+ conn.execute("SET preserve_insertion_order=false;")
+
+ # Pre-load small channel_table into memory (only ~330 rows)
+ conn.execute(
+ f"CREATE TABLE channels AS SELECT Global_Channel_ID, min_valid, max_valid FROM read_parquet('{channel_table}')"
+ )
+
+ for sensor in sensors:
+ # Use glob pattern directly - DuckDB handles this efficiently
+ parquet_glob = os.path.join(data_root, sensor, "*", "*.parquet")
+ csv_path = os.path.join(output_dir, f"{sensor}_normalizations.csv")
+
+ # Check if sensor directory exists
+ sensor_dir = os.path.join(data_root, sensor)
+ if not os.path.exists(sensor_dir):
+ print(f"\nSkipping {sensor}: directory not found")
+ continue
+
+ print(f"\nProcessing {sensor}...")
+
+ start = time.time()
+
+ # Conv needs additional height/pressure filtering (from QCLimits)
+ # GPS channels use 0.5 hPa min pressure, others use 200 hPa
+ if sensor == "conv":
+ gps_ids = ", ".join(str(x) for x in CONV_GPS_GLOBAL_IDS)
+ extra_where = f"""
+ AND o.Height BETWEEN {QCLimits.HEIGHT_MIN} AND {QCLimits.HEIGHT_MAX}
+ AND o.Pressure <= {QCLimits.PRESSURE_MAX}
+ AND o.Pressure >= CASE
+ WHEN o.Global_Channel_ID IN ({gps_ids}) THEN {QCLimits.PRESSURE_MIN_GPS}
+ ELSE {QCLimits.PRESSURE_MIN_DEFAULT}
+ END
+ """
+ else:
+ extra_where = ""
+
+ # Raw_Channel_ID is 1-indexed
+ sensor_offset = SENSOR_OFFSET[sensor]
+
+ # Stream directly using DuckDB glob - efficient for many files
+ sql = f"""
+ COPY (
+ -- per-platform stats
+ SELECT
+ o.Global_Channel_ID - {sensor_offset} + 1 AS Raw_Channel_ID,
+ o.Platform_ID,
+ STDDEV(o.Observation) AS obs_std,
+ AVG(o.Observation) AS obs_mean,
+ MIN(o.Observation) AS obs_min,
+ MAX(o.Observation) AS obs_max
+ FROM read_parquet('{parquet_glob}') o
+ JOIN channels c ON o.Global_Channel_ID = c.Global_Channel_ID
+ WHERE o.Observation BETWEEN c.min_valid AND c.max_valid
+ AND o.Observation IS NOT NULL
+ {extra_where}
+ GROUP BY o.Global_Channel_ID, o.Platform_ID
+
+ UNION ALL
+
+ -- overall stats (Platform_ID = -1)
+ SELECT
+ o.Global_Channel_ID - {sensor_offset} + 1 AS Raw_Channel_ID,
+ -1 AS Platform_ID,
+ STDDEV(o.Observation) AS obs_std,
+ AVG(o.Observation) AS obs_mean,
+ MIN(o.Observation) AS obs_min,
+ MAX(o.Observation) AS obs_max
+ FROM read_parquet('{parquet_glob}') o
+ JOIN channels c ON o.Global_Channel_ID = c.Global_Channel_ID
+ WHERE o.Observation BETWEEN c.min_valid AND c.max_valid
+ AND o.Observation IS NOT NULL
+ {extra_where}
+ GROUP BY o.Global_Channel_ID
+
+ ORDER BY Raw_Channel_ID, Platform_ID
+ )
+ TO '{csv_path}'
+ (HEADER TRUE, DELIMITER ',');
+ """
+
+ try:
+ conn.execute(sql)
+ print(f" Wrote to {csv_path} ({time.time() - start:.1f}s)")
+ except Exception as e:
+ print(f" Error: {e}")
+
+ conn.close()
+ print("\nDone!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/weather/healda/datasets/etl/etl_unified.py b/examples/weather/healda/datasets/etl/etl_unified.py
new file mode 100644
index 0000000000..98c62d24ca
--- /dev/null
+++ b/examples/weather/healda/datasets/etl/etl_unified.py
@@ -0,0 +1,693 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import os
+import random
+import re
+import tempfile
+from collections import defaultdict
+from concurrent.futures import ProcessPoolExecutor
+from typing import Dict, List, Literal, Optional, Tuple
+
+import h5py
+import joblib
+import numpy as np
+import pandas as pd
+import pyarrow as pa
+import pyarrow.compute as pc
+import pyarrow.parquet as pq
+from dotenv import load_dotenv
+from tqdm import tqdm
+
+from datasets.etl.combined_schema import (
+ get_channel_table_schema,
+ get_combined_observation_schema,
+)
+from datasets.sensors import (
+ CONV_CHANNEL_MAP,
+ CONV_CHANNELS,
+ CONV_PLATFORMS,
+ PLATFORM_NAME_TO_ID,
+ SENSOR_CONFIGS,
+ get_global_channel_id,
+)
+
+load_dotenv()
+
+memory = joblib.Memory(".cache")
+
+TEST = False
+
+# Set single-threaded Arrow for better parallelization
+os.environ["ARROW_NUM_THREADS"] = "1"
+
+
+# Satellite sensor columns
+SATELLITE_COLUMNS = [
+ "Latitude",
+ "Longitude",
+ "Observation",
+ "Channel_Index",
+ "Obs_Time",
+ "Sat_Zenith_Angle",
+ "Sol_Zenith_Angle",
+ "Scan_Angle",
+]
+
+# Conventional sensor metadata columns
+CONV_METADATA_COLUMNS = [
+ "Latitude",
+ "Longitude",
+ "Time",
+ "Pressure",
+ "Height",
+ "Observation_Type",
+]
+
+# Analysis columns (common to both)
+ANALYSIS_COLUMNS = [
+ "Obs_Minus_Forecast_adjusted",
+ "Obs_Minus_Forecast_unadjusted",
+ "QC_Flag", # For satellite
+ "Analysis_Use_Flag", # For conventional
+]
+
+
+def _get_conv_obs_columns_for_platform(platform: str) -> list[str]:
+ return [ch.nc_column for ch in CONV_CHANNELS if ch.platform == platform]
+
+
+@memory.cache
+def list_nc_files(path):
+ """List all .nc4 files in the given directory tree."""
+ return [
+ os.path.relpath(os.path.join(root, fname), path)
+ for root, _, fnames in os.walk(path)
+ for fname in fnames
+ if fname.endswith(".nc4")
+ ]
+
+
+def get_channel_table():
+ """Build channel metadata table for all sensors."""
+ nchan = [cfg.channels for cfg in SENSOR_CONFIGS.values()]
+ sensor_id = np.arange(len(nchan)).repeat(nchan)
+ id = np.arange(sensor_id.size).astype(np.uint16)
+
+ conv_names = [c.name for c in CONV_CHANNELS]
+ conv_min = [c.min_valid for c in CONV_CHANNELS]
+ conv_max = [c.max_valid for c in CONV_CHANNELS]
+
+ min_valid_list = []
+ max_valid_list = []
+ names = []
+ means_list = []
+ stds_list = []
+ for name, cfg in SENSOR_CONFIGS.items():
+ if name == "conv":
+ min_valid_list.extend(conv_min)
+ max_valid_list.extend(conv_max)
+ names.extend(conv_names)
+ else:
+ min_valid_list.extend([cfg.min_valid] * cfg.channels)
+ max_valid_list.extend([cfg.max_valid] * cfg.channels)
+ names.extend([f"{name}_{i:03d}" for i in range(cfg.channels)])
+
+ means_list.extend(cfg.means)
+ stds_list.extend(cfg.stds)
+
+ min_valid = np.array(min_valid_list, dtype=np.float32)
+ max_valid = np.array(max_valid_list, dtype=np.float32)
+ means = np.array(means_list, dtype=np.float32)
+ stds = np.array(stds_list, dtype=np.float32)
+
+ is_conv = sensor_id == list(SENSOR_CONFIGS).index("conv")
+
+ return pa.table(
+ [id, min_valid, max_valid, sensor_id, is_conv, names, means, stds],
+ schema=get_channel_table_schema(),
+ )
+
+
+def parse_args():
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser(
+ description="Unified ETL script for processing UFS observation data (satellite and conventional)"
+ )
+ parser.add_argument(
+ "--sensor",
+ type=str,
+ default="all",
+ help="Sensor(s) to process: 'all' for all sensors, single sensor (e.g., 'atms'), or comma-separated list (e.g., 'conv,amsua,atms')",
+ )
+ parser.add_argument(
+ "--input-dir",
+ type=str,
+ default=os.getenv("UFS_RAW_OBS_DIR"),
+ help="Input directory containing raw observation files (default: $UFS_RAW_OBS_DIR from .env)",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default=os.getenv("UFS_OBS_PATH"),
+ help="Output directory for processed parquet data (default: $UFS_OBS_PATH from .env)",
+ )
+ parser.add_argument(
+ "--type",
+ type=str,
+ default="ges",
+ choices=["ges", "anl"],
+ help="Ges or anl data to process (default: ges)",
+ )
+ parser.add_argument(
+ "--num-workers",
+ type=int,
+ default=96,
+ help="Number of workers to use (default: 96)",
+ )
+ parser.add_argument(
+ "--channel-table-only",
+ action="store_true",
+ help="Only generate channel_table.parquet (skip processing observation files)",
+ )
+ return parser.parse_args()
+
+
+def is_satellite_sensor(sensor: str) -> bool:
+ """Check if sensor is a satellite sensor (not conventional)."""
+ return sensor not in CONV_PLATFORMS and sensor != "conv" and sensor != "all"
+
+
+def is_conventional_sensor(sensor: str) -> bool:
+ """Check if sensor is conventional."""
+ return sensor == "conv"
+
+
+def read_satellite_variables_from_file(
+ full_file_path: str, sensor: str, obs_type: Literal["ges", "anl"] = "ges"
+) -> Optional[pa.Table]:
+ """Read variables from a satellite sensor NetCDF file."""
+ try:
+ with h5py.File(full_file_path, "r") as ds:
+ data = {}
+ cols_to_read = SATELLITE_COLUMNS
+ if obs_type == "anl":
+ cols_to_read += ANALYSIS_COLUMNS[:2] # Exclude QC_Flag for now
+
+ for col in cols_to_read:
+ if col == "Channel_Index":
+ channel_idx_in_sensor_chan = ds[col][:].astype(np.uint16)
+ sensor_chan = ds["sensor_chan"][:]
+ data["Raw_Channel_ID"] = sensor_chan[
+ channel_idx_in_sensor_chan - 1
+ ].astype(np.uint16)
+ elif col == "QC_Flag":
+ data[col] = ds[col][:].astype(np.int32)
+ else:
+ data[col] = ds[col][:]
+
+ data["Global_Channel_ID"] = get_global_channel_id(
+ sensor, data["Raw_Channel_ID"]
+ )
+
+ n = len(data[SATELLITE_COLUMNS[0]])
+ filename = os.path.basename(full_file_path)
+
+ # Map to global platform idx
+ platform = filename.split("_")[2]
+ platform_id = PLATFORM_NAME_TO_ID[platform]
+ data["Platform_ID"] = np.full(n, platform_id).astype(np.uint8)
+
+ # Process time information
+ match = re.search(r"\.(\d{10})_", filename)
+ if match:
+ date_str = match.group(1)
+ NS_3H = np.int64(3) * 3_600_000_000_000
+ base_time = pd.to_datetime(date_str, format="%Y%m%d%H")
+ base_ns = base_time.value
+ hours = data["Obs_Time"].astype(np.float64)
+ hours_to_ns = np.rint(hours * 3600.0 * 1e9).astype(np.int64)
+
+ # clip to be > -3h and <= 3h
+ hours_to_ns = np.clip(hours_to_ns, -NS_3H + 1, NS_3H)
+ abs_time_ns = base_ns + hours_to_ns
+ abs_time = abs_time_ns.astype("datetime64[ns]")
+ data["Absolute_Obs_Time"] = abs_time
+
+ # assign exactly two labels: t (left half) or t+3h (right half)
+ da_ns = np.where(abs_time_ns <= base_ns, base_ns, base_ns + NS_3H)
+ data["DA_window"] = da_ns.astype("datetime64[ns]")
+ else:
+ raise RuntimeError(f"No date match found for {filename}")
+
+ del data["Obs_Time"]
+ output_schema = get_combined_observation_schema()
+ df = pd.DataFrame(data)
+ for field in output_schema:
+ if field.name not in df.columns:
+ df[field.name] = pd.NA
+ return pa.table(df, schema=output_schema)
+
+ except Exception as e:
+ print(f"Error reading satellite file {full_file_path}: {e}")
+ return None
+
+
+def read_conventional_variables_from_file(
+ full_file_path: str, obs_type: Literal["ges", "anl"] = "ges"
+) -> Optional[pa.Table]:
+ """Read variables from a conventional sensor NetCDF file."""
+ try:
+ with h5py.File(full_file_path, "r") as ds:
+ data = {}
+
+ # Read common metadata columns
+ for col in CONV_METADATA_COLUMNS:
+ if col == "Observation_Type":
+ data["Observation_Type"] = ds[col][:].astype(np.uint16)
+ elif col == "Time":
+ pass # Handle separately
+ else:
+ data[col] = ds[col][:]
+
+ time = ds["Time"][:]
+ n = len(data[CONV_METADATA_COLUMNS[0]])
+ filename = os.path.basename(full_file_path)
+
+ # Extract platform from filename
+ platform = filename.split("_")[2]
+ platform_id = PLATFORM_NAME_TO_ID[platform]
+ data["Platform_ID"] = np.full(n, platform_id).astype(np.uint8)
+
+ # Get observation columns based on platform
+ observation_columns = _get_conv_obs_columns_for_platform(platform)
+
+ # Handle analysis columns
+ if obs_type == "anl":
+ data["Analysis_Use_Flag"] = ds["Analysis_Use_Flag"][:].astype(np.int8)
+ if platform != "uv":
+ data["Obs_Minus_Forecast_adjusted"] = ds[
+ "Obs_Minus_Forecast_adjusted"
+ ][:]
+ data["Obs_Minus_Forecast_unadjusted"] = ds[
+ "Obs_Minus_Forecast_unadjusted"
+ ][:]
+ else:
+ data["v_Obs_Minus_Forecast_adjusted"] = ds[
+ "v_Obs_Minus_Forecast_adjusted"
+ ][:]
+ data["v_Obs_Minus_Forecast_unadjusted"] = ds[
+ "v_Obs_Minus_Forecast_unadjusted"
+ ][:]
+ data["u_Obs_Minus_Forecast_adjusted"] = ds[
+ "u_Obs_Minus_Forecast_adjusted"
+ ][:]
+ data["u_Obs_Minus_Forecast_unadjusted"] = ds[
+ "u_Obs_Minus_Forecast_unadjusted"
+ ][:]
+
+ # Create absolute observation time with 3-hourly DA window splitting
+ match = re.search(r"\.(\d{10})_", filename)
+ if match:
+ date_str = match.group(1)
+ NS_3H = np.int64(3) * 3_600_000_000_000
+ base_time = pd.to_datetime(date_str, format="%Y%m%d%H")
+ base_ns = base_time.value
+ hours = time.astype(np.float64)
+ hours_to_ns = np.rint(hours * 3600.0 * 1e9).astype(np.int64)
+
+ # clip to be > -3h and <= 3h
+ hours_to_ns = np.clip(hours_to_ns, -NS_3H + 1, NS_3H)
+ abs_time_ns = base_ns + hours_to_ns
+ abs_time = abs_time_ns.astype("datetime64[ns]")
+ data["Absolute_Obs_Time"] = abs_time
+
+ # assign exactly two labels: t (left half) or t+3h (right half)
+ da_ns = np.where(abs_time_ns <= base_ns, base_ns, base_ns + NS_3H)
+ data["DA_window"] = da_ns.astype("datetime64[ns]")
+ else:
+ raise RuntimeError(f"No date match found for {filename}")
+
+ # Flatten the data to have a single Observation column
+ meta_df = pd.DataFrame(data)
+ dfs = []
+ for k, column in enumerate(observation_columns):
+ raw_channel_id = k + CONV_CHANNEL_MAP[platform]
+
+ this_df = meta_df.assign(
+ Observation=ds[column][:],
+ Global_Channel_ID=get_global_channel_id("conv", raw_channel_id),
+ )
+ dfs.append(this_df)
+ df = pd.concat(dfs)
+
+ output_schema = get_combined_observation_schema()
+ for field in output_schema:
+ if field.name not in df.columns:
+ df[field.name] = pd.NA
+ return pa.table(df, schema=output_schema)
+
+ except Exception as e:
+ print(f"Error reading conventional file {full_file_path}: {e}")
+ return None
+
+
+def read_variables_from_file(
+ full_file_path: str, sensor: str, obs_type: Literal["ges", "anl"] = "ges"
+) -> Optional[pa.Table]:
+ """Unified function to read variables from NetCDF files."""
+ if is_satellite_sensor(sensor):
+ return read_satellite_variables_from_file(full_file_path, sensor, obs_type)
+ elif is_conventional_sensor(sensor):
+ return read_conventional_variables_from_file(full_file_path, obs_type)
+ else:
+ raise ValueError(f"Unknown sensor type: {sensor}")
+
+
+def extract_info_from_filename(
+ filename: str, obs_type: Literal["ges", "anl"] = "ges"
+) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]:
+ """Extract sensor, platform, type, and date from filename."""
+ if obs_type == "ges":
+ pattern = re.compile(r"diag_([\w-]+)_([\w-]+)_(ges)\.(\d{10})_")
+ elif obs_type == "anl":
+ pattern = re.compile(r"diag_([\w-]+)_([\w-]+)_(anl)\.(\d{10})_")
+ else:
+ raise ValueError(f"Invalid obs type: {obs_type}")
+
+ match = pattern.match(filename)
+ if match:
+ sensor_name, platform, file_type, full_date = match.groups()
+ day_date = full_date[:8] # YYYYMMDD
+ return sensor_name, platform, file_type, day_date, full_date
+ return None, None, None, None, None
+
+
+def parse_sensor_list(sensor_str: str) -> List[str]:
+ """Parse comma-separated sensor list and validate sensors."""
+ if sensor_str == "all":
+ return ["all"]
+
+ sensors = [s.strip() for s in sensor_str.split(",")]
+
+ # Validate each sensor
+ for sensor in sensors:
+ if sensor not in ["all", "conv"] and not is_satellite_sensor(sensor):
+ raise ValueError(
+ f"Invalid sensor: {sensor}. Must be 'all', 'conv', or a satellite sensor name."
+ )
+
+ return sensors
+
+
+def filter_files_by_sensor(
+ files: List[str], target_sensors: List[str], obs_type: Literal["ges", "anl"] = "ges"
+) -> Dict[str, List[str]]:
+ """Filter files based on target sensors."""
+ if "all" in target_sensors:
+ # Process all sensors
+ filtered_files = defaultdict(list)
+ seen_keys = set()
+
+ for file_path in files:
+ if "spinup/" in file_path or "overlap/" in file_path:
+ continue
+
+ filename = os.path.basename(file_path)
+ sensor_name, platform, file_type, day_date, full_date = (
+ extract_info_from_filename(filename, obs_type)
+ )
+
+ if not all([sensor_name, platform, day_date]):
+ continue
+
+ # Determine the sensor key for output directory
+ if sensor_name == "conv":
+ sensor_key = "conv"
+ else:
+ sensor_key = sensor_name
+
+ # Avoid duplicates
+ key = (full_date, sensor_name, platform)
+ if key in seen_keys:
+ continue
+ seen_keys.add(key)
+
+ filtered_files[sensor_key].append(file_path)
+
+ return dict(filtered_files)
+
+ else:
+ # Process specific sensors
+ filtered_files = defaultdict(list)
+ seen_keys = set()
+
+ for file_path in files:
+ if "spinup/" in file_path or "overlap/" in file_path:
+ continue
+
+ filename = os.path.basename(file_path)
+ sensor_name, platform, file_type, day_date, full_date = (
+ extract_info_from_filename(filename, obs_type)
+ )
+
+ if not all([sensor_name, platform, day_date]):
+ continue
+
+ # Check if this file matches any of our target sensors
+ file_matches = False
+ target_sensor_key = None
+
+ if sensor_name == "conv" and "conv" in target_sensors:
+ file_matches = True
+ target_sensor_key = "conv"
+ elif sensor_name in target_sensors:
+ file_matches = True
+ target_sensor_key = sensor_name
+
+ if not file_matches:
+ continue
+
+ # Avoid duplicates
+ key = (full_date, sensor_name, platform)
+ if key in seen_keys:
+ continue
+ seen_keys.add(key)
+
+ filtered_files[target_sensor_key].append(file_path)
+
+ return dict(filtered_files)
+
+
+def process_day(
+ full_file_paths: List[str],
+ output_path: str,
+ sensor: str,
+ obs_type: Literal["ges", "anl"] = "ges",
+):
+ """Process a single day of data for a sensor."""
+ if os.path.exists(output_path):
+ return
+
+ tables = []
+ for full_file_path in full_file_paths:
+ data = read_variables_from_file(full_file_path, sensor, obs_type)
+ if data is not None:
+ tables.append(data)
+
+ if len(tables) == 0:
+ return
+
+ table = pa.concat_tables(tables)
+
+ # Sort by DA_window for better compression and reading performance
+ sort_idx = pc.sort_indices(table, sort_keys=[("DA_window", "ascending")])
+ table = pc.take(table, sort_idx)
+
+ # Group by DA_window for chunked reading (each DA window in separate row group)
+ col = table.column("DA_window").combine_chunks()
+ vals = col.to_numpy() # numpy datetime64[ns]
+ change = np.empty(len(vals), dtype=bool)
+ change[0] = True
+ change[1:] = vals[1:] != vals[:-1]
+
+ starts = np.flatnonzero(change)
+ ends = np.concatenate([starts[1:], [len(vals)]])
+
+ # Write to temporary file first, then atomically move to final location
+ output_dir = os.path.dirname(output_path)
+ os.makedirs(output_dir, exist_ok=True)
+ tmp_path = None
+ try:
+ with tempfile.NamedTemporaryFile(
+ mode="wb",
+ dir=output_dir,
+ suffix=".tmp",
+ delete=False,
+ ) as tmp_file:
+ tmp_path = tmp_file.name
+
+ with pq.ParquetWriter(tmp_path, table.schema) as writer:
+ for start, end in zip(starts, ends):
+ slice_ = table.slice(start, end - start)
+ # Force Arrow to put each DA_window in its own row-group
+ writer.write_table(slice_, row_group_size=len(slice_))
+
+ # Atomically move the temporary file to the final location
+ os.rename(tmp_path, output_path)
+ tmp_path = None # Success, don't clean up
+
+ except Exception as e:
+ # Clean up temporary file if something went wrong
+ if tmp_path and os.path.exists(tmp_path):
+ try:
+ os.unlink(tmp_path)
+ except OSError:
+ pass # Ignore cleanup errors
+ raise e
+
+
+def main():
+ """Main processing function."""
+ args = parse_args()
+
+ # Parse sensor list
+ target_sensors = parse_sensor_list(args.sensor)
+ ufs_raw_obs_dir = args.input_dir
+ base_dir = args.output_dir
+ obs_type = args.type
+ output_base_dir = f"{base_dir}_{obs_type}"
+
+ print(f"Processing sensors: {', '.join(target_sensors)}")
+ print(f"Input directory: {ufs_raw_obs_dir}")
+ print(f"Output base directory: {output_base_dir}")
+ print(f"Processing {obs_type} data")
+
+ # Save channel table at root level of output directory
+ print("\nSaving channel table...")
+ channel_table = get_channel_table()
+ channel_table_path = os.path.join(output_base_dir, "channel_table.parquet")
+ os.makedirs(output_base_dir, exist_ok=True)
+ pq.write_table(channel_table, channel_table_path)
+ print(f"Channel table saved to: {channel_table_path}")
+
+ if args.channel_table_only:
+ print("--channel-table-only specified, skipping observation processing.")
+ return
+
+ # Read file list
+ files = list_nc_files(args.input_dir)
+ print(f"Total files detected: {len(files):,}")
+
+ # Filter files by sensors
+ sensor_files = filter_files_by_sensor(files, target_sensors, obs_type)
+
+ if not sensor_files:
+ print("No files found to process!")
+ return
+
+ # Report what we found
+ total_days = 0
+ for sensor_name, file_list in sensor_files.items():
+ # Group files by date
+ date_to_files = defaultdict(list)
+ for file_path in file_list:
+ filename = os.path.basename(file_path)
+ _, _, _, day_date, _ = extract_info_from_filename(filename, obs_type)
+ if day_date:
+ date_to_files[day_date].append(file_path)
+
+ print(f"\n{sensor_name.upper()} sensor: {len(date_to_files)} days")
+ total_days += len(date_to_files)
+
+ # Show sample dates
+ sample_dates = sorted(date_to_files.keys())[:5]
+ print(f" Sample dates: {sample_dates}")
+ if len(date_to_files) > 5:
+ print(f" ... and {len(date_to_files) - 5} more")
+
+ print(f"\nTotal processing jobs: {total_days}")
+
+ # Create output directories and prepare all jobs
+ all_jobs = []
+ job_metadata = []
+
+ for sensor_name, file_list in sensor_files.items():
+ # Create output directory for this sensor
+ sensor_output_dir = os.path.join(output_base_dir, sensor_name)
+
+ # Group files by date
+ date_to_files = defaultdict(list)
+ for file_path in file_list:
+ filename = os.path.basename(file_path)
+ _, _, _, day_date, _ = extract_info_from_filename(filename, obs_type)
+ if day_date:
+ date_to_files[day_date].append(file_path)
+
+ # Prepare processing jobs
+ for date, day_files in date_to_files.items():
+ # Pre-compute all paths
+ full_file_paths = [
+ os.path.join(ufs_raw_obs_dir, file_path) for file_path in day_files
+ ]
+ output_path = os.path.join(sensor_output_dir, f"{date}", "0.parquet")
+
+ job_args = (
+ full_file_paths,
+ output_path,
+ sensor_name,
+ obs_type,
+ )
+ all_jobs.append(job_args)
+ job_metadata.append((sensor_name, date))
+
+ # Process files using map
+ print(f"\nPrepared {len(all_jobs)} jobs")
+ print(f"Starting parallel processing with {args.num_workers} workers...")
+
+ random.shuffle(all_jobs)
+
+ with ProcessPoolExecutor(args.num_workers) as executor:
+ # Use map to process all jobs
+ list(
+ tqdm(
+ executor.map(process_day, *zip(*all_jobs)),
+ total=len(all_jobs),
+ desc="Processing",
+ )
+ )
+
+ # Count results (map doesn't raise exceptions, so we need to check differently)
+ completed = len(all_jobs) # All jobs completed (map waits for all)
+ failed = 0 # We can't easily detect failures with map, but all jobs completed
+
+ print("\nETL Complete!")
+ print(f"Successfully processed: {completed} sensor-days")
+ print(f"Failed: {failed} sensor-days")
+
+ # Report output directories
+ print("\nOutput directories created:")
+ for sensor_name in sensor_files.keys():
+ sensor_dir = os.path.join(output_base_dir, sensor_name)
+ if os.path.exists(sensor_dir):
+ file_count = len(
+ [f for f in os.listdir(sensor_dir) if f.endswith(".parquet")]
+ )
+ print(f" {sensor_dir}: {file_count} parquet files")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/weather/healda/datasets/etl/normalizations/amsua_normalizations.csv b/examples/weather/healda/datasets/etl/normalizations/amsua_normalizations.csv
new file mode 100644
index 0000000000..44ccce3ac0
--- /dev/null
+++ b/examples/weather/healda/datasets/etl/normalizations/amsua_normalizations.csv
@@ -0,0 +1,113 @@
+Raw_Channel_ID,Platform_ID,obs_std,obs_mean
+1,-1,43.14827526898869,208.81245889545207
+1,14,43.245863004474046,208.5147247714152
+1,15,42.644794686368726,207.84771330773998
+1,16,43.32608185643857,208.88801777990457
+1,20,42.862800259005354,209.69306267322926
+1,21,41.00672569472209,206.3581961906114
+1,23,43.10684451429107,208.63778656625053
+1,24,43.477805956939,208.56640086972172
+2,-1,44.96814209305951,200.67487673075226
+2,14,45.0688240991038,200.64044574997308
+2,15,44.65034395075911,199.57948659408837
+2,16,45.10653156173914,200.78904240241823
+2,20,44.54634966722693,200.9030014910355
+2,21,42.16959824168713,197.74683727240892
+2,23,44.961453342136245,200.72370211205973
+2,24,45.38303896246749,200.65617535973996
+3,-1,20.60997229307027,241.33242782399282
+3,14,20.638403211245485,240.9346938961322
+3,15,20.56083967162193,241.01957817359767
+3,16,20.756384664385838,241.36646570607004
+3,20,20.40978773273391,242.06647213486357
+3,21,18.549164473665893,239.64565679658295
+3,23,20.594839602248793,241.05065270256063
+3,24,20.776392520670253,241.31193090983035
+4,-1,12.391091090191496,255.40195807265977
+4,14,12.39558220414599,255.2485698626346
+4,15,12.441051959212107,255.61576880929422
+4,16,12.39107026991132,255.39039410065763
+4,20,12.457001058263584,255.74958936586165
+4,21,11.702991216116937,254.80897097129633
+4,23,12.367115477186875,255.28127542166982
+4,24,12.324249047160235,255.26146633367378
+5,-1,9.72289107163208,248.30325035989532
+5,14,9.712855426440534,248.12249857181723
+5,15,9.782916048429138,248.4087687449923
+5,16,9.773101672516752,248.26028480979204
+5,20,9.71372617268573,248.84428383985446
+5,21,9.270351647297238,248.21563820415776
+5,23,9.699518172278252,248.0950168016271
+5,24,9.705036110488924,247.99733787940366
+6,-1,6.591734627516939,233.2546738888034
+6,14,6.626838235834015,233.21555828000677
+6,15,6.715863262633315,233.8078009560835
+6,16,6.751256745920496,233.18885362349366
+6,20,6.2770916623924276,233.12018885091325
+6,21,6.055423123419333,233.83423502032096
+6,23,6.615882647299575,233.14683849524496
+6,24,6.603232459352038,233.36377680690163
+7,-1,5.200585036541759,223.37161080931776
+7,14,5.328507768691345,222.63488189090833
+7,15,5.290039398907846,223.40604859527315
+7,16,5.286683193657975,222.87131600917934
+7,20,5.065071419902494,224.32254865454104
+7,21,4.660127715690419,223.87643319144212
+7,23,5.171385847225267,222.70343722755194
+7,24,5.07430384498421,223.1480783475126
+8,-1,6.0071988697567456,216.48895482206066
+8,14,5.993478811669694,216.18038748856023
+8,15,5.992285017757789,216.32746437053305
+8,16,5.9158778196382995,216.28546186411424
+8,20,6.002563219758394,217.0717997186405
+8,21,6.320274200173209,217.0598985070878
+8,23,6.017509215427473,215.95926857080985
+8,24,5.297709489024461,215.76674689563418
+9,-1,8.509114940634637,211.2537007874646
+9,14,8.546910890556669,211.04022344033265
+9,15,8.485593502788452,211.30400409997011
+9,16,8.31323746741205,210.86774331842972
+9,20,8.532774635936512,211.7353577068666
+9,23,8.602809925010613,210.69029206493593
+9,24,8.490408071518324,211.0450224915925
+10,-1,8.008770826557447,214.5290848362116
+10,14,8.031084627218561,214.49865927153715
+10,15,7.98725303457022,214.67147267452094
+10,16,7.901115635400929,214.31155514650166
+10,20,8.053008659834964,215.03146673742023
+10,23,8.016193430235884,214.09543446059095
+10,24,7.944071272830509,214.45702080831023
+11,-1,8.572729274070603,220.667324154054
+11,14,8.612188988458303,220.78406698291334
+11,15,8.578335740928825,220.91224176979185
+11,16,8.606801252318974,220.51350307302872
+11,23,8.590054337884801,220.36146839247078
+11,24,8.489296811083372,220.762503852709
+12,-1,9.827017164339123,229.2238827001046
+12,14,9.848325150318507,229.15015194738777
+12,15,9.787099403535226,229.28036940555563
+12,16,9.802699078194221,228.54596812574724
+12,20,9.876530004901452,229.60541801900766
+12,23,9.857903078183186,228.98018510273
+12,24,9.733668818886466,229.31037852193862
+13,-1,10.657981916161267,239.62808380868842
+13,14,10.676786662443906,239.41935483813367
+13,15,10.57700750393631,239.59523381548578
+13,16,10.591162650433894,239.05696101001703
+13,20,10.702002553441044,240.14076712717042
+13,23,10.708136300859543,239.4459129578857
+13,24,10.59151936202361,239.66792244377427
+14,-1,10.460706383630459,249.80052228607403
+14,14,10.518120440742436,249.5965305240693
+14,15,10.355629661230186,249.94929470130802
+14,16,10.340302426577065,249.42994207978592
+14,23,10.525085837365992,249.78271346792957
+14,24,10.436851910807803,250.01848408363801
+15,-1,27.206282747597683,241.08475017755075
+15,14,27.38726738210949,241.0812384969291
+15,15,27.297859237142635,241.426373404627
+15,16,27.312180489599005,241.39647412995765
+15,20,26.740265157879765,240.9054814678835
+15,21,26.25128791327429,237.4066499607801
+15,23,27.290915782286874,240.861508938203
+15,24,27.384349423192475,241.38266784296735
diff --git a/examples/weather/healda/datasets/etl/normalizations/amsub_normalizations.csv b/examples/weather/healda/datasets/etl/normalizations/amsub_normalizations.csv
new file mode 100644
index 0000000000..07d1fe95b3
--- /dev/null
+++ b/examples/weather/healda/datasets/etl/normalizations/amsub_normalizations.csv
@@ -0,0 +1,16 @@
+Raw_Channel_ID,Platform_ID,obs_std,obs_mean
+1,-1,27.2737718832671,236.66623997760456
+1,21,27.025331796802135,236.49292906247072
+1,22,27.46773606448012,236.80319681448057
+2,-1,27.203280859273686,259.53019990435104
+2,21,27.22256458977998,259.7678475989869
+2,22,27.186577013708682,259.34254894910356
+3,-1,9.489860500316203,247.52103646621188
+3,21,9.638507561000745,247.71051502953418
+3,22,9.36779125986415,247.37103182370328
+4,-1,12.906392940834248,258.5867625015679
+4,21,12.922531248194963,259.3452157730607
+4,22,12.862125887997312,257.98840743403474
+5,-1,18.451601367383574,264.40828198308105
+5,21,18.290864177699667,263.3042943725722
+5,22,18.53145681557096,265.2848912188969
diff --git a/examples/weather/healda/datasets/etl/normalizations/atms_normalizations.csv b/examples/weather/healda/datasets/etl/normalizations/atms_normalizations.csv
new file mode 100644
index 0000000000..e1b35d0f4d
--- /dev/null
+++ b/examples/weather/healda/datasets/etl/normalizations/atms_normalizations.csv
@@ -0,0 +1,67 @@
+Raw_Channel_ID,Platform_ID,obs_std,obs_mean
+1,-1,43.09279016587131,210.02112362010823
+1,25,43.22600229860692,209.8786665026229
+1,26,43.02804318886837,210.09002498629346
+2,-1,44.86846399304575,201.07331959768962
+2,25,45.01897932416165,200.9561896404016
+2,26,44.79537395581081,201.12997111896166
+3,-1,20.91136978667041,240.9698729566799
+3,25,21.021514548376974,240.87188064377247
+3,26,20.857843018367628,241.0171624463704
+4,-1,15.539263168345263,251.10140095634452
+4,25,15.666611148345122,250.9308978682481
+4,26,15.476750931187427,251.18369532111635
+5,-1,11.871982173921344,255.5391212721772
+5,25,11.960523646605074,255.4958955549361
+5,26,11.828949044784242,255.55998657237063
+6,-1,9.684743648063607,248.12555462660075
+6,25,9.764115534036241,248.04013477579554
+6,26,9.645898474886815,248.16681639990207
+7,-1,6.784248199795853,233.28375341081514
+7,25,6.856615715379073,232.9826081385835
+7,26,6.744210848770702,233.4291778163615
+8,-1,5.345625849992116,222.9450280954434
+8,25,5.436060756468369,222.7293455903235
+8,26,5.298259131893759,223.0491825905517
+9,-1,5.991791106218113,215.77753474885782
+9,25,6.008561080213766,215.47870864850285
+9,26,5.97833237796406,215.92181916545474
+10,-1,8.255314915635376,211.32904065225617
+10,25,8.188014023078008,210.92054661357824
+10,26,8.280406858326261,211.52630443706133
+11,-1,7.73402919628314,214.83510628159783
+11,25,7.7618820734836405,214.45412105711267
+11,26,7.713808772103638,215.01908359278923
+12,-1,8.383691337628381,221.3587919894676
+12,25,8.511318845815346,220.9317387710988
+12,26,8.313507519747509,221.56501849554937
+13,-1,9.592169458722168,229.7670172844918
+13,25,9.730511977369336,229.30138467444334
+13,26,9.516498564639026,229.99184850867996
+14,-1,10.36873946997381,240.4398718490441
+14,25,10.485895010432683,240.0398505723095
+14,26,10.3061359195915,240.63302697849377
+15,-1,10.10820595492633,250.57600766864383
+15,25,10.142481762137182,250.22218274721686
+15,26,10.087171228051647,250.74685867758924
+16,-1,27.16492755041575,241.56023841681622
+16,25,27.28324796338492,241.81119744782654
+16,26,27.106770027220723,241.43904225866117
+17,-1,25.07116867528301,263.5651818722436
+17,25,25.050729106218757,263.4870884050113
+17,26,25.080946292324455,263.6028943828347
+18,-1,18.129795132200364,264.1620960226355
+18,25,18.08394182285079,263.9588970537857
+18,26,18.151082646579844,264.2602246948714
+19,-1,14.859816103843695,261.3678407504765
+19,25,14.858920932630344,261.123239331056
+19,26,14.858806723767309,261.48596330493314
+20,-1,11.999951828546328,257.6685400346904
+20,25,12.028708960118182,257.37727666792404
+20,26,11.983504400438001,257.8092033704697
+21,-1,10.093610918549906,252.1632616023329
+21,25,10.136081891315454,251.87539123929696
+21,26,10.070091212448709,252.30227626712684
+22,-1,8.78830692800466,246.87701063895838
+22,25,8.83211644958051,246.57528075835782
+22,26,8.76335506458442,247.02271043043058
diff --git a/examples/weather/healda/datasets/etl/normalizations/conv_normalizations.csv b/examples/weather/healda/datasets/etl/normalizations/conv_normalizations.csv
new file mode 100644
index 0000000000..abe77dbbc8
--- /dev/null
+++ b/examples/weather/healda/datasets/etl/normalizations/conv_normalizations.csv
@@ -0,0 +1,9 @@
+Raw_Channel_ID,Platform_ID,obs_std,obs_mean,obs_min,obs_max
+1,-1,0.006021306479256463,0.004113416818592855,9.99999993922529e-09,0.0828777477145195
+2,-1,24.040581924661886,237.75966260053028,171.33290100097656,322.6221618652344
+3,-1,0.001796639343013556,0.0004967400106678111,1.0000000116860974e-07,0.026085082441568375
+4,-1,51.69925682101712,983.1479345318987,200.0,1100.0
+5,-1,0.005461665852510159,0.004603054301095736,0.0,0.06551100313663483
+6,-1,27.571154926597,258.28742986188126,173.14999389648438,349.95001220703125
+7,-1,14.89412603357212,7.152579155956843,-100.0,100.0
+8,-1,10.19761780393707,0.21634765886519516,-100.0,100.0
diff --git a/examples/weather/healda/datasets/etl/normalizations/mhs_normalizations.csv b/examples/weather/healda/datasets/etl/normalizations/mhs_normalizations.csv
new file mode 100644
index 0000000000..49df36963d
--- /dev/null
+++ b/examples/weather/healda/datasets/etl/normalizations/mhs_normalizations.csv
@@ -0,0 +1,31 @@
+Raw_Channel_ID,Platform_ID,obs_std,obs_mean
+1,-1,27.314051741583636,237.9566862096014
+1,14,27.237688807568933,237.82756663196426
+1,15,27.25539242564268,238.15684577291427
+1,16,27.104623367838013,238.3287148131825
+1,23,27.342755373663014,237.84289789226148
+1,24,27.452495102189413,237.89947810831046
+2,-1,26.618813103955723,261.915233839472
+2,14,26.714805283241724,261.70902595462064
+2,15,26.566889256684032,262.0244276665933
+2,16,26.501021796581963,262.0089071790408
+2,23,26.635057018276363,261.8615537641429
+2,24,26.598255721079855,262.0179211819286
+3,-1,9.442343271781501,247.38245044616173
+3,14,9.360823103781955,247.3196715240357
+3,15,9.358463680886302,247.33785121070034
+3,16,9.332064063561445,247.28928810997937
+3,23,9.357006869204472,247.3811298765863
+3,24,10.021122147008533,247.67779653526242
+4,-1,12.917529068672312,258.2913764089477
+4,14,12.894516788434766,258.1149392180758
+4,15,12.878421890127624,258.1926237313513
+4,16,12.800390334130613,258.1741921825457
+4,23,12.968961486043167,258.30936712616074
+4,24,12.952925642222597,258.52580165711
+5,-1,17.997020725294178,265.0566627761782
+5,14,18.121249551151013,265.00266964773476
+5,15,18.000085396832166,265.10250063093014
+5,16,17.85497341358148,265.1404196358589
+5,23,18.051736849065584,264.9905259829178
+5,24,17.88897715825186,265.09578591765796
diff --git a/examples/weather/healda/datasets/etl/pull_from_noaa_s3.sh b/examples/weather/healda/datasets/etl/pull_from_noaa_s3.sh
new file mode 100644
index 0000000000..4fbc1513d7
--- /dev/null
+++ b/examples/weather/healda/datasets/etl/pull_from_noaa_s3.sh
@@ -0,0 +1,66 @@
+#!/usr/bin/env bash
+# ---------------------------------------------------------------
+# Pull UFS GEFS-v13 replay observation from NOAA S3
+# to local lustre storage.
+#
+# Usage:
+# ./pull_from_noaa_s3.sh # real copy
+# ./pull_from_noaa_s3.sh --dry-run # preview only
+# ---------------------------------------------------------------
+set -euo pipefail
+
+# Set your destination directory (or set UFS_RAW_OBS_DIR env var)
+DST_DIR="${UFS_RAW_OBS_DIR:-/path/to/your/raw_obs}"
+
+
+export S5CMD_STAT_PERIOD=10s
+
+YEARS=(2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023)
+# Example file structure:
+# + **/diag_amsua_*_ges.2018*_control.nc4
+# + **/diag_atms_*_ges.2018*_control.nc4
+# + **/diag_iasi_*_ges.2018*_control.nc4
+# + **/diag_mhs_*_ges.2018*_control.nc4
+# + **/diag_cris-fsr_*_ges.2018*_control.nc4
+# + **/diag_conv_uv_ges.2018*_control.nc4
+
+# SENSORS=(conv_uv conv_t conv_q conv_gps conv_ps) # Conventional
+SENSORS=(amsua amsub atms mhs) # Microwave, can add iasi/cris-fsr for infrared
+
+KIND="ges" # "ges" or "anl" (contain different innovations but identical "Observation" values)
+NUM_WORKERS=512 # s5cmd worker pool
+# -----------------------------------------------------------------
+
+DRY=""
+[[ ${1:-} == "--dry-run" ]] && DRY="--dry-run"
+
+RUNFILE=$(mktemp)
+trap 'rm -f "$RUNFILE"' EXIT
+
+# Build s5cmd run-file: one cp line per year ร sensor
+# Conventional sensors (conv_*) have no satellite platform in filename
+# Satellite sensors have platform suffix (e.g., amsua_metop-b, atms_n20)
+for yr in "${YEARS[@]}"; do
+ for sensor in "${SENSORS[@]}"; do
+ if [[ $sensor == conv_* ]]; then
+ # Conventional: diag_conv_gps_ges.2022*_control.nc4
+ printf -- 'cp --sp "s3://noaa-ufs-gefsv13replay-pds/%s/*/*/gsi/diag_%s_%s.%s*_control.nc4" "%s/%s/"\n' \
+ "$yr" "$sensor" "$KIND" "$yr" "$DST_DIR" "$yr" >>"$RUNFILE"
+ else
+ # Satellite mw/ir: diag_amsua_*_ges.2022*_control.nc4 (wildcard for platform)
+ printf -- 'cp --sp "s3://noaa-ufs-gefsv13replay-pds/%s/*/*/gsi/diag_%s_*_%s.%s*_control.nc4" "%s/%s/"\n' \
+ "$yr" "$sensor" "$KIND" "$yr" "$DST_DIR" "$yr" >>"$RUNFILE"
+ fi
+ done
+done
+
+echo ">> built $(wc -l <"$RUNFILE") copy commands in $RUNFILE"
+
+echo ">> Generated commands:"
+cat "$RUNFILE"
+echo ">> End of generated commands"
+
+s5cmd --no-sign-request --stat \
+ $DRY \
+ --numworkers "$NUM_WORKERS" \
+ run "$RUNFILE"
diff --git a/examples/weather/healda/datasets/features.py b/examples/weather/healda/datasets/features.py
new file mode 100644
index 0000000000..a7390842d6
--- /dev/null
+++ b/examples/weather/healda/datasets/features.py
@@ -0,0 +1,194 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Literal
+
+import torch
+
+from utils import profiling
+
+
+@profiling.nvtx
+def compute_unified_metadata(
+ target_time_sec: torch.Tensor, # int64 seconds
+ lat: torch.Tensor,
+ lon: torch.Tensor,
+ time: torch.Tensor, # int64 nanoseconds
+ # Raw metadata fields
+ height: torch.Tensor | None = None,
+ pressure: torch.Tensor | None = None,
+ scan_angle: torch.Tensor | None = None,
+ sat_zenith_angle: torch.Tensor | None = None,
+ sol_zenith_angle: torch.Tensor | None = None,
+) -> torch.Tensor:
+ """
+ Compute unified metadata from raw fields.
+
+ Features are concatenated in the following order:
+ - Local solar time (4 features): Fourier encoding with 2 frequencies
+ - Relative time features (2 features): normalized time difference and its square
+ - Height features (8 features, NaN for satellite): Fourier encoding with 4 frequencies
+ - Pressure features (8 features, NaN for satellite): Fourier encoding with 4 frequencies
+ - Scan angle features (2 features, NaN for conventional): normalized scan angle and its square
+ - Satellite zenith features (2 features, NaN for conventional): cos(ฮธ_sat) and cos(ฮธ_sat)ยฒ
+ - Solar zenith features (2 features, NaN for conventional): cos(ฮธ_sun) and sin(ฮธ_sun)
+
+ Note: time inputs use int64 to preserve precision. Float conversion happens only
+ after magnitude reduction to avoid precision loss with large Unix timestamps.
+ """
+ device = lat.device
+ n_obs = lat.shape[0]
+
+ lst = local_solar_time(lon, time)
+
+ # Build metadata features as a list
+ metadata_features = []
+
+ # Local solar time features (4 features)
+ local_solar_time_feats = fourier_features(
+ lst / 24.0, 2
+ ) # 2 frequencies = 4 features
+ metadata_features.append(local_solar_time_feats)
+
+ # Relative time features (2 features)
+ target_time_ns = target_time_sec * 1_000_000_000
+ dt_sec = (time - target_time_ns).float() * 1e-9
+ relative_time_hours = dt_sec / 3600.0
+ dt_norm = relative_time_hours / 24.0 # Normalize
+ time_norm_feats = torch.stack([dt_norm, dt_norm**2], dim=-1)
+ metadata_features.append(time_norm_feats)
+
+ # Height features (16 features, NaN for satellite)
+ if height is not None:
+ height_norm = normalize(
+ height,
+ "linear",
+ 100.0, # height_min
+ 60000.0, # height_max
+ 0.5, # height_power
+ )
+ height_feats = fourier_features(height_norm, 4) # 4 frequencies = 8 features
+ metadata_features.append(height_feats)
+ else:
+ # Add NaN tensor for height features
+ metadata_features.append(
+ torch.full((n_obs, 8), float("nan"), device=device, dtype=torch.float32)
+ )
+
+ # Pressure features (16 features, NaN for satellite)
+ if pressure is not None:
+ pressure_norm = normalize(
+ pressure,
+ "linear",
+ 10.0, # pressure_min
+ 1100.0, # pressure_max
+ 3.0, # pressure_power
+ )
+ pressure_feats = fourier_features(
+ pressure_norm, 4
+ ) # 4 frequencies = 8 features
+ metadata_features.append(pressure_feats)
+ else:
+ # Add NaN tensor for pressure features
+ metadata_features.append(
+ torch.full((n_obs, 8), float("nan"), device=device, dtype=torch.float32)
+ )
+
+ # Scan angle features (2 features, NaN for conventional)
+ if scan_angle is not None:
+ xi_norm = scan_angle / 50.0 # ~[-1,1] as in existing code
+ scan_angle_feats = torch.stack([xi_norm, xi_norm**2], dim=-1)
+ metadata_features.append(scan_angle_feats)
+ else:
+ # Add NaN tensor for scan angle features
+ metadata_features.append(
+ torch.full((n_obs, 2), float("nan"), device=device, dtype=torch.float32)
+ )
+
+ # Satellite zenith features (2 features, NaN for conventional)
+ if sat_zenith_angle is not None:
+ cos_theta_sat = torch.cos(torch.deg2rad(sat_zenith_angle))
+ sat_zenith_feats = torch.stack([cos_theta_sat, cos_theta_sat**2], dim=-1)
+ metadata_features.append(sat_zenith_feats)
+ else:
+ # Add NaN tensor for satellite zenith features
+ metadata_features.append(
+ torch.full((n_obs, 2), float("nan"), device=device, dtype=torch.float32)
+ )
+
+ # Solar zenith features (2 features, NaN for conventional)
+ if sol_zenith_angle is not None:
+ cos_theta_sun = torch.cos(torch.deg2rad(sol_zenith_angle))
+ sin_theta_sun = torch.sin(torch.deg2rad(sol_zenith_angle))
+ sol_zenith_feats = torch.stack([cos_theta_sun, sin_theta_sun], dim=-1)
+ metadata_features.append(sol_zenith_feats)
+ else:
+ # Add NaN tensor for solar zenith features
+ metadata_features.append(
+ torch.full((n_obs, 2), float("nan"), device=device, dtype=torch.float32)
+ )
+
+ # Concatenate all features
+ metadata = torch.cat(metadata_features, dim=-1)
+ metadata = metadata.nan_to_num(0.0)
+
+ return metadata
+
+
+def normalize(
+ x: torch.Tensor,
+ scale: Literal["linear", "log", "power"],
+ x_min: float,
+ x_max: float,
+ power: float,
+) -> torch.Tensor:
+ # map x onto [0,1] using chosen scale
+ if scale == "linear":
+ return torch.clamp(x / x_max, 0.0, 1.0)
+ elif scale == "log":
+ # ensure positive
+ return (torch.log(x + x_min) - math.log(x_min)) / (
+ math.log(x_max + x_min) - math.log(x_min)
+ )
+ elif scale == "power":
+ x_lin = torch.clamp(x / x_max, 0.0, 1.0)
+ return x_lin.pow(power)
+ else:
+ raise ValueError(f"Unknown scale '{scale}'")
+
+
+def fourier_features(x_norm: torch.Tensor, num_freqs: int) -> torch.Tensor:
+ # x_norm: (N,) in [0,1]
+ # produce (N, 2*num_freqs) of sin/cos features
+ device = x_norm.device
+ freqs = torch.arange(1, num_freqs + 1, device=device, dtype=x_norm.dtype) * (
+ 2 * math.pi
+ )
+ x_expanded = x_norm.unsqueeze(-1) * freqs # (N, num_freqs)
+ sin_feats = torch.sin(x_expanded)
+ cos_feats = torch.cos(x_expanded)
+ return torch.cat([sin_feats, cos_feats], dim=-1)
+
+
+def local_solar_time(
+ lon_deg: torch.Tensor,
+ abs_time_ns: torch.Tensor,
+) -> torch.Tensor:
+ # Approximate without equation of time correction
+ sec_of_day = (abs_time_ns // 1_000_000_000) % 86400
+ utc_hours = sec_of_day.float() / 3600.0
+ lst = (utc_hours + lon_deg / 15.0) % 24.0
+ return lst
diff --git a/examples/weather/healda/datasets/filter_times.py b/examples/weather/healda/datasets/filter_times.py
new file mode 100644
index 0000000000..e737ab8946
--- /dev/null
+++ b/examples/weather/healda/datasets/filter_times.py
@@ -0,0 +1,178 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import pandas as pd
+
+from datasets.base import DatasetMetadata
+
+
+def _get_coarsest_freq(freq1: str, freq2: str) -> str:
+ """Get the coarsest (largest time step) frequency between two pandas frequency strings."""
+ td1 = pd.Timedelta(freq1)
+ td2 = pd.Timedelta(freq2)
+ return freq1 if td1 >= td2 else freq2
+
+
+def get_chunk_aligned_times(
+ base_metadata: DatasetMetadata,
+ obs_metadata: DatasetMetadata,
+ chunk_size: int = 24,
+ dropouts: list[tuple[str, str]] = [
+ ("2019-01-01", "2020-06-30"),
+ ("2022-06-01", "2022-09-30"),
+ ],
+) -> pd.DatetimeIndex:
+ """
+ Get chunk-aligned times for a dataset filtered by dropouts. Ensures that starting from the first time
+ returned, times come in full chunks of `chunk_size` and are aligned to the chunk boundaries.
+
+
+ """
+
+ # get max of obs start and base start
+ desired_start = max(
+ pd.Timestamp(obs_metadata.start), pd.Timestamp(base_metadata.start)
+ )
+ desired_end = min(pd.Timestamp(obs_metadata.end), pd.Timestamp(base_metadata.end))
+
+ coarsest_freq = _get_coarsest_freq(base_metadata.freq, obs_metadata.freq)
+
+ # Get chunk-aligned start
+ base_start = pd.Timestamp(base_metadata.start)
+ aligned_start = _find_next_chunk_boundary(
+ desired_start, base_start, chunk_size, base_metadata.freq
+ )
+
+ # Filter out dropouts
+ available_times = _generate_available_times(
+ aligned_start, desired_end, coarsest_freq, dropouts
+ )
+
+ segments = _extract_chunk_aligned_segments(
+ available_times, base_start, chunk_size, base_metadata.freq, obs_metadata.freq
+ )
+ if not segments:
+ raise ValueError("No valid chunk-aligned segments found")
+
+ stacked_times = pd.DatetimeIndex(np.concatenate([seg.values for seg in segments]))
+
+ # validate present in base metadata and obs metadata
+ base_valid_times = pd.date_range(
+ base_metadata.start, base_metadata.end, freq=base_metadata.freq
+ )
+ obs_valid_times = pd.date_range(
+ obs_metadata.start, obs_metadata.end, freq=obs_metadata.freq
+ )
+
+ if not (
+ np.all(stacked_times.isin(base_valid_times))
+ and np.all(stacked_times.isin(obs_valid_times))
+ ):
+ raise RuntimeError(
+ "Some times in stacked_times are not in base_valid_times or obs_valid_times"
+ )
+
+ return stacked_times
+
+
+def _find_next_chunk_boundary(
+ desired_start: pd.Timestamp, dataset_start: pd.Timestamp, chunk_size: int, freq: str
+) -> pd.Timestamp:
+ """Find the next chunk boundary at or after desired_start."""
+ freq_td = pd.Timedelta(freq)
+
+ # Number of timestamps from dataset start
+ n_timestamps = int((desired_start - dataset_start) / freq_td)
+
+ # Round up to next chunk boundary if needed
+ if n_timestamps % chunk_size != 0:
+ n_timestamps = ((n_timestamps // chunk_size) + 1) * chunk_size
+
+ return dataset_start + (n_timestamps * freq_td)
+
+
+def _generate_available_times(
+ start: pd.Timestamp, end: pd.Timestamp, freq: str, dropouts: list[tuple[str, str]]
+) -> pd.DatetimeIndex:
+ """Generate times excluding dropout periods."""
+ # Generate all theoretical times
+ all_times = pd.date_range(start, end, freq=freq)
+
+ # Create availability mask
+ available_mask = np.ones(len(all_times), dtype=bool)
+
+ # Apply dropout masks
+ for dropout_start_str, dropout_end_str in dropouts:
+ dropout_start = pd.Timestamp(dropout_start_str)
+ dropout_end = pd.Timestamp(dropout_end_str)
+ dropout_mask = (all_times >= dropout_start) & (all_times <= dropout_end)
+ available_mask &= ~dropout_mask
+
+ return all_times[available_mask]
+
+
+def _extract_chunk_aligned_segments(
+ available_times: pd.DatetimeIndex,
+ dataset_start: pd.Timestamp,
+ chunk_size: int,
+ ds_freq: str,
+ obs_freq: str,
+) -> list[pd.DatetimeIndex]:
+ """Extract contiguous segments that are chunk-aligned."""
+ if len(available_times) == 0:
+ return []
+
+ segments = []
+ freq_td = pd.Timedelta(ds_freq)
+ freq_obs_td = pd.Timedelta(obs_freq)
+
+ # Find contiguous regions
+ time_series = available_times.to_series()
+ time_diffs = time_series.diff()
+ gap_threshold = freq_obs_td * 2 # Allow one missing timestamp
+
+ # Identify break points - need to get positional indices, not timestamps
+ break_indices = [0]
+ # Find where gaps occur and get their positional indices
+ gap_mask = time_diffs > gap_threshold
+ if gap_mask.any():
+ # Get positional indices where gaps occur
+ gap_positions = np.where(gap_mask)[0]
+ break_indices.extend(gap_positions.tolist())
+
+ break_indices.append(len(available_times))
+ # Process each contiguous region
+ for start_idx, end_idx in zip(break_indices[:-1], break_indices[1:]):
+ segment = available_times[start_idx:end_idx]
+
+ if len(segment) < chunk_size:
+ continue
+
+ # Align segment start to chunk boundary
+ segment_start = segment[0]
+ offset = int((segment_start - dataset_start) / freq_td)
+ skip_count = (chunk_size - (offset % chunk_size)) % chunk_size
+
+ if skip_count > 0:
+ segment = segment[skip_count:]
+
+ # Trim to chunk-aligned length
+ aligned_length = (len(segment) // chunk_size) * chunk_size
+ if aligned_length >= chunk_size:
+ segment = segment[:aligned_length]
+ segments.append(segment)
+
+ return segments
diff --git a/examples/weather/healda/datasets/gsi_codes/convdata_codes.csv b/examples/weather/healda/datasets/gsi_codes/convdata_codes.csv
new file mode 100644
index 0000000000..05c0c7be5c
--- /dev/null
+++ b/examples/weather/healda/datasets/gsi_codes/convdata_codes.csv
@@ -0,0 +1,92 @@
+,type,message,description,note,code
+0,111,SYNDAT,"SYNTHETIC (BOGUS) TROPICAL CYCLONE STORM CENTER (generated in SYNDAT_SYNDATA) - q, Pstn",Not created (switch not set in SYNDAT_SYNDATA parm cards),111
+1,112,,"PSEUDO MEAN SEA-LEVEL PRESSURE AT TROPICAL CYCLONE STORM CENTERย (generated in GSI, does not appear in pre-analysis PREPBUFR files) - Pstn",Pstn used by assimilation When implemented. may appear in post-analysis PREPBUFR file,112
+2,120,ADPUPA,"RAWINSONDE - Tv, q, Pstn, sst","Tv, q, Pstn used by assimilation sst monitored by assimilationย by switch in convinfo text file read by GBL-GSI",120
+3,122,ADPUPA,"CLASS SOUNDING - Tv, q, Pstn",Entire report tossed by switch in PREPOBS_PREPDATA parm cards Entire reportย neither monitored nor assimilated - not in convinfo text file read by GBL-GSI,122
+4,126,RASSDA,RASS [FROM NOAA PROFILERย NETWORK (NPN) OR MULTI-AGENCY PROFILER (MAP) NETWORK] - Tv,Tvย flagged for non-use by assimilation due to missing obs error Tv monitored by assimilation by switch in convinfo text file read by GBL-GSI NPNย data are no longer produced after 13 Septemberย 2017. Multiu-Agency data are still available.,126
+5,130,AIRCFT,AIREP AND PIREP AIRCRAFT - Ts,Ts used by assimilation,130
+6,131 R,AIRCFT,"AMDAR AIRCRAFT - Ts, q (E-AMDAR only)",q (E-AMDAR only) not considered by assimilation Ts used by assimilation,131
+7,132,ADPUPA,"FLIGHT-LEVEL RECONNAISSANCE AND PROFILE DROPSONDE - Tv, q, Pstn","q (non-U.S. drops, NASA global Hawk drops, all levels), Pstn (all drops) [andย surface level Tv (all drops), surface levelย q (all drops)] flagged for non-use by assimilation by switch in PREPOBS_PREPDATA parm cards Tv (all drops,ย above surface; reccos), q [U.S. (NOAA Gulf Stream and P-3, andย USAF)ย drops, above surface; reccos] used by assimilation q (non-U.S. drops,ย NASA global Hawk drops, all levels) and Tv (all drops, at surface) monitored by assimilation due to theirย being flagged byย switch in PREPOBS_PREPDATA parm cards",132
+8,133 R,AIRCAR,"MDCRS ACARS AIRCRAFT - Ts, q","Ts, q used by assimilation",133
+9,134 R,AIRCFT,"TAMDARย AIRCRAFT - Ts, q","Ts, q monitored by assimilation by switch in convinfo text file read by GBL-GSI",134
+10,135 R,AIRCFT,CANADIAN AMDAR AIRCRAFT - Ts,Tsย monitored by assimilation by switch in convinfo text file read by GBL-GSI,135
+11,150,SPSSMI,SSM/I SUPEROBED (1 DEGREE LAT/LON) FNMOC (OPERATIONAL) RAIN RATE (DMSP) - rr,rr monitoredย by assimilation by switch in pcpinfo text file read by GBL-GSI Currently the assimilation obtains this directlyย from the spssmi dump file rather thanย from PREPBUFR file. The SSM/I F-13 satelliteย went bad in November 2009 resulting in no data being processed upstream. ย The empty dumps were turned off and the PREPOBS_PREPDATA parm cards were set to no longer process these dataย inย October 2010.,150
+12,151,GOESND,"NESDIS 1x1 F-O-V CLOUD TOP PRESSURE, TEMPERATURE;ย CLOUD AMOUNT (GOES)",Not dumped,151
+13,152,SPSSMI,SSM/I SUPEROBED (1 DEGREE LAT/LON) NEURAL NET-3ย PRECIPITABLE WATER OVER OCEAN (DMSP) - PWt,PWtย flagged for non-use by assimilation due to missing obs error PWt monitored by assimilation by switch in convinfo text file read by GBL-GSI The SSM/I F-13 satelliteย went bad in November 2009 resulting in no data being processed upstream. ย The empty dumps were turned off and the PREPOBS_PREPDATA parm cards were set to no longer process these dataย inย October 2010.,152
+14,153,GPSIPW,GPS-INTEGRATED PRECIPITABLE WATER (GPS-IPW) - PWt,Currently only reportsย from ENI (mainly over U.S.) are encoded into PREPBUFR file. European GNSS reports are skipped. PWtย flagged for non-use by assimilation due to missing obs error PWt monitored by assimilation by switch in convinfo text file read by GBL-GSI,153
+15,154,,reserved for GOES IMAGER SKY-COVER DATA used only in RTMA/URMA,,154
+16,156,GOESND,NESDISย 1x1 F-O-Vย 4-LAYERย PRECIPITABLE WATER OVER LAND - CLEAR (GOES) - PWl,Not dumped,156
+17,157,GOESND,NESDISย 1x1 F-O-Vย 4-LAYER PRECIPITABLE WATERย OVER LAND - CLOUDY (GOES) - PWl,Not dumped,157
+18,158,GOESND,NESDISย 1x1 F-O-Vย 4-LAYERย PRECIPITABLE WATER OVER OCEAN - CLEAR (GOES) - PWl,Not dumped,158
+19,159,GOESND,NESDISย 1x1 F-O-Vย 4-LAYERย PRECIPITABLE WATER OVER OCEAN - CLOUDY (GOES) - PWl,Not dumped,159
+20,164,GOESND,NESDISย 1x1 F-O-V RADIANCESย OVER LAND - CLEAR (GOES) - Tb,Not dumped,164
+21,165,GOESND,NESDISย 1x1 F-O-V RADIANCES OVER LAND - CLOUDYย (GOES) - Tb,Not dumped,165
+22,170,,"NACELLE - Tb, q",Not dumped,170
+23,171,,"TALL TOWER - Tb, q",Not dumped,171
+24,174,GOESND,NESDISย 1x1 F-O-V RADIANCES OVER OCEAN - CLEAR (GOES) - Tb,Not dumped,174
+25,175,GOESND,NESDISย 1x1 F-O-V RADIANCES OVER OCEAN - CLOUDYย (GOES) - Tb,Not dumped,175
+26,180 (R - U.S. & JMA SHIPS),SFCSHP,"SURFACE MARINE WITH REPORTED STATION PRESSURE (SHIP, BUOY, C-MAN, TIDE GAUGE) - Tv, q, Pstn, sst","Tv, q, Pstn used by assimilation sst monitored by assimilation by switch in convinfo text file read by GBL-GSI",180
+27,181 (R - WMO Res 40 SYNOPS),ADPSFC,"SURFACE LAND [SYNOPTIC (fixed and mobile),ย METAR] WITH REPORTED STATION PRESSURE - Tv, q, Pstn, sst","Pstn used by assimilation Tv, qย flagged for non-use by assimilation due to missing obs error Tv, q, sst monitored by assimilation by switch in convinfo text file read by GBL-GSI",181
+28,182,SFCSHP,"SPLASH-LEVEL DROPSONDEย OVER OCEAN - Tv, q, Pstn","Tv, q, Pstn used by assimilation",182
+29,"183 (R - WMO Res 40 SYNOPS, U.S. & JMA SHIPS)","ADPSFC, SFCSHP","SURFACE MARINE (SHIP, BUOY, C-MAN, TIDE GAUGE) OR LAND [SYNOPTIC (fixed and mobile), METAR] WITH MISSING STATION PRESSURE - Tv, q, Pstn, sst","Tv, q, Pstn (entire report) flagged for non-use by assimilation due to missing obs error Tv, q, Pstn,ย sst monitored by assimilation by switch in convinfo text file read by GBL-GSI Altimeter setting is also missing (always the case for synoptic). Station pressure calculated from reported mean sea-level pressure and elevation via U.S. Standard Atmosphere approximation. Elevation is greater than 7.5 meters (if less than 7.5 meters, station pressure set equal to sea-level pressure and report type set to 181).",183
+30,187,ADPSFC,"SURFACE LAND (METAR) WITH MISSING STATION PRESSURE - Tv, q, Pstn, sst","Pstn used by assimilation Tv, qย flagged for non-use by assimilation due to missing obs error Tv, q, sst monitored by assimilation by switch in convinfo text file read by GBL-GSI Altimeter setting is reported (never the case for synoptic). Station pressure calculated from reported altimeter setting and elevation.",187
+31,188 R,MSONET,"SURFACE MESONET - Tv, q, Pstn",Not dumped,188
+32,191,SFCBOG,AUSTRALIAN PAOB MEAN SEA-LEVEL PRESSURE BOGUS OVER OCEAN - Pstn,sfcbog dump file not readย byย PREPOBS_PREPDATA This dataย no longer produced after 17 August 2010,191
+33,192,ADPSFC,"SURFACE LANDย SYNOPIC (fixed and mobile) WITH MISSING STATION PRESSURE AND MISSING SEA-LEVEL PRESSURE -ย Tv, q, Pstn",Entire report tossed by switch in PREPOBS_PREPDATA parm cards Station pressure estimated from U.S. Standard Atmosphere approximation Pmsl andย reported temperature and elevation.,192
+34,193,ADPSFC,"SURFACE LANDย METAR WITH MISSING STATION PRESSURE,ย MISSING SEA-LEVEL PRESSURE AND MISSING ALTIMETER SETTING -ย Tv, q, Pstn",Entire report tossed by switch in PREPOBS_PREPDATA parm cards Station pressure estimated from U.S. Standard Atmosphere approximation Pmsl andย reported temperature and elevation.,193
+35,194,SFCSHP,"SURFACE MARINE (SHIP, BUOY, C-MAN, TIDE GAUGE) OR LAND (SYNOPTIC, METAR) WITH MISSING STATION PRESSUREย AND MISSING SEA-LEVEL PRESSURE - Tv, q, Pstn",Entire report tossed by switch in PREPOBS_PREPDATA parm cards Station pressure estimated from U.S. Standard Atmosphere approximation Pmsl andย reported temperature and elevation.,194
+36,195,MSONET,"SURFACE MESONETย WITH MISSING STATION PRESSUREย AND MISSING ALTIMETER SETTING (SEA-LEVEL PRESSURE IS ALWAYS MISSING) -ย Tv, q, Pstn",Entire report tossed by switch in PREPOBS_PREPDATA parm cards Station pressure estimated from U.S. Standard Atmosphere approximation Pmsl andย reported temperature and elevation.,195
+37,210,SYNDAT,"SYNTHETIC (BOGUS) TROPICAL CYCLONE - u, v","u, v monitored by assimilation by switch in convinfo text file read by GBL-GSI Added to PREPBUFR file by later program SYNDAT_SYNDATA.",210
+38,220,ADPUPA,"RAWINSONDE - u, v (all levels), z (winds-by-height levels)","u, v used by assimilation p (vertical coordinate) calculated from z on winds-by-height levels.",220
+39,221,ADPUPA,"PIBAL - u,v,z","u, v used by assimilation p (vertical coordinate) calculated from z.",221
+40,222,ADPUPA,"CLASS SOUNDING - u, v",Entire report tossed by switch in PREPOBS_PREPDATA parm cards Entire reportย neither monitored nor assimilated - not in convinfo text file read by GBL-GSI,222
+41,223,PROFLR,"NOAA PROFILER NETWORK (NPN) WIND PROFILER - u, v, z","u, v used by assimilation p (vertical coordinate) calculated from z. These data are no longer produced after 13 Septemberย 2017.",223
+42,224,VADWND,"NEXRAD VERTICAL AZIMUTH DISPLAY (VAD) from Radar Coded Message (subtype 1) - u, v, z","u, v used by assimilation p (vertical coordinate) calculated from z.",224
+43,227,PROFLR,"MULTI-AGENCY PROFILER (MAP) ANDย ACOUSTIC SOUNDER (SODAR) - u, v, z",Entire report tossed by switch in PREPOBS_PREPDATA parm cards,227
+44,228,PROFLR,"JAPANESE METEOROLOGICAL AGENCY (JMA) WIND PROFILER - u, v, z","u, vย (entire report) flagged for non-use by assimilation due to missing obs error u, v monitored by assimilation by switch in convinfo text file read by GBL-GSI",228
+45,229,PROFLR,"WIND PROFILER DECODED FROM PILOT (PIBAL) BULLETINS - u, v, z","u, v used by assimilation p (vertical coordinate) calculated from z.",229
+46,230,AIRCFT,"AIREP AND PIREP AIRCRAFT - u, v","u, v used by assimilation",230
+47,231 R,AIRCFT,"AMDAR AIRCRAFT - u, v","u, v used by assimilation",231
+48,232,ADPUPA,"FLIGHT-LEVEL RECONNAISSANCE AND PROFILE DROPSONDE - u, v","u, v used by assimilation",232
+49,233 R,AIRCAR,"MDCRS ACARS AIRCRAFT - u, v","u, v used by assimilation",233
+50,234 R,AIRCFT,"TAMDARย AIRCRAFT - u, v","u, v monitored by assimilation by switch in convinfo text file read by GBL-GSI",234
+51,235 R,AIRCFT,"CANADIAN AMDAR AIRCRAFT - u, v","u, v monitored by assimilation by switch in convinfo text file read by GBL-GSI",235
+52,240,SATWND,"NESDIS IR (SHORT-WAVE)ย CLOUD DRIFT (ALL LEVELS)ย (GOES) - u, v",Theย assimilation will obtain this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This willย not be written into the PREPBUFR file.,240
+53,241,SATWND,"INDIA IR (LONG-WAVE) AND VISIBLE CLOUD DRIFT (ALL LEVELS)ย (INSAT, KALPANA) - u, v","Effective 5/22/2012, theย assimilation obtains this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This is still written into the PREPBUFR file but is ignored by the GBL-GSI.",241
+54,242,SATWND,"JMA IR (LONG-WAVE) AND VISIBLE CLOUD DRIFT AT LEVELS BELOW 850 MB (GMS, MTSAT, HIMAWARI) - u, v","Effective 5/22/2012, theย assimilation obtains this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This is still written into the PREPBUFR file but is ignored by the GBL-GSI. The GSI redefines report type 242 as JMA visible cloud drift at all levels.",242
+55,243,SATWND,"EUMETSAT IR (LONG-WAVE) AND VISIBLE CLOUD DRIFT AT LEVELS BELOW 850 MB (METEOSAT) - u, v","Effective 5/22/2012, theย assimilation obtains this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This is still written into the PREPBUFR file but is ignored by the GBL-GSI. The GSI redefines report type 243 as EUMETSAT visible cloud drift at all levels.",243
+56,244,SATWND,"AVHRR/POES IR (LONG-WAVE) CLOUD DRIFT (ALL LEVELS) (NOAA,ย METOP) - u,v",Theย assimilation will obtain this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This willย not be written into the PREPBUFR file.,244
+57,245,SATWND,"NESDIS IR (LONG-WAVE) CLOUD DRIFT (ALL LEVELS) (GOES) - u, v","Effective 5/22/2012, theย assimilation obtains this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This is still written into the PREPBUFR file but is ignored by the GBL-GSI.",245
+58,246,SATWND,"NESDIS IMAGER WATER VAPOR (ALL LEVELS) -ย CLOUD TOP (GOES) - u, v","Effective 5/22/2012, theย assimilation obtains this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This is still written into the PREPBUFR file but is ignored by the GBL-GSI.",246
+59,247,SATWND,"NESDIS IMAGER WATER VAPOR (ALL LEVELS) - DEEP LAYER (GOES) - u, v","Effective 5/22/2012, theย assimilation obtains this directlyย from the satwnd dump file (it was never in theย PREPBUFR file) (see Table 18). ย Since it has both a missing obs error and is set to monitor in the convinfo file, it is now monitored by theย GBL-GSI.",247
+60,248,SATWND,"NESDIS SOUNDER WATER VAPOR (ALL LEVELS) - CLOUD TOP (GOES) - u, v","If ever processed, theย assimilation will obtain this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This will never be writtenย into the PREPBUFR file.",248
+61,249,SATWND,"NESDIS SOUNDER WATER VAPOR (ALL LEVELS) - DEEP LAYER (GOES) - u, v","If ever processed, theย assimilation will obtain this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This will never be writtenย into the PREPBUFR file.",249
+62,250,SATWND,"JMA IMAGER WATER VAPOR (ALL LEVELS) - CLOUD TOP & DEEP LAYER (GMS, MTSAT, HIMAWARI) - u, v","Effective 5/22/2012, theย assimilation obtains this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This is still written into the PREPBUFR file but is ignored by the GBL-GSI.",250
+63,251,SATWND,"NESDIS VISIBLE CLOUD DRIFT (ALL LEVELS) (GOES) - u, v","Effective 5/22/2012, theย assimilation obtains this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This is still written into the PREPBUFR file but is ignored by the GBL-GSI.",251
+64,252,SATWND,"JMA IR (LONG-WAVE) AND VISIBLE CLOUD DRIFT AT LEVELS ABOVE 850 MB (GMS, MTSAT, HIMAWARI) - u, v","Effective 5/22/2012, theย assimilation obtains this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This is still written into the PREPBUFR file but is ignored by the GBL-GSI. The GSI redefines report type 252 as JMA IR cloud drift at all levels.",252
+65,253,SATWND,"EUMETSAT IR (LONG-WAVE) AND VISIBLE CLOUD DRIFT AT LEVELS ABOVE 850 MB (METEOSAT) - u, v","Effective 5/22/2012, theย assimilation obtains this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This is still written into the PREPBUFR file but is ignored by the GBL-GSI. The GSI redefines report type 253 as EUMETSAT IR cloud drift at all levels.",253
+66,254,SATWND,"EUMETSAT IMAGER WATER VAPOR (ALL LEVELS) - CLOUD TOP & DEEP LAYER (METEOSAT) - u, v","Effective 5/22/2012, theย assimilation obtains this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This is still written into the PREPBUFR file but is ignored by the GBL-GSI.",254
+67,255,SATWND,NESDIS PICTURE TRIPLET CLOUD DRIFT (LOW LEVELS) (GOES),No longer produced by NESDIS,255
+68,256,SATWND,"INDIA IMAGER WATER VAPOR (ALL LEVELS) (INSAT, KALPANA) - u, v","If ever processed, theย assimilation will obtain this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This will never be writtenย into the PREPBUFR file.",256
+69,257,SATWND,"MODIS/POES IR (LONG-WAVE) CLOUD DRIFT (ALL LEVELS) (AQUA, TERRA) - u,v","Effective 5/22/2012, theย assimilation obtains this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This is still written into the PREPBUFR file but is ignored by the GBL-GSI.",257
+70,258,SATWND,"MODIS/POES IMAGER WATER VAPOR (ALL LEVELS) - CLOUD TOP (AQUA, TERRA) - u, v","Effective 5/22/2012, theย assimilation obtains this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This is still written into the PREPBUFR file but is ignored by the GBL-GSI.",258
+71,259,SATWND,"MODIS/POES IMAGER WATER VAPOR (ALL LEVELS) - DEEP LAYER (AQUA,ย TERRA) -ย u, v","Effective 5/22/2012, theย assimilation obtains this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This is still written into the PREPBUFR file but is ignored by the GBL-GSI.",259
+72,260,SATWND,"VIIRS/POES IR (LONG-WAVE) CLOUD DRIFT (ALL LEVELS) (NPP) - u,v",Theย assimilation will obtain this directlyย from the satwnd dump file rather thanย from the PREPBUFR file (see Table 18). ย This willย not beย written into the PREPBUFR file.,260
+73,270,,"NACELLE - u,v",Not dumped,270
+74,271,,"TALL TOWER -u,v",Not dumped,271
+75,280 (R - U.S. & JMA SHIPS),SFCSHP,"SURFACE MARINE WITH REPORTED STATION PRESSURE (SHIP, BUOY, C-MAN, TIDE GAUGE) - u, v","u, v used by assimilation",280
+76,281 (R - WMO Res 40 SYNOPS),ADPSFC,"SURFACE LAND [SYNOPTIC (fixed and mobile),ย METAR] WITH REPORTED STATION PRESSURE - u, v","u, vย (entire report) flagged for non-use by assimilation due to missing obs error u, v monitored by assimilation by switch in convinfo text file read by GBL-GSI",281
+77,282,SFCSHP,"ATLAS BUOY - u, v (see % below)","u, v used by assimilation Reported station pressure and mean sea-level pressure BOTH missing. Station pressure is set to 1013 mb. Elevation is less than or equal to 7.5 meters.",282
+78,283,SPSSMI,"SSM/I SUPEROBED (1 DEGREE LAT/LON) NEURAL NET 3 WIND SPEED OVER OCEAN - u, v","u, v monitored by assimilation by switch in convinfo text file read by GBL-GSI Only wspd available so direction initially set to ZERO in PREPBUFR file, direction calculated from analysisย and u, v recomputedย in program PREPOBS_OIQCBUFR. Reported station pressure and mean sea-level pressure BOTH missing. Station pressure is set to 1013 mb. Elevation is less than or equal to 7.5 meters. The SSM/I F-13 satelliteย went bad in November 2009 resulting in no data being processed upstream. ย The empty dumps were turned off and the PREPOBS_PREPDATA parm cards were set to no longer process these dataย inย October 2010.",283
+79,"284 (R - WMO Res 40 SYNOPS, U.S. & JMA SHIPS)","ADPSFC, SFCSHP","SURFACE MARINE (SHIP, BUOY, C-MAN, TIDE GAUGE) OR LAND [SYNOPTIC (fixed and mobile), METAR] WITH MISSING STATION PRESSURE - u, v","u, v (entire report) flagged for non-use by assimilation due to missing obs error u, v monitored by assimilation by switch in convinfo text file read by GBL-GSI Altimeter setting is also missing (always the case for synoptic). Station pressure calculated from reported mean sea-level pressure and elevation via U.S. Standard Atmosphere approximation. Elevation is greater than 7.5 meters (if less than 7.5 meters, station pressure set equal to sea-level pressure and report type set to 281).",284
+80,285,QKSWND,"SUPEROBED (0.5 DEGREE LAT/LON) SCATTEROMETER WINDS OVER OCEANย (QUIKSCAT) - u, v",No longer produced,285
+81,286,ERS1DA,"SCATTEROMETER WINDS OVER OCEANย (ERS) - u, v",No longer produced,286
+82,287,ADPSFC,"SURFACE LAND (METAR) WITH MISSING STATION PRESSURE - u, v","u, v (entire report) flagged for non-use by assimilation due to missing obs error u, v monitored by assimilation by switch in convinfo text file read by GBL-GSI Altimeter setting is reported (never the case for synoptic). Station pressure calculated from reported altimeter setting and elevation.",287
+83,288 R,MSONET,"SURFACE MESONET - u, v",Not dumped,288
+84,289,WDSATR,"SUPEROBEDย (1.0 DEGREE LAT/LON) SCATTEROMETER WINDS OVER OCEAN (WINDSAT) - u,v","u, v usedย by assimilation (currently not available due to format change in raw files)",289
+85,290,ASCATW,"NON-SUPEROBEDย SCATTEROMETER WINDS OVER OCEAN (ASCAT) - u,v (50 km resolution)","METOP-2(A) u, v usedย by assimilation METOP-1(B) u,v monitored by assimilation by switch in convinfo text file read by GBL-GSI",290
+86,291,,"NON-SUPEROBEDย SCATTEROMETER WINDS OVER OCEAN (OSCAT) - u,v","When available, theย assimilation will obtain this directlyย from the oscatwย dump file rather thanย from the PREPBUFR file (see Table 18). ย This will not beย written into the PREPBUFR file. This type will never be available. Instrument failed on 20 February 2014.",291
+87,292,ADPSFC,"SURFACE LANDย SYNOPIC (fixed and mobile) WITH MISSING STATION PRESSURE AND MISSING SEA-LEVEL PRESSURE -ย u,v",Entire report tossed by switch in PREPOBS_PREPDATA parm cards Station pressure estimated from U.S. Standard Atmosphere approximation Pmsl andย reported temperature and elevation.,292
+88,293,ADPSFC,"SURFACE LANDย METAR WITH MISSING STATION PRESSURE,ย MISSING SEA-LEVEL PRESSURE AND MISSING ALTIMETER SETTING -ย u,v",Entire report tossed by switch in PREPOBS_PREPDATA parm cards Station pressure estimated from U.S. Standard Atmosphere approximation Pmsl andย reported temperature and elevation.,293
+89,294,SFCSHP,"SURFACE MARINE (SHIP, BUOY, C-MAN, TIDE GAUGE) OR LAND (SYNOPTIC, METAR) WITH MISSING STATION PRESSUREย AND MISSING SEA-LEVEL PRESSURE - u,v",Entire report tossed by switch in PREPOBS_PREPDATA parm cards Station pressure estimated from U.S. Standard Atmosphere approximation Pmsl andย reported temperature and elevation.,294
+90,295,MSONET,"SURFACE MESONETย WITH MISSING STATION PRESSUREย AND MISSING ALTIMETER SETTING (SEA-LEVEL PRESSURE IS ALWAYS MISSING) - u,v",Entire report tossed by switch in PREPOBS_PREPDATA parm cards Station pressure estimated from U.S. Standard Atmosphere approximation Pmsl andย reported temperature and elevation.,295
diff --git a/examples/weather/healda/datasets/merged_dataset.py b/examples/weather/healda/datasets/merged_dataset.py
new file mode 100644
index 0000000000..75b9ddbd16
--- /dev/null
+++ b/examples/weather/healda/datasets/merged_dataset.py
@@ -0,0 +1,486 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import asyncio
+import logging
+import math
+from typing import Callable
+
+import numpy as np
+import pandas as pd
+import torch
+from zarr.core.sync import sync
+
+from utils import profiling
+
+from . import datetime_utils
+
+logger = logging.getLogger(__name__)
+
+
+class _FrameIndexGenerator:
+ """Handles frame index generation with striding, permuting, and model rank slicing."""
+
+ def __init__(
+ self,
+ times,
+ time_length: int,
+ frame_step: int,
+ model_rank: int,
+ model_world_size: int,
+ ):
+ """
+ Args:
+ times: Time array to split into contiguous segments
+ time_length: Number of frames per window
+ frame_step: Step size between frames
+ model_rank: Model rank for distributed training
+ model_world_size: Total number of model ranks
+ """
+ self.time_length = time_length
+ self.frame_step = frame_step
+ self.model_rank = model_rank
+ self.model_world_size = model_world_size
+
+ # Split times into contiguous segments and compute sizes
+ self.segments = _split_array_contiguous(times)
+ self.sizes = [len(segment) for segment in self.segments]
+
+ self.total_samples = sum(self.sizes)
+
+ # Calculate valid lengths for each segment
+ frames_per_window = (time_length - 1) * frame_step + 1
+ self.segment_valid_lengths = []
+ for segment in self.segments:
+ segment_valid_length = len(segment) - frames_per_window + 1
+ if segment_valid_length > 0:
+ self.segment_valid_lengths.append(segment_valid_length)
+ else:
+ self.segment_valid_lengths.append(0)
+
+ # Precompute cumulative sizes for efficient mapping
+ self.cumulative_valid_sizes = [0] + list(np.cumsum(self.segment_valid_lengths))
+ self.cumulative_sizes = [0] + list(np.cumsum(self.sizes))
+ self.valid_length = sum(self.segment_valid_lengths)
+
+ def generate_frame_indices(self, sample_indices: torch.Tensor) -> list[int]:
+ """Generate frame indices from sample indices with striding and model rank slicing.
+
+ Args:
+ sample_indices: Tensor of logical sample indices for each sample in the batch
+
+ Returns:
+ List of frame indices to load
+ """
+ frame_idxs = []
+ for sample_idx in sample_indices:
+ # Map logical sample index to physical frame index
+ physical_idx = self._map_logical_to_physical(sample_idx)
+
+ # Create frame range with striding
+ frames = list(
+ range(
+ physical_idx,
+ physical_idx + self.time_length * self.frame_step,
+ self.frame_step,
+ )
+ )
+ # Apply model rank slicing
+ n = self.time_length // self.model_world_size
+ frames = frames[self.model_rank * n : (self.model_rank + 1) * n]
+ frame_idxs.append(frames)
+ return frame_idxs
+
+ def _map_logical_to_physical(self, logical_idx: int) -> int:
+ """Map logical sample index to physical frame index across segments."""
+ if logical_idx >= self.total_samples:
+ raise IndexError(
+ f"Sample index {logical_idx} out of bounds for {self.total_samples} samples"
+ )
+
+ # Find which segment this logical index belongs to
+ segment_idx = 0
+ for i, cum_size in enumerate(self.cumulative_valid_sizes[1:], 1):
+ if logical_idx < cum_size:
+ segment_idx = i - 1
+ break
+
+ # Calculate offset within the segment
+ segment_start = self.cumulative_sizes[segment_idx]
+ offset_within_segment = logical_idx - self.cumulative_valid_sizes[segment_idx]
+
+ # Return the physical frame index in the original times array
+ return segment_start + offset_within_segment
+
+ def get_valid_length(self) -> int:
+ """Get the total valid length across all segments."""
+ return self.valid_length
+
+
+def _validate_times(times, world_size, time_length, frame_step):
+ if len(times) < world_size:
+ raise ValueError(f"Not enough times provided. Received {len(times)=}.")
+
+ if time_length == 1 and frame_step != 1:
+ raise ValueError("Frame_step must be 1 for image setting")
+
+
+class _MergedLoader:
+ def __init__(self, loaders) -> None:
+ self._loaders = loaders
+
+ async def sel_time(self, time) -> dict[str, np.ndarray]:
+ # Standardize time to np.ndarray of np.datetime64
+ arrays = await asyncio.gather(
+ *[loader.sel_time(time) for loader in self._loaders]
+ )
+ data = {}
+ for d in arrays:
+ data.update(d)
+ return data
+
+
+def _split(x, rank, world_size, drop_extra=True):
+ n = len(x)
+ base = n // world_size
+ rem = n % world_size
+
+ if drop_extra:
+ samples_per_rank = base
+ x = x[: base * world_size]
+ start = rank * base
+ else:
+ # give the first rem ranks one extra sample
+ if rank < rem:
+ samples_per_rank = base + 1
+ start = rank * samples_per_rank
+ else:
+ samples_per_rank = base
+ start = rem * (base + 1) + (rank - rem) * base
+
+ return x[start : start + samples_per_rank]
+
+
+class TimeMergedDataset(torch.utils.data.IterableDataset):
+ """Merge several loader objects in time and apply transforms.
+
+ This is used to join several datasets along time, and grab data in a chunked manner.
+
+ ``time_loaders`` is a list of objects with this interface::
+
+ class Loader:
+
+ async def sel_time(self, times) -> dict[str, np.ndarray]:
+ pass
+
+ ``chunk_size`` should ideally be larger than the chunking of each dataset.
+
+ ``transform`` is a function that prepares the raw loaded data for the model::
+
+ def transform(
+ times: list[pd.Timestamp],
+ data: list[dict[str, np.ndarray]]
+ ) -> dict[str, Any]
+
+ When `time_length = 1` and `frame_step = 1`, this collapses to the image case.
+ """
+
+ def __init__(
+ self,
+ times,
+ # for performance times should be in sequence
+ *,
+ time_loaders,
+ rank: int = 0,
+ world_size: int = 1,
+ shuffle: bool = True,
+ chunk_size: int = 48,
+ transform: Callable,
+ infinite: bool = True,
+ time_length: int = 1,
+ frame_step: int = 1,
+ window_stride: int = 1,
+ ):
+ _validate_times(times, world_size, time_length, frame_step)
+
+ frames_per_window = (time_length - 1) * frame_step + 1
+ self._loader = _MergedLoader(time_loaders)
+ self.rank = rank
+ self.world_size = world_size
+ self.set_times(times) # Shard times across ranks
+
+ if len(self._times) < chunk_size:
+ raise ValueError(
+ f"Sharded times too small for chunk size. Need {chunk_size} "
+ f"frames but only got {len(self._times)}"
+ )
+
+ self.shuffle = shuffle
+ self.transform = transform
+ self.chunk_size = chunk_size
+ self.infinite = infinite
+
+ self.time_length = time_length
+ self.frame_step = frame_step
+ self.window_stride = window_stride
+
+ self._generator = None
+
+ max_valid_idx = len(times) - self.chunk_size
+ self.max_valid_chunk_idx = max_valid_idx // self.chunk_size
+
+ self.overlap = frames_per_window - 1
+
+ @property
+ def times(self) -> pd.DatetimeIndex:
+ return pd.DatetimeIndex(self._times)
+
+ def set_times(self, times):
+ self._times = _split(datetime_utils.as_numpy(times), self.rank, self.world_size)
+
+ def _load_chunk(self, chunk: int):
+ return sync(self._loader.sel_time(self._times_for_chunk(chunk)))
+
+ def _times_for_chunk(self, chunk: int) -> np.ndarray:
+ return self._times[
+ chunk * self.chunk_size : (chunk + 1) * self.chunk_size + self.overlap
+ ]
+
+ def __iter__(self):
+ if self.infinite:
+ while True:
+ yield from self._iter()
+ else:
+ yield from self._iter()
+
+ def __len__(self):
+ return len(self._times)
+
+ def _generator_shuffle(self, arr, worker_info=None):
+ if self._generator is None:
+ if worker_info:
+ seed = worker_info.seed
+ else:
+ seed = np.random.randint(0, 2**31) + self.rank
+
+ self._generator = np.random.default_rng(seed=(seed % 2**32))
+ self._generator.shuffle(arr)
+
+ def _iter(self):
+ num_chunks = math.ceil(len(self._times) / self.chunk_size)
+ chunk_idxs = np.arange(num_chunks)
+
+ info = torch.utils.data.get_worker_info()
+ num_workers = 1 if info is None else info.num_workers
+ worker_id = 0 if info is None else info.id
+
+ # Shard chunks across the data workers. Shard before shuffle so that all workers
+ # have the same sharding pattern and each is assigned a unique set of chunks
+ chunk_idxs = _split(chunk_idxs, worker_id, num_workers, drop_extra=False)
+
+ if self.shuffle:
+ self._generator_shuffle(chunk_idxs, info)
+
+ for chunk_idx in chunk_idxs:
+ if chunk_idx > self.max_valid_chunk_idx:
+ continue
+
+ arr = self._load_chunk(chunk_idx)
+ times_for_chunk = self._times_for_chunk(chunk_idx)
+
+ max_window_start = (
+ len(times_for_chunk) - (self.time_length - 1) * self.frame_step
+ )
+
+ window_starts = np.arange(0, max_window_start, self.window_stride)
+ if self.shuffle:
+ self._generator_shuffle(window_starts, info)
+
+ for i, start_idx in enumerate(window_starts):
+ frame_idxs = range(
+ start_idx,
+ start_idx + self.time_length * self.frame_step,
+ self.frame_step,
+ )
+
+ frames = []
+ timestamps = []
+ for idx in frame_idxs:
+ time = times_for_chunk[idx]
+ arr_i = {k: v[idx] for k, v in arr.items()}
+ timestamp = pd.Timestamp(time)
+ cftimestamp = datetime_utils.as_cftime(timestamp)
+ frames.append(arr_i)
+ timestamps.append(cftimestamp)
+
+ window_tensor = self.transform(timestamps, frames)
+
+ yield window_tensor
+
+
+class TimeMergedMapStyle(torch.utils.data.Dataset):
+ """Map-style dataset wrapping time-loaders with transform and caching support.
+
+ Applies transforms either across all batches together or on individual frames.
+ Supports model parallelism to shard data across the time axis for ranks within
+ the model group.
+ """
+
+ def __init__(
+ self,
+ times,
+ *,
+ time_loaders,
+ time_length: int = 1,
+ frame_step: int = 1,
+ transform: Callable,
+ cache_chunk_size: int = 0,
+ model_rank=0,
+ model_world_size=1,
+ batch_transform=None,
+ ):
+ """
+ Args:
+ cache_chunk_size: if nonzero, then cache data in this chunk size, so that
+ data cn be accessed efficiently in sequence.
+ batch_size: if provided
+
+ """
+ _validate_times(times, model_world_size, time_length, frame_step)
+ self.times = times
+ self.transform = transform
+ self.batch_transform = batch_transform
+ self.time_length = time_length
+ self.frame_step = frame_step
+ self.model_rank = model_rank
+ self.model_world_size = model_world_size
+ self._loader = _MergedLoader(time_loaders)
+
+ self._frame_indexer = _FrameIndexGenerator(
+ times, time_length, frame_step, model_rank, model_world_size
+ )
+
+ # Get valid length from frame indexer
+ self.valid_length = self._frame_indexer.get_valid_length()
+ if self.valid_length <= 0:
+ frames_per_window = (self.time_length - 1) * self.frame_step + 1
+ raise ValueError(
+ f"Dataset too small for window length. Need {frames_per_window} "
+ f"frames but segments have lengths {self._frame_indexer.sizes}"
+ )
+ self.cache_chunk_size = cache_chunk_size
+ self._cache_id = None # cache id is based on idx//chunk_size
+ self._cache_start = 0
+ self._cache_end = 0
+ self._cache_data = None
+
+ def __len__(self):
+ return self._frame_indexer.get_valid_length()
+
+ def _load(self, frame_idxs):
+ if not self.cache_chunk_size:
+ window_times = self.times[list(frame_idxs)]
+ window_data = sync(self._loader.sel_time(window_times))
+ return [
+ {k: v[i] for k, v in window_data.items()}
+ for i in range(len(window_times))
+ ]
+
+ frames = []
+ # TODO this caching logic only works well if frame_idxs is sequential.
+ # otherwise it will results in many cache misses
+ for i in frame_idxs:
+ cache_id = i // self.cache_chunk_size
+ if self._cache_id != cache_id:
+ self._cache_start = cache_id * self.cache_chunk_size
+ self._cache_end = min(
+ self._cache_start + self.cache_chunk_size, len(self.times)
+ )
+ window_times = self.times[self._cache_start : self._cache_end]
+ window_data = sync(self._loader.sel_time(window_times))
+ self._cache_data = [
+ {k: v[i] for k, v in window_data.items()}
+ for i in range(len(window_times))
+ ]
+
+ self._cache_id = cache_id
+ frames.append(self._cache_data[i - self._cache_start])
+ return frames
+
+ def __getitem__(self, idx):
+ result = self.__getitems__([idx])
+ if self.batch_transform:
+ return result
+ else:
+ return result[0]
+
+ def _batch_transform(self, times, frames):
+ if self.batch_transform:
+ return self.batch_transform(times, frames)
+ elif self.transform:
+ output = []
+ for sample_times, sample_frames in zip(times, frames):
+ output.append(self.transform(sample_times, sample_frames))
+ return output
+
+ @profiling.nvtx
+ def _get_times_and_frames(self, idx):
+ if min(idx) < 0 or max(idx) >= self.valid_length:
+ raise IndexError(
+ f"Index {idx} out of bounds for dataset of length {self.valid_length}"
+ )
+
+ batch_size = len(idx)
+
+ frame_idxs = self._frame_indexer.generate_frame_indices(idx)
+ flat_frame_idxs = sum(frame_idxs, start=[])
+
+ frames = self._load(flat_frame_idxs)
+ window_times = self.times[flat_frame_idxs]
+
+ timestamps = []
+ for i, time in enumerate(window_times):
+ timestamp = pd.Timestamp(time)
+ cftimestamp = datetime_utils.as_cftime(timestamp)
+ timestamps.append(cftimestamp)
+
+ def reshape(list):
+ n = len(list) // batch_size
+ return [[list[n * i + j] for j in range(n)] for i in range(batch_size)]
+
+ timestamps = reshape(timestamps)
+ frames = reshape(frames)
+
+ return timestamps, frames
+
+ @profiling.nvtx
+ def __getitems__(self, idx):
+ timestamps, frames = self._get_times_and_frames(idx)
+ return self._batch_transform(timestamps, frames)
+
+
+def _split_array_contiguous(x):
+ d = x[1] - x[0]
+ segments = []
+ start = 0
+ for i in range(1, x.size):
+ if (x[i] - x[i - 1]) != d:
+ segments.append(x[start:i])
+ start = i
+
+ if start < x.size:
+ segments.append(x[start:])
+
+ return segments
diff --git a/examples/weather/healda/datasets/obs_filtering_utils.py b/examples/weather/healda/datasets/obs_filtering_utils.py
new file mode 100644
index 0000000000..fbaf168511
--- /dev/null
+++ b/examples/weather/healda/datasets/obs_filtering_utils.py
@@ -0,0 +1,128 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Shared quality control filtering utilities for observation data.
+
+This module provides reusable filtering functions that can be used by both
+the original UFSDataset and the new UFSUnifiedLoader to ensure consistent
+quality control across different data loading approaches.
+"""
+
+import pyarrow as pa
+import pyarrow.compute as pc
+
+from datasets.sensors import (
+ CONV_GPS_LEVEL2_CHANNELS,
+ CONV_UV_CHANNELS,
+ CONV_UV_IN_SITU_TYPES,
+ SENSOR_CONFIGS,
+ SENSOR_OFFSET,
+ QCLimits,
+)
+
+
+def _get_index_range(sensor):
+ start = SENSOR_OFFSET[sensor]
+ end = start + SENSOR_CONFIGS[sensor].channels
+ return start, end
+
+
+# columns to use for filtering
+height = pc.field("Height")
+pressure = pc.field("Pressure")
+obs = pc.field("Observation")
+analysis_use = pc.field("Analysis_Use_Flag")
+qc_flag = pc.field("QC_Flag")
+min_valid = pc.field("min_valid")
+max_valid = pc.field("max_valid")
+local_id = pc.field("local_channel_id")
+is_conv = pc.field("is_conv")
+obs_type = pc.field("Observation_Type")
+
+
+def _get_conv_filter_expr(
+ table: pa.Table,
+ qc_filter: bool = False,
+ uv_in_situ_only: bool = False,
+ gps_level1_only: bool = False,
+):
+ """Get filter expression for conventional observations."""
+ is_gps = local_id <= 2
+
+ # Use QCLimits from sensors.py (single source of truth)
+ height_ok = pc.is_finite(height) & (
+ (height >= QCLimits.HEIGHT_MIN) & (height <= QCLimits.HEIGHT_MAX)
+ )
+
+ min_pressure = pc.if_else(
+ is_gps,
+ pa.scalar(QCLimits.PRESSURE_MIN_GPS),
+ pa.scalar(QCLimits.PRESSURE_MIN_DEFAULT),
+ )
+ pressure_ok = pc.is_finite(pressure)
+ pressure_ok &= (pressure >= min_pressure) & (pressure <= QCLimits.PRESSURE_MAX)
+
+ ok = pressure_ok & height_ok
+
+ if qc_filter:
+ ok &= analysis_use == pa.scalar(1)
+
+ if uv_in_situ_only:
+ is_uv_channel = pc.is_in(local_id, pa.array(CONV_UV_CHANNELS))
+ is_in_situ = pc.is_in(
+ obs_type,
+ pa.array(CONV_UV_IN_SITU_TYPES, type=table["Observation_Type"].type),
+ )
+ ok &= ~is_uv_channel | is_in_situ
+
+ if gps_level1_only:
+ ok &= ~pc.is_in(local_id, pa.array(CONV_GPS_LEVEL2_CHANNELS))
+
+ return ok
+
+
+def filter_observations(
+ table: pa.Table,
+ qc_filter: bool = False,
+ conv_uv_in_situ_only: bool = False,
+ conv_gps_level1_only: bool = False,
+) -> pa.Table:
+ """
+ Unified filtering function for observation data.
+
+ Args:
+ table: PyArrow table containing observation data
+ qc_filter: Whether to apply QC flag filtering
+ conv_uv_in_situ_only: Exclude satellite UV (keep in-situ only)
+ conv_gps_level1_only: Exclude GPS T/Q retrievals (keep bending angle)
+
+ Returns:
+ Filtered PyArrow table
+ """
+ ok = pc.is_finite(obs)
+ ok &= obs >= min_valid
+ ok &= obs <= max_valid
+
+ sat_ok = ok
+ if qc_filter:
+ sat_ok &= qc_flag == 0
+
+ conv_filter = _get_conv_filter_expr(
+ table, qc_filter, conv_uv_in_situ_only, conv_gps_level1_only
+ )
+ ok &= pc.if_else(is_conv, conv_filter, sat_ok)
+
+ return table.filter(ok)
diff --git a/examples/weather/healda/datasets/obs_loader.py b/examples/weather/healda/datasets/obs_loader.py
new file mode 100644
index 0000000000..5531db9b57
--- /dev/null
+++ b/examples/weather/healda/datasets/obs_loader.py
@@ -0,0 +1,369 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+UFS Unified Loader for the new combined schema data format.
+
+This loader handles both satellite and conventional observations using the
+unified schema produced by etl_unified.py. It provides an async interface
+compatible with TimeMergedDataset and includes quality control filtering,
+normalization, and innovation filtering.
+"""
+
+import functools
+import io
+import os
+from datetime import datetime
+from typing import List, Literal
+
+import fsspec
+import numpy as np
+import pandas as pd
+import pyarrow as pa
+import pyarrow.compute as pc
+import pyarrow.parquet as pq
+from utils import storage
+
+from datasets.etl.combined_schema import (
+ GLOBAL_CHANNEL_ID,
+ SENSOR_ID,
+ get_combined_observation_schema,
+)
+from datasets.obs_filtering_utils import filter_observations
+from datasets.sensors import (
+ SENSOR_CONFIGS,
+)
+
+LOCAL_CHANNEL_ID = pa.field("local_channel_id", pa.uint16())
+
+
+def get_channel_table():
+ """Return PyArrow table mapping observation channel IDs to sensor metadata and normalization stats."""
+ import config.environment as config
+
+ return UFSUnifiedLoader(
+ config.UFS_OBS_PATH,
+ sensors=[],
+ obs_context_hours=(-3, 3),
+ normalization="zscore",
+ filesystem_type="s3" if config.UFS_OBS_PATH.startswith("s3://") else "local",
+ remote_name=config.UFS_OBS_PROFILE,
+ ).channel_table
+
+
+class UFSUnifiedLoader:
+ """
+ Unified loader for UFS observation data using the new combined schema.
+
+ This loader handles both satellite and conventional observations in a
+ unified format, providing async interface compatibility with TimeMergedDataset.
+ """
+
+ def __init__(
+ self,
+ data_path: str,
+ sensors: List[str],
+ filesystem_type: Literal["s3", "local"] = "local",
+ remote_name: str = "pdx",
+ normalization: Literal["minmax", "zscore"] = "minmax",
+ innovation_type: Literal["none", "adjusted", "unadjusted"] = "none",
+ qc_filter: bool = False,
+ filter_innovation: bool = False,
+ check_corrected: bool = True,
+ obs_context_hours: tuple[int, int] = (-24, 0),
+ data_spacing: int = 3, # hours
+ drop_obs_channel_ids: list[int] | None = None,
+ conv_uv_in_situ_only: bool = False,
+ conv_gps_level1_only: bool = False,
+ ):
+ """
+ Initialize the UFS Unified Loader.
+
+ Args:
+ data_path: Path to the processed observation data
+ sensors: List of sensors to load (e.g., ['atms', 'mhs', 'conv'])
+ filesystem_type: Type of filesystem ('s3' or 'local')
+ remote_name: Remote storage name for S3
+ normalization: Normalization method ('minmax' or 'zscore')
+ innovation_type: Innovation type to use ('none', 'adjusted', 'unadjusted')
+ qc_filter: Whether to apply quality control filtering
+ filter_innovation: Whether to filter based on innovation values
+ check_corrected: Whether to validate corrected observation values
+ obs_context_hours: Hours relative to target time for observation context
+ data_spacing: Hours between data points
+ drop_obs_channel_ids: Global channel IDs to drop
+ conv_uv_in_situ_only: Exclude satellite UV (keep in-situ only)
+ conv_gps_level1_only: Exclude GPS T/Q (keep bending angle)
+ """
+ self.data_path = data_path
+ self.sensors = sensors
+ self.filesystem_type = filesystem_type
+ self.remote_name = remote_name
+ self.normalization = normalization
+ self.innovation_type = innovation_type
+ self.qc_filter = qc_filter
+ self.filter_innovation = filter_innovation
+ self.check_corrected = check_corrected
+ self.obs_context_hours = obs_context_hours
+ self.data_spacing = data_spacing
+ # Optional list of global observation channel IDs (GLOBAL_CHANNEL_ID)
+ # to drop before normalization and further processing.
+ self.drop_obs_channel_ids = (
+ list(drop_obs_channel_ids) if drop_obs_channel_ids is not None else []
+ )
+ self.conv_uv_in_situ_only = conv_uv_in_situ_only
+ self.conv_gps_level1_only = conv_gps_level1_only
+
+ # Validate sensors
+ for sensor in self.sensors:
+ if sensor not in SENSOR_CONFIGS:
+ raise ValueError(
+ f"Unconfigured sensor: {sensor}. Available: {list(SENSOR_CONFIGS.keys())}"
+ )
+
+ # Setup filesystem
+ if self.filesystem_type == "s3":
+ self.fs = fsspec.filesystem(
+ "s3", **storage.get_storage_options(remote_name)
+ )
+ elif self.filesystem_type == "local":
+ self.fs = None
+ else:
+ raise ValueError(
+ f"Unsupported filesystem_type: {filesystem_type}. Use 's3' or 'local'"
+ )
+
+ # Load channel table for normalization
+ self._channel_table = None
+
+ @property
+ def output_schema(self) -> pa.Schema:
+ """Get the output schema including the sensor and platform columns."""
+ base_schema = get_combined_observation_schema()
+ return base_schema.append(LOCAL_CHANNEL_ID).append(SENSOR_ID)
+
+ @functools.cached_property
+ def channel_table(self) -> pa.Table:
+ """Load the channel table for normalization."""
+ channel_table_path = os.path.join(self.data_path, "channel_table.parquet")
+ if self.fs is not None:
+ file = io.BytesIO(self.fs.cat_file(channel_table_path))
+ else:
+ file = channel_table_path
+
+ table = pq.read_table(file)
+ sensor_id = np.asarray(table["sensor_id"])
+ local_channel_ids = []
+ offset = 0
+ for i in range(len(sensor_id)):
+ if sensor_id[i] != sensor_id[i - 1]:
+ offset = i
+ local_channel_ids.append(i - offset)
+ array = pa.array(local_channel_ids).cast(LOCAL_CHANNEL_ID.type)
+ return table.append_column(LOCAL_CHANNEL_ID, array)
+
+ def _get_interval_times(self, dt: datetime) -> pd.DatetimeIndex:
+ """Get times in the observation context interval."""
+ start, end = self.obs_context_hours
+ start += self.data_spacing # Window times are end-aligned
+
+ return pd.date_range(
+ dt + pd.Timedelta(hours=start),
+ dt + pd.Timedelta(hours=end),
+ freq=f"{self.data_spacing}h",
+ )
+
+ def _get_parquet_files_to_read(self, interval_times: pd.DatetimeIndex):
+ """Get parquet files to read for given time interval."""
+ required_dates = {t.strftime("%Y%m%d") for t in interval_times}
+
+ for sensor in self.sensors:
+ for date in required_dates:
+ file_path = os.path.join(self.data_path, sensor, f"{date}", "0.parquet")
+ yield (sensor, file_path)
+
+ def _iterate_parquet_da_windows(
+ self,
+ parquet_path: str,
+ target_windows: pd.DatetimeIndex,
+ ):
+ """
+ Stream Arrow Tables, one per DA_window row-group.
+
+ Args:
+ parquet_path: Path to parquet file
+ target_windows: Only yield these DA_window
+
+ Yields:
+ PyArrow tables, one per DA window
+
+ Note:
+ Silently skips files that don't exist or can't be read
+ """
+ try:
+ if self.fs is not None:
+ file = io.BytesIO(self.fs.cat_file(parquet_path))
+ else:
+ file = parquet_path
+
+ parquet = pq.ParquetFile(file)
+ schema = parquet.schema_arrow
+
+ # With uniform schema, just read all columns
+ da_idx = schema.get_field_index("DA_window")
+
+ for row_group_idx in range(parquet.num_row_groups):
+ stats = (
+ parquet.metadata.row_group(row_group_idx).column(da_idx).statistics
+ )
+ row_group_lo, row_group_hi = stats.min, stats.max
+
+ this_window = None
+ for w in target_windows:
+ if row_group_lo <= w <= row_group_hi:
+ this_window = w
+
+ if this_window is None:
+ continue
+
+ # Read all columns - no need for platform-specific selection
+ table = parquet.read_row_group(row_group_idx)
+
+ # Filter if row-group spans multiple windows
+ if row_group_lo != row_group_hi:
+ mask = pc.is_in(table["DA_window"], pa.array(list(target_windows)))
+ table = table.filter(mask)
+
+ if table.num_rows == 0:
+ continue
+
+ yield this_window, table
+ except (FileNotFoundError, OSError):
+ # File doesn't exist or can't be read - silently skip
+ return
+
+ def _filter_observations(self, table: pa.Table) -> pa.Table:
+ return filter_observations(
+ table,
+ self.qc_filter,
+ conv_uv_in_situ_only=self.conv_uv_in_situ_only,
+ conv_gps_level1_only=self.conv_gps_level1_only,
+ )
+
+ def _normalize_observations(
+ self,
+ table: pa.Table,
+ ) -> pa.Table:
+ """Normalize observation data using PyArrow compute functions."""
+ if self.normalization == "minmax":
+ # Simple minmax normalization (0-400 range)
+ normalized = pc.divide(pc.subtract(table["Observation"], 0), 400 - 0)
+ elif self.normalization == "zscore":
+ # Normalize using the joined mean and stddev columns
+ normalized = pc.divide(
+ pc.subtract(table["Observation"], table["mean"]), table["stddev"]
+ )
+ else:
+ raise ValueError(f"Unknown normalization type: {self.normalization}")
+ return table.set_column(
+ table.schema.get_field_index("Observation"),
+ "Observation",
+ normalized,
+ )
+
+ _extra_channel_fields = ["min_valid", "max_valid", "is_conv", "mean", "stddev"]
+
+ def _add_channel_metadata(self, table):
+ return table.join(
+ self.channel_table.select(
+ [
+ GLOBAL_CHANNEL_ID.name,
+ LOCAL_CHANNEL_ID.name,
+ SENSOR_ID.name,
+ *self._extra_channel_fields,
+ ]
+ ),
+ GLOBAL_CHANNEL_ID.name,
+ )
+
+ async def sel_time(self, times: pd.DatetimeIndex) -> pa.Table:
+ """
+ Load observation data for specified times.
+
+ Args:
+ times: Target times to load data for
+
+ Returns:
+ PyArrow table containing observation data (sorted by sensor and the obs_window)
+ """
+ # Get all times needed for the context window
+ all_times = set()
+ for t in times:
+ interval_times = self._get_interval_times(t)
+ all_times.update(interval_times)
+
+ interval_times = pd.DatetimeIndex(sorted(all_times))
+
+ # Get files to read
+ files_to_read = self._get_parquet_files_to_read(interval_times)
+
+ # Load data from all files
+
+ tables = {}
+ for sensor, file_path in files_to_read:
+ for interval_time, table in self._iterate_parquet_da_windows(
+ file_path, interval_times
+ ):
+ table = self._add_channel_metadata(table)
+ table = self._filter_observations(table)
+ # Drop specified global channels, if any
+ if self.drop_obs_channel_ids:
+ mask = pc.is_in(
+ table[GLOBAL_CHANNEL_ID.name],
+ pa.array(self.drop_obs_channel_ids).cast(
+ table[GLOBAL_CHANNEL_ID.name].type
+ ),
+ )
+ # Keep rows whose GLOBAL_CHANNEL_ID is NOT in drop list
+ table = table.filter(pc.invert(mask))
+ table = self._normalize_observations(table)
+ table = table.drop(self._extra_channel_fields)
+ # Apply normalization to observations using PyArrow
+ tables.setdefault(interval_time, []).append(table)
+
+ # Combine all observations
+ def process(t):
+ all_tables = [
+ table
+ for interval_time in self._get_interval_times(t)
+ for table in tables.get(interval_time, [])
+ ]
+
+ if not all_tables:
+ return empty
+
+ table = pa.concat_tables(all_tables)
+ # table = table.combine_chunks()
+ # Cast to ensure proper nullability and types
+ # it's 3x faster to filter the combined table
+ return table.cast(self.output_schema)
+
+ empty = self._get_empty_table()
+ return {"obs_v2": [process(t) for t in times]}
+
+ def _get_empty_table(self):
+ # Return empty table with proper schema
+ empty_arrays = [pa.array([], type=field.type) for field in self.output_schema]
+ return pa.table(empty_arrays, schema=self.output_schema)
diff --git a/examples/weather/healda/datasets/obs_time_range_loader.py b/examples/weather/healda/datasets/obs_time_range_loader.py
new file mode 100644
index 0000000000..39ee367123
--- /dev/null
+++ b/examples/weather/healda/datasets/obs_time_range_loader.py
@@ -0,0 +1,202 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import os
+
+import numpy as np
+import pandas as pd
+import pyarrow as pa
+import pyarrow.compute as pc
+import pyarrow.parquet as pq
+from config.environment import UFS_OBS_PATH, UFS_OBS_PROFILE
+from utils import storage
+
+from datasets.etl.combined_schema import get_combined_observation_schema
+
+logger = logging.getLogger(__name__)
+
+TIME_COLUMN = "Absolute_Obs_Time"
+
+
+def _da_window_from_time(t: pd.Timestamp) -> pd.Timestamp:
+ # da window at 21 UTC is (21-3, 21]
+ t0 = np.datetime64("2000-01-01T00")
+ dt = np.timedelta64(3, "h")
+
+ delta = (t - t0) % dt
+ return t if delta == np.timedelta64(0, "h") else t + dt - delta
+
+
+def _get_file_names(
+ dir,
+ sensors,
+ time_min: pd.Timestamp,
+ time_max: pd.Timestamp,
+):
+ da_min = _da_window_from_time(time_min)
+ da_max = _da_window_from_time(time_max)
+
+ # get date range
+ dates = []
+ this_date = da_min.date()
+ day = pd.Timedelta(1, "d")
+ while this_date <= da_max.date():
+ dates.append(this_date)
+ this_date += day
+
+ file_names = [
+ os.path.join(dir, sensor, date.strftime("%Y%m%d"), "0.parquet")
+ for date in dates
+ for sensor in sensors
+ ]
+ return file_names
+
+
+def _scan_parquet_dir(dir, sensors, time_min, time_max, columns=(), filesystem=None):
+ if dir.startswith("s3://"):
+ dir = dir[len("s3://") :]
+ for file_name in _get_file_names(dir, sensors, time_min, time_max):
+ try:
+ with pq.ParquetFile(file_name, filesystem=filesystem) as f:
+ yield from _scan_parquet_file(
+ f,
+ column=TIME_COLUMN,
+ time_min=time_min,
+ time_max=time_max,
+ columns=columns,
+ )
+ except FileNotFoundError:
+ logging.getLogger(__name__).debug(f"{file_name} not found.")
+ pass
+
+
+def _open_channel_table(dir, filesystem=None) -> pa.Table:
+ if dir.startswith("s3://"):
+ dir = dir[len("s3://") :]
+ channel_table_path = os.path.join(dir, "channel_table.parquet")
+ with pq.ParquetFile(channel_table_path, filesystem=filesystem) as f:
+ return f.read()
+
+
+def _scan_parquet_file(
+ parquet,
+ column,
+ time_min,
+ time_max,
+ columns=(),
+):
+ """
+ Stream Arrow Tables, one per DA_window row-group.
+
+ Args:
+ parquet_path: Path to parquet file
+ target_windows: Only yield these DA_window
+
+ Yields:
+ PyArrow tables, one per DA window
+
+ Note:
+ Silently skips files that don't exist or can't be read
+ """
+ schema = parquet.schema_arrow
+
+ # With uniform schema, just read all columns
+ da_idx = schema.get_field_index(column)
+
+ for row_group_idx in range(parquet.num_row_groups):
+ stats = parquet.metadata.row_group(row_group_idx).column(da_idx).statistics
+
+ # print(stats.min < time_min, stats.max, time_min)
+ if stats.max < time_min:
+ continue
+
+ if stats.min > time_max:
+ continue
+
+ # Read all columns - no need for platform-specific selection
+ table = parquet.read_row_group(row_group_idx, list(columns) + [column])
+
+ time_col = pc.field(column)
+ filter = (time_col >= time_min) & (time_col <= time_max)
+ table = table.filter(filter)
+ table = table.select(columns)
+
+ if table.num_rows == 0:
+ continue
+
+ yield table
+
+
+class Loader:
+ """Parquet obs loader
+
+ Allows selecting data based on a time range. only supports scalar time
+ bounds at the moment.
+ """
+
+ def __init__(
+ self,
+ sensors=("atms", "conv"),
+ columns=(
+ "Latitude",
+ "Longitude",
+ "Global_Channel_ID",
+ "Height",
+ "Pressure",
+ TIME_COLUMN,
+ ),
+ join_channel_table: bool = True,
+ ):
+ self.sensors = sensors
+ self._filesystem = storage.get_pyarrow_filesystem(
+ UFS_OBS_PROFILE, connect_timeout=1_000, request_timeout=1_000
+ )
+ self.channel_table = _open_channel_table(UFS_OBS_PATH, self._filesystem)
+ self.columns = columns
+ self.join_channel_table = join_channel_table
+
+ @property
+ def schema(self):
+ schema = get_combined_observation_schema()
+ schema = pa.schema([f for f in schema if f.name in self.columns])
+ return schema
+
+ def _get_empty(self):
+ obs = pa.table([[]] * len(self.schema), schema=self.schema)
+ return obs
+
+ def sel_time_range(
+ self, time_min: pd.Timestamp, time_max: pd.Timestamp
+ ) -> pa.Table:
+ scanner = _scan_parquet_dir(
+ UFS_OBS_PATH,
+ sensors=self.sensors,
+ columns=self.columns,
+ time_min=time_min,
+ time_max=time_max,
+ filesystem=self._filesystem,
+ )
+ obs = list(scanner)
+ if len(obs) > 0:
+ obs = pa.concat_tables(obs)
+ else:
+ logger.warning(f"No observations loaded for {time_min} -- {time_max}")
+ obs = self._get_empty()
+
+ if self.join_channel_table:
+ obs = obs.join(self.channel_table, "Global_Channel_ID")
+
+ return obs
diff --git a/examples/weather/healda/datasets/prefetch_map.py b/examples/weather/healda/datasets/prefetch_map.py
new file mode 100644
index 0000000000..404142bc43
--- /dev/null
+++ b/examples/weather/healda/datasets/prefetch_map.py
@@ -0,0 +1,194 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Async data loader that processes batches in a background thread with a separate CUDA stream.
+"""
+
+import dataclasses
+import queue
+import threading
+from typing import Any, Callable, Iterable, Optional
+
+import torch
+from torch.utils.data import DataLoader
+
+
+class _Done:
+ pass
+
+
+class _PrefetchIterator:
+ """
+ Wraps a PyTorch DataLoader to process batches asynchronously in a background thread.
+
+ The background thread uses a separate CUDA stream for processing, which synchronizes
+ its stream before adding a sample to the queue.
+ """
+
+ def __init__(
+ self,
+ dataloader: Iterable,
+ transform: Callable[[Any], Any],
+ queue_size: int = 2,
+ cuda_stream: Optional[torch.cuda.Stream] = None,
+ ):
+ """
+ Args:
+ dataloader: The PyTorch DataLoader to wrap
+ transform: Function to apply to each batch (batch -> batch)
+ queue_size: Maximum size of the processing queue (default: 2)
+ cuda_stream: CUDA stream to use for background processing (creates new if None)
+ """
+ self.dataloader = dataloader
+ self.transform = transform
+ self.queue_size = queue_size
+ self.cuda_stream = cuda_stream or torch.cuda.Stream()
+
+ # Threading components
+ self.queue = queue.Queue(maxsize=queue_size)
+ self.thread = None
+ self.stop_event = threading.Event()
+
+ # Iterator state
+ self.dataloader_iter = None
+ self._started = False
+
+ def _worker(self):
+ """Background worker that processes batches."""
+ try:
+ while not self.stop_event.is_set():
+ try:
+ # Get next batch from dataloader
+ batch = next(self.dataloader_iter)
+ except StopIteration:
+ # No more data, put sentinel and break
+ self.queue.put((_Done, None))
+ break
+
+ # Process batch in background CUDA stream
+ with torch.cuda.stream(self.cuda_stream):
+ processed_batch = self.transform(batch)
+
+ # Synchronize this stream to ensure work is complete before sending to main thread
+ # alternatively, could use cuda events for synchronization
+ self.cuda_stream.synchronize()
+
+ # Put processed batch in queue
+ self.queue.put((processed_batch, None))
+
+ except Exception as e:
+ self.queue.put((None, e))
+
+ def _start(self):
+ """Start the background processing thread."""
+ if self._started:
+ return
+
+ self.dataloader_iter = iter(self.dataloader)
+ self.stop_event.clear()
+ self.thread = threading.Thread(target=self._worker, daemon=True)
+ self.thread.start()
+ self._started = True
+
+ def __len__(self):
+ return len(self.dataloader)
+
+ def _stop(self):
+ """Stop the background processing thread."""
+ if not self._started:
+ return
+
+ self.stop_event.set()
+ if self.thread and self.thread.is_alive():
+ self.thread.join(timeout=1.0)
+ self._started = False
+
+ def __iter__(self):
+ """Start background processing and return iterator."""
+ self._start()
+ return self
+
+ def _record_stream(self, x):
+ """Marks tensors as having been used by this stream"""
+ if isinstance(x, torch.Tensor):
+ x.record_stream(self.cuda_stream)
+ elif isinstance(x, list):
+ for item in x:
+ self._record_stream(item)
+ elif isinstance(x, dict):
+ for item in x.values():
+ self._record_stream(item)
+ elif dataclasses.is_dataclass(x):
+ x.record_stream(self.cuda_stream)
+
+ def __next__(self):
+ """Get next processed batch."""
+ if not self._started:
+ raise RuntimeError("Iterator not started. Call __iter__ first.")
+
+ # Get processed batch from queue
+ try:
+ batch, error = self.queue.get()
+ except queue.Empty:
+ raise RuntimeError("Timeout waiting for processed batch")
+
+ if error is not None:
+ raise error
+
+ # Check for end of data
+ if batch is _Done:
+ self._stop()
+ raise StopIteration
+
+ # Needed for safe garbage collection: ensures that we do not deallocate the batch before
+ # work on it has completed
+ self._record_stream(batch)
+
+ return batch
+
+ def __del__(self):
+ """Cleanup on deletion."""
+ self._stop()
+
+
+def prefetch_map(
+ dataloader: DataLoader,
+ transform: Callable[[Any], Any],
+ queue_size: int = 2,
+ cuda_stream: Optional[torch.cuda.Stream] = None,
+) -> _PrefetchIterator:
+ """
+ Create an async data loader that processes batches in a background thread.
+
+ Args:
+ dataloader: The PyTorch DataLoader to wrap
+ transform: Function to apply to each batch (batch -> batch)
+ queue_size: Maximum size of the processing queue (default: 2)
+ cuda_stream: CUDA stream to use for background processing (creates new if None)
+
+ Returns:
+ AsyncDataLoader that can be iterated over like a regular DataLoader
+
+ Example:
+ >>> def move_to_gpu(batch):
+ ... return {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
+ >>>
+ >>> async_loader = async_map(dataloader, move_to_gpu)
+ >>> for batch in async_loader:
+ ... # batch is already on GPU and processed
+ ... pass
+ """
+ return _PrefetchIterator(dataloader, transform, queue_size, cuda_stream)
diff --git a/examples/weather/healda/datasets/round_robin.py b/examples/weather/healda/datasets/round_robin.py
new file mode 100644
index 0000000000..6b12ab9a2c
--- /dev/null
+++ b/examples/weather/healda/datasets/round_robin.py
@@ -0,0 +1,47 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+
+class RoundRobinLoader(torch.utils.data.IterableDataset):
+ """Round-robin interleaving of multiple map-style dataloaders.
+
+ This allows converting map-style datasets to iterable-style. It loads data
+ from the dataloaders in a round-robin manner, removing exhausted dataloaders
+ from rotation until ALL dataloaders are exhausted.
+
+ Args:
+ dataloaders: List of DataLoader instances to interleave
+ """
+
+ def __init__(self, dataloaders: list[torch.utils.data.DataLoader]):
+ super().__init__()
+ self.dataloaders = dataloaders
+
+ def __len__(self):
+ return sum(len(dl) for dl in self.dataloaders)
+
+ def __iter__(self):
+ iterators = [iter(dl) for dl in self.dataloaders]
+ active_indices = list(range(len(self.dataloaders)))
+
+ while active_indices:
+ for idx in list(active_indices):
+ try:
+ yield next(iterators[idx])
+ except StopIteration:
+ # Remove exhausted dataloader from rotation
+ active_indices.remove(idx)
diff --git a/examples/weather/healda/datasets/samplers.py b/examples/weather/healda/datasets/samplers.py
new file mode 100644
index 0000000000..dbe7829a67
--- /dev/null
+++ b/examples/weather/healda/datasets/samplers.py
@@ -0,0 +1,255 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+import random
+
+import torch
+import torch.utils.data
+from utils import distributed as dist
+
+
+def subsample(dataset, min_samples):
+ samples = min_samples % dist.get_world_size() + min_samples
+ golden_ratio = 1.618033988749
+ n = len(dataset)
+ sampler = [int((i * n * golden_ratio) % n) for i in range(samples)]
+ sampler = sorted(sampler)
+ return sampler
+
+
+def distributed_split(tasks, drop_last=True):
+ n = len(tasks)
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ chunk = math.ceil(len(tasks) / world_size)
+ start = rank * chunk
+ stop = n if drop_last and (rank == world_size - 1) else start + chunk
+ return [t for i, t in enumerate(tasks) if start <= i < stop]
+
+
+class InfiniteSequentialSampler(torch.utils.data.Sampler):
+ """An infinite sampler that iterates sequentially through a dataset
+ reshuffling every ``shuffle_every`` iterations
+ """
+
+ def __init__(
+ self,
+ dataset: torch.utils.data.Dataset,
+ shuffle: bool = True,
+ shuffle_every: int = 48,
+ ):
+ self.shuffle = shuffle
+ self.shuffle_every = shuffle_every
+ self.n = len(dataset)
+ self.rank = dist.get_rank()
+ self.replicas = dist.get_world_size()
+
+ def __iter__(self):
+ i = random.randint(0, self.n - 1)
+ k = 0
+ while True:
+ if (self.shuffle_every > 0) and (k % self.shuffle_every == 0):
+ i = random.randint(0, self.n - 1)
+
+ yield i
+
+ i = (i + 1) % self.n
+ k += 1
+
+
+class InfiniteChunkedIterable(torch.utils.data.IterableDataset):
+ """
+ Infinitely yields batches of contiguous samples from the dataset, reshuffling every
+ ``chunk_size // batch_size`` batches. As each worker runs __iter__ in its own process,
+ workers are assigned independent chunks. Data is yielded from workers round-robin, so
+ chunks will be interleaved across iterations.
+ """
+
+ def __init__(
+ self,
+ base_dataset: torch.utils.data.Dataset,
+ chunk_size: int = 48,
+ batch_size: int = 4,
+ ):
+ """
+ Args:
+ base_dataset: A map-style dataset (e.g. HealpixDatasetV5).
+ chunk_size: Number of consecutive samples in each chunk.
+ batch_size: Size of the mini-batches yielded to the main loop.
+ """
+ super().__init__()
+ self.dataset = base_dataset
+ self.n = len(base_dataset)
+ self.chunk_size = chunk_size
+ self.batch_size = batch_size
+
+ def __iter__(self):
+ while True:
+ start_idx = random.randint(0, self.n - 1)
+ indices = [(start_idx + j) % self.n for j in range(self.chunk_size)]
+
+ for i in range(0, len(indices), self.batch_size):
+ batch = [self.dataset[idx] for idx in indices[i : i + self.batch_size]]
+ yield torch.utils.data.default_collate(batch) # batch the list of dicts
+
+
+class ChunkedDistributedSampler(torch.utils.data.Sampler):
+ """A chunked random sampler. This allows accessing the dataset sequentially
+ within chunks, for better performance w/ chunked datasets that have caching
+ implemented.
+ """
+
+ def __init__(
+ self,
+ dataset: torch.utils.data.Dataset,
+ chunk_size: int = 1,
+ rank=0,
+ num_replicas=1,
+ shuffle=False,
+ shuffle_within_chunk=False,
+ drop_last=True,
+ seed=42,
+ sampler_fn=None,
+ ):
+ """
+ Args:
+ base_dataset: A map-style dataset (e.g. HealpixDatasetV5).
+ chunk_size: Number of consecutive samples in each chunk.
+ batch_size: Size of the mini-batches yielded to the main loop.
+ shuffle: Whether to shuffle order of chunks.
+ shuffle_within_chunk: Whether to shuffle indices within each chunk.
+ seed: random seed for the sampler, will be broadcasted from rank 0 to all other ranks
+ """
+ super().__init__()
+ self.n = len(dataset)
+ nchunks = self.n // chunk_size
+ chunks = list(range(nchunks))
+
+ if torch.distributed.is_initialized():
+ seed = torch.tensor(seed).cuda()
+ torch.distributed.broadcast(seed, src=0)
+ seed = seed.item()
+
+ self._chunk_sampler = (
+ sampler_fn(chunks)
+ if sampler_fn is not None
+ else torch.utils.data.DistributedSampler(
+ chunks,
+ num_replicas=num_replicas,
+ rank=rank,
+ shuffle=shuffle,
+ seed=seed,
+ drop_last=drop_last,
+ )
+ )
+ self.chunk_size = chunk_size
+ self.shuffle_within_chunk = shuffle_within_chunk
+ self.seed = seed
+ self.rank = rank
+ self.epoch = 0
+ self.index_within_chunk = 0
+ self._chunk_iter = iter(self._chunk_sampler)
+ self._current_chunk_indices = None
+
+ if self.shuffle_within_chunk:
+ self.rng = random.Random(seed + rank)
+
+ def set_epoch(self, epoch):
+ try:
+ self._chunk_sampler.set_epoch(epoch)
+ except AttributeError:
+ pass
+ self.epoch = epoch
+
+ def __len__(self):
+ return self.n
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.index_within_chunk == 0:
+ try:
+ self.active_chunk = next(self._chunk_iter)
+ except StopIteration:
+ self.set_epoch(self.epoch + 1) # reset sampler's rng
+ self._chunk_iter = iter(self._chunk_sampler)
+ raise StopIteration()
+
+ chunk_start = self.active_chunk * self.chunk_size
+ self._current_chunk_indices = list(
+ range(chunk_start, chunk_start + self.chunk_size)
+ )
+
+ if self.shuffle_within_chunk:
+ self.rng.shuffle(self._current_chunk_indices)
+
+ i = self._current_chunk_indices[self.index_within_chunk]
+ self.index_within_chunk = (self.index_within_chunk + 1) % self.chunk_size
+ return i
+
+
+class RestartableDistributedSampler(torch.utils.data.Sampler):
+ """A stateful distributed sampler that automatically loops over the dataset."""
+
+ def __init__(
+ self,
+ dataset: torch.utils.data.Dataset,
+ rank=0,
+ num_replicas=1,
+ shuffle=True,
+ drop_last=True,
+ seed=42,
+ ):
+ super().__init__()
+ self.iteration = 0
+ self.epoch = 0
+ self.len = len(dataset)
+ self.seed = seed
+ self.permutation = None
+ self.rank = rank
+ self.num_replicas = num_replicas
+
+ def __len__(self):
+ return self.len // self.num_replicas
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+ self.iteration = 0
+ rng = torch.Generator().manual_seed(self.seed + self.epoch + self.rank)
+ permutation = torch.randperm(self.len, generator=rng)
+
+ rem = self.len % self.num_replicas
+ if rem > 0:
+ permutation = permutation[:-rem]
+ self.permutation = permutation[self.rank :: self.num_replicas]
+
+ def restart(self, epoch, iteration, seed=None):
+ self.seed = seed or self.seed
+ self.set_epoch(epoch)
+ self.iteration = iteration
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.iteration >= len(self):
+ self.set_epoch(self.epoch + 1)
+ raise StopIteration()
+
+ idx = self.permutation[self.iteration]
+ self.iteration += 1
+ return idx
diff --git a/examples/weather/healda/datasets/sensors.py b/examples/weather/healda/datasets/sensors.py
new file mode 100644
index 0000000000..a2f716ec51
--- /dev/null
+++ b/examples/weather/healda/datasets/sensors.py
@@ -0,0 +1,260 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pathlib
+import warnings
+from dataclasses import dataclass, field
+
+import numpy as np
+import pandas as pd
+
+
+@dataclass
+class SensorConfig:
+ """
+ Sensor metadata that sets up data loading.
+ Defines the sensor name, platforms, channels, and normalization stats.
+ """
+
+ name: str
+ platforms: list[str]
+ channels: int
+ nc_file_template: str
+ means: np.ndarray = field(init=False)
+ stds: np.ndarray = field(init=False)
+ min_valid: float = 0.0
+ max_valid: float = 400.0
+ sensor_type: str = "microwave"
+ raw_to_local: np.ndarray = field(
+ init=False
+ ) # lookup-table: raw_id โ local channel
+
+ def __post_init__(self):
+ base = pathlib.Path(__file__).parent / "etl/normalizations"
+ norm_file = base / f"{self.name}_normalizations.csv"
+
+ if norm_file.exists():
+ df = pd.read_csv(norm_file)
+ # Col -1 is the avg across all platforms
+ channel_col = "Raw_Channel_ID"
+ df = df[df["Platform_ID"] == -1].sort_values(channel_col)
+
+ self.means = df["obs_mean"].to_numpy()
+ self.stds = df["obs_std"].to_numpy()
+
+ # build a rawโtoโlocal LUT
+ raw_ids = df[channel_col].to_numpy()
+ max_raw = raw_ids.max()
+ lookup_table = np.full(max_raw + 1, 0, dtype=int)
+ for local_idx, raw in enumerate(raw_ids, start=1):
+ lookup_table[raw] = local_idx
+ self.raw_to_local = lookup_table
+
+ else:
+ warnings.warn(
+ f"No normalization file for {self.name!r}. "
+ "Defaulting to means=0, stds=1, identity mapping."
+ )
+ self.means = np.zeros(self.channels, dtype=float)
+ self.stds = np.ones(self.channels, dtype=float)
+ self.raw_to_local = np.arange(self.channels + 1, dtype=int)
+
+
+def get_global_channel_id(sensor, raw_channel_ids):
+ """Map per-sensor raw channel IDs to unified global IDs (no overlap across sensors)."""
+ raw_to_local = SENSOR_CONFIGS[sensor].raw_to_local
+ channel_offset = SENSOR_OFFSET[sensor]
+ local_channels = raw_to_local[raw_channel_ids] - 1 # Convert to 0-based indexing
+ return (local_channels + channel_offset).astype(np.uint16)
+
+
+SENSOR_CONFIGS = {
+ "atms": SensorConfig(
+ name="atms",
+ platforms=["npp", "n20"],
+ channels=22,
+ nc_file_template="diag_atms_{platform}_ges.{date}_control.nc4",
+ min_valid=0.0,
+ max_valid=400.0,
+ sensor_type="microwave",
+ ),
+ "mhs": SensorConfig(
+ name="mhs",
+ platforms=["metop-a", "metop-b", "metop-c", "n18", "n19"],
+ channels=5,
+ nc_file_template="diag_mhs_{platform}_ges.{date}_control.nc4",
+ min_valid=0.0,
+ max_valid=400.0,
+ sensor_type="microwave",
+ ),
+ "amsua": SensorConfig(
+ name="amsua",
+ platforms=["metop-a", "metop-b", "metop-c", "n15", "n16", "n17", "n18", "n19"],
+ channels=15,
+ nc_file_template="diag_amsua_{platform}_ges.{date}_control.nc4",
+ min_valid=0.0,
+ max_valid=400.0,
+ sensor_type="microwave",
+ ),
+ "amsub": SensorConfig(
+ name="amsub",
+ platforms=["n15", "n16", "n17"],
+ channels=5,
+ nc_file_template="diag_amsub_{platform}_ges.{date}_control.nc4",
+ min_valid=0.0,
+ max_valid=400.0,
+ sensor_type="microwave",
+ ),
+ "iasi": SensorConfig(
+ name="iasi",
+ platforms=["metop-a", "metop-b", "metop-c"],
+ channels=175,
+ nc_file_template="diag_iasi_{platform}_ges.{date}_control.nc4",
+ min_valid=150.0,
+ max_valid=350.0,
+ sensor_type="infrared",
+ ),
+ "cris-fsr": SensorConfig(
+ name="cris-fsr",
+ platforms=["npp", "n20"],
+ channels=100,
+ nc_file_template="diag_cris_fsr_{platform}_ges.{date}_control.nc4",
+ min_valid=150.0,
+ max_valid=350.0,
+ sensor_type="infrared",
+ ),
+ "conv": SensorConfig(
+ name="conv",
+ platforms=[], # platform idea doesn't apply to conv
+ channels=8, # all conv sensors stacked (gps angle, gps temp, gps spfh, ps, q, t, u, v)
+ nc_file_template="conv_{platform}_ges.{date}_control.nc4",
+ sensor_type="conv",
+ ),
+}
+
+
+class QCLimits:
+ """Conventional Observation QC filtering limits."""
+
+ # Height limits (meters)
+ HEIGHT_MIN = 0
+ HEIGHT_MAX = 60000
+ # Pressure limits (hPa)
+ PRESSURE_MIN_GPS = 0.5
+ PRESSURE_MIN_DEFAULT = 200
+ PRESSURE_MAX = 1100
+
+
+# Concept of platform for conv is only used in etl, does not apply outside of etl. All conv obs have platform 0
+@dataclass(frozen=True)
+class ConvChannel:
+ """Conv sensor channel definition, used for ETL and creating channel table"""
+
+ name: str
+ platform: str
+ nc_column: str
+ min_valid: float
+ max_valid: float
+
+
+CONV_CHANNELS = [
+ ConvChannel("gps_angle", "gps", "Observation", float("-inf"), float("inf")),
+ ConvChannel("gps_t", "gps", "Temperature_at_Obs_Location", 150, 350),
+ ConvChannel("gps_q", "gps", "Specific_Humidity_at_Obs_Location", 0.0, 1.0),
+ ConvChannel("ps", "ps", "Observation", float("-inf"), float("inf")),
+ ConvChannel("q", "q", "Observation", 0, 1),
+ ConvChannel("t", "t", "Observation", 150, 350),
+ ConvChannel("u", "uv", "u_Observation", -100, 100),
+ ConvChannel("v", "uv", "v_Observation", -100, 100),
+]
+
+CONV_CHANNEL_NAMES = [c.name for c in CONV_CHANNELS]
+CONV_PLATFORMS = list(dict.fromkeys(c.platform for c in CONV_CHANNELS))
+CONV_GPS_CHANNELS = [i for i, c in enumerate(CONV_CHANNELS) if c.platform == "gps"]
+CONV_GPS_LEVEL2_CHANNELS = [
+ i for i, c in enumerate(CONV_CHANNELS) if c.name in ("gps_t", "gps_q")
+]
+CONV_UV_CHANNELS = [i for i, c in enumerate(CONV_CHANNELS) if c.platform == "uv"]
+CONV_UV_IN_SITU_TYPES = [220, 221, 229, 230, 231, 232, 233, 234, 235, 280, 282]
+
+
+def _build_conv_channel_map() -> dict[str, int]:
+ """Build map from platform name to first channel ID (1-indexed)."""
+ channel_map = {}
+ for i, channel in enumerate(CONV_CHANNELS, start=1):
+ if channel.platform not in channel_map:
+ channel_map[channel.platform] = i
+ return channel_map
+
+
+CONV_CHANNEL_MAP = _build_conv_channel_map()
+
+
+def _next_power_of_two(n: int) -> int:
+ return 1 << (n - 1).bit_length()
+
+
+PLATFORM_NAME_TO_ID = {
+ "aqua": 0,
+ "aura": 1,
+ "f10": 2,
+ "f11": 3,
+ "f13": 4,
+ "f14": 5,
+ "f15": 6,
+ "g08": 7,
+ "g10": 8,
+ "g11": 9,
+ "g12": 10,
+ "m08": 11,
+ "m09": 12,
+ "m10": 13,
+ "metop-a": 14,
+ "metop-b": 15,
+ "metop-c": 16,
+ "n11": 17,
+ "n12": 18,
+ "n14": 19,
+ "n15": 20,
+ "n16": 21,
+ "n17": 22,
+ "n18": 23,
+ "n19": 24,
+ "n20": 25,
+ "npp": 26,
+ "gps": 27,
+ "ps": 28,
+ "q": 29,
+ "t": 30,
+ "uv": 31,
+}
+
+PLATFORM_ID_TO_NAME = {v: k for k, v in PLATFORM_NAME_TO_ID.items()}
+
+NPLATFORMS = _next_power_of_two(max(len(PLATFORM_NAME_TO_ID), 64)) # 64
+
+SENSOR_OFFSET = {}
+offset = 0
+for name, cfg in SENSOR_CONFIGS.items():
+ SENSOR_OFFSET[name] = offset
+ offset += cfg.channels
+NCHANNEL = _next_power_of_two(max(offset, 1024)) # 1024
+
+# GPS channel Global_Channel_IDs (for use in SQL queries against parquet)
+CONV_GPS_GLOBAL_IDS = [SENSOR_OFFSET["conv"] + i for i in CONV_GPS_CHANNELS]
+
+
+SENSOR_NAME_TO_ID = {name: idx for idx, name in enumerate(SENSOR_CONFIGS.keys())}
+SENSOR_ID_TO_NAME = {idx: name for name, idx in SENSOR_NAME_TO_ID.items()}
diff --git a/examples/weather/healda/datasets/transform.py b/examples/weather/healda/datasets/transform.py
new file mode 100644
index 0000000000..2d185cee00
--- /dev/null
+++ b/examples/weather/healda/datasets/transform.py
@@ -0,0 +1,496 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import dataclasses
+import datetime
+import functools
+import warnings
+
+import cftime
+import config.environment as config
+import earth2grid
+import numpy as np
+import pyarrow as pa
+import pyarrow.compute as pc
+import torch
+import zarr
+from utils.storage import get_storage_options
+
+from datasets import catalog, features
+from datasets.analysis_loaders import (
+ get_batch_info,
+)
+from datasets.base import (
+ VariableConfig,
+)
+from datasets.variable_configs import VARIABLE_CONFIGS
+from utils import profiling
+from physicsnemo.experimental.models.healda import types
+
+warnings.filterwarnings(
+ "ignore",
+ message="The given NumPy array is not writable, and PyTorch does not support non-writable tensors",
+)
+
+# Column names required by the encode function
+ENCODE_REQUIRED_COLUMNS = [
+ "Latitude",
+ "Longitude",
+ "Absolute_Obs_Time",
+ "Platform_ID",
+ "Observation_Type",
+ "Observation",
+ "Global_Channel_ID",
+ "Sat_Zenith_Angle",
+ "Sol_Zenith_Angle",
+ "sensor_id",
+ "local_channel_id",
+]
+
+# Optional column names that are checked for existence in encode function
+ENCODE_OPTIONAL_COLUMNS = [
+ "Height",
+ "Pressure",
+ "Scan_Angle",
+]
+
+# All column names (required + optional)
+ENCODE_ALL_COLUMNS = ENCODE_REQUIRED_COLUMNS + ENCODE_OPTIONAL_COLUMNS
+
+
+# Static data loaders (moved from static_data.py)
+@functools.cache
+def load_lfrac(hpx_level) -> torch.Tensor:
+ src_grid = earth2grid.latlon.equiangular_lat_lon_grid(nlat=768, nlon=1536)
+ hpx_grid = earth2grid.healpix.Grid(
+ level=hpx_level, pixel_order=earth2grid.healpix.NEST
+ )
+ regridder = earth2grid.get_regridder(src_grid, hpx_grid)
+
+ # get static iputs
+ land_data = zarr.open_group(
+ config.UFS_LAND_DATA_ZARR,
+ storage_options=get_storage_options(config.UFS_LAND_DATA_PROFILE),
+ )
+ land_fraction = land_data["lfrac"][:]
+ land_fraction = regridder(torch.from_numpy(land_fraction).to(torch.float64))
+ return land_fraction
+
+
+@functools.cache
+def load_orography() -> np.ndarray:
+ entry = catalog.ufs()
+ group = entry.to_zarr()
+ return group["orog"][:]
+
+
+def _cftime_to_timestamp(time: cftime.DatetimeGregorian) -> float:
+ return datetime.datetime(
+ *cftime.to_tuple(time), tzinfo=datetime.timezone.utc
+ ).timestamp()
+
+
+def _reorder_nest_to_hpxpad(x):
+ x = torch.as_tensor(x)
+ src_order = earth2grid.healpix.NEST
+ dst_order = earth2grid.healpix.HEALPIX_PAD_XY
+ return earth2grid.healpix.reorder(x, src_order, dst_order)
+
+
+def _compute_second_of_day(time: cftime.datetime):
+ day_start = time.replace(hour=0, minute=0, second=0)
+ return (time - day_start) / datetime.timedelta(seconds=1)
+
+
+def _compute_day_of_year(time: cftime.datetime):
+ day_start = time.replace(hour=0, minute=0, second=0)
+ year_start = day_start.replace(month=1, day=1)
+ return (time - year_start) / datetime.timedelta(seconds=86400)
+
+
+def _compute_timestamp(time: cftime.datetime):
+ return int(_cftime_to_timestamp(time))
+
+
+def _get_static_condition(HPX_LEVEL, variable_config) -> torch.Tensor:
+ lfrac = load_lfrac(HPX_LEVEL)
+ orography = load_orography()
+ # insert land mask
+ orog_scale, orog_mean = 627.3885284872, 232.56013904090733
+ lfrac_scale, lfrac_mean = 0.4695501683565522, 0.3410480857539571
+ data = {
+ "orog": (orography - orog_mean) / orog_scale,
+ "lfrac": (lfrac - lfrac_mean) / lfrac_scale,
+ }
+ arrays = [torch.as_tensor(data[name]) for name in variable_config.variables_static]
+ array = torch.stack(arrays).float() # c x
+ return array.unsqueeze(1)
+
+
+@dataclasses.dataclass
+class TransformV2:
+ """Batch transform for normalizing state data and preparing observations for training.
+
+ Two-stage pipeline:
+ 1. ``transform(times, frames)`` - CPU preprocessing, returns intermediate dict
+ 2. ``device_transform(batch, device)`` - GPU transfer and featurization
+
+ Stage 1 - ``transform()`` returns dict with:
+ - ``target``: Normalized state tensor (B, C, T, X).
+ - ``unified_obs``: Tuple of (obs_tensors, offsets_3d, sensor_id_to_local).
+ - ``condition``: Static conditioning features (1, C_cond, X).
+ - ``second_of_day``, ``day_of_year``, ``timestamp``: Time encodings (B, T).
+
+ Intermediate ``unified_obs`` tuple structure:
+ - ``obs_tensors``: Dict of 1D tensors (N_obs,) - latitude, longitude,
+ observation, global_channel_id, sensor_id, platform_id, etc.
+ - ``offsets_3d``: Shape (S, B, T) cumulative end indices per sensor/batch/time.
+ - ``sensor_id_to_local``: Maps sensor_id to local index in offsets_3d.
+
+ Stage 2 - ``device_transform()`` converts ``unified_obs`` to ``UnifiedObservation``:
+ - Moves tensors to GPU
+ - Computes ``float_metadata`` via ``features.compute_unified_metadata()``
+ (encodes lat/lon, time deltas, zenith angles, etc.)
+ - Builds ``int_metadata`` tensor (N_obs, 6) with columns:
+ [sensor_id, hpx_pixel, local_channel, platform_id, obs_type, global_channel]
+ - Returns ``types.UnifiedObservation`` dataclass ready for model input.
+
+ Sensor grouping: Observations sorted by sensor_id with offsets_3d enabling
+ efficient (sensor, batch, time) slicing (see ``split_by_sensor`` in
+ ``physicsnemo.experimental.models.healda.types``).
+ """
+
+ variable_config: VariableConfig = VARIABLE_CONFIGS["era5"]
+ hpx_level: int = 10 # pixel level of the observations
+ hpx_level_condition: int = 6
+
+ def __post_init__(self):
+ batch_info = get_batch_info(self.variable_config)
+
+ self.mean = np.array(batch_info.center)[:, None]
+ self.std = np.array(batch_info.scales)[:, None]
+
+ @functools.cached_property
+ def _grid(self):
+ return earth2grid.healpix.Grid(
+ self.hpx_level, pixel_order=earth2grid.healpix.NEST
+ )
+
+ @staticmethod
+ def _sort_by_record_batch(table: pa.Table, column_name: str) -> pa.Table:
+ """
+ Sort PyArrow table by grouping record batches by a column value.
+ Assumes all rows in the batch have the same value for the column name.
+ """
+ record_batches_order = []
+ for batch in table.to_batches():
+ if batch.num_rows == 0:
+ continue
+ group_value = batch[column_name][0]
+ record_batches_order.append((group_value, batch))
+
+ # in empty case, from_batches will raise an error
+ if not record_batches_order:
+ return table
+
+ record_batches_order.sort(key=lambda x: x[0].as_py())
+ return pa.Table.from_batches([batch for _, batch in record_batches_order])
+
+ @staticmethod
+ def _append_batch_time_info_chunked(
+ table: pa.Table, b: int, t: int, timestamp: int
+ ) -> pa.Table:
+ """
+ Add batch/time indices and target time while maintaining original chunking.
+ """
+ b_idx_type = pa.int16()
+ t_idx_type = pa.int16()
+ time_type = pa.int64()
+
+ ref_col = table.column(0)
+
+ b_idx_chunks = []
+ t_idx_chunks = []
+ time_chunks = []
+
+ for chunk in ref_col.chunks:
+ L = len(chunk)
+ if L == 0:
+ b_idx_chunks.append(pa.array([], type=b_idx_type))
+ t_idx_chunks.append(pa.array([], type=t_idx_type))
+ time_chunks.append(pa.array([], type=time_type))
+ continue
+
+ # directly creating pa array of int16 not supported so use np first
+ b_idx_arr = np.full(L, b, dtype=np.int16)
+ t_idx_arr = np.full(L, t, dtype=np.int16)
+ times_arr = np.full(L, timestamp, dtype=np.int64)
+
+ b_idx_chunks.append(pa.array(b_idx_arr, type=b_idx_type))
+ t_idx_chunks.append(pa.array(t_idx_arr, type=t_idx_type))
+ time_chunks.append(pa.array(times_arr, type=time_type))
+
+ b_chunked = pa.chunked_array(b_idx_chunks, type=b_idx_type)
+ t_chunked = pa.chunked_array(t_idx_chunks, type=t_idx_type)
+ time_chunked = pa.chunked_array(time_chunks, type=time_type)
+
+ out = table.append_column("batch_idx", b_chunked)
+ out = out.append_column("time_idx", t_chunked)
+ out = out.append_column("target_time", time_chunked)
+ return out
+
+ @staticmethod
+ def _build_observation_offsets_3d(obs_table: pa.Table, frame_times):
+ B, T = len(frame_times), len(frame_times[0])
+
+ sensor_ids = set()
+ counts_map = {} # sensor_id -> (b, t) array of num obs in that sensor, batch, time
+
+ for batch in obs_table.to_batches():
+ if batch.num_rows == 0:
+ continue
+
+ s_id = int(batch["sensor_id"][0].as_py())
+ b_id = int(batch["batch_idx"][0].as_py())
+ t_id = int(batch["time_idx"][0].as_py())
+ n = batch.num_rows
+
+ sensor_ids.add(s_id)
+
+ if s_id not in counts_map:
+ counts_map[s_id] = torch.zeros((B, T), dtype=torch.int32)
+
+ counts_map[s_id][b_id, t_id] += n
+
+ # build per-sensor cumulative ends in row-major (b,t) order
+ # use only active sensors and maintain map, as all possible sensor ids unknown here
+ active_sensor_ids = sorted(sensor_ids)
+ S = len(active_sensor_ids)
+
+ # Handle empty case: no observations in entire batch
+ if not sensor_ids:
+ offsets_3d = torch.zeros((0, B, T), dtype=torch.int32)
+ sensor_id_to_local = torch.zeros((0,), dtype=torch.int32)
+ return offsets_3d, sensor_id_to_local
+
+ max_sensor_id = max(sensor_ids)
+
+ offsets_3d = torch.zeros((S, B, T), dtype=torch.int32)
+
+ prev_count = 0
+ for s_local, s_id in enumerate(active_sensor_ids):
+ counts_bt = counts_map[s_id] # [B,T]
+ flat_counts = counts_bt.reshape(-1) # len = B*T
+ flat_cumsum = torch.cumsum(flat_counts, dim=0) # cumulative ends
+ offsets_3d[s_local] = prev_count + flat_cumsum.reshape(B, T)
+ prev_count += flat_cumsum[-1].item()
+
+ # Create sensor_id -> local_idx map
+ sensor_id_to_local = torch.full((max_sensor_id + 1,), -1, dtype=torch.int32)
+ for local_idx, sensor_id in enumerate(active_sensor_ids):
+ sensor_id_to_local[sensor_id] = local_idx
+
+ return offsets_3d, sensor_id_to_local
+
+ @profiling.nvtx
+ def _process_obs(self, target_times: list[list[cftime.datetime]], frames):
+ # Add batch and time indices to each table before concatenation
+ all_obs_with_indices = []
+ for b_idx, sample_frames in enumerate(frames):
+ for t_idx, frame_dict in enumerate(sample_frames):
+ table = frame_dict["obs_v2"]
+ table_with_indices = self._append_batch_time_info_chunked(
+ table,
+ b_idx,
+ t_idx,
+ _compute_timestamp(target_times[b_idx][t_idx]),
+ )
+ all_obs_with_indices.append(table_with_indices)
+
+ obs = pa.concat_tables(all_obs_with_indices)
+
+ obs = self._sort_by_record_batch(obs, "sensor_id")
+
+ offsets_3d, sensor_id_to_local = self._build_observation_offsets_3d(
+ obs, target_times
+ )
+
+ # Extract columns into dictionary of torch tensors
+ obs_tensors = {}
+
+ # Required columns mapping
+ required_columns = {
+ "latitude": "Latitude",
+ "longitude": "Longitude",
+ "observation": "Observation",
+ "global_channel_id": "Global_Channel_ID",
+ "sat_zenith_angle": "Sat_Zenith_Angle",
+ "sol_zenith_angle": "Sol_Zenith_Angle",
+ "sensor_id": "sensor_id",
+ "local_channel_id": "local_channel_id",
+ "height": "Height",
+ "pressure": "Pressure",
+ "scan_angle": "Scan_Angle",
+ }
+
+ # Process required columns
+ for tensor_key, column_name in required_columns.items():
+ obs_tensors[tensor_key] = torch.from_numpy(obs[column_name].to_numpy())
+
+ arr = obs["Absolute_Obs_Time"].to_numpy().astype("datetime64[ns]", copy=False)
+ obs_tensors["absolute_obs_time"] = torch.from_numpy(arr.view(np.int64))
+ obs_tensors["target_time_sec"] = torch.from_numpy(obs["target_time"].to_numpy())
+
+ platform_id = pc.fill_null(obs["Platform_ID"], 0)
+ obs_tensors["platform_id"] = torch.from_numpy(platform_id.to_numpy())
+
+ obs_type = pc.fill_null(obs["Observation_Type"], 0)
+ obs_tensors["observation_type"] = torch.from_numpy(obs_type.to_numpy())
+
+ return (
+ obs_tensors,
+ offsets_3d,
+ sensor_id_to_local,
+ )
+
+ def _get_target(self, frames) -> torch.Tensor:
+ all_state = [f["state"] for sample in frames for f in sample]
+ batch_size = len(frames)
+ state = np.stack(all_state)
+ state = state.reshape((batch_size, -1) + state.shape[1:])
+ state = (state - self.mean) / self.std
+ target = torch.from_numpy(state)
+ b, t, c, x = range(4)
+ out = target.permute(b, c, t, x)
+ return _reorder_nest_to_hpxpad(out)
+
+ @functools.cached_property
+ def _static_condition(self):
+ condition = _get_static_condition(
+ self.hpx_level_condition, self.variable_config
+ )
+ condition = condition.unsqueeze(0)
+ return _reorder_nest_to_hpxpad(condition)
+
+ @profiling.nvtx
+ def transform(self, times, frames):
+ """
+ frames: [[{state: (c, x), obs_v2: Obs}]]
+ times: [[cftime]]
+ """
+ out = {}
+
+ def _apply_time_func(func):
+ return torch.from_numpy(np.vectorize(func)(times))
+
+ if "obs_v2" in frames[0][0].keys():
+ out["unified_obs"] = self._process_obs(times, frames)
+ out["target"] = self._get_target(frames).float()
+ out["second_of_day"] = _apply_time_func(_compute_second_of_day).float()
+ out["day_of_year"] = _apply_time_func(_compute_day_of_year).float()
+ out["timestamp"] = _apply_time_func(_compute_timestamp)
+ out["condition"] = self._static_condition.float()
+ out["labels"] = torch.empty([len(frames), 0])
+ return out
+
+ @profiling.nvtx
+ def device_transform(self, batch, device):
+ """Transforms to the output of .transform that can occur on gpu
+
+ Typically used with the prefetch_map in the main training process.
+ """
+ batch = batch.copy()
+ out = {}
+
+ for key in batch:
+ if key == "unified_obs":
+ obs_tensors, offsets, sensor_id_to_local = batch["unified_obs"]
+ out[key] = self._device_transform_unified_obs(
+ obs_tensors, offsets, sensor_id_to_local, device
+ )
+ else:
+ out[key] = batch[key].to(device, non_blocking=True)
+ return out
+
+ @profiling.nvtx
+ def _device_transform_unified_obs(
+ self, obs_tensors, offsets, sensor_id_to_local, device
+ ):
+ # Move all tensors to device efficiently
+ def _to_device(tensor, non_blocking=True):
+ if isinstance(tensor, torch.Tensor):
+ return tensor.to(device, non_blocking=non_blocking)
+ else:
+ return torch.from_numpy(tensor).to(device, non_blocking=non_blocking)
+
+ obs_tensors = {key: _to_device(val) for key, val in obs_tensors.items()}
+
+ obs_time_ns = obs_tensors["absolute_obs_time"]
+ lat_tensor = obs_tensors["latitude"]
+ lon_tensor = obs_tensors["longitude"]
+ height_tensor = obs_tensors["height"]
+ pressure_tensor = obs_tensors["pressure"]
+ scan_angle_tensor = obs_tensors["scan_angle"]
+ sat_zenith_tensor = obs_tensors["sat_zenith_angle"]
+ sol_zenith_tensor = obs_tensors["sol_zenith_angle"]
+ platform_id_tensor = obs_tensors["platform_id"].int()
+ obs_type_tensor = obs_tensors["observation_type"].int()
+ sensor_id_tensor = obs_tensors["sensor_id"].int()
+ pix = self._grid.ang2pix(lon_tensor, lat_tensor).int()
+ local_channel_id_tensor = obs_tensors["local_channel_id"].int()
+ global_channel_id_tensor = obs_tensors["global_channel_id"].int()
+ observation_tensor = obs_tensors["observation"]
+
+ # Compute metadata
+ meta = features.compute_unified_metadata(
+ obs_tensors["target_time_sec"],
+ time=obs_time_ns,
+ lat=lat_tensor,
+ lon=lon_tensor,
+ height=height_tensor,
+ pressure=pressure_tensor,
+ scan_angle=scan_angle_tensor,
+ sat_zenith_angle=sat_zenith_tensor,
+ sol_zenith_angle=sol_zenith_tensor,
+ )
+
+ # Build index tensor
+ index = torch.stack(
+ [
+ sensor_id_tensor,
+ pix,
+ local_channel_id_tensor,
+ platform_id_tensor,
+ obs_type_tensor,
+ global_channel_id_tensor,
+ ],
+ dim=1, # Stack along dimension 1 to get (n_obs, 6) instead of (6, n_obs)
+ )
+
+ # Create UnifiedObservation object
+ out = types.UnifiedObservation(
+ obs=observation_tensor,
+ time=obs_time_ns,
+ float_metadata=meta,
+ int_metadata=index,
+ hpx_level=self.hpx_level,
+ offsets=_to_device(offsets),
+ sensor_id_to_local=_to_device(sensor_id_to_local),
+ )
+ return out
+
+
+def collate(obj):
+ return obj
diff --git a/examples/weather/healda/datasets/variable_configs.py b/examples/weather/healda/datasets/variable_configs.py
new file mode 100644
index 0000000000..ec88af281d
--- /dev/null
+++ b/examples/weather/healda/datasets/variable_configs.py
@@ -0,0 +1,73 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from datasets.base import VariableConfig
+
+VARIABLE_CONFIGS = {}
+VARIABLE_CONFIGS["default"] = VariableConfig(
+ name="ufs",
+ # Data is on model pressure levels, not absolute pressure levels.
+ # For simplicity, we use the nearest model levels to the desired levels.
+ levels=[1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50],
+ variables_3d=["Q", "U", "V", "T", "Z"],
+ variables_2d=[
+ "tas",
+ "uas",
+ "vas",
+ "rlut",
+ "rsut",
+ "pressfc",
+ "pr",
+ "rsds",
+ "sst",
+ "sic",
+ "hfls",
+ "huss",
+ ],
+ variables_static=["orog", "lfrac"],
+)
+VARIABLE_CONFIGS["era5"] = VariableConfig(
+ name="era5",
+ levels=[1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50],
+ variables_3d=["U", "V", "T", "Z", "Q"],
+ variables_2d=[
+ "tcwv",
+ "tas",
+ "uas",
+ "vas",
+ "100u",
+ "100v",
+ "pres_msl",
+ "sst",
+ "sic",
+ ],
+ variables_static=["orog", "lfrac"],
+)
+VARIABLE_CONFIGS["gfs"] = VariableConfig(
+ name="gfs",
+ levels=[1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50],
+ variables_3d=["U", "V", "T", "Z", "Q"],
+ variables_2d=[
+ "tcwv",
+ "tas",
+ "uas",
+ "vas",
+ "100u",
+ "100v",
+ "pres_msl",
+ "sp", # Surface pressure
+ ],
+ variables_static=["orog", "lfrac"],
+)
diff --git a/examples/weather/healda/datasets/zarr_loader.py b/examples/weather/healda/datasets/zarr_loader.py
new file mode 100644
index 0000000000..4009177c03
--- /dev/null
+++ b/examples/weather/healda/datasets/zarr_loader.py
@@ -0,0 +1,188 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import asyncio
+import urllib.parse
+
+import cftime
+import numpy as np
+import pandas as pd
+import xarray as xr
+import zarr
+import zarr.storage
+from zarr.core.sync import sync
+
+NO_LEVEL = -1
+
+
+def _is_local(path):
+ url = urllib.parse.urlparse(path)
+ return url.scheme == ""
+
+
+async def _getitem(array, index):
+ return await array.get_orthogonal_selection(index)
+
+
+async def _getitem_static(array, num_times: int):
+ """Return the static field broadcasted to the number of times in the chunk"""
+ field = await array.getitem((slice(None),) * array.ndim)
+ field = field[None, ...]
+ return np.broadcast_to(field, (num_times, *field.shape[1:]))
+
+
+class ZarrLoader:
+ """Load 2d and 3d data from a zarr dataset"""
+
+ def __init__(
+ self,
+ *,
+ path: zarr.storage.StoreLike,
+ variables_3d,
+ variables_2d,
+ levels,
+ level_coord_name: str = "",
+ storage_options=None,
+ time_sel_method: str | None = None,
+ variables_static: list[str] = [],
+ ):
+ """
+ Args:
+ time_sel_method: passed to pd.Index.get_indexer(method=)
+ """
+ self.time_sel_method = time_sel_method
+ self.variables_2d = variables_2d
+ self.variables_3d = variables_3d
+ self.levels = levels
+ self.variables_static = variables_static
+ if isinstance(path, str):
+ if _is_local(path):
+ storage_options = None
+ self.group = sync(
+ zarr.api.asynchronous.open_group(
+ path,
+ storage_options=storage_options,
+ use_consolidated=True,
+ mode="r",
+ )
+ )
+ else:
+ self.group = sync(
+ zarr.api.asynchronous.open_group(
+ path,
+ storage_options=storage_options,
+ use_consolidated=True,
+ mode="r",
+ )
+ )
+
+ if self.variables_3d:
+ self.inds = sync(self._get_vertical_indices(level_coord_name, levels))
+
+ self._arrays = {}
+ self._has_time = self.variables_3d or self.variables_2d
+ if self._has_time:
+ time_num, self.units, self.calendar = sync(self._get_time())
+ if np.issubdtype(time_num.dtype, np.datetime64):
+ self.times = pd.DatetimeIndex(time_num)
+ else:
+ self.times = xr.CFTimeIndex(
+ cftime.num2date(time_num, units=self.units, calendar=self.calendar)
+ )
+
+ async def sel_time(self, times) -> dict[tuple[str, int], np.ndarray]:
+ """
+
+ Returns:
+ dict of output data:
+ keys are like (name, level), level == -1 for 2d variables
+
+ """
+ if self._has_time:
+ index_in_loader = self.times.get_indexer(times, method=self.time_sel_method)
+ if (index_in_loader == -1).any():
+ raise KeyError("Index not found.")
+ else:
+ index_in_loader = np.arange(len(times))
+ arr = await self._get(index_in_loader)
+
+ return arr
+
+ async def _get_time(self):
+ time = await self.group.get("time")
+ time_data = await time.getitem(slice(None))
+ return time_data, time.attrs.get("units"), time.attrs.get("calendar")
+
+ async def _get_vertical_indices(self, coord_name, levels):
+ levels_var = await self.group.get(coord_name)
+ levels_arr = await levels_var.getitem(slice(None))
+ return pd.Index(levels_arr).get_indexer(levels, method="nearest")
+
+ async def _get_array(self, name):
+ if name not in self._arrays:
+ self._arrays[name] = await self.group.get(name)
+ return self._arrays[name]
+
+ async def _get(self, t) -> dict[tuple[str, int | None], np.ndarray]:
+ tasks = []
+ keys = []
+
+ for name in self.variables_3d:
+ arr = await self._get_array(name)
+ if arr is None:
+ raise KeyError(name)
+ for level, k in zip(self.levels, self.inds):
+ key = (name, level)
+ # NOTE creating a length 1 list for this indexer avoids an zarr
+ # bug, when using a scalar value (k_indexer = 1)
+ #
+ # ValueError: could not broadcast input array from shape (2,1,49152) into shape (2,49152)
+ #
+ # not sure when this bug appeared
+ # but it's failing with zarr 3.1.13 on dfw
+ k_indexer = [k]
+ value = _getitem(arr, (t, k_indexer))
+ tasks.append(value)
+ keys.append(key)
+
+ for name in self.variables_2d:
+ arr = await self._get_array(name)
+ if arr is None:
+ raise KeyError(name)
+ key = (name, NO_LEVEL)
+ value = _getitem(arr, (t,))
+ tasks.append(value)
+ keys.append(key)
+
+ for name in self.variables_static:
+ arr = await self._get_array(name)
+ if arr is None:
+ raise KeyError(name)
+ key = (name, NO_LEVEL)
+ value = _getitem_static(arr, len(t))
+ tasks.append(value)
+ keys.append(key)
+
+ arrays = await asyncio.gather(*tasks)
+ # squeeze out the dimenions added to workaround the zarr bug. See NOTE above.
+ out = {}
+ for key, array in zip(keys, arrays):
+ name, _ = key
+ if name in self.variables_3d:
+ out[key] = np.squeeze(array, 1)
+ else:
+ out[key] = array
+
+ return out
diff --git a/examples/weather/healda/inference.py b/examples/weather/healda/inference.py
new file mode 100644
index 0000000000..cc150d0f66
--- /dev/null
+++ b/examples/weather/healda/inference.py
@@ -0,0 +1,155 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+
+import numpy as np
+import torch
+import zarr
+from datasets import samplers
+from datasets.prefetch_map import prefetch_map
+from datasets.transform import TransformV2, collate
+from inference_helpers import (
+ DAConfig,
+ DAModel,
+ scoring_times,
+ setup_zarr_output,
+ write_to_zarr,
+)
+from tqdm import tqdm
+from utils import distributed as dist
+from utils.dataclass_parser import parse_args
+
+
+def _device_transform(batch, transform, device):
+ return transform.device_transform(batch, device=device)
+
+
+def main():
+ args = parse_args(DAConfig, convert_underscore_to_hyphen=False)
+
+ dist.init(timeout_infinite=True)
+ dist.print0("Inference configuration:")
+ dist.print0(f" Dataset: {args.dataset}")
+ dist.print0(f" Innovation type: {args.innovation_type.value}")
+ dist.print0(f" Number of samples: {args.num_samples}")
+ dist.print0(f" Output: {args.output_path}")
+
+ # Load the checkpoint
+ dist.print0(f"Loading checkpoint from {args.checkpoint_path}")
+ da_model = DAModel(args)
+ dataset = da_model.get_dataset(split=args.split)
+
+ # full config
+ times = scoring_times(args.z06_18_inits, args.time_frequency, args.split)
+
+ if args.num_samples != -1:
+ min_samples = max(args.num_samples, dist.get_world_size())
+ tasks = samplers.subsample(times, min_samples=min_samples)
+ else:
+ tasks = list(range(len(times)))
+
+ subsampled_times = times[tasks]
+ subsampled_dataset_idx = dataset.times.get_indexer(subsampled_times)
+
+ gpu_tasks = samplers.distributed_split(subsampled_dataset_idx)
+ gpu_times = dataset.times[gpu_tasks]
+ print(
+ f"Tasks: Length: {len(gpu_tasks)} on rank {dist.get_rank()}. Min time: {gpu_times[0]}. Max time: {gpu_times[-1]}"
+ )
+ batch_size = min(args.batch_gpu, len(gpu_tasks))
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ sampler=gpu_tasks,
+ pin_memory=True,
+ batch_size=batch_size,
+ collate_fn=collate,
+ num_workers=5,
+ prefetch_factor=12,
+ multiprocessing_context="spawn",
+ )
+
+ transform = TransformV2(variable_config=da_model.variable_config)
+ dataloader = prefetch_map(
+ dataloader,
+ functools.partial(
+ _device_transform, transform=transform, device=da_model.device
+ ),
+ queue_size=2,
+ )
+
+ channels = da_model.batch_info.channels
+
+ if dist.get_rank() == 0:
+ group = setup_zarr_output(
+ args.output_path,
+ channels=channels,
+ num_times=len(subsampled_times),
+ batch_size=batch_size,
+ subsampled_times=subsampled_times,
+ )
+
+ if dist.get_world_size() > 1:
+ torch.distributed.barrier()
+
+ # reopen with consolidated metadata to avoid reading extra metadata when writing
+ group = zarr.open_group(args.output_path, mode="r+")
+
+ dist.print0("Setup output zarr file. Starting inference...")
+
+ with torch.no_grad():
+ # denormalize
+ scale = torch.tensor(da_model.batch_info.scales)[:, None, None]
+ mean = torch.tensor(da_model.batch_info.center)[:, None, None]
+
+ for k, batch in enumerate(
+ tqdm(dataloader, disable=dist.get_rank() != 0, desc="Inference")
+ ):
+ if args.use_analysis:
+ analysis = batch["target"]
+ else:
+ analysis = da_model.get_state(batch)["target"]
+
+ analysis_scaled = analysis.cpu() * scale + mean
+ target_scaled = batch["target"].cpu() * scale + mean
+ mse = torch.mean((analysis_scaled - target_scaled) ** 2, dim=(0, 2, 3))
+ rmse = mse.sqrt()
+
+ for field in ["Z500", "T850", "uas"]:
+ cz500 = da_model.batch_info.channels.index(field)
+ value = rmse[cz500].item()
+ print(f"RMSE {field} {value}")
+
+ batch_times = batch["timestamp"][:, -1].cpu()
+ batch_times = batch_times.numpy().astype("datetime64[s]")
+ output_index = subsampled_times.get_indexer(batch_times)
+
+ if np.any(output_index == -1):
+ raise KeyError(output_index, batch_times)
+
+ write_to_zarr(group, channels, output_index, analysis_scaled.numpy())
+
+ if dist.get_world_size() > 1:
+ torch.distributed.barrier()
+
+ if torch.distributed.is_initialized():
+ torch.distributed.destroy_process_group()
+
+ dist.print0("Inference completed.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/weather/healda/inference_helpers.py b/examples/weather/healda/inference_helpers.py
new file mode 100644
index 0000000000..be676a4159
--- /dev/null
+++ b/examples/weather/healda/inference_helpers.py
@@ -0,0 +1,327 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+import enum
+import warnings
+from typing import TypedDict
+
+import numpy as np
+import pandas as pd
+import torch
+import zarr
+from datasets.dataset import VARIABLE_CONFIGS
+from datasets.dataset import (
+ get_dataset as get_dataset_ufs,
+)
+from datasets.transform import TransformV2
+from torch.utils.data import Dataset
+from utils import distributed as dist
+from utils.checkpointing import Checkpoint
+from utils.dataclass_parser import Help, a
+
+from physicsnemo.experimental.models.healda import UnifiedObservation
+from config.model_config import ObsConfig
+
+
+class Rolling(Dataset):
+ """Returns window_size consecutive frames from dataset."""
+
+ def __init__(self, dataset, window_size, stride=1, step=1):
+ self.dataset = dataset
+ self.window_size = window_size
+ self.stride = stride
+ self.step = step
+ self.max_start = len(dataset) - (window_size - 1) * step
+ if self.max_start <= 0:
+ raise ValueError("Dataset too small for given window size and step.")
+
+ def __len__(self):
+ return (self.max_start + self.stride - 1) // self.stride
+
+ @property
+ def times(self):
+ return self.dataset.times[: self.max_start : self.stride]
+
+ def __getitem__(self, idx):
+ start = idx * self.stride
+ indices = [start + i * self.step for i in range(self.window_size)]
+ return [self.dataset[i] for i in indices]
+
+
+class Batch(TypedDict):
+ """Input batch structure of DA model"""
+
+ target: torch.Tensor
+ condition: torch.Tensor
+ second_of_day: torch.Tensor
+ day_of_year: torch.Tensor
+ labels: torch.Tensor
+ timestamp: torch.Tensor
+ unified_obs: UnifiedObservation
+
+
+warnings.filterwarnings("ignore", message="Cannot do a zero-copy NCHW to NHWC.")
+
+
+# Copied from training loop
+def _to_batch(x, device, non_blocking=True):
+ if isinstance(x, dict):
+ return {
+ k: _to_batch(v, device, non_blocking=non_blocking) for k, v in x.items()
+ }
+ elif isinstance(x, list):
+ return [_to_batch(i, device, non_blocking=non_blocking) for i in x]
+ elif torch.is_tensor(x):
+ if torch.is_floating_point(x):
+ x = x.float()
+ return x.to(device, non_blocking=non_blocking)
+ elif hasattr(x, "to") and callable(getattr(x, "to")):
+ # custom object with a 'to' method
+ return x.to(device, non_blocking=non_blocking)
+ else:
+ raise NotImplementedError(x)
+
+
+class InnovationType(enum.Enum):
+ """Observation-minus-background (innovation) type"""
+
+ NONE = "none"
+ ADJUSTED = "adjusted"
+ UNADJUSTED = "unadjusted"
+
+
+@dataclasses.dataclass
+class DAConfig:
+ checkpoint_path: a[str, Help("Path to the checkpoint file")]
+ output_path: a[str, Help("Output zarr file path")] = "healda_analysis.zarr"
+ dataset: a[str, Help("Dataset to use (ufs or era5)")] = "era5"
+ innovation_type: a[InnovationType, Help("Obs-minus-background type")] = (
+ InnovationType.NONE
+ )
+ num_samples: a[int, Help("Number of samples (-1 for all)")] = 32
+ time_frequency: a[str, Help("Spacing to sample times from")] = "6h"
+ use_infrared: a[bool, Help("Use infrared observations")] = False
+ use_conv: a[bool, Help("Use conventional observations")] = True
+ context_start: a[int, Help("Obs window start (hours before analysis)")] = -21
+ context_end: a[int, Help("Obs window end (hours after analysis)")] = 3
+ batch_gpu: a[int, Help("Batch size per GPU")] = 8
+ z06_18_inits: a[bool, Help("Use 06z/18z inits instead of 00z/12z")] = False
+ use_analysis: a[
+ bool,
+ Help(
+ "Save out ground truth target dataset instead of predicted HealDA analysis"
+ ),
+ ] = False
+ conv_uv_in_situ_only: a[bool, Help("Exclude satellite UV (keep in-situ)")] = False
+ conv_gps_level1_only: a[bool, Help("Exclude GPS T/Q (keep bending angle)")] = False
+ use_class_labels: a[bool, Help("Use class labels for conditioning")] = False
+ split: a[str, Help("Test (2022) or Train (2021)")] = "test"
+
+
+def scoring_times(
+ z06_z18_inits: bool, time_frequency, split: str = "test"
+) -> pd.DatetimeIndex:
+ year = 2022 if split == "test" else 2021
+ start_date = f"{year}-01-01-00" if not z06_z18_inits else f"{year}-01-01-06"
+ return pd.date_range(start_date, f"{year}-12-31-12", freq=time_frequency)
+
+
+class DAModel:
+ def __init__(self, args: DAConfig):
+ self.args = args
+ with Checkpoint(args.checkpoint_path, mode="r") as ckpt:
+ model = ckpt.read_model(map_location="cuda")
+ model.eval().cuda()
+
+ self.model = model
+ self.split = args.split
+ # Inference time obs config for the dataset.
+ # channel/emb dims are for the model and do not matter here
+ self.obs_config = ObsConfig(
+ use_obs=True, # Always use observations for DA
+ innovation_type=args.innovation_type.value,
+ context_start=args.context_start,
+ context_end=args.context_end,
+ use_infrared=args.use_infrared,
+ use_conv=args.use_conv,
+ conv_uv_in_situ_only=args.conv_uv_in_situ_only,
+ conv_gps_level1_only=args.conv_gps_level1_only,
+ )
+ self._batch_info = None
+ self.use_class_labels = args.use_class_labels
+
+ self.variable_config = (
+ VARIABLE_CONFIGS["default"]
+ if args.dataset == "ufs"
+ else VARIABLE_CONFIGS["era5"]
+ )
+
+ def get_dataset(
+ self,
+ split="test",
+ time_length: int = 1,
+ time_step: int = 12,
+ map_style: bool = True,
+ chunk_size: int = 0,
+ ):
+ args = self.args
+
+ dist.print0(f"Loading {args.dataset} dataset...")
+
+ transform = TransformV2(
+ variable_config=self.variable_config,
+ )
+
+ ds = get_dataset_ufs(
+ dataset=args.dataset,
+ batch_transform=transform.transform,
+ split=split,
+ shuffle=False,
+ obs_config=self.obs_config,
+ map_style=map_style,
+ time_length=time_length,
+ time_step=time_step,
+ chunk_size=chunk_size,
+ )
+ self._batch_info = ds.batch_info
+ return ds
+
+ @property
+ def batch_info(self):
+ if self._batch_info is None:
+ self.get_dataset(split=self.split)
+ return self._batch_info
+
+ @property
+ def device(self):
+ return next(self.model.parameters()).device
+
+ def get_state(self, batch):
+ batch = _to_batch(batch, self.device)
+ target = batch["target"]
+ b, c, t, x = target.shape
+ noise_labels = torch.zeros([b], device=target.device)
+ class_labels = batch["labels"]
+
+ if not self.use_class_labels:
+ class_labels = torch.empty([b, 0], device=target.device)
+
+ condition = batch["condition"]
+ obs = batch["unified_obs"]
+ # TODO generalize this logic in train_regression.py:_step it is like this
+ # time_step = self.time_step * 3600
+ time_step = 0
+ timestamp = (
+ batch["timestamp"].unsqueeze(1)
+ + torch.arange(t, device=target.device) * time_step
+ )
+ second_of_day = batch["second_of_day"]
+ day_of_year = batch["day_of_year"]
+
+ # Get predictions
+ with torch.autocast("cuda", dtype=torch.bfloat16):
+ return {
+ "timestamp": timestamp,
+ "second_of_day": second_of_day,
+ "day_of_year": day_of_year,
+ "target": self.model(
+ condition,
+ noise_labels=noise_labels,
+ class_labels=class_labels,
+ second_of_day=second_of_day,
+ day_of_year=day_of_year,
+ unified_obs=obs,
+ timestamp=timestamp,
+ ).out,
+ }
+
+
+def time_length(batch):
+ return batch["target"].shape[2]
+
+
+def find_matching_indices(targets, available):
+ indices = available.get_indexer(targets)
+ valid = indices != -1
+ return indices[valid], available[indices[valid]]
+
+
+def enumerate_to_dict(array):
+ return {int(val): i for i, val in enumerate(array)}
+
+
+def write_to_zarr(group, channels, index, data):
+ """Write denormalized predictions to zarr arrays.
+
+ Args:
+ group: Zarr group to write to
+ channels: List of channel names
+ index: Time indices to write to
+ data: Data array with shape (batch, channels, 1, cells)
+ """
+ for c in range(len(channels)):
+ name = channels[c]
+ array = data[:, c, 0, :]
+ group[name][index] = array
+
+
+def setup_zarr_output(
+ output_path, channels, num_times, batch_size, subsampled_times=None
+):
+ """Setup zarr output structure for inference results.
+
+ Args:
+ output_path: Path to output zarr file
+ channels: List of channel names
+ num_times: Number of time steps
+ batch_size: Batch size for chunking
+ subsampled_times: Optional pandas DatetimeIndex for time coordinate
+
+ Returns:
+ Opened zarr group (mode='w')
+ """
+
+ group = zarr.open_group(output_path, mode="w")
+
+ # Create data arrays for each channel
+ for field in channels:
+ group.create_array(
+ field,
+ shape=(num_times, 49152),
+ chunks=(batch_size, 49152),
+ fill_value=float("NaN"),
+ dimension_names=("time", "cells"),
+ dtype="f",
+ compressors=[],
+ )
+
+ # Create time coordinate if provided
+ if subsampled_times is not None:
+ times_array = subsampled_times.to_numpy()
+ time_v = group.create_array(
+ "time",
+ dtype=np.int64,
+ shape=times_array.shape,
+ chunks=times_array.shape,
+ dimension_names=["time"],
+ )
+ time_v[:] = times_array.astype("datetime64[s]").astype(np.int64)
+ time_v.attrs["units"] = "seconds since 1970-01-01 00:00:00"
+ time_v.attrs["calendar"] = "standard"
+
+ zarr.consolidate_metadata(group.store)
+ return group
diff --git a/examples/weather/healda/models.py b/examples/weather/healda/models.py
new file mode 100644
index 0000000000..5327374a80
--- /dev/null
+++ b/examples/weather/healda/models.py
@@ -0,0 +1,79 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+from physicsnemo.experimental.models.healda import HealDA
+from config.model_config import ModelConfigV1
+
+
+def _get_condition_dim(config: ModelConfigV1, hidden_size: int) -> int | None:
+ """Determine condition_dim from config.
+
+ Returns None for VIT mode (no diffusion conditioning),
+ or the embedding dimension for diffusion mode.
+ """
+ if config.as_vit:
+ return None
+ # Default to 4 * hidden_size if not specified
+ return config.emb_channels or 4 * hidden_size
+
+
+def get_model(config: ModelConfigV1) -> torch.nn.Module:
+ """Instantiate HealDA model from config."""
+ if config.architecture == "dit-test":
+ hidden_size = 128 # 2 heads * 64 dim
+ return HealDA(
+ in_channels=config.condition_channels,
+ out_channels=config.out_channels,
+ sensor_embedder_config=config.sensor_embedder_config,
+ sensors=config.sensors,
+ hidden_size=hidden_size,
+ num_layers=1,
+ num_heads=2,
+ level_in=6,
+ level_model=5,
+ time_length=config.time_length,
+ drop_path=config.drop_path,
+ dropout=config.p_dropout,
+ qk_norm_type="rmsnorm" if config.qk_rms_norm else None,
+ condition_dim=_get_condition_dim(config, hidden_size),
+ noise_channels=config.noise_channels or 1024,
+ label_dim=config.label_dim,
+ label_dropout=config.label_dropout if config.label_dropout > 0 else None,
+ )
+ elif config.architecture == "dit-l_reg_hpx6_per_sensor":
+ hidden_size = 1024 # 16 heads * 64 dim
+ return HealDA(
+ in_channels=config.condition_channels,
+ out_channels=config.out_channels,
+ sensor_embedder_config=config.sensor_embedder_config,
+ sensors=config.sensors,
+ hidden_size=hidden_size,
+ num_layers=24,
+ num_heads=16,
+ level_in=6,
+ level_model=5,
+ time_length=config.time_length,
+ drop_path=config.drop_path,
+ dropout=config.p_dropout,
+ qk_norm_type="rmsnorm" if config.qk_rms_norm else None,
+ condition_dim=_get_condition_dim(config, hidden_size),
+ noise_channels=config.noise_channels or 1024,
+ label_dim=config.label_dim,
+ label_dropout=config.label_dropout if config.label_dropout > 0 else None,
+ )
+ else:
+ raise NotImplementedError(config.architecture)
diff --git a/examples/weather/healda/requirements.txt b/examples/weather/healda/requirements.txt
new file mode 100644
index 0000000000..88b456b5bc
--- /dev/null
+++ b/examples/weather/healda/requirements.txt
@@ -0,0 +1,24 @@
+# HealDA Example Dependencies
+# install earth2grid separately first (see README.md)
+
+cartopy>=0.25.0
+cftime>=1.6.5
+diffusers>=0.36.0
+duckdb
+earth2grid>=2025.11.1
+h5py>=3.15.1
+joblib>=1.5.3
+matplotlib>=3.10.8
+numpy>=1.2.0
+obstore>=0.8.2
+pandas>=2.3.3
+psutil>=7.1.3
+pyarrow>=21.0.0
+pytest-asyncio>=0.23.0
+pytest-regtest
+python-dotenv>=1.2.1
+tensorboard>=2.20.0
+torch>=2.6.1
+torchmetrics>=1.8.2
+xarray>=2025.12.0
+zarr>=3.1.5
diff --git a/examples/weather/healda/scripts/__init__.py b/examples/weather/healda/scripts/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/weather/healda/scripts/ensemble_metrics.py b/examples/weather/healda/scripts/ensemble_metrics.py
new file mode 100644
index 0000000000..8d9e5d7057
--- /dev/null
+++ b/examples/weather/healda/scripts/ensemble_metrics.py
@@ -0,0 +1,117 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copied from modulus
+import torch
+
+Tensor = torch.Tensor
+
+
+@torch.jit.script
+def _kernel_crps_implementation(pred: Tensor, obs: Tensor, biased: bool) -> Tensor:
+ """An O(m log m) implementation of the kernel CRPS formulas"""
+ skill = torch.abs(pred - obs[..., None]).mean(-1)
+ pred, _ = torch.sort(pred)
+
+ # derivation of fast implementation of spread-portion of CRPS formula when x is sorted
+ # sum_(i,j=1)^m |x_i - x_j| = sum_(i j) |x_i - x_j|
+ # = 2 sum_(i <= j) |x_i -x_j|
+ # = 2 sum_(i <= j) (x_j - x_i)
+ # = 2 sum_(i <= j) x_j - 2 sum_(i <= j) x_i
+ # = 2 sum_(j=1)^m j x_j - 2 sum (m - i + 1) x_i
+ # = 2 sum_(i=1)^m (2i - m - 1) x_i
+ m = pred.size(-1)
+ i = torch.arange(1, m + 1, device=pred.device, dtype=pred.dtype)
+ denom = m * m if biased else m * (m - 1)
+ factor = (2 * i - m - 1) / denom
+ spread = torch.sum(factor * pred, dim=-1)
+ return skill - spread
+
+
+def kcrps(
+ pred: Tensor, obs: Tensor, dim: int = 0, biased: bool = False, chunk_size: int = 16
+):
+ """Estimate the CRPS from a finite ensemble in batched/streaming fashion
+
+ Computes the local Continuous Ranked Probability Score (CRPS) by using
+ the kernel version of CRPS. The cost is O(m log m).
+
+ Creates a map of CRPS and does not accumulate over lat/lon regions.
+ Approximates:
+ .. math::
+ CRPS(X, y) = E[X - y] - 0.5 E[X-X']
+
+ with
+ .. math::
+ sum_i=1^m |X_i - y| / m - 1/(2m^2) sum_i,j=1^m |x_i - x_j|
+
+ Parameters
+ ----------
+ pred : Tensor
+ Tensor containing the ensemble predictions. The ensemble dimension
+ is assumed to be the leading dimension unless 'dim' is specified.
+ obs : Union[Tensor, np.ndarray]
+ Tensor or array containing an observation over which the CRPS is computed
+ with respect to.
+ dim : int, optional
+ The dimension over which to compute the CRPS, assumed to be 0.
+ biased :
+ When False, uses the unbiased estimators described in (Zamo and Naveau, 2018)::
+
+ E|X-y|/m - 1/(2m(m-1)) sum_(i,j=1)|x_i - x_j|
+
+ Unlike ``crps`` this is fair for finite ensembles. Non-fair ``crps`` favors less
+ dispersive ensembles since it is biased high by E|X- X'|/ m where m is the
+ ensemble size.
+
+ Returns
+ -------
+ Tensor
+ Map of CRPS
+ """
+ pred = torch.movedim(pred, dim, -1)
+ return _kernel_crps_implementation(pred, obs, biased=biased)
+
+
+def unbiased_ensemble_metrics(prediction, truth):
+ """Unbiased ensemble metrics
+
+ When averaged over many forecasts these formulas are unbiased
+ even for size 2 ensembles.
+
+ Args:
+ prediction: shaped (*, e) - e is the ensemble dimension
+ truth: shaped (*)
+ """
+ scores = {}
+ scores["mse"] = mse = (prediction.mean(-1) - truth) ** 2
+ # ensemble scores
+ if prediction.size(-1) > 1:
+ scores["variance"] = variance = prediction.var(-1)
+ scores["crps"] = kcrps(prediction, truth, dim=-1, biased=False)
+
+ # unbias the ensemble mean mse formula
+ # per reviewer of Brenowitz, et. al. 2024 , A practical benchmark for probabilistic scoring
+ # RMSE for the ensemble mean can also be debiased to match the
+ # limit of infinite ensemble size, using the same math from Fortin, DebiasedMSE =
+ # MSE - (1/n) Var, where Var is the debiased estimate of the ensemble variance.
+ R = prediction.size(-1)
+ scores["mse"] = mse - variance / R
+ scores["mse_biased"] = mse
+
+ # Fortin. eq 15
+ scores["spread_error"] = torch.sqrt(variance / mse * (R + 1) / R)
+ return scores
diff --git a/examples/weather/healda/scripts/fengwu_model.py b/examples/weather/healda/scripts/fengwu_model.py
new file mode 100644
index 0000000000..3c7ce5ff93
--- /dev/null
+++ b/examples/weather/healda/scripts/fengwu_model.py
@@ -0,0 +1,388 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections import OrderedDict
+from collections.abc import Generator, Iterator
+from typing import TypeVar
+
+import numpy as np
+import torch
+from earth2studio.models.auto import AutoModelMixin, Package
+from earth2studio.models.batch import batch_coords, batch_func
+from earth2studio.models.px.base import PrognosticModel
+from earth2studio.models.px.utils import PrognosticMixin
+from earth2studio.models.utils import create_ort_session
+from earth2studio.utils import handshake_coords, handshake_dim
+from earth2studio.utils.imports import (
+ OptionalDependencyFailure,
+ check_optional_dependencies,
+)
+from earth2studio.utils.type import CoordSystem
+
+try:
+ import onnxruntime as ort
+ from onnxruntime import InferenceSession
+except ImportError:
+ OptionalDependencyFailure("fengwu")
+ ort = None
+ InferenceSession = TypeVar("InferenceSession") # type: ignore
+# Copied from earth2studio/models/px/fengwu.py. Patched to fix a memory layout issue.
+
+VARIABLES = [
+ "u10m",
+ "v10m",
+ "t2m",
+ "msl",
+ "z50",
+ "z100",
+ "z150",
+ "z200",
+ "z250",
+ "z300",
+ "z400",
+ "z500",
+ "z600",
+ "z700",
+ "z850",
+ "z925",
+ "z1000",
+ "q50",
+ "q100",
+ "q150",
+ "q200",
+ "q250",
+ "q300",
+ "q400",
+ "q500",
+ "q600",
+ "q700",
+ "q850",
+ "q925",
+ "q1000",
+ "u50",
+ "u100",
+ "u150",
+ "u200",
+ "u250",
+ "u300",
+ "u400",
+ "u500",
+ "u600",
+ "u700",
+ "u850",
+ "u925",
+ "u1000",
+ "v50",
+ "v100",
+ "v150",
+ "v200",
+ "v250",
+ "v300",
+ "v400",
+ "v500",
+ "v600",
+ "v700",
+ "v850",
+ "v925",
+ "v1000",
+ "t50",
+ "t100",
+ "t150",
+ "t200",
+ "t250",
+ "t300",
+ "t400",
+ "t500",
+ "t600",
+ "t700",
+ "t850",
+ "t925",
+ "t1000",
+]
+
+
+@check_optional_dependencies()
+class FengWu(torch.nn.Module, AutoModelMixin, PrognosticMixin):
+ """FengWu (operational) weather model consists of single auto-regressive model with
+ a time-step size of 6 hours. FengWu operates on 0.25 degree lat-lon grid (south-pole
+ including) equirectangular grid with 69 atmospheric/surface variables. This model
+ uses two time-steps as an input.
+
+ Note
+ ----
+ This model uses the ONNX checkpoint from the original publication repository. This
+ checkpoint is a operational version to the one used in the paper which requires less
+ variables. For additional information see the following resources:
+
+ - https://arxiv.org/abs/2304.02948
+ - https://github.com/OpenEarthLab/FengWu
+
+ Note
+ ----
+ To avoid ONNX init session overhead of this model we recommend setting the default
+ Pytorch device to the correct target prior to model construction.
+
+ Parameters
+ ----------
+ ort : str
+ Path to FengWu 6 hour onnx file
+ center : torch.Tensor
+ Model variable center normalization tensor of size [69]
+ scale : torch.Tensor
+ Model variable scale normalization tensor of size [69]
+ """
+
+ def __init__(
+ self,
+ ort: str,
+ center: torch.Tensor,
+ scale: torch.Tensor,
+ ) -> None:
+ super().__init__()
+
+ self.device = torch.ones(1).device # Hack to get default device
+ self.ort = create_ort_session(ort, self.device)
+
+ self.register_buffer("center", center.unsqueeze(-1).unsqueeze(-1))
+ self.register_buffer("scale", scale.unsqueeze(-1).unsqueeze(-1))
+
+ def input_coords(self) -> CoordSystem:
+ """Input coordinate system of the prognostic model
+
+ Returns
+ -------
+ CoordSystem
+ Coordinate system dictionary
+ """
+ return OrderedDict(
+ {
+ "batch": np.empty(0),
+ "lead_time": np.array(
+ [np.timedelta64(-6, "h"), np.timedelta64(0, "h")]
+ ),
+ "variable": np.array(VARIABLES),
+ "lat": np.linspace(90, -90, 721, endpoint=True),
+ "lon": np.linspace(0, 360, 1440, endpoint=False),
+ }
+ )
+
+ @batch_coords()
+ def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
+ """Output coordinate system of the prognostic model
+
+ Parameters
+ ----------
+ input_coords : CoordSystem
+ Input coordinate system to transform into output_coords
+
+ Returns
+ -------
+ CoordSystem
+ Coordinate system dictionary
+ """
+ output_coords = OrderedDict(
+ {
+ "batch": np.empty(0),
+ "lead_time": np.array([np.timedelta64(6, "h")]),
+ "variable": np.array(VARIABLES),
+ "lat": np.linspace(90, -90, 721, endpoint=True),
+ "lon": np.linspace(0, 360, 1440, endpoint=False),
+ }
+ )
+
+ test_coords = input_coords.copy()
+ test_coords["lead_time"] = (
+ test_coords["lead_time"] - input_coords["lead_time"][-1]
+ )
+ target_input_coords = self.input_coords()
+ for i, key in enumerate(target_input_coords):
+ if key != "batch":
+ handshake_dim(test_coords, key, i)
+ handshake_coords(test_coords, target_input_coords, key)
+
+ output_coords["batch"] = input_coords["batch"]
+ output_coords["lead_time"] = (
+ input_coords["lead_time"][1:] + output_coords["lead_time"]
+ )
+
+ return output_coords
+
+ def to(self, device: str | torch.device | int) -> PrognosticModel:
+ """Move model (and default ORT session) to device"""
+ device = torch.device(device)
+ if device.index is None:
+ if device.type == "cuda":
+ device = torch.device(device.type, torch.cuda.current_device())
+ else:
+ device = torch.device(device.type, 0)
+
+ super().to(device)
+
+ if device != self.device:
+ self.device = device
+ # Move base ort session
+ if self.ort is not None:
+ model_path = self.ort._model_path
+ del self.ort
+ self.ort = create_ort_session(model_path, device)
+
+ return self
+
+ @classmethod
+ def load_default_package(cls) -> Package:
+ """Load prognostic package"""
+ return Package(
+ "hf://NickGeneva/earth_ai/fengwu",
+ cache_options={
+ "cache_storage": Package.default_cache("fengwu"),
+ "same_names": True,
+ },
+ )
+
+ @classmethod
+ @check_optional_dependencies()
+ def load_model(
+ cls,
+ package: Package,
+ ) -> PrognosticModel:
+ """Load prognostic from package"""
+ onnx_file = package.resolve("fengwu_v1.onnx")
+ global_center = torch.Tensor(np.load(package.open("global_means.npy")))
+ global_std = torch.Tensor(np.load(package.open("global_stds.npy")))
+ return cls(onnx_file, global_center, global_std)
+
+ @torch.inference_mode()
+ def _forward(
+ self,
+ x: torch.Tensor,
+ ort_session: InferenceSession,
+ ) -> torch.Tensor:
+ # Ref https://onnxruntime.ai/docs/api/python/api_summary.html
+ binding = ort_session.io_binding()
+
+ def bind_input(name: str, input: torch.Tensor) -> None:
+ input = input.contiguous()
+ binding.bind_input(
+ name=name,
+ device_type=self.device.type,
+ device_id=self.device.index,
+ element_type=np.float32,
+ shape=tuple(input.shape),
+ buffer_ptr=input.data_ptr(),
+ )
+
+ def bind_output(name: str, like: torch.Tensor) -> torch.Tensor:
+ out = torch.empty_like(like).contiguous()
+ binding.bind_output(
+ name=name,
+ device_type=self.device.type,
+ device_id=self.device.index,
+ element_type=np.float32,
+ shape=tuple(out.shape),
+ buffer_ptr=out.data_ptr(),
+ )
+ return out
+
+ x = (x - self.center) / self.scale # Normalize
+
+ # Patch as view requires contiguous input
+ if not x.is_contiguous():
+ x = x.contiguous()
+
+ x = x.view(x.shape[0], -1, 721, 1440) # Concat time-steps
+ # Forward pass, fengwu onnx supports batched
+ bind_input("input", x)
+ output = bind_output("output", like=x)
+ ort_session.run_with_iobinding(binding)
+
+ # ONNX model outputs two time-steps, take the first
+ output_tensor = output[:].contiguous()
+ x = self.scale * output_tensor[:, :69].unsqueeze(1) + self.center # UnNormalize
+ return x
+
+ @batch_func()
+ def __call__(
+ self,
+ x: torch.Tensor,
+ coords: CoordSystem,
+ ) -> tuple[torch.Tensor, CoordSystem]:
+ """Runs 6 hour prognostic model 1 step.
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ Input tensor
+ coords : CoordSystem
+ Input coordinate system
+
+ Returns
+ -------
+ tuple[torch.Tensor, CoordSystem]
+ Output tensor and coordinate system 6 hours in the future
+ """
+ return self._forward(x, self.ort), self.output_coords(coords)
+
+ @batch_func()
+ def _default_generator(
+ self, x: torch.Tensor, coords: CoordSystem
+ ) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]:
+ coords = coords.copy()
+
+ self.output_coords(coords)
+
+ out = x[:, 1:]
+ out_coords = coords.copy()
+ out_coords["lead_time"] = out_coords["lead_time"][1:]
+ yield out, out_coords
+
+ while True:
+ # Front hook
+ x, coords = self.front_hook(x, coords)
+
+ # Forward is identity operator
+ out = self._forward(x, self.ort)
+ out_coords = self.output_coords(coords)
+
+ # Rear hook
+ out, out_coords = self.rear_hook(out, out_coords)
+
+ # Update inputs for next time-step
+ x = torch.cat([x[:, 1:], out], dim=1)
+ coords["lead_time"] = np.array(
+ [coords["lead_time"][-1], out_coords["lead_time"][-1]]
+ )
+
+ yield out, out_coords.copy()
+
+ def create_iterator(
+ self, x: torch.Tensor, coords: CoordSystem
+ ) -> Iterator[tuple[torch.Tensor, CoordSystem]]:
+ """Creates a iterator which can be used to perform time-integration of the
+ prognostic model. Will return the initial condition first (0th step).
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ Input tensor
+ coords : CoordSystem
+ Input coordinate system
+
+ Yields
+ ------
+ Iterator[tuple[torch.Tensor, CoordSystem]]
+ Iterator that generates time-steps of the prognostic model container the
+ output data tensor and coordinate system dictionary.
+ """
+ yield from self._default_generator(x, coords)
diff --git a/examples/weather/healda/scripts/forecast.py b/examples/weather/healda/scripts/forecast.py
new file mode 100644
index 0000000000..29b28eedab
--- /dev/null
+++ b/examples/weather/healda/scripts/forecast.py
@@ -0,0 +1,463 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Supports: FCN3, Aurora, FengWu, Pangu, and Mock/Persistence models.
+Output: HEALPix level 6 zarr format.
+"""
+
+import argparse
+import gc
+import logging
+import os
+import sys
+from collections import OrderedDict
+from pathlib import Path
+from typing import Any, Literal
+
+import earth2grid
+import numpy as np
+import pandas as pd
+import torch
+import xarray as xr
+from earth2grid import healpix, latlon
+from earth2studio.models.px import Persistence
+from earth2studio.models.px.aurora import VARIABLES as AURORA_VARIABLES
+from earth2studio.models.px.aurora import Aurora
+from earth2studio.models.px.fcn3 import FCN3
+from earth2studio.models.px.fcn3 import VARIABLES as FCN3_VARIABLES
+from earth2studio.models.px.pangu import VARIABLES as PANGU_VARIABLES
+from earth2studio.models.px.pangu import Pangu6
+from earth2studio.perturbation import Zero
+from earth2studio.run import deterministic as e2s_deterministic
+from earth2studio.run import ensemble as e2s_ensemble
+from fengwu_model import VARIABLES as FENGWU_VARIABLES
+from fengwu_model import FengWu
+from io_backend import RegriddingZarrBackend
+
+# Add healda utils to path
+sys.path.insert(0, str(Path(__file__).parent.parent / "utils"))
+import distributed as dist
+
+logger = logging.getLogger(__name__)
+
+
+# -----------------------------------------------------------------------------
+# Model Configuration
+# -----------------------------------------------------------------------------
+
+
+class ModelConfig:
+ """Configuration for different prognostic models."""
+
+ def __init__(
+ self,
+ name: str,
+ variables: list[str],
+ workflow: Literal["ensemble", "deterministic"],
+ model: torch.nn.Module,
+ nlat: int = 721,
+ nlon: int = 1440,
+ ):
+ self.name = name
+ self.variables = variables
+ self.workflow = workflow
+ self.model = model
+ self.nlat = nlat
+ self.nlon = nlon
+
+ @classmethod
+ def fcn3(cls) -> "ModelConfig":
+ return cls(
+ name="fcn3",
+ variables=FCN3_VARIABLES,
+ workflow="ensemble",
+ model=FCN3.load_model(FCN3.load_default_package()),
+ )
+
+ @classmethod
+ def aurora(cls) -> "ModelConfig":
+ return cls(
+ name="aurora",
+ variables=AURORA_VARIABLES,
+ workflow="deterministic",
+ model=Aurora.load_model(Aurora.load_default_package()),
+ nlat=720,
+ nlon=1440,
+ )
+
+ @classmethod
+ def pangu(cls) -> "ModelConfig":
+ return cls(
+ name="pangu",
+ variables=PANGU_VARIABLES,
+ workflow="deterministic",
+ model=Pangu6.load_model(Pangu6.load_default_package()),
+ )
+
+ @classmethod
+ def fengwu(cls) -> "ModelConfig":
+ return cls(
+ name="fengwu",
+ variables=FENGWU_VARIABLES,
+ workflow="deterministic",
+ model=FengWu.load_model(FengWu.load_default_package()),
+ )
+
+ @classmethod
+ def mock(cls) -> "ModelConfig":
+ grid = earth2grid.latlon.equiangular_lat_lon_grid(721, 1440)
+ coords = OrderedDict([("lat", grid.lat.ravel()), ("lon", grid.lon)])
+ return cls(
+ name="mock",
+ variables=FCN3_VARIABLES,
+ workflow="ensemble",
+ model=Persistence(FCN3_VARIABLES, coords),
+ )
+
+
+# -----------------------------------------------------------------------------
+# Data Source Wrapper
+# -----------------------------------------------------------------------------
+
+
+class DataArrayZarr:
+ """
+ E2Studio-compatible wrapper around xarray zarr that regrids from HEALPix to lat-lon.
+ Output order is always: (time, variable, lat, lon).
+ """
+
+ _2D_VAR_MAP = {
+ "tcwv": "tcwv",
+ "t2m": "tas",
+ "u10m": "uas",
+ "v10m": "vas",
+ "u100m": "100u",
+ "v100m": "100v",
+ "msl": "pres_msl",
+ }
+
+ def __init__(
+ self,
+ file_path: str,
+ *,
+ nlat: int = 721,
+ nlon: int = 1440,
+ xr_open_kwargs: dict[str, Any] | None = None,
+ ):
+ self.file_path = file_path
+ self.nlat = nlat
+ self.nlon = nlon
+ self.include_south_pole = nlat == 721
+
+ xr_open_kwargs = xr_open_kwargs or {}
+ ds = xr.open_zarr(file_path, **xr_open_kwargs)
+
+ if isinstance(ds, xr.Dataset):
+ da = ds.to_array(dim="variable")
+ else:
+ da = ds
+
+ dims = list(da.dims)
+ self._on_latlon_grid = "lat" in da.dims and "lon" in da.dims
+
+ if not self._on_latlon_grid and "cells" not in da.dims:
+ raise ValueError(
+ "DataArray must have 'cells' dimension for non-latlon grids"
+ )
+
+ # Move 'time' and 'variable' to front
+ for lead in ("variable", "time"):
+ if lead in dims:
+ dims.remove(lead)
+ dims.insert(0, lead)
+ da = da.transpose(*dims, missing_dims="ignore")
+
+ self.da = da
+ self.has_type_dim = "type" in da.dims
+ self._setup_regridding()
+
+ def _setup_regridding(self):
+ self._regridder = None
+ self._lat_coords = None
+ self._lon_coords = None
+
+ if self._on_latlon_grid:
+ return
+
+ hpx_grid = healpix.Grid(level=6, pixel_order=earth2grid.healpix.HEALPIX_PAD_XY)
+ ll_grid = latlon.equiangular_lat_lon_grid(
+ nlat=self.nlat, nlon=self.nlon, includes_south_pole=self.include_south_pole
+ )
+
+ self._regridder = earth2grid.get_regridder(hpx_grid, ll_grid).float().cuda()
+
+ lat_arr = np.asarray(ll_grid.lat).squeeze()
+ if lat_arr.ndim != 1:
+ lat_arr = lat_arr[:, 0]
+ lon_arr = np.asarray(ll_grid.lon).squeeze()
+ if lon_arr.ndim != 1:
+ lon_arr = lon_arr[0, :]
+ self._lat_coords = lat_arr
+ self._lon_coords = lon_arr
+
+ def _regrid(self, np_block: np.ndarray) -> np.ndarray:
+ t = torch.from_numpy(np_block).to(torch.float32).cuda()
+ return self._regridder(t).cpu().float().numpy()
+
+ def __call__(self, time, variable) -> xr.DataArray:
+ if isinstance(variable, str):
+ variable = [variable]
+
+ dataset_variables = []
+ available_vars = self.da.coords["variable"].values
+ ours_to_fcn3 = {}
+
+ for v in variable:
+ if v in available_vars:
+ dataset_variables.append(v)
+ elif v.upper() in available_vars:
+ dataset_variables.append(v.upper())
+ elif v in self._2D_VAR_MAP:
+ mapped_var = self._2D_VAR_MAP[v]
+ if mapped_var in available_vars:
+ dataset_variables.append(mapped_var)
+ else:
+ raise ValueError(f"Mapped variable {mapped_var} for {v} not found")
+ else:
+ raise ValueError(f"Variable {v} not found. Available: {available_vars}")
+
+ our_var_name = dataset_variables[-1]
+ ours_to_fcn3[our_var_name] = v
+
+ data = self.da.sel(time=time, variable=dataset_variables)
+ if self.has_type_dim:
+ data = data.isel(type=0)
+
+ if self._on_latlon_grid:
+ return data
+
+ np_block = np.ascontiguousarray(data.values.astype(np.float32))
+ out = self._regrid(np_block)
+
+ out_dims = list(data.dims[:-1]) + ["lat", "lon"]
+ out_coords = {k: data.coords[k].values for k in data.dims if k != "cells"}
+ out_coords["variable"] = [ours_to_fcn3[v] for v in out_coords["variable"]]
+ out_coords["lat"] = self._lat_coords
+ out_coords["lon"] = self._lon_coords
+
+ return xr.DataArray(out, dims=out_dims, coords=out_coords)
+
+
+def filter_paired_times(
+ times: pd.DatetimeIndex, delta: pd.Timedelta
+) -> pd.DatetimeIndex:
+ """Filter times to only those where (t - delta) also exists in times.
+
+ Aurora and FengWu require (t-6h, t) pairs for initialization.
+ """
+ time_set = set(times)
+ mask = [(t - delta) in time_set for t in times]
+ return times[mask]
+
+
+def subsample(dataset, num_samples: int) -> list[int]:
+ """Sample indices using golden ratio for quasi-random uniform distribution."""
+ golden_ratio = 1.618033988749
+ n = len(dataset)
+ indices = [int((i * n * golden_ratio) % n) for i in range(num_samples)]
+ return sorted(indices)
+
+
+def setup_logging():
+ """Setup logging"""
+ logging.basicConfig(level=logging.INFO)
+ logger.setLevel(logging.INFO)
+
+
+def main(argv=None):
+ """Run forecast inference.
+
+ Supports multiple weather models with HEALPix zarr output format.
+ Distributes work (times) across all available GPUs.
+ """
+ parser = argparse.ArgumentParser(description="Run weather forecast inference")
+ parser.add_argument(
+ "--init_path", type=str, required=True, help="Path to input zarr"
+ )
+ parser.add_argument("--out_dir", type=str, required=True, help="Output directory")
+ parser.add_argument(
+ "--num_steps", type=int, default=40, help="Number of forecast steps (6h each)"
+ )
+ parser.add_argument(
+ "--num_times", type=int, default=4, help="Number of initial times to forecast"
+ )
+ parser.add_argument(
+ "--num_ensemble", type=int, default=1, help="Number of ensemble members"
+ )
+ parser.add_argument(
+ "--z06_18_inits",
+ action="store_true",
+ help="Use 06/18 UTC times instead of 00/12",
+ )
+ parser.add_argument(
+ "--all_utc_times",
+ action="store_true",
+ help="Use all 00/06/12/18 UTC times",
+ )
+ parser.add_argument(
+ "--model",
+ type=str.lower,
+ choices=["fcn3", "aurora", "pangu", "fengwu", "mock"],
+ default="mock",
+ help="Model to use for inference",
+ )
+ parser.add_argument(
+ "--no-bfloat16", action="store_false", dest="bfloat16", default=True
+ )
+ args = parser.parse_args(argv)
+
+ dist.init()
+ setup_logging()
+
+ # Create model config
+ if args.model == "fcn3":
+ model_config = ModelConfig.fcn3()
+ elif args.model == "aurora":
+ model_config = ModelConfig.aurora()
+ elif args.model == "pangu":
+ model_config = ModelConfig.pangu()
+ elif args.model == "fengwu":
+ model_config = ModelConfig.fengwu()
+ else:
+ model_config = ModelConfig.mock()
+
+ # Ensemble size
+ nensemble = 1 if model_config.workflow == "deterministic" else args.num_ensemble
+
+ logger.info(f"Model: {model_config.name}, workflow: {model_config.workflow}")
+ logger.info(f"Ensemble members: {nensemble}, steps: {args.num_steps}")
+
+ # Load data
+ if not os.path.exists(args.init_path):
+ raise FileNotFoundError(f"Zarr not found: {args.init_path}")
+
+ ds = DataArrayZarr(args.init_path, nlat=model_config.nlat, nlon=model_config.nlon)
+ times = pd.to_datetime(ds.da.time.values)
+ logger.info(f"Zarr contains {len(times)} times: {times[0]} to {times[-1]}")
+
+ # Aurora and FengWu require (t-6h, t) pairs for initialization
+ if args.model in ["aurora", "fengwu"]:
+ orig_len = len(times)
+ times = filter_paired_times(times, pd.Timedelta(hours=6))
+ logger.info(f"Filtered {orig_len - len(times)} unpaired times for {args.model}")
+
+ # Filter by UTC time
+ if args.all_utc_times:
+ mask = times.hour.isin([0, 6, 12, 18])
+ times = times[mask]
+ elif args.z06_18_inits:
+ mask = times.hour.isin([6, 18])
+ times = times[mask]
+ else:
+ mask = times.hour.isin([0, 12])
+ times = times[mask]
+
+ if len(times) == 0:
+ logger.warning("No valid UTC times found after filtering. Using all available.")
+ times = pd.to_datetime(ds.da.time.values)
+
+ # Remove last 10 days of December (Dec 22-31) to keep forecasts within year
+ mask = ~((times.month == 12) & (times.day >= 22))
+ times = times[mask]
+
+ if len(times) == 0:
+ raise ValueError("No valid times found after filtering")
+
+ if args.num_times < len(times):
+ sample_indices = subsample(times, args.num_times)
+ times = times[sample_indices]
+ logger.info(f"Sampled {len(times)} times across year")
+
+ rank, world_size = dist.get_rank(), dist.get_world_size()
+ times = times[rank::world_size]
+
+ logger.info(
+ f"Rank {rank}: processing {len(times)} times from {times[0] if len(times) > 0 else 'N/A'} to {times[-1] if len(times) > 0 else 'N/A'}"
+ )
+
+ # Setup output (only create directory on rank 0)
+ if dist.get_rank() == 0:
+ os.makedirs(args.out_dir, exist_ok=True)
+
+ if dist.get_world_size() > 1:
+ torch.distributed.barrier()
+
+ zarr_path = os.path.join(args.out_dir, "forecast.zarr")
+
+ io_backend = RegriddingZarrBackend(
+ zarr_path=zarr_path,
+ times=times,
+ rank=rank,
+ out_vars=model_config.variables,
+ n_ensemble=nensemble,
+ nsteps=args.num_steps,
+ init_zarr_path=args.init_path,
+ )
+
+ model = model_config.model
+ perturbation = Zero()
+
+ for i, t in enumerate(times):
+ logger.info(f"Rank {rank}: [{i + 1}/{len(times)}] Forecasting from {t}")
+
+ with torch.autocast(
+ device_type="cuda", dtype=torch.bfloat16, enabled=args.bfloat16
+ ):
+ if model_config.workflow == "deterministic":
+ e2s_deterministic(
+ time=[t],
+ nsteps=args.num_steps,
+ prognostic=model,
+ data=ds,
+ io=io_backend,
+ output_coords={"variable": np.array(model_config.variables)},
+ )
+ else:
+ e2s_ensemble(
+ time=[t],
+ nsteps=args.num_steps,
+ nensemble=nensemble,
+ prognostic=model,
+ data=ds,
+ io=io_backend,
+ perturbation=perturbation,
+ output_coords={"variable": np.array(model_config.variables)},
+ )
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ if dist.get_world_size() > 1:
+ torch.distributed.barrier()
+
+ logger.info(f"Forecast complete. Output: {zarr_path}")
+
+ if torch.distributed.is_initialized():
+ torch.distributed.destroy_process_group()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/weather/healda/scripts/io_backend.py b/examples/weather/healda/scripts/io_backend.py
new file mode 100644
index 0000000000..9c2f4bdcb7
--- /dev/null
+++ b/examples/weather/healda/scripts/io_backend.py
@@ -0,0 +1,334 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Regridding Zarr IO Backend for HEALPix output."""
+
+import itertools
+import os
+from typing import Any
+
+import earth2grid
+import numpy as np
+import pandas as pd
+import torch
+import xarray as xr
+import zarr
+from earth2grid import healpix, latlon
+from earth2studio.utils.type import CoordSystem
+from regridding import ConservativeRegridder, add_south_pole_mean
+
+
+class RegriddingZarrBackend:
+ """Interface for a generic IO backend. Assume 0.25 output grid."""
+
+ _VAR_MAP = {
+ "t2m": "tas",
+ "u10m": "uas",
+ "v10m": "vas",
+ "u100m": "100u",
+ "v100m": "100v",
+ "msl": "pres_msl",
+ }
+
+ def __init__(
+ self,
+ zarr_path,
+ times,
+ out_vars,
+ n_ensemble,
+ nsteps,
+ rank,
+ regrid_conservative=False,
+ init_zarr_path=None,
+ ):
+ self.out_vars = out_vars
+ self.n_ensemble = n_ensemble
+ self.nsteps = nsteps
+ self.regrid_conservative = regrid_conservative
+ self.init_zarr_path = init_zarr_path
+ if rank == 0:
+ self._init_group(zarr_path, times)
+
+ # Load original HEALPix init data for t=0 bypass
+ if init_zarr_path:
+ self.init_ds = xr.open_zarr(init_zarr_path)
+ else:
+ self.init_ds = None
+
+ self.group = zarr.open_consolidated(zarr_path)
+ self.times = pd.DatetimeIndex(self.group["time"][:].astype("datetime64[s]"))
+
+ # Regridder expects 721 lat after add_south_pole_mean()
+ ll_grid = latlon.equiangular_lat_lon_grid(
+ nlat=721, nlon=1440, includes_south_pole=True
+ )
+
+ if not regrid_conservative:
+ hpx_grid = healpix.Grid(
+ level=6, pixel_order=earth2grid.healpix.PixelOrder.NEST
+ )
+ self.regridder_to_hpx = earth2grid.get_regridder(ll_grid, hpx_grid).float()
+ else:
+ self.regridder_to_hpx = ConservativeRegridder(
+ latlon_grid=ll_grid, regrid_level=8, out_level=6
+ )
+
+ def _regrid(self, array):
+ regridder = self.regridder_to_hpx.to(array.device)
+ return regridder(array)
+
+ def _update_coords(self, zarr_path, requested_times):
+ nensemble = self.n_ensemble
+ first_var = self.out_vars[0]
+ group = zarr.open_group(zarr_path, mode="a")
+ if first_var in group:
+ existing_array = group[first_var]
+ existing_ensemble_size = existing_array.shape[1]
+
+ # Validate times match
+ if "time" in group:
+ existing_times = pd.DatetimeIndex(
+ group["time"][:].astype("datetime64[s]")
+ )
+ requested_times_s = pd.DatetimeIndex(
+ np.asarray(requested_times).astype("datetime64[s]")
+ )
+ if not existing_times.equals(requested_times_s):
+ raise ValueError(
+ f"Existing zarr has times {existing_times[0]} to {existing_times[-1]} "
+ f"({len(existing_times)} times), but requested {requested_times_s[0]} to "
+ f"{requested_times_s[-1]} ({len(requested_times_s)} times). "
+ "Please delete the zarr or use a different output path."
+ )
+
+ if existing_ensemble_size > nensemble:
+ raise ValueError(
+ f"Existing zarr has {existing_ensemble_size} ensemble members, "
+ f"but requested {nensemble}. Cannot resize down. "
+ "Please delete the zarr or use a different output path."
+ )
+ elif existing_ensemble_size < nensemble:
+ print(
+ f"Resizing zarr from {existing_ensemble_size} to {nensemble} ensemble members"
+ )
+ # Resize all variable arrays along ensemble dimension
+ for field in self.out_vars:
+ if field in group:
+ var_array = group[field]
+ new_shape = list(var_array.shape)
+ new_shape[1] = nensemble
+ var_array.resize(new_shape)
+
+ # Update ensemble coordinate array
+ ensemble_v = group["ensemble"]
+ ensemble_array = np.arange(nensemble)
+ ensemble_v.resize(ensemble_array.shape)
+ ensemble_v[:] = ensemble_array
+
+ zarr.consolidate_metadata(group.store)
+ print(
+ f"Resized zarr structure at {zarr_path} (HEALPix level 6, {nensemble} ensemble members)"
+ )
+ return
+ else:
+ print(f"Zarr exists with {nensemble} ensemble members. Reusing.")
+ return
+
+ def _init_group(self, zarr_path, times):
+ os.makedirs(zarr_path, exist_ok=True)
+ group = zarr.open_group(zarr_path, mode="a")
+
+ zarr_exists = len(group) > 0 and any(key in group for key in self.out_vars)
+ if zarr_exists:
+ return self._update_coords(zarr_path, times)
+
+ # Create new zarr structure
+ print(f"Creating zarr at {zarr_path}")
+ spatial_shape = (49152,) # HEALPix cells
+ total_times = len(times) # Use full task count for dimensions
+ for field in self.out_vars:
+ group.create_array(
+ field,
+ shape=(
+ total_times,
+ self.n_ensemble,
+ self.nsteps + 1,
+ *spatial_shape,
+ ), # time, ensemble, step, cells
+ chunks=(1, 1, self.nsteps + 1, *spatial_shape),
+ fill_value=float("NaN"),
+ dimension_names=("time", "ensemble", "lead_time", "cells"),
+ dtype="f",
+ )
+
+ # Store actual datetime values for time coordinate
+ time_v = group.create_array(
+ "time",
+ dtype=np.int64,
+ shape=(total_times,),
+ chunks=(total_times,),
+ dimension_names=["time"],
+ )
+ # global_tasks contains indices, get actual times
+ # Ensure times are unique (floor to seconds and check for duplicates)
+ times_s = np.asarray(times).astype("datetime64[s]")
+ unique_times, counts = np.unique(times_s, return_counts=True)
+ if len(unique_times) != len(times_s):
+ duplicates = unique_times[counts > 1]
+ raise ValueError(
+ f"Duplicate times found after converting to seconds: {duplicates}"
+ )
+ time_v[:] = times_s.astype(np.int64)
+ time_v.attrs["units"] = "seconds since 1970-01-01 00:00:00"
+ time_v.attrs["calendar"] = "standard"
+
+ ensemble_array = np.arange(self.n_ensemble)
+ ensemble_v = group.create_array(
+ "ensemble",
+ dtype=np.int32,
+ shape=ensemble_array.shape,
+ chunks=ensemble_array.shape,
+ dimension_names=["ensemble"],
+ )
+ ensemble_v[:] = ensemble_array
+ ensemble_v.attrs["description"] = "ensemble member index"
+
+ # Store forecast step information (hours from initial time)
+ forecast_hours = np.arange(0, (self.nsteps + 1) * 6, 6) # 6-hour steps
+ step_v = group.create_array(
+ "lead_time",
+ dtype=np.int32,
+ shape=forecast_hours.shape,
+ chunks=forecast_hours.shape,
+ dimension_names=["lead_time"],
+ )
+ step_v[:] = forecast_hours
+ step_v.attrs["description"] = "Forecast lead time in hours"
+
+ # Add global attributes for HEALPix grid
+ group.attrs["grid_type"] = "HEALPix"
+ group.attrs["healpix_level"] = 6
+ group.attrs["healpix_nside"] = 64
+ group.attrs["healpix_ncells"] = 49152
+ group.attrs["healpix_pixel_order"] = "NEST"
+ group.attrs["description"] = "ensemble forecasts in HEALPix format"
+
+ zarr.consolidate_metadata(group.store)
+ print(
+ f"Created zarr structure at {zarr_path} (HEALPix level 6, {self.n_ensemble} ensemble members)"
+ )
+
+ def add_array(
+ self, coords: CoordSystem, array_name: str | list[str], **kwargs: dict[str, Any]
+ ) -> None:
+ """
+ Add an array with `array_name` to the existing IO backend object.
+
+ Parameters
+ ----------
+ coords : OrderedDict
+ Ordered dictionary of representing the dimensions and coordinate data
+ of x.
+ array_name : str
+ Name of the arrays that will be initialized with coordinates as dimensions.
+ kwargs : dict[str, Any], optional
+ Optional keyword arguments that will be passed to the IO backend constructor.
+ """
+ return
+
+ def flush(self):
+ return
+
+ def write(
+ self,
+ x: torch.Tensor | list[torch.Tensor],
+ coords: CoordSystem,
+ array_name: str | list[str],
+ ) -> None:
+ """
+ Write data to the current backend using the passed array_name.
+
+ Parameters
+ ----------
+ x : torch.Tensor | list[torch.Tensor]
+ Tensor(s) to be written to zarr store.
+ coords : OrderedDict
+ Coordinates of the passed data.
+ array_name : str | list[str]
+ Name(s) of the array(s) that will be written to.
+ """
+ coords_time = pd.DatetimeIndex(coords["time"]).floor("s")
+ time_idx = self.times.get_indexer(coords_time)
+ if np.any(time_idx == -1):
+ raise ValueError(
+ f"Time mismatch: {coords_time[time_idx == -1]} not in {self.times}"
+ )
+
+ time_idx = np.atleast_1d(time_idx)
+ step = coords["lead_time"] // np.timedelta64(6, "h")
+ step = np.atleast_1d(step)
+ if "ensemble" in coords:
+ ensemble_idx = coords["ensemble"]
+ else:
+ # handle aurora case where no ensemble dim and coords has [t0-6, t0]
+ ensemble_idx = np.array([0])
+ step = step[-1:]
+ ensemble_idx = np.atleast_1d(ensemble_idx)
+
+ # Ensure all models output 721 lat before regridding
+ x_processed = [add_south_pole_mean(array) for array in x]
+ regridded = {
+ name: self._regrid(array) for name, array in zip(array_name, x_processed)
+ }
+
+ for var_name in array_name:
+ nt = len(time_idx)
+ ne = len(ensemble_idx)
+ nstep = len(step)
+ array = regridded[var_name].cpu().numpy()
+
+ # Handle deterministic models (no ensemble dimension)
+ if array.ndim == 3: # Missing ensemble dimension
+ array = array[:, np.newaxis, :, :] # Add ensemble dimension
+
+ # Map e2studio var name to init zarr var name
+ init_var = None
+ if self.init_ds is not None:
+ if var_name in self.init_ds:
+ init_var = var_name
+ elif var_name.upper() in self.init_ds:
+ init_var = var_name.upper()
+ elif self._VAR_MAP.get(var_name) in self.init_ds:
+ init_var = self._VAR_MAP[var_name]
+ can_bypass_t0 = init_var is not None
+
+ for i, j, k in itertools.product(range(nt), range(ne), range(nstep)):
+ if step[k] == 0 and can_bypass_t0:
+ # Write original analysis data directly without regridding effects
+ orig = self.init_ds[init_var].sel(time=coords_time[i]).values
+ orig_tensor = torch.as_tensor(orig)
+ orig_nest = earth2grid.healpix.reorder(
+ orig_tensor,
+ earth2grid.healpix.HEALPIX_PAD_XY,
+ earth2grid.healpix.PixelOrder.NEST,
+ )
+ self.group[var_name][time_idx[i], ensemble_idx[j], 0] = (
+ orig_nest.numpy()
+ )
+ else:
+ # Write regridded data
+ self.group[var_name][time_idx[i], ensemble_idx[j], step[k]] = array[
+ i, j, k
+ ]
diff --git a/examples/weather/healda/scripts/plot_panel.py b/examples/weather/healda/scripts/plot_panel.py
new file mode 100644
index 0000000000..7da0706aa4
--- /dev/null
+++ b/examples/weather/healda/scripts/plot_panel.py
@@ -0,0 +1,127 @@
+#!/usr/bin/env python3
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Plot multi-panel metric comparison (RMSE, CRPS, spread, etc.)."""
+
+import argparse
+
+import matplotlib.pyplot as plt
+import numpy as np
+import xarray as xr
+
+METRIC_LABELS = {
+ "crps": "CRPS",
+ "rmse_ens": "RMSE",
+ "rmse_m0": "RMSE",
+ "spread": "Spread",
+ "ssr": "Spread-Skill Ratio",
+}
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Plot metric comparison panels")
+ parser.add_argument("--stats", nargs="+", required=True, help="Metrics .nc files")
+ parser.add_argument("--labels", nargs="+", required=True, help="System labels")
+ parser.add_argument(
+ "--metric",
+ default="rmse_ens",
+ help="Metric (rmse_ens, rmse_m0, crps, spread, ssr)",
+ )
+ parser.add_argument("--fields", nargs="+", default=None, help="Fields to plot")
+ parser.add_argument("--output_path", default="panel.pdf", help="Output file")
+ parser.add_argument("--max_lead_time", type=int, default=None)
+ args = parser.parse_args()
+
+ if len(args.stats) != len(args.labels):
+ raise ValueError("Number of stats files must match number of labels")
+
+ # Load datasets
+ datasets = [xr.open_dataset(f, decode_timedelta=False) for f in args.stats]
+
+ # Auto-detect fields if not specified
+ if args.fields is None:
+ field_sets = [set(ds.field.values) for ds in datasets if "field" in ds.dims]
+ common = set.intersection(*field_sets) if field_sets else set()
+ preferred = ["Z500", "T850", "U500", "Q700", "msl", "t2m", "u10m", "tcwv"]
+ fields = [f for f in preferred if f in common]
+ fields.extend(sorted(f for f in common if f not in fields))
+ fields = fields[:8]
+ else:
+ fields = args.fields
+
+ if not fields:
+ raise ValueError("No common fields found across datasets")
+
+ # Setup figure
+ ncols = min(4, len(fields))
+ nrows = (len(fields) + ncols - 1) // ncols
+ fig, axes = plt.subplots(nrows, ncols, figsize=(12, 2.5 * nrows), sharex=True)
+ if nrows == 1:
+ axes = axes.reshape(1, -1)
+
+ # Plot each field
+ for idx, field in enumerate(fields):
+ row, col = divmod(idx, ncols)
+ ax = axes[row, col]
+
+ for ds, label in zip(datasets, args.labels):
+ # Get lead time
+ lead = ds["lead_time"]
+ if np.issubdtype(lead.dtype, np.timedelta64):
+ x = (lead / np.timedelta64(1, "h")).astype(int).values
+ else:
+ x = lead.values
+
+ # Get metric data (case-insensitive field match)
+ if args.metric not in ds:
+ print(f"Metric {args.metric} not in {label}, skipping")
+ continue
+
+ ds_fields = {str(f).lower(): str(f) for f in ds.field.values}
+ field_key = ds_fields.get(field.lower())
+ if field_key is None:
+ continue
+
+ y = ds[args.metric].sel(field=field_key).values
+ if y.ndim > 1:
+ y = y.mean(axis=tuple(range(y.ndim - 1)))
+
+ if args.max_lead_time:
+ mask = x <= args.max_lead_time
+ x, y = x[mask], y[mask]
+
+ ax.plot(x, y, label=label, linewidth=1.5)
+
+ ax.set_title(field)
+ ax.grid(True, alpha=0.3)
+ if row == nrows - 1:
+ ax.set_xlabel("Lead time (hours)")
+ if col == 0:
+ ax.set_ylabel(METRIC_LABELS.get(args.metric, args.metric.upper()))
+
+ # Hide unused axes
+ for idx in range(len(fields), nrows * ncols):
+ row, col = divmod(idx, ncols)
+ axes[row, col].set_visible(False)
+
+ axes[0, 0].legend(loc="best", fontsize=10)
+ plt.tight_layout()
+ fig.savefig(args.output_path, dpi=150, bbox_inches="tight")
+ print(f"Saved to {args.output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/weather/healda/scripts/regridding.py b/examples/weather/healda/scripts/regridding.py
new file mode 100644
index 0000000000..45f11adfe9
--- /dev/null
+++ b/examples/weather/healda/scripts/regridding.py
@@ -0,0 +1,64 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Regridding utilities for lat-lon to HEALPix conversion."""
+
+import earth2grid
+import torch
+from earth2grid import healpix
+
+
+def add_south_pole_mean(x: torch.Tensor) -> torch.Tensor:
+ """Add south pole using zonal mean of southernmost latitude.
+
+ Aurora outputs 720 lat (patch_size=4 constraint) โ adds 721st via zonal mean.
+ Other models output 721 lat โ pass-through unchanged.
+ """
+ if x.shape[-2] == 721:
+ return x
+
+ pole_values = x[..., -1:, :].mean(dim=-1, keepdim=True)
+ pole_values = pole_values.expand(*x.shape[:-2], 1, x.shape[-1])
+
+ return torch.cat([x, pole_values], dim=-2)
+
+
+def get_latlon_bilinear_regridder(latlon_grid, regrid_level, dtype):
+ """Create bilinear regridder from lat-lon grid to HEALPix grid."""
+ hpx_grid = healpix.Grid(level=regrid_level, pixel_order=healpix.PixelOrder.NEST)
+ return earth2grid.get_regridder(latlon_grid, hpx_grid).to(dtype)
+
+
+class ConservativeRegridder(torch.nn.Module):
+ """
+ Bilinear regridder to high-res HPX then block average. More conservative than direct bilinear.
+ Matches ERA5 HPX64 processing.
+ """
+
+ def __init__(
+ self, latlon_grid=None, regrid_level=8, out_level=6, dtype=torch.float32
+ ):
+ super().__init__()
+
+ if latlon_grid is None:
+ latlon_grid = earth2grid.latlon.equiangular_lat_lon_grid(
+ nlat=721, nlon=1440
+ )
+ self.regridder = get_latlon_bilinear_regridder(latlon_grid, regrid_level, dtype)
+ self.coarsen_factor = 4 ** (regrid_level - out_level)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.regridder(x)
+ return x.reshape(x.shape[:-1] + (-1, self.coarsen_factor)).mean(-1)
diff --git a/examples/weather/healda/scripts/score_forecast.py b/examples/weather/healda/scripts/score_forecast.py
new file mode 100644
index 0000000000..573d41970c
--- /dev/null
+++ b/examples/weather/healda/scripts/score_forecast.py
@@ -0,0 +1,363 @@
+#!/usr/bin/env python
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Score forecasts against reference (e.g., ERA5).
+Computes ensemble metrics (RMSE, spread, CRPS) per variable and lead time.
+"""
+
+import argparse
+import logging
+import multiprocessing
+import os
+from concurrent.futures import ProcessPoolExecutor
+from functools import partial
+from typing import Literal, Optional
+
+import earth2grid
+import numpy as np
+import torch
+import xarray as xr
+from ensemble_metrics import unbiased_ensemble_metrics
+from tqdm import tqdm
+
+logging.basicConfig(
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+HPX_LEVEL = 6
+
+DEFAULT_SCORE_FIELDS = ["Z500", "U500", "T850", "Q700", "t2m", "msl", "u10m", "tcwv"]
+
+# Surface field aliases: canonical name -> list of equivalent names to search for
+SURFACE_ALIASES = {
+ "t2m": ["t2m", "tas", "2t", "2m_temperature"],
+ "u10m": ["u10m", "uas", "10u", "10m_u_component_of_wind"],
+ "v10m": ["v10m", "vas", "10v", "10m_v_component_of_wind"],
+ "msl": ["msl", "pres_msl", "mean_sea_level_pressure"],
+ "tcwv": ["tcwv", "total_column_water_vapour"],
+}
+
+
+def resolve_field(ds: xr.Dataset, field: str) -> Optional[xr.DataArray]:
+ """Resolve field with case-insensitive matching, level-dimension, and surface aliases.
+
+ Handles:
+ - Exact match (Z500)
+ - Case-insensitive (Z500 vs z500)
+ - Level-dimension selection (z with level=500 -> Z500)
+ - Surface field aliases (t2m -> tas, 2t, etc.)
+ """
+ # Exact match
+ if field in ds.data_vars:
+ return ds[field]
+ # Case-insensitive match (handles Z500 vs z500)
+ for var in ds.data_vars:
+ if var.lower() == field.lower():
+ return ds[var]
+ # Level-dimension selection (e.g., Z500 -> z.sel(level=500))
+ if len(field) > 1 and field[0].isalpha() and field[1:].isdigit():
+ base = field[0].lower() # 'z' from 'Z500'
+ level = int(field[1:]) # 500 from 'Z500'
+ for var in [base, base.upper()]:
+ if var in ds.data_vars and "level" in ds[var].dims:
+ return ds[var].sel(level=level)
+ # Surface field aliases (exact match only)
+ if field in SURFACE_ALIASES:
+ for alias in SURFACE_ALIASES[field]:
+ if alias in ds.data_vars:
+ return ds[alias]
+ return None
+
+
+def get_common_fields(
+ datasets: list[xr.Dataset], fields: list[str] = None
+) -> list[str]:
+ """Get fields that exist in all datasets."""
+ fields = fields or DEFAULT_SCORE_FIELDS
+ common = []
+ for f in fields:
+ found_in_all = all(resolve_field(ds, f) is not None for ds in datasets)
+ if found_in_all:
+ common.append(f)
+ else:
+ # Log which dataset is missing the field
+ missing_in = [
+ i for i, ds in enumerate(datasets) if resolve_field(ds, f) is None
+ ]
+ logger.warning(
+ f"Skipping field '{f}' - not found in dataset(s): {missing_in}"
+ )
+ return common
+
+
+def open_any(path: str, storage_options: Optional[dict] = None) -> xr.Dataset:
+ """Open dataset from zarr or netcdf, with optional S3 support."""
+ if path.endswith(".zarr"):
+ return xr.open_zarr(path, storage_options=storage_options)
+ return xr.open_dataset(path, storage_options=storage_options)
+
+
+def setup_hpx_regridder(input_format: Literal["nest", "ring", "hpxpadxy"]) -> callable:
+ """Create regridder to convert input format to NEST."""
+ if input_format == "nest":
+ return lambda x: x
+ src_order = {
+ "ring": earth2grid.healpix.PixelOrder.RING,
+ "hpxpadxy": earth2grid.healpix.HEALPIX_PAD_XY,
+ }[input_format]
+ return lambda x: earth2grid.healpix.reorder(
+ torch.as_tensor(x), src_order, earth2grid.healpix.PixelOrder.NEST
+ ).float()
+
+
+def get_lead_time_hours(forecast: xr.Dataset) -> np.ndarray:
+ """Get lead time values in hours as integers."""
+ lead_time = forecast.lead_time
+ if np.issubdtype(lead_time.dtype, np.timedelta64):
+ if lead_time.dtype == np.dtype("timedelta64[ns]"):
+ return (lead_time / np.timedelta64(1, "h")).astype(int).values
+ return lead_time.astype("timedelta64[h]").astype(int).values
+ return lead_time.values
+
+
+def get_valid_times(forecast: xr.Dataset) -> xr.DataArray:
+ """Calculate valid times for each forecast step."""
+ lead_hours = get_lead_time_hours(forecast)
+ lead_deltas = np.array([np.timedelta64(int(h), "h") for h in lead_hours])
+ lead_time = xr.DataArray(
+ lead_deltas, dims=["lead_time"], coords={"lead_time": lead_hours}
+ )
+ times = forecast.time.expand_dims(lead_time=lead_hours)
+ return times + lead_time
+
+
+@torch.no_grad()
+def compute_metrics_for_field(
+ reference_ds: xr.Dataset,
+ forecast_ds: xr.Dataset,
+ field: str,
+ forecast_format: Literal["nest", "ring", "hpxpadxy"],
+ reference_format: Literal["nest", "ring", "hpxpadxy"],
+ worker_id: int = 0,
+ num_gpus: int = 1,
+) -> xr.Dataset:
+ """Compute metrics for a single field."""
+ gpu_id = worker_id % max(num_gpus, 1)
+ if torch.cuda.is_available():
+ torch.cuda.set_device(gpu_id)
+
+ forecast_regridder = setup_hpx_regridder(forecast_format)
+ reference_regridder = setup_hpx_regridder(reference_format)
+
+ reference = resolve_field(reference_ds, field)
+ forecast = resolve_field(forecast_ds, field)
+ if reference is None or forecast is None:
+ raise ValueError(f"Field {field} not found in datasets")
+
+ if "ensemble" not in forecast.dims:
+ forecast = forecast.expand_dims(ensemble=[0])
+
+ valid_times = get_valid_times(forecast)
+ forecast = forecast.sel(time=valid_times.time)
+ reference = reference.sel(time=valid_times)
+
+ metrics_list = []
+ nan_warning_count = 0
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ logger.info(f"Processing {len(forecast.time)} time steps for {field}")
+
+ for t in tqdm(range(len(forecast.time)), desc=f"Processing {field}"):
+ forecast_t = forecast_regridder(
+ torch.as_tensor(forecast.isel(time=t).values)
+ ).to(device)
+ forecast_t = forecast_t.permute(1, 2, 0) # [lead_time, cells, ensemble]
+ reference_t = reference_regridder(
+ torch.as_tensor(reference.isel(time=t).values)
+ ).to(device)
+
+ # NaN handling
+ if torch.isnan(forecast_t).any():
+ if nan_warning_count < 2:
+ logger.warning(f"NaN in forecast for {field} at time {t}")
+ nan_warning_count += 1
+ forecast_t = torch.nan_to_num(
+ forecast_t, nan=forecast_t[~torch.isnan(forecast_t)].mean()
+ )
+
+ metrics_t = unbiased_ensemble_metrics(forecast_t, reference_t)
+ metrics_t["mse_m0"] = (forecast_t[:, :, 0] - reference_t) ** 2
+
+ metrics_list.append({k: v.cpu() for k, v in metrics_t.items()})
+ del reference_t, forecast_t
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Combine metrics
+ metrics = {
+ key: torch.stack([m[key] for m in metrics_list], dim=0)
+ for key in metrics_list[0].keys()
+ }
+
+ def reduce_avg(x):
+ return x.mean(dim=(0, -1))
+
+ data_vars = {
+ "rmse_ens": (
+ ("lead_time",),
+ torch.sqrt(torch.clamp(reduce_avg(metrics["mse"]), min=0)).numpy(),
+ ),
+ "rmse_m0": (
+ ("lead_time",),
+ torch.sqrt(torch.clamp(reduce_avg(metrics["mse_m0"]), min=0)).numpy(),
+ ),
+ }
+
+ # Ensemble metrics (spread, CRPS, SSR)
+ if "variance" in metrics:
+ spread = torch.sqrt(reduce_avg(metrics["variance"]))
+ rmse_ens = torch.sqrt(torch.clamp(reduce_avg(metrics["mse"]), min=0))
+ R = forecast_ds.sizes.get("ensemble", 1)
+ # SSR = sqrt((R+1)/R) * spread / rmse
+ ssr = (
+ torch.sqrt(torch.tensor((R + 1) / R))
+ * spread
+ / torch.clamp(rmse_ens, min=1e-9)
+ )
+
+ data_vars["spread"] = (("lead_time",), spread.numpy())
+ data_vars["crps"] = (("lead_time",), reduce_avg(metrics["crps"]).numpy())
+ data_vars["ssr"] = (("lead_time",), ssr.numpy())
+
+ return xr.Dataset(
+ data_vars,
+ coords={"lead_time": get_lead_time_hours(forecast_ds), "field": field},
+ )
+
+
+def score_forecast(
+ reference_path: str,
+ forecast_path: str,
+ fields: list[str],
+ forecast_format: Literal["nest", "ring", "hpxpadxy"] = "nest",
+ reference_format: Literal["nest", "ring", "hpxpadxy"] = "hpxpadxy",
+) -> xr.Dataset:
+ """Score forecast against reference with parallel processing across fields."""
+ reference_ds = open_any(reference_path)
+ forecast_ds = open_any(forecast_path)
+
+ fields = get_common_fields([reference_ds, forecast_ds], fields)
+ num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
+ logger.info(f"Scoring {len(fields)} fields on {num_gpus} GPUs")
+
+ worker = partial(
+ compute_metrics_for_field,
+ reference_ds,
+ forecast_ds,
+ forecast_format=forecast_format,
+ reference_format=reference_format,
+ num_gpus=num_gpus,
+ )
+
+ results = []
+ mp_context = multiprocessing.get_context("spawn")
+ with ProcessPoolExecutor(
+ max_workers=min(len(fields), 8), mp_context=mp_context
+ ) as executor:
+ futures = {
+ executor.submit(worker, field, worker_id=i): field
+ for i, field in enumerate(fields)
+ }
+ for future in tqdm(futures.keys(), desc="Scoring fields"):
+ try:
+ results.append(future.result())
+ except Exception as e:
+ logger.error(f"Error scoring {futures[future]}: {e}")
+
+ if not results:
+ raise RuntimeError("No fields were successfully scored")
+
+ combined = xr.concat(results, dim="field")
+ combined.lead_time.attrs = {"long_name": "forecast lead time", "units": "hours"}
+ return combined
+
+
+def main():
+ multiprocessing.set_start_method(
+ "spawn"
+ ) # Required for CUDA with ProcessPoolExecutor
+
+ parser = argparse.ArgumentParser(description="Score forecasts against reference")
+ parser.add_argument(
+ "--forecast_path",
+ type=str,
+ required=True,
+ help="Path to forecast zarr (from forecast.py)",
+ )
+ parser.add_argument(
+ "--reference_path",
+ type=str,
+ required=True,
+ help="Path to reference zarr (from inference.py --use_analysis)",
+ )
+ parser.add_argument(
+ "--output_path", type=str, required=True, help="Path to save metrics (.nc)"
+ )
+ parser.add_argument(
+ "--forecast_format",
+ type=str,
+ default="nest",
+ choices=["nest", "ring", "hpxpadxy"],
+ help="HEALPix format of forecast",
+ )
+ parser.add_argument(
+ "--reference_format",
+ type=str,
+ default="hpxpadxy",
+ choices=["nest", "ring", "hpxpadxy"],
+ help="HEALPix format of reference",
+ )
+ parser.add_argument(
+ "--fields",
+ type=str,
+ nargs="+",
+ default=None,
+ help=f"Fields (default: {DEFAULT_SCORE_FIELDS})",
+ )
+ args = parser.parse_args()
+
+ fields = args.fields or DEFAULT_SCORE_FIELDS
+ logger.info(f"Forecast: {args.forecast_path}")
+ logger.info(f"Reference: {args.reference_path}")
+ logger.info(f"Scoring {len(fields)} fields: {fields}")
+
+ metrics = score_forecast(
+ reference_path=args.reference_path,
+ forecast_path=args.forecast_path,
+ fields=fields,
+ forecast_format=args.forecast_format,
+ reference_format=args.reference_format,
+ )
+
+ os.makedirs(os.path.dirname(args.output_path) or ".", exist_ok=True)
+ metrics.to_netcdf(args.output_path)
+ logger.info(f"Saved metrics to {args.output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/weather/healda/tests/test_checkpoint_handler.py b/examples/weather/healda/tests/test_checkpoint_handler.py
new file mode 100644
index 0000000000..70e463f741
--- /dev/null
+++ b/examples/weather/healda/tests/test_checkpoint_handler.py
@@ -0,0 +1,33 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ruff: noqa: S101
+from unittest.mock import patch
+
+from training.loop import CheckpointHandler
+
+
+def test_checkpoint_handler_basic_functionality(tmp_path):
+ """Test core CheckpointHandler functionality."""
+ handler = CheckpointHandler(str(tmp_path))
+
+ # Test get_path
+ assert handler.get_path(123) == f"{tmp_path}/training-state-000000123.checkpoint"
+
+ # Test list_checkpoints with mock files
+ with patch("glob.glob") as mock_glob:
+ mock_glob.return_value = ["training-state-000000001.checkpoint", "invalid.txt"]
+ checkpoints = list(handler.list_checkpoints())
+ assert checkpoints == [(f"{tmp_path}/training-state-000000001.checkpoint", 1)]
diff --git a/examples/weather/healda/tests/test_checkpointing.py b/examples/weather/healda/tests/test_checkpointing.py
new file mode 100644
index 0000000000..ab57e382f9
--- /dev/null
+++ b/examples/weather/healda/tests/test_checkpointing.py
@@ -0,0 +1,43 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ruff: noqa: S101
+import models
+from utils.checkpointing import Checkpoint
+
+from config.model_config import ModelConfigV1
+
+
+def _assert_state_dict_equal(d1, d2):
+ assert set(d1) == set(d2)
+ for k in d1:
+ assert d1[k].equal(d2[k])
+
+
+def test_checkpointing(tmp_path):
+ config = ModelConfigV1(
+ architecture="dit-test",
+ )
+ model = models.get_model(config)
+
+ state_dict = model.state_dict()
+ with Checkpoint(tmp_path / "test.checkpoint", "w") as checkpoint:
+ checkpoint.write_model(model)
+ checkpoint.write_model_config(config)
+
+ with Checkpoint(tmp_path / "test.checkpoint", "r") as checkpoint:
+ model = checkpoint.read_model()
+ assert config == checkpoint.read_model_config()
+ _assert_state_dict_equal(model.state_dict(), state_dict)
diff --git a/examples/weather/healda/tests/test_dataclass_parser.py b/examples/weather/healda/tests/test_dataclass_parser.py
new file mode 100644
index 0000000000..4e6902c9f7
--- /dev/null
+++ b/examples/weather/healda/tests/test_dataclass_parser.py
@@ -0,0 +1,184 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ruff: noqa: S101
+from dataclasses import dataclass, field
+from enum import Enum, auto
+from typing import Any, Optional
+
+import pytest
+from utils.dataclass_parser import Help, a, parse_args, parse_dict
+
+
+@dataclass(frozen=True)
+class ModelConfig:
+ learning_rate: float = 0.01
+ epochs: int = 10
+ optional: Optional[bool] = False
+
+
+@dataclass
+class Config:
+ """Test config dataclass."""
+
+ model: ModelConfig = ModelConfig()
+ model_name: str = "default_model"
+ opt: a[str, Help("An example option.")] = "a"
+
+
+@pytest.mark.parametrize("convert_underscore_to_hyphen", [True, False])
+def test_parse_args(convert_underscore_to_hyphen):
+ # Usage example
+ sep = "-" if convert_underscore_to_hyphen else "_"
+ args = [
+ f"--model.learning{sep}rate",
+ "0.1",
+ "--model.epochs",
+ "20",
+ f"--model{sep}name",
+ "my_model",
+ ]
+
+ with pytest.raises(SystemExit):
+ parse_args(
+ Config,
+ [f"--model.learning{sep}rate", "not a num"],
+ convert_underscore_to_hyphen=convert_underscore_to_hyphen,
+ )
+
+ expected = Config(
+ model=ModelConfig(learning_rate=0.1, epochs=20), model_name="my_model"
+ )
+ assert (
+ parse_args(
+ Config, args, convert_underscore_to_hyphen=convert_underscore_to_hyphen
+ )
+ == expected
+ )
+
+
+def test_parse_dict():
+ obj = {"model_name": "hello", "model": {"learning_rate": 0.1}}
+ expected = Config(model=ModelConfig(learning_rate=0.1), model_name="hello")
+ assert parse_dict(Config, obj) == expected
+
+ with pytest.raises(ValueError):
+ parse_dict(Config, {"model_name": 1})
+
+
+def test_parse_args_optional():
+ @dataclass
+ class Config:
+ a: Optional[int] = None
+
+ c = parse_args(Config, ["--a", "1"])
+ assert c == Config(1)
+
+
+def test_parse_args_union():
+ @dataclass
+ class Config:
+ a: int | None = None
+
+ c = parse_args(Config, ["--a", "1"])
+ assert c == Config(1)
+
+
+def test_parse_args_any():
+ @dataclass
+ class Config:
+ a: Any = None
+
+ c = parse_args(Config, ["--a", "1"])
+ assert c == Config("1")
+
+
+def test_parse_args_bool_default_false():
+ @dataclass
+ class Config:
+ a: bool = False
+
+ c = parse_args(Config, ["--a"])
+ assert c == Config(True)
+
+
+def test_parse_args_bool_default_true():
+ @dataclass
+ class Config:
+ a: bool = True
+
+ c = parse_args(Config, ["--no-a"])
+ assert c == Config(False)
+
+
+def test_parse_args_bool_default_true_nested():
+ @dataclass
+ class Sub:
+ a: bool = False
+
+ @dataclass
+ class Config:
+ sub: Sub = field(default_factory=lambda: Sub(a=True))
+
+ c = parse_args(Config, ["--sub.no_a"], convert_underscore_to_hyphen=False)
+ assert c == Config(Sub(False))
+
+ c = parse_args(Config, ["--sub.no-a"])
+ assert c == Config(Sub(False))
+
+
+def test_enum():
+ class Options(Enum):
+ a = auto()
+ b = auto()
+
+ @dataclass
+ class CLI:
+ opt: Options = Options.a
+
+ c = parse_args(CLI, ["--opt", "b"])
+ c.opt == Options.b
+
+ c = parse_args(CLI, [])
+ c.opt == Options.a
+
+
+def test_parse_args_double_nested():
+ @dataclass(eq=True)
+ class SubSub:
+ a: int = 1
+
+ @dataclass(eq=True)
+ class Sub:
+ sub: SubSub = field(default_factory=SubSub)
+
+ @dataclass(eq=True)
+ class Config:
+ sub: Sub = field(default_factory=Sub)
+
+ c = parse_args(Config, ["--sub.sub.a", "1"], convert_underscore_to_hyphen=False)
+ assert c == Config()
+
+
+def test_parse_args_with_list_generic_type():
+ """Ensure list[...] | None fields are compatible with parse_args strict checks."""
+
+ @dataclass
+ class Config:
+ values: list[int] | None = None
+
+ # This should not raise TypeError from isinstance() on a parameterized generic.
+ cfg = parse_args(Config, args=[])
+ assert cfg.values is None
diff --git a/examples/weather/healda/tests/test_loop.py b/examples/weather/healda/tests/test_loop.py
new file mode 100644
index 0000000000..e41d3da592
--- /dev/null
+++ b/examples/weather/healda/tests/test_loop.py
@@ -0,0 +1,60 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+import torch
+from datasets.base import BatchInfo, TimeUnit
+from training import loop
+
+requires_cuda = pytest.mark.skipif(
+ not torch.cuda.is_available(), reason="CUDA not available"
+)
+
+
+class MockLoop(loop.TrainingLoopBase):
+ def get_data_loaders(self, batch_gpu: int):
+ class MockDataset:
+ def __init__(self):
+ self.batch_info = BatchInfo(
+ channels=["channel_0", "channel_1"],
+ time_step=1,
+ time_unit=TimeUnit.HOUR,
+ scales=[1.0, 1.0],
+ center=[0.0, 0.0],
+ )
+
+ dataset = MockDataset()
+ return dataset, None, None
+
+ def get_network(self) -> torch.nn.Module:
+ return torch.nn.Linear(10, 10)
+
+ def get_optimizer(self, parameters):
+ return torch.optim.Adam(parameters)
+
+ def get_loss_fn(self):
+ return torch.nn.MSELoss()
+
+
+@requires_cuda
+def test_loop_save_load(tmp_path):
+ rundir = tmp_path
+ loop = MockLoop(rundir.as_posix(), batch_gpu=1)
+ loop.setup()
+ loop.save_training_state(0)
+
+ loop = MockLoop.loads((rundir / "loop.json").read_text())
+ loop.setup()
+ loop.resume_from_rundir(rundir)
diff --git a/examples/weather/healda/tests/test_merged_dataset.py b/examples/weather/healda/tests/test_merged_dataset.py
new file mode 100644
index 0000000000..23e55859e4
--- /dev/null
+++ b/examples/weather/healda/tests/test_merged_dataset.py
@@ -0,0 +1,399 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ruff: noqa: S101
+import functools
+
+import numpy as np
+import pandas as pd
+import pytest
+import torch
+from datasets import datetime_utils
+from datasets.merged_dataset import (
+ TimeMergedDataset,
+ TimeMergedMapStyle,
+ _FrameIndexGenerator,
+ _split_array_contiguous,
+)
+
+
+class MockLoader:
+ def __init__(
+ self,
+ num_frames: int,
+ target_channels: int,
+ condition_channels: int,
+ num_pixels: int,
+ start_date: str,
+ ):
+ # Create data where each frame's values equal its index
+ # Shape: (num_frames, num_channels, 1, num_pixels)
+ self.data = (
+ torch.arange(num_frames)
+ .float()
+ .view(num_frames, 1, 1, 1)
+ .expand(-1, target_channels + condition_channels, -1, num_pixels)
+ )
+ self.target_channels = target_channels
+ self.start_time = pd.Timestamp(start_date)
+
+ async def sel_time(self, times):
+ # Convert timestamps to indices (hours since start)
+ indices = [(t - self.start_time).total_seconds() / 3600 for t in times]
+ indices = [int(i) for i in indices]
+ return {
+ "target": self.data[indices, : self.target_channels], # (C, 1, X)
+ "condition": self.data[indices, self.target_channels :],
+ }
+
+
+def temporal_stack(timestamp, frames):
+ out = {key: torch.cat([f[key] for f in frames], dim=1) for key in frames[0].keys()}
+ out["timestamp"] = datetime_utils.as_timestamp(timestamp)
+ return out
+
+
+@pytest.mark.parametrize(
+ "time_length,frame_step,window_stride",
+ [
+ (1, 1, 2), # Image
+ (1, 1, 1), # Image
+ (4, 3, 2), # Video
+ (6, 4, 1), # Video
+ ],
+)
+def test_time_merged_dataset(time_length, frame_step, window_stride):
+ # Setup small dataset where the total length is chunk_size + padding
+ # This way we can verify both in-chunk and between-chunk behavior
+ chunk_size = time_length * frame_step * window_stride + 10
+ num_frames = (
+ chunk_size + (time_length - 1) * frame_step
+ ) # Enough frames for one full chunk
+ target_channels = 5
+ condition_channels = 1
+ num_pixels = 16
+
+ start_date = "2025-01-01"
+ times = pd.date_range(start_date, periods=num_frames, freq="h")
+
+ loader = MockLoader(
+ num_frames, target_channels, condition_channels, num_pixels, start_date
+ )
+
+ transform = functools.partial(
+ temporal_stack,
+ )
+
+ # Test with shuffle=False
+ dataset = TimeMergedDataset(
+ times=times,
+ time_loaders=[loader],
+ transform=transform,
+ time_length=time_length,
+ frame_step=frame_step,
+ window_stride=window_stride,
+ chunk_size=chunk_size,
+ shuffle=False,
+ infinite=False,
+ )
+
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=1,
+ )
+
+ # Collect all samples from the dataset
+ samples = list(dataloader)
+
+ frames_per_window = (time_length - 1) * frame_step + 1
+ valid_length = num_frames - frames_per_window + 1
+ expected_windows = (valid_length + window_stride - 1) // window_stride
+ assert len(samples) == expected_windows, (
+ f"Expected {expected_windows} windows, got {len(samples)}"
+ )
+
+ for sample in samples:
+ assert sample["target"].shape == (1, target_channels, time_length, num_pixels)
+ assert sample["condition"].shape == (
+ 1,
+ condition_channels,
+ time_length,
+ num_pixels,
+ )
+
+ # Get start indices of each window and verify they're window_stride apart
+ start_indices = []
+ for sample in samples:
+ start_idx = int(sample["target"][0, 0, 0, 0].item())
+ start_indices.append(start_idx)
+
+ expected_starts = list(range(0, chunk_size * window_stride, window_stride))[
+ :expected_windows
+ ]
+ assert start_indices == expected_starts, (
+ "Without shuffle, windows should be in exact sequential order"
+ )
+
+ # Test with shuffle=True
+ dataset = TimeMergedDataset(
+ times=times,
+ time_loaders=[loader],
+ transform=transform,
+ time_length=time_length,
+ frame_step=frame_step,
+ window_stride=window_stride,
+ chunk_size=chunk_size,
+ shuffle=True,
+ infinite=True, # Need infinite for multiple chunks
+ )
+
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=1,
+ )
+
+ # Get two chunks worth of data and verify they are in a different order
+ chunk1_indices = []
+ chunk2_indices = []
+ dataloader_iter = iter(dataloader)
+
+ for _ in range(chunk_size // window_stride):
+ sample = next(dataloader_iter)
+ start_idx = int(sample["target"][0, 0, 0, 0].item())
+ chunk1_indices.append(start_idx)
+
+ for _ in range(chunk_size // window_stride):
+ sample = next(dataloader_iter)
+ start_idx = int(sample["target"][0, 0, 0, 0].item())
+ chunk2_indices.append(start_idx)
+
+ assert sorted(chunk1_indices) == sorted(chunk2_indices), (
+ "Chunks should contain same indices"
+ )
+
+ if len(chunk1_indices) != 1:
+ assert chunk1_indices != chunk2_indices, (
+ "Chunks should be in different orders when shuffled"
+ )
+
+
+@pytest.mark.parametrize(
+ "time_length,frame_step",
+ [
+ (1, 1), # Image case
+ (4, 2), # Video case
+ ],
+)
+def test_time_merged_map_style_dataset(time_length, frame_step):
+ num_frames = 100
+ target_channels = 5
+ condition_channels = 1
+ num_pixels = 16
+
+ start_date = "2025-01-01"
+ times = pd.date_range(start_date, periods=num_frames, freq="h")
+
+ loader = MockLoader(
+ num_frames, target_channels, condition_channels, num_pixels, start_date
+ )
+
+ transform = functools.partial(
+ temporal_stack,
+ )
+
+ dataset = TimeMergedMapStyle(
+ times=times,
+ time_loaders=[loader],
+ time_length=time_length,
+ frame_step=frame_step,
+ transform=transform,
+ )
+
+ # Test length
+ frames_per_window = (time_length - 1) * frame_step + 1
+ expected_length = num_frames - frames_per_window + 1
+ assert len(dataset) == expected_length
+
+ # Test first window
+ sample = dataset[0]
+ assert sample["target"].shape == (target_channels, time_length, num_pixels)
+ assert sample["condition"].shape == (condition_channels, time_length, num_pixels)
+
+ # Check values
+ expected_indices = list(
+ range(0, time_length * frame_step, frame_step)
+ ) # [0] for image, [0,2,4,6] for video
+ for t, expected_idx in enumerate(expected_indices):
+ assert torch.all(sample["target"][:, t] == expected_idx)
+ assert torch.all(sample["condition"][:, t] == expected_idx)
+
+ # Test last window
+ last_idx = len(dataset) - 1
+ last_sample = dataset[last_idx]
+ assert last_sample["target"].shape == (target_channels, time_length, num_pixels)
+ assert last_sample["condition"].shape == (
+ condition_channels,
+ time_length,
+ num_pixels,
+ )
+
+ # Test out of bounds
+ try:
+ dataset[len(dataset)]
+ assert False, "Should have raised IndexError"
+ except IndexError:
+ pass
+
+
+def test_map_style_caching():
+ start_date = "2025-01-01"
+ num_frames = 16
+ times = pd.date_range(start_date, periods=num_frames, freq="h")
+ num_frames = 100
+ target_channels = 5
+ condition_channels = 1
+ num_pixels = 16
+
+ loader = MockLoader(
+ num_frames, target_channels, condition_channels, num_pixels, start_date
+ )
+
+ transform = functools.partial(
+ temporal_stack,
+ )
+
+ dataset = TimeMergedMapStyle(
+ times=times,
+ time_loaders=[loader],
+ time_length=1,
+ frame_step=1,
+ cache_chunk_size=8,
+ transform=transform,
+ )
+
+ for i in range(10):
+ out = dataset[i]
+ time = out["timestamp"].astype("datetime64[s]")
+ assert time == times[i]
+
+ assert dataset._cache_data is not None
+
+
+def test_split_array_contiguous():
+ input = np.arange(10)
+ (output,) = _split_array_contiguous(input)
+ assert np.all(input == output)
+
+ input = np.array([0, 1, 2, 5, 6])
+ (out1, out2) = _split_array_contiguous(input)
+ assert out1.tolist() == [0, 1, 2]
+ assert out2.tolist() == [5, 6]
+
+
+def test_map_style_getitems():
+ start_date = "2025-01-01"
+ num_frames = 16
+ times = pd.date_range(start_date, periods=num_frames, freq="h")
+ num_frames = 100
+ target_channels = 5
+ condition_channels = 1
+ num_pixels = 16
+ batch_size = 3
+
+ loader = MockLoader(
+ num_frames, target_channels, condition_channels, num_pixels, start_date
+ )
+
+ transform = functools.partial(
+ temporal_stack,
+ )
+
+ class BatchTransform:
+ def __init__(self) -> None:
+ self._called = False
+
+ def __call__(self, times, frames):
+ self._called = True
+ return torch.ones(len(times), len(times[0]))
+
+ batch_transform = BatchTransform()
+
+ dataset = TimeMergedMapStyle(
+ times=times,
+ time_loaders=[loader],
+ time_length=2,
+ frame_step=1,
+ cache_chunk_size=8,
+ transform=transform,
+ batch_transform=batch_transform,
+ )
+
+ loader = torch.utils.data.DataLoader(
+ dataset, collate_fn=lambda x: x, batch_size=batch_size
+ )
+ out = next(iter(loader))
+ assert out.shape == (batch_size, dataset.time_length)
+ assert batch_transform._called
+
+
+def test_frame_index_generator():
+ """Test the _FrameIndexGenerator functionality."""
+ # Test basic functionality
+ times = np.arange(100) # Simple integer array [0, 1, 2, ..., 99]
+ generator = _FrameIndexGenerator(
+ times=times, time_length=3, frame_step=2, model_rank=0, model_world_size=1
+ )
+
+ start_indices = torch.tensor([0, 10])
+ frame_idxs = generator.generate_frame_indices(start_indices)
+
+ expected = [[0, 2, 4], [10, 12, 14]]
+ assert frame_idxs == expected
+
+ # Test with model rank slicing
+ generator_multi_rank = _FrameIndexGenerator(
+ times=times, time_length=4, frame_step=1, model_rank=1, model_world_size=2
+ )
+
+ start_indices = torch.tensor([5])
+ frame_idxs = generator_multi_rank.generate_frame_indices(start_indices)
+
+ # Expected: [7, 8] for rank 1 of 2 with time_length=4
+ # Full range would be [5, 6, 7, 8], rank 1 gets second half: [7, 8]
+ expected = [7, 8]
+ assert frame_idxs[0] == expected
+
+
+def test_frame_index_generator_multiple_segments():
+ """Test _FrameIndexGenerator with multiple segments."""
+ times = np.concatenate(
+ [
+ np.arange(0, 10), # [0, 1, 2, ..., 9]
+ np.arange(20, 35), # [20, 21, 22, ..., 34]
+ ]
+ )
+
+ generator = _FrameIndexGenerator(
+ times=times, time_length=3, frame_step=1, model_rank=0, model_world_size=1
+ )
+
+ # [7, 8, 9] is the last valid sample in the first segment
+ assert times[generator._map_logical_to_physical(0)] == 0
+ assert times[generator._map_logical_to_physical(1)] == 1
+ assert times[generator._map_logical_to_physical(7)] == 7
+ assert times[generator._map_logical_to_physical(8)] == 20
+
+ assert all(times[generator.generate_frame_indices([7])[0]] == [7, 8, 9])
+ assert all(times[generator.generate_frame_indices([8])[0]] == [20, 21, 22])
diff --git a/examples/weather/healda/tests/test_obs_filtering_utils.py b/examples/weather/healda/tests/test_obs_filtering_utils.py
new file mode 100644
index 0000000000..7d4077832c
--- /dev/null
+++ b/examples/weather/healda/tests/test_obs_filtering_utils.py
@@ -0,0 +1,130 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ruff: noqa: S101
+"""
+Test script for the obs_filtering_utils.py implementation.
+This tests the vectorized filtering approach for conventional observations.
+"""
+
+import numpy as np
+import pyarrow as pa
+from datasets.etl.combined_schema import (
+ GLOBAL_CHANNEL_ID,
+ get_combined_observation_schema,
+)
+from datasets.etl.etl_unified import get_channel_table
+from datasets.obs_filtering_utils import filter_observations
+from datasets.obs_loader import LOCAL_CHANNEL_ID
+from datasets.sensors import SENSOR_OFFSET
+
+
+def create_test_data():
+ """Create test data with different platform types."""
+ # Create test data with different channel IDs
+ conv_offset = SENSOR_OFFSET["conv"]
+
+ # GPS channels: 0, 1, 2
+ gps_channels = [conv_offset + 0, conv_offset + 1, conv_offset + 2]
+ # PS channel: 3
+ ps_channels = [conv_offset + 3]
+ # Q channel: 4
+ q_channels = [conv_offset + 4]
+ # T channel: 5
+ t_channels = [conv_offset + 5]
+ # UV channels: 6, 7
+ uv_channels = [conv_offset + 6, conv_offset + 7]
+
+ # Combine all channels
+ all_channels = gps_channels + ps_channels + q_channels + t_channels + uv_channels
+
+ # Create test data
+ n_rows = len(all_channels)
+ data = {
+ # Required common fields
+ "Latitude": np.random.uniform(-90, 90, n_rows),
+ "Longitude": np.random.uniform(-180, 180, n_rows),
+ "Absolute_Obs_Time": np.array(
+ [np.datetime64("2023-01-01T00:00:00")] * n_rows, dtype="datetime64[ns]"
+ ),
+ "DA_window": np.array(
+ [np.datetime64("2023-01-01T00:00:00")] * n_rows, dtype="datetime64[ns]"
+ ),
+ "Platform_ID": np.random.randint(1, 100, n_rows),
+ "Global_Channel_ID": all_channels,
+ "Observation": np.random.uniform(0, 100, n_rows),
+ # Satellite-specific fields (nullable)
+ "Sat_Zenith_Angle": np.full(n_rows, None, dtype=object),
+ "Sol_Zenith_Angle": np.full(n_rows, None, dtype=object),
+ "Scan_Angle": np.full(n_rows, None, dtype=object),
+ # Conventional-specific fields
+ "Pressure": np.random.uniform(200, 1100, n_rows),
+ "Height": np.random.uniform(0, 50000, n_rows),
+ "Observation_Type": np.random.randint(1, 10, n_rows),
+ # Analysis fields (nullable)
+ "QC_Flag": np.random.choice([0, 1], n_rows),
+ "Analysis_Use_Flag": np.random.choice([0, 1], n_rows),
+ "Obs_Minus_Forecast_adjusted": np.random.uniform(-10, 10, n_rows),
+ "Obs_Minus_Forecast_unadjusted": np.random.uniform(-10, 10, n_rows),
+ }
+
+ obs_table = pa.table(data, schema=get_combined_observation_schema())
+
+ # Add channel metadata via join (mimics UFSUnifiedLoader._add_channel_metadata)
+ def _add_channel_metadata(table):
+ channel_table = get_channel_table()
+
+ # Add local_channel_id (same as UFSUnifiedLoader.channel_table property)
+ sensor_id = np.asarray(channel_table["sensor_id"])
+ local_channel_ids = []
+ offset = 0
+ for i in range(len(sensor_id)):
+ if sensor_id[i] != sensor_id[i - 1]:
+ offset = i
+ local_channel_ids.append(i - offset)
+ channel_table = channel_table.append_column(
+ LOCAL_CHANNEL_ID.name, pa.array(local_channel_ids, type=pa.uint16())
+ )
+
+ return table.join(
+ channel_table.select(
+ [
+ GLOBAL_CHANNEL_ID.name,
+ LOCAL_CHANNEL_ID.name,
+ "min_valid",
+ "max_valid",
+ "is_conv",
+ ]
+ ),
+ GLOBAL_CHANNEL_ID.name,
+ )
+
+ return _add_channel_metadata(obs_table)
+
+
+def test_vectorized_filtering():
+ """Test that vectorized filtering works correctly."""
+ table = create_test_data()
+
+ # Test filtering using the unified filter_observations function
+ filtered_table = filter_observations(table, qc_filter=False)
+
+ # Test with QC filtering enabled
+ qc_filtered_table = filter_observations(table, qc_filter=True)
+
+ # Verify that filtering produces results
+ assert filtered_table.num_rows >= 0
+ assert qc_filtered_table.num_rows >= 0
+ assert qc_filtered_table.num_rows <= filtered_table.num_rows
diff --git a/examples/weather/healda/tests/test_obs_time_range_loader.py b/examples/weather/healda/tests/test_obs_time_range_loader.py
new file mode 100644
index 0000000000..f68e89d4b3
--- /dev/null
+++ b/examples/weather/healda/tests/test_obs_time_range_loader.py
@@ -0,0 +1,34 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ruff: noqa: S101
+import pandas as pd
+from datasets.obs_time_range_loader import Loader, _get_file_names
+
+
+def test_file_name():
+ assert _get_file_names(
+ "a",
+ ["atms"],
+ pd.Timestamp(2000, 1, 1, 1),
+ pd.Timestamp(2000, 1, 1, 23),
+ ) == ["a/atms/20000101/0.parquet", "a/atms/20000102/0.parquet"]
+
+
+def test_loader_get_empty():
+ columns = ("Latitude",)
+ loader = Loader(sensors=["amsua", "atms"], columns=columns)
+ out = loader._get_empty()
+ assert tuple(f.name for f in out.schema) == columns
diff --git a/examples/weather/healda/tests/test_prefetch_map.py b/examples/weather/healda/tests/test_prefetch_map.py
new file mode 100644
index 0000000000..761aed3265
--- /dev/null
+++ b/examples/weather/healda/tests/test_prefetch_map.py
@@ -0,0 +1,54 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ruff: noqa: S101
+
+import pytest
+import torch
+from datasets.prefetch_map import prefetch_map
+
+requires_cuda = pytest.mark.skipif(
+ not torch.cuda.is_available(), reason="CUDA not available"
+)
+
+
+@requires_cuda
+def test_prefetch_map_basic_functionality():
+ """Test basic async dataloader functionality with simple data."""
+ # Create simple test data using range
+ data = list(range(10)) # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+
+ # Simple transform that doubles the data
+ def transform(x):
+ return 2 * x
+
+ # Create async loader
+ async_loader = prefetch_map(data, transform)
+ assert list(async_loader) == list(range(0, 20, 2))
+
+
+@requires_cuda
+def test_prefetch_map_error_handling():
+ """Test error handling when transform raises an exception."""
+ data = list(range(4)) # [0, 1, 2, 3]
+
+ def failing_transform(x):
+ raise ValueError("Test error")
+
+ async_loader = prefetch_map(data, failing_transform)
+
+ # Should raise the exception from the background thread
+ with pytest.raises(ValueError, match="Test error"):
+ list(async_loader)
diff --git a/examples/weather/healda/tests/test_round_robin_loader.py b/examples/weather/healda/tests/test_round_robin_loader.py
new file mode 100644
index 0000000000..117a6af761
--- /dev/null
+++ b/examples/weather/healda/tests/test_round_robin_loader.py
@@ -0,0 +1,177 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ruff: noqa: S101
+"""Tests for RoundRobinLoader."""
+
+import torch
+import torch.utils.data
+from datasets.round_robin import RoundRobinLoader
+
+
+def test_round_robin_loader():
+ """Test RoundRobinLoader round-robin interleaving logic."""
+ # Create simple datasets
+ dataset1 = list(range(0, 10)) # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+ dataset2 = list(range(10, 15)) # [10, 11, 12, 13, 14]
+ dataset3 = list(range(15, 20)) # [15, 16, 17, 18, 19]
+
+ # Create dataloaders
+ loader1 = torch.utils.data.DataLoader(dataset1, batch_size=2)
+ loader2 = torch.utils.data.DataLoader(dataset2, batch_size=2)
+ loader3 = torch.utils.data.DataLoader(dataset3, batch_size=2)
+
+ # Create RoundRobinLoader
+ round_robin = RoundRobinLoader([loader1, loader2, loader3])
+
+ # Test __len__
+ assert len(round_robin) == len(loader1) + len(loader2) + len(loader3)
+ assert len(round_robin) == 5 + 3 + 3 # 11 total batches
+
+ # Collect all batches
+ batches = list(round_robin)
+
+ # Should have 11 batches total
+ assert len(batches) == 11
+
+ # First round: one batch from each loader
+ assert torch.equal(batches[0], torch.tensor([0, 1])) # from loader1
+ assert torch.equal(batches[1], torch.tensor([10, 11])) # from loader2
+ assert torch.equal(batches[2], torch.tensor([15, 16])) # from loader3
+
+ # Second round
+ assert torch.equal(batches[3], torch.tensor([2, 3])) # from loader1
+ assert torch.equal(batches[4], torch.tensor([12, 13])) # from loader2
+ assert torch.equal(batches[5], torch.tensor([17, 18])) # from loader3
+
+ # Third round - loader2 and loader3 exhausted, only loader1 continues
+ assert torch.equal(batches[6], torch.tensor([4, 5])) # from loader1
+ assert torch.equal(batches[7], torch.tensor([14])) # last batch from loader2
+ assert torch.equal(batches[8], torch.tensor([19])) # last batch from loader3
+
+ # Remaining batches from loader1
+ assert torch.equal(batches[9], torch.tensor([6, 7])) # from loader1
+ assert torch.equal(batches[10], torch.tensor([8, 9])) # from loader1
+
+
+def test_round_robin_loader_uneven_lengths():
+ """Test RoundRobinLoader with very uneven loader lengths."""
+ # Create datasets of very different sizes
+ dataset1 = list(range(0, 20)) # 10 batches
+ dataset2 = list(range(20, 22)) # 1 batch
+
+ loader1 = torch.utils.data.DataLoader(dataset1, batch_size=2)
+ loader2 = torch.utils.data.DataLoader(dataset2, batch_size=2)
+
+ round_robin = RoundRobinLoader([loader1, loader2])
+ batches = list(round_robin)
+
+ # Total should be 11 batches
+ assert len(batches) == 11
+
+ # First two batches alternate
+ assert torch.equal(batches[0], torch.tensor([0, 1]))
+ assert torch.equal(batches[1], torch.tensor([20, 21]))
+
+ # After loader2 is exhausted, only loader1 continues
+ assert torch.equal(batches[2], torch.tensor([2, 3]))
+ assert torch.equal(batches[3], torch.tensor([4, 5]))
+
+
+def test_round_robin_loader_empty():
+ """Test RoundRobinLoader with empty dataloaders list."""
+ round_robin = RoundRobinLoader([])
+ batches = list(round_robin)
+ assert len(batches) == 0
+
+
+def test_round_robin_loader_single():
+ """Test RoundRobinLoader with a single dataloader."""
+ dataset = list(range(0, 10))
+ loader = torch.utils.data.DataLoader(dataset, batch_size=3)
+
+ round_robin = RoundRobinLoader([loader])
+ batches = list(round_robin)
+
+ # Should have same batches as the original loader
+ expected = list(loader)
+ assert len(batches) == len(expected)
+ for b1, b2 in zip(batches, expected):
+ assert torch.equal(b1, b2)
+
+
+def test_round_robin_loader_epoch_handling():
+ """Test that RoundRobinLoader correctly handles epoch transitions with ChunkedDistributedSampler."""
+ from datasets.samplers import ChunkedDistributedSampler
+
+ # Create a simple dataset
+ dataset = list(range(100))
+
+ # Create samplers for two "workers"
+ sampler1 = ChunkedDistributedSampler(
+ dataset,
+ chunk_size=10,
+ num_replicas=2,
+ rank=0,
+ shuffle=True,
+ shuffle_within_chunk=False,
+ drop_last=False,
+ seed=42,
+ )
+ sampler2 = ChunkedDistributedSampler(
+ dataset,
+ chunk_size=10,
+ num_replicas=2,
+ rank=1,
+ shuffle=True,
+ shuffle_within_chunk=False,
+ drop_last=False,
+ seed=42,
+ )
+
+ # Create dataloaders
+ loader1 = torch.utils.data.DataLoader(dataset, batch_size=5, sampler=sampler1)
+ loader2 = torch.utils.data.DataLoader(dataset, batch_size=5, sampler=sampler2)
+
+ # Create RoundRobinLoader
+ round_robin = RoundRobinLoader([loader1, loader2])
+
+ # Epoch 1: Collect all batches
+ epoch1_batches = list(round_robin)
+ epoch1_count = len(epoch1_batches)
+
+ # Verify we got data
+ assert epoch1_count > 0, "Should have batches in epoch 1"
+
+ # Check that samplers auto-incremented their epoch
+ assert sampler1.epoch == 1, "Sampler 1 should have auto-incremented to epoch 1"
+ assert sampler2.epoch == 1, "Sampler 2 should have auto-incremented to epoch 1"
+
+ # Epoch 2: Iterate again - should work and get different shuffle order
+ epoch2_batches = list(round_robin)
+ epoch2_count = len(epoch2_batches)
+
+ # Should have same number of batches
+ assert epoch2_count == epoch1_count, "Should have same number of batches each epoch"
+
+ # Check that samplers auto-incremented again
+ assert sampler1.epoch == 2, "Sampler 1 should have auto-incremented to epoch 2"
+ assert sampler2.epoch == 2, "Sampler 2 should have auto-incremented to epoch 2"
+
+ # With shuffle=True, the order should be different across epochs
+ # (though this is probabilistic, with seed=42 and shuffle=True it should differ)
+ # We'll just verify we can iterate multiple times without error
+ epoch3_batches = list(round_robin)
+ assert len(epoch3_batches) == epoch1_count
diff --git a/examples/weather/healda/tests/test_sampler.py b/examples/weather/healda/tests/test_sampler.py
new file mode 100644
index 0000000000..21f72e2b18
--- /dev/null
+++ b/examples/weather/healda/tests/test_sampler.py
@@ -0,0 +1,180 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ruff: noqa: S101
+import itertools
+
+from datasets import samplers
+
+
+def test_InfiniteSequentialSampler():
+ ds = list(range(10))
+ sampler = samplers.InfiniteSequentialSampler(ds, shuffle_every=2)
+ iterator = iter(sampler)
+ i = next(iterator)
+ ip1 = next(iterator)
+ assert (i + 1) % 10 == ip1
+
+
+def test_ChunkedRandomSampler():
+ s = samplers.ChunkedDistributedSampler(list(range(100)), chunk_size=5)
+ it = iter(s)
+ visited = set()
+ for chunk in range(20):
+ last_i = 0
+ for i in range(5):
+ idx = next(it)
+ if i > 0:
+ assert idx - last_i == 1
+ last_i = idx
+ visited.add(idx)
+
+ no_repeats = len(visited) == 100
+ assert no_repeats
+
+
+def test_ChunkedDistributedSampler_with_islice():
+ """Test that ChunkedDistributedSampler works correctly with itertools.islice.
+
+ This test verifies the fix for a bug where calling iter(sampler) multiple times
+ would reset the iterator state, causing itertools.islice to restart from the
+ beginning instead of continuing from where it left off.
+ """
+ # Create a dataset with 100 items and chunk size of 10
+ dataset = list(range(100))
+ sampler = samplers.ChunkedDistributedSampler(
+ dataset, chunk_size=10, drop_last=False
+ )
+
+ # Test 1: Basic iteration works
+ iterator = iter(sampler)
+ first_10 = list(itertools.islice(iterator, 10))
+ assert len(first_10) == 10
+ assert first_10 == list(range(10)), f"Expected [0-9], got {first_10}"
+
+ # Test 2: The key bug fix - calling iter(sampler) again should continue from where we left off
+ # Before the fix, this would restart from the beginning
+ # After the fix, this should continue from where we left off
+ iterator2 = iter(sampler)
+ next_10 = list(itertools.islice(iterator2, 10))
+
+ # Verify we get the next 10 items (10-19), not the same as first_10
+ expected_next_10 = list(range(10, 20))
+ assert next_10 == expected_next_10, f"Expected {expected_next_10}, got {next_10}"
+ assert first_10 != next_10, "Iterator state was reset - this is the bug!"
+
+ # Test 3: Verify the items are sequential within chunks
+ # Items should be consecutive within each chunk of 10
+ for i in range(0, len(first_10), 10):
+ chunk = first_10[i : i + 10]
+ if len(chunk) > 1:
+ for j in range(1, len(chunk)):
+ assert chunk[j] - chunk[j - 1] == 1, (
+ f"Items not consecutive in chunk: {chunk}"
+ )
+
+ # Test 4: Verify we can continue iteration from the same iterator
+ iterator3 = iter(sampler)
+ first_5 = list(itertools.islice(iterator3, 5))
+ next_5 = list(itertools.islice(iterator3, 5))
+
+ # These should be consecutive (continuing from where iterator2 left off at 20)
+ expected_first_5 = list(range(20, 25))
+ expected_next_5 = list(range(25, 30))
+ assert first_5 == expected_first_5, f"Expected {expected_first_5}, got {first_5}"
+ assert next_5 == expected_next_5, f"Expected {expected_next_5}, got {next_5}"
+
+ # Test 5: Verify we can exhaust the iterator and it resets properly
+ iterator4 = iter(sampler)
+ all_items = list(iterator4)
+ # Should continue from where iterator3 left off (at 30)
+ expected_remaining = list(range(30, 100))
+ assert all_items == expected_remaining, (
+ f"Expected {expected_remaining}, got {all_items}"
+ )
+
+ # After exhaustion, a new iterator should start from the beginning
+ iterator5 = iter(sampler)
+ first_item = next(iterator5)
+ # The first item should be from chunk 0 (0-9 range)
+ assert 0 <= first_item < 10, f"Expected first item to be 0-9, got {first_item}"
+
+
+def test_shuffle_within_chunk():
+ """Test shuffle_within_chunk randomizes samples within each chunk."""
+ s = samplers.ChunkedDistributedSampler(
+ list(range(100)),
+ chunk_size=10,
+ shuffle=False, # Keep chunks in sequential order
+ shuffle_within_chunk=True,
+ seed=42,
+ )
+
+ indices = list(s)
+
+ # All indices should be present
+ assert sorted(indices) == list(range(100))
+
+ # First chunk should have indices 0-9 but shuffled
+ first_chunk = indices[:10]
+ assert sorted(first_chunk) == list(range(10))
+ assert first_chunk != list(range(10)), "Within-chunk shuffle should change order"
+
+
+def test_shuffle_epoch_changes_chunks():
+ """Test that epoch auto-increment causes different chunk order between epochs."""
+ s = samplers.ChunkedDistributedSampler(
+ list(range(100)),
+ chunk_size=10,
+ shuffle=True,
+ shuffle_within_chunk=True,
+ seed=42,
+ )
+
+ # First epoch - get first 10 indices
+ epoch1_indices = list(s)
+ epoch1_first_10 = sorted(epoch1_indices[:10])
+
+ # Second epoch (auto-incremented) - get first 10 indices
+ epoch2_indices = list(s)
+ epoch2_first_10 = sorted(epoch2_indices[:10])
+
+ # Both epochs should have all 100 indices
+ assert sorted(epoch1_indices) == list(range(100))
+ assert sorted(epoch2_indices) == list(range(100))
+
+ # First 10 indices should correspond to different chunks across epochs
+ assert epoch1_first_10 != epoch2_first_10, (
+ "First 10 indices should correspond to different chunks across epochs"
+ )
+
+
+def test_restartable_distributed_sampler_iteration():
+ """Test RestartableDistributedSampler iteration and epoch transitions."""
+ dataset = list(range(100))
+ sampler = samplers.RestartableDistributedSampler(
+ dataset, rank=0, num_replicas=2, seed=42
+ )
+ sampler.set_epoch(0)
+
+ # Test distributed splitting - rank 0 should get 50 items
+ assert len(sampler) == 50
+
+ # Iterate through one epoch
+ indices_epoch0 = list(sampler)
+ assert len(set(indices_epoch0)) == 50 # All unique
+
+ # After exhaustion, should auto-transition to next epoch
+ assert sampler.epoch == 1
diff --git a/examples/weather/healda/tests/test_training_stats.py b/examples/weather/healda/tests/test_training_stats.py
new file mode 100644
index 0000000000..1f33c62a7b
--- /dev/null
+++ b/examples/weather/healda/tests/test_training_stats.py
@@ -0,0 +1,66 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+import pytest
+import torch
+import torch.distributed as dist
+from torch.distributed import DeviceMesh
+from torch.distributed.tensor import DTensor, Replicate
+from training import training_stats
+
+requires_distributed = pytest.mark.skipif(
+ not torch.cuda.is_available() or os.environ.get("RANK") is None,
+ reason="Requires CUDA and distributed environment (RANK env var)",
+)
+
+
+@requires_distributed
+def test_dtensor_report_and_sync():
+ # Initialize distributed environment
+ if not dist.is_initialized():
+ dist.init_process_group(backend="nccl")
+
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ device = torch.device(f"cuda:{rank}")
+
+ print(f"[Rank {rank}] Initialized with world_size={world_size}")
+
+ # Initialize training_stats for multiprocessing
+ training_stats.init_multiprocessing(rank=rank, sync_device=device)
+
+ # Create a device mesh for DTensor
+ mesh = DeviceMesh("cuda", torch.arange(world_size))
+
+ # Create a DTensor with the correct dtype (float64 for _counter_dtype)
+ local_tensor = torch.randn(3, device=device, dtype=torch.float64)
+ dtensor = DTensor.from_local(local_tensor, mesh, [Replicate()])
+ # The bug is that DTensors can end up in _counters through various operations
+ # Let's directly inject one to simulate this scenario
+ # This represents what happens when DTensor operations preserve the DTensor type
+ metric_name = "test_metric"
+ if metric_name not in training_stats._counters:
+ training_stats._counters[metric_name] = dict()
+
+ # Inject the DTensor into _counters
+ # In the real bug, this happens through operations in report() that preserve DTensor type
+ training_stats._counters[metric_name][device] = dtensor
+
+ # Now trigger the sync which should cause the error
+ # This calls _sync() which tries to do: delta.add_(counter.to(device))
+ # where delta is a regular torch.Tensor but counter is a DTensor
+ training_stats.default_collector.update()
diff --git a/examples/weather/healda/tests/test_ufs_combined_schema.py b/examples/weather/healda/tests/test_ufs_combined_schema.py
new file mode 100644
index 0000000000..74775ca438
--- /dev/null
+++ b/examples/weather/healda/tests/test_ufs_combined_schema.py
@@ -0,0 +1,239 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ruff: noqa: S101
+"""
+Integration test for UFS Unified Loader with combined schema.
+"""
+
+import os
+import tempfile
+from datetime import datetime
+
+import numpy as np
+import pandas as pd
+import pyarrow as pa
+import pytest
+from datasets.etl.combined_schema import (
+ get_channel_table_schema,
+ get_combined_observation_schema,
+)
+from datasets.etl.etl_unified import get_channel_table
+from datasets.obs_loader import UFSUnifiedLoader
+from datasets.sensors import SENSOR_CONFIGS
+
+
+@pytest.fixture
+def temp_data_dir():
+ """Create temporary directory with sample data."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ # Create sensor directory
+ sensor_dir = os.path.join(temp_dir, "atms")
+ os.makedirs(sensor_dir, exist_ok=True)
+
+ # Create date directory
+ date_dir = os.path.join(sensor_dir, "20200101")
+ os.makedirs(date_dir, exist_ok=True)
+
+ # Create sample data with all required schema fields
+ n_obs = 50
+ data = {
+ # Common fields
+ "Latitude": np.random.uniform(-90, 90, n_obs).astype(np.float32),
+ "Longitude": np.random.uniform(-180, 180, n_obs).astype(np.float32),
+ "Absolute_Obs_Time": pd.date_range(
+ "2020-01-01", periods=n_obs, freq="1h"
+ ).astype("datetime64[ns]"),
+ "DA_window": pd.date_range("2020-01-01", periods=n_obs, freq="3h").astype(
+ "datetime64[ns]"
+ ),
+ "Platform_ID": np.random.randint(0, 32, n_obs).astype(np.uint16),
+ "Observation": np.random.uniform(0, 400, n_obs).astype(np.float32),
+ "Global_Channel_ID": np.random.randint(0, 100, n_obs).astype(np.uint16),
+ # Satellite-specific fields
+ "Sat_Zenith_Angle": np.random.uniform(0, 90, n_obs).astype(np.float32),
+ "Sol_Zenith_Angle": np.random.uniform(0, 90, n_obs).astype(np.float32),
+ "Scan_Angle": np.random.uniform(-45, 45, n_obs).astype(np.float32),
+ # Conventional fields (nullable)
+ "Pressure": np.full(n_obs, np.nan, dtype=np.float32),
+ "Height": np.full(n_obs, np.nan, dtype=np.float32),
+ "Observation_Type": np.full(n_obs, np.nan, dtype=np.uint16),
+ # Analysis fields (nullable)
+ "QC_Flag": np.random.randint(0, 2, n_obs).astype(np.int32),
+ "Analysis_Use_Flag": np.full(n_obs, np.nan, dtype=np.int8),
+ "Obs_Minus_Forecast_adjusted": np.random.uniform(-10, 10, n_obs).astype(
+ np.float32
+ ),
+ "Obs_Minus_Forecast_unadjusted": np.random.uniform(-10, 10, n_obs).astype(
+ np.float32
+ ),
+ }
+
+ # Write parquet file
+ schema = get_combined_observation_schema()
+ table = pa.table(data, schema=schema)
+ parquet_path = os.path.join(date_dir, "0.parquet")
+ pa.parquet.write_table(table, parquet_path)
+
+ # Create channel table for normalization
+ channel_table = get_channel_table()
+ channel_table_path = os.path.join(temp_dir, "channel_table.parquet")
+ pa.parquet.write_table(channel_table, channel_table_path)
+
+ # No need for availability_df.pkl - using try/catch approach
+
+ yield temp_dir
+
+
+@pytest.mark.parametrize("normalization", ["zscore", "minmax"])
+@pytest.mark.asyncio
+async def test_ufs_unified_loader(temp_data_dir, normalization):
+ """Test UFSUnifiedLoader basic functionality."""
+ # Initialize loader
+ loader = UFSUnifiedLoader(
+ data_path=temp_data_dir,
+ sensors=["atms"],
+ filesystem_type="local",
+ normalization=normalization,
+ )
+
+ # Test basic properties
+ assert loader.sensors == ["atms"]
+
+ # Test data loading
+ times = pd.DatetimeIndex([datetime(2020, 1, 1, 12)])
+ result = await loader.sel_time(times)
+
+ for result in result["obs_v2"]:
+ # Validate schema matches expected output schema
+ expected_schema = loader.output_schema
+ assert result.schema.equals(expected_schema)
+
+ # Check normalization (minmax should give [0,1] range) if data exists
+ if normalization == "minmax" and result.num_rows > 0:
+ df = result.to_pandas()
+ assert (df["Observation"] >= 0).all()
+ assert (df["Observation"] <= 1).all()
+
+
+@pytest.mark.parametrize("normalization", ["zscore", "minmax"])
+@pytest.mark.asyncio
+async def test_ufs_unified_loader_empty_dataset(normalization):
+ """Test UFSUnifiedLoader with empty dataset (no data files)."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ # Create empty directory structure
+ sensor_dir = os.path.join(temp_dir, "atms")
+ os.makedirs(sensor_dir, exist_ok=True)
+
+ # Initialize loader with empty directory
+ loader = UFSUnifiedLoader(
+ data_path=temp_dir,
+ sensors=["atms"],
+ filesystem_type="local",
+ normalization=normalization,
+ )
+
+ # Test basic properties
+ assert loader.sensors == ["atms"]
+
+ # Test data loading with empty dataset
+ times = pd.DatetimeIndex([datetime(2020, 1, 1, 12)])
+ result = await loader.sel_time(times)
+
+ result = result["obs_v2"][0]
+
+ # Check result structure - should return empty table with proper schema
+ assert isinstance(result, pa.Table)
+ assert result.num_rows == 0
+
+ # Validate schema matches expected output schema
+ expected_schema = loader.output_schema
+ assert result.schema.equals(expected_schema)
+
+
+def test_get_channel_table_structure():
+ """Test that get_channel_table returns correct table structure and schema."""
+ table = get_channel_table()
+
+ # Check that it's a PyArrow table
+ assert isinstance(table, pa.Table)
+
+ # Check schema matches expected channel table schema
+ expected_schema = get_channel_table_schema()
+ assert table.schema.equals(expected_schema)
+
+
+def test_get_channel_table_sensor_mapping():
+ """Test that sensor IDs and channel IDs are correctly mapped."""
+ table = get_channel_table()
+
+ # Convert to pandas for easier analysis
+ df = table.to_pandas()
+
+ # Calculate expected total channels
+ expected_total_channels = sum(cfg.channels for cfg in SENSOR_CONFIGS.values())
+ assert len(df) == expected_total_channels
+
+ # Check that Global_Channel_ID is sequential starting from 0
+ assert df["Global_Channel_ID"].min() == 0
+ assert df["Global_Channel_ID"].max() == expected_total_channels - 1
+ assert df["Global_Channel_ID"].is_monotonic_increasing
+
+ # Check sensor_id mapping
+ sensor_names = list(SENSOR_CONFIGS.keys())
+ for i, sensor_name in enumerate(sensor_names):
+ sensor_mask = df["sensor_id"] == i
+ expected_channels = SENSOR_CONFIGS[sensor_name].channels
+ assert sensor_mask.sum() == expected_channels
+
+
+def test_get_channel_table_conventional_handling():
+ """Test that conventional sensors are handled correctly."""
+ table = get_channel_table()
+ df = table.to_pandas()
+
+ # Find conventional sensor index
+ conv_sensor_id = list(SENSOR_CONFIGS.keys()).index("conv")
+ conv_mask = df["sensor_id"] == conv_sensor_id
+
+ # Check is_conv flag
+ assert df[conv_mask]["is_conv"].all()
+ assert not df[~conv_mask]["is_conv"].any()
+
+ # Check conventional sensor naming
+ conv_names = df[conv_mask]["name"].tolist()
+ expected_conv_names = ["gps_angle", "gps_t", "gps_q", "ps", "q", "t", "u", "v"]
+ assert conv_names == expected_conv_names
+
+
+def test_get_channel_table_consistency():
+ """Test that channel table is internally consistent."""
+ table = get_channel_table()
+ df = table.to_pandas()
+
+ # Check that all Global_Channel_IDs are unique
+ assert df["Global_Channel_ID"].nunique() == len(df)
+
+ # Check that sensor_id values are valid
+ max_sensor_id = len(SENSOR_CONFIGS) - 1
+ assert df["sensor_id"].min() >= 0
+ assert df["sensor_id"].max() <= max_sensor_id
+
+ # Check that min_valid <= max_valid for all channels
+ assert (df["min_valid"] <= df["max_valid"]).all()
+
+ # Check that all names are non-empty strings
+ assert df["name"].str.len().min() > 0
+ assert df["name"].notna().all()
diff --git a/examples/weather/healda/tests/test_visualizations.py b/examples/weather/healda/tests/test_visualizations.py
new file mode 100644
index 0000000000..f21a5cf3a5
--- /dev/null
+++ b/examples/weather/healda/tests/test_visualizations.py
@@ -0,0 +1,39 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ruff: noqa: S101
+import numpy as np
+import pytest
+from utils.visualizations import visualize
+
+
+def test_visualize():
+ """Test visualization with reasonable default settings."""
+ # Create a simple 1D input array
+ x = np.random.rand(12) # HEALPix level 1 has 12 pixels
+
+ # Test with reasonable defaults
+ im = visualize(
+ x,
+ region="Robinson",
+ title="Test Visualization",
+ cmap="viridis",
+ add_colorbar=True,
+ )
+ assert im is not None
+
+ # Test that invalid input raises error
+ with pytest.raises(ValueError):
+ visualize(np.random.rand(4, 3)) # 2D input should raise ValueError
diff --git a/examples/weather/healda/train.py b/examples/weather/healda/train.py
new file mode 100644
index 0000000000..72339c8a5a
--- /dev/null
+++ b/examples/weather/healda/train.py
@@ -0,0 +1,790 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Train DA model for observation-to-state regression.
+
+This is the training script for the HealDA data assimilation model.
+It trains a ViT model to regress atmospheric state from observations.
+
+Usage:
+ python train_da.py --name era5-v2-dense-noInfill-10M-fusion512-lrObs1e-4
+"""
+
+import dataclasses
+import functools
+import json
+import logging
+import os
+import warnings
+
+import config.environment as config
+import matplotlib.pyplot as plt
+import models
+import torch
+import torch.distributed
+import torch.utils
+import torch.utils.data
+import training.loop
+from datasets.base import BatchInfo, TimeUnit, VariableConfig
+from datasets.dataset import (
+ VARIABLE_CONFIGS,
+ get_batch_info,
+ get_dataset,
+ get_sensors_for_config,
+)
+from datasets.prefetch_map import prefetch_map
+from datasets.round_robin import RoundRobinLoader
+from datasets.samplers import (
+ ChunkedDistributedSampler,
+)
+from datasets.sensors import (
+ PLATFORM_NAME_TO_ID,
+ SENSOR_CONFIGS,
+ SENSOR_NAME_TO_ID,
+)
+from datasets.transform import TransformV2, collate
+from training import loop
+from utils import distributed as dist
+from utils.dataclass_parser import parse_args, parse_dict
+from utils.signals import finish_before_quitting
+from utils.visualizations import visualize
+
+from physicsnemo.experimental.models.healda import (
+ ModelSensorConfig,
+ SensorEmbedderConfig,
+)
+from config.model_config import ModelConfigV1, ObsConfig
+from utils import profiling
+
+logger = logging.getLogger(__name__)
+
+
+def build_sensor_config(sensor_names: list[str]) -> dict[str, ModelSensorConfig]:
+ """
+ Args:
+ sensor_names: List of sensor names to include in the model
+
+ Returns:
+ dict mapping sensor_name to ModelSensorConfig
+ """
+ return {
+ name: ModelSensorConfig(
+ sensor_id=SENSOR_NAME_TO_ID[name],
+ nchannel=SENSOR_CONFIGS[name].channels,
+ platform_ids=tuple(
+ PLATFORM_NAME_TO_ID[p] for p in SENSOR_CONFIGS[name].platforms
+ ),
+ )
+ for name in sensor_names
+ }
+
+
+@dataclasses.dataclass
+class DistributedConfig:
+ rank: int
+ world_size: int
+
+
+@dataclasses.dataclass
+class TrainingLoop(loop.TrainingLoopBase):
+ """
+ Training loop for observation-to-state DA model.
+
+ valid_samples_per_season: the number of samples to use when making season
+ average plots
+ """
+
+ valid_min_samples: int = 128
+
+ # loss options
+ loss_type: str = "mse"
+ huber_delta: float = 0.1
+
+ # data loader options
+ dataloader_num_workers: int = 3
+ dataloader_prefetch_factor: int = 8
+ prefetch_to_gpu: bool = True
+
+ # data options
+ embed_v2: bool = True # DEPRECATED: ignored, derived from obs_config.use_obs
+
+ label_dropout: float = 0.0
+ legacy_label_bias: bool = (
+ False # For loading old checkpoints with trained label bias
+ )
+ era5_chunk_size: int = 48
+ start_year: int = -1 # Filter training data to >= this year
+ obs_config: ObsConfig = ObsConfig()
+
+ # model configuration
+ model_channels: int = 256
+ opt: str = "adamw"
+ adam_eps: float = 1e-8
+ adam_beta2: float = 0.95
+ group_norm_eps: float = 1e-6
+ lr_obs: float = 1e-4
+ weight_decay: float = 0.1
+ weight_decay_biases: bool = True
+ drop_path: float = 0.0
+ p_dropout: float = 0.0
+ architecture: str = "dit-l_reg_hpx6_per_sensor"
+ as_vit: bool = False
+ gradient_checkpointing: bool = False
+ pos_emb_gains: bool = False
+ compile_dit: bool = False # Enable torch.compile for DiT _forward_DiT
+ dit_qk_rms_norm: bool = False
+ emb_channels: int | None = None
+ noise_channels: int | None = None
+
+ # When True, apply a custom gradient clipping schedule that
+ # linearly decays the clip value from 1.0 โ 0.015 over the
+ # first 50k images, then keeps it at 0.015 afterwards.
+ use_gradient_clip_schedule: bool = False
+
+ sensor_embedder_config: SensorEmbedderConfig | None = SensorEmbedderConfig()
+
+ finetune_from: str = ""
+ finetune_optimizer: bool = False
+ freeze_obs_embed: bool = False
+ freeze_transformer_blocks: bool = False
+ freeze_decoder: bool = False
+ freeze_pos_embedding: bool = False
+
+ # change defaults for parameter norm logging
+ log_parameter_norm: bool = False
+ log_parameter_grad_norm: bool = False
+
+ def __post_init__(self):
+ super().__post_init__()
+ self._train_sampler = None
+ self._test_sampler = None
+
+ @property
+ def variable_config(self) -> VariableConfig:
+ return VARIABLE_CONFIGS["era5"]
+
+ @functools.cached_property
+ def batch_info(self) -> BatchInfo:
+ return get_batch_info(
+ config=self.variable_config,
+ time_unit=TimeUnit.HOUR,
+ )
+
+ def resume_from_state(
+ self, resume_state_dump, optimizer=True, require_all=True, wandb=False
+ ):
+ super().resume_from_state(resume_state_dump, optimizer, require_all, wandb)
+ self._load_wandb_id()
+ dist.print0(f"Loaded checkpoint from {resume_state_dump}.")
+
+ def _save_wandb_id(self):
+ if self.wandb_id is not None:
+ with open(os.path.join(self.run_dir, "wandb_id"), "w") as f:
+ f.write(self.wandb_id)
+
+ def _load_wandb_id(self):
+ try:
+ with open(os.path.join(self.run_dir, "wandb_id")) as f:
+ self.wandb_id = f.read()
+ except FileNotFoundError:
+ pass
+
+ def save_training_state(self, cur_nimg):
+ if dist.get_rank() != 0:
+ return
+ super().save_training_state(cur_nimg)
+ self._save_wandb_id()
+
+ def save_network_snapshot(self, cur_nimg):
+ if dist.get_rank() != 0:
+ return
+ super().save_network_snapshot(cur_nimg)
+
+ @property
+ def out_channels(self):
+ return len(self.batch_info.channels)
+
+ def setup(self):
+ super().setup()
+ self.net.gradient_checkpointing = self.gradient_checkpointing
+
+ @profiling.nvtx
+ @finish_before_quitting
+ def step_optimizer(self, cur_nimg):
+ """Optionally apply a scheduled gradient clipping value, then
+ delegate to the base implementation for LR scheduling and optimizer step.
+ """
+ if self.use_gradient_clip_schedule:
+ start_clip = 1.0
+ end_clip = 0.015
+ schedule_end = 50_000
+
+ n = max(0, min(cur_nimg, schedule_end))
+ if schedule_end > 0:
+ frac = n / schedule_end
+ else:
+ frac = 1.0
+
+ self.gradient_clip_max_norm = start_clip + (end_clip - start_clip) * frac
+
+ super().step_optimizer(cur_nimg)
+
+ @functools.cached_property
+ def _data_transform(self):
+ """Get the appropriate data transform."""
+ return TransformV2(
+ variable_config=self.variable_config,
+ )
+
+ def get_dataset(self, train: bool):
+ """Returns the dataset for training or validation."""
+ return get_dataset(
+ dataset="era5",
+ split="train" if train else "test",
+ transform=None,
+ batch_transform=self._data_transform.transform,
+ rank=self.distributed_config.rank,
+ world_size=self.distributed_config.world_size,
+ infinite=True,
+ shuffle=True,
+ chunk_size=self.era5_chunk_size,
+ obs_config=self.obs_config,
+ start_year=self.start_year,
+ map_style=True,
+ )
+
+ @property
+ def distributed_config(self) -> DistributedConfig:
+ return DistributedConfig(dist.get_rank(), dist.get_world_size())
+
+ def _create_dataloader(
+ self,
+ dataset,
+ sampler,
+ batch_size,
+ num_workers=None,
+ prefetch_factor=None,
+ pin_memory=True,
+ ):
+ """Helper to create a DataLoader with common settings."""
+ if num_workers is None:
+ num_workers = self.dataloader_num_workers
+
+ return torch.utils.data.DataLoader(
+ dataset,
+ sampler=sampler,
+ prefetch_factor=prefetch_factor if num_workers > 0 else None,
+ multiprocessing_context="spawn" if num_workers > 0 else None,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ collate_fn=collate,
+ pin_memory=pin_memory,
+ persistent_workers=True if num_workers > 0 else False,
+ in_order=True,
+ )
+
+ def _get_loader(self, dataset, batch_size, train: bool = True):
+ workers = self.dataloader_num_workers
+ prefetch_factor = self.dataloader_prefetch_factor
+ if not train and workers != 0:
+ workers = 1
+ prefetch_factor = 4
+
+ if isinstance(dataset, torch.utils.data.IterableDataset):
+ # Iterable datasets don't use samplers
+ loader = self._create_dataloader(
+ dataset, sampler=None, batch_size=batch_size
+ )
+
+ else:
+ # Round-robin loader: one dataloader per worker, each with its own chunk assignment.
+ # This is optimal for chunked zarr data where sequential access within chunks is fast.
+ #
+ # Note: Currently restarts with same sample order every time (default seed).
+ # If we drop chunking in the future, switch to RestartableDistributedSampler
+ # with proper checkpointing for reproducible training restarts.
+ num_loaders = max(workers, 1)
+ dataloaders = []
+
+ for worker_id in range(num_loaders):
+ worker_sampler = ChunkedDistributedSampler(
+ dataset,
+ chunk_size=self.era5_chunk_size,
+ num_replicas=self.distributed_config.world_size * num_loaders,
+ rank=self.distributed_config.rank * num_loaders + worker_id,
+ shuffle=True,
+ shuffle_within_chunk=True,
+ drop_last=True,
+ )
+
+ worker_loader = self._create_dataloader(
+ dataset,
+ sampler=worker_sampler,
+ batch_size=batch_size,
+ num_workers=1, # single worker per loader
+ prefetch_factor=prefetch_factor,
+ )
+ dataloaders.append(worker_loader)
+ loader = RoundRobinLoader(dataloaders)
+
+ # transferring the obs data from cpu -> gpu can be slow, so
+ # running it in a separate thread using prefetch_map improves utilization
+ if self.prefetch_to_gpu:
+ loader = prefetch_map(loader, self._device_transform, queue_size=2)
+ return loader
+
+ def _device_transform(self, batch):
+ """Transformations to occur on device in a separate thread. including device movement"""
+ return self._data_transform.device_transform(batch, device=self.device)
+
+ def get_data_loaders(self, batch_gpu):
+ """Create train and test DataLoaders"""
+ dataset = self.get_dataset(train=True)
+ train_loader = self._get_loader(dataset, batch_size=batch_gpu, train=True)
+ test_dataset = self.get_dataset(train=False)
+ test_loader = self._get_loader(test_dataset, batch_size=batch_gpu, train=False)
+
+ self._test_dataset = test_dataset
+ return dataset, train_loader, test_loader
+
+ def _step(
+ self,
+ *,
+ train=True,
+ plot_image=False,
+ target: torch.Tensor,
+ condition,
+ second_of_day,
+ day_of_year,
+ unified_obs=None,
+ labels=None,
+ return_both=False,
+ timestamp,
+ **batch,
+ ):
+ b, c, t, x = target.shape
+ noise_labels = torch.zeros([b], device=target.device)
+
+ prediction = self.ddp(
+ condition,
+ noise_labels=noise_labels,
+ class_labels=labels,
+ second_of_day=second_of_day,
+ day_of_year=day_of_year,
+ unified_obs=unified_obs,
+ timestamp=timestamp,
+ )
+ pred = prediction.out
+
+ train_tag = "train" if train else "test"
+
+ # log per channel norm of training target and prediction
+ for c in range(len(self.batch_info.channels)):
+ channel = self.batch_info.channels[c]
+ self.log_metric(f"norm/{channel}/target", target[:, c].norm(), print=False)
+ self.log_metric(
+ f"norm/{channel}/pred_{train_tag}", pred[:, c].norm(), print=False
+ )
+ self.log_metric(
+ f"max/{channel}/target", target[:, c].abs().max(), print=False
+ )
+ self.log_metric(
+ f"max/{channel}/pred_{train_tag}", pred[:, c].abs().max(), print=False
+ )
+
+ mse = (target - pred) ** 2
+ huber_loss = torch.nn.functional.huber_loss(
+ target, pred, reduction="none", delta=self.huber_delta
+ )
+
+ self.log_metric(f"Loss/{train_tag}_mse", mse)
+ self.log_metric(f"Loss/{train_tag}_huber", huber_loss)
+
+ scales = torch.as_tensor(self.batch_info.scales)[:, None, None].to(self.device)
+ centers = torch.as_tensor(self.batch_info.center)[:, None, None].to(self.device)
+
+ metrics_pred = pred * scales + centers
+ metrics_target = target * scales + centers
+ # Compute MSE for logging purposes in physical units
+ full_mse_physical = (metrics_target - metrics_pred) ** 2
+
+ for c in range(len(self.batch_info.channels)):
+ channel = self.batch_info.channels[c]
+ this_rmse = torch.sqrt(full_mse_physical[:, c, -1].mean())
+ self.log_metric(f"rmse/{channel}/{train_tag}", this_rmse)
+ self.log_metric(
+ f"huber/{channel}/{train_tag}",
+ huber_loss[:, c, -1].mean(),
+ print=False,
+ )
+
+ if plot_image and dist.get_rank() == 0:
+ for name, field in zip(
+ ["prediction", "target"], [metrics_pred, metrics_target]
+ ):
+ fig = plt.figure()
+ display_field = field[0, c, -1].cpu()
+ visualize(
+ display_field,
+ hpxpad=True,
+ title=channel,
+ )
+ self.writer.add_figure(
+ f"sample/{channel}/{name}", fig, global_step=self.cur_nimg
+ )
+ loss = mse if self.loss_type == "mse" else huber_loss
+
+ if train:
+ self.log_metric("loss", loss, frequency="step")
+
+ if return_both:
+ return mse, huber_loss
+ else:
+ return loss
+
+ def train_step(self, **batch):
+ return self._step(train=True, **batch)
+
+ def test_step(self, **batch):
+ return self._step(train=False, **batch)
+
+ @classmethod
+ def loads(cls, s):
+ fields = json.loads(s)
+ # remove augment_kwargs if present
+ # this is in some older checkpoint
+ fields.pop("split_weight_decay", None)
+ fields.pop("patch_embed_lr", None)
+ return parse_dict(cls, fields)
+
+ def _load_net_state(self, checkpoint, require_all):
+ self.net.load_from_checkpoint(checkpoint)
+
+ @property
+ def model_config(self) -> ModelConfigV1:
+ out_channels = self.out_channels
+ label_dim = 0 # labels not used
+
+ condition_channels = 2 # orog and lfrac static variables
+
+ sensor_embedder_config = None
+ sensors_dict = None
+ if self.obs_config.use_obs and ("per_sensor" in self.architecture):
+ sensor_names = get_sensors_for_config(self.obs_config)
+ sensors_dict = build_sensor_config(sensor_names)
+ sensor_embedder_config = self.sensor_embedder_config
+
+ return models.ModelConfigV1(
+ architecture=self.architecture,
+ condition_channels=condition_channels,
+ out_channels=out_channels,
+ label_dim=label_dim,
+ label_dropout=self.label_dropout,
+ legacy_label_bias=self.legacy_label_bias,
+ obs_config=self.obs_config,
+ p_dropout=self.p_dropout,
+ drop_path=self.drop_path,
+ group_norm_eps=self.group_norm_eps,
+ pos_emb_gains=self.pos_emb_gains,
+ sensor_embedder_config=sensor_embedder_config,
+ sensors=sensors_dict,
+ qk_rms_norm=self.dit_qk_rms_norm,
+ allow_nans_condition=False,
+ compile_dit=self.compile_dit,
+ as_vit=self.as_vit,
+ emb_channels=self.emb_channels,
+ noise_channels=self.noise_channels,
+ )
+
+ def _setup_networks(self):
+ torch.manual_seed(self.seed)
+ net = self.get_network()
+ net.train()
+ net.requires_grad_(True)
+ net.to(self.device)
+
+ if self.freeze_obs_embed:
+ net.embed_obs.requires_grad_(False)
+
+ if self.freeze_transformer_blocks:
+ for block in net.transformer_blocks:
+ block.requires_grad_(False)
+
+ if self.freeze_pos_embedding:
+ net.pos_embed.pos_embed.requires_grad_(False)
+
+ if self.freeze_decoder:
+ net.patch_decode.requires_grad_(False)
+
+ self.net = net
+ if dist.get_world_size() > 1:
+ self.ddp = torch.nn.parallel.DistributedDataParallel(
+ self.net,
+ device_ids=[self.device],
+ broadcast_buffers=False,
+ )
+ else:
+ self.ddp = self.net
+
+ def get_optimizer(self, named_parameters):
+ """Builds optimizer, applying differential learning rate to observation embeddings and transformer blocks"""
+ named_params = list(named_parameters)
+ # Separate the obs embedding and transformer parameters to apply
+ # lower learning rate to the obs embedding
+ obs_param_prefix = "embed_v2_patch"
+
+ def _get_param_groups(params, lr):
+ if self.weight_decay_biases:
+ return ({"params": params, "lr": lr, "base_lr": lr},)
+
+ weights, biases = [], []
+
+ for param in params:
+ if param.ndim > 1:
+ weights.append(param)
+ else:
+ biases.append(param)
+
+ return [
+ {
+ "params": weights,
+ "lr": lr,
+ "base_lr": lr,
+ "weight_decay": self.weight_decay,
+ },
+ {"params": biases, "lr": lr, "base_lr": lr, "weight_decay": 0.0},
+ ]
+
+ xfmr_params, obs_params = [], []
+ for name, param in named_params:
+ if name.startswith(obs_param_prefix):
+ obs_params.append(param)
+ else:
+ xfmr_params.append(param)
+
+ param_groups = []
+ if xfmr_params:
+ param_groups.extend(_get_param_groups(xfmr_params, self.lr))
+ if obs_params and not self.freeze_obs_embed:
+ param_groups.extend(_get_param_groups(obs_params, self.lr_obs))
+
+ if self.opt == "adamw":
+ return torch.optim.AdamW(
+ param_groups,
+ betas=(0.9, self.adam_beta2),
+ eps=self.adam_eps,
+ weight_decay=self.weight_decay,
+ fused=True,
+ )
+ else:
+ return torch.optim.Adam(
+ param_groups,
+ betas=(0.9, self.adam_beta2),
+ eps=self.adam_eps,
+ fused=True,
+ )
+
+ def get_loss_fn(self):
+ """Return loss function."""
+ return None
+
+ @staticmethod
+ def print_network_info(net, device):
+ num_params = sum(p.numel() for p in net.parameters())
+ dist.print0(f"Number of parameters: {num_params}. Network: {net}")
+
+ def validate(self, net=None):
+ if net is None:
+ net = self.net
+ net.eval()
+
+ for batch_num, batch in enumerate(self.valid_loader):
+ if batch_num * self.batch_size >= self.valid_min_samples:
+ break
+ batch = self._stage_dict_batch(batch)
+ with torch.no_grad():
+ with torch.autocast(
+ device_type="cuda", dtype=torch.bfloat16, enabled=self.bf16
+ ):
+ self.test_step(plot_image=batch_num == 0, return_both=True, **batch)
+
+
+@dataclasses.dataclass
+class CLI:
+ """Command-line interface config"""
+
+ name: str = ""
+ output_dir: str = config.CHECKPOINT_ROOT
+ finetune_from: str = ""
+ resume_dir: str = ""
+ loop: TrainingLoop = dataclasses.field(
+ default_factory=lambda: TrainingLoop(
+ architecture="dit-l_reg_hpx6_per_sensor",
+ legacy_label_bias=True,
+ batch_size=8,
+ batch_gpu=1,
+ lr=0.0005,
+ lr_obs=0.0001,
+ lr_rampup_img=50000,
+ flat_imgs=0,
+ decay_imgs=10000000,
+ lr_min=0.0,
+ gradient_clip_max_norm=1.0,
+ steps_per_tick=2500,
+ snapshot_ticks=100,
+ state_dump_ticks=2,
+ print_steps=1,
+ loss_type="huber",
+ loss_reduction="v1",
+ huber_delta=0.1,
+ dataloader_num_workers=5,
+ dataloader_prefetch_factor=12,
+ total_ticks=250,
+ era5_chunk_size=24,
+ weight_decay=0.05,
+ drop_path=0.1,
+ p_dropout=0.05,
+ compile_dit=True,
+ obs_config=ObsConfig(
+ use_obs=True,
+ innovation_type="none",
+ context_start=-21,
+ context_end=3,
+ use_conv=True,
+ dropout=0.0,
+ conv_uv_in_situ_only=False,
+ conv_gps_level1_only=False,
+ ),
+ embed_v2=True,
+ dit_qk_rms_norm=True,
+ sensor_embedder_config=SensorEmbedderConfig(
+ embed_dim=32,
+ fusion_dim=512,
+ use_channel_platform_embedding_table=False,
+ ),
+ )
+ )
+
+
+warnings.filterwarnings(action="ignore", message="Cannot do a zero-copy NCHW to NHWC.")
+
+
+LOOPS = {}
+
+# HealDA v1 configuration: ERA5 observation-to-state training
+LOOPS["era5-v2-dense-noInfill-10M-fusion512-lrObs1e-4"] = TrainingLoop(
+ architecture="dit-l_reg_hpx6_per_sensor",
+ legacy_label_bias=True,
+ batch_size=8,
+ batch_gpu=1,
+ lr=0.0005,
+ lr_obs=0.0001,
+ lr_rampup_img=50000,
+ flat_imgs=0,
+ decay_imgs=10000000,
+ lr_min=0.0,
+ gradient_clip_max_norm=1.0,
+ steps_per_tick=2500,
+ snapshot_ticks=100,
+ state_dump_ticks=2,
+ print_steps=1,
+ loss_type="huber",
+ loss_reduction="v1",
+ huber_delta=0.1,
+ dataloader_num_workers=5,
+ dataloader_prefetch_factor=12,
+ total_ticks=250,
+ era5_chunk_size=24,
+ weight_decay=0.05,
+ drop_path=0.1,
+ p_dropout=0.05,
+ compile_dit=True,
+ obs_config=ObsConfig(
+ use_obs=True,
+ innovation_type="none",
+ context_start=-21,
+ context_end=3,
+ use_conv=True,
+ dropout=0.0,
+ conv_uv_in_situ_only=False,
+ conv_gps_level1_only=False,
+ ),
+ embed_v2=True,
+ dit_qk_rms_norm=True,
+ sensor_embedder_config=SensorEmbedderConfig(
+ embed_dim=32,
+ fusion_dim=512,
+ use_channel_platform_embedding_table=False,
+ ),
+)
+
+
+def main():
+ cli = parse_args(CLI, convert_underscore_to_hyphen=False)
+ dist.init()
+
+ if dist.get_rank() == 0:
+ logging.basicConfig(level=logging.INFO)
+ training.loop.logger.setLevel(level=logging.DEBUG)
+
+ try:
+ dist.print0(f"Using {cli.name=} preset.")
+ loop = LOOPS[cli.name]
+ except KeyError:
+ dist.print0("Using --loop command line arguments")
+ loop = cli.loop
+
+ loop.run_dir = os.path.join(cli.output_dir, cli.name)
+ loop.setup()
+ dist.print0("Training with:", loop)
+
+ if dist.get_rank() == 0:
+ config.print_config()
+
+ new_training = True
+ # attempt resuming from output-dir, and then try the resume_dir CLI
+ # this behavoir makes it easy to submit multiple segments of the run using
+ # the same CLI arguments
+ resume_dirs_in_priority = [loop.run_dir, cli.resume_dir]
+ for rundir in resume_dirs_in_priority:
+ try:
+ loop.resume_from_rundir(rundir, require_all=False)
+ new_training = False
+ break
+ except FileNotFoundError:
+ pass
+
+ if new_training and loop.finetune_from:
+ loop.resume_from_state(
+ loop.finetune_from, optimizer=loop.finetune_optimizer, require_all=False
+ )
+
+ if new_training:
+ loop.wandb_id = None
+ dist.print0("Starting new training")
+
+ try:
+ loop.cur_nimg = loop._cur_nimg_start
+ except AttributeError:
+ pass
+
+ loop.setup_wandb(name=cli.name)
+ loop.train()
+ torch.distributed.destroy_process_group()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/weather/healda/training/__init__.py b/examples/weather/healda/training/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/weather/healda/training/loop.py b/examples/weather/healda/training/loop.py
new file mode 100644
index 0000000000..b83c783458
--- /dev/null
+++ b/examples/weather/healda/training/loop.py
@@ -0,0 +1,858 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import abc
+import contextlib
+import dataclasses
+import gc
+import glob
+import itertools
+import json
+import logging
+import os
+import re
+import shutil
+import signal
+import time
+import warnings
+from functools import partial
+from typing import Iterable, Union
+
+import models
+import numpy as np
+import psutil
+import torch
+import torch.utils.tensorboard
+import torchmetrics
+from config.training import loop
+from datasets.base import BatchInfo, SpatioTemporalDataset
+from utils import checkpointing
+from utils import distributed as dist
+from utils.signals import QuitEarly, finish_before_quitting, handler
+
+from config.model_config import ModelConfigV1
+from utils import profiling
+
+from . import training_stats
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+DATASET_METADATA_FILENAME = "dataset-metadata.pth"
+TRAINER_METADATA_FILENAME = "loop.json"
+
+
+logger = logging.getLogger(__name__)
+
+
+# ----------------------------------------------------------------------------
+# Context manager for easily enabling/disabling DistributedDataParallel
+# synchronization.
+
+
+@contextlib.contextmanager
+def ddp_sync(module, sync):
+ if not isinstance(module, torch.nn.Module):
+ raise TypeError("module must be a torch.nn.Module")
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
+ yield
+ else:
+ with module.no_sync():
+ yield
+
+
+# ----------------------------------------------------------------------------
+
+
+def _to_batch(x, device, non_blocking=True):
+ if isinstance(x, dict):
+ return {
+ k: _to_batch(v, device, non_blocking=non_blocking) for k, v in x.items()
+ }
+ elif isinstance(x, list):
+ return [_to_batch(i, device, non_blocking=non_blocking) for i in x]
+ elif torch.is_tensor(x):
+ if torch.is_floating_point(x):
+ x = x.float()
+ return x.to(device, non_blocking=non_blocking)
+ elif hasattr(x, "to") and callable(getattr(x, "to")):
+ # custom object with a 'to' method
+ return x.to(device, non_blocking=non_blocking)
+ elif dataclasses.is_dataclass(x):
+ return x.__class__(
+ **{
+ field.name: _to_batch(
+ getattr(x, field.name), device, non_blocking=non_blocking
+ )
+ for field in dataclasses.fields(x)
+ }
+ )
+ else:
+ raise NotImplementedError(x)
+
+
+def _format_time(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
+ else:
+ return "{0}d {1:02}h {2:02}m".format(
+ s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60
+ )
+
+
+class CheckpointHandler:
+ """Manages checkpoint file naming and paths."""
+
+ def __init__(self, run_dir, filename: str = "training-state-{}.checkpoint"):
+ self.filename = filename
+ self.run_dir = run_dir
+
+ def get_filename(self, nimg):
+ """Return checkpoint filename for given image count."""
+ return self.filename.format("%09d" % nimg)
+
+ def get_path(self, nimg):
+ """Return full checkpoint path for given image count."""
+ return os.path.join(self.run_dir, self.get_filename(nimg))
+
+ def list_checkpoints(self, run_dir=None):
+ run_dir = run_dir or self.run_dir
+ files = glob.glob(self.filename.format("*"), root_dir=run_dir)
+ pattern = self.filename.format(r"(\d{9})")
+ files = sorted(files)
+ for file in files:
+ m = re.match(pattern, file)
+ if m:
+ nimg = int(m.group(1))
+ yield os.path.join(run_dir, file), nimg
+
+
+@dataclasses.dataclass
+class TrainingLoopBase(loop.TrainingLoopBase, abc.ABC):
+ """Abstract base class for diffusion trainings loops
+
+ Implementations should define
+ - get_data_loaders
+ - get_network
+
+ """
+
+ device: torch.device | None = None
+
+ def __post_init__(self):
+ if self.steps_per_tick <= 0:
+ ValueError(self.steps_per_tick)
+
+ self._metrics_to_print = set()
+ self.ema: torch.nn.Module | None = None
+ self.iteration = 0
+ self.do_wandb = False
+ self._wandb_run = None
+
+ @abc.abstractmethod
+ def get_data_loaders(
+ self, batch_gpu: int
+ ) -> tuple[SpatioTemporalDataset, Iterable, Iterable]:
+ """Returns dataset, training loader, and validation loader."""
+ pass
+
+ def get_network(self) -> torch.nn.Module:
+ """Instantiates the model from config."""
+ return models.get_model(self.model_config)
+
+ @abc.abstractmethod
+ def get_optimizer(self, parameters):
+ """Returns network optimizer"""
+ pass
+
+ @abc.abstractmethod
+ def get_loss_fn(self):
+ """Returns the loss function."""
+ pass
+
+ @property
+ def model_config(self) -> ModelConfigV1 | None:
+ """Model configuration used for the network. This is used for checkpointing.
+
+ If you are overriding get_network, then be sure to make this consistent.
+ """
+ return None
+
+ def _setup_datasets(self):
+ self.dataset_obj, self.train_loader, self.valid_loader = self.get_data_loaders(
+ self.batch_gpu
+ )
+ if self.test_with_single_batch:
+ self.train_loader = itertools.repeat(next(iter(self.train_loader)))
+ self.valid_loader = [next(iter(self.valid_loader))]
+
+ def _setup_networks(self):
+ self.ddp = self.net = self.get_network()
+ self.net.train().requires_grad_(True).to(self.device)
+ if dist.get_world_size() > 1:
+ self.ddp = torch.nn.parallel.DistributedDataParallel(
+ self.net,
+ device_ids=[self.device],
+ broadcast_buffers=False,
+ )
+
+ @profiling.nvtx
+ def log_tick(
+ self,
+ maintenance_time,
+ tick_start_time,
+ tick_end_time,
+ start_time,
+ cur_tick,
+ cur_nimg,
+ ):
+ # Print status line, accumulating the same information in training_stats.
+ images_per_tick = self.steps_per_tick * self.batch_size
+ fields = []
+ fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
+ fields += [
+ f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"
+ ]
+ fields += [
+ f"time {_format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"
+ ]
+ fields += [
+ f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"
+ ]
+ fields += [
+ f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / images_per_tick * 1e3):<7.2f}"
+ ]
+ fields += [
+ f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"
+ ]
+ fields += [
+ f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"
+ ]
+ fields += [
+ f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(self.device) / 2**30):<6.2f}"
+ ]
+ fields += [
+ f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(self.device) / 2**30):<6.2f}"
+ ]
+ torch.cuda.reset_peak_memory_stats()
+ dist.print0(" ".join(fields))
+
+ def setup_logs(self):
+ if dist.get_rank() != 0:
+ logger.setLevel(logging.CRITICAL)
+
+ self.writer = torch.utils.tensorboard.SummaryWriter(self.run_dir)
+ self._step_metrics = {}
+
+ @property
+ def batch_gpu_total(self) -> int:
+ world_size: int = dist.get_world_size()
+ return self.batch_size // world_size
+
+ def setup_batching(self):
+ # Select batch size per GPU.
+ if self.batch_gpu is None or self.batch_gpu > self.batch_gpu_total:
+ self.batch_gpu = self.batch_gpu_total
+
+ if self.batch_gpu_total % self.batch_gpu != 0:
+ raise ValueError()
+
+ num_accumulation_rounds = self.batch_gpu_total // self.batch_gpu
+ self.num_accumulation_rounds = num_accumulation_rounds
+
+ @staticmethod
+ def print_network_info(net, device):
+ pass
+
+ def _load_iterator_state(self, checkpoint):
+ with checkpoint.open("iterator_state.json") as f:
+ iterator_state = json.loads(f.read())
+ if iterator_state:
+ self.epoch_idx = iterator_state["epoch_idx"]
+ self.samples_processed_this_epoch_per_rank = iterator_state[
+ "samples_processed_this_epoch_per_rank"
+ ]
+
+ def resume_from_state(
+ self,
+ resume_state_dump,
+ optimizer=True,
+ require_all=True,
+ wandb=False,
+ iterator_state=True,
+ ):
+ dist.print0(f'Loading training state from "{resume_state_dump}"...')
+
+ with checkpointing.Checkpoint(resume_state_dump, "r") as checkpoint:
+ self._load_net_state(checkpoint, require_all)
+ gc.collect()
+ if optimizer and self.optimizer is not None:
+ self._load_optimizer_state(checkpoint)
+
+ with checkpoint.open("loop.json") as f:
+ old_loop = self.loads(f.read())
+
+ # Restore iterator state if available (for backward compatibility)
+ if iterator_state:
+ try:
+ self._load_iterator_state(checkpoint)
+ except FileNotFoundError as e:
+ logger.warning(
+ f"Iterator state not found in checkpoint (backward compatibility): {e}. "
+ "Using defaults (epoch_idx=0, samples_processed_this_epoch_per_rank=0)"
+ )
+
+ # handle wandb
+ if wandb:
+ self.wandb_id = old_loop.wandb_id
+
+ def _load_net_state(self, checkpoint, require_all):
+ with checkpoint.open("net_state.pth", "r") as f:
+ net_state = torch.load(f, weights_only=True, map_location="cpu")
+ self.net.load_state_dict(net_state, strict=require_all)
+
+ def _load_optimizer_state(self, checkpoint):
+ with checkpoint.open("optimizer_state.pth", "r") as f:
+ # load to cpu to avoid copies in gpu memory
+ optimizer_state = torch.load(f, map_location="cpu")
+ self.optimizer.load_state_dict(optimizer_state)
+
+ def train_step(
+ self, *, condition=None, target, labels, augment_labels=None, **kwargs
+ ):
+ return self.loss_fn(
+ net=partial(
+ self.ddp,
+ condition=condition,
+ class_labels=labels,
+ augment_labels=augment_labels,
+ **kwargs,
+ ),
+ images=target,
+ )
+
+ def _stage_tuple_batch(self, batch):
+ indict = {}
+ images, labels, condition = batch[:3]
+ if images.ndim != 4:
+ raise ValueError(f"Expected images.ndim == 4, got {images.ndim}")
+ indict["target"] = images.to(self.device)
+ indict["condition"] = condition.to(self.device)
+ indict["labels"] = labels.to(self.device)
+
+ if len(batch) == 4:
+ augment_labels = batch[3]
+ if augment_labels is not None:
+ augment_labels.to(self.device).float()
+ indict["augment_labels"] = batch[3]
+ return indict
+
+ def _stage_dict_batch(self, batch):
+ return _to_batch(batch, self.device)
+
+ @profiling.nvtx
+ def backward_batch(self, dataset_iterator):
+ self.ddp.train()
+ self.optimizer.zero_grad(set_to_none=True)
+ total_loss = 0.0
+ time_start = time.time()
+ for round_idx in range(self.num_accumulation_rounds):
+ with ddp_sync(self.ddp, (round_idx == self.num_accumulation_rounds - 1)):
+ with profiling.nvtx_range("load data"):
+ batch = next(dataset_iterator)
+
+ if isinstance(batch, dict):
+ indict = self._stage_dict_batch(batch)
+ else:
+ warnings.warn(
+ DeprecationWarning(
+ "tuple based dataloaders will be removed soon. please refactor to use dicts."
+ )
+ )
+ indict = self._stage_tuple_batch(batch)
+
+ with torch.autocast(
+ device_type="cuda", dtype=torch.bfloat16, enabled=self.bf16
+ ):
+ # print(f"indict size: {size_of(indict) / 1e9} GB")
+ loss = self.train_step(**indict)
+ self.log_metric("Loss/loss", loss, print=True)
+ time_length = loss.shape[2] # (b, c, t, x)
+
+ if self.loss_reduction == "v1":
+ loss_mean = loss.sum().mul(
+ self.loss_scaling / (self.batch_gpu_total * time_length)
+ )
+ elif self.loss_reduction == "mean":
+ loss_mean = loss.mean() / self.num_accumulation_rounds
+ else:
+ raise NotImplementedError(self.loss_reduction)
+
+ with profiling.nvtx_range("training_loop:backward"):
+ loss_mean.backward()
+
+ total_loss += loss_mean.detach().cpu()
+ time_end = time.time()
+ self.log_debug(f"Final Loss: {total_loss.item()}")
+ self.log_debug(
+ f"Time taken for {self.num_accumulation_rounds} accumulation rounds: {time_end - time_start}"
+ )
+
+ def _log_parameter_and_gradient_norms(self):
+ # Log parameter and gradient norms if enabled
+ if self.log_parameter_norm or self.log_parameter_grad_norm:
+ for name, param in self.net.named_parameters():
+ if self.log_parameter_norm:
+ self.log_metric(
+ f"param_norm/{name}",
+ param.data.norm(2),
+ frequency="tick",
+ print=False,
+ )
+ if self.log_parameter_grad_norm and param.grad is not None:
+ self.log_metric(f"grad_norm/{name}", param.grad.norm(2))
+
+ grad_norm = torch.nn.utils.get_total_norm(
+ [param.grad for param in self.net.parameters() if param.grad is not None]
+ )
+ self.log_metric("grad_norm", grad_norm, frequency="tick", print=True)
+ self.log_metric("grad_norm", grad_norm, frequency="step")
+
+ @profiling.nvtx
+ @finish_before_quitting
+ def step_optimizer(self, cur_nimg):
+ torch.cuda.nvtx.range_push("training_loop:step")
+
+ warmup_imgs = self.lr_rampup_img
+ flat_imgs = self.flat_imgs
+ decay_imgs = self.decay_imgs
+ total_imgs = warmup_imgs + flat_imgs + decay_imgs
+
+ def lr_lambda(cur_nimg):
+ import math
+
+ base_lr = self.lr
+ min_lr = self.lr_min
+
+ min_factor = min_lr / base_lr
+ if cur_nimg < warmup_imgs:
+ # linear ramp from 0 โ 1
+ return float(cur_nimg) / warmup_imgs
+ elif cur_nimg < warmup_imgs + flat_imgs:
+ return 1.0
+ elif cur_nimg < total_imgs:
+ # cosine decay from factor=1 โ factor=min_factor
+ progress = float(cur_nimg - warmup_imgs - flat_imgs) / decay_imgs
+ # standard cosine schedule:
+ return min_factor + 0.5 * (1.0 - min_factor) * (
+ 1.0 + math.cos(math.pi * progress)
+ )
+ else:
+ return min_factor
+
+ def default_scale(cur_nimg):
+ return min(cur_nimg / max(self.lr_rampup_img, 1e-8), 1)
+
+ use_lr_lambda = True
+ scale_fn = lr_lambda if use_lr_lambda else default_scale
+
+ scale = scale_fn(self.cur_nimg)
+ for g in self.optimizer.param_groups:
+ if "base_lr" not in g:
+ if "lr" in g:
+ g["base_lr"] = g["lr"] # lazy init from existing LR
+ else:
+ g["base_lr"] = self.optimizer.defaults["lr"]
+ lr = g["base_lr"] * scale
+ self.log_debug(
+ f"Learning rate: {lr} from base: {g['base_lr']} with scale factor: {scale} (would normally be {default_scale(self.cur_nimg)})"
+ )
+
+ g["lr"] = lr
+ self.writer.add_scalar("lr", lr, global_step=self.cur_nimg)
+
+ self._log_parameter_and_gradient_norms()
+
+ for param in self.net.parameters():
+ if param.grad is not None:
+ torch.nan_to_num(
+ param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad
+ )
+
+ if self.gradient_clip_max_norm is not None:
+ torch.nn.utils.clip_grad_norm_(
+ self.net.parameters(), max_norm=self.gradient_clip_max_norm
+ )
+
+ self._step_optimizer()
+ # increment the number of images processed within the current epoch
+ self.samples_processed_this_epoch_per_rank += self.batch_gpu or 1
+ torch.cuda.nvtx.range_pop()
+
+ self._flush_step_metrics()
+
+ def on_tick(self):
+ pass
+
+ @profiling.nvtx
+ def validate(self, net):
+ loss_key = "Loss/test_loss"
+
+ with torch.no_grad():
+ for batch in self.valid_loader:
+ if len(batch) == 4:
+ images, labels, condition, augment_labels = batch
+ else:
+ images, labels, condition = batch
+ augment_labels = None
+
+ if images.ndim != 4:
+ raise ValueError(f"Expected images.ndim == 4, got {images.ndim}")
+
+ images = images.to(self.device).to(torch.float32)
+ condition = condition.to(self.device).to(torch.float32)
+ labels = labels.to(self.device)
+
+ loss = self.train_step(
+ condition=condition,
+ target=images,
+ labels=labels,
+ augment_labels=augment_labels,
+ )
+ training_stats.report(loss_key, loss)
+
+ def log_metric(self, key, value, print=True, frequency="tick"):
+ """Log a metric. Will be averaged over all calls within a tick
+
+ Args:
+ print: if True then print the metric to the console at the end of the tick
+
+ """
+ if frequency == "tick":
+ training_stats.report(key, value)
+ elif frequency == "step":
+ if key not in self._step_metrics:
+ self._step_metrics[key] = torchmetrics.MeanMetric().to(value)
+ self._step_metrics[key].update(value)
+
+ if print:
+ self._metrics_to_print.add(key)
+
+ def _flush_step_metrics(self):
+ for key, metric in self._step_metrics.items():
+ value = metric.compute()
+ if dist.get_rank() == 0:
+ self.writer.add_scalar(key, value, global_step=self.cur_nimg)
+ metric.reset()
+
+ def _flush_training_stats_to_wandb(self):
+ if dist.get_rank() == 0:
+ info = training_stats.default_collector.as_dict()
+ metrics = {name: info[name]["mean"] for name in info}
+ wandb.log(metrics, step=self.cur_nimg)
+
+ @property
+ def batch_info(self) -> None | BatchInfo:
+ return None
+
+ @profiling.nvtx
+ @finish_before_quitting
+ def _save_checkpoint(self, path, optimizer: bool):
+ # ensure that file updates are atomic to avoid faulty
+ # restart files
+ tmppath = path + ".tmp" + str(os.getpid())
+ with checkpointing.Checkpoint(tmppath, "w") as checkpoint:
+ checkpoint.write_model(self.net)
+ if self.batch_info is not None:
+ checkpoint.write_batch_info(self.batch_info)
+
+ if optimizer:
+ with checkpoint.open("optimizer_state.pth", "w") as f:
+ torch.save(self.optimizer.state_dict(), f)
+
+ with checkpoint.open("loop.json", "w") as f:
+ f.write(self.dumps().encode())
+
+ # Save iterator state for resuming
+ with checkpoint.open("iterator_state.json", "w") as f:
+ iterator_state = {
+ "epoch_idx": self.epoch_idx,
+ "samples_processed_this_epoch_per_rank": self.samples_processed_this_epoch_per_rank,
+ }
+ f.write(json.dumps(iterator_state).encode())
+
+ if self.model_config is not None:
+ checkpoint.write_model_config(self.model_config)
+ shutil.move(tmppath, path)
+
+ @profiling.nvtx
+ def save_training_state(self, cur_nimg):
+ if dist.get_rank() != 0:
+ return
+
+ state_filename = self._state_checkpoint_handler.get_path(cur_nimg)
+ dist.print0(f"Saving checkpoint to {state_filename}")
+ self._save_checkpoint(state_filename, optimizer=True)
+ dist.print0(f"Checkpoint saved to {state_filename}")
+
+ @profiling.nvtx
+ def save_network_snapshot(self, cur_nimg):
+ if dist.get_rank() != 0:
+ return
+
+ filename = self._snapshot_checkpoint_handler.get_path(cur_nimg)
+ dist.print0(f"Saving network snapshot to {filename}")
+ self._save_checkpoint(filename, optimizer=False)
+
+ def flush_training_stats(self):
+ logger = logging.getLogger(__name__)
+ logger.info("Begin. flushing training stats.")
+ training_stats.default_collector.update()
+ if self.do_wandb:
+ self._flush_training_stats_to_wandb()
+
+ if dist.get_rank() == 0:
+ info = training_stats.default_collector.as_dict()
+ try:
+ nimg = info["Progress/kimg"]["mean"] * 1000
+ except KeyError:
+ nimg = self.cur_nimg
+
+ for k, v in info.items():
+ for moment in v:
+ self.writer.add_scalar(f"{k}/{moment}", v[moment], global_step=nimg)
+
+ stats_path = os.path.join(self.run_dir, "stats.jsonl")
+ with open(stats_path, "at") as f:
+ stats = training_stats.default_collector.as_dict()
+ for stat in stats:
+ mean = stats[stat]["mean"]
+ if stat in self._metrics_to_print:
+ print(f"{stat} = {mean:4g}")
+ f.write(
+ json.dumps(
+ dict(
+ training_stats.default_collector.as_dict(),
+ timestamp=time.time(),
+ )
+ )
+ + "\n"
+ )
+
+ @classmethod
+ def from_json(cls, path):
+ with open(path) as f:
+ return cls.loads(f.read())
+
+ @classmethod
+ def from_rundir(cls, run_dir):
+ path = os.path.join(run_dir, TRAINER_METADATA_FILENAME)
+ loop = cls.from_json(path)
+ loop.run_dir = run_dir
+ return loop
+
+ def dumps(self):
+ fields = dataclasses.asdict(self)
+ fields.pop("device", None)
+ return json.dumps(fields)
+
+ @classmethod
+ def loads(cls, s):
+ return cls(**json.loads(s))
+
+ def save_metadata(self):
+ with open(os.path.join(self.run_dir, TRAINER_METADATA_FILENAME), "w") as f:
+ fields = dataclasses.asdict(self)
+ fields.pop("device", None)
+ f.write(self.dumps())
+
+ def setup(self):
+ self.setup_logs()
+ self.save_metadata()
+ self.device = self.device or torch.device("cuda", torch.cuda.current_device())
+ self.loss_fn = self.get_loss_fn()
+
+ # iterators
+ # used to restore the sampler state during restarts
+ self.cur_nimg = 0
+ self.epoch_idx = 0
+ self.samples_processed_this_epoch_per_rank = 0
+
+ self._setup_datasets()
+ self._setup_networks()
+ self.print_network_info(self.net, self.device)
+ self.setup_batching()
+ self._setup_optimizer()
+ self._state_checkpoint_handler = CheckpointHandler(self.run_dir)
+ self._snapshot_checkpoint_handler = CheckpointHandler(
+ self.run_dir, "network-snapshot-{}.checkpoint"
+ )
+
+ def _setup_optimizer(self):
+ self.optimizer = self.get_optimizer(self.net.named_parameters())
+ if self.compile_optimizer:
+ self._step_optimizer = torch.compile(self.optimizer.step)
+ else:
+ self._step_optimizer = self.optimizer.step
+
+ def setup_wandb(self, **kwargs):
+ try:
+ if wandb is not None and dist.get_rank() == 0:
+ os.environ["WANDB_API_KEY"]
+ run = wandb.init(
+ id=self.wandb_id,
+ config=json.loads(self.dumps()),
+ project="ufs-da",
+ entity="nv-research-climate",
+ **kwargs,
+ )
+ self.wandb_id = run.id
+ self.do_wandb = True
+ self._wandb_run = run
+ except KeyError:
+ # cannot init wandb
+ dist.print0("WANDB_API_KEY not set. Cannot use wandb")
+ pass
+
+ def resume_from_rundir(self, run_dir=None, require_all=True):
+ checkpoint_info = None
+ for checkpoint_info in self._state_checkpoint_handler.list_checkpoints(run_dir):
+ pass
+ # now training_state and nimg are the final checkpoints
+ if checkpoint_info is None:
+ raise FileNotFoundError("No checkpoint file found.")
+
+ path, nimg = checkpoint_info
+
+ self.cur_nimg = nimg
+ self.resume_from_state(path, require_all=require_all, wandb=True)
+
+ def log_debug(self, msg):
+ if dist.get_rank() != 0:
+ return
+
+ if self.iteration % self.print_steps != 0:
+ return
+
+ logger.debug(msg)
+
+ def train(self):
+ # signal.signal(signal.SIGINT, handler)
+ signal.signal(signal.SIGTERM, handler)
+ try:
+ self._train()
+ except QuitEarly as e:
+ dist.print0(f"Caught {e}. Quitting early.")
+ self.save_training_state(self.cur_nimg)
+ try:
+ del self.train_loader
+ del self.valid_loader
+ except AttributeError:
+ pass
+
+ def _batch_iterator(self):
+ while True:
+ for batch in self.train_loader:
+ yield batch
+
+ self.epoch_idx += 1
+ self.samples_processed_this_epoch_per_rank = 0
+ self.iteration = 0
+
+ def _train(self):
+ dist.print0("Loss function", self.loss_fn)
+ start_time = time.time()
+ np.random.seed(
+ (self.seed * dist.get_world_size() + dist.get_rank() + self.cur_nimg)
+ % (1 << 31)
+ )
+ torch.manual_seed(np.random.randint(1 << 31))
+ torch.backends.cudnn.benchmark = self.cudnn_benchmark
+ torch.backends.cudnn.allow_tf32 = self.tf32
+ torch.backends.cuda.matmul.allow_tf32 = self.tf32
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
+ not self.tf32
+ )
+
+ # Train.
+ tick_start_time = time.time()
+ maintenance_time = tick_start_time - start_time
+ dist.update_progress(0, self.total_ticks)
+ dataset_iterator = self._batch_iterator()
+ top_time = time.time()
+ steps = 0
+ for cur_tick in range(self.total_ticks):
+ for _ in range(self.steps_per_tick):
+ step_start = time.time()
+ self.backward_batch(dataset_iterator)
+ self.cur_nimg += self.batch_size
+ self.step_optimizer(self.cur_nimg)
+
+ step_end = time.time()
+ self.log_debug(
+ f"Step {steps} time: {step_end - step_start}. Avg time: {(step_end - top_time) / (steps + 1)}"
+ )
+ self.log_debug(
+ f"CPU Memory: {psutil.Process(os.getpid()).memory_info().rss / 2**30:<6.2f}GB"
+ )
+ steps += 1
+ self.iteration = steps
+ tick_end_time = time.time()
+ self.log_tick(
+ maintenance_time,
+ tick_start_time,
+ tick_end_time,
+ start_time,
+ cur_tick,
+ self.cur_nimg,
+ )
+
+ # Save network snapshot.
+ if (self.snapshot_ticks is not None) and (
+ cur_tick % self.snapshot_ticks == 0
+ ):
+ self.save_network_snapshot(self.cur_nimg)
+ if (self.state_dump_ticks is not None) and (
+ cur_tick % self.state_dump_ticks == 0
+ ):
+ self.save_training_state(self.cur_nimg)
+
+ self.net.eval()
+ logger.info("Validating...")
+ val_start_time = time.time()
+ self.validate(self.net)
+ val_time = time.time() - val_start_time
+ logger.info(f"Validation time: {val_time:.2f}s.")
+ self.net.train()
+
+ # Update logs.
+ self.flush_training_stats()
+ dist.update_progress(cur_tick, self.total_ticks)
+
+ tick_start_time = time.time()
+ maintenance_time = tick_start_time - tick_end_time
+
+ # Done.
+ self.save_training_state(self.cur_nimg)
+ dist.print0("Exiting...")
diff --git a/examples/weather/healda/training/training_stats.py b/examples/weather/healda/training/training_stats.py
new file mode 100644
index 0000000000..dc45331a52
--- /dev/null
+++ b/examples/weather/healda/training/training_stats.py
@@ -0,0 +1,316 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Facilities for reporting and collecting training statistics across
+multiple processes and devices. The interface is designed to minimize
+synchronization overhead as well as the amount of boilerplate in user
+code."""
+
+import re
+
+import numpy as np
+import torch
+
+# ----------------------------------------------------------------------------
+
+_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
+_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
+_counter_dtype = torch.float64 # Data type to use for the internal counters.
+_rank = 0 # Rank of the current process.
+_sync_device = (
+ None # Device to use for multiprocess communication. None = single-process.
+)
+_sync_called = False # Has _sync() been called yet?
+_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
+_cumulative = (
+ dict()
+) # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
+
+# ----------------------------------------------------------------------------
+
+
+def init_multiprocessing(rank, sync_device):
+ r"""Initializes `torch_utils.training_stats` for collecting statistics
+ across multiple processes.
+
+ This function must be called after
+ `torch.distributed.init_process_group()` and before `Collector.update()`.
+ The call is not necessary if multi-process collection is not needed.
+
+ Args:
+ rank: Rank of the current process.
+ sync_device: PyTorch device to use for inter-process
+ communication, or None to disable multi-process
+ collection. Typically `torch.device('cuda', rank)`.
+ """
+ global _rank, _sync_device
+ if _sync_called:
+ raise RuntimeError("init() has already been called")
+ _rank = rank
+ _sync_device = sync_device
+
+
+# ----------------------------------------------------------------------------
+
+
+def report(name, value):
+ r"""Broadcasts the given set of scalars to all interested instances of
+ `Collector`, across device and process boundaries.
+
+ This function is expected to be extremely cheap and can be safely
+ called from anywhere in the training loop, loss function, or inside a
+ `torch.nn.Module`.
+
+ Warning: The current implementation expects the set of unique names to
+ be consistent across processes. Please make sure that `report()` is
+ called at least once for each unique name by each process, and in the
+ same order. If a given process has no scalars to broadcast, it can do
+ `report(name, [])` (empty list).
+
+ Args:
+ name: Arbitrary string specifying the name of the statistic.
+ Averages are accumulated separately for each unique name.
+ value: Arbitrary set of scalars. Can be a list, tuple,
+ NumPy array, PyTorch tensor, or Python scalar.
+
+ Returns:
+ The same `value` that was passed in.
+ """
+ if name not in _counters:
+ _counters[name] = dict()
+
+ elems = torch.as_tensor(value)
+ if elems.numel() == 0:
+ return value
+
+ elems = elems.detach().flatten().to(_reduce_dtype)
+ moments = torch.stack(
+ [
+ torch.ones_like(elems).sum(),
+ elems.sum(),
+ elems.square().sum(),
+ ]
+ )
+ if moments.ndim != 1 or moments.shape[0] != _num_moments:
+ raise ValueError(
+ f"Expected moments shape (1, {_num_moments}), got {moments.shape}"
+ )
+ moments = moments.to(_counter_dtype)
+
+ device = moments.device
+ if device not in _counters[name]:
+ _counters[name][device] = torch.zeros_like(moments)
+ _counters[name][device].add_(moments)
+ return value
+
+
+# ----------------------------------------------------------------------------
+
+
+def report0(name, value):
+ r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
+ but ignores any scalars provided by the other processes.
+ See `report()` for further details.
+ """
+ report(name, value if _rank == 0 else [])
+ return value
+
+
+# ----------------------------------------------------------------------------
+
+
+class Collector:
+ r"""Collects the scalars broadcasted by `report()` and `report0()` and
+ computes their long-term averages (mean and standard deviation) over
+ user-defined periods of time.
+
+ The averages are first collected into internal counters that are not
+ directly visible to the user. They are then copied to the user-visible
+ state as a result of calling `update()` and can then be queried using
+ `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
+ internal counters for the next round, so that the user-visible state
+ effectively reflects averages collected between the last two calls to
+ `update()`.
+
+ Args:
+ regex: Regular expression defining which statistics to
+ collect. The default is to collect everything.
+ keep_previous: Whether to retain the previous averages if no
+ scalars were collected on a given round
+ (default: True).
+ """
+
+ def __init__(self, regex=".*", keep_previous=True):
+ self._regex = re.compile(regex)
+ self._keep_previous = keep_previous
+ self._cumulative = dict()
+ self._moments = dict()
+ self.update()
+ self._moments.clear()
+
+ def names(self):
+ r"""Returns the names of all statistics broadcasted so far that
+ match the regular expression specified at construction time.
+ """
+ return [name for name in _counters if self._regex.fullmatch(name)]
+
+ def update(self):
+ r"""Copies current values of the internal counters to the
+ user-visible state and resets them for the next round.
+
+ If `keep_previous=True` was specified at construction time, the
+ operation is skipped for statistics that have received no scalars
+ since the last update, retaining their previous averages.
+
+ This method performs a number of GPU-to-CPU transfers and one
+ `torch.distributed.all_reduce()`. It is intended to be called
+ periodically in the main training loop, typically once every
+ N training steps.
+ """
+ if not self._keep_previous:
+ self._moments.clear()
+ for name, cumulative in _sync(self.names()):
+ if name not in self._cumulative:
+ self._cumulative[name] = torch.zeros(
+ [_num_moments], dtype=_counter_dtype
+ )
+ delta = cumulative - self._cumulative[name]
+ self._cumulative[name].copy_(cumulative)
+ if float(delta[0]) != 0:
+ self._moments[name] = delta
+
+ def _get_delta(self, name):
+ r"""Returns the raw moments that were accumulated for the given
+ statistic between the last two calls to `update()`, or zero if
+ no scalars were collected.
+ """
+ if not self._regex.fullmatch(name):
+ raise ValueError(f"Name '{name}' does not match expected pattern")
+ if name not in self._moments:
+ self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
+ return self._moments[name]
+
+ def num(self, name):
+ r"""Returns the number of scalars that were accumulated for the given
+ statistic between the last two calls to `update()`, or zero if
+ no scalars were collected.
+ """
+ delta = self._get_delta(name)
+ return int(delta[0])
+
+ def mean(self, name):
+ r"""Returns the mean of the scalars that were accumulated for the
+ given statistic between the last two calls to `update()`, or NaN if
+ no scalars were collected.
+ """
+ delta = self._get_delta(name)
+ if int(delta[0]) == 0:
+ return float("nan")
+ return float(delta[1] / delta[0])
+
+ def std(self, name):
+ r"""Returns the standard deviation of the scalars that were
+ accumulated for the given statistic between the last two calls to
+ `update()`, or NaN if no scalars were collected.
+ """
+ delta = self._get_delta(name)
+ if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
+ return float("nan")
+ if int(delta[0]) == 1:
+ return float(0)
+ mean = float(delta[1] / delta[0])
+ raw_var = float(delta[2] / delta[0])
+ return np.sqrt(max(raw_var - np.square(mean), 0))
+
+ def as_dict(self):
+ r"""Returns the averages accumulated between the last two calls to
+ `update()` as an `dict`. The contents are as follows:
+
+ dict(
+ NAME = dict(num=FLOAT, mean=FLOAT, std=FLOAT),
+ ...
+ )
+ """
+ stats = {}
+ for name in self.names():
+ stats[name] = dict(
+ num=self.num(name), mean=self.mean(name), std=self.std(name)
+ )
+ return stats
+
+ def __getitem__(self, name):
+ r"""Convenience getter.
+ `collector[name]` is a synonym for `collector.mean(name)`.
+ """
+ return self.mean(name)
+
+
+# ----------------------------------------------------------------------------
+
+
+def _as_local(tensor):
+ try:
+ # Try to extract local tensor from DTensor
+ return tensor.to_local()
+ except AttributeError:
+ # Regular tensor, just move to device
+ return tensor
+
+
+# ----------------------------------------------------------------------------
+
+
+def _sync(names):
+ r"""Synchronize the global cumulative counters across devices and
+ processes. Called internally by `Collector.update()`.
+ """
+ if len(names) == 0:
+ return []
+ global _sync_called
+ _sync_called = True
+
+ # Collect deltas within current rank.
+ deltas = []
+ device = _sync_device if _sync_device is not None else torch.device("cpu")
+ for name in names:
+ delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
+ for counter in _counters[name].values():
+ counter = _as_local(counter)
+ delta.add_(counter.to(device))
+ counter.copy_(torch.zeros_like(counter))
+ deltas.append(delta)
+ deltas = torch.stack(deltas)
+
+ # Sum deltas across ranks.
+ if _sync_device is not None:
+ torch.distributed.all_reduce(deltas)
+
+ # Update cumulative values.
+ deltas = deltas.cpu()
+ for idx, name in enumerate(names):
+ if name not in _cumulative:
+ _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
+ _cumulative[name].add_(deltas[idx])
+
+ # Return name-value pairs.
+ return [(name, _cumulative[name]) for name in names]
+
+
+# ----------------------------------------------------------------------------
+# Convenience.
+
+default_collector = Collector()
+
+# ----------------------------------------------------------------------------
diff --git a/examples/weather/healda/utils/__init__.py b/examples/weather/healda/utils/__init__.py
new file mode 100644
index 0000000000..cde1bcdab3
--- /dev/null
+++ b/examples/weather/healda/utils/__init__.py
@@ -0,0 +1,25 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utility modules for HealDA example."""
+
+from utils import (
+ checkpointing,
+ dataclass_parser,
+ distributed,
+ signals,
+ storage,
+ visualizations,
+)
diff --git a/examples/weather/healda/utils/checkpointing.py b/examples/weather/healda/utils/checkpointing.py
new file mode 100644
index 0000000000..c73eb8b230
--- /dev/null
+++ b/examples/weather/healda/utils/checkpointing.py
@@ -0,0 +1,155 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""File formats for checkpointing
+
+
+# Version 1:
+format = zip file
+contains
+
+```
+net_state.pth
+batch_info.json
+model.json: str # json-serialized ModelConfigV1 (if version == 1)
+metadata.json
+ version: int
+```
+
+# other files are not interpreted by cbottle but can be used for other purposes
+# (like in custom TrainingLoops)
+"""
+
+import json
+import warnings
+import zipfile
+from typing import Literal
+
+import datasets.base
+import models
+import torch
+
+from config.model_config import ModelConfigV1
+
+current_version = 1
+
+
+class Checkpoint:
+ """A checkpoint object
+
+ This is similar to ZipFile, but with convenience methods for reading and
+ writing models.
+ """
+
+ def __init__(self, f, mode: Literal["w", "r"] = "r"):
+ # Set zip to None in case the zipfile fails to open
+ # this will be used in __del__
+ self._zip = None
+ self._zip = zipfile.ZipFile(f, mode)
+
+ def write_model(self, net: torch.nn.Module):
+ with self._zip.open("net_state.pth", "w", force_zip64=True) as f:
+ torch.save(net.state_dict(), f)
+
+ def read_model(self, net=None, map_location=None) -> torch.nn.Module:
+ """Read the model from the checkpoint
+
+ Args:
+ net: If provided, the state dict will be loaded into this net.
+ Otherwise, a new model will be created.
+ """
+ try:
+ metadata = json.loads(self._zip.read("metadata.json").decode())
+ except KeyError:
+ warnings.warn("Old checkpoint format detected. Falling back to old format.")
+ with self._zip.open("net_state.pth", "r") as f:
+ return torch.load(f, weights_only=False)["net"]
+
+ # new checkpoint format
+ if metadata["version"] != current_version:
+ raise ValueError(f"Unsupported checkpoint version: {metadata['version']}")
+
+ model_config = self.read_model_config()
+ if net is None:
+ net = models.get_model(model_config)
+
+ with self._zip.open("net_state.pth", "r") as f:
+ net.load_state_dict(
+ torch.load(f, weights_only=True, map_location=map_location)
+ )
+ return net
+
+ def read_model_config(self) -> ModelConfigV1:
+ return ModelConfigV1.loads(self._zip.open("model.json").read())
+
+ def read_model_state_dict(self) -> dict:
+ with self._zip.open("net_state.pth", "r") as f:
+ return torch.load(f, weights_only=True)
+
+ def write_batch_info(self, batch_info: datasets.base.BatchInfo):
+ d = {
+ "channels": batch_info.channels,
+ "time_step": batch_info.time_step,
+ "time_unit": batch_info.time_unit.name, # enums don't serialize nicely
+ }
+ if batch_info.scales is not None:
+ d["scales"] = list(batch_info.scales)
+ if batch_info.center is not None:
+ d["center"] = list(batch_info.center)
+ self._zip.writestr("batch_info.json", json.dumps(d))
+
+ def read_batch_info(self) -> datasets.base.BatchInfo:
+ with self._zip.open("batch_info.json", "r") as f:
+ d = json.loads(f.read())
+ scales = d.pop("scales", None)
+ center = d.pop("center", None)
+ if d["time_unit"] == "": # backwards compatibility
+ time_unit = datasets.base.TimeUnit.HOUR
+ elif d["time_unit"] == "MINUTE":
+ time_unit = datasets.base.TimeUnit.MINUTE
+ elif d["time_unit"] == "HOUR":
+ time_unit = datasets.base.TimeUnit.HOUR
+ else:
+ time_unit_dict = {v.value: v for v in datasets.base.TimeUnit}
+ time_unit = time_unit_dict[d["time_unit"]]
+ return datasets.base.BatchInfo(
+ time_unit=time_unit,
+ scales=scales,
+ center=center,
+ channels=d["channels"],
+ )
+
+ def write_model_config(self, model_config: ModelConfigV1):
+ self._zip.writestr("model.json", model_config.dumps())
+
+ def open(self, name, mode: Literal["w", "r"] = "r"):
+ return self._zip.open(name, mode, force_zip64=True)
+
+ def close(self):
+ if self._zip.mode == "w":
+ self._zip.writestr(
+ "metadata.json", json.dumps({"version": current_version})
+ )
+ self._zip.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+
+ def __del__(self):
+ if self._zip is not None and self._zip.fp:
+ self.close()
diff --git a/examples/weather/healda/utils/dataclass_parser.py b/examples/weather/healda/utils/dataclass_parser.py
new file mode 100644
index 0000000000..6ceb2965c9
--- /dev/null
+++ b/examples/weather/healda/utils/dataclass_parser.py
@@ -0,0 +1,267 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A small library for parsing cli directly to dataclasses
+
+Example:
+
+"""
+
+import argparse
+import enum
+from dataclasses import MISSING, dataclass, fields, is_dataclass
+from types import UnionType
+from typing import (
+ Annotated,
+ Any,
+ Type,
+ TypeVar,
+ Union,
+ get_origin,
+)
+
+__all__ = ["Help", "parse_args", "parse_dict", "a"]
+
+T = TypeVar("T")
+
+a = Annotated
+
+
+@dataclass(frozen=True)
+class Help:
+ """When used with annotated types this will add the help to argparse"""
+
+ message: str
+
+
+def _get_type_and_meta(t):
+ """return the type and metadta of a potentially annotated type"""
+ if get_origin(t) is Annotated:
+ meta = t.__metadata__
+ t = t.__origin__
+ else:
+ meta = []
+
+ return _is_optional(t), _handle_optional(t), meta
+
+
+def _is_optional(t):
+ return get_origin(t) in [Union, UnionType]
+
+
+def _handle_optional(t):
+ """this returns the specified user type when wrapped in optional or a union type object
+
+ Exmaples:
+
+ _handle_optional(str | None) == str
+ _handle_optional(Optional[str]) == str
+ _handle_optional(Union[str, None]) == str
+
+ """
+ if get_origin(t) in [Union, UnionType]:
+ types = [tt for tt in t.__args__ if tt is not type(None)]
+ if len(types) > 1:
+ raise ValueError(f"Union types not supported: {t}.")
+ return types[0]
+ else:
+ return t
+
+
+def is_enum(T):
+ return isinstance(T, Type) and issubclass(T, enum.Enum)
+
+
+def parse_args(
+ opts: Type[T],
+ args: list[str] | None = None,
+ strict: bool = True,
+ convert_underscore_to_hyphen: bool = True,
+) -> T:
+ """Parse a list of command line arguments into a dataclass
+
+ Args:
+ opts: the dataclass specification of the arguments
+ args: the list of string command line arguments.
+ If not provided, this is read from sys.argv.
+ strict: if true, then check the types at runtime
+
+ Returns:
+ an instance of the `opts` dataclass
+
+ """
+ parser = argparse.ArgumentParser()
+
+ def add_arguments(parser: argparse.ArgumentParser, dataclass_type, prefix=""):
+ for field in fields(dataclass_type):
+ help_str = ""
+ _, T, meta = _get_type_and_meta(field.type)
+ this_parser = parser
+ for item in meta:
+ if isinstance(item, Help):
+ help_str = item.message
+
+ # Construct argument name with prefix if provided
+ # Check for default values
+ if isinstance(dataclass_type, type):
+ default = (
+ field.default
+ if field.default is not MISSING
+ else (
+ field.default_factory()
+ if field.default_factory is not MISSING
+ else MISSING
+ )
+ )
+ else:
+ default = getattr(dataclass_type, field.name)
+
+ def _get_arg_name(field_name: str, prefix: str, required):
+ if convert_underscore_to_hyphen:
+ field_name = field_name.replace("_", "-")
+
+ flag = "" if required and not prefix else "--"
+ if prefix:
+ return f"{flag}{prefix}{field_name}"
+ else:
+ return f"{flag}{field_name}"
+
+ if is_dataclass(T):
+ # Handle nested dataclass by adding arguments for its fields
+ add_arguments(parser, default or T, prefix=f"{prefix}{field.name}.")
+ else:
+ arg_name = _get_arg_name(
+ field.name, prefix, required=default is MISSING
+ )
+ if T is bool:
+ if default:
+ sep = "-" if convert_underscore_to_hyphen else "_"
+ arg_name = _get_arg_name(
+ field.name, prefix + f"no{sep}", required=default is MISSING
+ )
+ dest = f"{prefix}{field.name}" if prefix else f"{field.name}"
+ this_parser.add_argument(
+ arg_name,
+ action="store_false",
+ dest=dest,
+ help=help_str,
+ )
+ else:
+ this_parser.add_argument(
+ arg_name,
+ action="store_true",
+ help=help_str,
+ )
+ elif is_enum(T):
+ this_parser.add_argument(
+ arg_name, default=default.name, choices=[x.name for x in T]
+ )
+ elif T is Any:
+ this_parser.add_argument(
+ arg_name,
+ default=default,
+ help=help_str,
+ )
+ else:
+ help_str += f" [{T.__name__}, default: {default}]"
+ this_parser.add_argument(
+ arg_name,
+ type=T,
+ default=default,
+ help=help_str,
+ )
+
+ add_arguments(parser, opts)
+ parsed_args = parser.parse_args(args)
+
+ def construct_dataclass(dataclass_type, parsed_data, prefix=""):
+ init_kwargs = {}
+ for field in fields(dataclass_type):
+ key = f"{prefix}{field.name}"
+ optional, T, _ = _get_type_and_meta(field.type)
+ if is_dataclass(T):
+ if optional and not _contains_object(key, parsed_data):
+ value = None
+ else:
+ # Recursively build nested dataclass
+ value = construct_dataclass(
+ T, parsed_data, prefix=f"{prefix}{field.name}."
+ )
+ else:
+ # Use the argument value
+ value = getattr(parsed_data, key)
+
+ if is_enum(T):
+ (value,) = [it for it in T if it.name == value]
+
+ if strict and T is not Any:
+ # Normalize parameterized generics (e.g. list[int]) to their origin
+ # type for isinstance checks, since Python disallows using
+ # parameterized generics directly in isinstance().
+ origin = get_origin(T)
+ check_type = origin or T
+ allowed_types = (check_type, type(None)) if optional else (check_type,)
+ if not isinstance(value, allowed_types):
+ raise ValueError(f"{value} is not a {T}")
+
+ init_kwargs[field.name] = value
+ return dataclass_type(**init_kwargs)
+
+ return construct_dataclass(opts, parsed_args)
+
+
+def _contains_object(prefix, data):
+ for key in vars(data):
+ if key.startswith(prefix):
+ return True
+ return False
+
+
+def parse_dict(opts: Type[T], obj: dict, strict: bool = True) -> T:
+ """Parse an untyped nested dictionary ``obj``` into a dataclass ``opts```
+
+ If ``strict`` is true, then the types of the obj will be validated at runtime
+ """
+
+ def construct_dataclass(dataclass_type, data):
+ init_kwargs = {}
+ for field in fields(dataclass_type):
+ field_name = field.name
+ if field_name in data:
+ value = data[field_name]
+ is_optional, T, _ = _get_type_and_meta(field.type)
+ if is_dataclass(T) and isinstance(value, dict):
+ # Recursively construct nested dataclass
+ value = construct_dataclass(T, value)
+ if strict and T is not Any:
+ # Allow None for optional types
+ if not (is_optional and value is None):
+ origin = get_origin(T)
+ check_type = origin or T
+ if not isinstance(value, check_type):
+ raise ValueError(field, "is not a", T)
+ init_kwargs[field_name] = value
+ elif field.default is not MISSING:
+ init_kwargs[field_name] = field.default
+ elif field.default_factory is not MISSING:
+ init_kwargs[field_name] = field.default_factory()
+ else:
+ raise ValueError(
+ f"Field '{field_name}' is required but not provided in data."
+ )
+
+ return dataclass_type(**init_kwargs)
+
+ return construct_dataclass(opts, obj)
diff --git a/examples/weather/healda/utils/distributed.py b/examples/weather/healda/utils/distributed.py
new file mode 100644
index 0000000000..d4857f15a4
--- /dev/null
+++ b/examples/weather/healda/utils/distributed.py
@@ -0,0 +1,114 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import datetime
+import os
+
+import torch
+from training import training_stats
+
+# ----------------------------------------------------------------------------
+
+
+def init(timeout_infinite=False):
+ if "WORLD_SIZE" not in os.environ:
+ if "SLURM_NTASKS" in os.environ:
+ os.environ["WORLD_SIZE"] = os.environ.get("SLURM_NTASKS", "1")
+ else:
+ os.environ["WORLD_SIZE"] = "1"
+ if "MASTER_ADDR" not in os.environ:
+ if (
+ int(os.environ["WORLD_SIZE"]) > 1
+ and "SLURM_LAUNCH_NODE_IPADDR" in os.environ
+ ):
+ os.environ["MASTER_ADDR"] = os.environ.get(
+ "SLURM_LAUNCH_NODE_IPADDR", "localhost"
+ )
+ else:
+ os.environ["MASTER_ADDR"] = "localhost"
+ if "MASTER_PORT" not in os.environ:
+ os.environ["MASTER_PORT"] = "29500"
+ if "RANK" not in os.environ:
+ if "SLURM_PROCID" in os.environ:
+ os.environ["RANK"] = os.environ.get("SLURM_PROCID", "0")
+ else:
+ os.environ["RANK"] = "0"
+ if "LOCAL_RANK" not in os.environ:
+ if "SLURM_LOCALID" in os.environ:
+ os.environ["LOCAL_RANK"] = os.environ.get("SLURM_LOCALID", "0")
+ else:
+ os.environ["LOCAL_RANK"] = "0"
+
+ backend = "gloo" if os.name == "nt" else "nccl"
+ if timeout_infinite:
+ timeout = datetime.timedelta(days=365)
+ else:
+ timeout = None
+
+ device_id = int(os.environ.get("LOCAL_RANK", "0"))
+ torch.cuda.set_device(device_id)
+ torch.distributed.init_process_group(
+ backend=backend,
+ init_method="env://",
+ timeout=timeout,
+ device_id=torch.device("cuda", index=device_id),
+ )
+
+ sync_device = torch.device("cuda") if get_world_size() > 1 else None
+ training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device)
+
+
+# ----------------------------------------------------------------------------
+
+
+def get_rank():
+ """Return current process rank, or 0 if not distributed."""
+ return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
+
+
+# ----------------------------------------------------------------------------
+
+
+def get_world_size():
+ """Return world size, or 1 if not distributed."""
+ return (
+ torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
+ )
+
+
+# ----------------------------------------------------------------------------
+
+
+def should_stop():
+ return False
+
+
+# ----------------------------------------------------------------------------
+
+
+def update_progress(cur, total):
+ """Progress callback stub (no-op by default)."""
+ _ = cur, total
+
+
+# ----------------------------------------------------------------------------
+
+
+def print0(*args, **kwargs):
+ if get_rank() == 0:
+ print(*args, **kwargs)
+
+
+# ----------------------------------------------------------------------------
diff --git a/examples/weather/healda/utils/profiling.py b/examples/weather/healda/utils/profiling.py
new file mode 100644
index 0000000000..04c9530253
--- /dev/null
+++ b/examples/weather/healda/utils/profiling.py
@@ -0,0 +1,60 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import functools
+import os
+from contextlib import contextmanager, nullcontext
+
+import torch
+
+NVTX_ENABLED = os.environ.get("HEALDA_NVTX", "0") == "1"
+
+
+def nvtx(func=None, *, enabled: bool | None = None):
+ def decorator(fn):
+ use_nvtx = NVTX_ENABLED if enabled is None else enabled
+ if not use_nvtx:
+ return fn
+
+ tag = fn.__module__ + ":" + fn.__qualname__
+
+ @functools.wraps(fn)
+ def wrapper(*args, **kwargs):
+ torch.cuda.nvtx.range_push(tag)
+ out = fn(*args, **kwargs)
+ torch.cuda.nvtx.range_pop()
+ return out
+
+ return wrapper
+
+ if func is not None:
+ return decorator(func)
+ return decorator
+
+
+@contextmanager
+def _nvtx_range_impl(tag: str):
+ torch.cuda.nvtx.range_push(tag)
+ try:
+ yield
+ finally:
+ torch.cuda.nvtx.range_pop()
+
+
+def nvtx_range(tag: str, enabled: bool | None = None):
+ use_nvtx = NVTX_ENABLED if enabled is None else enabled
+ if use_nvtx:
+ return _nvtx_range_impl(tag)
+ return nullcontext()
diff --git a/examples/weather/healda/utils/signals.py b/examples/weather/healda/utils/signals.py
new file mode 100644
index 0000000000..d669aa3825
--- /dev/null
+++ b/examples/weather/healda/utils/signals.py
@@ -0,0 +1,72 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for catching unix signals and gracefully exiting
+
+Usage:
+```
+from utils import signals
+import signal
+
+# now can catch signals with exceptions
+signal.signal(signal.SIGTERM, signals.handler)
+
+
+def do_stuff():
+ do_optional_stuff()
+ with signals.finish_before_quitting():
+ do_stuff_i_need_to_finish()
+
+try:
+ do_stuff()
+except signals.QuitEarly:
+ cleanup()
+
+```
+
+"""
+
+
+class QuitEarly(Exception):
+ """Exception raised to signal early termination from a signal handler."""
+
+ depth = 0
+ quit_requested = False
+
+
+def finish_before_quitting(func):
+ """If signal caught defer quitting until the wrapped line of code completes
+
+ Used to handle sensitive code blocks
+ """
+
+ def newfunc(*args, **kwargs):
+ QuitEarly.depth += 1
+ func(*args, **kwargs)
+ QuitEarly.depth -= 1
+
+ if QuitEarly.quit_requested and QuitEarly.depth == 0:
+ QuitEarly.quit_requested = False
+ raise QuitEarly()
+
+ return newfunc
+
+
+def handler(signum, frame):
+ """Signal handler that raises QuitEarly when safe to do so."""
+ if QuitEarly.depth == 0:
+ raise QuitEarly(signum, frame)
+ else:
+ QuitEarly.quit_requested = True
diff --git a/examples/weather/healda/utils/storage.py b/examples/weather/healda/utils/storage.py
new file mode 100644
index 0000000000..104f5c415a
--- /dev/null
+++ b/examples/weather/healda/utils/storage.py
@@ -0,0 +1,162 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import configparser
+import os
+import shutil
+import sys
+import tempfile
+
+import fsspec
+
+DEFAULT_PATH = os.path.expanduser("~/.config/rclone/rclone.conf")
+
+
+class StorageConfigError(Exception):
+ """Exception raised when storage configuration is invalid or missing."""
+
+ pass
+
+
+def get_remote_config(remote_name, config_path=DEFAULT_PATH):
+ """Parse rclone config and return the section for the given remote."""
+ if not remote_name:
+ return None
+ # Parse the rclone config file
+ config = configparser.ConfigParser()
+ config.read(config_path)
+
+ # Ensure the remote exists in the config
+ if remote_name not in config:
+ raise StorageConfigError(f"Remote '{remote_name}' not found in rclone config.")
+
+ # Extract credentials from the config
+ remote_config = config[remote_name]
+ return remote_config
+
+
+def get_storage_options(remote_name, config_path=DEFAULT_PATH):
+ """Return S3 storage options dict for fsspec from rclone remote config."""
+ remote_config = get_remote_config(remote_name, config_path)
+
+ if remote_config is None:
+ return None
+
+ if remote_config.get("type") != "s3":
+ raise StorageConfigError(f"Remote '{remote_name}' is not an S3 remote.")
+
+ access_key = remote_config.get("access_key_id")
+ secret_key = remote_config.get("secret_access_key")
+ endpoint_url = remote_config.get("endpoint", None) # Optional endpoint
+
+ if not access_key or not secret_key:
+ raise StorageConfigError(
+ f"Access key or secret key missing for remote '{remote_name}'."
+ )
+
+ # Instantiate and return the S3FileSystem object
+ return dict(
+ key=access_key,
+ secret=secret_key,
+ client_kwargs={"endpoint_url": endpoint_url} if endpoint_url else None,
+ )
+
+
+def get_polars_storage_options(profile):
+ """Return S3 storage options dict for Polars from rclone remote config."""
+ opts = get_storage_options(profile)
+ key = opts["key"]
+ secret = opts["secret"]
+ endpoint = opts["client_kwargs"]["endpoint_url"]
+ return {
+ "aws_access_key_id": key,
+ "aws_secret_access_key": secret,
+ "aws_endpoint_url": endpoint,
+ }
+
+
+def get_duckdb_connection(profile):
+ """Return a DuckDB connection configured with S3 credentials from rclone."""
+ import duckdb
+
+ opts = get_storage_options(profile)
+ con = duckdb.connect()
+ key = opts["key"]
+ secret = opts["secret"]
+ endpoint = opts["client_kwargs"]["endpoint_url"]
+ if endpoint.startswith("https://"):
+ endpoint = endpoint[len("https://") :]
+ con.execute(f"""
+ CREATE SECRET (
+ TYPE s3,
+ PROVIDER config,
+ ENDPOINT '{endpoint}',
+ KEY_ID '{key}',
+ SECRET '{secret}'
+ );
+ """)
+
+ return con
+
+
+def ensure_downloaded(url, local):
+ if os.path.exists(local):
+ return
+
+ fs = fsspec.filesystem("http")
+ print(f"Downloading from {url} to {local}", file=sys.stderr)
+ with tempfile.TemporaryDirectory() as d:
+ tmpfile = os.path.join(d, "file")
+ fs.get(url, tmpfile)
+ os.makedirs(os.path.dirname(local), exist_ok=True)
+ shutil.move(tmpfile, local)
+
+
+def _get_endpoint(opts):
+ return opts["client_kwargs"]["endpoint_url"]
+
+
+def get_pyarrow_filesystem(profile: str, **kwargs):
+ """Return a PyArrow S3FileSystem configured from rclone remote profile."""
+ import pyarrow.fs
+
+ opts = get_storage_options(profile)
+ if opts is None:
+ return None
+
+ return pyarrow.fs.S3FileSystem(
+ access_key=opts.get("key"),
+ secret_key=opts.get("secret"),
+ region=opts.get("region", ""),
+ endpoint_override=_get_endpoint(opts),
+ **kwargs,
+ )
+
+
+def get_obstore(profile: str, bucket=None, **kwargs):
+ """Return an obstore S3Store configured from rclone remote profile."""
+ from obstore.store import S3Store
+
+ opts = get_storage_options(profile)
+ if opts is None:
+ return None
+
+ return S3Store(
+ bucket=bucket,
+ access_key_id=opts.get("key"),
+ secret_access_key=opts.get("secret"),
+ endpoint=_get_endpoint(opts),
+ **kwargs,
+ )
diff --git a/examples/weather/healda/utils/visualizations.py b/examples/weather/healda/utils/visualizations.py
new file mode 100644
index 0000000000..f8b0a649b2
--- /dev/null
+++ b/examples/weather/healda/utils/visualizations.py
@@ -0,0 +1,274 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import types
+
+import cartopy.crs
+import cartopy.feature
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from earth2grid import healpix
+from matplotlib.colors import LogNorm
+
+# Parameters from the proj4 string
+central_longitude = -85.0
+central_latitude = 30.0
+standard_parallels = (23.0, 37.0)
+
+# Create the LambertConformalConic projection
+caribbean_se_us_proj = cartopy.crs.LambertConformal(
+ central_longitude=central_longitude,
+ central_latitude=central_latitude,
+ standard_parallels=standard_parallels,
+ globe=cartopy.crs.Globe(ellipse="WGS84"),
+)
+
+
+_projections = {
+ "PlateCarree": cartopy.crs.PlateCarree(),
+ "Robinson": cartopy.crs.Robinson(),
+ "Robinson_180": cartopy.crs.Robinson(180),
+ "conus": cartopy.crs.epsg(5069),
+ "south_pole": cartopy.crs.SouthPolarStereo(),
+ "north_pole": cartopy.crs.NorthPolarStereo(),
+ "carib": caribbean_se_us_proj,
+ "warm_pool": cartopy.crs.PlateCarree(),
+ "conus_large": cartopy.crs.PlateCarree(),
+}
+
+_extents = {
+ "carib": [-100, -60, 10, 40],
+ "warm_pool": [85, 160, -15, 20],
+ "conus_large": [-130, -60, 18, 50],
+}
+
+
+def get_lim(target_crs, extents):
+ """Return axis limits for the given CRS and optional lat/lon extents."""
+ if target_crs in [cartopy.crs.SouthPolarStereo(), cartopy.crs.NorthPolarStereo()]:
+ lim = 4_500_000
+ return -lim, lim, -lim, lim
+ elif extents:
+ return latlon_to_grid_extents(target_crs, extents)
+ else:
+ return target_crs.x_limits + target_crs.y_limits
+
+
+def create_regular_grid_in_projection(projection, nx, ny, extents=None):
+ """
+ Create a regular grid of lat-lon coordinates in a given Cartopy projection.
+
+ Parameters:
+ projection (cartopy.crs.Projection): The desired Cartopy projection
+ resolution (float): The grid resolution in projection units
+
+ Returns:
+ tuple: Two 2D arrays, one for latitudes and one for longitudes
+ """
+ # Get the projection's limits
+ x_min, x_max, y_min, y_max = get_lim(projection, extents)
+ # Create a regular grid in the projection coordinates
+ x = np.linspace(x_min, x_max, nx)
+ y = np.linspace(y_min, y_max, ny)
+ xx, yy = np.meshgrid(x, y)
+
+ # Transform the gridded coordinates back to lat-lon
+ geodetic = cartopy.crs.Geodetic()
+ transformed = geodetic.transform_points(projection, xx, yy)
+
+ lons = transformed[..., 0]
+ lats = transformed[..., 1]
+
+ # Filter out invalid points (those outside the projection's valid domain)
+ valid = np.logical_and(np.isfinite(lons), np.isfinite(lats))
+ lons[~valid] = np.nan
+ lats[~valid] = np.nan
+
+ return lats, lons, xx, yy
+
+
+def latlon_to_grid_extents(proj, extents):
+ lon_min, lon_max, lat_min, lat_max = extents
+ # Create arrays of lat-lon points
+ lons = np.array([lon_min, lon_max, lon_max, lon_min])
+ lats = np.array([lat_min, lat_min, lat_max, lat_max])
+
+ # Transform to projection coordinates
+ x, y = proj.transform_points(cartopy.crs.PlateCarree(), lons, lats).T[:2]
+
+ # Get the min and max x and y values
+ x_min, x_max = np.min(x), np.max(x)
+ y_min, y_max = np.min(y), np.max(y)
+
+ return x_min, x_max, y_min, y_max
+
+
+def visualize(
+ x,
+ region="Robinson",
+ projection=None,
+ extents=None,
+ nest=False,
+ hpxpad=False,
+ pos=None,
+ n=None,
+ title=None,
+ colorbar_label=None,
+ cmap=None,
+ coastlines_color="k",
+ add_colorbar=True,
+ extend=None,
+ nlat=256,
+ nlon=512,
+ border=False,
+ kind="pcolormesh",
+ interp_method="bilinear",
+ **kw,
+):
+ if x.ndim != 1:
+ raise ValueError(f"Expected 1D input but received {x.ndim}.")
+
+ if projection is None:
+ crs = _projections[region]
+ extents = _extents.get(region, None)
+ else:
+ crs = projection
+
+ if nest:
+ pixel_order = healpix.NEST
+ elif hpxpad:
+ pixel_order = healpix.HEALPIX_PAD_XY
+ else:
+ pixel_order = healpix.RING
+
+ hpx = healpix.Grid(healpix.npix2level(x.shape[-1]), pixel_order=pixel_order)
+ lat, lon, xx, yy = create_regular_grid_in_projection(
+ crs, nlat, nlon, extents=extents
+ )
+ cmap = plt.get_cmap(cmap, n)
+ mask = ~np.isnan(lat)
+ latm = lat[mask]
+ lonm = lon[mask]
+ x = torch.as_tensor(x)
+ out = torch.zeros_like(torch.tensor(lat)).to(x)
+ if interp_method == "bilinear":
+ regrid = hpx.get_bilinear_regridder_to(latm, lonm)
+ regrid.to(x)
+ out[mask] = regrid(x)
+ out[~mask] = torch.nan
+ elif interp_method == "bin":
+ pix = hpx.ang2pix(
+ torch.as_tensor(lonm, device=x.device),
+ torch.as_tensor(latm, device=x.device),
+ )
+ out[mask] = x[pix]
+ out[~mask] = torch.nan
+ else:
+ raise ValueError(f"Invalid interpolation method: {interp_method}")
+
+ if isinstance(pos, tuple):
+ subplot_args = pos
+ elif pos is not None:
+ subplot_args = (pos,)
+ else:
+ subplot_args = ()
+
+ ax = plt.subplot(*subplot_args, projection=crs)
+ plot_func = getattr(ax, kind)
+ im = plot_func(xx, yy, out.cpu(), transform=crs, cmap=cmap, **kw)
+ ax.coastlines(color=coastlines_color)
+
+ if border:
+ ax.add_feature(cartopy.feature.BORDERS, linewidth=1, edgecolor="lightgray")
+ ax.add_feature(cartopy.feature.STATES, linewidth=0.4, edgecolor="lightgray")
+
+ cb = None
+ if add_colorbar:
+ cb = plt.colorbar(im, orientation="horizontal", extend=extend)
+ if colorbar_label:
+ cb.set_label(colorbar_label)
+ if title:
+ ax.set_title(title)
+
+ return types.SimpleNamespace(ax=ax, im=im, cb=cb)
+
+
+def plot_in_lat_lon(
+ ax,
+ x,
+ lat,
+ lon,
+ lat_min,
+ lat_max,
+ lon_min,
+ lon_max,
+ pr,
+ regrid,
+ vmin,
+ vmax,
+ nest=False,
+ hpxpad=False,
+):
+ if regrid:
+ if x.ndim != 1:
+ raise ValueError(f"Expected 1D input but received {x.ndim}.")
+ if nest:
+ pixel_order = healpix.NEST
+ if hpxpad:
+ pixel_order = healpix.HEALPIX_PAD_XY
+ else:
+ pixel_order = healpix.RING
+ hpx = healpix.Grid(healpix.npix2level(x.shape[-1]), pixel_order=pixel_order)
+ mask = ~np.isnan(lat)
+ latm = lat[mask]
+ lonm = lon[mask]
+ x = torch.as_tensor(x)
+ regrid = hpx.get_bilinear_regridder_to(latm, lonm)
+ regrid.to(x)
+ data = torch.zeros_like(torch.tensor(lat)).to(x)
+ data[mask] = regrid(x)
+ data[~mask] = torch.nan
+ else:
+ data = x
+ lon_max = max(lon_max, np.min(np.max(lon, axis=1)))
+ ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=cartopy.crs.PlateCarree())
+ if pr:
+ data[data <= np.exp(-12)] = np.exp(-12)
+ vmin = np.exp(-12)
+ vmax = np.exp(-4)
+ mesh = ax.pcolormesh(
+ lon,
+ lat,
+ data,
+ transform=cartopy.crs.PlateCarree(),
+ norm=LogNorm(vmin=vmin, vmax=vmax),
+ cmap="magma",
+ )
+ else:
+ mesh = ax.pcolormesh(
+ lon,
+ lat,
+ data,
+ transform=cartopy.crs.PlateCarree(),
+ cmap="magma",
+ vmin=vmin,
+ vmax=vmax,
+ )
+ mesh.set_rasterized(True)
+ ax.coastlines(resolution="110m", linewidth=1, color="lightgray")
+ ax.add_feature(cartopy.feature.BORDERS, linewidth=1, edgecolor="lightgray")
+ ax.add_feature(cartopy.feature.STATES, linewidth=0.4, edgecolor="lightgray")
+ return mesh
diff --git a/physicsnemo/experimental/models/dit/dit.py b/physicsnemo/experimental/models/dit/dit.py
index 4d2d1c28da..dccfca4735 100644
--- a/physicsnemo/experimental/models/dit/dit.py
+++ b/physicsnemo/experimental/models/dit/dit.py
@@ -17,13 +17,18 @@
from typing import Tuple, Union, Optional, Literal, Dict, Any
import torch
import torch.nn as nn
-
-from physicsnemo.nn import PositionalEmbedding, Linear
from dataclasses import dataclass
from physicsnemo.core.meta import ModelMetaData
from physicsnemo.core.module import Module
from physicsnemo.experimental.models.dit import DiTBlock
-from physicsnemo.experimental.models.dit.layers import get_tokenizer, get_detokenizer, TokenizerModuleBase, DetokenizerModuleBase
+from physicsnemo.experimental.models.dit.layers import (
+ get_tokenizer,
+ get_detokenizer,
+ get_conditioning_embedder,
+ TokenizerModuleBase,
+ DetokenizerModuleBase,
+ ConditioningEmbedderBase,
+)
@dataclass
@@ -106,6 +111,8 @@ class DiT(Module):
Additional keyword arguments to be passed to :class:`physicsnemo.nn.PositionalEmbedding`.
attn_kwargs (Dict[str, Any], optional):
Additional keyword arguments for the attention module constructor, if using a custom attention backend.
+ drop_path (float, optional):
+ DropPath rate for stochastic depth. Uses linear schedule from 0 to drop_path across blocks. Defaults to 0.0.
force_tokenization_fp32 (bool, optional):
If True, forces the tokenization and de-tokenization operations to be run in fp32. Defaults to False.
@@ -121,6 +128,8 @@ class DiT(Module):
The dropout probability for the intermediate dropout module (pre-attention) in the DiTBlock. If None, no dropout will be applied.
If a scalar, the same dropout probability will be applied to all samples in the batch.
Otherwise, it should be a tensor of shape (B,) to apply per-sample dropout to each sample in a batch.
+ tokenizer_kwargs (Optional[Dict[str, Any]]):
+ Additional keyword arguments passed to the tokenizer's forward method.
Returns
-------
@@ -164,12 +173,14 @@ def __init__(
attention_backend: Literal["timm", "transformer_engine", "natten2d"] = "transformer_engine",
layernorm_backend: Literal["apex", "torch"] = "torch",
condition_dim: Optional[int] = None,
+ conditioning_embedder: Union[Literal["post_mlp", "pre_mlp"], Module] = "post_mlp",
dit_initialization: Optional[int] = True,
+ conditioning_embedder_kwargs: Dict[str, Any] = {},
tokenizer_kwargs: Dict[str, Any] = {},
detokenizer_kwargs: Dict[str, Any] = {},
block_kwargs: Dict[str, Any] = {},
- timestep_embed_kwargs: Dict[str, Any] = {},
attn_kwargs: Dict[str, Any] = {},
+ drop_path: float = 0.0,
force_tokenization_fp32: bool = False,
):
super().__init__(meta=MetaData())
@@ -219,20 +230,19 @@ def __init__(
raise TypeError("tokenizer must be a string or a physicsnemo.core.Module instance subclassing physicsnemo.experimental.models.dit.layers.TokenizerModuleBase")
self.tokenizer = tokenizer
- self.t_embedder = PositionalEmbedding(hidden_size, amp_mode=self.meta.amp_gpu, learnable=True, **timestep_embed_kwargs)
- self.cond_embedder = (
- Linear(
- in_features=condition_dim,
- out_features=hidden_size,
- bias=False,
+ # Conditioning embedder: accept string or pre-instantiated Module
+ if isinstance(conditioning_embedder, str):
+ self.conditioning_embedder = get_conditioning_embedder(
+ hidden_size=hidden_size,
+ conditioning_embedder=conditioning_embedder,
+ condition_dim=condition_dim or 0,
amp_mode=self.meta.amp_gpu,
- init_mode="kaiming_uniform",
- init_weight=0,
- init_bias=0,
+ **conditioning_embedder_kwargs,
)
- if condition_dim
- else None
- )
+ else:
+ if not isinstance(conditioning_embedder, ConditioningEmbedderBase):
+ raise TypeError("conditioning_embedder must be a string or a physicsnemo.core.Module instance subclassing physicsnemo.experimental.models.dit.layers.ConditioningEmbedderBase")
+ self.conditioning_embedder = conditioning_embedder
# Detokenizer module: accept string or pre-instantiated PhysicsNeMo Module
if isinstance(detokenizer, str):
@@ -251,8 +261,11 @@ def __init__(
self.detokenizer = detokenizer
+ # Linear drop_path schedule: 0 -> drop_path
+ drop_path_rates = [drop_path * i / max(1, depth - 1) for i in range(depth)]
+
blocks = []
- for _ in range(depth):
+ for i in range(depth):
if isinstance(attention_backend, str):
attn_module = attention_backend
else:
@@ -266,6 +279,8 @@ def __init__(
attention_backend=attn_module,
layernorm_backend=layernorm_backend,
mlp_ratio=mlp_ratio,
+ drop_path=drop_path_rates[i],
+ condition_dim=self.conditioning_embedder.output_dim,
**block_kwargs,
**attn_kwargs,
)
@@ -299,31 +314,20 @@ def forward(
condition: Optional[torch.Tensor] = None,
p_dropout: Optional[float | torch.Tensor] = None,
attn_kwargs: Optional[Dict[str, Any]] = None,
+ tokenizer_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
# Tokenize: (B, C, H, W) -> (B, L, D)
if self.force_tokenization_fp32:
dtype = x.dtype
x = x.to(torch.float32)
with torch.autocast(device_type="cuda", enabled=False):
- x = self.tokenizer(x)
+ x = self.tokenizer(x, **(tokenizer_kwargs or {}))
x = x.to(dtype)
else:
- x = self.tokenizer(x)
+ x = self.tokenizer(x, **(tokenizer_kwargs or {}))
- t = self.t_embedder(t) # (B, D)
-
- # Handle conditioning
- if self.cond_embedder is not None:
- if condition is None:
- # Fallback to using only timestep embedding if conditioning is not provided
- c = t
- else:
- condition_embedding = self.cond_embedder(condition) # (B, D)
- c = t + condition_embedding # (B, D)
- else:
- if condition is not None:
- raise ValueError("Conditioning was provided but DiT has no conditioning embedding module.")
- c = t # (B, D)
+ # Compute conditioning embedding
+ c = self.conditioning_embedder(t, condition=condition) # (B, D)
for block in self.blocks:
x = block(x, c, p_dropout=p_dropout, attn_kwargs={**self.attn_kwargs_forward, **(attn_kwargs or {})}) # (B, L, D)
diff --git a/physicsnemo/experimental/models/dit/layers.py b/physicsnemo/experimental/models/dit/layers.py
index e0db7e402c..25b618a241 100644
--- a/physicsnemo/experimental/models/dit/layers.py
+++ b/physicsnemo/experimental/models/dit/layers.py
@@ -13,7 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import math
import warnings
+from functools import partial
from typing import Any, Dict, Literal, Union, Optional, Tuple
from abc import ABC, abstractmethod
import torch
@@ -34,6 +36,7 @@
from timm.layers.attention import Attention
else:
from timm.models.vision_transformer import Attention
+from timm.layers import RmsNorm
try:
from transformer_engine.pytorch import MultiheadAttention
@@ -55,12 +58,12 @@
NATTEN_AVAILABLE = False
from physicsnemo.core import Module
-from physicsnemo.nn import Mlp
+from physicsnemo.nn import Mlp, PositionalEmbedding, Linear
from physicsnemo.domain_parallel import ShardTensor
from physicsnemo.domain_parallel.shard_utils.natten_patches import partial_na2d
+from physicsnemo.nn import DropPath
from physicsnemo.nn.module.utils import PatchEmbed2D
-
def get_layer_norm(
hidden_size: int,
layernorm_backend: Literal["apex", "torch"],
@@ -175,6 +178,11 @@ class TimmSelfAttention(AttentionModuleBase):
The dropout rate for the attention operation.
proj_drop_rate: float
The dropout rate for the projection operation.
+ qk_norm_type: str, optional
+ QK normalization type. Options: "RMSNorm", "LayerNorm", or None.
+ Translated to timm's ``qk_norm=True`` and ``norm_layer``.
+ qk_norm_affine: bool, optional
+ Whether QK normalization layers should use learnable affine parameters.
**kwargs: Any
Additional keyword arguments for the timm attention module.
@@ -191,8 +199,26 @@ class TimmSelfAttention(AttentionModuleBase):
torch.Tensor
Output tensor of shape (B, L, D).
"""
- def __init__(self, hidden_size: int, num_heads: int, attn_drop_rate: float = 0.0, proj_drop_rate: float = 0.0, **kwargs: Any):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ attn_drop_rate: float = 0.0,
+ proj_drop_rate: float = 0.0,
+ qk_norm_type: Optional[Literal["RMSNorm", "LayerNorm"]] = None,
+ qk_norm_affine: bool = True,
+ **kwargs: Any,
+ ):
super().__init__()
+
+ # Translate qk_norm_type to timm's qk_norm and norm_layer
+ if qk_norm_type == "RMSNorm":
+ kwargs["qk_norm"] = True
+ kwargs["norm_layer"] = partial(RmsNorm, affine=qk_norm_affine)
+ elif qk_norm_type == "LayerNorm":
+ kwargs["qk_norm"] = True
+ kwargs["norm_layer"] = nn.LayerNorm
+
self.attn_op = Attention(dim=hidden_size, num_heads=num_heads, attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, qkv_bias=True, **kwargs)
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
@@ -220,6 +246,9 @@ class TESelfAttention(AttentionModuleBase):
The dropout rate for the attention operation.
proj_drop_rate: float
The dropout rate for the projection operation.
+ qkv_format: str, optional
+ Dimension format for Q/K/V tensors. Default ``"bshd"`` for batch-first layout.
+ Use ``"sbhd"`` for sequence-first layout.
**kwargs: Any
Additional keyword arguments for the transformer_engine attention module.
@@ -239,24 +268,46 @@ class TESelfAttention(AttentionModuleBase):
torch.Tensor
Output tensor of shape (B, L, D).
"""
- def __init__(self, hidden_size: int, num_heads: int, attn_drop_rate: float = 0.0, proj_drop_rate: float = 0.0, **kwargs: Any):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ attn_drop_rate: float = 0.0,
+ proj_drop_rate: float = 0.0,
+ qkv_format: str = "bshd",
+ **kwargs: Any,
+ ):
super().__init__()
if not TE_AVAILABLE:
raise ImportError(
"Transformer Engine is not installed. Please install it with `pip install transformer-engine`."
)
- if proj_drop_rate > 0:
+ if "qk_norm_affine" in kwargs and not kwargs["qk_norm_affine"]:
warnings.warn(
- "Transformer Engine MultiheadAttention does not support projection dropout (proj_drop_rate > 0). "
- "The specified proj_drop_rate will be ignored."
+ "Transformer Engine does not support disabling affine parameters for QK norm. "
+ "Ignoring qk_norm_affine=False and using affine parameters.",
+ UserWarning,
+ stacklevel=2,
)
- self.attn_op = MultiheadAttention(hidden_size=hidden_size, num_attention_heads=num_heads, attention_dropout=attn_drop_rate, **kwargs)
+ kwargs.pop("qk_norm_affine", None)
+ self.attn_op = MultiheadAttention(
+ hidden_size=hidden_size,
+ num_attention_heads=num_heads,
+ attention_dropout=attn_drop_rate,
+ qkv_format=qkv_format,
+ **kwargs,
+ )
+ # TE doesn't support proj_drop natively, so we add it manually after the attention output
+ self.proj_drop = nn.Dropout(proj_drop_rate)
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, mask_type: Optional[str] = "no_mask") -> torch.Tensor:
if attn_mask is not None:
mask_type = "arbitrary"
- return self.attn_op(x, attention_mask=attn_mask, attn_mask_type=mask_type)
+ out = self.attn_op(x, attention_mask=attn_mask, attn_mask_type=mask_type)
+ return self.proj_drop(out)
+
+
class Natten2DSelfAttention(AttentionModuleBase):
@@ -507,11 +558,15 @@ def __init__(
num_heads: int,
attention_backend: Union[Literal["transformer_engine", "timm", "natten2d"], Module] = "transformer_engine",
layernorm_backend: Literal["apex", "torch"] = "torch",
+ norm_eps: float = 1e-6,
mlp_ratio: float = 4.0,
intermediate_dropout: bool = False,
attn_drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
mlp_drop_rate: float = 0.0,
+ final_mlp_dropout: bool = True,
+ drop_path: float = 0.0,
+ condition_dim: Optional[int] = None,
**attn_kwargs: Any,
):
super().__init__()
@@ -529,10 +584,10 @@ def __init__(
)
self.pre_attention_norm = get_layer_norm(
- hidden_size, layernorm_backend, elementwise_affine=False, eps=1e-6
+ hidden_size, layernorm_backend, elementwise_affine=False, eps=norm_eps
)
self.pre_mlp_norm = get_layer_norm(
- hidden_size, layernorm_backend, elementwise_affine=False, eps=1e-6
+ hidden_size, layernorm_backend, elementwise_affine=False, eps=norm_eps
)
# Optional dropout/per-sample dropout module applied before attention
@@ -546,15 +601,19 @@ def __init__(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=lambda: nn.GELU(approximate="tanh"),
- drop=0,
+ drop=mlp_drop_rate,
+ final_dropout=final_mlp_dropout,
)
+ modulation_input_dim = hidden_size if condition_dim is None else condition_dim
self.adaptive_modulation = nn.Sequential(
- nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ nn.SiLU(), nn.Linear(modulation_input_dim, 6 * hidden_size, bias=True)
)
self.modulation = lambda x, scale, shift: x * (
1 + scale.unsqueeze(1)
) + shift.unsqueeze(1)
+ self.drop_path = DropPath(drop_path)
+
def initialize_weights(self):
# Zero out the adaptive modulation weights
nn.init.constant_(self.adaptive_modulation[-1].weight, 0)
@@ -592,14 +651,14 @@ def forward(
modulated_attn_input,
**(attn_kwargs or {}),
)
- x = x + attention_gate.unsqueeze(1) * attention_output
+ x = torch.addcmul(x, self.drop_path(attention_gate.unsqueeze(1)), attention_output)
# Feed-forward block
modulated_mlp_input = self.modulation(
self.pre_mlp_norm(x), mlp_scale, mlp_shift
)
mlp_output = self.linear(modulated_mlp_input)
- x = x + mlp_gate.unsqueeze(1) * mlp_output
+ x = torch.addcmul(x, self.drop_path(mlp_gate.unsqueeze(1)), mlp_output)
return x
@@ -982,3 +1041,262 @@ def get_detokenizer(
**detokenizer_kwargs,
)
raise ValueError("detokenizer must be 'proj_reshape_2d', no other supported detokenizers are available yet.")
+
+class ConditioningEmbedderBase(Module, ABC):
+ r"""Abstract base class for conditioning embedders used in DiT.
+
+ Computes conditioning embedding from timestep and optional additional inputs.
+ Implementations must define a forward method and specify their output dimension.
+
+ Forward
+ -------
+ t : torch.Tensor
+ Timestep tensor of shape :math:`(B,)`.
+ **kwargs
+ Additional conditioning inputs (e.g., condition, class_labels).
+
+ Returns
+ -------
+ torch.Tensor
+ Conditioning embedding of shape :math:`(B, D)` where D is ``output_dim``.
+ """
+
+ @property
+ @abstractmethod
+ def output_dim(self) -> int:
+ """Output dimension of conditioning embedding (used for block condition_dim)."""
+ pass
+
+ @abstractmethod
+ def forward(self, t: torch.Tensor, **kwargs) -> torch.Tensor:
+ pass
+
+
+class PostMLPConditionEmbedder(ConditioningEmbedderBase):
+ r"""Conditioning embedder that adds condition AFTER the MLP.
+
+ Architecture: ``t โ PositionalEmbedding โ MLP โ output``, then
+ ``condition โ Linear โ ADD`` (added to output).
+
+ Timestep and condition are processed independently and combined at the end.
+
+ Parameters
+ ----------
+ hidden_size : int
+ Output embedding dimension.
+ condition_dim : int, optional
+ Input condition dimension. If 0, no condition embedding is used.
+ max_positions : int, optional
+ Maximum positions for positional embedding. Default 10000.
+ learnable : bool, optional
+ Whether to use learnable MLP after positional embedding. Default True.
+ mlp_hidden_dim : int, optional
+ Hidden dimension of learnable MLP. Defaults to 2 * hidden_size.
+ amp_mode : bool, optional
+ Whether mixed-precision (AMP) training is enabled. Default False.
+
+ Forward
+ -------
+ t : torch.Tensor
+ Timestep tensor of shape :math:`(B,)`.
+ condition : torch.Tensor, optional
+ Condition tensor of shape :math:`(B, condition_dim)`.
+
+ Returns
+ -------
+ torch.Tensor
+ Conditioning embedding of shape :math:`(B, hidden_size)`.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ condition_dim: int = 0,
+ max_positions: int = 10000,
+ learnable: bool = True,
+ mlp_hidden_dim: int | None = None,
+ amp_mode: bool = False,
+ ):
+ super().__init__()
+ self._output_dim = hidden_size
+
+ self.t_embedder = PositionalEmbedding(
+ num_channels=hidden_size,
+ max_positions=max_positions,
+ learnable=learnable,
+ mlp_hidden_dim=mlp_hidden_dim,
+ amp_mode=amp_mode,
+ )
+
+ self.cond_embedder = (
+ Linear(
+ in_features=condition_dim,
+ out_features=hidden_size,
+ bias=False,
+ amp_mode=amp_mode,
+ init_mode="kaiming_uniform",
+ init_weight=0,
+ init_bias=0,
+ )
+ if condition_dim
+ else None
+ )
+
+ @property
+ def output_dim(self) -> int:
+ return self._output_dim
+
+ def forward(
+ self, t: torch.Tensor, condition: torch.Tensor | None = None, **kwargs
+ ) -> torch.Tensor:
+ c = self.t_embedder(t)
+
+ if self.cond_embedder is not None and condition is not None:
+ c = c + self.cond_embedder(condition)
+
+ return c
+
+
+class PreMLPConditionEmbedder(ConditioningEmbedderBase):
+ r"""Conditioning embedder that adds labels BEFORE the MLP.
+
+ Architecture: ``t โ PositionalEmbedding โ flip(sin/cos) โ ADD โ MLP โ output``
+ where labels are added before MLP processing.
+
+ Timestep and labels are combined before the MLP.
+
+ Note: The final SiLU is omitted here; consumers (AdaLN blocks) apply SiLU
+ before their modulation linear layer.
+
+ Parameters
+ ----------
+ emb_channels : int
+ Output embedding dimension (typically 4 * hidden_size).
+ noise_channels : int
+ Dimension of positional embedding (typically hidden_size).
+ label_dim : int, optional
+ Class label dimension. If 0, no label embedding. Default 0.
+ label_dropout : float, optional
+ Dropout probability for labels during training. Default 0.0.
+ legacy_label_bias : bool, optional
+ If ``True`` and ``label_dim`` is 0, add a legacy bias term matching the old
+ ``EmbedNoiseLabels`` behavior. Default ``False``.
+ max_positions : int, optional
+ Maximum positions for positional embedding. Default 10000.
+
+ Forward
+ -------
+ t : torch.Tensor
+ Timestep/noise_labels tensor of shape :math:`(B,)`.
+ condition : torch.Tensor, optional
+ Condition/class labels of shape :math:`(B, label_dim)`.
+
+ Returns
+ -------
+ torch.Tensor
+ Conditioning embedding of shape :math:`(B, emb_channels)`.
+ """
+
+ def __init__(
+ self,
+ emb_channels: int,
+ noise_channels: int,
+ label_dim: int = 0,
+ label_dropout: float = 0.0,
+ legacy_label_bias: bool = False,
+ max_positions: int = 10000,
+ **kwargs, # Accept and ignore extra kwargs (e.g., amp_mode) for compatibility
+ ):
+ super().__init__()
+ self._output_dim = emb_channels
+ self.label_dropout = label_dropout
+ self.legacy_label_bias = legacy_label_bias
+
+ self.map_noise = PositionalEmbedding(
+ num_channels=noise_channels,
+ max_positions=max_positions,
+ endpoint=True,
+ learnable=False, # No MLP here - added below
+ )
+
+ # Label embedding (added before MLP)
+ if label_dim > 0:
+ self.map_label = nn.Linear(label_dim, noise_channels)
+ elif legacy_label_bias:
+ # Preserve legacy bias-only behavior for label_dim=0.
+ self.map_label = nn.Linear(0, noise_channels, bias=True)
+ else:
+ self.map_label = None
+
+ # MLP: Linear โ SiLU โ Linear (no final SiLU - moved to AdaLN)
+ self.map_layer0 = nn.Linear(noise_channels, emb_channels)
+ self.map_layer1 = nn.Linear(emb_channels, emb_channels)
+
+ @property
+ def output_dim(self) -> int:
+ return self._output_dim
+
+ def forward(
+ self, t: torch.Tensor, condition: torch.Tensor | None = None, **kwargs
+ ) -> torch.Tensor:
+ # Positional embedding
+ emb = self.map_noise(t)
+
+ # Swap sin/cos order
+ emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape)
+
+ # Add label embedding before MLP
+ if self.map_label is not None:
+ if condition is None and self.legacy_label_bias and self.map_label.in_features == 0:
+ emb = emb + self.map_label.bias
+ elif condition is not None:
+ tmp = condition
+ if self.training and self.label_dropout:
+ tmp = tmp * (
+ torch.rand([t.shape[0], 1], device=tmp.device) >= self.label_dropout
+ ).to(tmp.dtype)
+ emb = emb + self.map_label(tmp * math.sqrt(self.map_label.in_features))
+
+ # MLP
+ emb = torch.nn.functional.silu(self.map_layer0(emb))
+ emb = self.map_layer1(emb)
+
+ return emb
+
+
+def get_conditioning_embedder(
+ hidden_size: int,
+ conditioning_embedder: Literal["post_mlp", "pre_mlp"] = "post_mlp",
+ condition_dim: int = 0,
+ **embedder_kwargs: Any,
+) -> ConditioningEmbedderBase:
+ r"""Factory function to create conditioning embedders.
+
+ Parameters
+ ----------
+ hidden_size : int
+ The hidden size of the DiT model.
+ conditioning_embedder : Literal["post_mlp", "pre_mlp"]
+ The type of conditioning embedder to use.
+ Options:
+ - 'post_mlp': Maps the timestep and condition independently, then adds them together post-MLP.
+ - 'pre_mlp': Adds the timestep and condition together before the MLP.
+ condition_dim : int
+ Condition dimension. For 'post_mlp', this is input condition dim.
+ For 'pre_mlp', this is output emb_channels.
+ **embedder_kwargs
+ Additional keyword arguments for the embedder.
+ """
+ if conditioning_embedder == "post_mlp":
+ return PostMLPConditionEmbedder(
+ hidden_size=hidden_size,
+ condition_dim=condition_dim,
+ **embedder_kwargs,
+ )
+ if conditioning_embedder == "pre_mlp":
+ return PreMLPConditionEmbedder(
+ emb_channels=condition_dim,
+ noise_channels=embedder_kwargs.pop("noise_channels", hidden_size),
+ **embedder_kwargs,
+ )
+ raise ValueError("conditioning_embedder must be 'post_mlp' or 'pre_mlp'.")
diff --git a/physicsnemo/experimental/models/healda/__init__.py b/physicsnemo/experimental/models/healda/__init__.py
new file mode 100644
index 0000000000..9d230e609d
--- /dev/null
+++ b/physicsnemo/experimental/models/healda/__init__.py
@@ -0,0 +1,63 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Experimental HealDA models and layers.
+
+Warning: This module is experimental and APIs may change without notice.
+Per MOD-002a, new models start here before promotion to physicsnemo.models.
+"""
+from .dit_layers import (
+ HealDA,
+ HealDAMetaData,
+ convert_healda_state_dict,
+ map_healda_to_pnm_block_keys,
+)
+
+# Config types (model constructor params only)
+# Training configs (ModelConfigV1, ObsConfig) are in examples/weather/healda/config/
+from .config import (
+ ModelSensorConfig,
+ SensorEmbedderConfig,
+)
+
+# Data types
+from .types import (
+ Batch,
+ UnifiedObservation,
+ split_by_sensor,
+)
+
+# Domain
+from .domain import Domain, HealPixDomain
+
+# Embedding layers (FrequencyEmbedding, CalendarEmbedding are HealDA-specific)
+# For PositionalEmbedding and FourierEmbedding, use physicsnemo.nn directly
+from .embedding import (
+ CalendarEmbedding,
+ EmbedNoiseLabels,
+ FrequencyEmbedding,
+)
+
+# HEALPix tokenizer/detokenizer
+from .healpix_layers import (
+ HPXPatchDetokenizer,
+ HPXPatchTokenizer,
+)
+
+# Obs embedding
+from .point_embed import MultiSensorObsEmbedding
+from .scatter_aggregator import ScatterAggregator
+from .scatter_mean import scatter_mean
diff --git a/physicsnemo/experimental/models/healda/config.py b/physicsnemo/experimental/models/healda/config.py
new file mode 100644
index 0000000000..1e233a812d
--- /dev/null
+++ b/physicsnemo/experimental/models/healda/config.py
@@ -0,0 +1,61 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Model configuration dataclasses for HealDA.
+
+These are constructor parameter types used by the model.
+Training-specific configs (ModelConfigV1, ObsConfig) are in examples/weather/healda/.
+"""
+import dataclasses
+
+
+@dataclasses.dataclass(frozen=True)
+class ModelSensorConfig:
+ """Configuration for a single sensor type.
+
+ Parameters
+ ----------
+ sensor_id : int
+ Unique identifier for the sensor type.
+ nchannel : int
+ Number of channels for this sensor.
+ platform_ids : tuple[int, ...]
+ Tuple of platform IDs associated with this sensor.
+ """
+ sensor_id: int
+ nchannel: int
+ platform_ids: tuple[int, ...] # Use tuple since frozen=True
+
+
+@dataclasses.dataclass(frozen=True)
+class SensorEmbedderConfig:
+ """Configuration for sensor embedding module.
+
+ Parameters
+ ----------
+ embed_dim : int, optional, default=32
+ Initial tokenization dimension for observations.
+ meta_dim : int, optional, default=28
+ Dimension of static metadata features.
+ fusion_dim : int, optional, default=512
+ Dimension after sensor fusion.
+ use_channel_platform_embedding_table : bool, optional, default=False
+ Whether to use embedding tables for channel and platform IDs.
+ """
+ embed_dim: int = 32
+ meta_dim: int = 28
+ fusion_dim: int = 512
+ use_channel_platform_embedding_table: bool = False
diff --git a/physicsnemo/experimental/models/healda/dit_layers.py b/physicsnemo/experimental/models/healda/dit_layers.py
new file mode 100644
index 0000000000..82467cbf6d
--- /dev/null
+++ b/physicsnemo/experimental/models/healda/dit_layers.py
@@ -0,0 +1,537 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+HealDA model and checkpoint migration utilities.
+"""
+
+from dataclasses import dataclass
+from typing import Literal, Optional
+
+import torch
+
+from physicsnemo.core.meta import ModelMetaData
+from physicsnemo.core.module import Module
+from physicsnemo.experimental.models.dit import DiT
+
+from .config import ModelSensorConfig, SensorEmbedderConfig
+from .healpix_layers import HPXPatchDetokenizer, HPXPatchTokenizer
+from .point_embed import MultiSensorObsEmbedding
+from .types import UnifiedObservation
+
+
+@dataclass
+class HealDAMetaData(ModelMetaData):
+ """Metadata for HealDA model."""
+
+ jit: bool = False
+ cuda_graphs: bool = False
+ amp_cpu: bool = False
+ amp_gpu: bool = True
+ torch_fx: bool = False
+ bf16: bool = True
+ onnx: bool = False
+ func_torch: bool = False
+ auto_grad: bool = False
+
+
+class HealDA(Module):
+ r"""
+ HealDA DiT model that composes preprocessor + PNM experimental DiT.
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input state channels.
+ out_channels : int
+ Number of output channels.
+ hidden_size : int, optional, default=1024
+ Transformer hidden dimension.
+ num_layers : int, optional, default=24
+ Number of transformer blocks.
+ num_heads : int, optional, default=16
+ Number of attention heads.
+ mlp_ratio : float, optional, default=4.0
+ MLP hidden dim multiplier.
+ level_in : int, optional, default=6
+ HEALPix input resolution level.
+ level_model : int, optional, default=5
+ HEALPix model resolution level after patching.
+ time_length : int, optional, default=1
+ Number of time steps.
+ sensor_embedder_config : SensorEmbedderConfig
+ Config for observation embedding.
+ sensors : dict[str, ModelSensorConfig]
+ Sensor configurations for obs embedding.
+ condition_channels : int, optional, default=2
+ Number of static input channels that go into tokenizer.
+ Tokenizer input = condition_channels + fusion_dim.
+ qk_norm_type : Literal["RMSNorm", "LayerNorm"], optional, default="RMSNorm"
+ QK normalization type. None disables QK normalization.
+ qk_norm_affine : bool, optional, default=True
+ Whether QK normalization layers use learnable affine parameters (timm backend only).
+ drop_path : float, optional, default=0.0
+ DropPath rate for stochastic depth.
+ dropout : float, optional, default=0.0
+ Dropout rate for projection and MLP layers.
+ condition_dim : int, optional, default=None
+ Conditioning embedding dimension. If None, runs as VIT (no conditioning).
+ If set, enables diffusion-style noise/label conditioning.
+ noise_channels : int, optional, default=1024
+ Channels for noise level positional embedding.
+ label_dim : int, optional, default=0
+ Dimension of class labels. 0 means no label conditioning.
+ label_dropout : float, optional, default=None
+ Dropout rate for labels during training.
+ attention_backend : str, optional, default="transformer_engine"
+ Attention backend to use.
+ layernorm_backend : str, optional, default="apex"
+ LayerNorm backend to use.
+
+ Forward
+ -------
+ x : torch.Tensor
+ Input state tensor of shape :math:`(B, C, T, N_{pix})`.
+ t : torch.Tensor
+ Timestep tensor of shape :math:`(B,)`.
+ unified_obs : UnifiedObservation
+ Observation data (required for obs-to-state DA).
+ second_of_day : torch.Tensor, optional
+ Second of day for calendar embedding.
+ day_of_year : torch.Tensor, optional
+ Day of year for calendar embedding.
+ noise_labels : torch.Tensor, optional
+ Noise levels for diffusion conditioning. Required when condition_dim is set.
+ class_labels : torch.Tensor, optional
+ Class labels for conditioning. Only used when condition_dim is set.
+
+ Outputs
+ -------
+ torch.Tensor
+ Output tensor of shape :math:`(B, C_{out}, T, N_{pix})`.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ sensor_embedder_config: SensorEmbedderConfig,
+ sensors: dict[str, ModelSensorConfig],
+ hidden_size: int = 1024,
+ num_layers: int = 24,
+ num_heads: int = 16,
+ mlp_ratio: float = 4.0,
+ level_in: int = 6,
+ level_model: int = 5,
+ time_length: int = 1,
+ condition_channels: int = 2, # Static input channels (e.g. lat/lon)
+ qk_norm_type: Optional[Literal["RMSNorm", "LayerNorm"]] = "RMSNorm",
+ qk_norm_affine: bool = True,
+ drop_path: float = 0.0,
+ dropout: float = 0.0,
+ condition_dim: Optional[int] = None,
+ noise_channels: int = 1024,
+ label_dim: int = 0,
+ label_dropout: Optional[float] = None,
+ norm_eps: float = 1e-5,
+ attention_backend: str = "transformer_engine",
+ layernorm_backend: str = "apex",
+ ):
+ super().__init__(meta=HealDAMetaData())
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.hidden_size = hidden_size
+ self.level_in = level_in
+ self.level_model = level_model
+ self.time_length = time_length
+ self.condition_channels = condition_channels
+ self.condition_dim = condition_dim
+ self.label_dim = label_dim
+ self.qk_norm_type = qk_norm_type
+ self.qk_norm_affine = qk_norm_affine
+ self.attention_backend = attention_backend
+
+ # Observation encoder (embeds obs and concatenates with state)
+ self.obs_embedder = MultiSensorObsEmbedding(
+ sensor_embedder_config=sensor_embedder_config,
+ sensors=sensors,
+ hpx_level=level_in,
+ )
+ self.fusion_dim = sensor_embedder_config.fusion_dim
+ # Tokenizer input: static condition channels + obs embedding (NOT full in_channels)
+ tokenizer_in_channels = condition_channels + self.fusion_dim
+
+ # Create tokenizer and detokenizer
+ tokenizer = HPXPatchTokenizer(
+ in_channels=tokenizer_in_channels,
+ hidden_size=hidden_size,
+ level_fine=level_in,
+ level_coarse=level_model,
+ )
+
+ detokenizer = HPXPatchDetokenizer(
+ hidden_size=hidden_size,
+ out_channels=out_channels,
+ level_coarse=level_model,
+ level_fine=level_in,
+ time_length=time_length,
+ condition_dim=condition_dim,
+ )
+
+ # Create PNM DiT with custom tokenizer/detokenizer
+ npix_coarse = 12 * 4 ** level_model
+ attn_kwargs = {"qk_norm_type": qk_norm_type} if qk_norm_type else {}
+ if qk_norm_type and attention_backend == "timm":
+ attn_kwargs["qk_norm_affine"] = qk_norm_affine
+ if attention_backend == "transformer_engine":
+ attn_kwargs["qkv_format"] = "bshd"
+
+ # HealDA used dropout after attention projection and in MLP, not on attention weights
+ block_kwargs = {
+ "proj_drop_rate": dropout,
+ "mlp_drop_rate": dropout,
+ "norm_eps": norm_eps,
+ "final_mlp_dropout": False,
+ }
+ # if attention_backend == "diffusers":
+ # block_kwargs["attn_drop_rate"] = dropout
+
+ self.dit = DiT(
+ input_size=(npix_coarse * time_length,),
+ in_channels=tokenizer_in_channels,
+ patch_size=(1,),
+ tokenizer=tokenizer,
+ detokenizer=detokenizer,
+ out_channels=out_channels,
+ hidden_size=hidden_size,
+ depth=num_layers,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ attention_backend=attention_backend,
+ layernorm_backend=layernorm_backend,
+ condition_dim=condition_dim,
+ conditioning_embedder="pre_mlp",
+ conditioning_embedder_kwargs={
+ "noise_channels": noise_channels,
+ "label_dim": label_dim,
+ "label_dropout": label_dropout,
+ "legacy_label_bias": True,
+ },
+ drop_path=drop_path,
+ attn_kwargs=attn_kwargs,
+ block_kwargs=block_kwargs,
+ )
+
+ def _maybe_reset_te_qk_norm_weights(self) -> None:
+ if (
+ self.attention_backend != "transformer_engine"
+ or self.qk_norm_type != "RMSNorm"
+ ):
+ return
+ print("Resetting TE QK RMSNorm weights")
+ with torch.no_grad():
+ for block in self.dit.blocks:
+ attn_op = block.attention.attn_op
+ if hasattr(attn_op, "q_norm") and getattr(attn_op.q_norm, "weight", None) is not None:
+ attn_op.q_norm.weight.fill_(1.0)
+ if hasattr(attn_op, "k_norm") and getattr(attn_op.k_norm, "weight", None) is not None:
+ attn_op.k_norm.weight.fill_(1.0)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ t: torch.Tensor,
+ unified_obs: UnifiedObservation,
+ second_of_day: Optional[torch.Tensor] = None,
+ day_of_year: Optional[torch.Tensor] = None,
+ class_labels: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Forward pass.
+
+ Args:
+ x: Input state (B, C, T, npix)
+ t: Timestep/noise_labels (B,) - used for conditioning
+ unified_obs: Observation data (required)
+ second_of_day: Calendar info
+ day_of_year: Calendar info
+ class_labels: Class labels for conditioning (only used when label_dim > 0)
+
+ Returns:
+ Output (B, C_out, T, npix)
+ """
+ # Embed observations and concatenate with state
+ obs_emb = self.obs_embedder(unified_obs) # (B, fusion_dim, T, npix)
+ x = torch.cat([x, obs_emb], dim=1) # (B, C + fusion_dim, T, npix)
+
+ return self.dit(
+ x, t, condition=class_labels,
+ tokenizer_kwargs={"second_of_day": second_of_day, "day_of_year": day_of_year},
+ )
+
+ @classmethod
+ def from_healda_checkpoint(
+ cls,
+ checkpoint_path: str,
+ sensor_embedder_config: SensorEmbedderConfig,
+ sensors: dict[str, ModelSensorConfig],
+ level_in: int = 6,
+ level_model: int = 5,
+ device: str = "cuda",
+ attention_backend: str = "transformer_engine",
+ ) -> "HealDA":
+ """
+ Load a HealDA model from an old HealDA checkpoint.
+
+ Args:
+ checkpoint_path: Path to the .checkpoint file
+ sensor_embedder_config: Configuration for multi-sensor observation embedder
+ sensors: Dictionary mapping sensor names to their configurations
+ level_in: HEALPix input resolution level (default: 6)
+ level_model: HEALPix model resolution level (default: 5)
+ device: Device to load the model on
+
+ Returns:
+ Instantiated HealDA with weights loaded
+ """
+ import json
+ import zipfile
+
+ with zipfile.ZipFile(checkpoint_path, "r") as zf:
+ # Read config
+ with zf.open("model.json") as f:
+ config = json.load(f)
+
+ # Read state dict
+ with zf.open("net_state.pth") as f:
+ old_state_dict = torch.load(f, map_location=device, weights_only=True)
+
+ # Extract model config
+ hidden_size = 1024 # Default for dit-l
+ num_layers = 24
+ num_heads = 16
+ print(f"Checkpoint config: {config}")
+
+ # Determine condition_dim from as_vit flag
+ # as_vit=True means condition_dim=0 (VIT mode, bias-only adaptive modulation)
+ condition_dim = 0 if config.get("as_vit", False) else 4 * hidden_size
+
+ # Create model
+ model = cls(
+ in_channels=config.get("out_channels", 74),
+ out_channels=config.get("out_channels", 74),
+ sensor_embedder_config=sensor_embedder_config,
+ sensors=sensors,
+ hidden_size=hidden_size,
+ num_layers=num_layers,
+ num_heads=num_heads,
+ level_in=level_in,
+ level_model=level_model,
+ time_length=config.get("time_length", 1),
+ condition_channels=config.get("condition_channels", 2),
+ qk_norm_type="RMSNorm" if config.get("qk_rms_norm", False) else None,
+ qk_norm_affine=False,
+ drop_path=config.get("drop_path", 0.0),
+ dropout=config.get("p_dropout", 0.0),
+ condition_dim=condition_dim,
+ noise_channels=config.get("noise_channels", hidden_size), # Default to inner_dim
+ label_dim=config.get("label_dim", 0),
+ label_dropout=config.get("label_dropout", 0.0),
+ norm_eps=config.get("norm_eps", 1e-5),
+ attention_backend=attention_backend,
+ )
+
+ # Convert and load state dict
+ new_state_dict = convert_healda_state_dict(
+ old_state_dict,
+ num_blocks=num_layers,
+ hidden_size=hidden_size,
+ condition_dim=condition_dim,
+ attention_backend=model.attention_backend,
+ )
+
+ # With condition_dim=0 (VIT mode), shapes now match [n, 0] directly
+ model.load_state_dict(new_state_dict, strict=False)
+ model._maybe_reset_te_qk_norm_weights()
+ return model.to(device)
+
+
+# Weight mapping utilities for checkpoint migration
+def map_healda_to_pnm_block_keys(
+ old_key: str,
+ block_idx: int,
+ attention_backend: Literal["transformer_engine", "timm", "natten2d"] = "transformer_engine",
+) -> str:
+ """
+ Map HealDA DiT block state dict key to PNM DiT block key.
+
+ Args:
+ old_key: Original key like 'transformer_blocks.0.attn1.to_q.weight'
+ block_idx: Block index
+
+ Returns:
+ New key like 'blocks.0.attention.attn_op.qkv.query_weight'
+ """
+ prefix = f"transformer_blocks.{block_idx}."
+ new_prefix = f"dit.blocks.{block_idx}."
+
+ if not old_key.startswith(prefix):
+ return old_key
+
+ suffix = old_key[len(prefix):]
+
+ # Mapping table
+ mappings = {
+ # AdaLN modulation
+ "norm1.linear.weight": "adaptive_modulation.1.weight",
+ "norm1.linear.bias": "adaptive_modulation.1.bias",
+ # MLP
+ "ff.net.0.proj.weight": "linear.layers.0.weight",
+ "ff.net.0.proj.bias": "linear.layers.0.bias",
+ "ff.net.2.weight": "linear.layers.3.weight",
+ "ff.net.2.bias": "linear.layers.3.bias",
+ }
+
+ if attention_backend == "timm":
+ mappings.update(
+ {
+ # Attention output projection
+ "attn1.to_out.0.weight": "attention.attn_op.proj.weight",
+ "attn1.to_out.0.bias": "attention.attn_op.proj.bias",
+ }
+ )
+ else:
+ mappings.update(
+ {
+ # Attention Q/K/V
+ "attn1.to_q.weight": "attention.attn_op.qkv.query_weight",
+ "attn1.to_q.bias": "attention.attn_op.qkv.query_bias",
+ "attn1.to_k.weight": "attention.attn_op.qkv.key_weight",
+ "attn1.to_k.bias": "attention.attn_op.qkv.key_bias",
+ "attn1.to_v.weight": "attention.attn_op.qkv.value_weight",
+ "attn1.to_v.bias": "attention.attn_op.qkv.value_bias",
+ # Attention output projection
+ "attn1.to_out.0.weight": "attention.attn_op.proj.weight",
+ "attn1.to_out.0.bias": "attention.attn_op.proj.bias",
+ }
+ )
+
+ if suffix in mappings:
+ return new_prefix + mappings[suffix]
+
+ return old_key
+
+
+def convert_healda_state_dict(
+ old_state_dict: dict,
+ num_blocks: int = 24,
+ hidden_size: int = 1024,
+ condition_dim: int = 0,
+ attention_backend: Literal["transformer_engine", "timm", "natten2d"] = "transformer_engine",
+) -> dict:
+ """
+ Convert HealDA DiT state dict to PNM DiT format.
+
+ Args:
+ old_state_dict: Original HealDA state dict
+ num_blocks: Number of transformer blocks
+ hidden_size: Model hidden dimension
+ condition_dim: Conditioning dimension. If 0, model is unconditional
+ and empty weights are expanded to zeros.
+
+ Returns:
+ New state dict compatible with PNM DiT
+ """
+ new_state_dict = {}
+ is_unconditional = condition_dim == 0
+
+ qkv_buffers: dict[int, dict[str, torch.Tensor]] = {}
+
+ for old_key, value in old_state_dict.items():
+ # Handle transformer blocks
+ if old_key.startswith("transformer_blocks."):
+ # Extract block index
+ parts = old_key.split(".")
+ block_idx = int(parts[1])
+
+ if attention_backend == "timm" and parts[2] == "attn1" and parts[3] in {"to_q", "to_k", "to_v"}:
+ qkv_buffers.setdefault(block_idx, {})[f"{parts[3]}.{parts[4]}"] = value
+ continue
+
+ new_key = map_healda_to_pnm_block_keys(
+ old_key,
+ block_idx,
+ attention_backend=attention_backend,
+ )
+ # VIT mode: [n, 0] weights are kept as-is (model now supports condition_dim=0)
+ new_state_dict[new_key] = value
+
+ # Handle obs embedder (embed_v2_patch -> obs_embedder)
+ elif old_key.startswith("embed_v2_patch."):
+ suffix = old_key[len("embed_v2_patch."):]
+ new_key = f"obs_embedder.{suffix}"
+ new_state_dict[new_key] = value
+
+ # Handle tokenizer (pos_embed)
+ elif old_key.startswith("pos_embed."):
+ suffix = old_key[len("pos_embed."):]
+ new_key = f"dit.tokenizer.{suffix}"
+ new_state_dict[new_key] = value
+
+ # Handle detokenizer (patch_decode)
+ elif old_key.startswith("patch_decode."):
+ suffix = old_key[len("patch_decode."):]
+ new_key = f"dit.detokenizer.{suffix}"
+ new_state_dict[new_key] = value
+
+ # Handle final projection (proj_out_1 -> detokenizer.adaptive_modulation.1)
+ elif old_key.startswith("proj_out_1."):
+ suffix = old_key[len("proj_out_1."):]
+ new_key = f"dit.detokenizer.adaptive_modulation.1.{suffix}"
+ # VIT mode: [n, 0] weights are kept as-is (model now supports condition_dim=0)
+ new_state_dict[new_key] = value
+
+ elif old_key.startswith("norm_out."):
+ suffix = old_key[len("norm_out."):]
+ new_key = f"dit.detokenizer.norm_out.{suffix}"
+ new_state_dict[new_key] = value
+
+ # Map noise_embed to dit.conditioning_embedder (not present in VIT mode)
+ elif old_key.startswith("noise_embed."):
+ if not is_unconditional:
+ suffix = old_key[len("noise_embed."):]
+ new_state_dict[f"dit.conditioning_embedder.{suffix}"] = value
+
+ else:
+ # Pass through other keys unchanged
+ new_state_dict[old_key] = value
+
+ if attention_backend == "timm":
+ for block_idx, parts in qkv_buffers.items():
+ q_weight = parts.get("to_q.weight")
+ k_weight = parts.get("to_k.weight")
+ v_weight = parts.get("to_v.weight")
+ q_bias = parts.get("to_q.bias")
+ k_bias = parts.get("to_k.bias")
+ v_bias = parts.get("to_v.bias")
+ if q_weight is None or k_weight is None or v_weight is None:
+ raise ValueError(f"Missing q/k/v weights for block {block_idx} while building timm qkv.")
+ qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
+ new_state_dict[f"dit.blocks.{block_idx}.attention.attn_op.qkv.weight"] = qkv_weight
+ if q_bias is not None and k_bias is not None and v_bias is not None:
+ qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0)
+ new_state_dict[f"dit.blocks.{block_idx}.attention.attn_op.qkv.bias"] = qkv_bias
+
+ return new_state_dict
diff --git a/physicsnemo/experimental/models/healda/domain.py b/physicsnemo/experimental/models/healda/domain.py
new file mode 100644
index 0000000000..95a3677f2a
--- /dev/null
+++ b/physicsnemo/experimental/models/healda/domain.py
@@ -0,0 +1,88 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Protocol
+
+import earth2grid
+
+
+class Domain(Protocol):
+ def numel(self) -> int:
+ size = 1
+ for n in self.shape():
+ size *= n
+ return size
+
+ def ndim(self) -> int:
+ return len(self.shape())
+
+ def shape(self) -> tuple[int]:
+ pass
+
+ @property
+ def img_resolution(self) -> int:
+ return max(self.shape())
+
+ # TODO add boundary/padding function
+
+
+class HealPixDomain(Domain):
+ """HEALPix domain."""
+
+ def __init__(self, hpx: earth2grid.healpix.Grid):
+ self._grid = hpx
+
+ def shape(self) -> tuple[int]:
+ return self._grid.shape
+
+ @property
+ def img_resolution(self):
+ return 2**self._grid.level
+
+
+class PatchedHealpixDomain(Domain):
+ """Patch of HEALPix domain"""
+
+ def __init__(self, hpx: earth2grid.healpix.Grid, patch_size: int = 128):
+ self._grid = hpx
+ self.patch_size = patch_size
+
+ def shape(self) -> tuple[int]:
+ return self._grid.shape
+
+ @property
+ def img_resolution(self):
+ return self.patch_size
+
+
+class Plane(Domain):
+ """2D rectangular grid domain."""
+
+ def __init__(self, nx, ny):
+ self.nx = nx
+ self.ny = ny
+
+ def shape(self):
+ return (self.ny, self.nx)
+
+
+class Ring(Domain):
+ """1D ring domain."""
+
+ def __init__(self, n):
+ self.n = n
+
+ def shape(self):
+ return (self.n,)
diff --git a/physicsnemo/experimental/models/healda/embedding.py b/physicsnemo/experimental/models/healda/embedding.py
new file mode 100644
index 0000000000..ea150d16f6
--- /dev/null
+++ b/physicsnemo/experimental/models/healda/embedding.py
@@ -0,0 +1,131 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+
+import torch
+
+from physicsnemo.core.module import Module
+from physicsnemo.nn import PositionalEmbedding
+
+
+class FrequencyEmbedding(Module):
+ """Periodic Embedding.
+
+ Useful for inputs defined on the circle [0, 2pi)
+ """
+
+ def __init__(self, num_channels):
+ super().__init__()
+ self.register_buffer(
+ "freqs", torch.arange(1, num_channels + 1), persistent=False
+ )
+
+ def forward(self, x):
+ freqs = self.freqs[None, :, None, None]
+ x = x[:, None, :, :]
+ x = x * (2 * math.pi * freqs).to(x.dtype)
+ x = torch.cat([x.cos(), x.sin()], dim=1)
+ return x
+
+
+class CalendarEmbedding(Module):
+ """Time embedding assuming 365.25 day years
+
+ Args:
+ day_of_year: (n, t)
+ second_of_day: (n, t)
+ Returns:
+ (n, embed_channels * 4, t, x)
+
+ """
+
+ def __init__(self, lon, embed_channels: int, include_legacy_bug: bool = False):
+ """
+ Args:
+ include_legacy_bug: Provided for backwards compatibility
+ with existing checkpoints. If True, use the incorrect formula
+ for local_time (hour - lon) instead of the correct formula (hour + lon)
+ """
+ super().__init__()
+ self.register_buffer("lon", lon, persistent=False)
+ self.embed_channels = embed_channels
+ self.embed_second = FrequencyEmbedding(embed_channels)
+ self.embed_day = FrequencyEmbedding(embed_channels)
+ self.out_channels = embed_channels * 4
+ self.include_legacy_bug = include_legacy_bug
+
+ def forward(self, day_of_year, second_of_day):
+ if second_of_day.shape != day_of_year.shape:
+ raise ValueError()
+
+ if self.include_legacy_bug:
+ local_time = (second_of_day.unsqueeze(2) - self.lon * 86400 // 360) % 86400
+ else:
+ local_time = (second_of_day.unsqueeze(2) + self.lon * 86400 // 360) % 86400
+
+ a = self.embed_second(local_time / 86400)
+ doy = day_of_year.unsqueeze(2)
+ b = self.embed_day((doy / 365.25) % 1)
+ a, b = torch.broadcast_tensors(a, b)
+ return torch.concat([a, b], dim=1) # (n c x)
+
+
+class EmbedNoiseLabels(Module):
+ """Embedding layer for noise levels and class labels."""
+
+ def __init__(
+ self,
+ emb_channels,
+ label_dim,
+ noise_channels,
+ label_dropout=None,
+ legacy_label_bias: bool = False,
+ ):
+ super().__init__()
+ self.label_dropout = label_dropout
+ self.map_noise = PositionalEmbedding(num_channels=noise_channels, endpoint=True)
+
+ # legacy_label_bias: for loading old checkpoints that had Linear(0, noise_channels)
+ # which contributed a trained bias even with label_dim=0
+ self.map_label = None
+ if label_dim != 0 or legacy_label_bias:
+ self.map_label = torch.nn.Linear(label_dim, noise_channels)
+
+ self.map_layer0 = torch.nn.Linear(
+ in_features=noise_channels, out_features=emb_channels
+ )
+ self.map_layer1 = torch.nn.Linear(
+ in_features=emb_channels, out_features=emb_channels
+ )
+
+ def forward(self, noise_labels, class_labels):
+ emb = self.map_noise(noise_labels)
+ emb = (
+ emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape)
+ ) # swap sin/cos
+
+ if self.map_label is not None:
+ tmp = class_labels
+ if self.training and self.label_dropout:
+ tmp = tmp * (
+ torch.rand([noise_labels.shape[0], 1], device=tmp.device)
+ >= self.label_dropout
+ ).to(tmp.dtype)
+ emb = emb + self.map_label(tmp * math.sqrt(self.map_label.in_features))
+
+ emb = torch.nn.functional.silu(self.map_layer0(emb))
+ emb = self.map_layer1(emb) # No SiLU - consumers (AdaLN) add SiLU before modulation linear
+ return emb
diff --git a/physicsnemo/experimental/models/healda/healpix_layers.py b/physicsnemo/experimental/models/healda/healpix_layers.py
new file mode 100644
index 0000000000..197d5e35b1
--- /dev/null
+++ b/physicsnemo/experimental/models/healda/healpix_layers.py
@@ -0,0 +1,225 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""HEALPix tokenization and detokenization layers for HealDA."""
+
+from typing import Optional
+
+import earth2grid
+import earth2grid.healpix
+import einops
+import torch
+import torch.nn as nn
+
+from physicsnemo.experimental.models.dit.layers import (
+ DetokenizerModuleBase,
+ TokenizerModuleBase,
+)
+
+from .embedding import CalendarEmbedding
+
+
+class HPXPatchTokenizer(TokenizerModuleBase):
+ r"""
+ HEALPix patch tokenizer for DiT integration.
+
+ Folds 12 HEALPix faces into batch, applies conv, unfolds, adds global pos_embed + calendar_embed.
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels.
+ hidden_size : int
+ Number of output embedding channels.
+ level_fine : int
+ HEALPix resolution level of input data.
+ level_coarse : int
+ HEALPix resolution level after patch embedding (model level).
+
+ Forward
+ -------
+ x : torch.Tensor
+ Input tensor of shape :math:`(B, C, T, N_{pix})` where :math:`N_{pix} = 12 \\times 4^{level\\_fine}`.
+ second_of_day : torch.Tensor, optional
+ Second of day for calendar embedding.
+ day_of_year : torch.Tensor, optional
+ Day of year for calendar embedding.
+
+ Outputs
+ -------
+ torch.Tensor
+ Output tensor of shape :math:`(B, L, D)` where :math:`L = T \\times 12 \\times patches\\_per\\_face`.
+ """
+
+ pixel_order = earth2grid.healpix.HEALPIX_PAD_XY
+
+ def __init__(
+ self,
+ *,
+ in_channels: int,
+ hidden_size: int,
+ level_fine: int,
+ level_coarse: int,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_size = hidden_size
+ self.level_fine = level_fine
+ self.level_coarse = level_coarse
+ self.nside = 2**level_fine
+ self.patch_size = 2 ** (level_fine - level_coarse)
+
+ # Patch embedding conv
+ self.conv = nn.Conv2d(
+ in_channels, hidden_size,
+ kernel_size=self.patch_size, stride=self.patch_size,
+ )
+
+ # Global positional embedding across all 12 faces
+ npix_coarse = 12 * 4**level_coarse
+ self.pos_embed = nn.Parameter(torch.randn(npix_coarse, hidden_size))
+
+ # Calendar embedding (HEALPix-specific: incorporates longitude)
+ grid = earth2grid.healpix.Grid(level=level_coarse, pixel_order=self.pixel_order)
+ lon = torch.as_tensor(grid.lon)
+ if hidden_size % 4 != 0:
+ raise ValueError(f"hidden_size must be divisible by 4, got {hidden_size}")
+ self.calendar_embed = CalendarEmbedding(lon, hidden_size // 4).float()
+
+ def initialize_weights(self) -> None:
+ w = self.conv.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.conv.bias, 0)
+ nn.init.normal_(self.pos_embed, std=0.02)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ second_of_day: Optional[torch.Tensor] = None,
+ day_of_year: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ b, c, t, npix = x.shape
+
+ # 1. Fold faces into batch: (B, C, T, 12*nsideยฒ) -> (B*T*12, C, nside, nside)
+ x = einops.rearrange(x, "b c t (f x y) -> (b t f) c x y", f=12, x=self.nside, y=self.nside)
+
+ # 2. Conv: (B*T*12, C, nside, nside) -> (B*T*12, D, nside_c, nside_c)
+ x = self.conv(x)
+
+ # 3. Flatten + unfold: (B*T*12, D, n_c, n_c) -> (B, T*12*n_c*n_c, D)
+ x = einops.rearrange(x, "(b t f) d x y -> b (t f x y) d", b=b, t=t, f=12)
+
+ # 4. Add global positional embedding
+ x = x + self.pos_embed
+
+ # 5. Add calendar embedding
+ if second_of_day is not None and day_of_year is not None:
+ calendar_emb = self.calendar_embed(second_of_day=second_of_day, day_of_year=day_of_year)
+ calendar_emb = einops.rearrange(calendar_emb, "b d t x -> b (t x) d")
+ x = x + calendar_emb
+
+ return x
+
+
+class HPXPatchDetokenizer(DetokenizerModuleBase):
+ r"""
+ HEALPix patch detokenizer for DiT integration.
+
+ Applies final AdaLN modulation and conv transpose to upsample patches.
+
+ Parameters
+ ----------
+ hidden_size : int
+ Input embedding dimension.
+ out_channels : int
+ Number of output channels.
+ level_coarse : int
+ HEALPix resolution level of input patches.
+ level_fine : int
+ HEALPix resolution level of output data.
+ time_length : int, optional, default=1
+ Number of time steps.
+
+ Forward
+ -------
+ x : torch.Tensor
+ Input tensor of shape :math:`(B, L, D)`.
+ c : torch.Tensor
+ Conditioning tensor of shape :math:`(B, D)`. Pass zeros for VIT mode.
+
+ Outputs
+ -------
+ torch.Tensor
+ Output tensor of shape :math:`(B, C_{out}, T, N_{pix})`.
+ """
+
+ def __init__(
+ self,
+ *,
+ hidden_size: int,
+ out_channels: int,
+ level_coarse: int,
+ level_fine: int,
+ time_length: int = 1,
+ condition_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.out_channels = out_channels
+ self.level_coarse = level_coarse
+ self.level_fine = level_fine
+ self.time_length = time_length
+ self.nside_coarse = 2**level_coarse
+ self.patch_size = 2 ** (level_fine - level_coarse)
+
+ # AdaLN: c -> (shift, scale)
+ modulation_input_dim = hidden_size if condition_dim is None else condition_dim
+ self.adaptive_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(modulation_input_dim, 2 * hidden_size),
+ )
+ self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+
+ # Conv transpose for upsampling
+ self.conv_t = nn.ConvTranspose2d(
+ hidden_size, out_channels,
+ kernel_size=self.patch_size, stride=self.patch_size,
+ )
+
+ def initialize_weights(self) -> None:
+ nn.init.constant_(self.adaptive_modulation[-1].weight, 0)
+ nn.init.constant_(self.adaptive_modulation[-1].bias, 0)
+
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
+ b = x.shape[0]
+ t = self.time_length
+ n = self.nside_coarse
+
+ # 1. Unflatten: (B, L, D) -> (B, T, 12*n*n, D)
+ x = einops.rearrange(x, "b (t f x y) d -> b t (f x y) d", t=t, f=12, x=n, y=n)
+
+ # 2. AdaLN: norm(x) * (1 + scale) + shift
+ shift, scale = self.adaptive_modulation(c).chunk(2, dim=-1)
+ x = self.norm_out(x) * (1 + scale[:, None, None, :]) + shift[:, None, None, :]
+
+ # 3. Fold faces: (B, T, 12*n*n, D) -> (B*T*12, D, n, n)
+ x = einops.rearrange(x, "b t (f x y) d -> (b t f) d x y", f=12, x=n, y=n)
+
+ # 4. Conv transpose
+ x = self.conv_t(x)
+
+ # 5. Unfold: (B*T*12, C, nside_fine, nside_fine) -> (B, C, T, npix)
+ x = einops.rearrange(x, "(b t f) c x y -> b c t (f x y)", f=12, b=b, t=t)
+ return x
diff --git a/physicsnemo/experimental/models/healda/point_embed.py b/physicsnemo/experimental/models/healda/point_embed.py
new file mode 100644
index 0000000000..0677883eac
--- /dev/null
+++ b/physicsnemo/experimental/models/healda/point_embed.py
@@ -0,0 +1,464 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Multi-sensor observation embedding for HealDA."""
+
+import logging
+import math
+
+import earth2grid
+import torch
+
+from physicsnemo.core.module import Module
+
+from .config import ModelSensorConfig, SensorEmbedderConfig
+from .scatter_aggregator import ScatterAggregator
+from .types import UnifiedObservation, split_by_sensor
+
+
+def _prod(shape):
+ out = 1
+ for s in shape:
+ out *= s
+ return out
+
+
+GLOBAL_MAX_CHANNELS = 1024
+GLOBAL_MAX_PLATFORM = 1024
+
+logger = logging.getLogger(__name__)
+
+
+class ObsTokenizer(Module):
+ """Tokenizes individual observations using metadata + measurement + embedding tables into feature tokens.
+
+ This creates intermediate token representations that will be aggregated and projected to final embeddings.
+
+ Args:
+ meta_dim: Dimension of static metadata features
+ out_dim: Output token dimension
+ platform_id_map: Tensor mapping global platform IDs to local indices
+ n_embed: Size of observation type embedding table
+ nchannel: Max number of channels (for channel embedding table)
+ nplatform: Max number of platforms (for platform embedding table)
+ embed_dim: Dimension of observation type embeddings
+ use_channel_platform_embedding_table: Use channel and platform embedding tables
+ """
+
+ def __init__(
+ self,
+ meta_dim: int,
+ out_dim: int,
+ platform_id_map: torch.tensor,
+ n_embed: int = 1024,
+ nchannel: int = 1024,
+ nplatform: int = 1024,
+ embed_dim: int = 4,
+ use_channel_platform_embedding_table: bool = True,
+ ):
+ super().__init__()
+
+ if nchannel > GLOBAL_MAX_CHANNELS or nplatform > GLOBAL_MAX_PLATFORM:
+ raise ValueError(
+ f"nchannel {nchannel} or nplatform {nplatform} is greater than the global max {GLOBAL_MAX_CHANNELS} or {GLOBAL_MAX_PLATFORM}"
+ )
+
+ self.use_channel_platform_embedding_table = use_channel_platform_embedding_table
+ if self.use_channel_platform_embedding_table:
+ self.channel_embedding = torch.nn.Embedding(GLOBAL_MAX_CHANNELS, embed_dim)
+ self.platform_embedding = torch.nn.Embedding(GLOBAL_MAX_PLATFORM, embed_dim)
+ self.embed_table = torch.nn.Embedding(n_embed, embed_dim)
+
+ self.register_buffer("platform_id_map", platform_id_map)
+
+ mlp_in_dim = (
+ 1
+ + meta_dim
+ + embed_dim * (3 if self.use_channel_platform_embedding_table else 1)
+ )
+ mlp_out_dim = out_dim - 1
+ hidden_dim = out_dim * 2 if out_dim <= 32 else out_dim
+
+ self.meta_mlp = torch.nn.Sequential(
+ torch.nn.Linear(mlp_in_dim, hidden_dim),
+ torch.nn.LayerNorm(hidden_dim),
+ torch.nn.SiLU(),
+ torch.nn.Linear(hidden_dim, mlp_out_dim),
+ )
+
+ def forward(self, obs: UnifiedObservation) -> torch.Tensor:
+ """
+ Tokenize observations into feature tokens.
+
+ Args:
+ obs: UnifiedObservation containing observations and metadata
+
+ Returns:
+ (nobs, out_dim)
+ """
+ # Extract columns from int_metadata (n_obs, 6)
+ channel_id = obs.int_metadata[:, obs.bucket_index.local_channel]
+ platform_id_global = obs.int_metadata[:, obs.bucket_index.platform]
+ obs_type_id = obs.int_metadata[:, obs.bucket_index.obs_type]
+ embed_vec = self.embed_table(obs_type_id)
+ if self.use_channel_platform_embedding_table:
+ channel_emb = self.channel_embedding(channel_id)
+ local_platform_id = self.platform_id_map[platform_id_global]
+ platform_emb = self.platform_embedding(local_platform_id)
+
+ x_in = torch.cat(
+ [
+ obs.obs.unsqueeze(-1),
+ obs.float_metadata,
+ embed_vec,
+ *(
+ [channel_emb, platform_emb]
+ if self.use_channel_platform_embedding_table
+ else []
+ ),
+ ],
+ dim=-1,
+ )
+ mlp_out = self.meta_mlp(x_in)
+ encoded = torch.cat([obs.obs.unsqueeze(-1), mlp_out], dim=-1)
+ return encoded
+
+
+class UniformFusion(Module):
+ """
+ Uniform weighting across all sensors with normalization for number of sensors.
+
+ Simple averaging with 1/sqrt(N) scaling to maintain variance.
+ """
+
+ def __init__(self, fusion_dim: int = 256):
+ super().__init__()
+ self.fusion_dim = fusion_dim
+ self.norm = torch.nn.LayerNorm(self.fusion_dim)
+
+ def forward(
+ self, projected: torch.Tensor, sensor_ids: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Args:
+ projected: (num_sensors, ..., fusion_dim)
+ sensor_ids: (num_sensors,) - not used, for API consistency
+ Returns:
+ (..., fusion_dim)
+ """
+ num_sensors = projected.shape[0]
+
+ projected = self.norm(projected)
+ return projected.sum(dim=0) / math.sqrt(num_sensors)
+
+
+class SensorEmbedder(Module):
+ """Unified sensor embedding for any observation source (satellite, conventional, etc.).
+
+ Pipeline:
+ 1. Per-obs tokenization via MLP
+ 2. Scatter aggregation
+ 3. Final projection to output_dim
+
+ Args:
+ platform_ids: List of global platform IDs for this sensor
+ sensor_embed_dim: Internal feature dimension
+ output_dim: Final output dimension of a sensor embedding
+ meta_dim: Dimension of static metadata features
+ hpx_level: HEALPix grid level
+ n_embed: Size of observation type embedding table
+ embed_dim: Dimension of observation type embeddings
+ nchannel: Max number of channels
+ use_checkpoint: Apply gradient checkpointing
+ """
+
+ def __init__(
+ self,
+ platform_ids: list[int],
+ sensor_embed_dim: int = 32,
+ output_dim: int = 256,
+ meta_dim: int = 32,
+ hpx_level: int = 6,
+ # Embedding table config
+ n_embed: int = 1024, # Large sparse table for observation types
+ embed_dim: int = 4,
+ nchannel: int = 1024, # Max channels
+ use_checkpoint: bool = False,
+ use_channel_platform_embedding_table: bool = True,
+ ):
+ super().__init__()
+
+ platform_id_map_size = GLOBAL_MAX_PLATFORM + 1
+ # Map global platform IDs to local indices for embedding lookup
+ if len(platform_ids) == 0:
+ # Platform-agnostic sensor (e.g., conv): all platforms map to index 0
+ # Use a map large enough to cover all possible platform IDs
+ self.register_buffer(
+ "platform_id_map",
+ torch.zeros(
+ platform_id_map_size, dtype=torch.long
+ ), # All platforms โ 0
+ )
+ nplatform = 1 # Single embedding for all platforms
+ else:
+ # Normal sensor: create lookup map for specific platforms
+ self.register_buffer(
+ "platform_id_map",
+ torch.full((platform_id_map_size,), -1, dtype=torch.long),
+ )
+ for local_idx, global_id in enumerate(platform_ids):
+ self.platform_id_map[global_id] = local_idx
+ nplatform = len(platform_ids)
+
+ self.sensor_embed_dim = sensor_embed_dim
+ self.output_dim = output_dim
+ self.hpx_level = hpx_level
+ self.npix = 12 * 4**hpx_level
+ self.use_checkpoint = use_checkpoint
+ self.nchannel = nchannel
+ self.nplatform = nplatform
+
+ self.obs_tokenizer = ObsTokenizer(
+ meta_dim=meta_dim,
+ out_dim=sensor_embed_dim,
+ platform_id_map=self.platform_id_map,
+ n_embed=n_embed,
+ nchannel=nchannel,
+ nplatform=nplatform,
+ embed_dim=embed_dim,
+ use_channel_platform_embedding_table=use_channel_platform_embedding_table,
+ )
+
+ # Aggregation setup - outputs (nbatch, npix, output_dim)
+ self.scatter_infill_aggregator = ScatterAggregator(
+ in_dim=sensor_embed_dim,
+ out_dim=output_dim,
+ nchannel=nchannel,
+ nplatform=nplatform,
+ npix=self.npix,
+ )
+
+ def aggregate(
+ self,
+ embedded_obs: torch.Tensor,
+ obs: UnifiedObservation,
+ batch_idx: torch.Tensor,
+ nbatch: int,
+ ) -> torch.Tensor:
+ """
+ Aggregate observations to spatial grid and project to output dimension.
+
+ Args:
+ embedded_obs: (nobs, sensor_embed_dim) tokenized observations
+ obs: UnifiedObservation
+ batch_idx: (nobs,) batch index for each obs
+ nbatch: product of batch dimensions
+
+ Returns:
+ (nbatch, npix, output_dim) aggregated spatial grid in HEALPIX_PAD_XY format
+ """
+ obs_pix = obs.int_metadata[:, obs.bucket_index.pix]
+ channel = obs.int_metadata[:, obs.bucket_index.local_channel]
+ platform_global = obs.int_metadata[:, obs.bucket_index.platform]
+ platform = self.platform_id_map[platform_global]
+
+ # Convert observation pixels to aggregator grid resolution
+ pix = obs_pix // int(4.0 ** (obs.hpx_level - self.hpx_level))
+
+ # Build combined bucket ID
+ bucket_id = platform * self.nchannel + channel
+ return self.scatter_infill_aggregator(
+ obs_features=embedded_obs,
+ batch_idx=batch_idx,
+ pix=pix,
+ bucket_id=bucket_id,
+ nbatch=nbatch,
+ )
+
+ def _forward(self, obs: UnifiedObservation):
+ batch_dims = obs.batch_dims # () if offsets is None, (B, T) otherwise
+ nbatch = _prod(batch_dims) # 1 if batch_dims==(), B*T otherwise
+
+ embedded_obs = self.obs_tokenizer(obs)
+
+ batch_idx = obs.batch_idx
+
+ # Aggregator handles empty batches internally to keep all parameters in the computation graph
+ output = self.aggregate(
+ embedded_obs, obs, batch_idx, nbatch
+ ) # NEST (nbatch, npix, output_dim)
+
+ if len(batch_dims) == 0:
+ output = output.view(self.npix, self.output_dim)
+ else:
+ output = output.view(*batch_dims, self.npix, self.output_dim)
+
+ return output
+
+ def forward(self, obs: UnifiedObservation) -> torch.Tensor:
+ """
+ Embed observations from a single sensor onto a spatial grid.
+
+ Args:
+ obs: UnifiedObservation for a single sensor. Observations are flattened
+ across batch/time dimensions; `obs.offsets` defines the structure.
+
+ Returns:
+ If offsets=None: (npix, output_dim) - single spatial grid in NEST order
+ If offsets present: (*offsets.shape, npix, output_dim) - grid in NEST order
+ e.g., (batch, time, npix, output_dim) if offsets is (batch, time)
+ """
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(
+ self._forward, obs, use_reentrant=False
+ )
+ else:
+ return self._forward(obs)
+
+
+class MultiSensorObsEmbedding(Module):
+ r"""
+ Multi-sensor observation embedding.
+
+ Embeds observations from multiple sensor types into a unified representation on a HEALPix grid.
+
+ Parameters
+ ----------
+ sensor_embedder_config : SensorEmbedderConfig
+ Configuration with embedding hyperparameters.
+ sensors : dict[str, ModelSensorConfig]
+ Dictionary mapping sensor names to their configurations.
+ hpx_level : int
+ HEALPix grid level for all sensors.
+ use_checkpoint : bool, optional, default=False
+ If True, applies gradient checkpointing to reduce memory usage.
+ compile : bool, optional, default=True
+ If True, compiles the embedding function for improved performance.
+
+ Forward
+ -------
+ obs : UnifiedObservation
+ Unified observation data containing measurements from multiple sensors.
+
+ Outputs
+ -------
+ torch.Tensor
+ Embedded observations of shape :math:`(B, T, N_{pix}, D)` where :math:`B` is batch size,
+ :math:`T` is time steps, :math:`N_{pix}` is number of HEALPix pixels, and :math:`D`
+ is embedding dimension.
+ """
+
+ def __init__(
+ self,
+ sensor_embedder_config: SensorEmbedderConfig,
+ sensors: dict[str, ModelSensorConfig],
+ hpx_level: int,
+ use_checkpoint: bool = False,
+ compile: bool = True,
+ ):
+ super().__init__()
+
+ # Store config values
+ self.sensors = sensors
+ self.sensor_names = list(self.sensors.keys())
+ self.sensor_ids = [cfg.sensor_id for cfg in self.sensors.values()]
+ self.fusion_dim = sensor_embedder_config.fusion_dim
+ self.use_channel_platform_embedding_table = (
+ sensor_embedder_config.use_channel_platform_embedding_table
+ )
+ self.hpx_level = hpx_level
+ self.npix = 12 * 4**hpx_level
+
+ # src grid of sensor embeddings
+ self.grid = earth2grid.healpix.Grid(
+ hpx_level, pixel_order=earth2grid.healpix.NEST
+ )
+
+ embed_cfg = sensor_embedder_config
+ # Separate embedders for each sensor.
+ self.embedder = torch.nn.ModuleDict(
+ {
+ str(sensor_cfg.sensor_id): SensorEmbedder(
+ sensor_embed_dim=embed_cfg.embed_dim,
+ meta_dim=embed_cfg.meta_dim,
+ output_dim=self.fusion_dim,
+ hpx_level=self.hpx_level,
+ nchannel=sensor_cfg.nchannel,
+ platform_ids=sensor_cfg.platform_ids,
+ use_checkpoint=use_checkpoint,
+ use_channel_platform_embedding_table=self.use_channel_platform_embedding_table,
+ )
+ for sensor_cfg in self.sensors.values()
+ }
+ )
+
+ self.sensor_fusion = UniformFusion(fusion_dim=self.fusion_dim)
+
+ self.output_norm = torch.nn.LayerNorm(self.fusion_dim)
+
+ self.register_buffer(
+ "sensor_ids_tensor", torch.tensor(self.sensor_ids, dtype=torch.int32)
+ )
+ if compile:
+ self.forward = torch.compile(self.forward, dynamic=True)
+
+ def _reorder(self, x: torch.Tensor) -> torch.Tensor:
+ """Reorder from NEST to HEALPIX_PAD_XY. Input shape: (..., npix, c)"""
+ x = self.grid.reorder(
+ earth2grid.healpix.HEALPIX_PAD_XY, x.transpose(-1, -2)
+ ).transpose(-1, -2)
+ return x
+
+ def forward(self, obs: UnifiedObservation) -> torch.Tensor:
+ """
+ Args:
+ obs: UnifiedObservation with flattened observations from all sensors
+
+ Returns:
+ (batch, fusion_dim, time, npix) in HEALPIX_PAD_XY format
+ """
+ if obs.batch_dims is None:
+ raise ValueError(
+ f"offset batch dimensions must be (batch, time) in MultiSensorObsEmbedding, got {obs.batch_dims}"
+ )
+
+ # Embed each sensor's obs separately
+ obs_by_sensor = split_by_sensor(obs, self.sensor_ids)
+ sensor_embeddings = []
+
+ for sensor_id_str, embedder in self.embedder.items():
+ sensor_id = int(sensor_id_str)
+ sensor_obs: UnifiedObservation = obs_by_sensor[sensor_id]
+ output = embedder(sensor_obs) # (b, t, x, c)
+ sensor_embeddings.append(output)
+
+ sensor_embeddings = torch.stack(
+ sensor_embeddings, dim=0
+ ) # (num_sensors, b, t, x, c)
+
+ # Fuse sensors
+ num_sensors, b, t, x, c = sensor_embeddings.shape
+ sensor_embeddings_flat = sensor_embeddings.view(num_sensors, b * t * x, c)
+ fused_flat = self.sensor_fusion(
+ sensor_embeddings_flat, self.sensor_ids_tensor
+ ) # (b*t*x, fusion_dim)
+
+ out = fused_flat.view(b, t, x, self.fusion_dim) # (b, t, x, fusion_dim)
+
+ out = self._reorder(out)
+ out = self.output_norm(out)
+ out = out.permute(0, 3, 1, 2).to(memory_format=torch.channels_last)
+
+ return out
diff --git a/physicsnemo/experimental/models/healda/scatter_aggregator.py b/physicsnemo/experimental/models/healda/scatter_aggregator.py
new file mode 100644
index 0000000000..30bf4a2f5e
--- /dev/null
+++ b/physicsnemo/experimental/models/healda/scatter_aggregator.py
@@ -0,0 +1,102 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Scatter aggregation module for observation embedding."""
+
+import torch
+
+from physicsnemo.core.module import Module
+
+from .scatter_mean import scatter_mean
+
+
+class ScatterAggregator(Module):
+ """Dense-bucket scatter aggregation (all batches together, all buckets).
+
+ Pipeline:
+ 1. Aggregate observations onto spatial grid using scatter_mean
+ 2. Fill unobserved values with zeros
+ 3. Mix across all buckets using MLP
+
+ Args:
+ in_dim: Input token dimension
+ out_dim: Output dimension after projection
+ nchannel: Max number of channels
+ nplatform: Max number of platforms
+ npix: Number of spatial pixels in output grid
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ nchannel: int,
+ nplatform: int,
+ npix: int,
+ ):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.nchannel = nchannel
+ self.nplatform = nplatform
+ self.npix = npix
+ self.nbuckets = nchannel * nplatform
+
+ proj_in = self.nbuckets * in_dim + self.nbuckets # features + bucket coverage
+ proj_out = out_dim * 2
+ self.bucket_mixing_mlp = torch.nn.Sequential(
+ torch.nn.Linear(proj_in, proj_out),
+ torch.nn.LayerNorm(proj_out),
+ torch.nn.SiLU(),
+ torch.nn.Linear(proj_out, out_dim),
+ )
+
+ def forward(
+ self,
+ obs_features: torch.Tensor,
+ batch_idx: torch.Tensor,
+ pix: torch.Tensor,
+ bucket_id: torch.Tensor,
+ nbatch: int,
+ ) -> torch.Tensor:
+ """
+ Aggregate observations to spatial grid.
+
+ Args:
+ obs_features: (nobs, in_dim) tokenized observations
+ batch_idx: (nobs,) batch index for each observation
+ pix: (nobs,) spatial pixel index for each observation
+ bucket_id: (nobs,) bucket ID (platform * nchannel + channel) for each observation
+ nbatch: Number of batch elements
+
+ Returns:
+ (nbatch, npix, out_dim) aggregated and projected spatial grid
+ """
+ grid_indices = torch.stack([batch_idx, pix, bucket_id], dim=-1)
+
+ aggregated, has_obs = scatter_mean(
+ tensor=obs_features,
+ index=grid_indices,
+ shape=(nbatch, self.npix, self.nbuckets),
+ ) # (nbatch, npix, nbuckets, in_dim), (nbatch, npix, nbuckets)
+
+ # Reshape and fill unobserved with zeros (scatter_mean fills empty cells with NaN)
+ nbatch, npix, nbuckets, in_dim = aggregated.shape
+ aggregated = aggregated.view(nbatch, npix, nbuckets * in_dim)
+ aggregated = torch.nan_to_num(aggregated, nan=0.0)
+
+ # Concatenate bucket coverage info and project through MLP
+ mlp_input = torch.cat([aggregated, has_obs.float()], dim=-1)
+ return self.bucket_mixing_mlp(mlp_input)
diff --git a/physicsnemo/experimental/models/healda/scatter_mean.py b/physicsnemo/experimental/models/healda/scatter_mean.py
new file mode 100644
index 0000000000..85172c8489
--- /dev/null
+++ b/physicsnemo/experimental/models/healda/scatter_mean.py
@@ -0,0 +1,81 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+
+import torch
+
+
+def _compute_row_major_strides(shape):
+ strides = []
+ stride = 1
+ for size in reversed(shape):
+ strides.insert(0, stride)
+ stride *= size
+ return strides
+
+
+def scatter_mean(
+ tensor: torch.Tensor,
+ index: torch.Tensor,
+ shape: tuple[int, ...],
+ fill_value: float = float("nan"),
+) -> torch.Tensor:
+ """Scatter-mean values onto a multi-dimensional grid
+
+ Args:
+ tensor: [N, c] observation feature vectors
+ index: [N, d] 1D grid cell index for each value in tensor
+ shape: d-tuple. The size of the non value dimensions of the output array
+
+ Returns:
+ aggregated: [*shape, c] with mean-aggregated values,
+ filled with fill_value at grid cells with no values
+ present: (*shape) bool mask indicating which grid cells have values
+ """
+ strides = _compute_row_major_strides(shape)
+ # manually implement the dot product since matmul doesn't support long tensors on cuda
+ # avoids RuntimeError: "addmv_impl_cuda" not implemented for 'Long'
+ grid_indices_flat = (index * torch.tensor(strides, device=index.device)).sum(dim=-1)
+ grid_size = math.prod(shape)
+
+ device = tensor.device
+ dtype = tensor.dtype
+ embedding_dim = tensor.shape[1]
+
+ # Initialize with fill_value (typically NaN)
+ values_mean = torch.full(
+ (grid_size, embedding_dim), fill_value, device=device, dtype=dtype
+ )
+
+ # Use scatter_reduce with mean, expanding indices to match tensor dimensions
+ grid_indices_flat_expanded = grid_indices_flat.unsqueeze(-1).expand(
+ -1, embedding_dim
+ )
+ values_mean.scatter_reduce_(
+ 0, grid_indices_flat_expanded, tensor, reduce="mean", include_self=False
+ )
+
+ # Compute present mask (cells that are not fill_value)
+ if math.isnan(fill_value):
+ present = ~torch.isnan(values_mean[:, 0])
+ else:
+ present = values_mean[:, 0] != fill_value
+
+ # Reshape
+ aggregated = values_mean.view(*shape, embedding_dim)
+ present = present.view(shape)
+
+ return aggregated, present
diff --git a/physicsnemo/experimental/models/healda/types.py b/physicsnemo/experimental/models/healda/types.py
new file mode 100644
index 0000000000..19032d2b69
--- /dev/null
+++ b/physicsnemo/experimental/models/healda/types.py
@@ -0,0 +1,237 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import dataclasses
+from typing import Optional, TypedDict
+
+import torch
+
+
+@dataclasses.dataclass
+class UnifiedObservation:
+ """Unified observation structure for both satellite and conventional observations."""
+
+ # Core observation data
+ obs: torch.Tensor # Observation values
+ time: torch.Tensor # Observation times
+
+ # Pre-computed metadata
+ float_metadata: (
+ torch.Tensor
+ ) # Pre-computed float features (e.g., angles, local solar time)
+
+ # Integer metadata for spatial aggregation and embedding lookups
+ # Shape: (n_obs, n_metadata_fields=6) - transposed for better torch.compile performance
+ # Typical fields: sensor_id, pix, channel_id, platform_id, obs_type, global_channel_id
+ int_metadata: torch.Tensor # dtype=torch.long
+
+ class bucket_index:
+ sensor = 0
+ pix = 1
+ local_channel = 2
+ platform = 3
+ obs_type = 4
+ global_channel = 5
+
+ offsets: torch.Tensor | None = (
+ None # 3D: (n_active_sensors, batch, time) cumulative start indices
+ )
+ batch_idx: torch.Tensor | None = (
+ None # (n_obs,) batch index - computed from offsets in post_init (batch context, not intrinsic property)
+ )
+ sensor_id_to_local: torch.Tensor | None = (
+ None # (max_sensor_id active + 1,) map: sensor_id -> local_idx (-1 if inactive)
+ )
+ hpx_level: int | None = (
+ None # the hpx level (pix is at int_metadata[:, bucket_index.pix])
+ )
+
+ def __post_init__(self):
+ """Automatically compute batch_idx from offsets after construction."""
+ if self.batch_idx is None:
+ if self.offsets is not None:
+ self.batch_idx = offsets_to_batch_idx(self.offsets)
+ else:
+ # No offsets = single batch, all observations belong to batch 0
+ self.batch_idx = torch.zeros(
+ self.obs.shape[0], dtype=torch.long, device=self.obs.device
+ )
+
+ @property
+ def batch_dims(self):
+ """Return (batch, time) shape from 3D offsets (S, B, T)."""
+ if self.offsets is not None:
+ return self.offsets.shape[-2:]
+ else:
+ return ()
+
+ def __repr__(self):
+ nobs = self.obs.shape[0]
+ return f"UnifiedObservation({nobs=}, batch_dims={self.batch_dims})"
+
+ def to(self, device=None, dtype=None, non_blocking=True):
+ """Move all tensors to device and/or convert dtype."""
+
+ def _move_tensor(x):
+ if x is None:
+ return
+ return x.to(device=device, dtype=dtype, non_blocking=non_blocking)
+
+ return UnifiedObservation(
+ obs=_move_tensor(self.obs),
+ time=_move_tensor(self.time),
+ float_metadata=_move_tensor(self.float_metadata),
+ int_metadata=_move_tensor(self.int_metadata),
+ offsets=_move_tensor(self.offsets),
+ batch_idx=_move_tensor(self.batch_idx),
+ sensor_id_to_local=_move_tensor(self.sensor_id_to_local),
+ hpx_level=self.hpx_level,
+ )
+
+ def record_stream(self, stream):
+ """Mark"""
+ self.obs.record_stream(stream)
+ self.time.record_stream(stream)
+ self.float_metadata.record_stream(stream)
+ self.int_metadata.record_stream(stream)
+ self.batch_idx.record_stream(stream)
+ if self.offsets is not None:
+ self.offsets.record_stream(stream)
+ if self.sensor_id_to_local is not None:
+ self.sensor_id_to_local.record_stream(stream)
+
+
+class Batch(TypedDict):
+ """Input of DA model on which Obs Encoder operates"""
+
+ target: torch.Tensor # (b, c, t, x) - main atmospheric variables
+ condition: torch.Tensor # (b, c_cond, t, x) - conditioning variables
+ second_of_day: torch.Tensor # (b, t) - seconds of day
+ day_of_year: torch.Tensor # (b, t) - day of year
+ labels: torch.Tensor # (b, num_classes) - one-hot encoded labels
+ timestamp: torch.Tensor # (b,) - timestamps as seconds since epoch
+ unified_obs: Optional[UnifiedObservation] # Unified observation data (v2)
+ # Residual training fields (optional)
+ background: Optional[
+ torch.Tensor
+ ] # (b, c, t, x) - background data for residual training
+ residual_target: Optional[torch.Tensor] # (b, c, t, x) - residual target
+ residual_denormalized: Optional[
+ torch.Tensor
+ ] # (b, c, t, x) - denormalized residual
+ background_label: Optional[torch.Tensor] # (b, num_classes) - background label
+ lag_steps: Optional[torch.Tensor] # (b,) - lag steps for residual training
+
+
+def offsets_to_batch_idx(offsets):
+ """Convert 3D cumulative-end offsets to (batch, time) indices.
+
+ offsets is (S, B, T) with cumulative ends.
+ Returns index in [0, B*T) for each observation, ignoring sensor dimension.
+ """
+ S, B, T = offsets.shape
+ bt_size = B * T
+
+ offsets_flat = offsets.flatten()
+ offsets_with_zero = torch.cat(
+ [torch.tensor([0], device=offsets.device, dtype=offsets.dtype), offsets_flat]
+ )
+ sizes = offsets_with_zero.diff() # num obs per group of sensor obs
+
+ # Assign each group an index in [0, S*B*T), then map to [0, B*T) with mod
+ window_indices = torch.arange(
+ sizes.shape[0], dtype=torch.long, device=offsets.device
+ )
+ bt_indices = window_indices % bt_size
+
+ return bt_indices.repeat_interleave(sizes)
+
+
+@torch.compiler.disable
+def split_by_sensor(
+ obs: UnifiedObservation, target_sensor_ids: list[int]
+) -> dict[int, UnifiedObservation]:
+ """
+ Slice a UnifiedObservation into per-sensor sub-objects using its precomputed offsets.
+
+ Args:
+ obs: UnifiedObservation
+ target_sensor_ids: list of int sensor IDs to extract
+
+ Returns:
+ dict[int, UnifiedObservation]: mapping from sensor_id -> sliced UnifiedObservation.
+ If a sensor_id has no data, returns an empty slice
+ (same structure, 0 rows).
+ """
+ if obs.offsets is None or obs.sensor_id_to_local is None:
+ raise ValueError("offsets is required for split_by_sensor")
+
+ out = {}
+ offsets = obs.offsets # [S,B,T]
+ sensor_id_to_local = obs.sensor_id_to_local # [max_sensor_id+1]
+
+ device = obs.obs.device
+ B, T = obs.batch_dims
+ total_obs = obs.obs.shape[0]
+
+ obs_count = 0
+ for sensor_id in target_sensor_ids:
+ if sensor_id < len(sensor_id_to_local):
+ s_local = sensor_id_to_local[sensor_id].item()
+ else:
+ s_local = -1
+
+ if s_local < 0:
+ # Not active -> return zero-length slice
+ start = end = 0
+ sensor_offsets = torch.zeros((1, B, T), dtype=offsets.dtype, device=device)
+ else:
+ # Each sensor's last cumulative offset is total rows for that sensor
+ end = offsets[s_local, -1, -1].item()
+ # Adjust offsets to be relative to this sensor's start
+ start = 0 if s_local == 0 else offsets[s_local - 1, -1, -1].item()
+ sensor_offsets = offsets[s_local : s_local + 1] - start
+
+ if not (0 <= start <= total_obs and start <= end <= total_obs):
+ raise ValueError(
+ f"Invalid offsets for sensor {sensor_id}: start={start}, end={end}, "
+ f"total_obs={total_obs}."
+ )
+ length = end - start
+
+ def _narrow(x):
+ return (
+ torch.narrow(x, 0, start, length) if length > 0 else x.narrow(0, 0, 0)
+ )
+
+ # single-sensor map: only this sensor maps to local idx 0
+ single_sensor_map = torch.full(
+ (sensor_id + 1,), -1, dtype=torch.int32, device=device
+ )
+ single_sensor_map[sensor_id] = 0
+
+ out[sensor_id] = UnifiedObservation(
+ obs=_narrow(obs.obs),
+ time=_narrow(obs.time),
+ float_metadata=_narrow(obs.float_metadata),
+ int_metadata=_narrow(obs.int_metadata),
+ hpx_level=obs.hpx_level,
+ offsets=sensor_offsets, # (1, B, T) relative to sliced data
+ batch_idx=_narrow(obs.batch_idx), # Also narrow batch_idx
+ sensor_id_to_local=single_sensor_map,
+ )
+ obs_count += length
+
+ return out
diff --git a/physicsnemo/nn/__init__.py b/physicsnemo/nn/__init__.py
index 26e5adc63f..0a72fe247b 100644
--- a/physicsnemo/nn/__init__.py
+++ b/physicsnemo/nn/__init__.py
@@ -51,6 +51,7 @@
TransposeConvLayer,
)
from .module.dgm_layers import DGMLayer
+from .module.drop import DropPath
from .module.embedding_layers import FourierEmbedding, PositionalEmbedding
from .module.fourier_layers import (
FourierFilter,
diff --git a/physicsnemo/nn/module/__init__.py b/physicsnemo/nn/module/__init__.py
index 03c400f563..9358960a4e 100644
--- a/physicsnemo/nn/module/__init__.py
+++ b/physicsnemo/nn/module/__init__.py
@@ -28,6 +28,7 @@
from .ball_query import BQWarp
from .conv_layers import ConvBlock, CubeEmbedding
from .dgm_layers import DGMLayer
+from .drop import DropPath
from .embedding_layers import FourierEmbedding, PositionalEmbedding
from .fourier_layers import (
FourierFilter,
diff --git a/physicsnemo/nn/module/mlp_layers.py b/physicsnemo/nn/module/mlp_layers.py
index ecb34396e2..f59cf4a6c8 100644
--- a/physicsnemo/nn/module/mlp_layers.py
+++ b/physicsnemo/nn/module/mlp_layers.py
@@ -28,6 +28,7 @@ def __init__(
out_features: int | None = None,
act_layer: nn.Module | str = nn.GELU,
drop: float = 0.0,
+ final_dropout: bool = True,
):
super().__init__()
out_features = out_features or in_features
@@ -65,7 +66,7 @@ def __init__(
# Add the last layers:
layers.append(nn.Linear(input_dim, out_features))
- if drop != 0:
+ if drop != 0 and final_dropout:
layers.append(nn.Dropout(drop))
self.layers = nn.Sequential(*layers)
diff --git a/test/models/healda/__init__.py b/test/models/healda/__init__.py
new file mode 100644
index 0000000000..0e5652267a
--- /dev/null
+++ b/test/models/healda/__init__.py
@@ -0,0 +1,15 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/test/models/healda/_regtest_outputs/test_obs_decoder.test_obs_decoder_forward_pass.out b/test/models/healda/_regtest_outputs/test_obs_decoder.test_obs_decoder_forward_pass.out
new file mode 100644
index 0000000000..64ec07d754
--- /dev/null
+++ b/test/models/healda/_regtest_outputs/test_obs_decoder.test_obs_decoder_forward_pass.out
@@ -0,0 +1,6 @@
+Output shape: torch.Size([50, 1])
+Output mean: 0.203019
+Output std: 0.270797
+Output min: -0.618511
+Output max: 0.797052
+Output hash: 3a6b7998eb968dfc2cc37f5343bdacb980bdc07ebd90307d226c5992e9d006b6
diff --git a/test/models/healda/test_dit.py b/test/models/healda/test_dit.py
new file mode 100644
index 0000000000..f9f1ed41a0
--- /dev/null
+++ b/test/models/healda/test_dit.py
@@ -0,0 +1,300 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+import torch
+
+from physicsnemo.experimental.models.healda import (
+ HealDA,
+ HPXPatchDetokenizer,
+ HPXPatchTokenizer,
+ ModelSensorConfig,
+ SensorEmbedderConfig,
+)
+
+from .utils.obs_test_utils import create_unified_observation
+
+
+def test_tokenizer():
+ """Test HPXPatchTokenizer."""
+ level_fine = 6
+ level_coarse = 5
+ hidden_size = 64
+ n, t = 2, 1
+ npix_fine = 12 * 4**level_fine
+ npix_coarse = 12 * 4**level_coarse
+
+ tokenizer = HPXPatchTokenizer(
+ in_channels=10,
+ hidden_size=hidden_size,
+ level_fine=level_fine,
+ level_coarse=level_coarse,
+ )
+
+ x = torch.randn(n, 10, t, npix_fine)
+ doy = torch.ones([n, t])
+ second = torch.ones([n, t])
+
+ out = tokenizer(x, second_of_day=second, day_of_year=doy)
+ # Output: (B, L, D) where L = T * npix_coarse
+ assert out.shape == (n, t * npix_coarse, hidden_size)
+
+
+def test_detokenizer():
+ """Test HPXPatchDetokenizer."""
+ level_fine = 6
+ level_coarse = 5
+ hidden_size = 64
+ out_channels = 3
+ n, t = 2, 1
+ npix_fine = 12 * 4**level_fine
+ npix_coarse = 12 * 4**level_coarse
+
+ detokenizer = HPXPatchDetokenizer(
+ hidden_size=hidden_size,
+ out_channels=out_channels,
+ level_coarse=level_coarse,
+ level_fine=level_fine,
+ time_length=t,
+ )
+
+ # Input: (B, L, D) where L = T * npix_coarse
+ x = torch.randn(n, t * npix_coarse, hidden_size)
+ c = torch.randn(n, hidden_size) # conditioning vector
+
+ out = detokenizer(x, c)
+ # Output: (B, C_out, T, npix_fine)
+ assert out.shape == (n, out_channels, t, npix_fine)
+
+
+def test_detokenizer_vit_mode():
+ """Test HPXPatchDetokenizer in VIT mode (c=zeros)."""
+ level_fine = 6
+ level_coarse = 5
+ hidden_size = 64
+ out_channels = 3
+ n, t = 2, 1
+ npix_fine = 12 * 4**level_fine
+ npix_coarse = 12 * 4**level_coarse
+
+ detokenizer = HPXPatchDetokenizer(
+ hidden_size=hidden_size,
+ out_channels=out_channels,
+ level_coarse=level_coarse,
+ level_fine=level_fine,
+ time_length=t,
+ )
+ # Zero-init for VIT mode (scale=0, shift=0 -> identity modulation)
+ detokenizer.initialize_weights()
+
+ x = torch.randn(n, t * npix_coarse, hidden_size)
+
+ # VIT mode: pass zeros for conditioning (with zero-init weights -> identity)
+ c = torch.zeros(n, hidden_size)
+ out = detokenizer(x, c)
+ assert out.shape == (n, out_channels, t, npix_fine)
+
+
+# ============================================================================
+# HealDA Model Tests
+# ============================================================================
+
+
+@pytest.mark.parametrize("condition_dim", [None, 256])
+def test_healda_forward(condition_dim, device):
+ """Test HealDA model forward pass in VIT and Diffusion modes."""
+ n, t = 1, 1
+ level_in = 6
+ level_model = 5
+ in_channels = 3
+ out_channels = 3
+ npix = 12 * 4**level_in
+
+ sensor_config = {
+ "sensor_1": ModelSensorConfig(sensor_id=1, nchannel=4, platform_ids=(0, 1, 2)),
+ }
+ sensor_embedder_config = SensorEmbedderConfig(embed_dim=16, fusion_dim=32)
+
+ model = HealDA(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ sensor_embedder_config=sensor_embedder_config,
+ sensors=sensor_config,
+ hidden_size=64,
+ num_layers=1,
+ num_heads=2,
+ level_in=level_in,
+ level_model=level_model,
+ time_length=t,
+ condition_dim=condition_dim,
+ attention_backend="timm", # Use timm for CPU testing
+ layernorm_backend="torch",
+ )
+ model.to(device)
+
+ # Create inputs
+ x = torch.randn(n, in_channels, t, npix, device=device)
+ timestep = torch.zeros(n, device=device)
+ doy = torch.ones([n, t], device=device)
+ second = torch.ones([n, t], device=device)
+
+ # Create mock observation
+ obs = create_unified_observation(
+ nobs=100,
+ batch_size=n,
+ time_steps=t,
+ meta_dim=28,
+ hpx_level=level_in,
+ device=device,
+ sensor_config=sensor_config,
+ )
+
+ # Prepare conditioning args
+ kwargs = {}
+ if condition_dim is not None:
+ kwargs["noise_labels"] = torch.ones(n, device=device)
+
+ out = model(
+ x,
+ timestep,
+ unified_obs=obs,
+ day_of_year=doy,
+ second_of_day=second,
+ **kwargs,
+ )
+
+ # Verify output shape
+ assert out.shape == (n, out_channels, t, npix)
+
+
+def test_healda_backward(device):
+ """Test HealDA model backward pass."""
+ n, t = 1, 1
+ level_in = 6
+ level_model = 5
+ in_channels = 3
+ out_channels = 3
+ npix = 12 * 4**level_in
+
+ sensor_config = {
+ "sensor_1": ModelSensorConfig(sensor_id=1, nchannel=4, platform_ids=(0,)),
+ }
+ sensor_embedder_config = SensorEmbedderConfig(embed_dim=16, fusion_dim=32)
+
+ model = HealDA(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ sensor_embedder_config=sensor_embedder_config,
+ sensors=sensor_config,
+ hidden_size=64,
+ num_layers=1,
+ num_heads=2,
+ level_in=level_in,
+ level_model=level_model,
+ time_length=t,
+ condition_dim=None, # VIT mode
+ attention_backend="timm",
+ layernorm_backend="torch",
+ )
+ model.to(device)
+
+ x = torch.randn(n, in_channels, t, npix, device=device, requires_grad=True)
+ timestep = torch.zeros(n, device=device)
+ doy = torch.ones([n, t], device=device)
+ second = torch.ones([n, t], device=device)
+
+ obs = create_unified_observation(
+ nobs=100,
+ batch_size=n,
+ time_steps=t,
+ meta_dim=28,
+ hpx_level=level_in,
+ device=device,
+ sensor_config=sensor_config,
+ )
+
+ out = model(
+ x,
+ timestep,
+ unified_obs=obs,
+ day_of_year=doy,
+ second_of_day=second,
+ )
+
+ # Backward pass
+ loss = out.sum()
+ loss.backward()
+
+ # Verify gradients exist
+ assert x.grad is not None
+ assert x.grad.shape == x.shape
+
+
+@pytest.mark.parametrize("t", [1, 2])
+def test_healda_time_length(t, device):
+ """Test HealDA with different time lengths."""
+ n = 1
+ level_in = 6
+ level_model = 5
+ in_channels = 3
+ out_channels = 3
+ npix = 12 * 4**level_in
+
+ sensor_config = {
+ "sensor_1": ModelSensorConfig(sensor_id=1, nchannel=4, platform_ids=(0,)),
+ }
+ sensor_embedder_config = SensorEmbedderConfig(embed_dim=16, fusion_dim=32)
+
+ model = HealDA(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ sensor_embedder_config=sensor_embedder_config,
+ sensors=sensor_config,
+ hidden_size=64,
+ num_layers=1,
+ num_heads=2,
+ level_in=level_in,
+ level_model=level_model,
+ time_length=t,
+ condition_dim=None,
+ attention_backend="timm",
+ layernorm_backend="torch",
+ )
+ model.to(device)
+
+ x = torch.randn(n, in_channels, t, npix, device=device)
+ timestep = torch.zeros(n, device=device)
+ doy = torch.ones([n, t], device=device)
+ second = torch.ones([n, t], device=device)
+
+ obs = create_unified_observation(
+ nobs=100,
+ batch_size=n,
+ time_steps=t,
+ meta_dim=28,
+ hpx_level=level_in,
+ device=device,
+ sensor_config=sensor_config,
+ )
+
+ out = model(
+ x,
+ timestep,
+ unified_obs=obs,
+ day_of_year=doy,
+ second_of_day=second,
+ )
+
+ assert out.shape == (n, out_channels, t, npix)
diff --git a/test/models/healda/test_healpix_layers.py b/test/models/healda/test_healpix_layers.py
new file mode 100644
index 0000000000..2cce22f15a
--- /dev/null
+++ b/test/models/healda/test_healpix_layers.py
@@ -0,0 +1,197 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+from physicsnemo.experimental.models.healda import (
+ HPXPatchDetokenizer,
+ HPXPatchTokenizer,
+)
+
+
+def test_hpx_patch_tokenizer():
+ """Test HPXPatchTokenizer forward pass."""
+ in_channels = 5
+ hidden_size = 64
+ level_fine = 6
+ level_coarse = 4
+
+ tokenizer = HPXPatchTokenizer(
+ in_channels=in_channels,
+ hidden_size=hidden_size,
+ level_fine=level_fine,
+ level_coarse=level_coarse,
+ )
+
+ b, t = 2, 3
+ npix = 12 * 4**level_fine
+ x = torch.randn(b, in_channels, t, npix)
+
+ second_of_day = torch.randint(0, 86400, (b, t))
+ day_of_year = torch.randint(0, 365, (b, t))
+
+ out = tokenizer(x, second_of_day=second_of_day, day_of_year=day_of_year)
+
+ # Output should be (B, L, D) where L = T * npix_coarse
+ expected_npix_coarse = 12 * 4**level_coarse
+ expected_L = t * expected_npix_coarse
+ assert out.shape == (b, expected_L, hidden_size)
+
+ # Verify output is finite
+ assert torch.isfinite(out).all(), "Output contains non-finite values"
+
+
+def test_hpx_patch_tokenizer_no_calendar():
+ """Test HPXPatchTokenizer without calendar embedding."""
+ tokenizer = HPXPatchTokenizer(
+ in_channels=5,
+ hidden_size=64,
+ level_fine=6,
+ level_coarse=4,
+ )
+
+ b, t = 2, 1
+ npix = 12 * 4**6
+ x = torch.randn(b, 5, t, npix)
+
+ # No calendar embedding
+ out = tokenizer(x)
+
+ expected_L = t * 12 * 4**4
+ assert out.shape == (b, expected_L, 64)
+
+
+def test_hpx_patch_detokenizer():
+ """Test HPXPatchDetokenizer forward pass."""
+ hidden_size = 64
+ out_channels = 5
+ level_coarse = 4
+ level_fine = 6
+ time_length = 3
+
+ detokenizer = HPXPatchDetokenizer(
+ hidden_size=hidden_size,
+ out_channels=out_channels,
+ level_coarse=level_coarse,
+ level_fine=level_fine,
+ time_length=time_length,
+ )
+
+ b = 2
+ npix_coarse = 12 * 4**level_coarse
+ L = time_length * npix_coarse
+ x = torch.randn(b, L, hidden_size)
+ c = torch.randn(b, hidden_size) # Conditioning
+
+ out = detokenizer(x, c)
+
+ # Output should be (B, C, T, npix_fine)
+ npix_fine = 12 * 4**level_fine
+ assert out.shape == (b, out_channels, time_length, npix_fine)
+ assert torch.isfinite(out).all()
+
+
+def test_hpx_patch_detokenizer_vit_mode():
+ """Test HPXPatchDetokenizer in VIT mode (c=zeros)."""
+ hidden_size = 64
+ out_channels = 5
+ level_coarse = 4
+ level_fine = 6
+ time_length = 3
+
+ detokenizer = HPXPatchDetokenizer(
+ hidden_size=hidden_size,
+ out_channels=out_channels,
+ level_coarse=level_coarse,
+ level_fine=level_fine,
+ time_length=time_length,
+ )
+ # Zero-initialize (as in DiT)
+ detokenizer.initialize_weights()
+
+ b = 2
+ npix_coarse = 12 * 4**level_coarse
+ L = time_length * npix_coarse
+ x = torch.randn(b, L, hidden_size)
+
+ # VIT mode: pass zeros for c
+ c = torch.zeros(b, hidden_size)
+ out = detokenizer(x, c)
+
+ npix_fine = 12 * 4**level_fine
+ assert out.shape == (b, out_channels, time_length, npix_fine)
+ assert torch.isfinite(out).all()
+
+
+def test_tokenizer_detokenizer_roundtrip():
+ """Test that tokenizer -> detokenizer roundtrip produces correct shapes."""
+ in_channels = 5
+ hidden_size = 64
+ level_fine = 6
+ level_coarse = 4
+ time_length = 2
+
+ tokenizer = HPXPatchTokenizer(
+ in_channels=in_channels,
+ hidden_size=hidden_size,
+ level_fine=level_fine,
+ level_coarse=level_coarse,
+ )
+
+ detokenizer = HPXPatchDetokenizer(
+ hidden_size=hidden_size,
+ out_channels=in_channels,
+ level_coarse=level_coarse,
+ level_fine=level_fine,
+ time_length=time_length,
+ )
+
+ b = 2
+ npix_fine = 12 * 4**level_fine
+ x = torch.randn(b, in_channels, time_length, npix_fine)
+ second_of_day = torch.randint(0, 86400, (b, time_length))
+ day_of_year = torch.randint(0, 365, (b, time_length))
+ c = torch.randn(b, hidden_size)
+
+ # Tokenize
+ tokens = tokenizer(x, second_of_day=second_of_day, day_of_year=day_of_year)
+ npix_coarse = 12 * 4**level_coarse
+ assert tokens.shape == (b, time_length * npix_coarse, hidden_size)
+
+ # Detokenize
+ out = detokenizer(tokens, c)
+ assert out.shape == x.shape
+
+
+def test_tokenizer_backward():
+ """Test HPXPatchTokenizer backward pass."""
+ tokenizer = HPXPatchTokenizer(
+ in_channels=5,
+ hidden_size=64,
+ level_fine=6,
+ level_coarse=4,
+ )
+
+ b, t = 2, 1
+ npix = 12 * 4**6
+ x = torch.randn(b, 5, t, npix, requires_grad=True)
+ second_of_day = torch.randint(0, 86400, (b, t))
+ day_of_year = torch.randint(0, 365, (b, t))
+
+ out = tokenizer(x, second_of_day=second_of_day, day_of_year=day_of_year)
+ out.sum().backward()
+
+ assert x.grad is not None
+ assert x.grad.shape == x.shape
diff --git a/test/models/healda/test_point_embed.py b/test/models/healda/test_point_embed.py
new file mode 100644
index 0000000000..30a3173b92
--- /dev/null
+++ b/test/models/healda/test_point_embed.py
@@ -0,0 +1,268 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+import torch
+
+from physicsnemo.experimental.models.healda import (
+ ModelSensorConfig,
+ MultiSensorObsEmbedding,
+ SensorEmbedderConfig,
+ UnifiedObservation,
+)
+from physicsnemo.experimental.models.healda.point_embed import SensorEmbedder
+
+from .utils.obs_test_utils import create_unified_observation
+
+# ============================================================================
+# Test Utilities
+# ============================================================================
+
+
+def check_all_params_have_gradients(model: torch.nn.Module) -> tuple[bool, list[str]]:
+ """
+ Check that all parameters in a model have gradients.
+
+ Returns:
+ (all_have_grads, params_without_grads)
+ """
+ params_without_grads = []
+ for name, param in model.named_parameters():
+ if param.requires_grad and param.grad is None:
+ params_without_grads.append(name)
+
+ return len(params_without_grads) == 0, params_without_grads
+
+
+# ============================================================================
+# SensorEmbedder Tests
+# ============================================================================
+
+
+@pytest.mark.parametrize("nobs", [0, 100])
+def test_sensor_embedder(nobs):
+ """Test SensorEmbedder shapes, values, and gradients (includes empty obs for DDP compatibility)."""
+ torch.manual_seed(42)
+
+ sensor_embed_dim = 16
+ output_dim = 32
+ hpx_level = 5
+ meta_dim = 8
+ nchannel = 10
+ nplatform = 5
+ batch_size = 2
+
+ embedder = SensorEmbedder(
+ platform_ids=list(
+ range(nplatform)
+ ), # Simple case: platform IDs = [0, 1, ..., nplatform-1]
+ sensor_embed_dim=sensor_embed_dim,
+ output_dim=output_dim,
+ meta_dim=meta_dim,
+ n_embed=100,
+ hpx_level=hpx_level,
+ nchannel=nchannel,
+ )
+ embedder.train()
+
+ # Create observation
+ obs = create_unified_observation(
+ nobs=nobs,
+ batch_size=batch_size,
+ time_steps=1,
+ hpx_level=hpx_level + 1,
+ meta_dim=meta_dim,
+ nchannel=nchannel,
+ nplatform=nplatform,
+ n_embed=100,
+ )
+
+ # Forward pass
+ output = embedder(obs)
+
+ # Check output shape and values
+ npix = 12 * 4**hpx_level
+ assert output.shape == (batch_size, 1, npix, output_dim)
+ assert torch.isfinite(output).all()
+
+ # Check gradients
+ loss = output.sum()
+ loss.backward()
+ all_have_grads, missing = check_all_params_have_gradients(embedder)
+ assert all_have_grads, f"Parameters without gradients (nobs={nobs}): {missing}"
+
+
+def test_sensor_embedder_no_batching():
+ """Test SensorEmbedder with offsets=None (no explicit batching)."""
+ torch.manual_seed(42)
+
+ sensor_embed_dim = 16
+ output_dim = 128
+ hpx_level = 5
+ nobs = 50
+ nplatform = 5
+
+ embedder = SensorEmbedder(
+ platform_ids=list(range(nplatform)),
+ sensor_embed_dim=sensor_embed_dim,
+ output_dim=output_dim,
+ meta_dim=8,
+ n_embed=100,
+ hpx_level=hpx_level,
+ )
+
+ # Create observation without offsets
+ obs = create_unified_observation(
+ nobs=nobs,
+ batch_size=1,
+ time_steps=1,
+ hpx_level=hpx_level + 1,
+ meta_dim=8,
+ n_embed=100,
+ )
+ obs = UnifiedObservation(
+ obs=obs.obs,
+ time=obs.time,
+ float_metadata=obs.float_metadata,
+ int_metadata=obs.int_metadata,
+ offsets=None, # No explicit batching
+ hpx_level=obs.hpx_level,
+ )
+
+ # Forward pass
+ with torch.no_grad():
+ output = embedder(obs)
+
+ # Check output shape (should be 2D: npix x output_dim)
+ npix = 12 * 4**hpx_level
+ assert output.shape == (npix, output_dim)
+
+
+# ============================================================================
+# MultiSensorObsEmbedding Tests
+# ============================================================================
+
+
+@pytest.mark.parametrize("num_sensors", [1, 2])
+def test_multisensor_obs_embedding(num_sensors):
+ """Test MultiSensorObsEmbedding with different sensor counts."""
+ torch.manual_seed(42)
+
+ sensor_embed_dim = 16
+ fusion_dim = 32
+ hpx_level = 5
+ meta_dim = 8
+
+ # Build sensor configs
+ all_sensor_configs = {
+ "test_sensor_0": ModelSensorConfig(
+ sensor_id=0, nchannel=10, platform_ids=tuple(range(5))
+ ),
+ "test_sensor_1": ModelSensorConfig(
+ sensor_id=1, nchannel=10, platform_ids=tuple(range(5))
+ ),
+ }
+ sensor_config = dict(list(all_sensor_configs.items())[:num_sensors])
+
+ sensor_embedder_config = SensorEmbedderConfig(
+ embed_dim=sensor_embed_dim,
+ meta_dim=meta_dim,
+ fusion_dim=fusion_dim,
+ )
+
+ embedder = MultiSensorObsEmbedding(
+ sensor_embedder_config=sensor_embedder_config,
+ sensors=sensor_config,
+ hpx_level=hpx_level,
+ )
+
+ # Create observation
+ obs = create_unified_observation(
+ nobs=100,
+ batch_size=2,
+ time_steps=1,
+ hpx_level=hpx_level + 1,
+ meta_dim=meta_dim,
+ sensor_config=sensor_config,
+ n_embed=100,
+ ensure_all_sensors=True, # Ensure all obs are assigned to defined sensors
+ )
+
+ # Forward pass
+ with torch.no_grad():
+ output = embedder(obs)
+
+ # Check output shape and values
+ npix = 12 * 4**hpx_level
+ assert output.shape == (2, fusion_dim, 1, npix)
+ assert torch.isfinite(output).all()
+ assert not torch.allclose(output, torch.zeros_like(output))
+
+
+@pytest.mark.parametrize("nobs", [0, 50])
+def test_multisensor_gradients(nobs):
+ """Test gradient flow through MultiSensorObsEmbedding (includes empty obs for DDP compatibility)."""
+ torch.manual_seed(42)
+
+ sensor_embed_dim = 16
+ fusion_dim = 32
+ hpx_level = 5
+ meta_dim = 8
+
+ sensor_config = {
+ "test_sensor_0": ModelSensorConfig(
+ sensor_id=0, nchannel=10, platform_ids=tuple(range(5))
+ ),
+ "test_sensor_1": ModelSensorConfig(
+ sensor_id=1, nchannel=10, platform_ids=tuple(range(5))
+ ),
+ }
+
+ sensor_embedder_config = SensorEmbedderConfig(
+ embed_dim=sensor_embed_dim,
+ meta_dim=meta_dim,
+ fusion_dim=fusion_dim,
+ )
+
+ embedder = MultiSensorObsEmbedding(
+ sensor_embedder_config=sensor_embedder_config,
+ sensors=sensor_config,
+ hpx_level=hpx_level,
+ )
+ embedder.train()
+ embedder.zero_grad()
+
+ obs = create_unified_observation(
+ nobs=nobs,
+ batch_size=2,
+ time_steps=1,
+ hpx_level=hpx_level + 1,
+ meta_dim=meta_dim,
+ sensor_config=sensor_config,
+ n_embed=100,
+ ensure_all_sensors=True,
+ )
+
+ # Forward + backward
+ output = embedder(obs)
+ assert torch.isfinite(output).all()
+ loss = output.sum()
+ loss.backward()
+
+ # Check gradients (critical for DDP - must work with empty obs)
+ all_have_grads, missing = check_all_params_have_gradients(embedder)
+ assert all_have_grads, f"Parameters without gradients (nobs={nobs}): {missing}"
diff --git a/test/models/healda/test_scatter_mean.py b/test/models/healda/test_scatter_mean.py
new file mode 100644
index 0000000000..b8ad3104f4
--- /dev/null
+++ b/test/models/healda/test_scatter_mean.py
@@ -0,0 +1,102 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+from physicsnemo.experimental.models.healda import scatter_mean
+
+
+def test_scatter_mean_basic():
+ """Test scatter_mean with simple known values"""
+ # Create test data:
+ # - 5 observations with 2 features each
+ # - Scatter into a 3x2 grid (6 cells total)
+ # - Some cells will have multiple values (need averaging)
+ # - Some cells will be empty (should get fill_value)
+
+ tensor = torch.tensor(
+ [
+ [1.0, 10.0], # goes to cell (0, 0)
+ [2.0, 20.0], # goes to cell (0, 1)
+ [3.0, 30.0], # goes to cell (0, 0) - same as first, should average
+ [4.0, 40.0], # goes to cell (2, 1)
+ [5.0, 50.0], # goes to cell (1, 0)
+ ]
+ )
+
+ index = torch.tensor(
+ [
+ [0, 0], # cell (0, 0)
+ [0, 1], # cell (0, 1)
+ [0, 0], # cell (0, 0)
+ [2, 1], # cell (2, 1)
+ [1, 0], # cell (1, 0)
+ ]
+ )
+
+ shape = (3, 2) # 3 rows, 2 columns
+
+ aggregated, present = scatter_mean(tensor, index, shape)
+
+ # Check shape
+ assert aggregated.shape == (3, 2, 2) # (3, 2) grid with 2 features
+ assert present.shape == (3, 2)
+
+ # Check aggregated values
+ # Cell (0, 0): mean of [1.0, 10.0] and [3.0, 30.0] = [2.0, 20.0]
+ assert torch.allclose(aggregated[0, 0], torch.tensor([2.0, 20.0]))
+
+ # Cell (0, 1): [2.0, 20.0] (single value)
+ assert torch.allclose(aggregated[0, 1], torch.tensor([2.0, 20.0]))
+
+ # Cell (1, 0): [5.0, 50.0] (single value)
+ assert torch.allclose(aggregated[1, 0], torch.tensor([5.0, 50.0]))
+
+ # Cell (1, 1): empty, should be NaN
+ assert torch.isnan(aggregated[1, 1]).all()
+
+ # Cell (2, 0): empty, should be NaN
+ assert torch.isnan(aggregated[2, 0]).all()
+
+ # Cell (2, 1): [4.0, 40.0] (single value)
+ assert torch.allclose(aggregated[2, 1], torch.tensor([4.0, 40.0]))
+
+ # Check present mask
+ expected_present = torch.tensor([[True, True], [True, False], [False, True]])
+ assert torch.equal(present, expected_present)
+
+
+def test_scatter_mean_custom_fill_value():
+ """Test scatter_mean with a custom fill value"""
+ tensor = torch.tensor([[1.0, 2.0]])
+ index = torch.tensor([[0, 0]])
+ shape = (2, 2)
+ fill_value = -999.0
+
+ aggregated, present = scatter_mean(tensor, index, shape, fill_value=fill_value)
+
+ # Cell (0, 0) should have the value
+ assert torch.allclose(aggregated[0, 0], torch.tensor([1.0, 2.0]))
+
+ # Other cells should have the fill value
+ assert (aggregated[0, 1] == fill_value).all()
+ assert (aggregated[1, 0] == fill_value).all()
+ assert (aggregated[1, 1] == fill_value).all()
+
+ # Only (0, 0) should be present
+ assert present[0, 0]
+ assert not present[0, 1]
+ assert not present[1, 0]
+ assert not present[1, 1]
diff --git a/test/models/healda/test_types.py b/test/models/healda/test_types.py
new file mode 100644
index 0000000000..8aded32589
--- /dev/null
+++ b/test/models/healda/test_types.py
@@ -0,0 +1,262 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+import torch
+
+from physicsnemo.experimental.models.healda.types import (
+ UnifiedObservation,
+ split_by_sensor,
+)
+
+
+def make_realistic_obs(
+ B: int = 2, T: int = 2, sensors: list[int] = [0, 1, 2]
+) -> UnifiedObservation:
+ """Create realistic cyclic observation data matching real UFS patterns.
+
+ Sensors cycle: 0,1,2,0,1,2,... within each (b,t) window, then data is sorted globally by sensor_id.
+ """
+ S = len(sensors)
+
+ # Generate observations: each window has 6 obs cycling through sensors
+ all_obs = []
+ for b in range(B):
+ for t in range(T):
+ for i in range(6): # 6 obs per window
+ sensor_id = sensors[i % S]
+ all_obs.append(
+ (sensor_id, b, t, len(all_obs))
+ ) # (sensor, batch, time, value)
+
+ # Sort by sensor_id (as real data is)
+ all_obs.sort(key=lambda x: (x[0], x[3])) # Sort by sensor, then original index
+
+ # Extract sorted data
+ sensor_ids = torch.tensor([x[0] for x in all_obs], dtype=torch.long)
+ values = torch.tensor([x[3] for x in all_obs], dtype=torch.float32)
+
+ # Build 3D offsets: (S, B, T) cumulative ends
+ # Count how many obs each sensor has in each (b,t) window
+ offsets_3d = torch.zeros((S, B, T), dtype=torch.int32)
+ idx = 0
+ for s_local, s_id in enumerate(sensors):
+ for b in range(B):
+ for t in range(T):
+ # Count obs for this sensor in this window
+ count = sum(
+ 1
+ for obs in all_obs
+ if obs[0] == s_id and obs[1] == b and obs[2] == t
+ )
+ idx += count
+ offsets_3d[s_local, b, t] = idx
+
+ # Create sensor_id_to_local map
+ sensor_id_to_local = torch.full((max(sensors) + 1,), -1, dtype=torch.int32)
+ for local_idx, s_id in enumerate(sensors):
+ sensor_id_to_local[s_id] = local_idx
+
+ # Build UnifiedObservation
+ nobs = len(all_obs)
+ return UnifiedObservation(
+ obs=values.unsqueeze(1).expand(nobs, 3), # (nobs, 3) features
+ time=values,
+ float_metadata=values.unsqueeze(1).expand(nobs, 5),
+ int_metadata=torch.stack(
+ [
+ sensor_ids,
+ torch.arange(nobs),
+ torch.zeros(nobs),
+ torch.zeros(nobs),
+ torch.zeros(nobs),
+ torch.zeros(nobs),
+ ],
+ dim=1,
+ ),
+ offsets=offsets_3d,
+ sensor_id_to_local=sensor_id_to_local,
+ hpx_level=6,
+ )
+
+
+def test_split_preserves_all_observations():
+ """Critical test: verify no observations are lost during split."""
+ obs = make_realistic_obs(B=2, T=2, sensors=[0, 1, 2])
+ total_before = obs.obs.shape[0]
+
+ split = split_by_sensor(obs, [0, 1, 2])
+
+ # Count total after split
+ total_after = sum(split[sid].obs.shape[0] for sid in [0, 1, 2])
+ assert total_after == total_before, (
+ f"LOST OBSERVATIONS: {total_before} โ {total_after}"
+ )
+
+ # Each sensor appears 2 times per window (6 obs / 3 sensors), across B*T=4 windows = 8 total
+ for sid in [0, 1, 2]:
+ assert split[sid].obs.shape[0] == 8, f"Sensor {sid} should have 8 obs"
+
+
+def test_split_content_correctness():
+ """Verify split observations contain correct data for each sensor."""
+ obs = make_realistic_obs(B=2, T=2, sensors=[0, 1, 2])
+ split = split_by_sensor(obs, [0, 1, 2])
+
+ # Verify each split contains only its sensor's data
+ for sid in [0, 1, 2]:
+ s_obs = split[sid]
+ sensor_ids_in_split = s_obs.int_metadata[:, s_obs.bucket_index.sensor]
+
+ # All observations must be for this sensor
+ assert torch.all(sensor_ids_in_split == sid), (
+ f"Sensor {sid} contains wrong sensor IDs: {sensor_ids_in_split.unique().tolist()}"
+ )
+
+ # Verify values match (obs tensor should match time for our test data)
+ assert torch.allclose(s_obs.obs[:, 0], s_obs.time), "Data corruption detected"
+
+
+def test_split_offsets_are_relative():
+ """Verify split offsets are relative to each sensor's slice, not absolute."""
+ obs = make_realistic_obs(B=1, T=2, sensors=[0, 1])
+ split = split_by_sensor(obs, [0, 1])
+
+ for sid in [0, 1]:
+ s_obs = split[sid]
+ # Last offset should equal obs count (not some large absolute index)
+ assert s_obs.offsets[0, -1, -1].item() == s_obs.obs.shape[0], (
+ f"Sensor {sid} offsets not relative"
+ )
+
+
+def test_split_empty_sensor():
+ """Test handling of sensor with no data."""
+ obs = make_realistic_obs(B=1, T=1, sensors=[0, 1])
+ split = split_by_sensor(obs, [0, 1, 2]) # Request sensor 2 which doesn't exist
+
+ assert split[2].obs.shape[0] == 0, "Empty sensor should have 0 observations"
+ assert split[2].offsets.shape == (
+ 1,
+ 1,
+ 1,
+ ), "Empty sensor should preserve batch structure"
+
+
+def test_split_requires_offsets():
+ """Test that split_by_sensor requires offsets."""
+ obs = UnifiedObservation(
+ obs=torch.randn(10, 3),
+ time=torch.arange(10, dtype=torch.float32),
+ float_metadata=torch.randn(10, 5),
+ int_metadata=torch.zeros((10, 6), dtype=torch.long),
+ offsets=None, # No offsets
+ sensor_id_to_local=None,
+ hpx_level=6,
+ )
+
+ with pytest.raises(ValueError, match="offsets is required"):
+ split_by_sensor(obs, [0, 1])
+
+
+def test_offsets_monotonic_row_major():
+ """Offsets must be nondecreasing in row-major (b,t) order for each sensor."""
+ obs = make_realistic_obs(B=2, T=3, sensors=[0, 1, 2])
+ S, B, T = obs.offsets.shape
+
+ for s_local in range(S):
+ flat = obs.offsets[s_local].reshape(-1)
+ assert torch.all(flat[:-1] <= flat[1:]), (
+ f"offsets for sensor {s_local} must be nondecreasing in row-major (b,t)"
+ )
+
+
+def test_split_handles_sparse_windows():
+ """Sensor missing from some (b,t) windows; split must still work."""
+ B, T = 2, 3
+ sensors = [0, 4]
+
+ # Sparse data: sensor 0 everywhere (2 obs/window), sensor 4 only in (b=1,t=2) with 3 obs
+ all_obs = []
+ for b in range(B):
+ for t in range(T):
+ all_obs.extend([(0, b, t)] * 2) # sensor 0: 2 obs/window
+ all_obs.extend([(4, 1, 2)] * 3) # sensor 4: 3 obs only in (1,2)
+
+ sensor_ids = torch.tensor([x[0] for x in all_obs], dtype=torch.long)
+ nobs = len(all_obs)
+
+ # Build offsets: sensor 0 cumulative, sensor 4 mostly zeros except (1,2)
+ offsets_3d = torch.zeros((2, B, T), dtype=torch.int32)
+ idx = 0
+ for b in range(B):
+ for t in range(T):
+ idx += 2 # sensor 0 has 2 obs/window
+ offsets_3d[0, b, t] = idx
+ for b in range(B):
+ for t in range(T):
+ if b == 1 and t == 2:
+ idx += 3 # sensor 4 only here
+ offsets_3d[1, b, t] = idx
+
+ sensor_id_to_local = torch.full((5,), -1, dtype=torch.int32)
+ for local_idx, s_id in enumerate(sensors):
+ sensor_id_to_local[s_id] = local_idx
+
+ obs = UnifiedObservation(
+ obs=torch.arange(nobs, dtype=torch.float32).unsqueeze(1).expand(nobs, 3),
+ time=torch.arange(nobs, dtype=torch.float32),
+ float_metadata=torch.arange(nobs, dtype=torch.float32)
+ .unsqueeze(1)
+ .expand(nobs, 5),
+ int_metadata=torch.stack(
+ [
+ sensor_ids,
+ torch.arange(nobs),
+ torch.zeros(nobs),
+ torch.zeros(nobs),
+ torch.zeros(nobs),
+ torch.zeros(nobs),
+ ],
+ dim=1,
+ ),
+ offsets=offsets_3d,
+ sensor_id_to_local=sensor_id_to_local,
+ hpx_level=6,
+ )
+
+ assert obs.batch_dims == (2, 3)
+
+ split = split_by_sensor(obs, [0, 4, 99])
+
+ # Sensor 0: 12 obs (2 per window * 6 windows)
+ s0 = split[0]
+ assert s0.obs.shape[0] == 12
+ assert s0.offsets.shape == (1, 2, 3)
+ assert s0.batch_dims == (2, 3)
+ assert s0.offsets[0, -1, -1].item() == 12
+
+ # Sensor 4: 3 obs (only in window (1,2))
+ s4 = split[4]
+ assert s4.obs.shape[0] == 3
+ assert s4.offsets.shape == (1, 2, 3)
+ assert s4.offsets[0, -1, -1].item() == 3
+ assert s4.batch_dims == (2, 3)
+
+ # Sensor 99: absent
+ s99 = split[99]
+ assert s99.obs.shape[0] == 0
+ assert s99.offsets.shape == (1, 2, 3)
+ assert torch.all(s99.offsets == 0)
diff --git a/test/models/healda/utils/obs_test_utils.py b/test/models/healda/utils/obs_test_utils.py
new file mode 100644
index 0000000000..85cef6fd74
--- /dev/null
+++ b/test/models/healda/utils/obs_test_utils.py
@@ -0,0 +1,148 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test utilities for observation embedding tests."""
+
+import torch
+
+from physicsnemo.experimental.models.healda import ModelSensorConfig, UnifiedObservation
+
+
+def create_unified_observation(
+ nobs: int,
+ batch_size: int = 1,
+ time_steps: int = 1,
+ meta_dim: int = 8,
+ hpx_level: int = 6,
+ nchannel: int = 10,
+ nplatform: int = 5,
+ n_embed: int = 5,
+ device: str = "cpu",
+ ensure_all_sensors: bool = False,
+ sensor_config: dict[str, ModelSensorConfig] | None = None,
+) -> UnifiedObservation:
+ torch.manual_seed(0)
+
+ # Extract sensor info
+ if sensor_config is not None:
+ sensors = [
+ (cfg.sensor_id, cfg.nchannel, list(cfg.platform_ids))
+ for cfg in sensor_config.values()
+ ]
+ else:
+ sensors = [(0, nchannel, list(range(nplatform)))]
+
+ sensor_ids = [s[0] for s in sensors]
+ n_sensors = len(sensor_ids)
+ npix = 12 * 4**hpx_level
+
+ # Build sensor_id_to_local mapping
+ max_sensor_id = max(sensor_ids) if sensor_ids else 0
+ sensor_id_to_local = torch.full((max_sensor_id + 1,), -1, dtype=torch.long)
+ for local_idx, sid in enumerate(sensor_ids):
+ sensor_id_to_local[sid] = local_idx
+
+ # Handle empty case
+ if nobs == 0:
+ return UnifiedObservation(
+ obs=torch.empty(0, device=device),
+ time=torch.empty(0, dtype=torch.long, device=device),
+ float_metadata=torch.empty((0, meta_dim), device=device),
+ int_metadata=torch.empty((0, 6), dtype=torch.long, device=device),
+ offsets=torch.zeros(
+ (n_sensors, batch_size, time_steps), dtype=torch.long, device=device
+ ),
+ sensor_id_to_local=sensor_id_to_local.to(device),
+ hpx_level=hpx_level,
+ )
+
+ # Generate random observations
+ def random_obs_for_sensor(sid):
+ """Generate one observation for a given sensor."""
+ _, nchan, plat_ids = next(s for s in sensors if s[0] == sid)
+ return {
+ "obs": torch.randn(1).item() * 0.5,
+ "time": 946674000000000000 + torch.randint(0, 86400 * 10**9, (1,)).item(),
+ "pix": torch.randint(0, npix, (1,)).item(),
+ "platform": plat_ids[torch.randint(0, len(plat_ids), (1,)).item()],
+ "channel": torch.randint(0, nchan, (1,)).item(),
+ "embed_id": torch.randint(0, n_embed, (1,)).item(),
+ "float_meta": torch.randn(meta_dim) * 0.8,
+ "sensor_id": sid,
+ }
+
+ # Generate observations
+ if ensure_all_sensors:
+ # One per sensor first, then random
+ observations = [random_obs_for_sensor(sid) for sid in sensor_ids]
+ observations.extend(
+ random_obs_for_sensor(
+ sensor_ids[torch.randint(0, len(sensor_ids), (1,)).item()]
+ )
+ for _ in range(nobs - len(sensor_ids))
+ )
+ else:
+ observations = [
+ random_obs_for_sensor(
+ sensor_ids[torch.randint(0, len(sensor_ids), (1,)).item()]
+ )
+ for _ in range(nobs)
+ ]
+
+ # Sort by sensor_id (required for per-sensor processing)
+ observations.sort(key=lambda x: x["sensor_id"])
+
+ # Build tensors
+ obs = torch.tensor([o["obs"] for o in observations], dtype=torch.float32)
+ time = torch.tensor([o["time"] for o in observations], dtype=torch.long)
+ float_metadata = torch.stack([o["float_meta"] for o in observations])
+ sensor_id_tensor = torch.tensor(
+ [o["sensor_id"] for o in observations], dtype=torch.long
+ )
+
+ pix_tensor = torch.tensor([o["pix"] for o in observations], dtype=torch.long)
+ channel_tensor = torch.tensor(
+ [o["channel"] for o in observations], dtype=torch.long
+ )
+ platform_tensor = torch.tensor(
+ [o["platform"] for o in observations], dtype=torch.long
+ )
+
+ idx = UnifiedObservation.bucket_index
+ int_metadata = torch.zeros((nobs, 6), dtype=torch.long)
+ int_metadata[:, idx.sensor] = sensor_id_tensor
+ int_metadata[:, idx.pix] = pix_tensor
+ int_metadata[:, idx.local_channel] = channel_tensor
+ int_metadata[:, idx.platform] = platform_tensor
+ int_metadata[:, idx.obs_type] = 0
+ int_metadata[:, idx.global_channel] = channel_tensor
+
+ # Build 3D offsets: cumulative end indices over (sensor, batch, time)
+ # All observations go into the first window for simplicity
+ offsets = torch.zeros((n_sensors, batch_size, time_steps), dtype=torch.long)
+ cumulative = 0
+ for s_local, sid in enumerate(sensor_ids):
+ cumulative += (sensor_id_tensor == sid).sum().item()
+ offsets[s_local, :, :] = cumulative
+
+ return UnifiedObservation(
+ obs=obs.to(device),
+ time=time.to(device),
+ float_metadata=float_metadata.to(device),
+ int_metadata=int_metadata.to(device),
+ offsets=offsets.to(device),
+ sensor_id_to_local=sensor_id_to_local.to(device),
+ hpx_level=hpx_level,
+ )