[pallas] Improve some error messages and add API tests. #22173
+261
−61
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
We make the following improvements:
mismatch
the overall array. Without this we used to get a
safe_zip
error. We also carry the pytree paths to localize the error.
we used to get
body_fun output and input must have same type structure
in the interpreter,
assert len(jaxpr.outvars) == 0
on GPU,and
INTERNAL: Mosaic failed to compile TPU kernel: has 1 operands, but enclosing function (@main) returns 0
on TPU.
To simplify the generation of the error messages we added a helper
function
tree_util.equality_errors_pytreedef
, which is just liketree_util.equality_errors
but takesPyTreeDef
inputs rather thanPyTrees. We then used this new helper function in
pjit.py
andstages.py
.