Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support stash_type attribute for onnx.LayerNormalization #3888

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

jinchen62
Copy link
Collaborator

@jinchen62 jinchen62 commented Nov 22, 2024

Fixes nod-ai/SHARK-ModelDev#888

If stash_type is different from input_dtype/result_dtype:

  1. convert x dtype to stash_type
  2. calculate mean and var in stash_type since x is in stash_type already
  3. convert back to result_dtype before stage two calculation
  4. convert mean_dtype and var_dtype if they are different from stash_type

@zjgarvey
Copy link
Collaborator

I think we should probably support the stash type arg by separating the two stages of computation as is suggested by ONNX in https://onnx.ai/onnx/operators/onnx__LayerNormalization.html. If an onnx op actually has different result types and stash types, we would likely see numeric mismatches for those situations unless we perform the computation correctly.

Another option is to allow LayerNormalization to be function-expanded on import via

function_expansion_allowlists_by_domain: Optional[Dict[str, set[str]]] = field(

In any case, we should put together a few e2e tests for this op:

  1. With bf16 result type and bf16 stash type
  2. With bf16 result type and unspecified stash type

@jinchen62 jinchen62 changed the title Remove stash_type check for onnx.LayerNormalization lowering Support stash_type attribute for onnx.LayerNormalization Nov 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

(torch-to-onnx) FLUX.1 - bf16 onnx.LayerNormalization failing to legalize
2 participants