Skip to content

Commit

Permalink
[pallas] Improve some error messages and add API tests.
Browse files Browse the repository at this point in the history
We make the following improvements:

  * pytree structural disequality messages attempt to localize the
    mismatch
  * we check that the rank of the block_shape matches the rank of
    the overall array. Without this we used to get a `safe_zip`
    error. We also carry the pytree paths to localize the error.

To simplify the generation of the error messages we added a helper
function `tree_util.equality_errors_pytreedef`, which is just like
`tree_util.equality_errors` but takes `PyTreeDef` inputs rather than
PyTrees. We then used this new helper function in `pjit.py` and `stages.py`.
  • Loading branch information
gnecula committed Jul 1, 2024
1 parent 727d120 commit 1397dad
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 60 deletions.
88 changes: 58 additions & 30 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,34 +300,47 @@ def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
shape = tuple(s for s in block_shape if s is not None)
return ref.update(inner_aval=ref.inner_aval.update(shape=shape))


def _get_ref_avals(grid, in_avals, in_specs, out_avals, out_specs):
def _get_memory_space(spec):
def _get_ref_avals(grid, in_avals, in_specs, in_paths,
out_avals, out_specs, out_paths):
assert grid is not None # TODO(necula): can it ever be None?
# if grid is None:
# in_specs = [None] * len(in_avals)
# out_specs = [None] * len(out_avals)

def make_ref(aval: jax_core.ShapedArray, spec: BlockSpec,
path: tree_util.KeyPath,
what: str) -> state.AbstractRef:
if spec is no_block_spec:
return None
return spec.memory_space
memory_space = None
block_shape = None
else:
memory_space = spec.memory_space
block_shape = spec.block_shape

ref_aval = AbstractMemoryRef(aval, memory_space)
if block_shape is None:
return ref_aval
if len(block_shape) != len(aval.shape):
raise ValueError(
f"The rank of the {what} block_shape at "
f"{what}{tree_util.keystr(path)} (= {len(block_shape)}) "
"does not match the rank of the corresponding "
f"{what} shape (= {aval.shape})")
for axis, (block_d, shape_d) in enumerate(zip(block_shape, aval.shape)):
assert shape_d is not None # when does this happen?
assert block_d is not None
shape = tuple(s for s in block_shape if s is not None)
return ref_aval.update(inner_aval=ref_aval.inner_aval.update(shape=shape))

in_ref_avals = [
AbstractMemoryRef(aval, _get_memory_space(in_spec))
for aval, in_spec in zip(in_avals, in_specs)
make_ref(aval, in_spec, in_path, "input")
for aval, in_spec, in_path in zip(in_avals, in_specs, in_paths)
]
out_ref_avals = [
AbstractMemoryRef(aval, _get_memory_space(out_spec))
for aval, out_spec in zip(out_avals, out_specs)
make_ref(aval, out_spec, out_path, "output")
for aval, out_spec, out_path in zip(out_avals, out_specs, out_paths)
]
if grid is None:
in_specs = [None] * len(in_avals)
out_specs = [None] * len(out_avals)
tiled_in_ref_avals = [
aval if in_spec is no_block_spec
else _tile_ref(aval, in_spec.block_shape)
for aval, in_spec in zip(in_ref_avals, in_specs)
]
tiled_out_ref_avals = [
aval if out_spec is no_block_spec
else _tile_ref(aval, out_spec.block_shape)
for aval, out_spec in zip(out_ref_avals, out_specs)
]
return in_specs, tiled_in_ref_avals, out_specs, tiled_out_ref_avals
return in_specs, in_ref_avals, out_specs, out_ref_avals

class NoBlockSpec:
pass
Expand Down Expand Up @@ -375,20 +388,20 @@ def _get_in_out_specs(self, in_avals, in_tree, out_avals, out_tree):
flat_in_specs = self.in_specs
if self.in_specs_tree != in_tree:
raise ValueError(
"Pytree specs for arguments and `in_specs` must match: "
f"{in_tree} vs. {self.in_specs_tree}")
pytreedef_mismatch_err_msg("`in_specs`", self.in_specs_tree,
"inputs", in_tree))
if self.out_specs is no_block_spec:
flat_out_specs = [no_block_spec] * len(out_avals)
else:
flat_out_specs = self.out_specs
if self.out_specs_tree != out_tree:
raise ValueError(
"Pytree specs for `out_shape` and `out_specs` must match: "
f"{out_tree} vs. {self.out_specs_tree}")
pytreedef_mismatch_err_msg("`out_specs`", self.out_specs_tree,
"`out_shape`", out_tree))
return flat_in_specs, flat_out_specs

def get_grid_mapping(
self, in_avals, in_tree, out_avals, out_tree
self, in_avals, in_tree, in_paths, out_avals, out_tree, out_paths
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
assert all(i is None or isinstance(i, int) for i in self.grid)
grid_mapping_grid = tuple(
Expand All @@ -397,8 +410,8 @@ def get_grid_mapping(
flat_in_specs, flat_out_specs = self._get_in_out_specs(
in_avals, in_tree, out_avals, out_tree)
in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals(
self.grid, in_avals, flat_in_specs, out_avals,
flat_out_specs)
self.grid, in_avals, flat_in_specs, in_paths,
out_avals, flat_out_specs, out_paths)
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
# Create args, kwargs pytree def
grid_tree = tree_util.tree_structure((tuple(grid_avals), {}))
Expand Down Expand Up @@ -445,3 +458,18 @@ def unzip_dynamic_grid_bounds(
static_self = copy.copy(self)
static_self.grid = static_grid # type: ignore
return static_self, dynamic_bounds

def pytreedef_mismatch_err_msg(
what1: str, tree1: tree_util.PyTreeDef,
what2: str, tree2: tree_util.PyTreeDef) -> str:
errs = list(tree_util.equality_errors_pytreedef(tree1, tree2))
msg = []
msg.append(
f"Pytree for {what1} and {what2} do not match. "
f"There are {len(errs)} mismatches, including:")
for path, thing1, thing2, explanation in errs:
where = f"at {tree_util.keystr(path)}, " if path else ""
msg.append(
f" * {where}{what1} is a {thing1} but"
f" {what2} is a {thing2}, so {explanation}")
return "\n".join(msg)
6 changes: 3 additions & 3 deletions jax/_src/pallas/mosaic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __init__(
self.scratch_shapes = tuple(scratch_shapes)

def get_grid_mapping(
self, in_avals, in_tree, out_avals, out_tree
self, in_avals, in_tree, in_paths, out_avals, out_tree, out_paths
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
assert all(i is None or isinstance(i, int) for i in self.grid)
grid_mapping_grid = tuple(
Expand All @@ -189,8 +189,8 @@ def get_grid_mapping(
in_avals, in_avals_tree, out_avals, out_tree)
in_specs, in_ref_avals, out_specs, out_ref_avals = (
pallas_core._get_ref_avals(
self.grid, in_avals, flat_in_specs,
out_avals, flat_out_specs))
self.grid, in_avals, flat_in_specs, in_paths,
out_avals, flat_out_specs, out_paths))
scalar_ref_avals = [
AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype),
TPUMemorySpace.SMEM)
Expand Down
29 changes: 19 additions & 10 deletions jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
import jax
from jax import api_util
from jax import lax
from jax import tree_util
from jax._src import ad_util
from jax._src import checkify
from jax._src import config
from jax._src import core as jax_core
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import state
from jax._src import tree_util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
Expand All @@ -45,6 +45,7 @@
safe_zip,
split_list,
tuple_insert,
unzip2,
weakref_lru_cache,
)
import jax.numpy as jnp
Expand Down Expand Up @@ -893,10 +894,16 @@ def checked_kernel_fn(*args):
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule

@weakref_lru_cache
def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, flat_in_avals,
flat_out_avals, in_tree, out_tree, interpret: bool):
avals, grid_mapping = grid_spec.get_grid_mapping(flat_in_avals, in_tree,
flat_out_avals, out_tree)
def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
flat_in_avals: Sequence[jax_core.AbstractValue],
flat_out_avals: Sequence[jax_core.AbstractValue],
in_tree: tree_util.PyTreeDef,
in_paths: Sequence[tree_util.KeyPath],
out_tree: tree_util.PyTreeDef,
out_paths: Sequence[tree_util.KeyPath],
interpret: bool):
avals, grid_mapping = grid_spec.get_grid_mapping(flat_in_avals, in_tree, in_paths,
flat_out_avals, out_tree, out_paths)
if interpret:
avals = jax.tree_util.tree_map(_logical_aval_to_interpret_mode_aval, avals)
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals)
Expand Down Expand Up @@ -1058,19 +1065,21 @@ def pallas_call(
grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds()
if isinstance(out_shape, list):
out_shape = tuple(out_shape)
flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape)
flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype)
flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path(out_shape)
out_paths, flat_out_shapes = unzip2(flat_out_shapes_with_paths)
flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) # type: ignore
for x in flat_out_shapes]
@jax.jit
def wrapped(*args):
flat_args, in_tree = tree_util.tree_flatten(args)
flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args)
in_paths, flat_args = unzip2(flat_args_with_paths)
flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a))
for a in flat_args)
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
for v in flat_out_shapes)
grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr(
f, grid_spec, flat_in_avals, flat_out_avals, in_tree,
out_tree, interpret=interpret)
f, grid_spec, flat_in_avals, flat_out_avals, in_tree, in_paths,
out_tree, out_paths, interpret=interpret)
out_flat = pallas_call_p.bind(
*dynamic_grid_bounds, *consts, *flat_args,
jaxpr=jaxpr, name=name,
Expand Down
6 changes: 1 addition & 5 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,11 +1209,7 @@ def unpack(key):
p(f" never seen input pytree{in_tree_str}")
dont_match = [t for t, *_ in seen_keys if t != in_tree]
closest_tree = min(dont_match, key=lambda t: abs(t.num_leaves - in_tree.num_leaves))
# TODO(mattjj): make equality_errors not print type name, avoid metaclass
leaf = type('LeafMeta', (type,), dict(__repr__=lambda _: 'leaf'))('Leaf', (), {})()
this_dummy = tree_unflatten(in_tree, [leaf] * in_tree.num_leaves)
close_dummy = tree_unflatten(closest_tree, [leaf] * closest_tree.num_leaves) # type: ignore
errs = list(tree_util.equality_errors(this_dummy, close_dummy))
errs = list(tree_util.equality_errors_pytreedef(in_tree, closest_tree)) # type: ignore[arg-type]
p(f" closest seen input pytree has {len(errs)} mismatches, including:")
for path, thing1, thing2, explanation in errs:
fst, *path = path # type: ignore
Expand Down
13 changes: 2 additions & 11 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util
from jax._src.tree_util import tree_unflatten, keystr
from jax._src import util
from jax._src.sharding_impls import is_unspecified_or_auto
from jax._src.layout import Layout
Expand Down Expand Up @@ -590,11 +589,7 @@ def call(*args, **kwargs):
f"keyword arguments, but called with keyword arguments: {kws}")
args_flat, in_tree = tree_util.tree_flatten((args, kwargs))
if in_tree != params.in_tree:
leaf = PytreeLeaf()
this_dummy = tree_unflatten(in_tree, [leaf] * in_tree.num_leaves)
other_dummy = tree_unflatten(
params.in_tree, [leaf] * params.in_tree.num_leaves)
errs = list(tree_util.equality_errors(this_dummy, other_dummy))
errs = list(tree_util.equality_errors_pytreedef(in_tree, in_tree))
msg = []
msg.append(
"Function compiled with input pytree does not match the input pytree"
Expand All @@ -603,7 +598,7 @@ def call(*args, **kwargs):
fst, *rest = path
base = ['args', 'kwargs'][fst.idx]
msg.append(
f" * at {base}{keystr(tuple(rest))}, seen {thing2} but now"
f" * at {base}{tree_util.keystr(tuple(rest))}, seen {thing2} but now"
f" given {thing1}, so {explanation}")
raise TypeError('\n'.join(msg))
try:
Expand Down Expand Up @@ -641,10 +636,6 @@ def cpp_call_fallback(*args, **kwargs):
return self._call(*args, **kwargs)


class PytreeLeaf:
def __repr__(self): return "pytree leaf"


class Lowered(Stage):
"""Lowering of a function specialized to argument types and values.
Expand Down
11 changes: 10 additions & 1 deletion jax/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def equality_errors(
"""Helper to describe structural differences between two pytrees.
Args:
tree1, tree2: pytrees to compare.
tree1, tree2: pytrees known to have different structure.
Usage:
Expand All @@ -636,6 +636,15 @@ def equality_errors(
"""
yield from _equality_errors((), tree1, tree2, is_leaf)

def equality_errors_pytreedef(
tree1: PyTreeDef,
tree2: PyTreeDef) -> Iterable[tuple[KeyPath, str, str, str]]:
"""Like `equality_errors` but invoked on PyTreeDef."""
# TODO(mattjj): make equality_errors not print type name, avoid metaclass
leaf = type("LeafMeta", (type,), dict(__repr__=lambda _: "pytree leaf"))("Leaf", (), {})()
return equality_errors(tree_unflatten(tree1, [leaf] * tree1.num_leaves),
tree_unflatten(tree2, [leaf] * tree2.num_leaves))

# TODO(mattjj): maybe share some logic with _prefix_error?
def _equality_errors(path, t1, t2, is_leaf):
# If both are leaves, this isn't a structure equality error.
Expand Down
28 changes: 28 additions & 0 deletions tests/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,34 @@ package(

jax_generate_backend_suites()

jax_test(
name = "api_test",
srcs = ["api_test.py"],
config_tags_overrides = {
"gpu_a100_x32": {
"ondemand": False, # Include in presubmit.
},
},
disable_configs = [
"gpu",
"gpu_x32",
"gpu_a100",
"gpu_h100",
"gpu_p100",
"gpu_p100_x32",
"gpu_pjrt_c_api",
],
enable_configs = [
"gpu_a100_x32",
],
tags = [],
deps = [
"//jax:pallas",
"//jax:pallas_gpu", # build_cleaner: keep
"//jax:pallas_tpu", # build_cleaner: keep
],
)

jax_test(
name = "pallas_test",
srcs = [
Expand Down
Loading

0 comments on commit 1397dad

Please sign in to comment.