Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pallas] Improve some error messages and add API tests. #22173

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Jun 28, 2024

We make the following improvements:

  • pytree structural disequality messages attempt to localize the
    mismatch
  • we check that the rank of the block_shape matches the rank of
    the overall array. Without this we used to get a safe_zip
    error. We also carry the pytree paths to localize the error.
  • We check that the kernel function returns None. Without this
    we used to get body_fun output and input must have same type structure
    in the interpreter, assert len(jaxpr.outvars) == 0 on GPU,
    and INTERNAL: Mosaic failed to compile TPU kernel: has 1 operands, but enclosing function (@main) returns 0
    on TPU.

To simplify the generation of the error messages we added a helper
function tree_util.equality_errors_pytreedef, which is just like
tree_util.equality_errors but takes PyTreeDef inputs rather than
PyTrees. We then used this new helper function in pjit.py and stages.py.

@gnecula gnecula self-assigned this Jun 28, 2024
@gnecula gnecula added the pull ready Ready for copybara import and testing label Jun 28, 2024
@gnecula gnecula force-pushed the pallas_errors branch 2 times, most recently from 327c5dd to 1397dad Compare July 1, 2024 12:22
We make the following improvements:

  * pytree structural disequality messages attempt to localize the
    mismatch
  * we check that the rank of the block_shape matches the rank of
    the overall array. Without this we used to get a `safe_zip`
    error. We also carry the pytree paths to localize the error.
  * We check that the kernel function returns None. Without this
    we used to get "body_fun output and input must have same type structure"
    in the interpreter, `assert len(jaxpr.outvars) == 0` on GPU,
    and "INTERNAL: Mosaic failed to compile TPU kernel: has 1 operands, but enclosing function (@main) returns 0"
    on TPU.

To simplify the generation of the error messages we added a helper
function `tree_util.equality_errors_pytreedef`, which is just like
`tree_util.equality_errors` but takes `PyTreeDef` inputs rather than
PyTrees. We then used this new helper function in `pjit.py` and `stages.py`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant