From ec59816975da5abe6a01cbe31feaf681bb5c4e34 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Thu, 20 Nov 2025 23:59:11 +0000 Subject: [PATCH 01/29] support all tests in layers to tpu --- .github/workflows/actions.yml | 2 +- .gitignore | 1 + .../embedding/distributed_embedding_test.py | 101 +++------- .../src/layers/embedding/embed_reduce_test.py | 178 +++++++++++++----- .../tensorflow/config_conversion_test.py | 5 + .../dot_interaction_test.py | 77 +++++--- .../feature_interaction/feature_cross_test.py | 63 +++++-- .../retrieval/brute_force_retrieval_test.py | 56 ++++-- .../retrieval/hard_negative_mining_test.py | 66 +++++-- .../retrieval/remove_accidental_hits_test.py | 82 +++++--- .../src/layers/retrieval/retrieval_test.py | 37 +++- .../sampling_probability_correction_test.py | 59 ++++-- keras_rs/src/testing/test_case.py | 8 + keras_rs/src/utils/tpu_test_utils.py | 78 ++++++++ 14 files changed, 572 insertions(+), 241 deletions(-) create mode 100644 keras_rs/src/utils/tpu_test_utils.py diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 5088f571..9b06868e 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -90,7 +90,7 @@ jobs: run: python3 -c "import jax; print('JAX devices:', jax.devices())" - name: Test with pytest - run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py + run: pytest keras_rs/src/layers/ check_format: name: Check the code format diff --git a/.gitignore b/.gitignore index eacd3be8..dfbf3335 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ build/ .idea/ venv/ +venv_tf/ diff --git a/keras_rs/src/layers/embedding/distributed_embedding_test.py b/keras_rs/src/layers/embedding/distributed_embedding_test.py index cb4df82f..c7ced5bc 100644 --- a/keras_rs/src/layers/embedding/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/distributed_embedding_test.py @@ -1,4 +1,3 @@ -import contextlib import functools import math import os @@ -14,6 +13,7 @@ from keras_rs.src import types from keras_rs.src.layers.embedding import distributed_embedding from keras_rs.src.layers.embedding import distributed_embedding_config as config +from keras_rs.src.utils import tpu_test_utils try: import jax @@ -30,28 +30,6 @@ SEQUENCE_LENGTH = 13 -class DummyStrategy: - def scope(self): - return contextlib.nullcontext() - - @property - def num_replicas_in_sync(self): - return 1 - - def run(self, fn, args): - return fn(*args) - - def experimental_distribute_dataset(self, dataset, options=None): - del options - return dataset - - -class JaxDummyStrategy(DummyStrategy): - @property - def num_replicas_in_sync(self): - return jax.device_count("tpu") - - def ragged_bool_true(self): return True @@ -74,46 +52,11 @@ def setUp(self): # FLAGS.xla_sparse_core_max_ids_per_partition_per_sample = 16 # FLAGS.xla_sparse_core_max_unique_ids_per_partition_per_sample = 16 - resolver = tf.distribute.cluster_resolver.TPUClusterResolver() - tf.config.experimental_connect_to_cluster(resolver) - - topology = tf.tpu.experimental.initialize_tpu_system(resolver) - tpu_metadata = resolver.get_tpu_system_metadata() - - device_assignment = tf.tpu.experimental.DeviceAssignment.build( - topology, num_replicas=tpu_metadata.num_hosts - ) - self._strategy = tf.distribute.TPUStrategy( - resolver, experimental_device_assignment=device_assignment - ) - print("### num_replicas", self._strategy.num_replicas_in_sync) - self.addCleanup(tf.tpu.experimental.shutdown_tpu_system, resolver) - elif keras.backend.backend() == "jax" and self.on_tpu: - self._strategy = JaxDummyStrategy() - else: - self._strategy = DummyStrategy() - + self._strategy = tpu_test_utils.get_tpu_strategy(self) self.batch_size = ( BATCH_SIZE_PER_CORE * self._strategy.num_replicas_in_sync ) - def run_with_strategy(self, fn, *args, jit_compile=False): - """Wrapper for running a function under a strategy.""" - - if keras.backend.backend() == "tensorflow": - - @tf.function(jit_compile=jit_compile) - def tf_function_wrapper(*tf_function_args): - def strategy_fn(*strategy_fn_args): - return fn(*strategy_fn_args) - - return self._strategy.run(strategy_fn, args=tf_function_args) - - return tf_function_wrapper(*args) - else: - self.assertFalse(jit_compile) - return fn(*args) - def get_embedding_config(self, input_type, placement): sequence_length = 1 if input_type == "dense" else SEQUENCE_LENGTH @@ -263,7 +206,9 @@ def test_basics(self, input_type, placement): preprocessed_inputs = layer.preprocess(inputs, weights) res = layer(preprocessed_inputs) else: - res = self.run_with_strategy(layer.__call__, inputs, weights) + res = tpu_test_utils.run_with_strategy( + self._strategy, layer.__call__, inputs, weights + ) if placement == "default_device" or not self.on_tpu: # verify sublayers and variables are tracked @@ -422,14 +367,14 @@ def test_dataset_generator(): model.compile(optimizer="adam", loss="mse") model_inputs, _ = next(iter(test_dataset)) - test_output_before = self.run_with_strategy( - model.__call__, model_inputs + test_output_before = tpu_test_utils.run_with_strategy( + self._strategy, model.__call__, model_inputs ) model.fit(train_dataset, steps_per_epoch=1, epochs=1) - test_output_after = self.run_with_strategy( - model.__call__, model_inputs + test_output_after = tpu_test_utils.run_with_strategy( + self._strategy, model.__call__, model_inputs ) # Verify that the embedding has actually trained. @@ -610,10 +555,16 @@ def test_correctness( preprocessed, ) else: - res = self.run_with_strategy(layer.__call__, preprocessed) + res = tpu_test_utils.run_with_strategy( + self._strategy, layer.__call__, preprocessed + ) else: - res = self.run_with_strategy( - layer.__call__, inputs, weights, jit_compile=jit_compile + res = tpu_test_utils.run_with_strategy( + self._strategy, + layer.__call__, + inputs, + weights, + jit_compile=jit_compile, ) self.assertEqual(res.shape, (self.batch_size, EMBEDDING_OUTPUT_DIM)) @@ -686,7 +637,9 @@ def test_shared_table(self): with self._strategy.scope(): layer = distributed_embedding.DistributedEmbedding(embedding_config) - res = self.run_with_strategy(layer.__call__, inputs) + res = tpu_test_utils.run_with_strategy( + self._strategy, layer.__call__, inputs + ) if self.placement == "default_device": self.assertLen(layer._flatten_layers(include_self=False), 1) @@ -760,7 +713,9 @@ def test_mixed_placement(self): with self._strategy.scope(): layer = distributed_embedding.DistributedEmbedding(embedding_config) - res = self.run_with_strategy(layer.__call__, inputs) + res = tpu_test_utils.run_with_strategy( + self._strategy, layer.__call__, inputs + ) self.assertEqual( res["feature1"].shape, (self.batch_size, embedding_output_dim1) @@ -793,13 +748,15 @@ def test_save_load_model(self): keras_outputs = layer(keras_inputs) model = keras.Model(inputs=keras_inputs, outputs=keras_outputs) - output_before = self.run_with_strategy(model.__call__, inputs) + output_before = tpu_test_utils.run_with_strategy( + self._strategy, model.__call__, inputs + ) model.save(path) with self._strategy.scope(): reloaded_model = keras.models.load_model(path) - output_after = self.run_with_strategy( - reloaded_model.__call__, inputs + output_after = tpu_test_utils.run_with_strategy( + self._strategy, reloaded_model.__call__, inputs ) if self.placement == "sparsecore": diff --git a/keras_rs/src/layers/embedding/embed_reduce_test.py b/keras_rs/src/layers/embedding/embed_reduce_test.py index 1d7fb456..908067ca 100644 --- a/keras_rs/src/layers/embedding/embed_reduce_test.py +++ b/keras_rs/src/layers/embedding/embed_reduce_test.py @@ -1,6 +1,9 @@ import math +import os import keras +import tensorflow as tf +from absl.testing import absltest from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -8,9 +11,26 @@ from keras_rs.src import testing from keras_rs.src.layers.embedding.embed_reduce import EmbedReduce +from keras_rs.src.utils import tpu_test_utils + +try: + import jax + from jax.experimental import sparse as jax_sparse +except ImportError: + jax = None + jax_sparse = None class EmbedReduceTest(testing.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + self.on_tpu = "TPU_NAME" in os.environ + + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + + self._strategy = tpu_test_utils.get_tpu_strategy(self) + @parameterized.named_parameters( [ ( @@ -42,50 +62,87 @@ def test_call(self, combiner, input_type, input_rank, use_weights): ): self.skipTest(f"sparse not supported on {keras.backend.backend()}") + if self.on_tpu and input_type in ["ragged", "sparse"]: + self.skipTest("Ragged and sparse are not compilable on TPU.") + + batch_size = 2 * self._strategy.num_replicas_in_sync + + def repeat_input(item, times): + return [item[i % len(item)] for i in range(times)] + if input_type == "dense" and input_rank == 1: - inputs = ops.convert_to_tensor([1, 2]) - weights = ops.convert_to_tensor([1.0, 2.0]) + inputs = ops.convert_to_tensor(repeat_input([1, 2], batch_size)) + weights = ops.convert_to_tensor( + repeat_input([1.0, 2.0], batch_size) + ) elif input_type == "dense" and input_rank == 2: - inputs = ops.convert_to_tensor([[1, 2], [3, 4]]) - weights = ops.convert_to_tensor([[1.0, 2.0], [3.0, 4.0]]) + inputs = ops.convert_to_tensor( + repeat_input([[1, 2], [3, 4]], batch_size) + ) + weights = ops.convert_to_tensor( + repeat_input([[1.0, 2.0], [3.0, 4.0]], batch_size) + ) elif input_type == "ragged" and input_rank == 2: - import tensorflow as tf - - inputs = tf.ragged.constant([[1], [2, 3, 4, 5]]) - weights = tf.ragged.constant([[1.0], [1.0, 2.0, 3.0, 4.0]]) + inputs = tf.ragged.constant( + repeat_input([[1], [2, 3, 4, 5]], batch_size) + ) + weights = tf.ragged.constant( + repeat_input([[1.0], [1.0, 2.0, 3.0, 4.0]], batch_size) + ) elif input_type == "sparse" and input_rank == 2: - indices = [[0, 0], [1, 0], [1, 1], [1, 2], [1, 3]] + base_indices = [[0, 0], [1, 0], [1, 1], [1, 2], [1, 3]] + base_values = [1, 2, 3, 4, 5] + base_weights = [1.0, 1.0, 2.0, 3.0, 4.0] + indices = [] + values = [] + weight_values = [] + for i in range(batch_size // 2): + for idx, val, wgt in zip( + base_indices, base_values, base_weights + ): + indices.append([i * 2 + idx[0], idx[1]]) + values.append(val) + weight_values.append(wgt) if keras.backend.backend() == "tensorflow": - import tensorflow as tf - inputs = tf.sparse.reorder( - tf.SparseTensor(indices, [1, 2, 3, 4, 5], (2, 4)) + tf.SparseTensor(indices, values, (batch_size, 4)) ) weights = tf.sparse.reorder( - tf.SparseTensor(indices, [1.0, 1.0, 2.0, 3.0, 4.0], (2, 4)) + tf.SparseTensor(indices, weight_values, (batch_size, 4)) ) elif keras.backend.backend() == "jax": - from jax.experimental import sparse as jax_sparse - inputs = jax_sparse.BCOO( - ([1, 2, 3, 4, 5], indices), - shape=(2, 4), + (ops.array(values), ops.array(indices)), + shape=(batch_size, 4), unique_indices=True, ) weights = jax_sparse.BCOO( - ([1.0, 1.0, 2.0, 3.0, 4.0], indices), - shape=(2, 4), + (ops.array(weight_values), ops.array(indices)), + shape=(batch_size, 4), unique_indices=True, ) if not use_weights: weights = None - layer = EmbedReduce(10, 20, combiner=combiner) - res = layer(inputs, weights) + with self._strategy.scope(): + layer = EmbedReduce(10, 20, combiner=combiner) + + if keras.backend.backend() == "tensorflow": + # TF requires weights to be None or match input type + if input_type == "sparse" and not use_weights: + res = tpu_test_utils.run_with_strategy( + self._strategy, layer.__call__, inputs + ) + else: + res = tpu_test_utils.run_with_strategy( + self._strategy, layer.__call__, inputs, weights + ) + else: # JAX or other + res = layer(inputs, weights) - self.assertEqual(res.shape, (2, 20)) + self.assertEqual(res.shape, (batch_size, 20)) e = layer.embeddings if input_type == "dense" and input_rank == 1: @@ -116,6 +173,7 @@ def test_call(self, combiner, input_type, input_rank, use_weights): elif combiner == "sqrtn": expected[1] /= math.sqrt(30.0 if use_weights else 4.0) + expected = repeat_input(expected, batch_size) self.assertAllClose(res, expected) @parameterized.named_parameters( @@ -141,50 +199,70 @@ def test_call(self, combiner, input_type, input_rank, use_weights): def test_symbolic_call(self, input_type, input_rank, use_weights): if input_type == "ragged" and keras.backend.backend() != "tensorflow": self.skipTest(f"ragged not supported on {keras.backend.backend()}") - if input_type == "sparse" and keras.backend.backend() not in ( - "jax", - "tensorflow", - ): - self.skipTest(f"sparse not supported on {keras.backend.backend()}") + if input_type == "sparse": + if keras.backend.backend() == "jax": + self.assertTrue( + jax is not None, "JAX not found for JAX backend test." + ) + elif keras.backend.backend() != "tensorflow": + self.skipTest( + f"sparse not supported on {keras.backend.backend()}" + ) - input = keras.layers.Input( - shape=(2,) if input_rank == 2 else (), - sparse=input_type == "sparse", - ragged=input_type == "ragged", - dtype="int32", - ) + with self._strategy.scope(): + layer = EmbedReduce(10, 20, dtype="float32") - if use_weights: - weights = keras.layers.Input( + input_tensor = keras.layers.Input( shape=(2,) if input_rank == 2 else (), sparse=input_type == "sparse", ragged=input_type == "ragged", - dtype="float32", + dtype="int32", ) - output = EmbedReduce(10, 20, dtype="float32")(input, weights) - else: - output = EmbedReduce(10, 20, dtype="float32")(input) - self.assertEqual(output.shape, (None, 20)) - self.assertEqual(output.dtype, "float32") - self.assertFalse(output.sparse) - self.assertFalse(output.ragged) + if use_weights: + weights = keras.layers.Input( + shape=(2,) if input_rank == 2 else (), + sparse=input_type == "sparse", + ragged=input_type == "ragged", + dtype="float32", + ) + output = layer(input_tensor, weights) + else: + output = layer(input_tensor) + + self.assertEqual(output.shape, (None, 20)) + self.assertEqual(output.dtype, "float32") + self.assertFalse(output.sparse) + self.assertFalse(output.ragged) def test_predict(self): - input = keras.random.randint((5, 7), minval=0, maxval=10) - model = keras.models.Sequential([EmbedReduce(10, 20)]) - model.predict(input, batch_size=2) + input_data = keras.random.randint((5, 7), minval=0, maxval=10) + with self._strategy.scope(): + model = keras.models.Sequential([EmbedReduce(10, 20)]) + # Compilation is often needed for strategies to be fully utilized + model.compile(optimizer="adam", loss="mse") + + # model.predict itself handles the strategy distribution + model.predict(input_data, batch_size=2) def test_serialization(self): - layer = EmbedReduce(10, 20, combiner="sqrtn") + with self._strategy.scope(): + layer = EmbedReduce(10, 20, combiner="sqrtn") + restored = deserialize(serialize(layer)) self.assertDictEqual(layer.get_config(), restored.get_config()) def test_model_saving(self): - input = keras.random.randint((5, 7), minval=0, maxval=10) - model = keras.models.Sequential([EmbedReduce(10, 20)]) + input_data = keras.random.randint((5, 7), minval=0, maxval=10) + + with self._strategy.scope(): + model = keras.models.Sequential([EmbedReduce(10, 20)]) self.run_model_saving_test( model=model, - input_data=input, + input_data=input_data, ) + + +if __name__ == "__main__": + absltest.main() diff --git a/keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py b/keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py index f314f887..57675d0f 100644 --- a/keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py +++ b/keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py @@ -1,4 +1,5 @@ import keras +import pytest import tensorflow as tf from absl.testing import parameterized @@ -7,6 +8,10 @@ from keras_rs.src.layers.embedding.tensorflow import config_conversion +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="Backend specific test", +) class ConfigConversionTest(testing.TestCase, parameterized.TestCase): @parameterized.named_parameters( ( diff --git a/keras_rs/src/layers/feature_interaction/dot_interaction_test.py b/keras_rs/src/layers/feature_interaction/dot_interaction_test.py index 99c38abc..8c9182c0 100644 --- a/keras_rs/src/layers/feature_interaction/dot_interaction_test.py +++ b/keras_rs/src/layers/feature_interaction/dot_interaction_test.py @@ -1,4 +1,8 @@ +import os + import keras +import tensorflow as tf +from absl.testing import absltest from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -8,10 +12,18 @@ from keras_rs.src.layers.feature_interaction.dot_interaction import ( DotInteraction, ) +from keras_rs.src.utils import tpu_test_utils class DotInteractionTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + + self.on_tpu = "TPU_NAME" in os.environ + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.input = [ ops.array([[0.1, -4.3, 0.2, 1.1, 0.3]]), ops.array([[2.0, 3.2, -1.0, 0.0, 1.0]]), @@ -77,25 +89,36 @@ def setUp(self): ), ) def test_call(self, self_interaction, skip_gather, exp_output_idx): - layer = DotInteraction( - self_interaction=self_interaction, skip_gather=skip_gather + with self._strategy.scope(): + layer = DotInteraction( + self_interaction=self_interaction, skip_gather=skip_gather + ) + output = tpu_test_utils.run_with_strategy( + self._strategy, layer, self.input + ) + self.assertAllClose( + output, self.exp_outputs[exp_output_idx], is_tpu=self.on_tpu ) - output = layer(self.input) - self.assertAllClose(output, self.exp_outputs[exp_output_idx]) def test_invalid_input_rank(self): rank_1_input = [ops.ones((3,)), ops.ones((3,))] - layer = DotInteraction() + with self._strategy.scope(): + layer = DotInteraction() with self.assertRaises(ValueError): - layer(rank_1_input) + tpu_test_utils.run_with_strategy( + self._strategy, layer, rank_1_input + ) def test_invalid_input_different_shapes(self): unequal_shape_input = [ops.ones((1, 3)), ops.ones((1, 4))] - layer = DotInteraction() + with self._strategy.scope(): + layer = DotInteraction() with self.assertRaises(ValueError): - layer(unequal_shape_input) + tpu_test_utils.run_with_strategy( + self._strategy, layer, unequal_shape_input + ) @parameterized.named_parameters( ( @@ -120,31 +143,39 @@ def test_invalid_input_different_shapes(self): ), ) def test_predict(self, self_interaction, skip_gather): - feature1 = keras.layers.Input(shape=(5,)) - feature2 = keras.layers.Input(shape=(5,)) - feature3 = keras.layers.Input(shape=(5,)) - x = DotInteraction( - self_interaction=self_interaction, skip_gather=skip_gather - )([feature1, feature2, feature3]) - x = keras.layers.Dense(units=1)(x) - model = keras.Model([feature1, feature2, feature3], x) + with self._strategy.scope(): + feature1 = keras.layers.Input(shape=(5,)) + feature2 = keras.layers.Input(shape=(5,)) + feature3 = keras.layers.Input(shape=(5,)) + x = DotInteraction( + self_interaction=self_interaction, skip_gather=skip_gather + )([feature1, feature2, feature3]) + x = keras.layers.Dense(units=1)(x) + model = keras.Model([feature1, feature2, feature3], x) + model.compile(optimizer="adam", loss="mse") model.predict(self.input, batch_size=2) def test_serialization(self): - layer = DotInteraction() + with self._strategy.scope(): + layer = DotInteraction() restored = deserialize(serialize(layer)) self.assertDictEqual(layer.get_config(), restored.get_config()) def test_model_saving(self): - feature1 = keras.layers.Input(shape=(5,)) - feature2 = keras.layers.Input(shape=(5,)) - feature3 = keras.layers.Input(shape=(5,)) - x = DotInteraction()([feature1, feature2, feature3]) - x = keras.layers.Dense(units=1)(x) - model = keras.Model([feature1, feature2, feature3], x) + with self._strategy.scope(): + feature1 = keras.layers.Input(shape=(5,)) + feature2 = keras.layers.Input(shape=(5,)) + feature3 = keras.layers.Input(shape=(5,)) + x = DotInteraction()([feature1, feature2, feature3]) + x = keras.layers.Dense(units=1)(x) + model = keras.Model([feature1, feature2, feature3], x) self.run_model_saving_test( model=model, input_data=self.input, ) + + +if __name__ == "__main__": + absltest.main() diff --git a/keras_rs/src/layers/feature_interaction/feature_cross_test.py b/keras_rs/src/layers/feature_interaction/feature_cross_test.py index 8724ab53..89df5009 100644 --- a/keras_rs/src/layers/feature_interaction/feature_cross_test.py +++ b/keras_rs/src/layers/feature_interaction/feature_cross_test.py @@ -1,4 +1,6 @@ import keras +import tensorflow as tf +from absl.testing import absltest from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -6,10 +8,17 @@ from keras_rs.src import testing from keras_rs.src.layers.feature_interaction.feature_cross import FeatureCross +from keras_rs.src.utils import tpu_test_utils class FeatureCrossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.x0 = ops.array([[0.1, 0.2, 0.3]], dtype="float32") self.x = ops.array([[0.4, 0.5, 0.6]], dtype="float32") self.exp_output = ops.array([[0.55, 0.8, 1.05]]) @@ -17,7 +26,8 @@ def setUp(self): self.one_inp_exp_output = ops.array([[0.16, 0.32, 0.48]]) def test_full_layer(self): - layer = FeatureCross(projection_dim=None, kernel_initializer="ones") + with self._strategy.scope(): + layer = FeatureCross(projection_dim=None, kernel_initializer="ones") output = layer(self.x0, self.x) # Test output. @@ -30,7 +40,8 @@ def test_full_layer(self): self.assertEqual(layer.weights[1].shape, (3,)) def test_low_rank_layer(self): - layer = FeatureCross(projection_dim=1, kernel_initializer="ones") + with self._strategy.scope(): + layer = FeatureCross(projection_dim=1, kernel_initializer="ones") output = layer(self.x0, self.x) # Test output. @@ -45,7 +56,8 @@ def test_low_rank_layer(self): self.assertEqual(layer.weights[2].shape, (3,)) def test_one_input(self): - layer = FeatureCross(projection_dim=None, kernel_initializer="ones") + with self._strategy.scope(): + layer = FeatureCross(projection_dim=None, kernel_initializer="ones") output = layer(self.x0) self.assertAllClose(self.one_inp_exp_output, output) @@ -53,7 +65,8 @@ def test_invalid_input_shapes(self): x0 = ops.ones((12, 5)) x = ops.ones((12, 7)) - layer = FeatureCross() + with self._strategy.scope(): + layer = FeatureCross() with self.assertRaises(ValueError): layer(x0, x) @@ -63,41 +76,53 @@ def test_invalid_diag_scale(self): FeatureCross(diag_scale=-1.0) def test_diag_scale(self): - layer = FeatureCross( - projection_dim=None, diag_scale=1.0, kernel_initializer="ones" - ) + with self._strategy.scope(): + layer = FeatureCross( + projection_dim=None, diag_scale=1.0, kernel_initializer="ones" + ) output = layer(self.x0, self.x) self.assertAllClose(ops.array([[0.59, 0.9, 1.23]]), output) def test_pre_activation(self): - layer = FeatureCross(projection_dim=None, pre_activation=ops.zeros_like) + with self._strategy.scope(): + layer = FeatureCross( + projection_dim=None, pre_activation=ops.zeros_like + ) output = layer(self.x0, self.x) self.assertAllClose(self.x, output) def test_predict(self): - x0 = keras.layers.Input(shape=(3,)) - x1 = FeatureCross(projection_dim=None)(x0, x0) - x2 = FeatureCross(projection_dim=None)(x0, x1) - logits = keras.layers.Dense(units=1)(x2) - model = keras.Model(x0, logits) + with self._strategy.scope(): + x0 = keras.layers.Input(shape=(3,)) + x1 = FeatureCross(projection_dim=None)(x0, x0) + x2 = FeatureCross(projection_dim=None)(x0, x1) + logits = keras.layers.Dense(units=1)(x2) + model = keras.Model(x0, logits) + model.compile(optimizer="adam", loss="mse") model.predict(self.x0, batch_size=2) def test_serialization(self): - sampler = FeatureCross(projection_dim=None, pre_activation="swish") + with self._strategy.scope(): + sampler = FeatureCross(projection_dim=None, pre_activation="swish") restored = deserialize(serialize(sampler)) self.assertDictEqual(sampler.get_config(), restored.get_config()) def test_model_saving(self): - x0 = keras.layers.Input(shape=(3,)) - x1 = FeatureCross(projection_dim=None)(x0, x0) - x2 = FeatureCross(projection_dim=None)(x0, x1) - logits = keras.layers.Dense(units=1)(x2) - model = keras.Model(x0, logits) + with self._strategy.scope(): + x0 = keras.layers.Input(shape=(3,)) + x1 = FeatureCross(projection_dim=None)(x0, x0) + x2 = FeatureCross(projection_dim=None)(x0, x1) + logits = keras.layers.Dense(units=1)(x2) + model = keras.Model(x0, logits) self.run_model_saving_test( model=model, input_data=self.x0, ) + + +if __name__ == "__main__": + absltest.main() diff --git a/keras_rs/src/layers/retrieval/brute_force_retrieval_test.py b/keras_rs/src/layers/retrieval/brute_force_retrieval_test.py index a5f8a86f..a31aa417 100644 --- a/keras_rs/src/layers/retrieval/brute_force_retrieval_test.py +++ b/keras_rs/src/layers/retrieval/brute_force_retrieval_test.py @@ -1,11 +1,24 @@ +import os + import keras +import tensorflow as tf +from absl.testing import absltest from absl.testing import parameterized from keras_rs.src import testing from keras_rs.src.layers.retrieval import brute_force_retrieval +from keras_rs.src.utils import tpu_test_utils class BruteForceRetrievalTest(testing.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + self.on_tpu = "TPU_NAME" in os.environ + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + + self._strategy = tpu_test_utils.get_tpu_strategy(self) + @parameterized.product( has_identifiers=(True, False), return_scores=(True, False), @@ -25,12 +38,13 @@ def test_brute_force_retrieval(self, has_identifiers, return_scores): else None ) - layer = brute_force_retrieval.BruteForceRetrieval( - k=k, - candidate_embeddings=candidates, - candidate_ids=candidate_indices, - return_scores=return_scores, - ) + with self._strategy.scope(): + layer = brute_force_retrieval.BruteForceRetrieval( + k=k, + candidate_embeddings=candidates, + candidate_ids=candidate_indices, + return_scores=return_scores, + ) query = keras.random.normal((num_queries, 4), dtype="float32", seed=rng) scores = keras.ops.matmul(query, keras.ops.transpose(candidates)) @@ -51,14 +65,34 @@ def test_brute_force_retrieval(self, has_identifiers, return_scores): for i in range(2): if i: # First time uses values from __init__, second time uses update. - layer.update_candidates(candidates, candidate_indices) + with self._strategy.scope(): + layer.update_candidates(candidates, candidate_indices) if return_scores: - top_scores, top_indices = layer(query) + top_scores, top_indices = tpu_test_utils.run_with_strategy( + self._strategy, layer, query + ) self.assertEqual(top_scores.shape, expected_top_scores.shape) - self.assertAllClose(top_scores, expected_top_scores, atol=1e-4) + self.assertAllClose( + top_scores, + expected_top_scores, + atol=1e-4, + is_tpu=self.on_tpu, + ) else: - top_indices = layer(query) + top_indices = tpu_test_utils.run_with_strategy( + self._strategy, layer, query + ) self.assertEqual(top_indices.shape, expected_top_indices.shape) - self.assertAllClose(top_indices, expected_top_indices) + self.assertAllClose( + top_indices, + expected_top_indices, + tpu_atol=5, + tpu_rtol=10, + is_tpu=self.on_tpu, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/keras_rs/src/layers/retrieval/hard_negative_mining_test.py b/keras_rs/src/layers/retrieval/hard_negative_mining_test.py index d7ab74d0..ddbc0832 100644 --- a/keras_rs/src/layers/retrieval/hard_negative_mining_test.py +++ b/keras_rs/src/layers/retrieval/hard_negative_mining_test.py @@ -1,4 +1,6 @@ import keras +import tensorflow as tf +from absl.testing import absltest from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -6,9 +8,17 @@ from keras_rs.src import testing from keras_rs.src.layers.retrieval import hard_negative_mining +from keras_rs.src.utils import tpu_test_utils class HardNegativeMiningTest(testing.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + + self._strategy = tpu_test_utils.get_tpu_strategy(self) + def create_inputs(self, rank=2): shape_3d = (15, 20, 10) shape = shape_3d[-rank:] @@ -59,9 +69,11 @@ def test_call(self, rank, num_hard_negatives): logits, labels = self.create_inputs(rank=rank) num_logits = logits.shape[-1] - out_logits, out_labels = hard_negative_mining.HardNegativeMining( - num_hard_negatives - )(logits, labels) + with self._strategy.scope(): + layer = hard_negative_mining.HardNegativeMining(num_hard_negatives) + out_logits, out_labels = tpu_test_utils.run_with_strategy( + self._strategy, layer, logits, labels + ) self.assertEqual( out_logits.shape[-1], min(num_hard_negatives + 1, num_logits) @@ -76,9 +88,11 @@ def test_call(self, rank, num_hard_negatives): # Set the logits for labels to be highest to ignore effect of labels. logits = logits + labels * 1000.0 - out_logits, _ = hard_negative_mining.HardNegativeMining( - num_hard_negatives - )(logits, labels) + with self._strategy.scope(): + layer = hard_negative_mining.HardNegativeMining(num_hard_negatives) + out_logits, _ = tpu_test_utils.run_with_strategy( + self._strategy, layer, logits, labels + ) # Highest K logits are always returned. self.assertAllClose( @@ -89,28 +103,42 @@ def test_call(self, rank, num_hard_negatives): def test_predict(self): logits, labels = self.create_inputs() - in_logits = keras.layers.Input(shape=logits.shape[1:]) - in_labels = keras.layers.Input(shape=labels.shape[1:]) - out_logits, out_labels = hard_negative_mining.HardNegativeMining( - num_hard_negatives=3 - )(in_logits, in_labels) - model = keras.Model([in_logits, in_labels], [out_logits, out_labels]) + with self._strategy.scope(): + in_logits = keras.layers.Input(shape=logits.shape[1:]) + in_labels = keras.layers.Input(shape=labels.shape[1:]) + out_logits, out_labels = hard_negative_mining.HardNegativeMining( + num_hard_negatives=3 + )(in_logits, in_labels) + model = keras.Model( + [in_logits, in_labels], [out_logits, out_labels] + ) + model.compile(optimizer="adam", loss="mse") model.predict([logits, labels], batch_size=8) def test_serialization(self): - layer = hard_negative_mining.HardNegativeMining(num_hard_negatives=3) + with self._strategy.scope(): + layer = hard_negative_mining.HardNegativeMining( + num_hard_negatives=3 + ) restored = deserialize(serialize(layer)) self.assertDictEqual(layer.get_config(), restored.get_config()) def test_model_saving(self): logits, labels = self.create_inputs() - in_logits = keras.layers.Input(shape=logits.shape[1:]) - in_labels = keras.layers.Input(shape=labels.shape[1:]) - out_logits, out_labels = hard_negative_mining.HardNegativeMining( - num_hard_negatives=3 - )(in_logits, in_labels) - model = keras.Model([in_logits, in_labels], [out_logits, out_labels]) + with self._strategy.scope(): + in_logits = keras.layers.Input(shape=logits.shape[1:]) + in_labels = keras.layers.Input(shape=labels.shape[1:]) + out_logits, out_labels = hard_negative_mining.HardNegativeMining( + num_hard_negatives=3 + )(in_logits, in_labels) + model = keras.Model( + [in_logits, in_labels], [out_logits, out_labels] + ) self.run_model_saving_test(model=model, input_data=[logits, labels]) + + +if __name__ == "__main__": + absltest.main() diff --git a/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py b/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py index 8cb4fa71..cb1e7357 100644 --- a/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py +++ b/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py @@ -1,4 +1,6 @@ import keras +import tensorflow as tf +from absl.testing import absltest from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -6,9 +8,17 @@ from keras_rs.src import testing from keras_rs.src.layers.retrieval import remove_accidental_hits +from keras_rs.src.utils import tpu_test_utils class RemoveAccidentalHitsTest(testing.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + + self._strategy = tpu_test_utils.get_tpu_strategy(self) + def create_inputs(self, logits_rank=2, candidate_ids_rank=1): shape_3d = (15, 20, 10) shape = shape_3d[-logits_rank:] @@ -68,8 +78,10 @@ def test_call(self, logits_rank, candidate_ids_rank): logits_rank=logits_rank, candidate_ids_rank=candidate_ids_rank ) - out_logits = remove_accidental_hits.RemoveAccidentalHits()( - logits, labels, candidate_ids + with self._strategy.scope(): + layer = remove_accidental_hits.RemoveAccidentalHits() + out_logits = tpu_test_utils.run_with_strategy( + self._strategy, layer, logits, labels, candidate_ids ) # Logits of labels are unchanged. @@ -131,54 +143,78 @@ def test_call(self, logits_rank, candidate_ids_rank): ) def test_mismatched_labels_logits_shapes(self): - layer = remove_accidental_hits.RemoveAccidentalHits() + with self._strategy.scope(): + layer = remove_accidental_hits.RemoveAccidentalHits() with self.assertRaisesRegex( ValueError, "`labels` and `logits` should have the same shape" ): - layer(ops.zeros((10, 20)), ops.zeros((10, 30)), ops.zeros((20,))) + tpu_test_utils.run_with_strategy( + self._strategy, + layer, + ops.zeros((10, 20)), + ops.zeros((10, 30)), + ops.zeros((20,)), + ) def test_mismatched_labels_candidates_shapes(self): - layer = remove_accidental_hits.RemoveAccidentalHits() + with self._strategy.scope(): + layer = remove_accidental_hits.RemoveAccidentalHits() with self.assertRaisesRegex( ValueError, "`candidate_ids` should have the same shape as .* `labels`", ): - layer(ops.zeros((10, 20)), ops.zeros((10, 20)), ops.zeros((30,))) + tpu_test_utils.run_with_strategy( + self._strategy, + layer, + ops.zeros((10, 20)), + ops.zeros((10, 20)), + ops.zeros((30,)), + ) def test_predict(self): # Note: for predict, we test with probabilities that have a batch dim. logits, labels, candidate_ids = self.create_inputs(candidate_ids_rank=2) - layer = remove_accidental_hits.RemoveAccidentalHits() - in_logits = keras.layers.Input(logits.shape[1:]) - in_labels = keras.layers.Input(labels.shape[1:]) - in_candidate_ids = keras.layers.Input(labels.shape[1:]) - out_logits = layer(in_logits, in_labels, in_candidate_ids) - model = keras.Model( - [in_logits, in_labels, in_candidate_ids], out_logits - ) + with self._strategy.scope(): + layer = remove_accidental_hits.RemoveAccidentalHits() + in_logits = keras.layers.Input(logits.shape[1:]) + in_labels = keras.layers.Input(labels.shape[1:]) + in_candidate_ids = keras.layers.Input(labels.shape[1:]) + out_logits = layer(in_logits, in_labels, in_candidate_ids) + model = keras.Model( + [in_logits, in_labels, in_candidate_ids], out_logits + ) + model.compile(optimizer="adam", loss="mse") model.predict([logits, labels, candidate_ids], batch_size=8) def test_serialization(self): - layer = remove_accidental_hits.RemoveAccidentalHits() + with self._strategy.scope(): + layer = remove_accidental_hits.RemoveAccidentalHits() restored = deserialize(serialize(layer)) self.assertDictEqual(layer.get_config(), restored.get_config()) def test_model_saving(self): logits, labels, candidate_ids = self.create_inputs() - layer = remove_accidental_hits.RemoveAccidentalHits() - in_logits = keras.layers.Input(logits.shape[1:]) - in_labels = keras.layers.Input(labels.shape[1:]) - in_candidate_ids = keras.layers.Input(batch_shape=candidate_ids.shape) - out_logits = layer(in_logits, in_labels, in_candidate_ids) - model = keras.Model( - [in_logits, in_labels, in_candidate_ids], out_logits - ) + with self._strategy.scope(): + layer = remove_accidental_hits.RemoveAccidentalHits() + in_logits = keras.layers.Input(logits.shape[1:]) + in_labels = keras.layers.Input(labels.shape[1:]) + in_candidate_ids = keras.layers.Input( + batch_shape=candidate_ids.shape + ) + out_logits = layer(in_logits, in_labels, in_candidate_ids) + model = keras.Model( + [in_logits, in_labels, in_candidate_ids], out_logits + ) self.run_model_saving_test( model=model, input_data=[logits, labels, candidate_ids] ) + + +if __name__ == "__main__": + absltest.main() diff --git a/keras_rs/src/layers/retrieval/retrieval_test.py b/keras_rs/src/layers/retrieval/retrieval_test.py index a5887e28..05ee0305 100644 --- a/keras_rs/src/layers/retrieval/retrieval_test.py +++ b/keras_rs/src/layers/retrieval/retrieval_test.py @@ -1,8 +1,11 @@ import keras +import tensorflow as tf +from absl.testing import absltest from absl.testing import parameterized from keras_rs.src import testing from keras_rs.src.layers.retrieval.retrieval import Retrieval +from keras_rs.src.utils import tpu_test_utils class DummyRetrieval(Retrieval): @@ -15,7 +18,15 @@ def call(self, inputs): class RetrievalTest(testing.TestCase, parameterized.TestCase): def setUp(self): - self.layer = DummyRetrieval(k=5) + super().setUp() + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + + self._strategy = tpu_test_utils.get_tpu_strategy(self) + if self._strategy is None: + self._strategy = tpu_test_utils.DummyStrategy() + with self._strategy.scope(): + self.layer = DummyRetrieval(k=5) @parameterized.named_parameters( ("embeddings_none", None, None, "`candidate_embeddings` is required."), @@ -48,19 +59,27 @@ def test_validate_candidate_embeddings_and_ids( ) def test_call_not_overridden(self): - class DummyRetrieval(Retrieval): - def update_candidates( - self, candidate_embeddings, candidate_ids=None - ): - pass + with self._strategy.scope(): + + class DummyRetrieval(Retrieval): + def update_candidates( + self, candidate_embeddings, candidate_ids=None + ): + pass with self.assertRaises(TypeError): DummyRetrieval(k=5) def test_update_candidates_not_overridden(self): - class DummyRetrieval(Retrieval): - def call(self, inputs): - pass + with self._strategy.scope(): + + class DummyRetrieval(Retrieval): + def call(self, inputs): + pass with self.assertRaises(TypeError): DummyRetrieval(k=5) + + +if __name__ == "__main__": + absltest.main() diff --git a/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py b/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py index 8dc8ff73..6d88fd2a 100644 --- a/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py +++ b/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py @@ -1,4 +1,6 @@ import keras +import tensorflow as tf +from absl.testing import absltest from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -6,11 +8,19 @@ from keras_rs.src import testing from keras_rs.src.layers.retrieval import sampling_probability_correction +from keras_rs.src.utils import tpu_test_utils class SamplingProbabilityCorrectionTest( testing.TestCase, parameterized.TestCase ): + def setUp(self): + super().setUp() + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + + self._strategy = tpu_test_utils.get_tpu_strategy(self) + def create_inputs(self, logits_rank=2, probs_rank=1): shape_3d = (15, 20, 10) logits_shape = shape_3d[-logits_rank:] @@ -61,8 +71,13 @@ def test_call(self, logits_rank, probs_rank): ) # Verifies logits are always less than corrected logits. - layer = sampling_probability_correction.SamplingProbabilityCorrection() - corrected_logits = layer(logits, probs) + with self._strategy.scope(): + layer = ( + sampling_probability_correction.SamplingProbabilityCorrection() + ) + corrected_logits = tpu_test_utils.run_with_strategy( + self._strategy, layer, logits, probs + ) self.assertAllClose( ops.less(logits, corrected_logits), ops.ones(logits.shape) ) @@ -77,7 +92,9 @@ def test_call(self, logits_rank, probs_rank): ) # Verifies logits are always less than corrected logits. - corrected_logits_with_zeros = layer(logits, probs_with_zeros) + corrected_logits_with_zeros = tpu_test_utils.run_with_strategy( + self._strategy, layer, logits, probs_with_zeros + ) self.assertAllClose( ops.less(logits, corrected_logits_with_zeros), ops.ones(logits.shape), @@ -87,26 +104,40 @@ def test_predict(self): # Note: for predict, we test with probabilities that have a batch dim. logits, probs = self.create_inputs(probs_rank=2) - layer = sampling_probability_correction.SamplingProbabilityCorrection() - in_logits = keras.layers.Input(logits.shape[1:]) - in_probs = keras.layers.Input(probs.shape[1:]) - out_logits = layer(in_logits, in_probs) - model = keras.Model([in_logits, in_probs], out_logits) + with self._strategy.scope(): + layer = ( + sampling_probability_correction.SamplingProbabilityCorrection() + ) + in_logits = keras.layers.Input(logits.shape[1:]) + in_probs = keras.layers.Input(probs.shape[1:]) + out_logits = layer(in_logits, in_probs) + model = keras.Model([in_logits, in_probs], out_logits) + model.compile(optimizer="adam", loss="mse") model.predict([logits, probs], batch_size=4) def test_serialization(self): - layer = sampling_probability_correction.SamplingProbabilityCorrection() + with self._strategy.scope(): + layer = ( + sampling_probability_correction.SamplingProbabilityCorrection() + ) restored = deserialize(serialize(layer)) self.assertDictEqual(layer.get_config(), restored.get_config()) def test_model_saving(self): logits, probs = self.create_inputs() - layer = sampling_probability_correction.SamplingProbabilityCorrection() - in_logits = keras.layers.Input(shape=logits.shape[1:]) - in_probs = keras.layers.Input(batch_shape=probs.shape) - out_logits = layer(in_logits, in_probs) - model = keras.Model([in_logits, in_probs], out_logits) + with self._strategy.scope(): + layer = ( + sampling_probability_correction.SamplingProbabilityCorrection() + ) + in_logits = keras.layers.Input(shape=logits.shape[1:]) + in_probs = keras.layers.Input(batch_shape=probs.shape) + out_logits = layer(in_logits, in_probs) + model = keras.Model([in_logits, in_probs], out_logits) self.run_model_saving_test(model=model, input_data=[logits, probs]) + + +if __name__ == "__main__": + absltest.main() diff --git a/keras_rs/src/testing/test_case.py b/keras_rs/src/testing/test_case.py index a764abf3..a30f52d2 100644 --- a/keras_rs/src/testing/test_case.py +++ b/keras_rs/src/testing/test_case.py @@ -23,6 +23,9 @@ def assertAllClose( desired: types.Tensor, atol: float = 1e-6, rtol: float = 1e-6, + tpu_atol: float = 1e-2, + tpu_rtol: float = 1e-2, + is_tpu: bool = False, msg: str = "", ) -> None: """Verify that two tensors are close in value element by element. @@ -38,6 +41,11 @@ def assertAllClose( actual = keras.ops.convert_to_numpy(actual) if not isinstance(desired, np.ndarray): desired = keras.ops.convert_to_numpy(desired) + if tpu_atol is not None and is_tpu: + atol = tpu_atol + if tpu_rtol is not None and is_tpu: + rtol = tpu_rtol + np.testing.assert_allclose( actual, desired, atol=atol, rtol=rtol, err_msg=msg ) diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py new file mode 100644 index 00000000..a27bdecb --- /dev/null +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -0,0 +1,78 @@ +import contextlib +import os + +import keras +import tensorflow as tf + +try: + import jax +except ImportError: + jax = None + + +class DummyStrategy: + def scope(self): + return contextlib.nullcontext() + + @property + def num_replicas_in_sync(self): + return 1 + + def run(self, fn, args): + return fn(*args) + + def experimental_distribute_dataset(self, dataset, options=None): + del options + return dataset + + +class JaxDummyStrategy(DummyStrategy): + @property + def num_replicas_in_sync(self): + if jax is None: + return 0 + return jax.device_count("tpu") + + +def get_tpu_strategy(test_case): + """Get TPU strategy if on TPU, otherwise return DummyStrategy.""" + if "TPU_NAME" not in os.environ: + return DummyStrategy() + if keras.backend.backend() == "tensorflow": + resolver = tf.distribute.cluster_resolver.TPUClusterResolver() + tf.config.experimental_connect_to_cluster(resolver) + topology = tf.tpu.experimental.initialize_tpu_system(resolver) + tpu_metadata = resolver.get_tpu_system_metadata() + device_assignment = tf.tpu.experimental.DeviceAssignment.build( + topology, num_replicas=tpu_metadata.num_hosts + ) + strategy = tf.distribute.TPUStrategy( + resolver, experimental_device_assignment=device_assignment + ) + print("### num_replicas", strategy.num_replicas_in_sync) + test_case.addCleanup(tf.tpu.experimental.shutdown_tpu_system, resolver) + return strategy + elif keras.backend.backend() == "jax": + if jax is None: + raise ImportError( + "JAX backend requires jax to be installed for TPU." + ) + return JaxDummyStrategy() + else: + return DummyStrategy() + + +def run_with_strategy(strategy, fn, *args, jit_compile=False): + """Wrapper for running a function under a strategy.""" + if keras.backend.backend() == "tensorflow": + + @tf.function(jit_compile=jit_compile) + def tf_function_wrapper(*tf_function_args): + def strategy_fn(*strategy_fn_args): + return fn(*strategy_fn_args) + + return strategy.run(strategy_fn, args=tf_function_args) + + return tf_function_wrapper(*args) + else: + assert not jit_compile From 9d1b08915245ea4972b39fc677bfacaff2de90e5 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Sun, 23 Nov 2025 20:28:51 +0000 Subject: [PATCH 02/29] fix jax test --- .gitignore | 1 + keras_rs/src/layers/retrieval/retrieval_test.py | 2 -- keras_rs/src/utils/tpu_test_utils.py | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index dfbf3335..df148986 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ build/ venv/ venv_tf/ +venv_jax/ \ No newline at end of file diff --git a/keras_rs/src/layers/retrieval/retrieval_test.py b/keras_rs/src/layers/retrieval/retrieval_test.py index 05ee0305..bfd1b88b 100644 --- a/keras_rs/src/layers/retrieval/retrieval_test.py +++ b/keras_rs/src/layers/retrieval/retrieval_test.py @@ -23,8 +23,6 @@ def setUp(self): tf.debugging.disable_traceback_filtering() self._strategy = tpu_test_utils.get_tpu_strategy(self) - if self._strategy is None: - self._strategy = tpu_test_utils.DummyStrategy() with self._strategy.scope(): self.layer = DummyRetrieval(k=5) diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py index a27bdecb..ceadb54b 100644 --- a/keras_rs/src/utils/tpu_test_utils.py +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -76,3 +76,4 @@ def strategy_fn(*strategy_fn_args): return tf_function_wrapper(*args) else: assert not jit_compile + return fn(*args) From 3c348df354fc5b9242cefcb45c7d3a480c91ed78 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Sun, 23 Nov 2025 20:45:47 +0000 Subject: [PATCH 03/29] add jax strategy print --- keras_rs/src/utils/tpu_test_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py index ceadb54b..8e7af81a 100644 --- a/keras_rs/src/utils/tpu_test_utils.py +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -57,6 +57,7 @@ def get_tpu_strategy(test_case): raise ImportError( "JAX backend requires jax to be installed for TPU." ) + print("### num_replicas", jax.device_count("tpu")) return JaxDummyStrategy() else: return DummyStrategy() From 86c96dcbca1423aba7996f28c4233d9f1110f4c4 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Mon, 24 Nov 2025 08:59:35 +0000 Subject: [PATCH 04/29] support tpu for losses tests and some metrics tests --- keras_rs/src/losses/list_mle_loss_test.py | 35 +++- .../src/losses/pairwise_hinge_loss_test.py | 35 +++- .../src/losses/pairwise_logistic_loss_test.py | 34 +++- .../pairwise_mean_squared_error_test.py | 35 +++- .../pairwise_soft_zero_one_loss_test.py | 35 +++- keras_rs/src/metrics/dcg_test.py | 160 ++++++++++++------ .../metrics/mean_average_precision_test.py | 149 ++++++++++------ .../src/metrics/mean_reciprocal_rank_test.py | 123 +++++++++----- keras_rs/src/metrics/ndcg_test.py | 31 ++-- keras_rs/src/utils/tpu_test_utils.py | 11 +- 10 files changed, 438 insertions(+), 210 deletions(-) diff --git a/keras_rs/src/losses/list_mle_loss_test.py b/keras_rs/src/losses/list_mle_loss_test.py index 3656354b..57212a5c 100644 --- a/keras_rs/src/losses/list_mle_loss_test.py +++ b/keras_rs/src/losses/list_mle_loss_test.py @@ -6,10 +6,16 @@ from keras_rs.src import testing from keras_rs.src.losses.list_mle_loss import ListMLELoss +import tensorflow as tf +from keras_rs.src.utils import tpu_test_utils class ListMLELossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.unbatched_scores = ops.array( [1.0, 3.0, 2.0, 4.0, 0.8], dtype="float32" ) @@ -83,15 +89,26 @@ def test_scalar_sample_weight(self): ) def test_model_fit(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - - model.compile(loss=ListMLELoss(), optimizer="adam") - model.fit( - x=keras.random.normal((2, 20)), - y=keras.random.randint((2, 5), minval=0, maxval=2), - ) + def create_model(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile(loss=ListMLELoss(), optimizer="adam") + return model + + if self._strategy: + with self._strategy.scope(): + model = create_model() + else: + model = create_model() + + x_data = keras.random.normal((2, 20)) + y_data = keras.random.randint((2, 5), minval=0, maxval=2) + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch(self._strategy.num_replicas_in_sync if self._strategy else 1) + if self._strategy: + dataset = self._strategy.experimental_distribute_dataset(dataset) + + model.fit(dataset, epochs=1, steps_per_epoch=2) def test_serialization(self): loss = ListMLELoss(temperature=0.8) diff --git a/keras_rs/src/losses/pairwise_hinge_loss_test.py b/keras_rs/src/losses/pairwise_hinge_loss_test.py index f5aedb20..991ff310 100644 --- a/keras_rs/src/losses/pairwise_hinge_loss_test.py +++ b/keras_rs/src/losses/pairwise_hinge_loss_test.py @@ -6,10 +6,16 @@ from keras_rs.src import testing from keras_rs.src.losses.pairwise_hinge_loss import PairwiseHingeLoss +from keras_rs.src.utils import tpu_test_utils +import tensorflow as tf class PairwiseHingeLossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) @@ -110,15 +116,26 @@ def test_mask_input(self): self.assertAllClose(output, expected_output, atol=1e-5) def test_model_fit(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - - model.compile(loss=PairwiseHingeLoss(), optimizer="adam") - model.fit( - x=keras.random.normal((2, 20)), - y=keras.random.randint((2, 5), minval=0, maxval=2), - ) + def create_model(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile(loss=PairwiseHingeLoss(), optimizer="adam") + return model + + if self._strategy: + with self._strategy.scope(): + model = create_model() + else: + model = create_model() + + x_data = keras.random.normal((2, 20)) + y_data = keras.random.randint((2, 5), minval=0, maxval=2) + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch(self._strategy.num_replicas_in_sync if self._strategy else 1) + if self._strategy: + dataset = self._strategy.experimental_distribute_dataset(dataset) + + model.fit(dataset, epochs=1, steps_per_epoch=2) def test_serialization(self): loss = PairwiseHingeLoss() diff --git a/keras_rs/src/losses/pairwise_logistic_loss_test.py b/keras_rs/src/losses/pairwise_logistic_loss_test.py index ffba4b05..3e853fcb 100644 --- a/keras_rs/src/losses/pairwise_logistic_loss_test.py +++ b/keras_rs/src/losses/pairwise_logistic_loss_test.py @@ -6,10 +6,15 @@ from keras_rs.src import testing from keras_rs.src.losses.pairwise_logistic_loss import PairwiseLogisticLoss +from keras_rs.src.utils import tpu_test_utils +import tensorflow as tf class PairwiseLogisticLossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + self._strategy = tpu_test_utils.get_tpu_strategy(self) self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) @@ -110,15 +115,26 @@ def test_mask_input(self): self.assertAllClose(output, expected_output, atol=1e-5) def test_model_fit(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - - model.compile(loss=PairwiseLogisticLoss(), optimizer="adam") - model.fit( - x=keras.random.normal((2, 20)), - y=keras.random.randint((2, 5), minval=0, maxval=2), - ) + def create_model(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile(loss=PairwiseLogisticLoss(), optimizer="adam") + return model + + if self._strategy: + with self._strategy.scope(): + model = create_model() + else: + model = create_model() + + x_data = keras.random.normal((2, 20)) + y_data = keras.random.randint((2, 5), minval=0, maxval=2) + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch(self._strategy.num_replicas_in_sync if self._strategy else 1) + if self._strategy: + dataset = self._strategy.experimental_distribute_dataset(dataset) + + model.fit(dataset, epochs=1, steps_per_epoch=2) def test_serialization(self): loss = PairwiseLogisticLoss() diff --git a/keras_rs/src/losses/pairwise_mean_squared_error_test.py b/keras_rs/src/losses/pairwise_mean_squared_error_test.py index 4b93eff9..2d299b76 100644 --- a/keras_rs/src/losses/pairwise_mean_squared_error_test.py +++ b/keras_rs/src/losses/pairwise_mean_squared_error_test.py @@ -8,10 +8,16 @@ from keras_rs.src.losses.pairwise_mean_squared_error import ( PairwiseMeanSquaredError, ) +import tensorflow as tf +from keras_rs.src.utils import tpu_test_utils class PairwiseMeanSquaredErrorTest(testing.TestCase, parameterized.TestCase): def setUp(self): + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) @@ -109,15 +115,26 @@ def test_mask_input(self): self.assertAllClose(output, expected_output, atol=1e-5) def test_model_fit(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - - model.compile(loss=PairwiseMeanSquaredError(), optimizer="adam") - model.fit( - x=keras.random.normal((2, 20)), - y=keras.random.randint((2, 5), minval=0, maxval=2), - ) + def create_model(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile(loss=PairwiseMeanSquaredError(), optimizer="adam") + return model + + if self._strategy: + with self._strategy.scope(): + model = create_model() + else: + model = create_model() + + x_data = keras.random.normal((2, 20)) + y_data = keras.random.randint((2, 5), minval=0, maxval=2) + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch(self._strategy.num_replicas_in_sync if self._strategy else 1) + if self._strategy: + dataset = self._strategy.experimental_distribute_dataset(dataset) + + model.fit(dataset, epochs=1, steps_per_epoch=2) def test_serialization(self): loss = PairwiseMeanSquaredError() diff --git a/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py b/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py index 66e7d634..b6fb7184 100644 --- a/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py +++ b/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py @@ -8,10 +8,16 @@ from keras_rs.src.losses.pairwise_soft_zero_one_loss import ( PairwiseSoftZeroOneLoss, ) +import tensorflow as tf +from keras_rs.src.utils import tpu_test_utils class PairwiseSoftZeroOneLossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) @@ -112,15 +118,26 @@ def test_mask_input(self): self.assertAllClose(output, expected_output, atol=1e-5) def test_model_fit(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - - model.compile(loss=PairwiseSoftZeroOneLoss(), optimizer="adam") - model.fit( - x=keras.random.normal((2, 20)), - y=keras.random.randint((2, 5), minval=0, maxval=2), - ) + def create_model(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile(loss=PairwiseSoftZeroOneLoss(), optimizer="adam") + return model + + if self._strategy: + with self._strategy.scope(): + model = create_model() + else: + model = create_model() + + x_data = keras.random.normal((2, 20)) + y_data = keras.random.randint((2, 5), minval=0, maxval=2) + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch(self._strategy.num_replicas_in_sync if self._strategy else 1) + if self._strategy: + dataset = self._strategy.experimental_distribute_dataset(dataset) + + model.fit(dataset, epochs=1, steps_per_epoch=2) def test_serialization(self): loss = PairwiseSoftZeroOneLoss() diff --git a/keras_rs/src/metrics/dcg_test.py b/keras_rs/src/metrics/dcg_test.py index 430214ac..cee15915 100644 --- a/keras_rs/src/metrics/dcg_test.py +++ b/keras_rs/src/metrics/dcg_test.py @@ -8,6 +8,8 @@ from keras_rs.src import testing from keras_rs.src.metrics.dcg import DCG +from keras_rs.src.utils import tpu_test_utils +import tensorflow as tf def _compute_dcg(labels, ranks): @@ -19,6 +21,10 @@ def _compute_dcg(labels, ranks): class DCGTest(testing.TestCase, parameterized.TestCase): def setUp(self): + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.y_true_batched = ops.array( [ [0, 0, 1, 0], @@ -50,18 +56,19 @@ def setUp(self): ) def test_invalid_k_init(self): - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - DCG(k=0) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - DCG(k=-5) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - DCG(k=3.5) + with self._strategy.scope(): + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + DCG(k=0) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + DCG(k=-5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + DCG(k=3.5) @parameterized.named_parameters( ( @@ -131,14 +138,25 @@ def test_invalid_k_init(self): def test_unbatched_inputs( self, y_true, y_pred, sample_weight, expected_output ): - dcg_metric = DCG() - dcg_metric.update_state(y_true, y_pred, sample_weight=sample_weight) + with self._strategy.scope(): + dcg_metric = DCG() + y_true_t = ops.array(y_true, dtype="float32") + y_pred_t = ops.array(y_pred, dtype="float32") + sw = ops.array(sample_weight, dtype="float32") if sample_weight is not None else None + args = (y_true_t, y_pred_t, sw) if sw is not None else (y_true_t, y_pred_t) + tpu_test_utils.run_with_strategy(self._strategy, dcg_metric.update_state, *args) result = dcg_metric.result() self.assertAllClose(result, expected_output) def test_batched_input(self): - dcg_metric = DCG() - dcg_metric.update_state(self.y_true_batched, self.y_pred_batched) + with self._strategy.scope(): + dcg_metric = DCG() + tpu_test_utils.run_with_strategy( + self._strategy, + dcg_metric.update_state, + self.y_true_batched, + self.y_pred_batched + ) result = dcg_metric.result() self.assertAllClose(result, self.expected_output_batched) @@ -148,11 +166,15 @@ def test_batched_input(self): ("1d", [1.0, 0.5, 2.0, 1.0], 2.7288804), ) def test_batched_inputs_sample_weight(self, sample_weight, expected_output): - dcg_metric = DCG() - dcg_metric.update_state( + with self._strategy.scope(): + dcg_metric = DCG() + sw = ops.array(sample_weight, dtype="float32") + tpu_test_utils.run_with_strategy( + self._strategy, + dcg_metric.update_state, self.y_true_batched, self.y_pred_batched, - sample_weight=sample_weight, + sample_weight=sw, ) result = dcg_metric.result() self.assertAllClose(result, expected_output) @@ -215,9 +237,14 @@ def test_batched_inputs_sample_weight(self, sample_weight, expected_output): def test_2d_sample_weight( self, y_true, y_pred, sample_weight, expected_output ): - dcg_metric = DCG() - - dcg_metric.update_state(y_true, y_pred, sample_weight=sample_weight) + with self._strategy.scope(): + dcg_metric = DCG() + tpu_test_utils.run_with_strategy( + self._strategy, + dcg_metric.update_state, + y_true, + y_pred, + sample_weight=sample_weight) result = dcg_metric.result() self.assertAllClose(result, expected_output) @@ -278,9 +305,14 @@ def test_2d_sample_weight( ), ) def test_masking(self, y_true, y_pred, sample_weight, expected_output): - dcg_metric = DCG() - - dcg_metric.update_state(y_true, y_pred, sample_weight=sample_weight) + with self._strategy.scope(): + dcg_metric = DCG() + tpu_test_utils.run_with_strategy( + self._strategy, + dcg_metric.update_state, + y_true, + y_pred, + sample_weight=sample_weight) result = dcg_metric.result() self.assertAllClose(result, expected_output) @@ -291,27 +323,41 @@ def test_masking(self, y_true, y_pred, sample_weight, expected_output): ("4", 4, 3.39040), ) def test_k(self, k, exp_value): - dcg_metric = DCG(k=k) + with self._strategy.scope(): + dcg_metric = DCG(k=k) + tpu_test_utils.run_with_strategy( + self._strategy, + dcg_metric.update_state, + self.y_true_batched, + self.y_pred_batched + ) dcg_metric.update_state(self.y_true_batched, self.y_pred_batched) result = dcg_metric.result() self.assertAllClose(result, exp_value, rtol=1e-5) def test_statefulness(self): - dcg_metric = DCG() + with self._strategy.scope(): + dcg_metric = DCG() # Batch 1 - dcg_metric.update_state( - self.y_true_batched[:2], self.y_pred_batched[:2] + tpu_test_utils.run_with_strategy( + self._strategy, + dcg_metric.update_state, + self.y_true_batched[:2], + self.y_pred_batched[:2] ) result = dcg_metric.result() self.assertAllClose( result, - sum([_compute_dcg([1], [1]), _compute_dcg([3, 2, 1], [1, 3, 4])]) - / 2, + sum([_compute_dcg([1], [1]), + _compute_dcg([3, 2, 1], [1, 3, 4])]) / 2, ) # Batch 2 - dcg_metric.update_state( - self.y_true_batched[2:], self.y_pred_batched[2:] + tpu_test_utils.run_with_strategy( + self._strategy, + dcg_metric.update_state, + self.y_true_batched[2:], + self.y_pred_batched[2:] ) result = dcg_metric.result() self.assertAllClose(result, self.expected_output_batched) @@ -322,7 +368,8 @@ def test_statefulness(self): self.assertAllClose(result, 0.0) def test_serialization(self): - metric = DCG() + with self._strategy.scope(): + metric = DCG() restored = deserialize(serialize(metric)) self.assertDictEqual(metric.get_config(), restored.get_config()) @@ -333,10 +380,15 @@ def linear_gain_fn(label): def inverse_discount_fn(rank): return ops.divide(1.0, rank) - dcg_metric = DCG( - gain_fn=linear_gain_fn, rank_discount_fn=inverse_discount_fn - ) - dcg_metric.update_state(self.y_true_batched, self.y_pred_batched) + with self._strategy.scope(): + dcg_metric = DCG( + gain_fn=linear_gain_fn, rank_discount_fn=inverse_discount_fn + ) + tpu_test_utils.run_with_strategy( + self._strategy, + dcg_metric.update_state, + self.y_true_batched, + self.y_pred_batched) result = dcg_metric.result() expected_output = ( @@ -345,16 +397,24 @@ def inverse_discount_fn(rank): self.assertAllClose(result, expected_output, rtol=1e-5) def test_model_evaluate(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) + with self._strategy.scope(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[DCG()], - optimizer="adam", - ) - model.evaluate( - x=keras.random.normal((2, 20)), - y=keras.random.randint((2, 5), minval=0, maxval=4), - ) + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[DCG()], + optimizer="adam", + ) + + x_data = keras.random.normal((2, 20)) + y_data = keras.random.randint((2, 5), minval=0, maxval=4) + + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) + dataset = dataset.batch(self._strategy.num_replicas_in_sync if isinstance(self._strategy, tf.distribute.Strategy) else 1) + + if isinstance(self._strategy, tf.distribute.TPUStrategy): + dataset = self._strategy.experimental_distribute_dataset(dataset) + + model.evaluate(dataset, steps=2) diff --git a/keras_rs/src/metrics/mean_average_precision_test.py b/keras_rs/src/metrics/mean_average_precision_test.py index 9c16d25e..c0695c79 100644 --- a/keras_rs/src/metrics/mean_average_precision_test.py +++ b/keras_rs/src/metrics/mean_average_precision_test.py @@ -6,10 +6,16 @@ from keras_rs.src import testing from keras_rs.src.metrics.mean_average_precision import MeanAveragePrecision +from keras_rs.src.utils import tpu_test_utils +import tensorflow as tf class MeanAveragePrecisionTest(testing.TestCase, parameterized.TestCase): def setUp(self): + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.y_true_batched = ops.array( [ [0, 0, 1, 0], @@ -30,18 +36,19 @@ def setUp(self): ) def test_invalid_k_init(self): - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - MeanAveragePrecision(k=0) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - MeanAveragePrecision(k=-5) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - MeanAveragePrecision(k=3.5) + with self._strategy.scope(): + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + MeanAveragePrecision(k=0) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + MeanAveragePrecision(k=-5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + MeanAveragePrecision(k=3.5) @parameterized.named_parameters( ( @@ -104,14 +111,25 @@ def test_invalid_k_init(self): def test_unbatched_inputs( self, y_true, y_pred, sample_weight, expected_output ): - map_metric = MeanAveragePrecision() - map_metric.update_state(y_true, y_pred, sample_weight=sample_weight) + with self._strategy.scope(): + map_metric = MeanAveragePrecision() + y_true_t = ops.array(y_true, dtype="float32") + y_pred_t = ops.array(y_pred, dtype="float32") + sw = ops.array(sample_weight, dtype="float32") if sample_weight is not None else None + args = (y_true_t, y_pred_t, sw) if sw is not None else (y_true_t, y_pred_t) + tpu_test_utils.run_with_strategy(self._strategy, map_metric.update_state, *args) result = map_metric.result() self.assertAllClose(result, expected_output) def test_batched_input(self): - map_metric = MeanAveragePrecision() - map_metric.update_state(self.y_true_batched, self.y_pred_batched) + with self._strategy.scope(): + map_metric = MeanAveragePrecision() + tpu_test_utils.run_with_strategy( + self._strategy, + map_metric.update_state, + self.y_true_batched, + self.y_pred_batched + ) result = map_metric.result() self.assertAllClose(result, 0.5625) @@ -121,11 +139,15 @@ def test_batched_input(self): ("1d", [1.0, 0.5, 2.0, 1.0], 0.6), ) def test_batched_inputs_sample_weight(self, sample_weight, expected_output): - map_metric = MeanAveragePrecision() - map_metric.update_state( + with self._strategy.scope(): + map_metric = MeanAveragePrecision() + sw = ops.array(sample_weight, dtype="float32") + tpu_test_utils.run_with_strategy( + self._strategy, + map_metric.update_state, self.y_true_batched, self.y_pred_batched, - sample_weight=sample_weight, + sample_weight=sw, ) result = map_metric.result() self.assertAllClose(result, expected_output) @@ -163,9 +185,14 @@ def test_batched_inputs_sample_weight(self, sample_weight, expected_output): def test_2d_sample_weight( self, y_true, y_pred, sample_weight, expected_output ): - map_metric = MeanAveragePrecision() - - map_metric.update_state(y_true, y_pred, sample_weight=sample_weight) + with self._strategy.scope(): + map_metric = MeanAveragePrecision() + tpu_test_utils.run_with_strategy( + self._strategy, + map_metric.update_state, + y_true, + y_pred, + sample_weight=sample_weight) result = map_metric.result() self.assertAllClose(result, expected_output) @@ -206,9 +233,14 @@ def test_2d_sample_weight( ), ) def test_masking(self, y_true, y_pred, sample_weight, expected_output): - map_metric = MeanAveragePrecision() - - map_metric.update_state(y_true, y_pred, sample_weight=sample_weight) + with self._strategy.scope(): + map_metric = MeanAveragePrecision() + tpu_test_utils.run_with_strategy( + self._strategy, + map_metric.update_state, + y_true, + y_pred, + sample_weight=sample_weight) result = map_metric.result() self.assertAllClose(result, expected_output) @@ -225,17 +257,24 @@ def test_k(self, k, expected_map): self.assertAllClose(result, expected_map) def test_statefulness(self): - map_metric = MeanAveragePrecision() - # Batch 1: First two lists - map_metric.update_state( - self.y_true_batched[:2], self.y_pred_batched[:2] + with self._strategy.scope(): + map_metric = MeanAveragePrecision() + # Batch 1: First two lists + tpu_test_utils.run_with_strategy( + self._strategy, + map_metric.update_state, + self.y_true_batched[:2], + self.y_pred_batched[:2] ) result = map_metric.result() self.assertAllClose(result, 0.75) # Batch 2: Last two lists - map_metric.update_state( - self.y_true_batched[2:], self.y_pred_batched[2:] + tpu_test_utils.run_with_strategy( + self._strategy, + map_metric.update_state, + self.y_true_batched[2:], + self.y_pred_batched[2:] ) result = map_metric.result() self.assertAllClose(result, 0.5625) @@ -250,8 +289,11 @@ def test_statefulness(self): ("weight_0", 0.0, 0.0), ) def test_scalar_sample_weight(self, sample_weight, expected_output): - map_metric = MeanAveragePrecision() - map_metric.update_state( + with self._strategy.scope(): + map_metric = MeanAveragePrecision() + tpu_test_utils.run_with_strategy( + self._strategy, + map_metric.update_state, self.y_true_batched, self.y_pred_batched, sample_weight=sample_weight, @@ -260,9 +302,12 @@ def test_scalar_sample_weight(self, sample_weight, expected_output): self.assertAllClose(result, expected_output) def test_1d_sample_weight(self): - map_metric = MeanAveragePrecision() + with self._strategy.scope(): + map_metric = MeanAveragePrecision() sample_weight = ops.array([1.0, 0.5, 2.0, 1.0], dtype="float32") - map_metric.update_state( + tpu_test_utils.run_with_strategy( + self._strategy, + map_metric.update_state, self.y_true_batched, self.y_pred_batched, sample_weight=sample_weight, @@ -271,21 +316,29 @@ def test_1d_sample_weight(self): self.assertAllClose(result, 0.6) def test_serialization(self): - metric = MeanAveragePrecision() + with self._strategy.scope(): + metric = MeanAveragePrecision() restored = deserialize(serialize(metric)) self.assertDictEqual(metric.get_config(), restored.get_config()) def test_model_evaluate(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) + with self._strategy.scope(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[MeanAveragePrecision()], - optimizer="adam", - ) - model.evaluate( - x=keras.random.normal((2, 20)), - y=keras.random.randint((2, 5), minval=0, maxval=4), - ) + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[MeanAveragePrecision()], + optimizer="adam", + ) + x_data = keras.random.normal((2, 20)) + y_data = keras.random.randint((2, 5), minval=0, maxval=4) + + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) + dataset = dataset.batch(self._strategy.num_replicas_in_sync if isinstance(self._strategy, tf.distribute.Strategy) else 1) + + if isinstance(self._strategy, tf.distribute.TPUStrategy): + dataset = self._strategy.experimental_distribute_dataset(dataset) + + model.evaluate(dataset, steps=2) diff --git a/keras_rs/src/metrics/mean_reciprocal_rank_test.py b/keras_rs/src/metrics/mean_reciprocal_rank_test.py index 02940c36..06ea5853 100644 --- a/keras_rs/src/metrics/mean_reciprocal_rank_test.py +++ b/keras_rs/src/metrics/mean_reciprocal_rank_test.py @@ -6,10 +6,16 @@ from keras_rs.src import testing from keras_rs.src.metrics.mean_reciprocal_rank import MeanReciprocalRank +from keras_rs.src.utils import tpu_test_utils +import tensorflow as tf class MeanReciprocalRankTest(testing.TestCase, parameterized.TestCase): def setUp(self): + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.y_true_batched = ops.array( [ [0, 0, 1, 0], @@ -30,18 +36,19 @@ def setUp(self): ) def test_invalid_k_init(self): - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - MeanReciprocalRank(k=0) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - MeanReciprocalRank(k=-5) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - MeanReciprocalRank(k=3.5) # type: ignore + with self._strategy.scope(): + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + MeanReciprocalRank(k=0) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + MeanReciprocalRank(k=-5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + MeanReciprocalRank(k=3.5) # type: ignore @parameterized.named_parameters( ( @@ -104,14 +111,22 @@ def test_invalid_k_init(self): def test_unbatched_inputs( self, y_true, y_pred, sample_weight, expected_output ): - mrr_metric = MeanReciprocalRank() - mrr_metric.update_state(y_true, y_pred, sample_weight=sample_weight) + with self._strategy.scope(): + mrr_metric = MeanReciprocalRank() + tpu_test_utils.run_with_strategy( + self._strategy, + mrr_metric.update_state, + y_true_t, + y_pred_t, + sample_weight=sample_weight + ) result = mrr_metric.result() self.assertAllClose(result, expected_output) def test_batched_input(self): - mrr_metric = MeanReciprocalRank() - mrr_metric.update_state(self.y_true_batched, self.y_pred_batched) + with self._strategy.scope(): + mrr_metric = MeanReciprocalRank() + tpu_test_utils.run_with_strategy(mrr_metric.update_state, self.y_true_batched, self.y_pred_batched) result = mrr_metric.result() self.assertAllClose(result, 0.625) @@ -121,12 +136,14 @@ def test_batched_input(self): ("1d", [1.0, 0.5, 2.0, 1.0], 0.675), ) def test_batched_inputs_sample_weight(self, sample_weight, expected_output): - mrr_metric = MeanReciprocalRank() - mrr_metric.update_state( + with self._strategy.scope(): + mrr_metric = MeanReciprocalRank() + tpu_test_utils.run_with_strategy( + mrr_metric.update_state, self.y_true_batched, self.y_pred_batched, sample_weight=sample_weight, - ) + ) result = mrr_metric.result() self.assertAllClose(result, expected_output) @@ -163,9 +180,14 @@ def test_batched_inputs_sample_weight(self, sample_weight, expected_output): def test_2d_sample_weight( self, y_true, y_pred, sample_weight, expected_output ): - mrr_metric = MeanReciprocalRank() - - mrr_metric.update_state(y_true, y_pred, sample_weight=sample_weight) + with self._strategy.scope(): + mrr_metric = MeanReciprocalRank() + tpu_test_utils.run_with_strategy( + mrr_metric.update_state, + y_true, + y_pred, + sample_weight=sample_weight + ) result = mrr_metric.result() self.assertAllClose(result, expected_output) @@ -206,9 +228,9 @@ def test_2d_sample_weight( ), ) def test_masking(self, y_true, y_pred, sample_weight, expected_output): - mrr_metric = MeanReciprocalRank() - - mrr_metric.update_state(y_true, y_pred, sample_weight=sample_weight) + with self._strategy.scope(): + mrr_metric = MeanReciprocalRank() + tpu_test_utils.run_with_strategy(mrr_metric.update_state, y_true, y_pred, sample_weight=sample_weight) result = mrr_metric.result() self.assertAllClose(result, expected_output) @@ -216,24 +238,22 @@ def test_masking(self, y_true, y_pred, sample_weight, expected_output): ("1", 1, 0.5), ("2", 2, 0.625), ("3", 3, 0.625), ("4", 4, 0.625) ) def test_k(self, k, expected_mrr): - mrr_metric = MeanReciprocalRank(k=k) - mrr_metric.update_state(self.y_true_batched, self.y_pred_batched) + with self._strategy.scope(): + mrr_metric = MeanReciprocalRank(k=k) + tpu_test_utils.run_with_strategy(mrr_metric.update_state, self.y_true_batched, self.y_pred_batched) result = mrr_metric.result() self.assertAllClose(result, expected_mrr) def test_statefulness(self): - mrr_metric = MeanReciprocalRank() + with self._strategy.scope(): + mrr_metric = MeanReciprocalRank() # Batch 1: First two lists - mrr_metric.update_state( - self.y_true_batched[:2], self.y_pred_batched[:2] - ) + tpu_test_utils.run_with_strategy(mrr_metric.update_state, self.y_true_batched[:2], self.y_pred_batched[:2]) result = mrr_metric.result() self.assertAllClose(result, 0.75) # Batch 2: Last two lists - mrr_metric.update_state( - self.y_true_batched[2:], self.y_pred_batched[2:] - ) + tpu_test_utils.run_with_strategy(mrr_metric.update_state, self.y_true_batched[2:], self.y_pred_batched[2:]) result = mrr_metric.result() self.assertAllClose(result, 0.625) @@ -243,21 +263,30 @@ def test_statefulness(self): self.assertAllClose(result, 0.0) def test_serialization(self): - metric = MeanReciprocalRank() + with self._strategy.scope(): + metric = MeanReciprocalRank() restored = deserialize(serialize(metric)) self.assertDictEqual(metric.get_config(), restored.get_config()) def test_model_evaluate(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) + with self._strategy.scope(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[MeanReciprocalRank()], - optimizer="adam", - ) - model.evaluate( - x=keras.random.normal((2, 20)), - y=keras.random.randint((2, 5), minval=0, maxval=4), - ) + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[MeanReciprocalRank()], + optimizer="adam", + ) + + x_data = keras.random.normal((2, 20)) + y_data = keras.random.randint((2, 5), minval=0, maxval=4) + + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) + dataset = dataset.batch(self._strategy.num_replicas_in_sync if isinstance(self._strategy, tf.distribute.Strategy) else 1) + + if isinstance(self._strategy, tf.distribute.TPUStrategy): + dataset = self._strategy.experimental_distribute_dataset(dataset) + + model.evaluate(dataset, steps=2) diff --git a/keras_rs/src/metrics/ndcg_test.py b/keras_rs/src/metrics/ndcg_test.py index 8c86e01c..ce29aa16 100644 --- a/keras_rs/src/metrics/ndcg_test.py +++ b/keras_rs/src/metrics/ndcg_test.py @@ -8,6 +8,8 @@ from keras_rs.src import testing from keras_rs.src.metrics.ndcg import NDCG +from keras_rs.src.utils import tpu_test_utils +import tensorflow as tf def _compute_dcg(labels, ranks): @@ -19,6 +21,10 @@ def _compute_dcg(labels, ranks): class NDCGTest(testing.TestCase, parameterized.TestCase): def setUp(self): + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.y_true_batched = ops.array( [ [0, 0, 1, 0], @@ -57,18 +63,19 @@ def setUp(self): self.expected_output_batched = sum(expected_ndcg) / 4.0 def test_invalid_k_init(self): - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - NDCG(k=0) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - NDCG(k=-5) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - NDCG(k=3.5) + with self._strategy.scope(): + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + NDCG(k=0) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + NDCG(k=-5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + NDCG(k=3.5) @parameterized.named_parameters( ( diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py index 8e7af81a..3dd9493d 100644 --- a/keras_rs/src/utils/tpu_test_utils.py +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -63,18 +63,13 @@ def get_tpu_strategy(test_case): return DummyStrategy() -def run_with_strategy(strategy, fn, *args, jit_compile=False): +def run_with_strategy(strategy, fn, *args, jit_compile=False, **kwargs): """Wrapper for running a function under a strategy.""" if keras.backend.backend() == "tensorflow": - @tf.function(jit_compile=jit_compile) def tf_function_wrapper(*tf_function_args): - def strategy_fn(*strategy_fn_args): - return fn(*strategy_fn_args) - - return strategy.run(strategy_fn, args=tf_function_args) - + return strategy.run(fn, args=tf_function_args, kwargs=kwargs) return tf_function_wrapper(*args) else: assert not jit_compile - return fn(*args) + return fn(*args, **kwargs) From 13fa93de1abb9b45482588de9c3342cf85ef5e52 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Mon, 24 Nov 2025 09:50:45 +0000 Subject: [PATCH 05/29] support all tpu tests in metrics and format --- keras_rs/src/losses/list_mle_loss_test.py | 10 +- .../src/losses/pairwise_hinge_loss_test.py | 8 +- .../src/losses/pairwise_logistic_loss_test.py | 10 +- .../pairwise_mean_squared_error_test.py | 12 +- .../pairwise_soft_zero_one_loss_test.py | 10 +- keras_rs/src/metrics/dcg_test.py | 43 +++-- .../metrics/mean_average_precision_test.py | 36 +++-- .../src/metrics/mean_reciprocal_rank_test.py | 58 +++++-- keras_rs/src/metrics/ndcg_test.py | 152 ++++++++++++------ keras_rs/src/metrics/precision_at_k_test.py | 150 ++++++++++++----- keras_rs/src/metrics/recall_at_k_test.py | 150 ++++++++++++----- keras_rs/src/utils/tpu_test_utils.py | 2 + 12 files changed, 451 insertions(+), 190 deletions(-) diff --git a/keras_rs/src/losses/list_mle_loss_test.py b/keras_rs/src/losses/list_mle_loss_test.py index 57212a5c..bdf4ba92 100644 --- a/keras_rs/src/losses/list_mle_loss_test.py +++ b/keras_rs/src/losses/list_mle_loss_test.py @@ -1,4 +1,5 @@ import keras +import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.losses import deserialize @@ -6,7 +7,6 @@ from keras_rs.src import testing from keras_rs.src.losses.list_mle_loss import ListMLELoss -import tensorflow as tf from keras_rs.src.utils import tpu_test_utils @@ -101,12 +101,14 @@ def create_model(): model = create_model() else: model = create_model() - + x_data = keras.random.normal((2, 20)) y_data = keras.random.randint((2, 5), minval=0, maxval=2) - dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch(self._strategy.num_replicas_in_sync if self._strategy else 1) + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch( + self._strategy.num_replicas_in_sync if self._strategy else 1 + ) if self._strategy: - dataset = self._strategy.experimental_distribute_dataset(dataset) + dataset = self._strategy.experimental_distribute_dataset(dataset) model.fit(dataset, epochs=1, steps_per_epoch=2) diff --git a/keras_rs/src/losses/pairwise_hinge_loss_test.py b/keras_rs/src/losses/pairwise_hinge_loss_test.py index 991ff310..a7c6e167 100644 --- a/keras_rs/src/losses/pairwise_hinge_loss_test.py +++ b/keras_rs/src/losses/pairwise_hinge_loss_test.py @@ -1,4 +1,5 @@ import keras +import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.losses import deserialize @@ -7,7 +8,6 @@ from keras_rs.src import testing from keras_rs.src.losses.pairwise_hinge_loss import PairwiseHingeLoss from keras_rs.src.utils import tpu_test_utils -import tensorflow as tf class PairwiseHingeLossTest(testing.TestCase, parameterized.TestCase): @@ -131,9 +131,11 @@ def create_model(): x_data = keras.random.normal((2, 20)) y_data = keras.random.randint((2, 5), minval=0, maxval=2) - dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch(self._strategy.num_replicas_in_sync if self._strategy else 1) + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch( + self._strategy.num_replicas_in_sync if self._strategy else 1 + ) if self._strategy: - dataset = self._strategy.experimental_distribute_dataset(dataset) + dataset = self._strategy.experimental_distribute_dataset(dataset) model.fit(dataset, epochs=1, steps_per_epoch=2) diff --git a/keras_rs/src/losses/pairwise_logistic_loss_test.py b/keras_rs/src/losses/pairwise_logistic_loss_test.py index 3e853fcb..07c743d8 100644 --- a/keras_rs/src/losses/pairwise_logistic_loss_test.py +++ b/keras_rs/src/losses/pairwise_logistic_loss_test.py @@ -1,4 +1,5 @@ import keras +import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.losses import deserialize @@ -7,7 +8,6 @@ from keras_rs.src import testing from keras_rs.src.losses.pairwise_logistic_loss import PairwiseLogisticLoss from keras_rs.src.utils import tpu_test_utils -import tensorflow as tf class PairwiseLogisticLossTest(testing.TestCase, parameterized.TestCase): @@ -127,12 +127,14 @@ def create_model(): model = create_model() else: model = create_model() - + x_data = keras.random.normal((2, 20)) y_data = keras.random.randint((2, 5), minval=0, maxval=2) - dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch(self._strategy.num_replicas_in_sync if self._strategy else 1) + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch( + self._strategy.num_replicas_in_sync if self._strategy else 1 + ) if self._strategy: - dataset = self._strategy.experimental_distribute_dataset(dataset) + dataset = self._strategy.experimental_distribute_dataset(dataset) model.fit(dataset, epochs=1, steps_per_epoch=2) diff --git a/keras_rs/src/losses/pairwise_mean_squared_error_test.py b/keras_rs/src/losses/pairwise_mean_squared_error_test.py index 2d299b76..bdbb3da3 100644 --- a/keras_rs/src/losses/pairwise_mean_squared_error_test.py +++ b/keras_rs/src/losses/pairwise_mean_squared_error_test.py @@ -1,4 +1,5 @@ import keras +import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.losses import deserialize @@ -8,7 +9,6 @@ from keras_rs.src.losses.pairwise_mean_squared_error import ( PairwiseMeanSquaredError, ) -import tensorflow as tf from keras_rs.src.utils import tpu_test_utils @@ -121,18 +121,20 @@ def create_model(): model = keras.Model(inputs=inputs, outputs=outputs) model.compile(loss=PairwiseMeanSquaredError(), optimizer="adam") return model - + if self._strategy: with self._strategy.scope(): model = create_model() else: model = create_model() - + x_data = keras.random.normal((2, 20)) y_data = keras.random.randint((2, 5), minval=0, maxval=2) - dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch(self._strategy.num_replicas_in_sync if self._strategy else 1) + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch( + self._strategy.num_replicas_in_sync if self._strategy else 1 + ) if self._strategy: - dataset = self._strategy.experimental_distribute_dataset(dataset) + dataset = self._strategy.experimental_distribute_dataset(dataset) model.fit(dataset, epochs=1, steps_per_epoch=2) diff --git a/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py b/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py index b6fb7184..0ca05646 100644 --- a/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py +++ b/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py @@ -1,4 +1,5 @@ import keras +import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.losses import deserialize @@ -8,7 +9,6 @@ from keras_rs.src.losses.pairwise_soft_zero_one_loss import ( PairwiseSoftZeroOneLoss, ) -import tensorflow as tf from keras_rs.src.utils import tpu_test_utils @@ -124,7 +124,7 @@ def create_model(): model = keras.Model(inputs=inputs, outputs=outputs) model.compile(loss=PairwiseSoftZeroOneLoss(), optimizer="adam") return model - + if self._strategy: with self._strategy.scope(): model = create_model() @@ -133,9 +133,11 @@ def create_model(): x_data = keras.random.normal((2, 20)) y_data = keras.random.randint((2, 5), minval=0, maxval=2) - dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch(self._strategy.num_replicas_in_sync if self._strategy else 1) + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch( + self._strategy.num_replicas_in_sync if self._strategy else 1 + ) if self._strategy: - dataset = self._strategy.experimental_distribute_dataset(dataset) + dataset = self._strategy.experimental_distribute_dataset(dataset) model.fit(dataset, epochs=1, steps_per_epoch=2) diff --git a/keras_rs/src/metrics/dcg_test.py b/keras_rs/src/metrics/dcg_test.py index cee15915..1a1341bc 100644 --- a/keras_rs/src/metrics/dcg_test.py +++ b/keras_rs/src/metrics/dcg_test.py @@ -1,6 +1,7 @@ import math import keras +import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.metrics import deserialize @@ -9,7 +10,6 @@ from keras_rs.src import testing from keras_rs.src.metrics.dcg import DCG from keras_rs.src.utils import tpu_test_utils -import tensorflow as tf def _compute_dcg(labels, ranks): @@ -142,9 +142,17 @@ def test_unbatched_inputs( dcg_metric = DCG() y_true_t = ops.array(y_true, dtype="float32") y_pred_t = ops.array(y_pred, dtype="float32") - sw = ops.array(sample_weight, dtype="float32") if sample_weight is not None else None - args = (y_true_t, y_pred_t, sw) if sw is not None else (y_true_t, y_pred_t) - tpu_test_utils.run_with_strategy(self._strategy, dcg_metric.update_state, *args) + sw = ( + ops.array(sample_weight, dtype="float32") + if sample_weight is not None + else None + ) + args = ( + (y_true_t, y_pred_t, sw) if sw is not None else (y_true_t, y_pred_t) + ) + tpu_test_utils.run_with_strategy( + self._strategy, dcg_metric.update_state, *args + ) result = dcg_metric.result() self.assertAllClose(result, expected_output) @@ -155,7 +163,7 @@ def test_batched_input(self): self._strategy, dcg_metric.update_state, self.y_true_batched, - self.y_pred_batched + self.y_pred_batched, ) result = dcg_metric.result() self.assertAllClose(result, self.expected_output_batched) @@ -244,7 +252,8 @@ def test_2d_sample_weight( dcg_metric.update_state, y_true, y_pred, - sample_weight=sample_weight) + sample_weight=sample_weight, + ) result = dcg_metric.result() self.assertAllClose(result, expected_output) @@ -312,7 +321,8 @@ def test_masking(self, y_true, y_pred, sample_weight, expected_output): dcg_metric.update_state, y_true, y_pred, - sample_weight=sample_weight) + sample_weight=sample_weight, + ) result = dcg_metric.result() self.assertAllClose(result, expected_output) @@ -329,7 +339,7 @@ def test_k(self, k, exp_value): self._strategy, dcg_metric.update_state, self.y_true_batched, - self.y_pred_batched + self.y_pred_batched, ) dcg_metric.update_state(self.y_true_batched, self.y_pred_batched) result = dcg_metric.result() @@ -343,13 +353,13 @@ def test_statefulness(self): self._strategy, dcg_metric.update_state, self.y_true_batched[:2], - self.y_pred_batched[:2] + self.y_pred_batched[:2], ) result = dcg_metric.result() self.assertAllClose( result, - sum([_compute_dcg([1], [1]), - _compute_dcg([3, 2, 1], [1, 3, 4])]) / 2, + sum([_compute_dcg([1], [1]), _compute_dcg([3, 2, 1], [1, 3, 4])]) + / 2, ) # Batch 2 @@ -357,7 +367,7 @@ def test_statefulness(self): self._strategy, dcg_metric.update_state, self.y_true_batched[2:], - self.y_pred_batched[2:] + self.y_pred_batched[2:], ) result = dcg_metric.result() self.assertAllClose(result, self.expected_output_batched) @@ -388,7 +398,8 @@ def inverse_discount_fn(rank): self._strategy, dcg_metric.update_state, self.y_true_batched, - self.y_pred_batched) + self.y_pred_batched, + ) result = dcg_metric.result() expected_output = ( @@ -412,7 +423,11 @@ def test_model_evaluate(self): y_data = keras.random.randint((2, 5), minval=0, maxval=4) dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) - dataset = dataset.batch(self._strategy.num_replicas_in_sync if isinstance(self._strategy, tf.distribute.Strategy) else 1) + dataset = dataset.batch( + self._strategy.num_replicas_in_sync + if isinstance(self._strategy, tf.distribute.Strategy) + else 1 + ) if isinstance(self._strategy, tf.distribute.TPUStrategy): dataset = self._strategy.experimental_distribute_dataset(dataset) diff --git a/keras_rs/src/metrics/mean_average_precision_test.py b/keras_rs/src/metrics/mean_average_precision_test.py index c0695c79..39c6bcb8 100644 --- a/keras_rs/src/metrics/mean_average_precision_test.py +++ b/keras_rs/src/metrics/mean_average_precision_test.py @@ -1,4 +1,5 @@ import keras +import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.metrics import deserialize @@ -7,7 +8,6 @@ from keras_rs.src import testing from keras_rs.src.metrics.mean_average_precision import MeanAveragePrecision from keras_rs.src.utils import tpu_test_utils -import tensorflow as tf class MeanAveragePrecisionTest(testing.TestCase, parameterized.TestCase): @@ -115,9 +115,17 @@ def test_unbatched_inputs( map_metric = MeanAveragePrecision() y_true_t = ops.array(y_true, dtype="float32") y_pred_t = ops.array(y_pred, dtype="float32") - sw = ops.array(sample_weight, dtype="float32") if sample_weight is not None else None - args = (y_true_t, y_pred_t, sw) if sw is not None else (y_true_t, y_pred_t) - tpu_test_utils.run_with_strategy(self._strategy, map_metric.update_state, *args) + sw = ( + ops.array(sample_weight, dtype="float32") + if sample_weight is not None + else None + ) + args = ( + (y_true_t, y_pred_t, sw) if sw is not None else (y_true_t, y_pred_t) + ) + tpu_test_utils.run_with_strategy( + self._strategy, map_metric.update_state, *args + ) result = map_metric.result() self.assertAllClose(result, expected_output) @@ -128,7 +136,7 @@ def test_batched_input(self): self._strategy, map_metric.update_state, self.y_true_batched, - self.y_pred_batched + self.y_pred_batched, ) result = map_metric.result() self.assertAllClose(result, 0.5625) @@ -192,7 +200,8 @@ def test_2d_sample_weight( map_metric.update_state, y_true, y_pred, - sample_weight=sample_weight) + sample_weight=sample_weight, + ) result = map_metric.result() self.assertAllClose(result, expected_output) @@ -240,7 +249,8 @@ def test_masking(self, y_true, y_pred, sample_weight, expected_output): map_metric.update_state, y_true, y_pred, - sample_weight=sample_weight) + sample_weight=sample_weight, + ) result = map_metric.result() self.assertAllClose(result, expected_output) @@ -259,12 +269,12 @@ def test_k(self, k, expected_map): def test_statefulness(self): with self._strategy.scope(): map_metric = MeanAveragePrecision() - # Batch 1: First two lists + # Batch 1: First two lists tpu_test_utils.run_with_strategy( self._strategy, map_metric.update_state, self.y_true_batched[:2], - self.y_pred_batched[:2] + self.y_pred_batched[:2], ) result = map_metric.result() self.assertAllClose(result, 0.75) @@ -274,7 +284,7 @@ def test_statefulness(self): self._strategy, map_metric.update_state, self.y_true_batched[2:], - self.y_pred_batched[2:] + self.y_pred_batched[2:], ) result = map_metric.result() self.assertAllClose(result, 0.5625) @@ -336,7 +346,11 @@ def test_model_evaluate(self): y_data = keras.random.randint((2, 5), minval=0, maxval=4) dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) - dataset = dataset.batch(self._strategy.num_replicas_in_sync if isinstance(self._strategy, tf.distribute.Strategy) else 1) + dataset = dataset.batch( + self._strategy.num_replicas_in_sync + if isinstance(self._strategy, tf.distribute.Strategy) + else 1 + ) if isinstance(self._strategy, tf.distribute.TPUStrategy): dataset = self._strategy.experimental_distribute_dataset(dataset) diff --git a/keras_rs/src/metrics/mean_reciprocal_rank_test.py b/keras_rs/src/metrics/mean_reciprocal_rank_test.py index 06ea5853..c78aef34 100644 --- a/keras_rs/src/metrics/mean_reciprocal_rank_test.py +++ b/keras_rs/src/metrics/mean_reciprocal_rank_test.py @@ -1,4 +1,5 @@ import keras +import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.metrics import deserialize @@ -7,7 +8,6 @@ from keras_rs.src import testing from keras_rs.src.metrics.mean_reciprocal_rank import MeanReciprocalRank from keras_rs.src.utils import tpu_test_utils -import tensorflow as tf class MeanReciprocalRankTest(testing.TestCase, parameterized.TestCase): @@ -116,9 +116,9 @@ def test_unbatched_inputs( tpu_test_utils.run_with_strategy( self._strategy, mrr_metric.update_state, - y_true_t, - y_pred_t, - sample_weight=sample_weight + y_true, + y_pred, + sample_weight=sample_weight, ) result = mrr_metric.result() self.assertAllClose(result, expected_output) @@ -126,7 +126,12 @@ def test_unbatched_inputs( def test_batched_input(self): with self._strategy.scope(): mrr_metric = MeanReciprocalRank() - tpu_test_utils.run_with_strategy(mrr_metric.update_state, self.y_true_batched, self.y_pred_batched) + tpu_test_utils.run_with_strategy( + self._strategy, + mrr_metric.update_state, + self.y_true_batched, + self.y_pred_batched, + ) result = mrr_metric.result() self.assertAllClose(result, 0.625) @@ -138,12 +143,13 @@ def test_batched_input(self): def test_batched_inputs_sample_weight(self, sample_weight, expected_output): with self._strategy.scope(): mrr_metric = MeanReciprocalRank() - tpu_test_utils.run_with_strategy( + tpu_test_utils.run_with_strategy( + self._strategy, mrr_metric.update_state, self.y_true_batched, self.y_pred_batched, sample_weight=sample_weight, - ) + ) result = mrr_metric.result() self.assertAllClose(result, expected_output) @@ -183,10 +189,11 @@ def test_2d_sample_weight( with self._strategy.scope(): mrr_metric = MeanReciprocalRank() tpu_test_utils.run_with_strategy( + self._strategy, mrr_metric.update_state, y_true, y_pred, - sample_weight=sample_weight + sample_weight=sample_weight, ) result = mrr_metric.result() self.assertAllClose(result, expected_output) @@ -230,7 +237,13 @@ def test_2d_sample_weight( def test_masking(self, y_true, y_pred, sample_weight, expected_output): with self._strategy.scope(): mrr_metric = MeanReciprocalRank() - tpu_test_utils.run_with_strategy(mrr_metric.update_state, y_true, y_pred, sample_weight=sample_weight) + tpu_test_utils.run_with_strategy( + self._strategy, + mrr_metric.update_state, + y_true, + y_pred, + sample_weight=sample_weight, + ) result = mrr_metric.result() self.assertAllClose(result, expected_output) @@ -240,7 +253,12 @@ def test_masking(self, y_true, y_pred, sample_weight, expected_output): def test_k(self, k, expected_mrr): with self._strategy.scope(): mrr_metric = MeanReciprocalRank(k=k) - tpu_test_utils.run_with_strategy(mrr_metric.update_state, self.y_true_batched, self.y_pred_batched) + tpu_test_utils.run_with_strategy( + self._strategy, + mrr_metric.update_state, + self.y_true_batched, + self.y_pred_batched, + ) result = mrr_metric.result() self.assertAllClose(result, expected_mrr) @@ -248,12 +266,22 @@ def test_statefulness(self): with self._strategy.scope(): mrr_metric = MeanReciprocalRank() # Batch 1: First two lists - tpu_test_utils.run_with_strategy(mrr_metric.update_state, self.y_true_batched[:2], self.y_pred_batched[:2]) + tpu_test_utils.run_with_strategy( + self._strategy, + mrr_metric.update_state, + self.y_true_batched[:2], + self.y_pred_batched[:2], + ) result = mrr_metric.result() self.assertAllClose(result, 0.75) # Batch 2: Last two lists - tpu_test_utils.run_with_strategy(mrr_metric.update_state, self.y_true_batched[2:], self.y_pred_batched[2:]) + tpu_test_utils.run_with_strategy( + self._strategy, + mrr_metric.update_state, + self.y_true_batched[2:], + self.y_pred_batched[2:], + ) result = mrr_metric.result() self.assertAllClose(result, 0.625) @@ -284,7 +312,11 @@ def test_model_evaluate(self): y_data = keras.random.randint((2, 5), minval=0, maxval=4) dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) - dataset = dataset.batch(self._strategy.num_replicas_in_sync if isinstance(self._strategy, tf.distribute.Strategy) else 1) + dataset = dataset.batch( + self._strategy.num_replicas_in_sync + if isinstance(self._strategy, tf.distribute.Strategy) + else 1 + ) if isinstance(self._strategy, tf.distribute.TPUStrategy): dataset = self._strategy.experimental_distribute_dataset(dataset) diff --git a/keras_rs/src/metrics/ndcg_test.py b/keras_rs/src/metrics/ndcg_test.py index ce29aa16..76b36b0c 100644 --- a/keras_rs/src/metrics/ndcg_test.py +++ b/keras_rs/src/metrics/ndcg_test.py @@ -1,6 +1,7 @@ import math import keras +import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.metrics import deserialize @@ -9,7 +10,6 @@ from keras_rs.src import testing from keras_rs.src.metrics.ndcg import NDCG from keras_rs.src.utils import tpu_test_utils -import tensorflow as tf def _compute_dcg(labels, ranks): @@ -148,15 +148,28 @@ def test_invalid_k_init(self): def test_unbatched_inputs( self, y_true, y_pred, sample_weight, expected_output ): - dcg_metric = NDCG() - dcg_metric.update_state(y_true, y_pred, sample_weight=sample_weight) - result = dcg_metric.result() + with self._strategy.scope(): + ndcg_metric = NDCG() + tpu_test_utils.run_with_strategy( + self._strategy, + ndcg_metric.update_state, + y_true, + y_pred, + sample_weight=sample_weight, + ) + result = ndcg_metric.result() self.assertAllClose(result, expected_output) def test_batched_input(self): - dcg_metric = NDCG() - dcg_metric.update_state(self.y_true_batched, self.y_pred_batched) - result = dcg_metric.result() + with self._strategy.scope(): + ndcg_metric = NDCG() + tpu_test_utils.run_with_strategy( + self._strategy, + ndcg_metric.update_state, + self.y_true_batched, + self.y_pred_batched, + ) + result = ndcg_metric.result() self.assertAllClose(result, self.expected_output_batched) @parameterized.named_parameters( @@ -165,13 +178,16 @@ def test_batched_input(self): ("1d", [1.0, 0.5, 2.0, 1.0], 0.74262), ) def test_batched_inputs_sample_weight(self, sample_weight, expected_output): - dcg_metric = NDCG() - dcg_metric.update_state( + with self._strategy.scope(): + ndcg_metric = NDCG() + tpu_test_utils.run_with_strategy( + self._strategy, + ndcg_metric.update_state, self.y_true_batched, self.y_pred_batched, sample_weight=sample_weight, ) - result = dcg_metric.result() + result = ndcg_metric.result() self.assertAllClose(result, expected_output) @parameterized.named_parameters( @@ -232,10 +248,16 @@ def test_batched_inputs_sample_weight(self, sample_weight, expected_output): def test_2d_sample_weight( self, y_true, y_pred, sample_weight, expected_output ): - dcg_metric = NDCG() - - dcg_metric.update_state(y_true, y_pred, sample_weight=sample_weight) - result = dcg_metric.result() + with self._strategy.scope(): + ndcg_metric = NDCG() + tpu_test_utils.run_with_strategy( + self._strategy, + ndcg_metric.update_state, + y_true, + y_pred, + sample_weight=sample_weight, + ) + result = ndcg_metric.result() self.assertAllClose(result, expected_output) @parameterized.named_parameters( @@ -295,10 +317,16 @@ def test_2d_sample_weight( ), ) def test_masking(self, y_true, y_pred, sample_weight, expected_output): - dcg_metric = NDCG() - - dcg_metric.update_state(y_true, y_pred, sample_weight=sample_weight) - result = dcg_metric.result() + with self._strategy.scope(): + ndcg_metric = NDCG() + tpu_test_utils.run_with_strategy( + self._strategy, + ndcg_metric.update_state, + y_true, + y_pred, + sample_weight=sample_weight, + ) + result = ndcg_metric.result() self.assertAllClose(result, expected_output) @parameterized.named_parameters( @@ -308,18 +336,28 @@ def test_masking(self, y_true, y_pred, sample_weight, expected_output): ("4", 4, 0.7377), ) def test_k(self, k, exp_value): - dcg_metric = NDCG(k=k) - dcg_metric.update_state(self.y_true_batched, self.y_pred_batched) - result = dcg_metric.result() + with self._strategy.scope(): + ndcg_metric = NDCG(k=k) + tpu_test_utils.run_with_strategy( + self._strategy, + ndcg_metric.update_state, + self.y_true_batched, + self.y_pred_batched, + ) + result = ndcg_metric.result() self.assertAllClose(result, exp_value, rtol=1e-5) def test_statefulness(self): - dcg_metric = NDCG() + with self._strategy.scope(): + ndcg_metric = NDCG() # Batch 1 - dcg_metric.update_state( - self.y_true_batched[:2], self.y_pred_batched[:2] + tpu_test_utils.run_with_strategy( + self._strategy, + ndcg_metric.update_state, + self.y_true_batched[:2], + self.y_pred_batched[:2], ) - result = dcg_metric.result() + result = ndcg_metric.result() dcg = [_compute_dcg([1], [1]), _compute_dcg([3, 2, 1], [1, 3, 4])] idcg = [_compute_dcg([1], [1]), _compute_dcg([3, 2, 1], [1, 2, 3])] ndcg = sum([a / b if b != 0.0 else 0.0 for a, b in zip(dcg, idcg)]) / 2 @@ -329,19 +367,23 @@ def test_statefulness(self): ) # Batch 2 - dcg_metric.update_state( - self.y_true_batched[2:], self.y_pred_batched[2:] + tpu_test_utils.run_with_strategy( + self._strategy, + ndcg_metric.update_state, + self.y_true_batched[2:], + self.y_pred_batched[2:], ) - result = dcg_metric.result() + result = ndcg_metric.result() self.assertAllClose(result, self.expected_output_batched) # Reset state - dcg_metric.reset_state() - result = dcg_metric.result() + ndcg_metric.reset_state() + result = ndcg_metric.result() self.assertAllClose(result, 0.0) def test_serialization(self): - metric = NDCG() + with self._strategy.scope(): + metric = NDCG() restored = deserialize(serialize(metric)) self.assertDictEqual(metric.get_config(), restored.get_config()) @@ -352,11 +394,17 @@ def linear_gain_fn(label): def inverse_discount_fn(rank): return ops.divide(1.0, rank) - dcg_metric = NDCG( - gain_fn=linear_gain_fn, rank_discount_fn=inverse_discount_fn + with self._strategy.scope(): + ndcg_metric = NDCG( + gain_fn=linear_gain_fn, rank_discount_fn=inverse_discount_fn + ) + tpu_test_utils.run_with_strategy( + self._strategy, + ndcg_metric.update_state, + self.y_true_batched, + self.y_pred_batched, ) - dcg_metric.update_state(self.y_true_batched, self.y_pred_batched) - result = dcg_metric.result() + result = ndcg_metric.result() dcg = [1 / 1, 3 / 1 + 2 / 3 + 1 / 4, 0, 2 / 1 + 1 / 2] idcg = [1 / 1, 3 / 1 + 2 / 2 + 1 / 3, 0.0, 2 / 1 + 1 / 2] @@ -364,16 +412,28 @@ def inverse_discount_fn(rank): self.assertAllClose(result, ndcg, rtol=1e-5) def test_model_evaluate(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) + with self._strategy.scope(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[NDCG()], - optimizer="adam", - ) - model.evaluate( - x=keras.random.normal((2, 20)), - y=keras.random.randint((2, 5), minval=0, maxval=4), + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[NDCG()], + optimizer="adam", + ) + + x_data = keras.random.normal((2, 20)) + y_data = keras.random.randint((2, 5), minval=0, maxval=4) + + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) + dataset = dataset.batch( + self._strategy.num_replicas_in_sync + if isinstance(self._strategy, tf.distribute.Strategy) + else 1 ) + + if isinstance(self._strategy, tf.distribute.TPUStrategy): + dataset = self._strategy.experimental_distribute_dataset(dataset) + + model.evaluate(dataset, steps=2) diff --git a/keras_rs/src/metrics/precision_at_k_test.py b/keras_rs/src/metrics/precision_at_k_test.py index d83c5c9d..1e8805f6 100644 --- a/keras_rs/src/metrics/precision_at_k_test.py +++ b/keras_rs/src/metrics/precision_at_k_test.py @@ -1,4 +1,5 @@ import keras +import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.metrics import deserialize @@ -6,10 +7,15 @@ from keras_rs.src import testing from keras_rs.src.metrics.precision_at_k import PrecisionAtK +from keras_rs.src.utils import tpu_test_utils class PrecisionAtKTest(testing.TestCase, parameterized.TestCase): def setUp(self): + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.y_true_batched = ops.array( [ [0, 0, 1, 0], @@ -30,18 +36,19 @@ def setUp(self): ) def test_invalid_k_init(self): - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - PrecisionAtK(k=0) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - PrecisionAtK(k=-5) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - PrecisionAtK(k=3.5) + with self._strategy.scope(): + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + PrecisionAtK(k=0) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + PrecisionAtK(k=-5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + PrecisionAtK(k=3.5) @parameterized.named_parameters( ( @@ -90,14 +97,27 @@ def test_invalid_k_init(self): def test_unbatched_inputs( self, y_true, y_pred, sample_weight, expected_output ): - p_at_k = PrecisionAtK(k=3) - p_at_k.update_state(y_true, y_pred, sample_weight=sample_weight) + with self._strategy.scope(): + p_at_k = PrecisionAtK(k=3) + tpu_test_utils.run_with_strategy( + self._strategy, + p_at_k.update_state, + y_true, + y_pred, + sample_weight=sample_weight, + ) result = p_at_k.result() self.assertAllClose(result, expected_output, rtol=1e-6) def test_batched_input(self): - p_at_k = PrecisionAtK(k=3) - p_at_k.update_state(self.y_true_batched, self.y_pred_batched) + with self._strategy.scope(): + p_at_k = PrecisionAtK(k=3) + tpu_test_utils.run_with_strategy( + self._strategy, + p_at_k.update_state, + self.y_true_batched, + self.y_pred_batched, + ) result = p_at_k.result() self.assertAllClose(result, 1 / 3) @@ -107,8 +127,11 @@ def test_batched_input(self): ("1d", [1.0, 0.5, 2.0, 1.0], 0.3), ) def test_batched_inputs_sample_weight(self, sample_weight, expected_output): - p_at_k = PrecisionAtK(k=3) - p_at_k.update_state( + with self._strategy.scope(): + p_at_k = PrecisionAtK(k=3) + tpu_test_utils.run_with_strategy( + self._strategy, + p_at_k.update_state, self.y_true_batched, self.y_pred_batched, sample_weight=sample_weight, @@ -149,8 +172,15 @@ def test_batched_inputs_sample_weight(self, sample_weight, expected_output): def test_2d_sample_weight( self, y_true, y_pred, sample_weight, expected_output ): - p_at_k = PrecisionAtK(k=3) - p_at_k.update_state(y_true, y_pred, sample_weight=sample_weight) + with self._strategy.scope(): + p_at_k = PrecisionAtK(k=3) + tpu_test_utils.run_with_strategy( + self._strategy, + p_at_k.update_state, + y_true, + y_pred, + sample_weight=sample_weight, + ) result = p_at_k.result() self.assertAllClose(result, expected_output, rtol=1e-6) @@ -191,8 +221,15 @@ def test_2d_sample_weight( ), ) def test_masking(self, y_true, y_pred, sample_weight, expected_output): - p_at_k = PrecisionAtK(k=3) - p_at_k.update_state(y_true, y_pred, sample_weight=sample_weight) + with self._strategy.scope(): + p_at_k = PrecisionAtK(k=3) + tpu_test_utils.run_with_strategy( + self._strategy, + p_at_k.update_state, + y_true, + y_pred, + sample_weight=sample_weight, + ) result = p_at_k.result() self.assertAllClose(result, expected_output, rtol=1e-6) @@ -203,18 +240,35 @@ def test_masking(self, y_true, y_pred, sample_weight, expected_output): ("4", 4, 0.375), ) def test_k(self, k, expected_precision): - p_at_k = PrecisionAtK(k=k) - p_at_k.update_state(self.y_true_batched, self.y_pred_batched) + with self._strategy.scope(): + p_at_k = PrecisionAtK(k=k) + tpu_test_utils.run_with_strategy( + self._strategy, + p_at_k.update_state, + self.y_true_batched, + self.y_pred_batched, + ) result = p_at_k.result() self.assertAllClose(result, expected_precision) def test_statefulness(self): - p_at_k = PrecisionAtK(k=3) - p_at_k.update_state(self.y_true_batched[:2], self.y_pred_batched[:2]) + with self._strategy.scope(): + p_at_k = PrecisionAtK(k=3) + tpu_test_utils.run_with_strategy( + self._strategy, + p_at_k.update_state, + self.y_true_batched[:2], + self.y_pred_batched[:2], + ) result = p_at_k.result() self.assertAllClose(result, 0.5, rtol=1e-6) - p_at_k.update_state(self.y_true_batched[2:], self.y_pred_batched[2:]) + tpu_test_utils.run_with_strategy( + self._strategy, + p_at_k.update_state, + self.y_true_batched[2:], + self.y_pred_batched[2:], + ) result = p_at_k.result() self.assertAllClose(result, 1 / 3) @@ -223,24 +277,34 @@ def test_statefulness(self): self.assertAllClose(result, 0.0) def test_serialization(self): - metric = PrecisionAtK(k=3) + with self._strategy.scope(): + metric = PrecisionAtK(k=3) restored = deserialize(serialize(metric)) self.assertDictEqual(metric.get_config(), restored.get_config()) def test_model_evaluate(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[PrecisionAtK(k=3)], - optimizer="adam", - ) - model.evaluate( - x=keras.random.normal((2, 20)), - y=keras.random.randint( - (2, 5), minval=0, maxval=2 - ), # Using 0/1 for y_true - verbose=0, + with self._strategy.scope(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[PrecisionAtK(k=3)], + optimizer="adam", + ) + + x_data = keras.random.normal((2, 20)) + y_data = keras.random.randint((2, 5), minval=0, maxval=2) + + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) + dataset = dataset.batch( + self._strategy.num_replicas_in_sync + if isinstance(self._strategy, tf.distribute.Strategy) + else 1 ) + + if isinstance(self._strategy, tf.distribute.TPUStrategy): + dataset = self._strategy.experimental_distribute_dataset(dataset) + + model.evaluate(dataset, steps=2, verbose=0) diff --git a/keras_rs/src/metrics/recall_at_k_test.py b/keras_rs/src/metrics/recall_at_k_test.py index 1a6672ce..d536bafb 100644 --- a/keras_rs/src/metrics/recall_at_k_test.py +++ b/keras_rs/src/metrics/recall_at_k_test.py @@ -1,4 +1,5 @@ import keras +import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.metrics import deserialize @@ -6,10 +7,15 @@ from keras_rs.src import testing from keras_rs.src.metrics.recall_at_k import RecallAtK +from keras_rs.src.utils import tpu_test_utils class RecallAtKTest(testing.TestCase, parameterized.TestCase): def setUp(self): + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.y_true_batched = ops.array( [ [0, 0, 1, 0], @@ -30,18 +36,19 @@ def setUp(self): ) def test_invalid_k_init(self): - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - RecallAtK(k=0) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - RecallAtK(k=-5) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - RecallAtK(k=3.5) + with self._strategy.scope(): + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + RecallAtK(k=0) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + RecallAtK(k=-5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + RecallAtK(k=3.5) @parameterized.named_parameters( ( @@ -90,14 +97,27 @@ def test_invalid_k_init(self): def test_unbatched_inputs( self, y_true, y_pred, sample_weight, expected_output ): - r_at_k = RecallAtK(k=3) - r_at_k.update_state(y_true, y_pred, sample_weight=sample_weight) + with self._strategy.scope(): + r_at_k = RecallAtK(k=3) + tpu_test_utils.run_with_strategy( + self._strategy, + r_at_k.update_state, + y_true, + y_pred, + sample_weight=sample_weight, + ) result = r_at_k.result() self.assertAllClose(result, expected_output, rtol=1e-6) def test_batched_input(self): - r_at_k = RecallAtK(k=3) - r_at_k.update_state(self.y_true_batched, self.y_pred_batched) + with self._strategy.scope(): + r_at_k = RecallAtK(k=3) + tpu_test_utils.run_with_strategy( + self._strategy, + r_at_k.update_state, + self.y_true_batched, + self.y_pred_batched, + ) result = r_at_k.result() self.assertAllClose(result, 0.541667) @@ -107,8 +127,11 @@ def test_batched_input(self): ("1d", [1.0, 0.5, 2.0, 1.0], 0.55), ) def test_batched_inputs_sample_weight(self, sample_weight, expected_output): - r_at_k = RecallAtK(k=3) - r_at_k.update_state( + with self._strategy.scope(): + r_at_k = RecallAtK(k=3) + tpu_test_utils.run_with_strategy( + self._strategy, + r_at_k.update_state, self.y_true_batched, self.y_pred_batched, sample_weight=sample_weight, @@ -149,8 +172,15 @@ def test_batched_inputs_sample_weight(self, sample_weight, expected_output): def test_2d_sample_weight( self, y_true, y_pred, sample_weight, expected_output ): - r_at_k = RecallAtK(k=2) - r_at_k.update_state(y_true, y_pred, sample_weight=sample_weight) + with self._strategy.scope(): + r_at_k = RecallAtK(k=2) + tpu_test_utils.run_with_strategy( + self._strategy, + r_at_k.update_state, + y_true, + y_pred, + sample_weight=sample_weight, + ) result = r_at_k.result() self.assertAllClose(result, expected_output, rtol=1e-6) @@ -194,8 +224,15 @@ def test_2d_sample_weight( ), ) def test_masking(self, y_true, y_pred, sample_weight, expected_output): - r_at_k = RecallAtK(k=2) - r_at_k.update_state(y_true, y_pred, sample_weight=sample_weight) + with self._strategy.scope(): + r_at_k = RecallAtK(k=2) + tpu_test_utils.run_with_strategy( + self._strategy, + r_at_k.update_state, + y_true, + y_pred, + sample_weight=sample_weight, + ) result = r_at_k.result() self.assertAllClose(result, expected_output, rtol=1e-6) @@ -206,18 +243,35 @@ def test_masking(self, y_true, y_pred, sample_weight, expected_output): ("4", 4, 0.75), ) def test_k(self, k, expected_recall): - r_at_k = RecallAtK(k=k) - r_at_k.update_state(self.y_true_batched, self.y_pred_batched) + with self._strategy.scope(): + r_at_k = RecallAtK(k=k) + tpu_test_utils.run_with_strategy( + self._strategy, + r_at_k.update_state, + self.y_true_batched, + self.y_pred_batched, + ) result = r_at_k.result() self.assertAllClose(result, expected_recall) def test_statefulness(self): - r_at_k = RecallAtK(k=3) - r_at_k.update_state(self.y_true_batched[:2], self.y_pred_batched[:2]) + with self._strategy.scope(): + r_at_k = RecallAtK(k=3) + tpu_test_utils.run_with_strategy( + self._strategy, + r_at_k.update_state, + self.y_true_batched[:2], + self.y_pred_batched[:2], + ) result = r_at_k.result() self.assertAllClose(result, 0.833333, rtol=1e-6) - r_at_k.update_state(self.y_true_batched[2:], self.y_pred_batched[2:]) + tpu_test_utils.run_with_strategy( + self._strategy, + r_at_k.update_state, + self.y_true_batched[2:], + self.y_pred_batched[2:], + ) result = r_at_k.result() self.assertAllClose(result, 0.541667) @@ -226,24 +280,34 @@ def test_statefulness(self): self.assertAllClose(result, 0.0) def test_serialization(self): - metric = RecallAtK(k=3) + with self._strategy.scope(): + metric = RecallAtK(k=3) restored = deserialize(serialize(metric)) self.assertDictEqual(metric.get_config(), restored.get_config()) def test_model_evaluate(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[RecallAtK(k=3)], - optimizer="adam", - ) - model.evaluate( - x=keras.random.normal((2, 20)), - y=keras.random.randint( - (2, 5), minval=0, maxval=2 - ), # Using 0/1 for y_true - verbose=0, + with self._strategy.scope(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[RecallAtK(k=3)], + optimizer="adam", + ) + + x_data = keras.random.normal((2, 20)) + y_data = keras.random.randint((2, 5), minval=0, maxval=2) + + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) + dataset = dataset.batch( + self._strategy.num_replicas_in_sync + if isinstance(self._strategy, tf.distribute.Strategy) + else 1 ) + + if isinstance(self._strategy, tf.distribute.TPUStrategy): + dataset = self._strategy.experimental_distribute_dataset(dataset) + + model.evaluate(dataset, steps=2, verbose=0) diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py index 3dd9493d..e1bdbcfd 100644 --- a/keras_rs/src/utils/tpu_test_utils.py +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -66,9 +66,11 @@ def get_tpu_strategy(test_case): def run_with_strategy(strategy, fn, *args, jit_compile=False, **kwargs): """Wrapper for running a function under a strategy.""" if keras.backend.backend() == "tensorflow": + @tf.function(jit_compile=jit_compile) def tf_function_wrapper(*tf_function_args): return strategy.run(fn, args=tf_function_args, kwargs=kwargs) + return tf_function_wrapper(*args) else: assert not jit_compile From 214d40ec2adeab13adf00208a3f86ab424b0e6df Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Mon, 24 Nov 2025 10:03:23 +0000 Subject: [PATCH 06/29] update tpu test workflow --- .github/workflows/actions.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 9b06868e..b3bc8aff 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -90,7 +90,10 @@ jobs: run: python3 -c "import jax; print('JAX devices:', jax.devices())" - name: Test with pytest - run: pytest keras_rs/src/layers/ + if: ${{ matrix.backend == 'tensorflow'}} + run: pytest keras_rs/ --ignore=keras_rs/src/layers/embedding/jax + if: ${{ matrix.backend == 'jax'}} + run: pytest keras_rs/ check_format: name: Check the code format From 461a9ea4be68e8fa76c4a52f364e3ecbec66ec69 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Mon, 24 Nov 2025 10:06:46 +0000 Subject: [PATCH 07/29] update actions.yml --- .github/workflows/actions.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index b3bc8aff..1f23ab51 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -89,10 +89,12 @@ jobs: if: ${{ matrix.backend == 'jax'}} run: python3 -c "import jax; print('JAX devices:', jax.devices())" - - name: Test with pytest - if: ${{ matrix.backend == 'tensorflow'}} + - name: Test with pytest (TensorFlow) + if: ${{ matrix.backend == 'tensorflow' }} run: pytest keras_rs/ --ignore=keras_rs/src/layers/embedding/jax - if: ${{ matrix.backend == 'jax'}} + + - name: Test with pytest (JAX) + if: ${{ matrix.backend == 'jax' }} run: pytest keras_rs/ check_format: From c4df6186ed0fc7e567e10c3afef45128524e446b Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Mon, 24 Nov 2025 23:23:26 +0000 Subject: [PATCH 08/29] fix cpu tf error on run with strategy taking kwargs and format --- keras_rs/src/utils/tpu_test_utils.py | 45 +++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py index e1bdbcfd..67518389 100644 --- a/keras_rs/src/utils/tpu_test_utils.py +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -4,10 +4,12 @@ import keras import tensorflow as tf +jax: Optional[ModuleType] = None + try: import jax except ImportError: - jax = None + pass class DummyStrategy: @@ -33,8 +35,9 @@ def num_replicas_in_sync(self): return 0 return jax.device_count("tpu") +StrategyType = Union[tf.distribute.Strategy, DummyStrategy, JaxDummyStrategy] -def get_tpu_strategy(test_case): +def get_tpu_strategy(test_case: Any) -> StrategyType: """Get TPU strategy if on TPU, otherwise return DummyStrategy.""" if "TPU_NAME" not in os.environ: return DummyStrategy() @@ -63,15 +66,41 @@ def get_tpu_strategy(test_case): return DummyStrategy() -def run_with_strategy(strategy, fn, *args, jit_compile=False, **kwargs): - """Wrapper for running a function under a strategy.""" +def run_with_strategy( + strategy: Any, + fn: Callable[..., Any], + *args: Any, + jit_compile: bool = False, + **kwargs: Any +) -> Any: + """ + Final wrapper fix: Flattens allowed kwargs into positional args before + entering tf.function to guarantee a fixed graph signature. + """ if keras.backend.backend() == "tensorflow": + # Extract sample_weight and treat it as an explicit third positional argument. + # If not present, use a placeholder (None). + sample_weight_value = kwargs.get('sample_weight', None) + all_inputs = args + (sample_weight_value,) @tf.function(jit_compile=jit_compile) - def tf_function_wrapper(*tf_function_args): - return strategy.run(fn, args=tf_function_args, kwargs=kwargs) - - return tf_function_wrapper(*args) + def tf_function_wrapper(input_tuple: Tuple[Any, ...]) -> Any: + num_original_args = len(args) + core_args = input_tuple[:num_original_args] + sw_value = input_tuple[-1] + + if sw_value is not None: + all_positional_args = core_args + (sw_value,) + return strategy.run( + fn, + args=all_positional_args + ) + else: + return strategy.run( + fn, + args=core_args + ) + return tf_function_wrapper(all_inputs) else: assert not jit_compile return fn(*args, **kwargs) From 5691b4da0174518b7dd8d73d1d85236e17bb9f43 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Mon, 24 Nov 2025 23:25:35 +0000 Subject: [PATCH 09/29] format --- keras_rs/src/utils/tpu_test_utils.py | 31 +++++++++++++++------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py index 67518389..7b868f3a 100644 --- a/keras_rs/src/utils/tpu_test_utils.py +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -1,8 +1,14 @@ import contextlib import os +import Any +import Callable import keras +import ModuleType +import Optional import tensorflow as tf +import Tuple +import Union jax: Optional[ModuleType] = None @@ -35,8 +41,10 @@ def num_replicas_in_sync(self): return 0 return jax.device_count("tpu") + StrategyType = Union[tf.distribute.Strategy, DummyStrategy, JaxDummyStrategy] + def get_tpu_strategy(test_case: Any) -> StrategyType: """Get TPU strategy if on TPU, otherwise return DummyStrategy.""" if "TPU_NAME" not in os.environ: @@ -71,16 +79,16 @@ def run_with_strategy( fn: Callable[..., Any], *args: Any, jit_compile: bool = False, - **kwargs: Any + **kwargs: Any, ) -> Any: """ - Final wrapper fix: Flattens allowed kwargs into positional args before + Final wrapper fix: Flattens allowed kwargs into positional args before entering tf.function to guarantee a fixed graph signature. """ if keras.backend.backend() == "tensorflow": - # Extract sample_weight and treat it as an explicit third positional argument. - # If not present, use a placeholder (None). - sample_weight_value = kwargs.get('sample_weight', None) + # Extract sample_weight and treat it as an explicit third positional + # argument. If not present, use a placeholder (None). + sample_weight_value = kwargs.get("sample_weight", None) all_inputs = args + (sample_weight_value,) @tf.function(jit_compile=jit_compile) @@ -88,18 +96,13 @@ def tf_function_wrapper(input_tuple: Tuple[Any, ...]) -> Any: num_original_args = len(args) core_args = input_tuple[:num_original_args] sw_value = input_tuple[-1] - + if sw_value is not None: all_positional_args = core_args + (sw_value,) - return strategy.run( - fn, - args=all_positional_args - ) + return strategy.run(fn, args=all_positional_args) else: - return strategy.run( - fn, - args=core_args - ) + return strategy.run(fn, args=core_args) + return tf_function_wrapper(all_inputs) else: assert not jit_compile From e5578055130dfbce7c876e5f78089d6571d9c868 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Mon, 24 Nov 2025 23:41:35 +0000 Subject: [PATCH 10/29] fix import --- keras_rs/src/utils/tpu_test_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py index 7b868f3a..2ad15549 100644 --- a/keras_rs/src/utils/tpu_test_utils.py +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -1,14 +1,10 @@ import contextlib import os +from types import ModuleType +from typing import Any, Callable, Optional, Tuple, Union -import Any -import Callable import keras -import ModuleType -import Optional import tensorflow as tf -import Tuple -import Union jax: Optional[ModuleType] = None From 91e82e7814a73dfc4876870c3d5d27e64b0a8881 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 25 Nov 2025 00:28:14 +0000 Subject: [PATCH 11/29] fix test errors --- .../embedding/jax/distributed_embedding_test.py | 6 ++++-- keras_rs/src/utils/tpu_test_utils.py | 16 +++++++++------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py index 1dd3525d..79e49059 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py @@ -28,6 +28,7 @@ keras.config.disable_traceback_filtering() +from keras_rs.src import testing def _create_sparsecore_layout( sharding_axis: str = "sparsecore", @@ -308,7 +309,7 @@ def my_initializer(shape: tuple[int, int], dtype: Any): keras.backend.backend() != "jax", reason="Backend specific test", ) -class DistributedEmbeddingLayerTest(parameterized.TestCase): +class DistributedEmbeddingLayerTest(testing.TestCase, parameterized.TestCase): @parameterized.product( ragged=[True, False], combiner=["sum", "mean", "sqrtn"], @@ -326,6 +327,7 @@ def test_call( table_stacking: str | list[str] | list[list[str]], jit: bool, ): + self.on_tpu = "TPU_NAME" in os.environ table_configs = keras_test_utils.create_random_table_configs( combiner=combiner, seed=10 ) @@ -374,7 +376,7 @@ def test_call( ) keras.tree.map_structure( - lambda a, b: np.testing.assert_allclose(a, b, atol=1e-5), + lambda a, b: self.assertAllClose(a, b, atol=1e-3, is_tpu=self.on_tpu), outputs, expected_outputs, ) diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py index 2ad15549..dcab2eae 100644 --- a/keras_rs/src/utils/tpu_test_utils.py +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -1,7 +1,7 @@ import contextlib import os from types import ModuleType -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, ContextManager, Optional, Tuple, Union import keras import tensorflow as tf @@ -15,24 +15,26 @@ class DummyStrategy: - def scope(self): + def scope(self) -> ContextManager[None]: return contextlib.nullcontext() @property - def num_replicas_in_sync(self): + def num_replicas_in_sync(self) -> int: return 1 - def run(self, fn, args): + def run(self, fn: Callable[..., Any], args: Tuple[Any, ...]) -> Any: return fn(*args) - def experimental_distribute_dataset(self, dataset, options=None): + def experimental_distribute_dataset( + self, dataset: Any, options: Optional[Any] = None + ) -> Any: del options return dataset class JaxDummyStrategy(DummyStrategy): @property - def num_replicas_in_sync(self): + def num_replicas_in_sync(self) -> int: if jax is None: return 0 return jax.device_count("tpu") @@ -87,7 +89,7 @@ def run_with_strategy( sample_weight_value = kwargs.get("sample_weight", None) all_inputs = args + (sample_weight_value,) - @tf.function(jit_compile=jit_compile) + @tf.function(jit_compile=jit_compile) # type: ignore[misc] def tf_function_wrapper(input_tuple: Tuple[Any, ...]) -> Any: num_original_args = len(args) core_args = input_tuple[:num_original_args] From a756439de0f8883cb3d8c262fbabd841fa67f314 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 25 Nov 2025 00:34:36 +0000 Subject: [PATCH 12/29] format --- .../src/layers/embedding/jax/distributed_embedding_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py index 79e49059..cc41e584 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py @@ -18,6 +18,7 @@ ) from jax_tpu_embedding.sparsecore.utils import utils as jte_utils +from keras_rs.src import testing from keras_rs.src.layers.embedding import test_utils as keras_test_utils from keras_rs.src.layers.embedding.jax import checkpoint_utils from keras_rs.src.layers.embedding.jax import config_conversion @@ -28,7 +29,6 @@ keras.config.disable_traceback_filtering() -from keras_rs.src import testing def _create_sparsecore_layout( sharding_axis: str = "sparsecore", @@ -376,7 +376,9 @@ def test_call( ) keras.tree.map_structure( - lambda a, b: self.assertAllClose(a, b, atol=1e-3, is_tpu=self.on_tpu), + lambda a, b: self.assertAllClose( + a, b, atol=1e-3, is_tpu=self.on_tpu + ), outputs, expected_outputs, ) From 4e0363a954510b749663c229406d8ab83192a2b4 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 25 Nov 2025 00:37:56 +0000 Subject: [PATCH 13/29] fix type --- keras_rs/src/utils/tpu_test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py index dcab2eae..d6e1d514 100644 --- a/keras_rs/src/utils/tpu_test_utils.py +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -34,7 +34,7 @@ def experimental_distribute_dataset( class JaxDummyStrategy(DummyStrategy): @property - def num_replicas_in_sync(self) -> int: + def num_replicas_in_sync(self) -> Any: if jax is None: return 0 return jax.device_count("tpu") From df8c1ff0d9df77059ce3d63c579e332437f5964d Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 25 Nov 2025 01:18:29 +0000 Subject: [PATCH 14/29] ignore long runnign tpu test --- .github/workflows/actions.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 1f23ab51..f9ce4452 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -95,7 +95,7 @@ jobs: - name: Test with pytest (JAX) if: ${{ matrix.backend == 'jax' }} - run: pytest keras_rs/ + run: pytest keras_rs/ --ignore=keras_rs/src/layers/embedding/jax/distributed_embedding_test.py::DistributedEmbeddingLayerTest check_format: name: Check the code format From 0b81a274db0caf05382728a1987aad0e2688b091 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 25 Nov 2025 01:30:15 +0000 Subject: [PATCH 15/29] update ignore --- .github/workflows/actions.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index f9ce4452..b125b6d1 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -95,7 +95,7 @@ jobs: - name: Test with pytest (JAX) if: ${{ matrix.backend == 'jax' }} - run: pytest keras_rs/ --ignore=keras_rs/src/layers/embedding/jax/distributed_embedding_test.py::DistributedEmbeddingLayerTest + run: pytest keras_rs/ --ignore=keras_rs/src/layers/embedding/jax/distributed_embedding_test.py check_format: name: Check the code format From 376a954598eecd3a96fa969517eb96a0e8aaefca Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 25 Nov 2025 18:44:03 +0000 Subject: [PATCH 16/29] clean up --- .../src/layers/embedding/jax/distributed_embedding_test.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py index cc41e584..d54cce4f 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py @@ -309,7 +309,7 @@ def my_initializer(shape: tuple[int, int], dtype: Any): keras.backend.backend() != "jax", reason="Backend specific test", ) -class DistributedEmbeddingLayerTest(testing.TestCase, parameterized.TestCase): +class DistributedEmbeddingLayerTest(parameterized.TestCase): @parameterized.product( ragged=[True, False], combiner=["sum", "mean", "sqrtn"], @@ -327,7 +327,6 @@ def test_call( table_stacking: str | list[str] | list[list[str]], jit: bool, ): - self.on_tpu = "TPU_NAME" in os.environ table_configs = keras_test_utils.create_random_table_configs( combiner=combiner, seed=10 ) @@ -376,9 +375,7 @@ def test_call( ) keras.tree.map_structure( - lambda a, b: self.assertAllClose( - a, b, atol=1e-3, is_tpu=self.on_tpu - ), + lambda a, b: np.testing.assert_allclose(a, b, atol=1e-5), outputs, expected_outputs, ) From a1dedc65ba31d1e0b04b7e2079d5f7e3ea4fab4f Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Wed, 26 Nov 2025 00:45:17 +0000 Subject: [PATCH 17/29] revert unnecessary tpu strategy for eager --- .../src/layers/embedding/embed_reduce_test.py | 157 ++++++------------ .../jax/distributed_embedding_test.py | 1 - .../dot_interaction_test.py | 49 ++---- .../feature_interaction/feature_cross_test.py | 44 ++--- .../retrieval/brute_force_retrieval_test.py | 56 ++----- .../retrieval/hard_negative_mining_test.py | 41 ++--- .../retrieval/remove_accidental_hits_test.py | 55 ++---- .../src/layers/retrieval/retrieval_test.py | 35 +--- .../sampling_probability_correction_test.py | 36 +--- 9 files changed, 135 insertions(+), 339 deletions(-) diff --git a/keras_rs/src/layers/embedding/embed_reduce_test.py b/keras_rs/src/layers/embedding/embed_reduce_test.py index 908067ca..c3923d4a 100644 --- a/keras_rs/src/layers/embedding/embed_reduce_test.py +++ b/keras_rs/src/layers/embedding/embed_reduce_test.py @@ -3,7 +3,6 @@ import keras import tensorflow as tf -from absl.testing import absltest from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -25,10 +24,8 @@ class EmbedReduceTest(testing.TestCase, parameterized.TestCase): def setUp(self): super().setUp() self.on_tpu = "TPU_NAME" in os.environ - if keras.backend.backend() == "tensorflow": tf.debugging.disable_traceback_filtering() - self._strategy = tpu_test_utils.get_tpu_strategy(self) @parameterized.named_parameters( @@ -62,87 +59,50 @@ def test_call(self, combiner, input_type, input_rank, use_weights): ): self.skipTest(f"sparse not supported on {keras.backend.backend()}") - if self.on_tpu and input_type in ["ragged", "sparse"]: - self.skipTest("Ragged and sparse are not compilable on TPU.") - - batch_size = 2 * self._strategy.num_replicas_in_sync - - def repeat_input(item, times): - return [item[i % len(item)] for i in range(times)] - if input_type == "dense" and input_rank == 1: - inputs = ops.convert_to_tensor(repeat_input([1, 2], batch_size)) - weights = ops.convert_to_tensor( - repeat_input([1.0, 2.0], batch_size) - ) + inputs = ops.convert_to_tensor([1, 2]) + weights = ops.convert_to_tensor([1.0, 2.0]) elif input_type == "dense" and input_rank == 2: - inputs = ops.convert_to_tensor( - repeat_input([[1, 2], [3, 4]], batch_size) - ) - weights = ops.convert_to_tensor( - repeat_input([[1.0, 2.0], [3.0, 4.0]], batch_size) - ) + inputs = ops.convert_to_tensor([[1, 2], [3, 4]]) + weights = ops.convert_to_tensor([[1.0, 2.0], [3.0, 4.0]]) elif input_type == "ragged" and input_rank == 2: - inputs = tf.ragged.constant( - repeat_input([[1], [2, 3, 4, 5]], batch_size) - ) - weights = tf.ragged.constant( - repeat_input([[1.0], [1.0, 2.0, 3.0, 4.0]], batch_size) - ) + import tensorflow as tf + + inputs = tf.ragged.constant([[1], [2, 3, 4, 5]]) + weights = tf.ragged.constant([[1.0], [1.0, 2.0, 3.0, 4.0]]) elif input_type == "sparse" and input_rank == 2: - base_indices = [[0, 0], [1, 0], [1, 1], [1, 2], [1, 3]] - base_values = [1, 2, 3, 4, 5] - base_weights = [1.0, 1.0, 2.0, 3.0, 4.0] - indices = [] - values = [] - weight_values = [] - for i in range(batch_size // 2): - for idx, val, wgt in zip( - base_indices, base_values, base_weights - ): - indices.append([i * 2 + idx[0], idx[1]]) - values.append(val) - weight_values.append(wgt) + indices = [[0, 0], [1, 0], [1, 1], [1, 2], [1, 3]] if keras.backend.backend() == "tensorflow": + import tensorflow as tf + inputs = tf.sparse.reorder( - tf.SparseTensor(indices, values, (batch_size, 4)) + tf.SparseTensor(indices, [1, 2, 3, 4, 5], (2, 4)) ) weights = tf.sparse.reorder( - tf.SparseTensor(indices, weight_values, (batch_size, 4)) + tf.SparseTensor(indices, [1.0, 1.0, 2.0, 3.0, 4.0], (2, 4)) ) elif keras.backend.backend() == "jax": + from jax.experimental import sparse as jax_sparse + inputs = jax_sparse.BCOO( - (ops.array(values), ops.array(indices)), - shape=(batch_size, 4), + ([1, 2, 3, 4, 5], indices), + shape=(2, 4), unique_indices=True, ) weights = jax_sparse.BCOO( - (ops.array(weight_values), ops.array(indices)), - shape=(batch_size, 4), + ([1.0, 1.0, 2.0, 3.0, 4.0], indices), + shape=(2, 4), unique_indices=True, ) if not use_weights: weights = None - with self._strategy.scope(): - layer = EmbedReduce(10, 20, combiner=combiner) - - if keras.backend.backend() == "tensorflow": - # TF requires weights to be None or match input type - if input_type == "sparse" and not use_weights: - res = tpu_test_utils.run_with_strategy( - self._strategy, layer.__call__, inputs - ) - else: - res = tpu_test_utils.run_with_strategy( - self._strategy, layer.__call__, inputs, weights - ) - else: # JAX or other - res = layer(inputs, weights) + layer = EmbedReduce(10, 20, combiner=combiner) + res = layer(inputs, weights) - self.assertEqual(res.shape, (batch_size, 20)) + self.assertEqual(res.shape, (2, 20)) e = layer.embeddings if input_type == "dense" and input_rank == 1: @@ -173,7 +133,6 @@ def repeat_input(item, times): elif combiner == "sqrtn": expected[1] /= math.sqrt(30.0 if use_weights else 4.0) - expected = repeat_input(expected, batch_size) self.assertAllClose(res, expected) @parameterized.named_parameters( @@ -199,70 +158,52 @@ def repeat_input(item, times): def test_symbolic_call(self, input_type, input_rank, use_weights): if input_type == "ragged" and keras.backend.backend() != "tensorflow": self.skipTest(f"ragged not supported on {keras.backend.backend()}") - if input_type == "sparse": - if keras.backend.backend() == "jax": - self.assertTrue( - jax is not None, "JAX not found for JAX backend test." - ) - elif keras.backend.backend() != "tensorflow": - self.skipTest( - f"sparse not supported on {keras.backend.backend()}" - ) + if input_type == "sparse" and keras.backend.backend() not in ( + "jax", + "tensorflow", + ): + self.skipTest(f"sparse not supported on {keras.backend.backend()}") - with self._strategy.scope(): - layer = EmbedReduce(10, 20, dtype="float32") + input = keras.layers.Input( + shape=(2,) if input_rank == 2 else (), + sparse=input_type == "sparse", + ragged=input_type == "ragged", + dtype="int32", + ) - input_tensor = keras.layers.Input( + if use_weights: + weights = keras.layers.Input( shape=(2,) if input_rank == 2 else (), sparse=input_type == "sparse", ragged=input_type == "ragged", - dtype="int32", + dtype="float32", ) + output = EmbedReduce(10, 20, dtype="float32")(input, weights) + else: + output = EmbedReduce(10, 20, dtype="float32")(input) - if use_weights: - weights = keras.layers.Input( - shape=(2,) if input_rank == 2 else (), - sparse=input_type == "sparse", - ragged=input_type == "ragged", - dtype="float32", - ) - output = layer(input_tensor, weights) - else: - output = layer(input_tensor) - - self.assertEqual(output.shape, (None, 20)) - self.assertEqual(output.dtype, "float32") - self.assertFalse(output.sparse) - self.assertFalse(output.ragged) + self.assertEqual(output.shape, (None, 20)) + self.assertEqual(output.dtype, "float32") + self.assertFalse(output.sparse) + self.assertFalse(output.ragged) def test_predict(self): - input_data = keras.random.randint((5, 7), minval=0, maxval=10) + input = keras.random.randint((5, 7), minval=0, maxval=10) with self._strategy.scope(): model = keras.models.Sequential([EmbedReduce(10, 20)]) - # Compilation is often needed for strategies to be fully utilized model.compile(optimizer="adam", loss="mse") - - # model.predict itself handles the strategy distribution - model.predict(input_data, batch_size=2) + model.predict(input, batch_size=2) def test_serialization(self): - with self._strategy.scope(): - layer = EmbedReduce(10, 20, combiner="sqrtn") - + layer = EmbedReduce(10, 20, combiner="sqrtn") restored = deserialize(serialize(layer)) self.assertDictEqual(layer.get_config(), restored.get_config()) def test_model_saving(self): - input_data = keras.random.randint((5, 7), minval=0, maxval=10) - - with self._strategy.scope(): - model = keras.models.Sequential([EmbedReduce(10, 20)]) + input = keras.random.randint((5, 7), minval=0, maxval=10) + model = keras.models.Sequential([EmbedReduce(10, 20)]) self.run_model_saving_test( model=model, - input_data=input_data, + input_data=input, ) - - -if __name__ == "__main__": - absltest.main() diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py index d54cce4f..1dd3525d 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py @@ -18,7 +18,6 @@ ) from jax_tpu_embedding.sparsecore.utils import utils as jte_utils -from keras_rs.src import testing from keras_rs.src.layers.embedding import test_utils as keras_test_utils from keras_rs.src.layers.embedding.jax import checkpoint_utils from keras_rs.src.layers.embedding.jax import config_conversion diff --git a/keras_rs/src/layers/feature_interaction/dot_interaction_test.py b/keras_rs/src/layers/feature_interaction/dot_interaction_test.py index 8c9182c0..a565297e 100644 --- a/keras_rs/src/layers/feature_interaction/dot_interaction_test.py +++ b/keras_rs/src/layers/feature_interaction/dot_interaction_test.py @@ -2,7 +2,6 @@ import keras import tensorflow as tf -from absl.testing import absltest from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -20,7 +19,6 @@ def setUp(self): super().setUp() if keras.backend.backend() == "tensorflow": tf.debugging.disable_traceback_filtering() - self.on_tpu = "TPU_NAME" in os.environ self._strategy = tpu_test_utils.get_tpu_strategy(self) @@ -89,36 +87,25 @@ def setUp(self): ), ) def test_call(self, self_interaction, skip_gather, exp_output_idx): - with self._strategy.scope(): - layer = DotInteraction( - self_interaction=self_interaction, skip_gather=skip_gather - ) - output = tpu_test_utils.run_with_strategy( - self._strategy, layer, self.input - ) - self.assertAllClose( - output, self.exp_outputs[exp_output_idx], is_tpu=self.on_tpu + layer = DotInteraction( + self_interaction=self_interaction, skip_gather=skip_gather ) + output = layer(self.input) + self.assertAllClose(output, self.exp_outputs[exp_output_idx]) def test_invalid_input_rank(self): rank_1_input = [ops.ones((3,)), ops.ones((3,))] - with self._strategy.scope(): - layer = DotInteraction() + layer = DotInteraction() with self.assertRaises(ValueError): - tpu_test_utils.run_with_strategy( - self._strategy, layer, rank_1_input - ) + layer(rank_1_input) def test_invalid_input_different_shapes(self): unequal_shape_input = [ops.ones((1, 3)), ops.ones((1, 4))] - with self._strategy.scope(): - layer = DotInteraction() + layer = DotInteraction() with self.assertRaises(ValueError): - tpu_test_utils.run_with_strategy( - self._strategy, layer, unequal_shape_input - ) + layer(unequal_shape_input) @parameterized.named_parameters( ( @@ -157,25 +144,19 @@ def test_predict(self, self_interaction, skip_gather): model.predict(self.input, batch_size=2) def test_serialization(self): - with self._strategy.scope(): - layer = DotInteraction() + layer = DotInteraction() restored = deserialize(serialize(layer)) self.assertDictEqual(layer.get_config(), restored.get_config()) def test_model_saving(self): - with self._strategy.scope(): - feature1 = keras.layers.Input(shape=(5,)) - feature2 = keras.layers.Input(shape=(5,)) - feature3 = keras.layers.Input(shape=(5,)) - x = DotInteraction()([feature1, feature2, feature3]) - x = keras.layers.Dense(units=1)(x) - model = keras.Model([feature1, feature2, feature3], x) + feature1 = keras.layers.Input(shape=(5,)) + feature2 = keras.layers.Input(shape=(5,)) + feature3 = keras.layers.Input(shape=(5,)) + x = DotInteraction()([feature1, feature2, feature3]) + x = keras.layers.Dense(units=1)(x) + model = keras.Model([feature1, feature2, feature3], x) self.run_model_saving_test( model=model, input_data=self.input, ) - - -if __name__ == "__main__": - absltest.main() diff --git a/keras_rs/src/layers/feature_interaction/feature_cross_test.py b/keras_rs/src/layers/feature_interaction/feature_cross_test.py index 89df5009..d81afb2a 100644 --- a/keras_rs/src/layers/feature_interaction/feature_cross_test.py +++ b/keras_rs/src/layers/feature_interaction/feature_cross_test.py @@ -1,6 +1,5 @@ import keras import tensorflow as tf -from absl.testing import absltest from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -16,7 +15,6 @@ def setUp(self): super().setUp() if keras.backend.backend() == "tensorflow": tf.debugging.disable_traceback_filtering() - self._strategy = tpu_test_utils.get_tpu_strategy(self) self.x0 = ops.array([[0.1, 0.2, 0.3]], dtype="float32") @@ -26,8 +24,7 @@ def setUp(self): self.one_inp_exp_output = ops.array([[0.16, 0.32, 0.48]]) def test_full_layer(self): - with self._strategy.scope(): - layer = FeatureCross(projection_dim=None, kernel_initializer="ones") + layer = FeatureCross(projection_dim=None, kernel_initializer="ones") output = layer(self.x0, self.x) # Test output. @@ -40,8 +37,7 @@ def test_full_layer(self): self.assertEqual(layer.weights[1].shape, (3,)) def test_low_rank_layer(self): - with self._strategy.scope(): - layer = FeatureCross(projection_dim=1, kernel_initializer="ones") + layer = FeatureCross(projection_dim=1, kernel_initializer="ones") output = layer(self.x0, self.x) # Test output. @@ -56,8 +52,7 @@ def test_low_rank_layer(self): self.assertEqual(layer.weights[2].shape, (3,)) def test_one_input(self): - with self._strategy.scope(): - layer = FeatureCross(projection_dim=None, kernel_initializer="ones") + layer = FeatureCross(projection_dim=None, kernel_initializer="ones") output = layer(self.x0) self.assertAllClose(self.one_inp_exp_output, output) @@ -65,8 +60,7 @@ def test_invalid_input_shapes(self): x0 = ops.ones((12, 5)) x = ops.ones((12, 7)) - with self._strategy.scope(): - layer = FeatureCross() + layer = FeatureCross() with self.assertRaises(ValueError): layer(x0, x) @@ -76,19 +70,15 @@ def test_invalid_diag_scale(self): FeatureCross(diag_scale=-1.0) def test_diag_scale(self): - with self._strategy.scope(): - layer = FeatureCross( - projection_dim=None, diag_scale=1.0, kernel_initializer="ones" - ) + layer = FeatureCross( + projection_dim=None, diag_scale=1.0, kernel_initializer="ones" + ) output = layer(self.x0, self.x) self.assertAllClose(ops.array([[0.59, 0.9, 1.23]]), output) def test_pre_activation(self): - with self._strategy.scope(): - layer = FeatureCross( - projection_dim=None, pre_activation=ops.zeros_like - ) + layer = FeatureCross(projection_dim=None, pre_activation=ops.zeros_like) output = layer(self.x0, self.x) self.assertAllClose(self.x, output) @@ -105,24 +95,18 @@ def test_predict(self): model.predict(self.x0, batch_size=2) def test_serialization(self): - with self._strategy.scope(): - sampler = FeatureCross(projection_dim=None, pre_activation="swish") + sampler = FeatureCross(projection_dim=None, pre_activation="swish") restored = deserialize(serialize(sampler)) self.assertDictEqual(sampler.get_config(), restored.get_config()) def test_model_saving(self): - with self._strategy.scope(): - x0 = keras.layers.Input(shape=(3,)) - x1 = FeatureCross(projection_dim=None)(x0, x0) - x2 = FeatureCross(projection_dim=None)(x0, x1) - logits = keras.layers.Dense(units=1)(x2) - model = keras.Model(x0, logits) + x0 = keras.layers.Input(shape=(3,)) + x1 = FeatureCross(projection_dim=None)(x0, x0) + x2 = FeatureCross(projection_dim=None)(x0, x1) + logits = keras.layers.Dense(units=1)(x2) + model = keras.Model(x0, logits) self.run_model_saving_test( model=model, input_data=self.x0, ) - - -if __name__ == "__main__": - absltest.main() diff --git a/keras_rs/src/layers/retrieval/brute_force_retrieval_test.py b/keras_rs/src/layers/retrieval/brute_force_retrieval_test.py index a31aa417..a5f8a86f 100644 --- a/keras_rs/src/layers/retrieval/brute_force_retrieval_test.py +++ b/keras_rs/src/layers/retrieval/brute_force_retrieval_test.py @@ -1,24 +1,11 @@ -import os - import keras -import tensorflow as tf -from absl.testing import absltest from absl.testing import parameterized from keras_rs.src import testing from keras_rs.src.layers.retrieval import brute_force_retrieval -from keras_rs.src.utils import tpu_test_utils class BruteForceRetrievalTest(testing.TestCase, parameterized.TestCase): - def setUp(self): - super().setUp() - self.on_tpu = "TPU_NAME" in os.environ - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() - - self._strategy = tpu_test_utils.get_tpu_strategy(self) - @parameterized.product( has_identifiers=(True, False), return_scores=(True, False), @@ -38,13 +25,12 @@ def test_brute_force_retrieval(self, has_identifiers, return_scores): else None ) - with self._strategy.scope(): - layer = brute_force_retrieval.BruteForceRetrieval( - k=k, - candidate_embeddings=candidates, - candidate_ids=candidate_indices, - return_scores=return_scores, - ) + layer = brute_force_retrieval.BruteForceRetrieval( + k=k, + candidate_embeddings=candidates, + candidate_ids=candidate_indices, + return_scores=return_scores, + ) query = keras.random.normal((num_queries, 4), dtype="float32", seed=rng) scores = keras.ops.matmul(query, keras.ops.transpose(candidates)) @@ -65,34 +51,14 @@ def test_brute_force_retrieval(self, has_identifiers, return_scores): for i in range(2): if i: # First time uses values from __init__, second time uses update. - with self._strategy.scope(): - layer.update_candidates(candidates, candidate_indices) + layer.update_candidates(candidates, candidate_indices) if return_scores: - top_scores, top_indices = tpu_test_utils.run_with_strategy( - self._strategy, layer, query - ) + top_scores, top_indices = layer(query) self.assertEqual(top_scores.shape, expected_top_scores.shape) - self.assertAllClose( - top_scores, - expected_top_scores, - atol=1e-4, - is_tpu=self.on_tpu, - ) + self.assertAllClose(top_scores, expected_top_scores, atol=1e-4) else: - top_indices = tpu_test_utils.run_with_strategy( - self._strategy, layer, query - ) + top_indices = layer(query) self.assertEqual(top_indices.shape, expected_top_indices.shape) - self.assertAllClose( - top_indices, - expected_top_indices, - tpu_atol=5, - tpu_rtol=10, - is_tpu=self.on_tpu, - ) - - -if __name__ == "__main__": - absltest.main() + self.assertAllClose(top_indices, expected_top_indices) diff --git a/keras_rs/src/layers/retrieval/hard_negative_mining_test.py b/keras_rs/src/layers/retrieval/hard_negative_mining_test.py index ddbc0832..431e3d34 100644 --- a/keras_rs/src/layers/retrieval/hard_negative_mining_test.py +++ b/keras_rs/src/layers/retrieval/hard_negative_mining_test.py @@ -1,6 +1,5 @@ import keras import tensorflow as tf -from absl.testing import absltest from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -69,11 +68,9 @@ def test_call(self, rank, num_hard_negatives): logits, labels = self.create_inputs(rank=rank) num_logits = logits.shape[-1] - with self._strategy.scope(): - layer = hard_negative_mining.HardNegativeMining(num_hard_negatives) - out_logits, out_labels = tpu_test_utils.run_with_strategy( - self._strategy, layer, logits, labels - ) + out_logits, out_labels = hard_negative_mining.HardNegativeMining( + num_hard_negatives + )(logits, labels) self.assertEqual( out_logits.shape[-1], min(num_hard_negatives + 1, num_logits) @@ -88,11 +85,9 @@ def test_call(self, rank, num_hard_negatives): # Set the logits for labels to be highest to ignore effect of labels. logits = logits + labels * 1000.0 - with self._strategy.scope(): - layer = hard_negative_mining.HardNegativeMining(num_hard_negatives) - out_logits, _ = tpu_test_utils.run_with_strategy( - self._strategy, layer, logits, labels - ) + out_logits, _ = hard_negative_mining.HardNegativeMining( + num_hard_negatives + )(logits, labels) # Highest K logits are always returned. self.assertAllClose( @@ -117,28 +112,18 @@ def test_predict(self): model.predict([logits, labels], batch_size=8) def test_serialization(self): - with self._strategy.scope(): - layer = hard_negative_mining.HardNegativeMining( - num_hard_negatives=3 - ) + layer = hard_negative_mining.HardNegativeMining(num_hard_negatives=3) restored = deserialize(serialize(layer)) self.assertDictEqual(layer.get_config(), restored.get_config()) def test_model_saving(self): logits, labels = self.create_inputs() - with self._strategy.scope(): - in_logits = keras.layers.Input(shape=logits.shape[1:]) - in_labels = keras.layers.Input(shape=labels.shape[1:]) - out_logits, out_labels = hard_negative_mining.HardNegativeMining( - num_hard_negatives=3 - )(in_logits, in_labels) - model = keras.Model( - [in_logits, in_labels], [out_logits, out_labels] - ) + in_logits = keras.layers.Input(shape=logits.shape[1:]) + in_labels = keras.layers.Input(shape=labels.shape[1:]) + out_logits, out_labels = hard_negative_mining.HardNegativeMining( + num_hard_negatives=3 + )(in_logits, in_labels) + model = keras.Model([in_logits, in_labels], [out_logits, out_labels]) self.run_model_saving_test(model=model, input_data=[logits, labels]) - - -if __name__ == "__main__": - absltest.main() diff --git a/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py b/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py index cb1e7357..4a5e5fdd 100644 --- a/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py +++ b/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py @@ -1,6 +1,5 @@ import keras import tensorflow as tf -from absl.testing import absltest from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -78,10 +77,8 @@ def test_call(self, logits_rank, candidate_ids_rank): logits_rank=logits_rank, candidate_ids_rank=candidate_ids_rank ) - with self._strategy.scope(): - layer = remove_accidental_hits.RemoveAccidentalHits() - out_logits = tpu_test_utils.run_with_strategy( - self._strategy, layer, logits, labels, candidate_ids + out_logits = remove_accidental_hits.RemoveAccidentalHits()( + logits, labels, candidate_ids ) # Logits of labels are unchanged. @@ -143,35 +140,21 @@ def test_call(self, logits_rank, candidate_ids_rank): ) def test_mismatched_labels_logits_shapes(self): - with self._strategy.scope(): - layer = remove_accidental_hits.RemoveAccidentalHits() + layer = remove_accidental_hits.RemoveAccidentalHits() with self.assertRaisesRegex( ValueError, "`labels` and `logits` should have the same shape" ): - tpu_test_utils.run_with_strategy( - self._strategy, - layer, - ops.zeros((10, 20)), - ops.zeros((10, 30)), - ops.zeros((20,)), - ) + layer(ops.zeros((10, 20)), ops.zeros((10, 30)), ops.zeros((20,))) def test_mismatched_labels_candidates_shapes(self): - with self._strategy.scope(): - layer = remove_accidental_hits.RemoveAccidentalHits() + layer = remove_accidental_hits.RemoveAccidentalHits() with self.assertRaisesRegex( ValueError, "`candidate_ids` should have the same shape as .* `labels`", ): - tpu_test_utils.run_with_strategy( - self._strategy, - layer, - ops.zeros((10, 20)), - ops.zeros((10, 20)), - ops.zeros((30,)), - ) + layer(ops.zeros((10, 20)), ops.zeros((10, 20)), ops.zeros((30,))) def test_predict(self): # Note: for predict, we test with probabilities that have a batch dim. @@ -191,30 +174,22 @@ def test_predict(self): model.predict([logits, labels, candidate_ids], batch_size=8) def test_serialization(self): - with self._strategy.scope(): - layer = remove_accidental_hits.RemoveAccidentalHits() + layer = remove_accidental_hits.RemoveAccidentalHits() restored = deserialize(serialize(layer)) self.assertDictEqual(layer.get_config(), restored.get_config()) def test_model_saving(self): logits, labels, candidate_ids = self.create_inputs() - with self._strategy.scope(): - layer = remove_accidental_hits.RemoveAccidentalHits() - in_logits = keras.layers.Input(logits.shape[1:]) - in_labels = keras.layers.Input(labels.shape[1:]) - in_candidate_ids = keras.layers.Input( - batch_shape=candidate_ids.shape - ) - out_logits = layer(in_logits, in_labels, in_candidate_ids) - model = keras.Model( - [in_logits, in_labels, in_candidate_ids], out_logits - ) + layer = remove_accidental_hits.RemoveAccidentalHits() + in_logits = keras.layers.Input(logits.shape[1:]) + in_labels = keras.layers.Input(labels.shape[1:]) + in_candidate_ids = keras.layers.Input(batch_shape=candidate_ids.shape) + out_logits = layer(in_logits, in_labels, in_candidate_ids) + model = keras.Model( + [in_logits, in_labels, in_candidate_ids], out_logits + ) self.run_model_saving_test( model=model, input_data=[logits, labels, candidate_ids] ) - - -if __name__ == "__main__": - absltest.main() diff --git a/keras_rs/src/layers/retrieval/retrieval_test.py b/keras_rs/src/layers/retrieval/retrieval_test.py index bfd1b88b..a5887e28 100644 --- a/keras_rs/src/layers/retrieval/retrieval_test.py +++ b/keras_rs/src/layers/retrieval/retrieval_test.py @@ -1,11 +1,8 @@ import keras -import tensorflow as tf -from absl.testing import absltest from absl.testing import parameterized from keras_rs.src import testing from keras_rs.src.layers.retrieval.retrieval import Retrieval -from keras_rs.src.utils import tpu_test_utils class DummyRetrieval(Retrieval): @@ -18,13 +15,7 @@ def call(self, inputs): class RetrievalTest(testing.TestCase, parameterized.TestCase): def setUp(self): - super().setUp() - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() - - self._strategy = tpu_test_utils.get_tpu_strategy(self) - with self._strategy.scope(): - self.layer = DummyRetrieval(k=5) + self.layer = DummyRetrieval(k=5) @parameterized.named_parameters( ("embeddings_none", None, None, "`candidate_embeddings` is required."), @@ -57,27 +48,19 @@ def test_validate_candidate_embeddings_and_ids( ) def test_call_not_overridden(self): - with self._strategy.scope(): - - class DummyRetrieval(Retrieval): - def update_candidates( - self, candidate_embeddings, candidate_ids=None - ): - pass + class DummyRetrieval(Retrieval): + def update_candidates( + self, candidate_embeddings, candidate_ids=None + ): + pass with self.assertRaises(TypeError): DummyRetrieval(k=5) def test_update_candidates_not_overridden(self): - with self._strategy.scope(): - - class DummyRetrieval(Retrieval): - def call(self, inputs): - pass + class DummyRetrieval(Retrieval): + def call(self, inputs): + pass with self.assertRaises(TypeError): DummyRetrieval(k=5) - - -if __name__ == "__main__": - absltest.main() diff --git a/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py b/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py index 6d88fd2a..9b4fee3d 100644 --- a/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py +++ b/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py @@ -1,6 +1,5 @@ import keras import tensorflow as tf -from absl.testing import absltest from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -71,13 +70,8 @@ def test_call(self, logits_rank, probs_rank): ) # Verifies logits are always less than corrected logits. - with self._strategy.scope(): - layer = ( - sampling_probability_correction.SamplingProbabilityCorrection() - ) - corrected_logits = tpu_test_utils.run_with_strategy( - self._strategy, layer, logits, probs - ) + layer = sampling_probability_correction.SamplingProbabilityCorrection() + corrected_logits = layer(logits, probs) self.assertAllClose( ops.less(logits, corrected_logits), ops.ones(logits.shape) ) @@ -92,9 +86,7 @@ def test_call(self, logits_rank, probs_rank): ) # Verifies logits are always less than corrected logits. - corrected_logits_with_zeros = tpu_test_utils.run_with_strategy( - self._strategy, layer, logits, probs_with_zeros - ) + corrected_logits_with_zeros = layer(logits, probs_with_zeros) self.assertAllClose( ops.less(logits, corrected_logits_with_zeros), ops.ones(logits.shape), @@ -117,27 +109,17 @@ def test_predict(self): model.predict([logits, probs], batch_size=4) def test_serialization(self): - with self._strategy.scope(): - layer = ( - sampling_probability_correction.SamplingProbabilityCorrection() - ) + layer = sampling_probability_correction.SamplingProbabilityCorrection() restored = deserialize(serialize(layer)) self.assertDictEqual(layer.get_config(), restored.get_config()) def test_model_saving(self): logits, probs = self.create_inputs() - with self._strategy.scope(): - layer = ( - sampling_probability_correction.SamplingProbabilityCorrection() - ) - in_logits = keras.layers.Input(shape=logits.shape[1:]) - in_probs = keras.layers.Input(batch_shape=probs.shape) - out_logits = layer(in_logits, in_probs) - model = keras.Model([in_logits, in_probs], out_logits) + layer = sampling_probability_correction.SamplingProbabilityCorrection() + in_logits = keras.layers.Input(shape=logits.shape[1:]) + in_probs = keras.layers.Input(batch_shape=probs.shape) + out_logits = layer(in_logits, in_probs) + model = keras.Model([in_logits, in_probs], out_logits) self.run_model_saving_test(model=model, input_data=[logits, probs]) - - -if __name__ == "__main__": - absltest.main() From d9a5aeb2bd4155ce4c6a764f352faa012ea4e89e Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Wed, 26 Nov 2025 02:11:04 +0000 Subject: [PATCH 18/29] revert more unnecessary changes and resolve comments --- .../embedding/distributed_embedding_test.py | 11 +- .../src/layers/embedding/embed_reduce_test.py | 8 +- .../tensorflow/config_conversion_test.py | 2 +- .../dot_interaction_test.py | 15 +- .../feature_interaction/feature_cross_test.py | 6 +- .../retrieval/hard_negative_mining_test.py | 7 +- .../retrieval/remove_accidental_hits_test.py | 7 +- .../sampling_probability_correction_test.py | 7 +- keras_rs/src/losses/list_mle_loss_test.py | 20 +-- .../src/losses/pairwise_hinge_loss_test.py | 20 +-- .../src/losses/pairwise_logistic_loss_test.py | 19 +-- .../pairwise_mean_squared_error_test.py | 20 +-- .../pairwise_soft_zero_one_loss_test.py | 20 +-- keras_rs/src/metrics/dcg_test.py | 148 +++++----------- .../metrics/mean_average_precision_test.py | 139 +++++---------- .../src/metrics/mean_reciprocal_rank_test.py | 126 ++++---------- keras_rs/src/metrics/ndcg_test.py | 160 ++++++------------ keras_rs/src/metrics/precision_at_k_test.py | 127 ++++---------- keras_rs/src/metrics/recall_at_k_test.py | 127 ++++---------- keras_rs/src/testing/test_case.py | 28 ++- 20 files changed, 290 insertions(+), 727 deletions(-) diff --git a/keras_rs/src/layers/embedding/distributed_embedding_test.py b/keras_rs/src/layers/embedding/distributed_embedding_test.py index c7ced5bc..49a72d93 100644 --- a/keras_rs/src/layers/embedding/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/distributed_embedding_test.py @@ -52,7 +52,6 @@ def setUp(self): # FLAGS.xla_sparse_core_max_ids_per_partition_per_sample = 16 # FLAGS.xla_sparse_core_max_unique_ids_per_partition_per_sample = 16 - self._strategy = tpu_test_utils.get_tpu_strategy(self) self.batch_size = ( BATCH_SIZE_PER_CORE * self._strategy.num_replicas_in_sync ) @@ -207,7 +206,7 @@ def test_basics(self, input_type, placement): res = layer(preprocessed_inputs) else: res = tpu_test_utils.run_with_strategy( - self._strategy, layer.__call__, inputs, weights + self.strategy, layer.__call__, inputs, weights ) if placement == "default_device" or not self.on_tpu: @@ -368,13 +367,13 @@ def test_dataset_generator(): model_inputs, _ = next(iter(test_dataset)) test_output_before = tpu_test_utils.run_with_strategy( - self._strategy, model.__call__, model_inputs + self.strategy, model.__call__, model_inputs ) model.fit(train_dataset, steps_per_epoch=1, epochs=1) test_output_after = tpu_test_utils.run_with_strategy( - self._strategy, model.__call__, model_inputs + self.strategy, model.__call__, model_inputs ) # Verify that the embedding has actually trained. @@ -556,11 +555,11 @@ def test_correctness( ) else: res = tpu_test_utils.run_with_strategy( - self._strategy, layer.__call__, preprocessed + self.strategy, layer.__call__, preprocessed ) else: res = tpu_test_utils.run_with_strategy( - self._strategy, + self.strategy, layer.__call__, inputs, weights, diff --git a/keras_rs/src/layers/embedding/embed_reduce_test.py b/keras_rs/src/layers/embedding/embed_reduce_test.py index c3923d4a..d30a8dcf 100644 --- a/keras_rs/src/layers/embedding/embed_reduce_test.py +++ b/keras_rs/src/layers/embedding/embed_reduce_test.py @@ -1,8 +1,6 @@ import math -import os import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -23,10 +21,6 @@ class EmbedReduceTest(testing.TestCase, parameterized.TestCase): def setUp(self): super().setUp() - self.on_tpu = "TPU_NAME" in os.environ - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() - self._strategy = tpu_test_utils.get_tpu_strategy(self) @parameterized.named_parameters( [ @@ -189,7 +183,7 @@ def test_symbolic_call(self, input_type, input_rank, use_weights): def test_predict(self): input = keras.random.randint((5, 7), minval=0, maxval=10) - with self._strategy.scope(): + with self.strategy.scope(): model = keras.models.Sequential([EmbedReduce(10, 20)]) model.compile(optimizer="adam", loss="mse") model.predict(input, batch_size=2) diff --git a/keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py b/keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py index 57675d0f..3d8e4daa 100644 --- a/keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py +++ b/keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py @@ -10,7 +10,7 @@ @pytest.mark.skipif( keras.backend.backend() != "tensorflow", - reason="Backend specific test", + reason="Tensorflow specific test", ) class ConfigConversionTest(testing.TestCase, parameterized.TestCase): @parameterized.named_parameters( diff --git a/keras_rs/src/layers/feature_interaction/dot_interaction_test.py b/keras_rs/src/layers/feature_interaction/dot_interaction_test.py index a565297e..24763207 100644 --- a/keras_rs/src/layers/feature_interaction/dot_interaction_test.py +++ b/keras_rs/src/layers/feature_interaction/dot_interaction_test.py @@ -1,7 +1,4 @@ -import os - import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -17,10 +14,6 @@ class DotInteractionTest(testing.TestCase, parameterized.TestCase): def setUp(self): super().setUp() - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() - self.on_tpu = "TPU_NAME" in os.environ - self._strategy = tpu_test_utils.get_tpu_strategy(self) self.input = [ ops.array([[0.1, -4.3, 0.2, 1.1, 0.3]]), @@ -91,7 +84,11 @@ def test_call(self, self_interaction, skip_gather, exp_output_idx): self_interaction=self_interaction, skip_gather=skip_gather ) output = layer(self.input) - self.assertAllClose(output, self.exp_outputs[exp_output_idx]) + self.assertAllClose( + output, + self.exp_outputs[exp_output_idx], + tpu_atol=1e-2, + tpu_rtol=1e-2) def test_invalid_input_rank(self): rank_1_input = [ops.ones((3,)), ops.ones((3,))] @@ -130,7 +127,7 @@ def test_invalid_input_different_shapes(self): ), ) def test_predict(self, self_interaction, skip_gather): - with self._strategy.scope(): + with self.strategy.scope(): feature1 = keras.layers.Input(shape=(5,)) feature2 = keras.layers.Input(shape=(5,)) feature3 = keras.layers.Input(shape=(5,)) diff --git a/keras_rs/src/layers/feature_interaction/feature_cross_test.py b/keras_rs/src/layers/feature_interaction/feature_cross_test.py index d81afb2a..37148648 100644 --- a/keras_rs/src/layers/feature_interaction/feature_cross_test.py +++ b/keras_rs/src/layers/feature_interaction/feature_cross_test.py @@ -1,5 +1,4 @@ import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -13,9 +12,6 @@ class FeatureCrossTest(testing.TestCase, parameterized.TestCase): def setUp(self): super().setUp() - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() - self._strategy = tpu_test_utils.get_tpu_strategy(self) self.x0 = ops.array([[0.1, 0.2, 0.3]], dtype="float32") self.x = ops.array([[0.4, 0.5, 0.6]], dtype="float32") @@ -84,7 +80,7 @@ def test_pre_activation(self): self.assertAllClose(self.x, output) def test_predict(self): - with self._strategy.scope(): + with self.strategy.scope(): x0 = keras.layers.Input(shape=(3,)) x1 = FeatureCross(projection_dim=None)(x0, x0) x2 = FeatureCross(projection_dim=None)(x0, x1) diff --git a/keras_rs/src/layers/retrieval/hard_negative_mining_test.py b/keras_rs/src/layers/retrieval/hard_negative_mining_test.py index 431e3d34..794a97cc 100644 --- a/keras_rs/src/layers/retrieval/hard_negative_mining_test.py +++ b/keras_rs/src/layers/retrieval/hard_negative_mining_test.py @@ -1,5 +1,4 @@ import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -13,10 +12,6 @@ class HardNegativeMiningTest(testing.TestCase, parameterized.TestCase): def setUp(self): super().setUp() - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() - - self._strategy = tpu_test_utils.get_tpu_strategy(self) def create_inputs(self, rank=2): shape_3d = (15, 20, 10) @@ -98,7 +93,7 @@ def test_call(self, rank, num_hard_negatives): def test_predict(self): logits, labels = self.create_inputs() - with self._strategy.scope(): + with self.strategy.scope(): in_logits = keras.layers.Input(shape=logits.shape[1:]) in_labels = keras.layers.Input(shape=labels.shape[1:]) out_logits, out_labels = hard_negative_mining.HardNegativeMining( diff --git a/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py b/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py index 4a5e5fdd..6b8780cd 100644 --- a/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py +++ b/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py @@ -1,5 +1,4 @@ import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -13,10 +12,6 @@ class RemoveAccidentalHitsTest(testing.TestCase, parameterized.TestCase): def setUp(self): super().setUp() - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() - - self._strategy = tpu_test_utils.get_tpu_strategy(self) def create_inputs(self, logits_rank=2, candidate_ids_rank=1): shape_3d = (15, 20, 10) @@ -160,7 +155,7 @@ def test_predict(self): # Note: for predict, we test with probabilities that have a batch dim. logits, labels, candidate_ids = self.create_inputs(candidate_ids_rank=2) - with self._strategy.scope(): + with self.strategy.scope(): layer = remove_accidental_hits.RemoveAccidentalHits() in_logits = keras.layers.Input(logits.shape[1:]) in_labels = keras.layers.Input(labels.shape[1:]) diff --git a/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py b/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py index 9b4fee3d..9c896c8e 100644 --- a/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py +++ b/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py @@ -1,5 +1,4 @@ import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.layers import deserialize @@ -15,10 +14,6 @@ class SamplingProbabilityCorrectionTest( ): def setUp(self): super().setUp() - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() - - self._strategy = tpu_test_utils.get_tpu_strategy(self) def create_inputs(self, logits_rank=2, probs_rank=1): shape_3d = (15, 20, 10) @@ -96,7 +91,7 @@ def test_predict(self): # Note: for predict, we test with probabilities that have a batch dim. logits, probs = self.create_inputs(probs_rank=2) - with self._strategy.scope(): + with self.strategy.scope(): layer = ( sampling_probability_correction.SamplingProbabilityCorrection() ) diff --git a/keras_rs/src/losses/list_mle_loss_test.py b/keras_rs/src/losses/list_mle_loss_test.py index bdf4ba92..b28ad859 100644 --- a/keras_rs/src/losses/list_mle_loss_test.py +++ b/keras_rs/src/losses/list_mle_loss_test.py @@ -1,5 +1,4 @@ import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.losses import deserialize @@ -12,10 +11,6 @@ class ListMLELossTest(testing.TestCase, parameterized.TestCase): def setUp(self): - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() - self._strategy = tpu_test_utils.get_tpu_strategy(self) - self.unbatched_scores = ops.array( [1.0, 3.0, 2.0, 4.0, 0.8], dtype="float32" ) @@ -96,21 +91,16 @@ def create_model(): model.compile(loss=ListMLELoss(), optimizer="adam") return model - if self._strategy: - with self._strategy.scope(): + if self.strategy: + with self.strategy.scope(): model = create_model() else: model = create_model() - x_data = keras.random.normal((2, 20)) - y_data = keras.random.randint((2, 5), minval=0, maxval=2) - dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch( - self._strategy.num_replicas_in_sync if self._strategy else 1 + model.fit( + x=keras.random.normal((2, 20)), + y=keras.random.randint((2, 5), minval=0, maxval=2), ) - if self._strategy: - dataset = self._strategy.experimental_distribute_dataset(dataset) - - model.fit(dataset, epochs=1, steps_per_epoch=2) def test_serialization(self): loss = ListMLELoss(temperature=0.8) diff --git a/keras_rs/src/losses/pairwise_hinge_loss_test.py b/keras_rs/src/losses/pairwise_hinge_loss_test.py index a7c6e167..81577634 100644 --- a/keras_rs/src/losses/pairwise_hinge_loss_test.py +++ b/keras_rs/src/losses/pairwise_hinge_loss_test.py @@ -1,5 +1,4 @@ import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.losses import deserialize @@ -12,10 +11,6 @@ class PairwiseHingeLossTest(testing.TestCase, parameterized.TestCase): def setUp(self): - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() - self._strategy = tpu_test_utils.get_tpu_strategy(self) - self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) @@ -123,21 +118,16 @@ def create_model(): model.compile(loss=PairwiseHingeLoss(), optimizer="adam") return model - if self._strategy: - with self._strategy.scope(): + if self.strategy: + with self.strategy.scope(): model = create_model() else: model = create_model() - x_data = keras.random.normal((2, 20)) - y_data = keras.random.randint((2, 5), minval=0, maxval=2) - dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch( - self._strategy.num_replicas_in_sync if self._strategy else 1 + model.fit( + x=keras.random.normal((2, 20)), + y=keras.random.randint((2, 5), minval=0, maxval=2), ) - if self._strategy: - dataset = self._strategy.experimental_distribute_dataset(dataset) - - model.fit(dataset, epochs=1, steps_per_epoch=2) def test_serialization(self): loss = PairwiseHingeLoss() diff --git a/keras_rs/src/losses/pairwise_logistic_loss_test.py b/keras_rs/src/losses/pairwise_logistic_loss_test.py index 07c743d8..0d10643b 100644 --- a/keras_rs/src/losses/pairwise_logistic_loss_test.py +++ b/keras_rs/src/losses/pairwise_logistic_loss_test.py @@ -1,5 +1,4 @@ import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.losses import deserialize @@ -12,9 +11,6 @@ class PairwiseLogisticLossTest(testing.TestCase, parameterized.TestCase): def setUp(self): - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() - self._strategy = tpu_test_utils.get_tpu_strategy(self) self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) @@ -122,21 +118,16 @@ def create_model(): model.compile(loss=PairwiseLogisticLoss(), optimizer="adam") return model - if self._strategy: - with self._strategy.scope(): + if self.strategy: + with self.strategy.scope(): model = create_model() else: model = create_model() - x_data = keras.random.normal((2, 20)) - y_data = keras.random.randint((2, 5), minval=0, maxval=2) - dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch( - self._strategy.num_replicas_in_sync if self._strategy else 1 + model.fit( + x=keras.random.normal((2, 20)), + y=keras.random.randint((2, 5), minval=0, maxval=2), ) - if self._strategy: - dataset = self._strategy.experimental_distribute_dataset(dataset) - - model.fit(dataset, epochs=1, steps_per_epoch=2) def test_serialization(self): loss = PairwiseLogisticLoss() diff --git a/keras_rs/src/losses/pairwise_mean_squared_error_test.py b/keras_rs/src/losses/pairwise_mean_squared_error_test.py index bdbb3da3..d5524da1 100644 --- a/keras_rs/src/losses/pairwise_mean_squared_error_test.py +++ b/keras_rs/src/losses/pairwise_mean_squared_error_test.py @@ -1,5 +1,4 @@ import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.losses import deserialize @@ -14,10 +13,6 @@ class PairwiseMeanSquaredErrorTest(testing.TestCase, parameterized.TestCase): def setUp(self): - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() - self._strategy = tpu_test_utils.get_tpu_strategy(self) - self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) @@ -122,21 +117,16 @@ def create_model(): model.compile(loss=PairwiseMeanSquaredError(), optimizer="adam") return model - if self._strategy: - with self._strategy.scope(): + if self.strategy: + with self.strategy.scope(): model = create_model() else: model = create_model() - x_data = keras.random.normal((2, 20)) - y_data = keras.random.randint((2, 5), minval=0, maxval=2) - dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch( - self._strategy.num_replicas_in_sync if self._strategy else 1 + model.fit( + x=keras.random.normal((2, 20)), + y=keras.random.randint((2, 5), minval=0, maxval=2), ) - if self._strategy: - dataset = self._strategy.experimental_distribute_dataset(dataset) - - model.fit(dataset, epochs=1, steps_per_epoch=2) def test_serialization(self): loss = PairwiseMeanSquaredError() diff --git a/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py b/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py index 0ca05646..e3ced187 100644 --- a/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py +++ b/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py @@ -1,5 +1,4 @@ import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.losses import deserialize @@ -14,10 +13,6 @@ class PairwiseSoftZeroOneLossTest(testing.TestCase, parameterized.TestCase): def setUp(self): - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() - self._strategy = tpu_test_utils.get_tpu_strategy(self) - self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) @@ -125,21 +120,16 @@ def create_model(): model.compile(loss=PairwiseSoftZeroOneLoss(), optimizer="adam") return model - if self._strategy: - with self._strategy.scope(): + if self.strategy: + with self.strategy.scope(): model = create_model() else: model = create_model() - x_data = keras.random.normal((2, 20)) - y_data = keras.random.randint((2, 5), minval=0, maxval=2) - dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch( - self._strategy.num_replicas_in_sync if self._strategy else 1 + model.fit( + x=keras.random.normal((2, 20)), + y=keras.random.randint((2, 5), minval=0, maxval=2), ) - if self._strategy: - dataset = self._strategy.experimental_distribute_dataset(dataset) - - model.fit(dataset, epochs=1, steps_per_epoch=2) def test_serialization(self): loss = PairwiseSoftZeroOneLoss() diff --git a/keras_rs/src/metrics/dcg_test.py b/keras_rs/src/metrics/dcg_test.py index 1a1341bc..73b7c6f5 100644 --- a/keras_rs/src/metrics/dcg_test.py +++ b/keras_rs/src/metrics/dcg_test.py @@ -1,7 +1,6 @@ import math import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.metrics import deserialize @@ -21,8 +20,6 @@ def _compute_dcg(labels, ranks): class DCGTest(testing.TestCase, parameterized.TestCase): def setUp(self): - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() self._strategy = tpu_test_utils.get_tpu_strategy(self) self.y_true_batched = ops.array( @@ -56,19 +53,18 @@ def setUp(self): ) def test_invalid_k_init(self): - with self._strategy.scope(): - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - DCG(k=0) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - DCG(k=-5) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - DCG(k=3.5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + DCG(k=0) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + DCG(k=-5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + DCG(k=3.5) @parameterized.named_parameters( ( @@ -138,33 +134,14 @@ def test_invalid_k_init(self): def test_unbatched_inputs( self, y_true, y_pred, sample_weight, expected_output ): - with self._strategy.scope(): - dcg_metric = DCG() - y_true_t = ops.array(y_true, dtype="float32") - y_pred_t = ops.array(y_pred, dtype="float32") - sw = ( - ops.array(sample_weight, dtype="float32") - if sample_weight is not None - else None - ) - args = ( - (y_true_t, y_pred_t, sw) if sw is not None else (y_true_t, y_pred_t) - ) - tpu_test_utils.run_with_strategy( - self._strategy, dcg_metric.update_state, *args - ) + dcg_metric = DCG() + dcg_metric.update_state(y_true, y_pred, sample_weight=sample_weight) result = dcg_metric.result() self.assertAllClose(result, expected_output) def test_batched_input(self): - with self._strategy.scope(): - dcg_metric = DCG() - tpu_test_utils.run_with_strategy( - self._strategy, - dcg_metric.update_state, - self.y_true_batched, - self.y_pred_batched, - ) + dcg_metric = DCG() + dcg_metric.update_state(self.y_true_batched, self.y_pred_batched) result = dcg_metric.result() self.assertAllClose(result, self.expected_output_batched) @@ -174,15 +151,11 @@ def test_batched_input(self): ("1d", [1.0, 0.5, 2.0, 1.0], 2.7288804), ) def test_batched_inputs_sample_weight(self, sample_weight, expected_output): - with self._strategy.scope(): - dcg_metric = DCG() - sw = ops.array(sample_weight, dtype="float32") - tpu_test_utils.run_with_strategy( - self._strategy, - dcg_metric.update_state, + dcg_metric = DCG() + dcg_metric.update_state( self.y_true_batched, self.y_pred_batched, - sample_weight=sw, + sample_weight=sample_weight, ) result = dcg_metric.result() self.assertAllClose(result, expected_output) @@ -245,15 +218,9 @@ def test_batched_inputs_sample_weight(self, sample_weight, expected_output): def test_2d_sample_weight( self, y_true, y_pred, sample_weight, expected_output ): - with self._strategy.scope(): - dcg_metric = DCG() - tpu_test_utils.run_with_strategy( - self._strategy, - dcg_metric.update_state, - y_true, - y_pred, - sample_weight=sample_weight, - ) + dcg_metric = DCG() + + dcg_metric.update_state(y_true, y_pred, sample_weight=sample_weight) result = dcg_metric.result() self.assertAllClose(result, expected_output) @@ -314,15 +281,9 @@ def test_2d_sample_weight( ), ) def test_masking(self, y_true, y_pred, sample_weight, expected_output): - with self._strategy.scope(): - dcg_metric = DCG() - tpu_test_utils.run_with_strategy( - self._strategy, - dcg_metric.update_state, - y_true, - y_pred, - sample_weight=sample_weight, - ) + dcg_metric = DCG() + + dcg_metric.update_state(y_true, y_pred, sample_weight=sample_weight) result = dcg_metric.result() self.assertAllClose(result, expected_output) @@ -333,27 +294,16 @@ def test_masking(self, y_true, y_pred, sample_weight, expected_output): ("4", 4, 3.39040), ) def test_k(self, k, exp_value): - with self._strategy.scope(): - dcg_metric = DCG(k=k) - tpu_test_utils.run_with_strategy( - self._strategy, - dcg_metric.update_state, - self.y_true_batched, - self.y_pred_batched, - ) + dcg_metric = DCG(k=k) dcg_metric.update_state(self.y_true_batched, self.y_pred_batched) result = dcg_metric.result() self.assertAllClose(result, exp_value, rtol=1e-5) def test_statefulness(self): - with self._strategy.scope(): - dcg_metric = DCG() + dcg_metric = DCG() # Batch 1 - tpu_test_utils.run_with_strategy( - self._strategy, - dcg_metric.update_state, - self.y_true_batched[:2], - self.y_pred_batched[:2], + dcg_metric.update_state( + self.y_true_batched[:2], self.y_pred_batched[:2] ) result = dcg_metric.result() self.assertAllClose( @@ -363,11 +313,8 @@ def test_statefulness(self): ) # Batch 2 - tpu_test_utils.run_with_strategy( - self._strategy, - dcg_metric.update_state, - self.y_true_batched[2:], - self.y_pred_batched[2:], + dcg_metric.update_state( + self.y_true_batched[2:], self.y_pred_batched[2:] ) result = dcg_metric.result() self.assertAllClose(result, self.expected_output_batched) @@ -378,8 +325,7 @@ def test_statefulness(self): self.assertAllClose(result, 0.0) def test_serialization(self): - with self._strategy.scope(): - metric = DCG() + metric = DCG() restored = deserialize(serialize(metric)) self.assertDictEqual(metric.get_config(), restored.get_config()) @@ -390,16 +336,10 @@ def linear_gain_fn(label): def inverse_discount_fn(rank): return ops.divide(1.0, rank) - with self._strategy.scope(): - dcg_metric = DCG( - gain_fn=linear_gain_fn, rank_discount_fn=inverse_discount_fn - ) - tpu_test_utils.run_with_strategy( - self._strategy, - dcg_metric.update_state, - self.y_true_batched, - self.y_pred_batched, + dcg_metric = DCG( + gain_fn=linear_gain_fn, rank_discount_fn=inverse_discount_fn ) + dcg_metric.update_state(self.y_true_batched, self.y_pred_batched) result = dcg_metric.result() expected_output = ( @@ -419,17 +359,7 @@ def test_model_evaluate(self): optimizer="adam", ) - x_data = keras.random.normal((2, 20)) - y_data = keras.random.randint((2, 5), minval=0, maxval=4) - - dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) - dataset = dataset.batch( - self._strategy.num_replicas_in_sync - if isinstance(self._strategy, tf.distribute.Strategy) - else 1 - ) - - if isinstance(self._strategy, tf.distribute.TPUStrategy): - dataset = self._strategy.experimental_distribute_dataset(dataset) - - model.evaluate(dataset, steps=2) + model.evaluate( + x=keras.random.normal((2, 20)), + y=keras.random.randint((2, 5), minval=0, maxval=4), + ) \ No newline at end of file diff --git a/keras_rs/src/metrics/mean_average_precision_test.py b/keras_rs/src/metrics/mean_average_precision_test.py index 39c6bcb8..7161534a 100644 --- a/keras_rs/src/metrics/mean_average_precision_test.py +++ b/keras_rs/src/metrics/mean_average_precision_test.py @@ -1,5 +1,4 @@ import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.metrics import deserialize @@ -12,8 +11,6 @@ class MeanAveragePrecisionTest(testing.TestCase, parameterized.TestCase): def setUp(self): - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() self._strategy = tpu_test_utils.get_tpu_strategy(self) self.y_true_batched = ops.array( @@ -36,19 +33,18 @@ def setUp(self): ) def test_invalid_k_init(self): - with self._strategy.scope(): - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - MeanAveragePrecision(k=0) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - MeanAveragePrecision(k=-5) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - MeanAveragePrecision(k=3.5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + MeanAveragePrecision(k=0) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + MeanAveragePrecision(k=-5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + MeanAveragePrecision(k=3.5) @parameterized.named_parameters( ( @@ -111,33 +107,14 @@ def test_invalid_k_init(self): def test_unbatched_inputs( self, y_true, y_pred, sample_weight, expected_output ): - with self._strategy.scope(): - map_metric = MeanAveragePrecision() - y_true_t = ops.array(y_true, dtype="float32") - y_pred_t = ops.array(y_pred, dtype="float32") - sw = ( - ops.array(sample_weight, dtype="float32") - if sample_weight is not None - else None - ) - args = ( - (y_true_t, y_pred_t, sw) if sw is not None else (y_true_t, y_pred_t) - ) - tpu_test_utils.run_with_strategy( - self._strategy, map_metric.update_state, *args - ) + map_metric = MeanAveragePrecision() + map_metric.update_state(y_true, y_pred, sample_weight=sample_weight) result = map_metric.result() self.assertAllClose(result, expected_output) def test_batched_input(self): - with self._strategy.scope(): - map_metric = MeanAveragePrecision() - tpu_test_utils.run_with_strategy( - self._strategy, - map_metric.update_state, - self.y_true_batched, - self.y_pred_batched, - ) + map_metric = MeanAveragePrecision() + map_metric.update_state(self.y_true_batched, self.y_pred_batched) result = map_metric.result() self.assertAllClose(result, 0.5625) @@ -147,15 +124,11 @@ def test_batched_input(self): ("1d", [1.0, 0.5, 2.0, 1.0], 0.6), ) def test_batched_inputs_sample_weight(self, sample_weight, expected_output): - with self._strategy.scope(): - map_metric = MeanAveragePrecision() - sw = ops.array(sample_weight, dtype="float32") - tpu_test_utils.run_with_strategy( - self._strategy, - map_metric.update_state, + map_metric = MeanAveragePrecision() + map_metric.update_state( self.y_true_batched, self.y_pred_batched, - sample_weight=sw, + sample_weight=sample_weight, ) result = map_metric.result() self.assertAllClose(result, expected_output) @@ -193,15 +166,9 @@ def test_batched_inputs_sample_weight(self, sample_weight, expected_output): def test_2d_sample_weight( self, y_true, y_pred, sample_weight, expected_output ): - with self._strategy.scope(): - map_metric = MeanAveragePrecision() - tpu_test_utils.run_with_strategy( - self._strategy, - map_metric.update_state, - y_true, - y_pred, - sample_weight=sample_weight, - ) + map_metric = MeanAveragePrecision() + + map_metric.update_state(y_true, y_pred, sample_weight=sample_weight) result = map_metric.result() self.assertAllClose(result, expected_output) @@ -242,15 +209,9 @@ def test_2d_sample_weight( ), ) def test_masking(self, y_true, y_pred, sample_weight, expected_output): - with self._strategy.scope(): - map_metric = MeanAveragePrecision() - tpu_test_utils.run_with_strategy( - self._strategy, - map_metric.update_state, - y_true, - y_pred, - sample_weight=sample_weight, - ) + map_metric = MeanAveragePrecision() + + map_metric.update_state(y_true, y_pred, sample_weight=sample_weight) result = map_metric.result() self.assertAllClose(result, expected_output) @@ -267,24 +228,17 @@ def test_k(self, k, expected_map): self.assertAllClose(result, expected_map) def test_statefulness(self): - with self._strategy.scope(): - map_metric = MeanAveragePrecision() + map_metric = MeanAveragePrecision() # Batch 1: First two lists - tpu_test_utils.run_with_strategy( - self._strategy, - map_metric.update_state, - self.y_true_batched[:2], - self.y_pred_batched[:2], + map_metric.update_state( + self.y_true_batched[:2], self.y_pred_batched[:2] ) result = map_metric.result() self.assertAllClose(result, 0.75) # Batch 2: Last two lists - tpu_test_utils.run_with_strategy( - self._strategy, - map_metric.update_state, - self.y_true_batched[2:], - self.y_pred_batched[2:], + map_metric.update_state( + self.y_true_batched[2:], self.y_pred_batched[2:] ) result = map_metric.result() self.assertAllClose(result, 0.5625) @@ -299,11 +253,8 @@ def test_statefulness(self): ("weight_0", 0.0, 0.0), ) def test_scalar_sample_weight(self, sample_weight, expected_output): - with self._strategy.scope(): - map_metric = MeanAveragePrecision() - tpu_test_utils.run_with_strategy( - self._strategy, - map_metric.update_state, + map_metric = MeanAveragePrecision() + map_metric.update_state( self.y_true_batched, self.y_pred_batched, sample_weight=sample_weight, @@ -312,12 +263,9 @@ def test_scalar_sample_weight(self, sample_weight, expected_output): self.assertAllClose(result, expected_output) def test_1d_sample_weight(self): - with self._strategy.scope(): - map_metric = MeanAveragePrecision() + map_metric = MeanAveragePrecision() sample_weight = ops.array([1.0, 0.5, 2.0, 1.0], dtype="float32") - tpu_test_utils.run_with_strategy( - self._strategy, - map_metric.update_state, + map_metric.update_state( self.y_true_batched, self.y_pred_batched, sample_weight=sample_weight, @@ -326,8 +274,7 @@ def test_1d_sample_weight(self): self.assertAllClose(result, 0.6) def test_serialization(self): - with self._strategy.scope(): - metric = MeanAveragePrecision() + metric = MeanAveragePrecision() restored = deserialize(serialize(metric)) self.assertDictEqual(metric.get_config(), restored.get_config()) @@ -342,17 +289,7 @@ def test_model_evaluate(self): metrics=[MeanAveragePrecision()], optimizer="adam", ) - x_data = keras.random.normal((2, 20)) - y_data = keras.random.randint((2, 5), minval=0, maxval=4) - - dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) - dataset = dataset.batch( - self._strategy.num_replicas_in_sync - if isinstance(self._strategy, tf.distribute.Strategy) - else 1 + model.evaluate( + x=keras.random.normal((2, 20)), + y=keras.random.randint((2, 5), minval=0, maxval=4), ) - - if isinstance(self._strategy, tf.distribute.TPUStrategy): - dataset = self._strategy.experimental_distribute_dataset(dataset) - - model.evaluate(dataset, steps=2) diff --git a/keras_rs/src/metrics/mean_reciprocal_rank_test.py b/keras_rs/src/metrics/mean_reciprocal_rank_test.py index c78aef34..3d5264bc 100644 --- a/keras_rs/src/metrics/mean_reciprocal_rank_test.py +++ b/keras_rs/src/metrics/mean_reciprocal_rank_test.py @@ -1,5 +1,4 @@ import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.metrics import deserialize @@ -12,8 +11,6 @@ class MeanReciprocalRankTest(testing.TestCase, parameterized.TestCase): def setUp(self): - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() self._strategy = tpu_test_utils.get_tpu_strategy(self) self.y_true_batched = ops.array( @@ -36,19 +33,18 @@ def setUp(self): ) def test_invalid_k_init(self): - with self._strategy.scope(): - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - MeanReciprocalRank(k=0) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - MeanReciprocalRank(k=-5) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - MeanReciprocalRank(k=3.5) # type: ignore + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + MeanReciprocalRank(k=0) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + MeanReciprocalRank(k=-5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + MeanReciprocalRank(k=3.5) # type: ignore @parameterized.named_parameters( ( @@ -111,27 +107,14 @@ def test_invalid_k_init(self): def test_unbatched_inputs( self, y_true, y_pred, sample_weight, expected_output ): - with self._strategy.scope(): - mrr_metric = MeanReciprocalRank() - tpu_test_utils.run_with_strategy( - self._strategy, - mrr_metric.update_state, - y_true, - y_pred, - sample_weight=sample_weight, - ) + mrr_metric = MeanReciprocalRank() + mrr_metric.update_state(y_true, y_pred, sample_weight=sample_weight) result = mrr_metric.result() self.assertAllClose(result, expected_output) def test_batched_input(self): - with self._strategy.scope(): - mrr_metric = MeanReciprocalRank() - tpu_test_utils.run_with_strategy( - self._strategy, - mrr_metric.update_state, - self.y_true_batched, - self.y_pred_batched, - ) + mrr_metric = MeanReciprocalRank() + mrr_metric.update_state(self.y_true_batched, self.y_pred_batched) result = mrr_metric.result() self.assertAllClose(result, 0.625) @@ -141,11 +124,8 @@ def test_batched_input(self): ("1d", [1.0, 0.5, 2.0, 1.0], 0.675), ) def test_batched_inputs_sample_weight(self, sample_weight, expected_output): - with self._strategy.scope(): - mrr_metric = MeanReciprocalRank() - tpu_test_utils.run_with_strategy( - self._strategy, - mrr_metric.update_state, + mrr_metric = MeanReciprocalRank() + mrr_metric.update_state( self.y_true_batched, self.y_pred_batched, sample_weight=sample_weight, @@ -186,15 +166,9 @@ def test_batched_inputs_sample_weight(self, sample_weight, expected_output): def test_2d_sample_weight( self, y_true, y_pred, sample_weight, expected_output ): - with self._strategy.scope(): - mrr_metric = MeanReciprocalRank() - tpu_test_utils.run_with_strategy( - self._strategy, - mrr_metric.update_state, - y_true, - y_pred, - sample_weight=sample_weight, - ) + mrr_metric = MeanReciprocalRank() + + mrr_metric.update_state(y_true, y_pred, sample_weight=sample_weight) result = mrr_metric.result() self.assertAllClose(result, expected_output) @@ -235,15 +209,9 @@ def test_2d_sample_weight( ), ) def test_masking(self, y_true, y_pred, sample_weight, expected_output): - with self._strategy.scope(): - mrr_metric = MeanReciprocalRank() - tpu_test_utils.run_with_strategy( - self._strategy, - mrr_metric.update_state, - y_true, - y_pred, - sample_weight=sample_weight, - ) + mrr_metric = MeanReciprocalRank() + + mrr_metric.update_state(y_true, y_pred, sample_weight=sample_weight) result = mrr_metric.result() self.assertAllClose(result, expected_output) @@ -251,36 +219,23 @@ def test_masking(self, y_true, y_pred, sample_weight, expected_output): ("1", 1, 0.5), ("2", 2, 0.625), ("3", 3, 0.625), ("4", 4, 0.625) ) def test_k(self, k, expected_mrr): - with self._strategy.scope(): - mrr_metric = MeanReciprocalRank(k=k) - tpu_test_utils.run_with_strategy( - self._strategy, - mrr_metric.update_state, - self.y_true_batched, - self.y_pred_batched, - ) + mrr_metric = MeanReciprocalRank(k=k) + mrr_metric.update_state(self.y_true_batched, self.y_pred_batched) result = mrr_metric.result() self.assertAllClose(result, expected_mrr) def test_statefulness(self): - with self._strategy.scope(): - mrr_metric = MeanReciprocalRank() + mrr_metric = MeanReciprocalRank() # Batch 1: First two lists - tpu_test_utils.run_with_strategy( - self._strategy, - mrr_metric.update_state, - self.y_true_batched[:2], - self.y_pred_batched[:2], + mrr_metric.update_state( + self.y_true_batched[:2], self.y_pred_batched[:2] ) result = mrr_metric.result() self.assertAllClose(result, 0.75) # Batch 2: Last two lists - tpu_test_utils.run_with_strategy( - self._strategy, - mrr_metric.update_state, - self.y_true_batched[2:], - self.y_pred_batched[2:], + mrr_metric.update_state( + self.y_true_batched[2:], self.y_pred_batched[2:] ) result = mrr_metric.result() self.assertAllClose(result, 0.625) @@ -291,8 +246,7 @@ def test_statefulness(self): self.assertAllClose(result, 0.0) def test_serialization(self): - with self._strategy.scope(): - metric = MeanReciprocalRank() + metric = MeanReciprocalRank() restored = deserialize(serialize(metric)) self.assertDictEqual(metric.get_config(), restored.get_config()) @@ -308,17 +262,7 @@ def test_model_evaluate(self): optimizer="adam", ) - x_data = keras.random.normal((2, 20)) - y_data = keras.random.randint((2, 5), minval=0, maxval=4) - - dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) - dataset = dataset.batch( - self._strategy.num_replicas_in_sync - if isinstance(self._strategy, tf.distribute.Strategy) - else 1 + model.evaluate( + x=keras.random.normal((2, 20)), + y=keras.random.randint((2, 5), minval=0, maxval=4), ) - - if isinstance(self._strategy, tf.distribute.TPUStrategy): - dataset = self._strategy.experimental_distribute_dataset(dataset) - - model.evaluate(dataset, steps=2) diff --git a/keras_rs/src/metrics/ndcg_test.py b/keras_rs/src/metrics/ndcg_test.py index 76b36b0c..68fc2c96 100644 --- a/keras_rs/src/metrics/ndcg_test.py +++ b/keras_rs/src/metrics/ndcg_test.py @@ -1,7 +1,6 @@ import math import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.metrics import deserialize @@ -21,8 +20,6 @@ def _compute_dcg(labels, ranks): class NDCGTest(testing.TestCase, parameterized.TestCase): def setUp(self): - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() self._strategy = tpu_test_utils.get_tpu_strategy(self) self.y_true_batched = ops.array( @@ -63,19 +60,18 @@ def setUp(self): self.expected_output_batched = sum(expected_ndcg) / 4.0 def test_invalid_k_init(self): - with self._strategy.scope(): - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - NDCG(k=0) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - NDCG(k=-5) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - NDCG(k=3.5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + NDCG(k=0) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + NDCG(k=-5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + NDCG(k=3.5) @parameterized.named_parameters( ( @@ -148,28 +144,15 @@ def test_invalid_k_init(self): def test_unbatched_inputs( self, y_true, y_pred, sample_weight, expected_output ): - with self._strategy.scope(): - ndcg_metric = NDCG() - tpu_test_utils.run_with_strategy( - self._strategy, - ndcg_metric.update_state, - y_true, - y_pred, - sample_weight=sample_weight, - ) - result = ndcg_metric.result() + dcg_metric = NDCG() + dcg_metric.update_state(y_true, y_pred, sample_weight=sample_weight) + result = dcg_metric.result() self.assertAllClose(result, expected_output) def test_batched_input(self): - with self._strategy.scope(): - ndcg_metric = NDCG() - tpu_test_utils.run_with_strategy( - self._strategy, - ndcg_metric.update_state, - self.y_true_batched, - self.y_pred_batched, - ) - result = ndcg_metric.result() + dcg_metric = NDCG() + dcg_metric.update_state(self.y_true_batched, self.y_pred_batched) + result = dcg_metric.result() self.assertAllClose(result, self.expected_output_batched) @parameterized.named_parameters( @@ -178,16 +161,13 @@ def test_batched_input(self): ("1d", [1.0, 0.5, 2.0, 1.0], 0.74262), ) def test_batched_inputs_sample_weight(self, sample_weight, expected_output): - with self._strategy.scope(): - ndcg_metric = NDCG() - tpu_test_utils.run_with_strategy( - self._strategy, - ndcg_metric.update_state, + dcg_metric = NDCG() + dcg_metric.update_state( self.y_true_batched, self.y_pred_batched, sample_weight=sample_weight, ) - result = ndcg_metric.result() + result = dcg_metric.result() self.assertAllClose(result, expected_output) @parameterized.named_parameters( @@ -248,16 +228,10 @@ def test_batched_inputs_sample_weight(self, sample_weight, expected_output): def test_2d_sample_weight( self, y_true, y_pred, sample_weight, expected_output ): - with self._strategy.scope(): - ndcg_metric = NDCG() - tpu_test_utils.run_with_strategy( - self._strategy, - ndcg_metric.update_state, - y_true, - y_pred, - sample_weight=sample_weight, - ) - result = ndcg_metric.result() + dcg_metric = NDCG() + + dcg_metric.update_state(y_true, y_pred, sample_weight=sample_weight) + result = dcg_metric.result() self.assertAllClose(result, expected_output) @parameterized.named_parameters( @@ -317,16 +291,10 @@ def test_2d_sample_weight( ), ) def test_masking(self, y_true, y_pred, sample_weight, expected_output): - with self._strategy.scope(): - ndcg_metric = NDCG() - tpu_test_utils.run_with_strategy( - self._strategy, - ndcg_metric.update_state, - y_true, - y_pred, - sample_weight=sample_weight, - ) - result = ndcg_metric.result() + dcg_metric = NDCG() + + dcg_metric.update_state(y_true, y_pred, sample_weight=sample_weight) + result = dcg_metric.result() self.assertAllClose(result, expected_output) @parameterized.named_parameters( @@ -336,28 +304,18 @@ def test_masking(self, y_true, y_pred, sample_weight, expected_output): ("4", 4, 0.7377), ) def test_k(self, k, exp_value): - with self._strategy.scope(): - ndcg_metric = NDCG(k=k) - tpu_test_utils.run_with_strategy( - self._strategy, - ndcg_metric.update_state, - self.y_true_batched, - self.y_pred_batched, - ) - result = ndcg_metric.result() + dcg_metric = NDCG(k=k) + dcg_metric.update_state(self.y_true_batched, self.y_pred_batched) + result = dcg_metric.result() self.assertAllClose(result, exp_value, rtol=1e-5) def test_statefulness(self): - with self._strategy.scope(): - ndcg_metric = NDCG() + dcg_metric = NDCG() # Batch 1 - tpu_test_utils.run_with_strategy( - self._strategy, - ndcg_metric.update_state, - self.y_true_batched[:2], - self.y_pred_batched[:2], + dcg_metric.update_state( + self.y_true_batched[:2], self.y_pred_batched[:2] ) - result = ndcg_metric.result() + result = dcg_metric.result() dcg = [_compute_dcg([1], [1]), _compute_dcg([3, 2, 1], [1, 3, 4])] idcg = [_compute_dcg([1], [1]), _compute_dcg([3, 2, 1], [1, 2, 3])] ndcg = sum([a / b if b != 0.0 else 0.0 for a, b in zip(dcg, idcg)]) / 2 @@ -367,23 +325,19 @@ def test_statefulness(self): ) # Batch 2 - tpu_test_utils.run_with_strategy( - self._strategy, - ndcg_metric.update_state, - self.y_true_batched[2:], - self.y_pred_batched[2:], + dcg_metric.update_state( + self.y_true_batched[2:], self.y_pred_batched[2:] ) - result = ndcg_metric.result() + result = dcg_metric.result() self.assertAllClose(result, self.expected_output_batched) # Reset state - ndcg_metric.reset_state() - result = ndcg_metric.result() + dcg_metric.reset_state() + result = dcg_metric.result() self.assertAllClose(result, 0.0) def test_serialization(self): - with self._strategy.scope(): - metric = NDCG() + metric = NDCG() restored = deserialize(serialize(metric)) self.assertDictEqual(metric.get_config(), restored.get_config()) @@ -394,17 +348,11 @@ def linear_gain_fn(label): def inverse_discount_fn(rank): return ops.divide(1.0, rank) - with self._strategy.scope(): - ndcg_metric = NDCG( - gain_fn=linear_gain_fn, rank_discount_fn=inverse_discount_fn - ) - tpu_test_utils.run_with_strategy( - self._strategy, - ndcg_metric.update_state, - self.y_true_batched, - self.y_pred_batched, + dcg_metric = NDCG( + gain_fn=linear_gain_fn, rank_discount_fn=inverse_discount_fn ) - result = ndcg_metric.result() + dcg_metric.update_state(self.y_true_batched, self.y_pred_batched) + result = dcg_metric.result() dcg = [1 / 1, 3 / 1 + 2 / 3 + 1 / 4, 0, 2 / 1 + 1 / 2] idcg = [1 / 1, 3 / 1 + 2 / 2 + 1 / 3, 0.0, 2 / 1 + 1 / 2] @@ -423,17 +371,7 @@ def test_model_evaluate(self): optimizer="adam", ) - x_data = keras.random.normal((2, 20)) - y_data = keras.random.randint((2, 5), minval=0, maxval=4) - - dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) - dataset = dataset.batch( - self._strategy.num_replicas_in_sync - if isinstance(self._strategy, tf.distribute.Strategy) - else 1 + model.evaluate( + x=keras.random.normal((2, 20)), + y=keras.random.randint((2, 5), minval=0, maxval=4), ) - - if isinstance(self._strategy, tf.distribute.TPUStrategy): - dataset = self._strategy.experimental_distribute_dataset(dataset) - - model.evaluate(dataset, steps=2) diff --git a/keras_rs/src/metrics/precision_at_k_test.py b/keras_rs/src/metrics/precision_at_k_test.py index 1e8805f6..62b8348a 100644 --- a/keras_rs/src/metrics/precision_at_k_test.py +++ b/keras_rs/src/metrics/precision_at_k_test.py @@ -1,5 +1,4 @@ import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.metrics import deserialize @@ -12,8 +11,6 @@ class PrecisionAtKTest(testing.TestCase, parameterized.TestCase): def setUp(self): - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() self._strategy = tpu_test_utils.get_tpu_strategy(self) self.y_true_batched = ops.array( @@ -36,19 +33,18 @@ def setUp(self): ) def test_invalid_k_init(self): - with self._strategy.scope(): - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - PrecisionAtK(k=0) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - PrecisionAtK(k=-5) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - PrecisionAtK(k=3.5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + PrecisionAtK(k=0) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + PrecisionAtK(k=-5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + PrecisionAtK(k=3.5) @parameterized.named_parameters( ( @@ -97,27 +93,14 @@ def test_invalid_k_init(self): def test_unbatched_inputs( self, y_true, y_pred, sample_weight, expected_output ): - with self._strategy.scope(): - p_at_k = PrecisionAtK(k=3) - tpu_test_utils.run_with_strategy( - self._strategy, - p_at_k.update_state, - y_true, - y_pred, - sample_weight=sample_weight, - ) + p_at_k = PrecisionAtK(k=3) + p_at_k.update_state(y_true, y_pred, sample_weight=sample_weight) result = p_at_k.result() self.assertAllClose(result, expected_output, rtol=1e-6) def test_batched_input(self): - with self._strategy.scope(): - p_at_k = PrecisionAtK(k=3) - tpu_test_utils.run_with_strategy( - self._strategy, - p_at_k.update_state, - self.y_true_batched, - self.y_pred_batched, - ) + p_at_k = PrecisionAtK(k=3) + p_at_k.update_state(self.y_true_batched, self.y_pred_batched) result = p_at_k.result() self.assertAllClose(result, 1 / 3) @@ -127,11 +110,8 @@ def test_batched_input(self): ("1d", [1.0, 0.5, 2.0, 1.0], 0.3), ) def test_batched_inputs_sample_weight(self, sample_weight, expected_output): - with self._strategy.scope(): - p_at_k = PrecisionAtK(k=3) - tpu_test_utils.run_with_strategy( - self._strategy, - p_at_k.update_state, + p_at_k = PrecisionAtK(k=3) + p_at_k.update_state( self.y_true_batched, self.y_pred_batched, sample_weight=sample_weight, @@ -172,15 +152,8 @@ def test_batched_inputs_sample_weight(self, sample_weight, expected_output): def test_2d_sample_weight( self, y_true, y_pred, sample_weight, expected_output ): - with self._strategy.scope(): - p_at_k = PrecisionAtK(k=3) - tpu_test_utils.run_with_strategy( - self._strategy, - p_at_k.update_state, - y_true, - y_pred, - sample_weight=sample_weight, - ) + p_at_k = PrecisionAtK(k=3) + p_at_k.update_state(y_true, y_pred, sample_weight=sample_weight) result = p_at_k.result() self.assertAllClose(result, expected_output, rtol=1e-6) @@ -221,15 +194,8 @@ def test_2d_sample_weight( ), ) def test_masking(self, y_true, y_pred, sample_weight, expected_output): - with self._strategy.scope(): - p_at_k = PrecisionAtK(k=3) - tpu_test_utils.run_with_strategy( - self._strategy, - p_at_k.update_state, - y_true, - y_pred, - sample_weight=sample_weight, - ) + p_at_k = PrecisionAtK(k=3) + p_at_k.update_state(y_true, y_pred, sample_weight=sample_weight) result = p_at_k.result() self.assertAllClose(result, expected_output, rtol=1e-6) @@ -240,35 +206,18 @@ def test_masking(self, y_true, y_pred, sample_weight, expected_output): ("4", 4, 0.375), ) def test_k(self, k, expected_precision): - with self._strategy.scope(): - p_at_k = PrecisionAtK(k=k) - tpu_test_utils.run_with_strategy( - self._strategy, - p_at_k.update_state, - self.y_true_batched, - self.y_pred_batched, - ) + p_at_k = PrecisionAtK(k=k) + p_at_k.update_state(self.y_true_batched, self.y_pred_batched) result = p_at_k.result() self.assertAllClose(result, expected_precision) def test_statefulness(self): - with self._strategy.scope(): - p_at_k = PrecisionAtK(k=3) - tpu_test_utils.run_with_strategy( - self._strategy, - p_at_k.update_state, - self.y_true_batched[:2], - self.y_pred_batched[:2], - ) + p_at_k = PrecisionAtK(k=3) + p_at_k.update_state(self.y_true_batched[:2], self.y_pred_batched[:2]) result = p_at_k.result() self.assertAllClose(result, 0.5, rtol=1e-6) - tpu_test_utils.run_with_strategy( - self._strategy, - p_at_k.update_state, - self.y_true_batched[2:], - self.y_pred_batched[2:], - ) + p_at_k.update_state(self.y_true_batched[2:], self.y_pred_batched[2:]) result = p_at_k.result() self.assertAllClose(result, 1 / 3) @@ -277,8 +226,7 @@ def test_statefulness(self): self.assertAllClose(result, 0.0) def test_serialization(self): - with self._strategy.scope(): - metric = PrecisionAtK(k=3) + metric = PrecisionAtK(k=3) restored = deserialize(serialize(metric)) self.assertDictEqual(metric.get_config(), restored.get_config()) @@ -294,17 +242,10 @@ def test_model_evaluate(self): optimizer="adam", ) - x_data = keras.random.normal((2, 20)) - y_data = keras.random.randint((2, 5), minval=0, maxval=2) - - dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) - dataset = dataset.batch( - self._strategy.num_replicas_in_sync - if isinstance(self._strategy, tf.distribute.Strategy) - else 1 + model.evaluate( + x=keras.random.normal((2, 20)), + y=keras.random.randint( + (2, 5), minval=0, maxval=2 + ), # Using 0/1 for y_true + verbose=0, ) - - if isinstance(self._strategy, tf.distribute.TPUStrategy): - dataset = self._strategy.experimental_distribute_dataset(dataset) - - model.evaluate(dataset, steps=2, verbose=0) diff --git a/keras_rs/src/metrics/recall_at_k_test.py b/keras_rs/src/metrics/recall_at_k_test.py index d536bafb..d397f9cd 100644 --- a/keras_rs/src/metrics/recall_at_k_test.py +++ b/keras_rs/src/metrics/recall_at_k_test.py @@ -1,5 +1,4 @@ import keras -import tensorflow as tf from absl.testing import parameterized from keras import ops from keras.metrics import deserialize @@ -12,8 +11,6 @@ class RecallAtKTest(testing.TestCase, parameterized.TestCase): def setUp(self): - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() self._strategy = tpu_test_utils.get_tpu_strategy(self) self.y_true_batched = ops.array( @@ -36,19 +33,18 @@ def setUp(self): ) def test_invalid_k_init(self): - with self._strategy.scope(): - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - RecallAtK(k=0) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - RecallAtK(k=-5) - with self.assertRaisesRegex( - ValueError, "`k` should be a positive integer" - ): - RecallAtK(k=3.5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + RecallAtK(k=0) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + RecallAtK(k=-5) + with self.assertRaisesRegex( + ValueError, "`k` should be a positive integer" + ): + RecallAtK(k=3.5) @parameterized.named_parameters( ( @@ -97,27 +93,14 @@ def test_invalid_k_init(self): def test_unbatched_inputs( self, y_true, y_pred, sample_weight, expected_output ): - with self._strategy.scope(): - r_at_k = RecallAtK(k=3) - tpu_test_utils.run_with_strategy( - self._strategy, - r_at_k.update_state, - y_true, - y_pred, - sample_weight=sample_weight, - ) + r_at_k = RecallAtK(k=3) + r_at_k.update_state(y_true, y_pred, sample_weight=sample_weight) result = r_at_k.result() self.assertAllClose(result, expected_output, rtol=1e-6) def test_batched_input(self): - with self._strategy.scope(): - r_at_k = RecallAtK(k=3) - tpu_test_utils.run_with_strategy( - self._strategy, - r_at_k.update_state, - self.y_true_batched, - self.y_pred_batched, - ) + r_at_k = RecallAtK(k=3) + r_at_k.update_state(self.y_true_batched, self.y_pred_batched) result = r_at_k.result() self.assertAllClose(result, 0.541667) @@ -127,11 +110,8 @@ def test_batched_input(self): ("1d", [1.0, 0.5, 2.0, 1.0], 0.55), ) def test_batched_inputs_sample_weight(self, sample_weight, expected_output): - with self._strategy.scope(): - r_at_k = RecallAtK(k=3) - tpu_test_utils.run_with_strategy( - self._strategy, - r_at_k.update_state, + r_at_k = RecallAtK(k=3) + r_at_k.update_state( self.y_true_batched, self.y_pred_batched, sample_weight=sample_weight, @@ -172,15 +152,8 @@ def test_batched_inputs_sample_weight(self, sample_weight, expected_output): def test_2d_sample_weight( self, y_true, y_pred, sample_weight, expected_output ): - with self._strategy.scope(): - r_at_k = RecallAtK(k=2) - tpu_test_utils.run_with_strategy( - self._strategy, - r_at_k.update_state, - y_true, - y_pred, - sample_weight=sample_weight, - ) + r_at_k = RecallAtK(k=2) + r_at_k.update_state(y_true, y_pred, sample_weight=sample_weight) result = r_at_k.result() self.assertAllClose(result, expected_output, rtol=1e-6) @@ -224,15 +197,8 @@ def test_2d_sample_weight( ), ) def test_masking(self, y_true, y_pred, sample_weight, expected_output): - with self._strategy.scope(): - r_at_k = RecallAtK(k=2) - tpu_test_utils.run_with_strategy( - self._strategy, - r_at_k.update_state, - y_true, - y_pred, - sample_weight=sample_weight, - ) + r_at_k = RecallAtK(k=2) + r_at_k.update_state(y_true, y_pred, sample_weight=sample_weight) result = r_at_k.result() self.assertAllClose(result, expected_output, rtol=1e-6) @@ -243,35 +209,18 @@ def test_masking(self, y_true, y_pred, sample_weight, expected_output): ("4", 4, 0.75), ) def test_k(self, k, expected_recall): - with self._strategy.scope(): - r_at_k = RecallAtK(k=k) - tpu_test_utils.run_with_strategy( - self._strategy, - r_at_k.update_state, - self.y_true_batched, - self.y_pred_batched, - ) + r_at_k = RecallAtK(k=k) + r_at_k.update_state(self.y_true_batched, self.y_pred_batched) result = r_at_k.result() self.assertAllClose(result, expected_recall) def test_statefulness(self): - with self._strategy.scope(): - r_at_k = RecallAtK(k=3) - tpu_test_utils.run_with_strategy( - self._strategy, - r_at_k.update_state, - self.y_true_batched[:2], - self.y_pred_batched[:2], - ) + r_at_k = RecallAtK(k=3) + r_at_k.update_state(self.y_true_batched[:2], self.y_pred_batched[:2]) result = r_at_k.result() self.assertAllClose(result, 0.833333, rtol=1e-6) - tpu_test_utils.run_with_strategy( - self._strategy, - r_at_k.update_state, - self.y_true_batched[2:], - self.y_pred_batched[2:], - ) + r_at_k.update_state(self.y_true_batched[2:], self.y_pred_batched[2:]) result = r_at_k.result() self.assertAllClose(result, 0.541667) @@ -280,8 +229,7 @@ def test_statefulness(self): self.assertAllClose(result, 0.0) def test_serialization(self): - with self._strategy.scope(): - metric = RecallAtK(k=3) + metric = RecallAtK(k=3) restored = deserialize(serialize(metric)) self.assertDictEqual(metric.get_config(), restored.get_config()) @@ -297,17 +245,10 @@ def test_model_evaluate(self): optimizer="adam", ) - x_data = keras.random.normal((2, 20)) - y_data = keras.random.randint((2, 5), minval=0, maxval=2) - - dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) - dataset = dataset.batch( - self._strategy.num_replicas_in_sync - if isinstance(self._strategy, tf.distribute.Strategy) - else 1 + model.evaluate( + x=keras.random.normal((2, 20)), + y=keras.random.randint( + (2, 5), minval=0, maxval=2 + ), # Using 0/1 for y_true + verbose=0, ) - - if isinstance(self._strategy, tf.distribute.TPUStrategy): - dataset = self._strategy.experimental_distribute_dataset(dataset) - - model.evaluate(dataset, steps=2, verbose=0) diff --git a/keras_rs/src/testing/test_case.py b/keras_rs/src/testing/test_case.py index a30f52d2..a10a5fd0 100644 --- a/keras_rs/src/testing/test_case.py +++ b/keras_rs/src/testing/test_case.py @@ -7,7 +7,8 @@ import numpy as np from keras_rs.src import types - +import tensorflow as tf +from keras_rs.src.utils import tpu_test_utils class TestCase(unittest.TestCase): """TestCase class for all Keras Recommenders tests.""" @@ -16,6 +17,16 @@ def setUp(self) -> None: super().setUp() keras.utils.clear_session() keras.config.disable_traceback_filtering() + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + self.on_tpu = "TPU_NAME" in os.environ + + @property + def strategy(self): + if hasattr(self, "_strategy"): + return self._strategy + self._strategy = tpu_test_utils.get_tpu_strategy(self) + return self._strategy def assertAllClose( self, @@ -23,9 +34,8 @@ def assertAllClose( desired: types.Tensor, atol: float = 1e-6, rtol: float = 1e-6, - tpu_atol: float = 1e-2, - tpu_rtol: float = 1e-2, - is_tpu: bool = False, + tpu_atol=None, + tpu_rtol=None, msg: str = "", ) -> None: """Verify that two tensors are close in value element by element. @@ -37,15 +47,15 @@ def assertAllClose( rtol: Relative tolerance. msg: Optional error message. """ + if tpu_atol is not None and self.on_tpu: + atol = tpu_atol + if tpu_rtol is not None and self.on_tpu: + rtol = tpu_rtol + if not isinstance(actual, np.ndarray): actual = keras.ops.convert_to_numpy(actual) if not isinstance(desired, np.ndarray): desired = keras.ops.convert_to_numpy(desired) - if tpu_atol is not None and is_tpu: - atol = tpu_atol - if tpu_rtol is not None and is_tpu: - rtol = tpu_rtol - np.testing.assert_allclose( actual, desired, atol=atol, rtol=rtol, err_msg=msg ) From b634f7888ab45405618278dfc2568de870f50d83 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Wed, 26 Nov 2025 02:38:47 +0000 Subject: [PATCH 19/29] remove venv and reformat --- .gitignore | 2 -- keras_rs/src/layers/embedding/embed_reduce_test.py | 1 - .../src/layers/feature_interaction/dot_interaction_test.py | 4 ++-- keras_rs/src/layers/feature_interaction/feature_cross_test.py | 1 - keras_rs/src/layers/retrieval/hard_negative_mining_test.py | 1 - keras_rs/src/layers/retrieval/remove_accidental_hits_test.py | 1 - .../layers/retrieval/sampling_probability_correction_test.py | 1 - keras_rs/src/losses/list_mle_loss_test.py | 1 - keras_rs/src/losses/pairwise_hinge_loss_test.py | 1 - keras_rs/src/losses/pairwise_logistic_loss_test.py | 1 - keras_rs/src/losses/pairwise_mean_squared_error_test.py | 1 - keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py | 1 - keras_rs/src/metrics/dcg_test.py | 2 +- keras_rs/src/testing/test_case.py | 3 ++- 14 files changed, 5 insertions(+), 16 deletions(-) diff --git a/.gitignore b/.gitignore index df148986..eacd3be8 100644 --- a/.gitignore +++ b/.gitignore @@ -20,5 +20,3 @@ build/ .idea/ venv/ -venv_tf/ -venv_jax/ \ No newline at end of file diff --git a/keras_rs/src/layers/embedding/embed_reduce_test.py b/keras_rs/src/layers/embedding/embed_reduce_test.py index d30a8dcf..440259a9 100644 --- a/keras_rs/src/layers/embedding/embed_reduce_test.py +++ b/keras_rs/src/layers/embedding/embed_reduce_test.py @@ -8,7 +8,6 @@ from keras_rs.src import testing from keras_rs.src.layers.embedding.embed_reduce import EmbedReduce -from keras_rs.src.utils import tpu_test_utils try: import jax diff --git a/keras_rs/src/layers/feature_interaction/dot_interaction_test.py b/keras_rs/src/layers/feature_interaction/dot_interaction_test.py index 24763207..b5aa1f6c 100644 --- a/keras_rs/src/layers/feature_interaction/dot_interaction_test.py +++ b/keras_rs/src/layers/feature_interaction/dot_interaction_test.py @@ -8,7 +8,6 @@ from keras_rs.src.layers.feature_interaction.dot_interaction import ( DotInteraction, ) -from keras_rs.src.utils import tpu_test_utils class DotInteractionTest(testing.TestCase, parameterized.TestCase): @@ -88,7 +87,8 @@ def test_call(self, self_interaction, skip_gather, exp_output_idx): output, self.exp_outputs[exp_output_idx], tpu_atol=1e-2, - tpu_rtol=1e-2) + tpu_rtol=1e-2, + ) def test_invalid_input_rank(self): rank_1_input = [ops.ones((3,)), ops.ones((3,))] diff --git a/keras_rs/src/layers/feature_interaction/feature_cross_test.py b/keras_rs/src/layers/feature_interaction/feature_cross_test.py index 37148648..485f776e 100644 --- a/keras_rs/src/layers/feature_interaction/feature_cross_test.py +++ b/keras_rs/src/layers/feature_interaction/feature_cross_test.py @@ -6,7 +6,6 @@ from keras_rs.src import testing from keras_rs.src.layers.feature_interaction.feature_cross import FeatureCross -from keras_rs.src.utils import tpu_test_utils class FeatureCrossTest(testing.TestCase, parameterized.TestCase): diff --git a/keras_rs/src/layers/retrieval/hard_negative_mining_test.py b/keras_rs/src/layers/retrieval/hard_negative_mining_test.py index 794a97cc..69964d98 100644 --- a/keras_rs/src/layers/retrieval/hard_negative_mining_test.py +++ b/keras_rs/src/layers/retrieval/hard_negative_mining_test.py @@ -6,7 +6,6 @@ from keras_rs.src import testing from keras_rs.src.layers.retrieval import hard_negative_mining -from keras_rs.src.utils import tpu_test_utils class HardNegativeMiningTest(testing.TestCase, parameterized.TestCase): diff --git a/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py b/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py index 6b8780cd..f436678b 100644 --- a/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py +++ b/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py @@ -6,7 +6,6 @@ from keras_rs.src import testing from keras_rs.src.layers.retrieval import remove_accidental_hits -from keras_rs.src.utils import tpu_test_utils class RemoveAccidentalHitsTest(testing.TestCase, parameterized.TestCase): diff --git a/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py b/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py index 9c896c8e..0ade230a 100644 --- a/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py +++ b/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py @@ -6,7 +6,6 @@ from keras_rs.src import testing from keras_rs.src.layers.retrieval import sampling_probability_correction -from keras_rs.src.utils import tpu_test_utils class SamplingProbabilityCorrectionTest( diff --git a/keras_rs/src/losses/list_mle_loss_test.py b/keras_rs/src/losses/list_mle_loss_test.py index b28ad859..ab3a4db9 100644 --- a/keras_rs/src/losses/list_mle_loss_test.py +++ b/keras_rs/src/losses/list_mle_loss_test.py @@ -6,7 +6,6 @@ from keras_rs.src import testing from keras_rs.src.losses.list_mle_loss import ListMLELoss -from keras_rs.src.utils import tpu_test_utils class ListMLELossTest(testing.TestCase, parameterized.TestCase): diff --git a/keras_rs/src/losses/pairwise_hinge_loss_test.py b/keras_rs/src/losses/pairwise_hinge_loss_test.py index 81577634..e1d3e6e4 100644 --- a/keras_rs/src/losses/pairwise_hinge_loss_test.py +++ b/keras_rs/src/losses/pairwise_hinge_loss_test.py @@ -6,7 +6,6 @@ from keras_rs.src import testing from keras_rs.src.losses.pairwise_hinge_loss import PairwiseHingeLoss -from keras_rs.src.utils import tpu_test_utils class PairwiseHingeLossTest(testing.TestCase, parameterized.TestCase): diff --git a/keras_rs/src/losses/pairwise_logistic_loss_test.py b/keras_rs/src/losses/pairwise_logistic_loss_test.py index 0d10643b..54a379c2 100644 --- a/keras_rs/src/losses/pairwise_logistic_loss_test.py +++ b/keras_rs/src/losses/pairwise_logistic_loss_test.py @@ -6,7 +6,6 @@ from keras_rs.src import testing from keras_rs.src.losses.pairwise_logistic_loss import PairwiseLogisticLoss -from keras_rs.src.utils import tpu_test_utils class PairwiseLogisticLossTest(testing.TestCase, parameterized.TestCase): diff --git a/keras_rs/src/losses/pairwise_mean_squared_error_test.py b/keras_rs/src/losses/pairwise_mean_squared_error_test.py index d5524da1..c181d859 100644 --- a/keras_rs/src/losses/pairwise_mean_squared_error_test.py +++ b/keras_rs/src/losses/pairwise_mean_squared_error_test.py @@ -8,7 +8,6 @@ from keras_rs.src.losses.pairwise_mean_squared_error import ( PairwiseMeanSquaredError, ) -from keras_rs.src.utils import tpu_test_utils class PairwiseMeanSquaredErrorTest(testing.TestCase, parameterized.TestCase): diff --git a/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py b/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py index e3ced187..98059882 100644 --- a/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py +++ b/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py @@ -8,7 +8,6 @@ from keras_rs.src.losses.pairwise_soft_zero_one_loss import ( PairwiseSoftZeroOneLoss, ) -from keras_rs.src.utils import tpu_test_utils class PairwiseSoftZeroOneLossTest(testing.TestCase, parameterized.TestCase): diff --git a/keras_rs/src/metrics/dcg_test.py b/keras_rs/src/metrics/dcg_test.py index 73b7c6f5..3e2b4b20 100644 --- a/keras_rs/src/metrics/dcg_test.py +++ b/keras_rs/src/metrics/dcg_test.py @@ -362,4 +362,4 @@ def test_model_evaluate(self): model.evaluate( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=4), - ) \ No newline at end of file + ) diff --git a/keras_rs/src/testing/test_case.py b/keras_rs/src/testing/test_case.py index a10a5fd0..37c76c4f 100644 --- a/keras_rs/src/testing/test_case.py +++ b/keras_rs/src/testing/test_case.py @@ -5,11 +5,12 @@ import keras import numpy as np +import tensorflow as tf from keras_rs.src import types -import tensorflow as tf from keras_rs.src.utils import tpu_test_utils + class TestCase(unittest.TestCase): """TestCase class for all Keras Recommenders tests.""" From 58c78978a7edd51d2fb6fba8705fc2bd1b2ba423 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Wed, 26 Nov 2025 06:50:26 +0000 Subject: [PATCH 20/29] use a shared strategy in conftest.py --- conftest.py | 16 +++++ .../embedding/distributed_embedding_test.py | 32 +++++----- keras_rs/src/losses/list_mle_loss_test.py | 1 + .../src/losses/pairwise_hinge_loss_test.py | 1 + .../src/losses/pairwise_logistic_loss_test.py | 1 + .../pairwise_mean_squared_error_test.py | 1 + .../pairwise_soft_zero_one_loss_test.py | 1 + keras_rs/src/testing/test_case.py | 32 +++++++--- keras_rs/src/utils/tpu_test_utils.py | 58 +++++++++++++++++++ 9 files changed, 119 insertions(+), 24 deletions(-) create mode 100644 conftest.py diff --git a/conftest.py b/conftest.py new file mode 100644 index 00000000..cb811b57 --- /dev/null +++ b/conftest.py @@ -0,0 +1,16 @@ +import pytest +import os +from keras_rs.src.utils import tpu_test_utils + +@pytest.fixture(scope="session", autouse=True) +def prime_shared_tpu_strategy(request): + """ + Eagerly initializes the shared TPU strategy at the beginning of the session + if running on a TPU. This helps catch initialization errors early. + """ + strategy = tpu_test_utils.get_shared_tpu_strategy() + if not strategy: + pytest.fail( + "Failed to initialize shared TPUStrategy for the test session. " + "Check logs for details from create_tpu_strategy." + ) \ No newline at end of file diff --git a/keras_rs/src/layers/embedding/distributed_embedding_test.py b/keras_rs/src/layers/embedding/distributed_embedding_test.py index 49a72d93..e493931f 100644 --- a/keras_rs/src/layers/embedding/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/distributed_embedding_test.py @@ -53,7 +53,7 @@ def setUp(self): # FLAGS.xla_sparse_core_max_unique_ids_per_partition_per_sample = 16 self.batch_size = ( - BATCH_SIZE_PER_CORE * self._strategy.num_replicas_in_sync + BATCH_SIZE_PER_CORE * self.strategy.num_replicas_in_sync ) def get_embedding_config(self, input_type, placement): @@ -194,11 +194,11 @@ def test_basics(self, input_type, placement): if placement == "sparsecore" and not self.on_tpu: with self.assertRaisesRegex(Exception, "sparsecore"): - with self._strategy.scope(): + with self.strategy.scope(): distributed_embedding.DistributedEmbedding(feature_configs) return - with self._strategy.scope(): + with self.strategy.scope(): layer = distributed_embedding.DistributedEmbedding(feature_configs) if keras.backend.backend() == "jax": @@ -276,7 +276,7 @@ def test_model_fit(self, input_type, use_weights): (test_model_inputs, test_labels) ) - with self._strategy.scope(): + with self.strategy.scope(): layer = distributed_embedding.DistributedEmbedding(feature_configs) def _create_keras_input( @@ -347,7 +347,7 @@ def test_dataset_generator(): # New preprocessed data removes the `weights` component. dataset_has_weights = False else: - train_dataset = self._strategy.experimental_distribute_dataset( + train_dataset = self.strategy.experimental_distribute_dataset( train_dataset, options=tf.distribute.InputOptions( experimental_fetch_to_device=False @@ -362,7 +362,7 @@ def test_dataset_generator(): inputs=keras_model_inputs, outputs=keras_model_outputs ) - with self._strategy.scope(): + with self.strategy.scope(): model.compile(optimizer="adam", loss="mse") model_inputs, _ = next(iter(test_dataset)) @@ -511,7 +511,7 @@ def test_correctness( if not use_weights: weights = None - with self._strategy.scope(): + with self.strategy.scope(): layer = distributed_embedding.DistributedEmbedding(feature_config) if keras.backend.backend() == "jax": @@ -568,7 +568,7 @@ def test_correctness( self.assertEqual(res.shape, (self.batch_size, EMBEDDING_OUTPUT_DIM)) - with self._strategy.scope(): + with self.strategy.scope(): tables = layer.get_embedding_tables() emb = tables["table"] @@ -633,11 +633,11 @@ def test_shared_table(self): "dense", embedding_config ) - with self._strategy.scope(): + with self.strategy.scope(): layer = distributed_embedding.DistributedEmbedding(embedding_config) res = tpu_test_utils.run_with_strategy( - self._strategy, layer.__call__, inputs + self.strategy, layer.__call__, inputs ) if self.placement == "default_device": @@ -709,11 +709,11 @@ def test_mixed_placement(self): "dense", embedding_config ) - with self._strategy.scope(): + with self.strategy.scope(): layer = distributed_embedding.DistributedEmbedding(embedding_config) res = tpu_test_utils.run_with_strategy( - self._strategy, layer.__call__, inputs + self.strategy, layer.__call__, inputs ) self.assertEqual( @@ -740,7 +740,7 @@ def test_save_load_model(self): with tempfile.TemporaryDirectory() as temp_dir: path = os.path.join(temp_dir, "model.keras") - with self._strategy.scope(): + with self.strategy.scope(): layer = distributed_embedding.DistributedEmbedding( feature_configs ) @@ -748,14 +748,14 @@ def test_save_load_model(self): model = keras.Model(inputs=keras_inputs, outputs=keras_outputs) output_before = tpu_test_utils.run_with_strategy( - self._strategy, model.__call__, inputs + self.strategy, model.__call__, inputs ) model.save(path) - with self._strategy.scope(): + with self.strategy.scope(): reloaded_model = keras.models.load_model(path) output_after = tpu_test_utils.run_with_strategy( - self._strategy, reloaded_model.__call__, inputs + self.strategy, reloaded_model.__call__, inputs ) if self.placement == "sparsecore": diff --git a/keras_rs/src/losses/list_mle_loss_test.py b/keras_rs/src/losses/list_mle_loss_test.py index ab3a4db9..ebf9a1d1 100644 --- a/keras_rs/src/losses/list_mle_loss_test.py +++ b/keras_rs/src/losses/list_mle_loss_test.py @@ -10,6 +10,7 @@ class ListMLELossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() self.unbatched_scores = ops.array( [1.0, 3.0, 2.0, 4.0, 0.8], dtype="float32" ) diff --git a/keras_rs/src/losses/pairwise_hinge_loss_test.py b/keras_rs/src/losses/pairwise_hinge_loss_test.py index e1d3e6e4..7c782015 100644 --- a/keras_rs/src/losses/pairwise_hinge_loss_test.py +++ b/keras_rs/src/losses/pairwise_hinge_loss_test.py @@ -10,6 +10,7 @@ class PairwiseHingeLossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) diff --git a/keras_rs/src/losses/pairwise_logistic_loss_test.py b/keras_rs/src/losses/pairwise_logistic_loss_test.py index 54a379c2..74c383a0 100644 --- a/keras_rs/src/losses/pairwise_logistic_loss_test.py +++ b/keras_rs/src/losses/pairwise_logistic_loss_test.py @@ -10,6 +10,7 @@ class PairwiseLogisticLossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) diff --git a/keras_rs/src/losses/pairwise_mean_squared_error_test.py b/keras_rs/src/losses/pairwise_mean_squared_error_test.py index c181d859..e1f865c6 100644 --- a/keras_rs/src/losses/pairwise_mean_squared_error_test.py +++ b/keras_rs/src/losses/pairwise_mean_squared_error_test.py @@ -12,6 +12,7 @@ class PairwiseMeanSquaredErrorTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) diff --git a/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py b/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py index 98059882..92ddeae2 100644 --- a/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py +++ b/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py @@ -12,6 +12,7 @@ class PairwiseSoftZeroOneLossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) diff --git a/keras_rs/src/testing/test_case.py b/keras_rs/src/testing/test_case.py index 37c76c4f..b6aee246 100644 --- a/keras_rs/src/testing/test_case.py +++ b/keras_rs/src/testing/test_case.py @@ -1,7 +1,7 @@ import os import tempfile import unittest -from typing import Any +from typing import Any, Optional, Union import keras import numpy as np @@ -10,6 +10,12 @@ from keras_rs.src import types from keras_rs.src.utils import tpu_test_utils +StrategyType = Union[ + tf.distribute.Strategy, + tpu_test_utils.DummyStrategy, + tpu_test_utils.JaxDummyStrategy, +] + class TestCase(unittest.TestCase): """TestCase class for all Keras Recommenders tests.""" @@ -21,13 +27,23 @@ def setUp(self) -> None: if keras.backend.backend() == "tensorflow": tf.debugging.disable_traceback_filtering() self.on_tpu = "TPU_NAME" in os.environ + self._strategy: Optional[StrategyType] = None @property - def strategy(self): - if hasattr(self, "_strategy"): - return self._strategy - self._strategy = tpu_test_utils.get_tpu_strategy(self) - return self._strategy + def strategy(self) -> StrategyType: + strat = tpu_test_utils.get_shared_tpu_strategy() + + if strat is None: + # This case should ideally be caught by the conftest.py fixture + self.fail( + "TPU environment detected, but the shared TPUStrategy is None. " + "Initialization likely failed." + ) + return strat + # if self._strategy is not None: + # return self._strategy + # self._strategy = tpu_test_utils.get_tpu_strategy(self) + # return self._strategy def assertAllClose( self, @@ -35,8 +51,8 @@ def assertAllClose( desired: types.Tensor, atol: float = 1e-6, rtol: float = 1e-6, - tpu_atol=None, - tpu_rtol=None, + tpu_atol: float = None, + tpu_rtol: float = None, msg: str = "", ) -> None: """Verify that two tensors are close in value element by element. diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py index d6e1d514..2a66d8ea 100644 --- a/keras_rs/src/utils/tpu_test_utils.py +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -1,5 +1,6 @@ import contextlib import os +import threading from types import ModuleType from typing import Any, Callable, ContextManager, Optional, Tuple, Union @@ -42,6 +43,63 @@ def num_replicas_in_sync(self) -> Any: StrategyType = Union[tf.distribute.Strategy, DummyStrategy, JaxDummyStrategy] +_shared_strategy: Optional[StrategyType] = None +_lock = threading.Lock() + +def create_tpu_strategy() -> Optional[StrategyType]: + """Initializes the TPU system and returns a TPUStrategy.""" + print("Attempting to create TPUStrategy...") + try: + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') + tf.config.experimental_connect_to_cluster(resolver) + tf.tpu.experimental.initialize_tpu_system(resolver) + strategy = tf.distribute.TPUStrategy(resolver) + print(f"TPUStrategy created successfully. Devices: {strategy.extended.num_replicas_in_sync}") + return strategy + except Exception as e: + print(f"Error creating TPUStrategy: {e}") + return None + +def get_shared_tpu_strategy() -> Optional[StrategyType]: + """ + Returns a session-wide shared TPUStrategy instance. + Creates the instance on the first call. + Returns None if not in a TPU environment or if creation fails. + """ + global _shared_strategy + if _shared_strategy is not None: + return _shared_strategy + + with _lock: + if _shared_strategy is None: + if "TPU_NAME" not in os.environ: + _shared_strategy = DummyStrategy() + return _shared_strategy + if keras.backend.backend() == "tensorflow": + resolver = tf.distribute.cluster_resolver.TPUClusterResolver() + tf.config.experimental_connect_to_cluster(resolver) + topology = tf.tpu.experimental.initialize_tpu_system(resolver) + tpu_metadata = resolver.get_tpu_system_metadata() + device_assignment = tf.tpu.experimental.DeviceAssignment.build( + topology, num_replicas=tpu_metadata.num_hosts + ) + _shared_strategy = tf.distribute.TPUStrategy( + resolver, experimental_device_assignment=device_assignment + ) + print("### num_replicas", _shared_strategy.num_replicas_in_sync) + elif keras.backend.backend() == "jax": + if jax is None: + raise ImportError( + "JAX backend requires jax to be installed for TPU." + ) + print("### num_replicas", jax.device_count("tpu")) + _shared_strategy = JaxDummyStrategy() + else: + _shared_strategy = DummyStrategy() + if _shared_strategy is None: + print("Failed to create the shared TPUStrategy.") + return _shared_strategy + def get_tpu_strategy(test_case: Any) -> StrategyType: """Get TPU strategy if on TPU, otherwise return DummyStrategy.""" From ea18299dd33fd5f7e407e1bd92fb720a40032448 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Wed, 26 Nov 2025 19:42:59 +0000 Subject: [PATCH 21/29] format --- conftest.py | 5 +++-- keras_rs/src/testing/test_case.py | 5 ++--- keras_rs/src/utils/tpu_test_utils.py | 11 ++++++++--- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/conftest.py b/conftest.py index cb811b57..e8cac721 100644 --- a/conftest.py +++ b/conftest.py @@ -1,7 +1,8 @@ import pytest -import os + from keras_rs.src.utils import tpu_test_utils + @pytest.fixture(scope="session", autouse=True) def prime_shared_tpu_strategy(request): """ @@ -13,4 +14,4 @@ def prime_shared_tpu_strategy(request): pytest.fail( "Failed to initialize shared TPUStrategy for the test session. " "Check logs for details from create_tpu_strategy." - ) \ No newline at end of file + ) diff --git a/keras_rs/src/testing/test_case.py b/keras_rs/src/testing/test_case.py index b6aee246..6dff2bdd 100644 --- a/keras_rs/src/testing/test_case.py +++ b/keras_rs/src/testing/test_case.py @@ -34,7 +34,6 @@ def strategy(self) -> StrategyType: strat = tpu_test_utils.get_shared_tpu_strategy() if strat is None: - # This case should ideally be caught by the conftest.py fixture self.fail( "TPU environment detected, but the shared TPUStrategy is None. " "Initialization likely failed." @@ -51,8 +50,8 @@ def assertAllClose( desired: types.Tensor, atol: float = 1e-6, rtol: float = 1e-6, - tpu_atol: float = None, - tpu_rtol: float = None, + tpu_atol: float | None = None, + tpu_rtol: float | None = None, msg: str = "", ) -> None: """Verify that two tensors are close in value element by element. diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py index 2a66d8ea..7d5be954 100644 --- a/keras_rs/src/utils/tpu_test_utils.py +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -46,20 +46,25 @@ def num_replicas_in_sync(self) -> Any: _shared_strategy: Optional[StrategyType] = None _lock = threading.Lock() + def create_tpu_strategy() -> Optional[StrategyType]: """Initializes the TPU system and returns a TPUStrategy.""" print("Attempting to create TPUStrategy...") try: - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="") tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) - print(f"TPUStrategy created successfully. Devices: {strategy.extended.num_replicas_in_sync}") + print( + "TPUStrategy created successfully." + "Devices: {strategy.extended.num_replicas_in_sync}" + ) return strategy except Exception as e: print(f"Error creating TPUStrategy: {e}") return None + def get_shared_tpu_strategy() -> Optional[StrategyType]: """ Returns a session-wide shared TPUStrategy instance. @@ -97,7 +102,7 @@ def get_shared_tpu_strategy() -> Optional[StrategyType]: else: _shared_strategy = DummyStrategy() if _shared_strategy is None: - print("Failed to create the shared TPUStrategy.") + print("Failed to create the shared TPUStrategy.") return _shared_strategy From dd2219fc4c221d266f1191f4b9cc29c1389257aa Mon Sep 17 00:00:00 2001 From: Wenyi Guo <41378453+wenyi-guo@users.noreply.github.com> Date: Wed, 26 Nov 2025 15:21:49 -0800 Subject: [PATCH 22/29] format conftest Added type hint for prime_shared_tpu_strategy function. --- conftest.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/conftest.py b/conftest.py index e8cac721..4f4a4002 100644 --- a/conftest.py +++ b/conftest.py @@ -1,10 +1,18 @@ +from typing import Union + import pytest from keras_rs.src.utils import tpu_test_utils +StrategyType = Union[ + tf.distribute.Strategy, + tpu_test_utils.DummyStrategy, + tpu_test_utils.JaxDummyStrategy, +] + @pytest.fixture(scope="session", autouse=True) -def prime_shared_tpu_strategy(request): +def prime_shared_tpu_strategy(request) -> StrategyType: """ Eagerly initializes the shared TPU strategy at the beginning of the session if running on a TPU. This helps catch initialization errors early. From 29c7f29c58459339ab9ed484f376983d01e71c04 Mon Sep 17 00:00:00 2001 From: Wenyi Guo <41378453+wenyi-guo@users.noreply.github.com> Date: Wed, 26 Nov 2025 15:23:35 -0800 Subject: [PATCH 23/29] format import --- conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/conftest.py b/conftest.py index 4f4a4002..ff8e114a 100644 --- a/conftest.py +++ b/conftest.py @@ -1,6 +1,7 @@ from typing import Union import pytest +import tensorflow as tf from keras_rs.src.utils import tpu_test_utils From 836de86d561400c92f3d8073d803457ef47bdd83 Mon Sep 17 00:00:00 2001 From: Wenyi Guo <41378453+wenyi-guo@users.noreply.github.com> Date: Wed, 26 Nov 2025 22:20:59 -0800 Subject: [PATCH 24/29] format --- conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conftest.py b/conftest.py index ff8e114a..f80e27f6 100644 --- a/conftest.py +++ b/conftest.py @@ -13,7 +13,7 @@ @pytest.fixture(scope="session", autouse=True) -def prime_shared_tpu_strategy(request) -> StrategyType: +def prime_shared_tpu_strategy() -> None: """ Eagerly initializes the shared TPU strategy at the beginning of the session if running on a TPU. This helps catch initialization errors early. From 16dd280a6dd770c832130f6858a4a8451a72b244 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Mon, 1 Dec 2025 23:27:36 +0000 Subject: [PATCH 25/29] resolve comments --- .gitignore | 3 +- .../embedding/distributed_embedding_test.py | 66 ++++++++----------- .../src/layers/embedding/embed_reduce_test.py | 14 +--- .../dot_interaction_test.py | 19 +++--- .../retrieval/hard_negative_mining_test.py | 16 ++--- .../retrieval/remove_accidental_hits_test.py | 18 +++-- .../sampling_probability_correction_test.py | 14 ++-- keras_rs/src/losses/list_mle_loss_test.py | 16 ++--- .../src/losses/pairwise_hinge_loss_test.py | 16 ++--- .../src/losses/pairwise_logistic_loss_test.py | 16 ++--- .../pairwise_mean_squared_error_test.py | 16 ++--- .../pairwise_soft_zero_one_loss_test.py | 16 ++--- keras_rs/src/metrics/dcg_test.py | 21 +++--- .../metrics/mean_average_precision_test.py | 20 +++--- .../src/metrics/mean_reciprocal_rank_test.py | 21 +++--- keras_rs/src/metrics/ndcg_test.py | 21 +++--- keras_rs/src/metrics/precision_at_k_test.py | 23 +++---- keras_rs/src/metrics/recall_at_k_test.py | 23 +++---- keras_rs/src/testing/test_case.py | 8 ++- 19 files changed, 140 insertions(+), 227 deletions(-) diff --git a/.gitignore b/.gitignore index eacd3be8..2e927895 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,5 @@ build/ # Pycharm files .idea/ -venv/ +venv/] +venv_tf/ diff --git a/keras_rs/src/layers/embedding/distributed_embedding_test.py b/keras_rs/src/layers/embedding/distributed_embedding_test.py index e493931f..27a592dc 100644 --- a/keras_rs/src/layers/embedding/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/distributed_embedding_test.py @@ -194,12 +194,10 @@ def test_basics(self, input_type, placement): if placement == "sparsecore" and not self.on_tpu: with self.assertRaisesRegex(Exception, "sparsecore"): - with self.strategy.scope(): - distributed_embedding.DistributedEmbedding(feature_configs) + distributed_embedding.DistributedEmbedding(feature_configs) return - with self.strategy.scope(): - layer = distributed_embedding.DistributedEmbedding(feature_configs) + layer = distributed_embedding.DistributedEmbedding(feature_configs) if keras.backend.backend() == "jax": preprocessed_inputs = layer.preprocess(inputs, weights) @@ -276,8 +274,7 @@ def test_model_fit(self, input_type, use_weights): (test_model_inputs, test_labels) ) - with self.strategy.scope(): - layer = distributed_embedding.DistributedEmbedding(feature_configs) + layer = distributed_embedding.DistributedEmbedding(feature_configs) def _create_keras_input( feature_config: config.FeatureConfig, dtype: types.DType @@ -362,19 +359,18 @@ def test_dataset_generator(): inputs=keras_model_inputs, outputs=keras_model_outputs ) - with self.strategy.scope(): - model.compile(optimizer="adam", loss="mse") + model.compile(optimizer="adam", loss="mse") - model_inputs, _ = next(iter(test_dataset)) - test_output_before = tpu_test_utils.run_with_strategy( - self.strategy, model.__call__, model_inputs - ) + model_inputs, _ = next(iter(test_dataset)) + test_output_before = tpu_test_utils.run_with_strategy( + self.strategy, model.__call__, model_inputs + ) - model.fit(train_dataset, steps_per_epoch=1, epochs=1) + model.fit(train_dataset, steps_per_epoch=1, epochs=1) - test_output_after = tpu_test_utils.run_with_strategy( - self.strategy, model.__call__, model_inputs - ) + test_output_after = tpu_test_utils.run_with_strategy( + self.strategy, model.__call__, model_inputs + ) # Verify that the embedding has actually trained. for before, after in zip( @@ -511,8 +507,7 @@ def test_correctness( if not use_weights: weights = None - with self.strategy.scope(): - layer = distributed_embedding.DistributedEmbedding(feature_config) + layer = distributed_embedding.DistributedEmbedding(feature_config) if keras.backend.backend() == "jax": preprocessed = layer.preprocess(inputs, weights) @@ -568,8 +563,7 @@ def test_correctness( self.assertEqual(res.shape, (self.batch_size, EMBEDDING_OUTPUT_DIM)) - with self.strategy.scope(): - tables = layer.get_embedding_tables() + tables = layer.get_embedding_tables() emb = tables["table"] @@ -633,8 +627,7 @@ def test_shared_table(self): "dense", embedding_config ) - with self.strategy.scope(): - layer = distributed_embedding.DistributedEmbedding(embedding_config) + layer = distributed_embedding.DistributedEmbedding(embedding_config) res = tpu_test_utils.run_with_strategy( self.strategy, layer.__call__, inputs @@ -709,8 +702,7 @@ def test_mixed_placement(self): "dense", embedding_config ) - with self.strategy.scope(): - layer = distributed_embedding.DistributedEmbedding(embedding_config) + layer = distributed_embedding.DistributedEmbedding(embedding_config) res = tpu_test_utils.run_with_strategy( self.strategy, layer.__call__, inputs @@ -740,23 +732,19 @@ def test_save_load_model(self): with tempfile.TemporaryDirectory() as temp_dir: path = os.path.join(temp_dir, "model.keras") - with self.strategy.scope(): - layer = distributed_embedding.DistributedEmbedding( - feature_configs - ) - keras_outputs = layer(keras_inputs) - model = keras.Model(inputs=keras_inputs, outputs=keras_outputs) + layer = distributed_embedding.DistributedEmbedding(feature_configs) + keras_outputs = layer(keras_inputs) + model = keras.Model(inputs=keras_inputs, outputs=keras_outputs) - output_before = tpu_test_utils.run_with_strategy( - self.strategy, model.__call__, inputs - ) - model.save(path) + output_before = tpu_test_utils.run_with_strategy( + self.strategy, model.__call__, inputs + ) + model.save(path) - with self.strategy.scope(): - reloaded_model = keras.models.load_model(path) - output_after = tpu_test_utils.run_with_strategy( - self.strategy, reloaded_model.__call__, inputs - ) + reloaded_model = keras.models.load_model(path) + output_after = tpu_test_utils.run_with_strategy( + self.strategy, reloaded_model.__call__, inputs + ) if self.placement == "sparsecore": self.skipTest("TODO table reloading.") diff --git a/keras_rs/src/layers/embedding/embed_reduce_test.py b/keras_rs/src/layers/embedding/embed_reduce_test.py index 440259a9..1d7fb456 100644 --- a/keras_rs/src/layers/embedding/embed_reduce_test.py +++ b/keras_rs/src/layers/embedding/embed_reduce_test.py @@ -9,18 +9,8 @@ from keras_rs.src import testing from keras_rs.src.layers.embedding.embed_reduce import EmbedReduce -try: - import jax - from jax.experimental import sparse as jax_sparse -except ImportError: - jax = None - jax_sparse = None - class EmbedReduceTest(testing.TestCase, parameterized.TestCase): - def setUp(self): - super().setUp() - @parameterized.named_parameters( [ ( @@ -182,9 +172,7 @@ def test_symbolic_call(self, input_type, input_rank, use_weights): def test_predict(self): input = keras.random.randint((5, 7), minval=0, maxval=10) - with self.strategy.scope(): - model = keras.models.Sequential([EmbedReduce(10, 20)]) - model.compile(optimizer="adam", loss="mse") + model = keras.models.Sequential([EmbedReduce(10, 20)]) model.predict(input, batch_size=2) def test_serialization(self): diff --git a/keras_rs/src/layers/feature_interaction/dot_interaction_test.py b/keras_rs/src/layers/feature_interaction/dot_interaction_test.py index b5aa1f6c..eb634478 100644 --- a/keras_rs/src/layers/feature_interaction/dot_interaction_test.py +++ b/keras_rs/src/layers/feature_interaction/dot_interaction_test.py @@ -127,16 +127,15 @@ def test_invalid_input_different_shapes(self): ), ) def test_predict(self, self_interaction, skip_gather): - with self.strategy.scope(): - feature1 = keras.layers.Input(shape=(5,)) - feature2 = keras.layers.Input(shape=(5,)) - feature3 = keras.layers.Input(shape=(5,)) - x = DotInteraction( - self_interaction=self_interaction, skip_gather=skip_gather - )([feature1, feature2, feature3]) - x = keras.layers.Dense(units=1)(x) - model = keras.Model([feature1, feature2, feature3], x) - model.compile(optimizer="adam", loss="mse") + feature1 = keras.layers.Input(shape=(5,)) + feature2 = keras.layers.Input(shape=(5,)) + feature3 = keras.layers.Input(shape=(5,)) + x = DotInteraction( + self_interaction=self_interaction, skip_gather=skip_gather + )([feature1, feature2, feature3]) + x = keras.layers.Dense(units=1)(x) + model = keras.Model([feature1, feature2, feature3], x) + # model.compile(optimizer="adam", loss="mse") model.predict(self.input, batch_size=2) diff --git a/keras_rs/src/layers/retrieval/hard_negative_mining_test.py b/keras_rs/src/layers/retrieval/hard_negative_mining_test.py index 69964d98..382e0742 100644 --- a/keras_rs/src/layers/retrieval/hard_negative_mining_test.py +++ b/keras_rs/src/layers/retrieval/hard_negative_mining_test.py @@ -92,16 +92,12 @@ def test_call(self, rank, num_hard_negatives): def test_predict(self): logits, labels = self.create_inputs() - with self.strategy.scope(): - in_logits = keras.layers.Input(shape=logits.shape[1:]) - in_labels = keras.layers.Input(shape=labels.shape[1:]) - out_logits, out_labels = hard_negative_mining.HardNegativeMining( - num_hard_negatives=3 - )(in_logits, in_labels) - model = keras.Model( - [in_logits, in_labels], [out_logits, out_labels] - ) - model.compile(optimizer="adam", loss="mse") + in_logits = keras.layers.Input(shape=logits.shape[1:]) + in_labels = keras.layers.Input(shape=labels.shape[1:]) + out_logits, out_labels = hard_negative_mining.HardNegativeMining( + num_hard_negatives=3 + )(in_logits, in_labels) + model = keras.Model([in_logits, in_labels], [out_logits, out_labels]) model.predict([logits, labels], batch_size=8) diff --git a/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py b/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py index f436678b..38f9ef27 100644 --- a/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py +++ b/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py @@ -154,16 +154,14 @@ def test_predict(self): # Note: for predict, we test with probabilities that have a batch dim. logits, labels, candidate_ids = self.create_inputs(candidate_ids_rank=2) - with self.strategy.scope(): - layer = remove_accidental_hits.RemoveAccidentalHits() - in_logits = keras.layers.Input(logits.shape[1:]) - in_labels = keras.layers.Input(labels.shape[1:]) - in_candidate_ids = keras.layers.Input(labels.shape[1:]) - out_logits = layer(in_logits, in_labels, in_candidate_ids) - model = keras.Model( - [in_logits, in_labels, in_candidate_ids], out_logits - ) - model.compile(optimizer="adam", loss="mse") + layer = remove_accidental_hits.RemoveAccidentalHits() + in_logits = keras.layers.Input(logits.shape[1:]) + in_labels = keras.layers.Input(labels.shape[1:]) + in_candidate_ids = keras.layers.Input(labels.shape[1:]) + out_logits = layer(in_logits, in_labels, in_candidate_ids) + model = keras.Model( + [in_logits, in_labels, in_candidate_ids], out_logits + ) model.predict([logits, labels, candidate_ids], batch_size=8) diff --git a/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py b/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py index 0ade230a..629b5a66 100644 --- a/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py +++ b/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py @@ -90,15 +90,11 @@ def test_predict(self): # Note: for predict, we test with probabilities that have a batch dim. logits, probs = self.create_inputs(probs_rank=2) - with self.strategy.scope(): - layer = ( - sampling_probability_correction.SamplingProbabilityCorrection() - ) - in_logits = keras.layers.Input(logits.shape[1:]) - in_probs = keras.layers.Input(probs.shape[1:]) - out_logits = layer(in_logits, in_probs) - model = keras.Model([in_logits, in_probs], out_logits) - model.compile(optimizer="adam", loss="mse") + layer = sampling_probability_correction.SamplingProbabilityCorrection() + in_logits = keras.layers.Input(logits.shape[1:]) + in_probs = keras.layers.Input(probs.shape[1:]) + out_logits = layer(in_logits, in_probs) + model = keras.Model([in_logits, in_probs], out_logits) model.predict([logits, probs], batch_size=4) diff --git a/keras_rs/src/losses/list_mle_loss_test.py b/keras_rs/src/losses/list_mle_loss_test.py index ebf9a1d1..306fcd26 100644 --- a/keras_rs/src/losses/list_mle_loss_test.py +++ b/keras_rs/src/losses/list_mle_loss_test.py @@ -84,19 +84,11 @@ def test_scalar_sample_weight(self): ) def test_model_fit(self): - def create_model(): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - model.compile(loss=ListMLELoss(), optimizer="adam") - return model - - if self.strategy: - with self.strategy.scope(): - model = create_model() - else: - model = create_model() + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile(loss=ListMLELoss(), optimizer="adam") model.fit( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=2), diff --git a/keras_rs/src/losses/pairwise_hinge_loss_test.py b/keras_rs/src/losses/pairwise_hinge_loss_test.py index 7c782015..0e82533e 100644 --- a/keras_rs/src/losses/pairwise_hinge_loss_test.py +++ b/keras_rs/src/losses/pairwise_hinge_loss_test.py @@ -111,19 +111,11 @@ def test_mask_input(self): self.assertAllClose(output, expected_output, atol=1e-5) def test_model_fit(self): - def create_model(): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - model.compile(loss=PairwiseHingeLoss(), optimizer="adam") - return model - - if self.strategy: - with self.strategy.scope(): - model = create_model() - else: - model = create_model() + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile(loss=PairwiseHingeLoss(), optimizer="adam") model.fit( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=2), diff --git a/keras_rs/src/losses/pairwise_logistic_loss_test.py b/keras_rs/src/losses/pairwise_logistic_loss_test.py index 74c383a0..f0644634 100644 --- a/keras_rs/src/losses/pairwise_logistic_loss_test.py +++ b/keras_rs/src/losses/pairwise_logistic_loss_test.py @@ -111,19 +111,11 @@ def test_mask_input(self): self.assertAllClose(output, expected_output, atol=1e-5) def test_model_fit(self): - def create_model(): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - model.compile(loss=PairwiseLogisticLoss(), optimizer="adam") - return model - - if self.strategy: - with self.strategy.scope(): - model = create_model() - else: - model = create_model() + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile(loss=PairwiseLogisticLoss(), optimizer="adam") model.fit( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=2), diff --git a/keras_rs/src/losses/pairwise_mean_squared_error_test.py b/keras_rs/src/losses/pairwise_mean_squared_error_test.py index e1f865c6..8d2912f3 100644 --- a/keras_rs/src/losses/pairwise_mean_squared_error_test.py +++ b/keras_rs/src/losses/pairwise_mean_squared_error_test.py @@ -110,19 +110,11 @@ def test_mask_input(self): self.assertAllClose(output, expected_output, atol=1e-5) def test_model_fit(self): - def create_model(): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - model.compile(loss=PairwiseMeanSquaredError(), optimizer="adam") - return model - - if self.strategy: - with self.strategy.scope(): - model = create_model() - else: - model = create_model() + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile(loss=PairwiseMeanSquaredError(), optimizer="adam") model.fit( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=2), diff --git a/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py b/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py index 92ddeae2..6063b17b 100644 --- a/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py +++ b/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py @@ -113,19 +113,11 @@ def test_mask_input(self): self.assertAllClose(output, expected_output, atol=1e-5) def test_model_fit(self): - def create_model(): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - model.compile(loss=PairwiseSoftZeroOneLoss(), optimizer="adam") - return model - - if self.strategy: - with self.strategy.scope(): - model = create_model() - else: - model = create_model() + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile(loss=PairwiseSoftZeroOneLoss(), optimizer="adam") model.fit( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=2), diff --git a/keras_rs/src/metrics/dcg_test.py b/keras_rs/src/metrics/dcg_test.py index 3e2b4b20..e3e000c5 100644 --- a/keras_rs/src/metrics/dcg_test.py +++ b/keras_rs/src/metrics/dcg_test.py @@ -8,7 +8,6 @@ from keras_rs.src import testing from keras_rs.src.metrics.dcg import DCG -from keras_rs.src.utils import tpu_test_utils def _compute_dcg(labels, ranks): @@ -20,7 +19,7 @@ def _compute_dcg(labels, ranks): class DCGTest(testing.TestCase, parameterized.TestCase): def setUp(self): - self._strategy = tpu_test_utils.get_tpu_strategy(self) + super().setUp() self.y_true_batched = ops.array( [ @@ -348,17 +347,15 @@ def inverse_discount_fn(rank): self.assertAllClose(result, expected_output, rtol=1e-5) def test_model_evaluate(self): - with self._strategy.scope(): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[DCG()], - optimizer="adam", - ) + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[DCG()], + optimizer="adam", + ) model.evaluate( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=4), diff --git a/keras_rs/src/metrics/mean_average_precision_test.py b/keras_rs/src/metrics/mean_average_precision_test.py index 7161534a..574e8af3 100644 --- a/keras_rs/src/metrics/mean_average_precision_test.py +++ b/keras_rs/src/metrics/mean_average_precision_test.py @@ -6,12 +6,11 @@ from keras_rs.src import testing from keras_rs.src.metrics.mean_average_precision import MeanAveragePrecision -from keras_rs.src.utils import tpu_test_utils class MeanAveragePrecisionTest(testing.TestCase, parameterized.TestCase): def setUp(self): - self._strategy = tpu_test_utils.get_tpu_strategy(self) + super().setUp() self.y_true_batched = ops.array( [ @@ -279,16 +278,15 @@ def test_serialization(self): self.assertDictEqual(metric.get_config(), restored.get_config()) def test_model_evaluate(self): - with self._strategy.scope(): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[MeanAveragePrecision()], - optimizer="adam", - ) + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[MeanAveragePrecision()], + optimizer="adam", + ) model.evaluate( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=4), diff --git a/keras_rs/src/metrics/mean_reciprocal_rank_test.py b/keras_rs/src/metrics/mean_reciprocal_rank_test.py index 3d5264bc..5277d29d 100644 --- a/keras_rs/src/metrics/mean_reciprocal_rank_test.py +++ b/keras_rs/src/metrics/mean_reciprocal_rank_test.py @@ -6,12 +6,11 @@ from keras_rs.src import testing from keras_rs.src.metrics.mean_reciprocal_rank import MeanReciprocalRank -from keras_rs.src.utils import tpu_test_utils class MeanReciprocalRankTest(testing.TestCase, parameterized.TestCase): def setUp(self): - self._strategy = tpu_test_utils.get_tpu_strategy(self) + super().setUp() self.y_true_batched = ops.array( [ @@ -251,17 +250,15 @@ def test_serialization(self): self.assertDictEqual(metric.get_config(), restored.get_config()) def test_model_evaluate(self): - with self._strategy.scope(): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[MeanReciprocalRank()], - optimizer="adam", - ) + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[MeanReciprocalRank()], + optimizer="adam", + ) model.evaluate( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=4), diff --git a/keras_rs/src/metrics/ndcg_test.py b/keras_rs/src/metrics/ndcg_test.py index 68fc2c96..155b8843 100644 --- a/keras_rs/src/metrics/ndcg_test.py +++ b/keras_rs/src/metrics/ndcg_test.py @@ -8,7 +8,6 @@ from keras_rs.src import testing from keras_rs.src.metrics.ndcg import NDCG -from keras_rs.src.utils import tpu_test_utils def _compute_dcg(labels, ranks): @@ -20,7 +19,7 @@ def _compute_dcg(labels, ranks): class NDCGTest(testing.TestCase, parameterized.TestCase): def setUp(self): - self._strategy = tpu_test_utils.get_tpu_strategy(self) + super().setUp() self.y_true_batched = ops.array( [ @@ -360,17 +359,15 @@ def inverse_discount_fn(rank): self.assertAllClose(result, ndcg, rtol=1e-5) def test_model_evaluate(self): - with self._strategy.scope(): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[NDCG()], - optimizer="adam", - ) + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[NDCG()], + optimizer="adam", + ) model.evaluate( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=4), diff --git a/keras_rs/src/metrics/precision_at_k_test.py b/keras_rs/src/metrics/precision_at_k_test.py index 62b8348a..c1b345a0 100644 --- a/keras_rs/src/metrics/precision_at_k_test.py +++ b/keras_rs/src/metrics/precision_at_k_test.py @@ -6,12 +6,11 @@ from keras_rs.src import testing from keras_rs.src.metrics.precision_at_k import PrecisionAtK -from keras_rs.src.utils import tpu_test_utils class PrecisionAtKTest(testing.TestCase, parameterized.TestCase): def setUp(self): - self._strategy = tpu_test_utils.get_tpu_strategy(self) + super().setUp() self.y_true_batched = ops.array( [ @@ -231,17 +230,15 @@ def test_serialization(self): self.assertDictEqual(metric.get_config(), restored.get_config()) def test_model_evaluate(self): - with self._strategy.scope(): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[PrecisionAtK(k=3)], - optimizer="adam", - ) - + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[PrecisionAtK(k=3)], + optimizer="adam", + ) model.evaluate( x=keras.random.normal((2, 20)), y=keras.random.randint( diff --git a/keras_rs/src/metrics/recall_at_k_test.py b/keras_rs/src/metrics/recall_at_k_test.py index d397f9cd..4c1c3d67 100644 --- a/keras_rs/src/metrics/recall_at_k_test.py +++ b/keras_rs/src/metrics/recall_at_k_test.py @@ -6,12 +6,11 @@ from keras_rs.src import testing from keras_rs.src.metrics.recall_at_k import RecallAtK -from keras_rs.src.utils import tpu_test_utils class RecallAtKTest(testing.TestCase, parameterized.TestCase): def setUp(self): - self._strategy = tpu_test_utils.get_tpu_strategy(self) + super().setUp() self.y_true_batched = ops.array( [ @@ -234,17 +233,15 @@ def test_serialization(self): self.assertDictEqual(metric.get_config(), restored.get_config()) def test_model_evaluate(self): - with self._strategy.scope(): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[RecallAtK(k=3)], - optimizer="adam", - ) - + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[RecallAtK(k=3)], + optimizer="adam", + ) model.evaluate( x=keras.random.normal((2, 20)), y=keras.random.randint( diff --git a/keras_rs/src/testing/test_case.py b/keras_rs/src/testing/test_case.py index 6dff2bdd..29d7e2c4 100644 --- a/keras_rs/src/testing/test_case.py +++ b/keras_rs/src/testing/test_case.py @@ -24,10 +24,14 @@ def setUp(self) -> None: super().setUp() keras.utils.clear_session() keras.config.disable_traceback_filtering() - if keras.backend.backend() == "tensorflow": - tf.debugging.disable_traceback_filtering() self.on_tpu = "TPU_NAME" in os.environ self._strategy: Optional[StrategyType] = None + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + if self.on_tpu: + scope = tpu_test_utils.get_shared_tpu_strategy().scope() + scope.__enter__() + self.addCleanup(lambda: scope.__exit__(None, None, None)) @property def strategy(self) -> StrategyType: From 8092f3f808991a339ce0a3835dc4fb7df9a9be7f Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 2 Dec 2025 01:11:40 +0000 Subject: [PATCH 26/29] clean gitignore --- .gitignore | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 2e927895..4e067da8 100644 --- a/.gitignore +++ b/.gitignore @@ -19,5 +19,4 @@ build/ # Pycharm files .idea/ -venv/] -venv_tf/ +venv/ \ No newline at end of file From e8a12a6377a39049bbd0658147daa248a1d5c6c2 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 2 Dec 2025 20:51:12 +0000 Subject: [PATCH 27/29] format mypy --- keras_rs/src/metrics/ranking_metrics_utils.py | 4 ++-- keras_rs/src/testing/test_case.py | 8 +++++--- keras_rs/src/utils/tpu_test_utils.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/keras_rs/src/metrics/ranking_metrics_utils.py b/keras_rs/src/metrics/ranking_metrics_utils.py index ea3ddaaf..91712847 100644 --- a/keras_rs/src/metrics/ranking_metrics_utils.py +++ b/keras_rs/src/metrics/ranking_metrics_utils.py @@ -224,12 +224,12 @@ def get_list_weights( return final_weights -@keras.saving.register_keras_serializable() # type: ignore[misc] +@keras.saving.register_keras_serializable() # type: ignore[untyped-decorator] def default_gain_fn(label: types.Tensor) -> types.Tensor: return ops.subtract(ops.power(2.0, label), 1.0) -@keras.saving.register_keras_serializable() # type: ignore[misc] +@keras.saving.register_keras_serializable() # type: ignore[untyped-decorator] def default_rank_discount_fn(rank: types.Tensor) -> types.Tensor: return ops.divide( ops.cast(1, dtype=rank.dtype), diff --git a/keras_rs/src/testing/test_case.py b/keras_rs/src/testing/test_case.py index 29d7e2c4..7341e9eb 100644 --- a/keras_rs/src/testing/test_case.py +++ b/keras_rs/src/testing/test_case.py @@ -29,9 +29,11 @@ def setUp(self) -> None: if keras.backend.backend() == "tensorflow": tf.debugging.disable_traceback_filtering() if self.on_tpu: - scope = tpu_test_utils.get_shared_tpu_strategy().scope() - scope.__enter__() - self.addCleanup(lambda: scope.__exit__(None, None, None)) + strategy = tpu_test_utils.get_shared_tpu_strategy() + if strategy is not None: + scope = strategy.scope() + scope.__enter__() + self.addCleanup(lambda: scope.__exit__(None, None, None)) @property def strategy(self) -> StrategyType: diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py index 7d5be954..ecd69973 100644 --- a/keras_rs/src/utils/tpu_test_utils.py +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -152,7 +152,7 @@ def run_with_strategy( sample_weight_value = kwargs.get("sample_weight", None) all_inputs = args + (sample_weight_value,) - @tf.function(jit_compile=jit_compile) # type: ignore[misc] + @tf.function(jit_compile=jit_compile) # type: ignore[untyped-decorator] def tf_function_wrapper(input_tuple: Tuple[Any, ...]) -> Any: num_original_args = len(args) core_args = input_tuple[:num_original_args] From 3bfcb5cdb99d690d9bf5ac1f574f3eda1f4165ba Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 2 Dec 2025 21:13:52 +0000 Subject: [PATCH 28/29] resolve comments --- .gitignore | 2 +- .../dot_interaction_test.py | 1 - .../feature_interaction/feature_cross_test.py | 12 ++--- .../retrieval/hard_negative_mining_test.py | 3 -- .../retrieval/remove_accidental_hits_test.py | 3 -- .../sampling_probability_correction_test.py | 3 -- keras_rs/src/utils/tpu_test_utils.py | 49 ++----------------- 7 files changed, 9 insertions(+), 64 deletions(-) diff --git a/.gitignore b/.gitignore index 4e067da8..eacd3be8 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,4 @@ build/ # Pycharm files .idea/ -venv/ \ No newline at end of file +venv/ diff --git a/keras_rs/src/layers/feature_interaction/dot_interaction_test.py b/keras_rs/src/layers/feature_interaction/dot_interaction_test.py index eb634478..da714a8a 100644 --- a/keras_rs/src/layers/feature_interaction/dot_interaction_test.py +++ b/keras_rs/src/layers/feature_interaction/dot_interaction_test.py @@ -135,7 +135,6 @@ def test_predict(self, self_interaction, skip_gather): )([feature1, feature2, feature3]) x = keras.layers.Dense(units=1)(x) model = keras.Model([feature1, feature2, feature3], x) - # model.compile(optimizer="adam", loss="mse") model.predict(self.input, batch_size=2) diff --git a/keras_rs/src/layers/feature_interaction/feature_cross_test.py b/keras_rs/src/layers/feature_interaction/feature_cross_test.py index 485f776e..6bb200a2 100644 --- a/keras_rs/src/layers/feature_interaction/feature_cross_test.py +++ b/keras_rs/src/layers/feature_interaction/feature_cross_test.py @@ -79,13 +79,11 @@ def test_pre_activation(self): self.assertAllClose(self.x, output) def test_predict(self): - with self.strategy.scope(): - x0 = keras.layers.Input(shape=(3,)) - x1 = FeatureCross(projection_dim=None)(x0, x0) - x2 = FeatureCross(projection_dim=None)(x0, x1) - logits = keras.layers.Dense(units=1)(x2) - model = keras.Model(x0, logits) - model.compile(optimizer="adam", loss="mse") + x0 = keras.layers.Input(shape=(3,)) + x1 = FeatureCross(projection_dim=None)(x0, x0) + x2 = FeatureCross(projection_dim=None)(x0, x1) + logits = keras.layers.Dense(units=1)(x2) + model = keras.Model(x0, logits) model.predict(self.x0, batch_size=2) diff --git a/keras_rs/src/layers/retrieval/hard_negative_mining_test.py b/keras_rs/src/layers/retrieval/hard_negative_mining_test.py index 382e0742..d7ab74d0 100644 --- a/keras_rs/src/layers/retrieval/hard_negative_mining_test.py +++ b/keras_rs/src/layers/retrieval/hard_negative_mining_test.py @@ -9,9 +9,6 @@ class HardNegativeMiningTest(testing.TestCase, parameterized.TestCase): - def setUp(self): - super().setUp() - def create_inputs(self, rank=2): shape_3d = (15, 20, 10) shape = shape_3d[-rank:] diff --git a/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py b/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py index 38f9ef27..8cb4fa71 100644 --- a/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py +++ b/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py @@ -9,9 +9,6 @@ class RemoveAccidentalHitsTest(testing.TestCase, parameterized.TestCase): - def setUp(self): - super().setUp() - def create_inputs(self, logits_rank=2, candidate_ids_rank=1): shape_3d = (15, 20, 10) shape = shape_3d[-logits_rank:] diff --git a/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py b/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py index 629b5a66..8dc8ff73 100644 --- a/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py +++ b/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py @@ -11,9 +11,6 @@ class SamplingProbabilityCorrectionTest( testing.TestCase, parameterized.TestCase ): - def setUp(self): - super().setUp() - def create_inputs(self, logits_rank=2, probs_rank=1): shape_3d = (15, 20, 10) logits_shape = shape_3d[-logits_rank:] diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py index ecd69973..d4822a73 100644 --- a/keras_rs/src/utils/tpu_test_utils.py +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -1,19 +1,11 @@ import contextlib import os import threading -from types import ModuleType from typing import Any, Callable, ContextManager, Optional, Tuple, Union import keras import tensorflow as tf -jax: Optional[ModuleType] = None - -try: - import jax -except ImportError: - pass - class DummyStrategy: def scope(self) -> ContextManager[None]: @@ -36,8 +28,8 @@ def experimental_distribute_dataset( class JaxDummyStrategy(DummyStrategy): @property def num_replicas_in_sync(self) -> Any: - if jax is None: - return 0 + import jax + return jax.device_count("tpu") @@ -93,48 +85,13 @@ def get_shared_tpu_strategy() -> Optional[StrategyType]: ) print("### num_replicas", _shared_strategy.num_replicas_in_sync) elif keras.backend.backend() == "jax": - if jax is None: - raise ImportError( - "JAX backend requires jax to be installed for TPU." - ) - print("### num_replicas", jax.device_count("tpu")) _shared_strategy = JaxDummyStrategy() + print("### num_replicas", _shared_strategy.num_replicas_in_sync) else: _shared_strategy = DummyStrategy() - if _shared_strategy is None: - print("Failed to create the shared TPUStrategy.") return _shared_strategy -def get_tpu_strategy(test_case: Any) -> StrategyType: - """Get TPU strategy if on TPU, otherwise return DummyStrategy.""" - if "TPU_NAME" not in os.environ: - return DummyStrategy() - if keras.backend.backend() == "tensorflow": - resolver = tf.distribute.cluster_resolver.TPUClusterResolver() - tf.config.experimental_connect_to_cluster(resolver) - topology = tf.tpu.experimental.initialize_tpu_system(resolver) - tpu_metadata = resolver.get_tpu_system_metadata() - device_assignment = tf.tpu.experimental.DeviceAssignment.build( - topology, num_replicas=tpu_metadata.num_hosts - ) - strategy = tf.distribute.TPUStrategy( - resolver, experimental_device_assignment=device_assignment - ) - print("### num_replicas", strategy.num_replicas_in_sync) - test_case.addCleanup(tf.tpu.experimental.shutdown_tpu_system, resolver) - return strategy - elif keras.backend.backend() == "jax": - if jax is None: - raise ImportError( - "JAX backend requires jax to be installed for TPU." - ) - print("### num_replicas", jax.device_count("tpu")) - return JaxDummyStrategy() - else: - return DummyStrategy() - - def run_with_strategy( strategy: Any, fn: Callable[..., Any], From 620e7974b2a36ebd36a0ad32fdf44690cf3388d2 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 2 Dec 2025 23:57:49 +0000 Subject: [PATCH 29/29] address new comments --- conftest.py | 26 -------------------------- keras_rs/src/testing/test_case.py | 12 +++--------- keras_rs/src/utils/tpu_test_utils.py | 19 ++++++------------- 3 files changed, 9 insertions(+), 48 deletions(-) delete mode 100644 conftest.py diff --git a/conftest.py b/conftest.py deleted file mode 100644 index f80e27f6..00000000 --- a/conftest.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Union - -import pytest -import tensorflow as tf - -from keras_rs.src.utils import tpu_test_utils - -StrategyType = Union[ - tf.distribute.Strategy, - tpu_test_utils.DummyStrategy, - tpu_test_utils.JaxDummyStrategy, -] - - -@pytest.fixture(scope="session", autouse=True) -def prime_shared_tpu_strategy() -> None: - """ - Eagerly initializes the shared TPU strategy at the beginning of the session - if running on a TPU. This helps catch initialization errors early. - """ - strategy = tpu_test_utils.get_shared_tpu_strategy() - if not strategy: - pytest.fail( - "Failed to initialize shared TPUStrategy for the test session. " - "Check logs for details from create_tpu_strategy." - ) diff --git a/keras_rs/src/testing/test_case.py b/keras_rs/src/testing/test_case.py index 7341e9eb..dae87e97 100644 --- a/keras_rs/src/testing/test_case.py +++ b/keras_rs/src/testing/test_case.py @@ -1,7 +1,7 @@ import os import tempfile import unittest -from typing import Any, Optional, Union +from typing import Any, Optional import keras import numpy as np @@ -10,12 +10,6 @@ from keras_rs.src import types from keras_rs.src.utils import tpu_test_utils -StrategyType = Union[ - tf.distribute.Strategy, - tpu_test_utils.DummyStrategy, - tpu_test_utils.JaxDummyStrategy, -] - class TestCase(unittest.TestCase): """TestCase class for all Keras Recommenders tests.""" @@ -25,7 +19,7 @@ def setUp(self) -> None: keras.utils.clear_session() keras.config.disable_traceback_filtering() self.on_tpu = "TPU_NAME" in os.environ - self._strategy: Optional[StrategyType] = None + self._strategy: Optional[tpu_test_utils.StrategyType] = None if keras.backend.backend() == "tensorflow": tf.debugging.disable_traceback_filtering() if self.on_tpu: @@ -36,7 +30,7 @@ def setUp(self) -> None: self.addCleanup(lambda: scope.__exit__(None, None, None)) @property - def strategy(self) -> StrategyType: + def strategy(self) -> tpu_test_utils.StrategyType: strat = tpu_test_utils.get_shared_tpu_strategy() if strat is None: diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py index d4822a73..4798e6e0 100644 --- a/keras_rs/src/utils/tpu_test_utils.py +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -33,7 +33,7 @@ def num_replicas_in_sync(self) -> Any: return jax.device_count("tpu") -StrategyType = Union[tf.distribute.Strategy, DummyStrategy, JaxDummyStrategy] +StrategyType = Union[tf.distribute.Strategy, DummyStrategy] _shared_strategy: Optional[StrategyType] = None _lock = threading.Lock() @@ -104,20 +104,13 @@ def run_with_strategy( entering tf.function to guarantee a fixed graph signature. """ if keras.backend.backend() == "tensorflow": - # Extract sample_weight and treat it as an explicit third positional - # argument. If not present, use a placeholder (None). - sample_weight_value = kwargs.get("sample_weight", None) - all_inputs = args + (sample_weight_value,) + all_inputs = (args, kwargs) @tf.function(jit_compile=jit_compile) # type: ignore[untyped-decorator] - def tf_function_wrapper(input_tuple: Tuple[Any, ...]) -> Any: - num_original_args = len(args) - core_args = input_tuple[:num_original_args] - sw_value = input_tuple[-1] - - if sw_value is not None: - all_positional_args = core_args + (sw_value,) - return strategy.run(fn, args=all_positional_args) + def tf_function_wrapper(input_tuple: Tuple[Any, Any]) -> Any: + core_args, core_kwargs = input_tuple + if core_kwargs: + return strategy.run(fn, args=core_args, kwargs=core_kwargs) else: return strategy.run(fn, args=core_args)