diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 567838a8cab1..88e36d99e664 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -292,10 +292,8 @@ LogicalResult canonicalize_elementwise(int hardware_generation_, return failure(); } auto element_type = ty.getElementType(); - // PowFOp and DivFOp do not seem to be supported in bf16 on later - // hardware. - bool needs_cast = hardware_generation_ <= 5 || isa(op) || - isa(op); + // PowFOp does not seem to be supported in bf16 on later hardware. + bool needs_cast = hardware_generation_ <= 5 || isa(op); if (needs_cast && element_type.isBF16()) { auto target_f32 = builder.create(op.getLoc(), target_f32_ty, operand) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index d848bc437df9..00750766c0c9 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1116,18 +1116,31 @@ def kernel(x_ref, y_ref, out_ref): @parameterized.parameters( ("int32", "float32"), ("float32", "float32"), + ("bfloat16", "bfloat16"), ) def test_true_divide(self, dtype, out_dtype): + if jtu.test_device_matches(["tpu"]): + if out_dtype == "bfloat16" and not jtu.is_device_tpu_at_least(6): + self.skipTest("bfloat16 is not supported on older TPU generations") + if not jtu.if_cloud_tpu_at_least(2024, 12, 21): + self.skipTest("Requires libtpu built after 2024-12-21") @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((8,), out_dtype), + out_shape=jax.ShapeDtypeStruct((8, 8), out_dtype), ) def kernel(x_ref, y_ref, o_ref): o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...]) x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype) y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype) - np.testing.assert_allclose(jnp.true_divide(x, y), kernel(x, y)) + x = jnp.repeat(x, 8, axis=0).reshape(8, 8) + y = jnp.tile(y, 8).reshape(8, 8) + rtol = 8e-3 if dtype == "bfloat16" else 1e-6 + np.testing.assert_allclose( + jnp.true_divide(x, y).astype(jnp.float32), + kernel(x, y).astype(jnp.float32), + rtol=rtol, + ) @parameterized.parameters("float16", "bfloat16") def test_true_divide_unsupported(self, dtype):