Skip to content

HealDA Integration#1356

Draft
aayushg55 wants to merge 21 commits intoNVIDIA:mainfrom
aayushg55:healda
Draft

HealDA Integration#1356
aayushg55 wants to merge 21 commits intoNVIDIA:mainfrom
aayushg55:healda

Conversation

@aayushg55
Copy link
Contributor

@aayushg55 aayushg55 commented Jan 28, 2026

PhysicsNeMo Pull Request

Description

This integrates the HealDA AI data assimilation model training and inference pipelines into PhysicsNemo. This includes the following in the examples/weather/healda directory:

  • UFS Replay Observation dataset ETL
  • HealDA training recipe
  • End-to-end forecasting example using HealDA initial conditions with FourCastNet3

The HealDA Observation Embedding and DiT architectures are integrated at physicsnemo/models/healda.

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@aayushg55 aayushg55 marked this pull request as ready for review January 28, 2026 19:54
@aayushg55
Copy link
Contributor Author

@greptileai

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 28, 2026

Greptile Overview

Greptile Summary

This PR integrates the HealDA (Hierarchical Lattice Data Assimilation) AI model into PhysicsNeMo, adding training and inference pipelines for weather forecasting with observation data assimilation.

Major Changes:

  • Added DiT (Diffusion Transformer) model architecture in physicsnemo/models/healda/
  • Implemented HEALPix-based patch embedding and decoding layers
  • Added multi-sensor observation embedding system for satellite and conventional weather observations
  • Included UFS Replay Observation dataset ETL pipeline and training recipes
  • Added end-to-end forecasting example integrating HealDA with FourCastNet3

Critical Issues Found:

  1. MOD-001 Violation (MUST FIX): All main model classes (DiT, HPXPatchEmbed, HPXPatchDecode, MultiSensorObsEmbedding, ObsDecoder) inherit from torch.nn.Module instead of physicsnemo.Module. This violates the coding standard and prevents these models from benefiting from PhysicsNeMo's serialization, versioning, and registry features.

  2. MOD-003 Violations: Multiple docstring formatting issues including missing raw string prefix r""", incomplete docstring sections (missing Parameters, Forward, Outputs), and improper NumPy style formatting.

  3. Missing Comprehensive Tests: While basic smoke tests exist, the models lack the required constructor/attributes tests (MOD-008a), non-regression tests with reference data (MOD-008b), and checkpoint loading tests (MOD-008c).

  4. Unused Code: self.silu is initialized but never used in AdaLayerNormZero and AdaLayerNormTemporalAttn classes.

Recommendations:

  • Fix MOD-001 violations by changing all model base classes to inherit from physicsnemo.Module
  • Add proper super().__init__(meta=...) calls with appropriate metadata
  • Complete docstring sections following MOD-003 guidelines
  • Add comprehensive test coverage per MOD-008 requirements before moving out of experimental status

Important Files Changed

Filename Overview
physicsnemo/models/healda/dit.py adds DiT transformer model for data assimilation, but violates MOD-001 (inherits from torch.nn.Module instead of physicsnemo.Module) and has docstring formatting issues
physicsnemo/models/healda/healpix_layers.py adds HEALPix patch embedding/decoding layers, but violates MOD-001 and MOD-003c (missing proper docstring sections)
physicsnemo/models/healda/obs_embedding/point_embed.py implements multi-sensor observation embedding, but violates MOD-001 (inherits from torch.nn.Module)
physicsnemo/models/healda/obs_embedding/decoder.py adds observation decoder from latent representations, but violates MOD-001
test/models/healda/test_dit.py basic smoke tests for DiT model instantiation and forward pass, but lacks comprehensive tests required by MOD-008a, MOD-008b, MOD-008c

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 9 comments

Edit Code Review Agent Settings | Greptile

self._parallel_group = group


class DiT(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

violates MOD-001: inherits from torch.nn.Module instead of physicsnemo.Module

Suggested change
class DiT(torch.nn.Module):
class DiT(physicsnemo.Module):

)


class HPXPatchEmbed(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

violates MOD-001: inherits from torch.nn.Module instead of physicsnemo.Module

Suggested change
class HPXPatchEmbed(torch.nn.Module):
class HPXPatchEmbed(physicsnemo.Module):

return x + pos_embed + calendar_embed


class HPXPatchDecode(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

violates MOD-001: inherits from torch.nn.Module instead of physicsnemo.Module

Suggested change
class HPXPatchDecode(torch.nn.Module):
class HPXPatchDecode(physicsnemo.Module):

return self._forward(obs)


class MultiSensorObsEmbedding(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

violates MOD-001: inherits from torch.nn.Module instead of physicsnemo.Module

Suggested change
class MultiSensorObsEmbedding(torch.nn.Module):
class MultiSensorObsEmbedding(physicsnemo.Module):

return x


class ObsDecoder(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

violates MOD-001: inherits from torch.nn.Module instead of physicsnemo.Module

Suggested change
class ObsDecoder(torch.nn.Module):
class ObsDecoder(physicsnemo.Module):

Comment on lines 610 to 640
r"""
A 2D Transformer model as introduced in DiT (https://huggingface.co/papers/2212.09748).

Parameters:
num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
in_channels (int, defaults to 4): The number of channels in the input.
out_channels (int, optional):
The number of channels in the output. Specify this parameter if the output channel number differs from the
input.
num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
norm_num_groups (int, optional, defaults to 32):
Number of groups for group normalization within Transformer blocks.
attention_bias (bool, optional, defaults to True):
Configure if the Transformer blocks' attention should contain a bias parameter.
sample_size (int, defaults to 32):
The width of the latent images. This parameter is fixed during training.
patch_size (int, defaults to 2):
Size of the patches the model processes, relevant for architectures working on non-sequential data.
activation_fn (str, optional, defaults to "gelu-approximate"):
Activation function to use in feed-forward networks within Transformer blocks.
upcast_attention (bool, optional, defaults to False):
If true, upcasts the attention mechanism dimensions for potentially improved performance.
norm_type (str, optional, defaults to "ada_norm_zero"):
Specifies the type of normalization used, can be 'ada_norm_zero'.
norm_elementwise_affine (bool, optional, defaults to False):
If true, enables element-wise affine parameters in the normalization layers.
norm_eps (float, optional, defaults to 1e-5):
A small constant added to the denominator in normalization layers to prevent division by zero.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

violates MOD-003b: docstring does not use raw string prefix r"""

The docstring should start with r""" to properly render LaTeX math in Sphinx documentation.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines 98 to 106
class HPXPatchEmbed(torch.nn.Module):
"""

Args:
input: (b c t x)
Returns
output: (b t x_model c_model)

"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

violates MOD-003c: missing required Parameters, Forward, and Outputs docstring sections

The docstring should follow NumPy style with proper sections as specified in the coding standards.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines 93 to 94
# TODO silu unused. Is this a bug? --noah 9/5/25
self.silu = nn.SiLU()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused attribute self.silu initialized but never used

Consider removing if unnecessary, or document why it's kept for potential future use.

def __init__(self, embedding_dim: int, emb_channels: int, bias=True):
super().__init__()
# TODO silu unused. Is this a bug? --noah 9/5/25
self.silu = nn.SiLU()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused attribute self.silu initialized but never used (duplicate issue)

Same as AdaLayerNormZero - consider removing if unnecessary.

Copy link
Collaborator

@coreyjadams coreyjadams left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @aayushg55 ,

There are some things to address in this PR before we can review it properly. I haven't looked at the logic of the code, I'll let a subject matter expert review.

The code in physicsnemo/models/healda is not aligned with the standards we're moving to for physicsnemo v2.0. All of this feedback is arising because physicsnemo was becoming very fragmented, circular, and unmanageable with such large PRs. We're attempting to improve the user and developer experience by reducing some of this duplication, coding practice violations, etc. Your PR is coming as we're starting to really enforce this stuff - sorry. But we'll have to get the model implementation up to spec before merging.

Some specific things I have found looking quickly, though certainly not everything:

  • model code must enter physicsnemo through the experimental folder. We already have a DiT in the experimental folder, and if they can be combined to maximize overlap and minimize code repetition we should do that.
  • We also have already some substantial amount of healpix and earth2grid layers / work. We can't duplicate here - please extend the existing layers, as needed, but repetition isn't maintainable.
  • There is a lot of tooling in the models folder that does not belong there. profiling.py and types.py come to mind as utilities. sharding looks like it might be domain-parallel specific but I don't know what it's doing. We have a whole suite of distributed and domain parallel tooling, perhaps those should be there.
  • bare earth2grid imports in physicsnemo are not allowed.
  • We have a number of embeddings already that seem very similar to some of the ones in embedding.py. Let's not duplicate.
  • avoid importing from a subdirectory (obs_embedding) into a higher directory like you do, it's just a circular import waiting to happen.
  • you have files missing license headers in places.

We have a pre-commit system in physicsnemo that should have caught a lot of these - did you try it?

I think this PR has a ways to go before proper review can start. Would you like to convert it to draft and we can help you? We can schedule a meeting with someone on the physicsnemo team to give you some guidance.

All the best,
Corey

@aayushg55
Copy link
Contributor Author

Thanks @coreyjadams for taking a look. I agree this needs significant work and refactoring for a proper integration with physicsnemo, and this PR was mainly intended to start the public release process.

I’ll mark this as a draft for now. I believe @NickGeneva will be helping me bring this into better shape.

Re: pre-commit — Yes, I had run the pre-commit hooks and resolved the linting/license issues it flagged.

@aayushg55 aayushg55 marked this pull request as draft January 28, 2026 22:58
@nbren12
Copy link
Collaborator

nbren12 commented Jan 29, 2026

Keep in mind that we need to maintain checkpoint compatibility since this checkpoint has been heavily validated and released via a publication. Some of the concerns about apparent code duplication need to be weighed against that. Refactoring and changing state dict names is fine, but these should not change the answer. 100s of person hours are invested in this checkpoint, so we shouldn’t be too rigid.

The hpx embedding layer strikes me as a risky component to refactor, and also a fairly simple one which shouldn’t increase maintenance burden much. Not sure if you have a hpx layer yet.

@NickGeneva NickGeneva self-assigned this Jan 29, 2026
- removed ObsDecoder/Sharding/Profiling/SubDomain
- HpxPatch Embed/Decode subclass the existing DiT tokenizers
- refactored the noise+condition embedding into module to be compatible
- moved HealDA to experimental
- added DropPath to DiT
- breaking changes for DropOut in DiT MlpLayer and TE Attn backend
  (proj_out)
- moved HealDA model config classes to examples, but retain sensor
  related configs
- flattend obs_embedding subdir
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants