diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 5088f571..b125b6d1 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -89,8 +89,13 @@ jobs: if: ${{ matrix.backend == 'jax'}} run: python3 -c "import jax; print('JAX devices:', jax.devices())" - - name: Test with pytest - run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py + - name: Test with pytest (TensorFlow) + if: ${{ matrix.backend == 'tensorflow' }} + run: pytest keras_rs/ --ignore=keras_rs/src/layers/embedding/jax + + - name: Test with pytest (JAX) + if: ${{ matrix.backend == 'jax' }} + run: pytest keras_rs/ --ignore=keras_rs/src/layers/embedding/jax/distributed_embedding_test.py check_format: name: Check the code format diff --git a/keras_rs/src/layers/embedding/distributed_embedding_test.py b/keras_rs/src/layers/embedding/distributed_embedding_test.py index cb4df82f..27a592dc 100644 --- a/keras_rs/src/layers/embedding/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/distributed_embedding_test.py @@ -1,4 +1,3 @@ -import contextlib import functools import math import os @@ -14,6 +13,7 @@ from keras_rs.src import types from keras_rs.src.layers.embedding import distributed_embedding from keras_rs.src.layers.embedding import distributed_embedding_config as config +from keras_rs.src.utils import tpu_test_utils try: import jax @@ -30,28 +30,6 @@ SEQUENCE_LENGTH = 13 -class DummyStrategy: - def scope(self): - return contextlib.nullcontext() - - @property - def num_replicas_in_sync(self): - return 1 - - def run(self, fn, args): - return fn(*args) - - def experimental_distribute_dataset(self, dataset, options=None): - del options - return dataset - - -class JaxDummyStrategy(DummyStrategy): - @property - def num_replicas_in_sync(self): - return jax.device_count("tpu") - - def ragged_bool_true(self): return True @@ -74,46 +52,10 @@ def setUp(self): # FLAGS.xla_sparse_core_max_ids_per_partition_per_sample = 16 # FLAGS.xla_sparse_core_max_unique_ids_per_partition_per_sample = 16 - resolver = tf.distribute.cluster_resolver.TPUClusterResolver() - tf.config.experimental_connect_to_cluster(resolver) - - topology = tf.tpu.experimental.initialize_tpu_system(resolver) - tpu_metadata = resolver.get_tpu_system_metadata() - - device_assignment = tf.tpu.experimental.DeviceAssignment.build( - topology, num_replicas=tpu_metadata.num_hosts - ) - self._strategy = tf.distribute.TPUStrategy( - resolver, experimental_device_assignment=device_assignment - ) - print("### num_replicas", self._strategy.num_replicas_in_sync) - self.addCleanup(tf.tpu.experimental.shutdown_tpu_system, resolver) - elif keras.backend.backend() == "jax" and self.on_tpu: - self._strategy = JaxDummyStrategy() - else: - self._strategy = DummyStrategy() - self.batch_size = ( - BATCH_SIZE_PER_CORE * self._strategy.num_replicas_in_sync + BATCH_SIZE_PER_CORE * self.strategy.num_replicas_in_sync ) - def run_with_strategy(self, fn, *args, jit_compile=False): - """Wrapper for running a function under a strategy.""" - - if keras.backend.backend() == "tensorflow": - - @tf.function(jit_compile=jit_compile) - def tf_function_wrapper(*tf_function_args): - def strategy_fn(*strategy_fn_args): - return fn(*strategy_fn_args) - - return self._strategy.run(strategy_fn, args=tf_function_args) - - return tf_function_wrapper(*args) - else: - self.assertFalse(jit_compile) - return fn(*args) - def get_embedding_config(self, input_type, placement): sequence_length = 1 if input_type == "dense" else SEQUENCE_LENGTH @@ -252,18 +194,18 @@ def test_basics(self, input_type, placement): if placement == "sparsecore" and not self.on_tpu: with self.assertRaisesRegex(Exception, "sparsecore"): - with self._strategy.scope(): - distributed_embedding.DistributedEmbedding(feature_configs) + distributed_embedding.DistributedEmbedding(feature_configs) return - with self._strategy.scope(): - layer = distributed_embedding.DistributedEmbedding(feature_configs) + layer = distributed_embedding.DistributedEmbedding(feature_configs) if keras.backend.backend() == "jax": preprocessed_inputs = layer.preprocess(inputs, weights) res = layer(preprocessed_inputs) else: - res = self.run_with_strategy(layer.__call__, inputs, weights) + res = tpu_test_utils.run_with_strategy( + self.strategy, layer.__call__, inputs, weights + ) if placement == "default_device" or not self.on_tpu: # verify sublayers and variables are tracked @@ -332,8 +274,7 @@ def test_model_fit(self, input_type, use_weights): (test_model_inputs, test_labels) ) - with self._strategy.scope(): - layer = distributed_embedding.DistributedEmbedding(feature_configs) + layer = distributed_embedding.DistributedEmbedding(feature_configs) def _create_keras_input( feature_config: config.FeatureConfig, dtype: types.DType @@ -403,7 +344,7 @@ def test_dataset_generator(): # New preprocessed data removes the `weights` component. dataset_has_weights = False else: - train_dataset = self._strategy.experimental_distribute_dataset( + train_dataset = self.strategy.experimental_distribute_dataset( train_dataset, options=tf.distribute.InputOptions( experimental_fetch_to_device=False @@ -418,19 +359,18 @@ def test_dataset_generator(): inputs=keras_model_inputs, outputs=keras_model_outputs ) - with self._strategy.scope(): - model.compile(optimizer="adam", loss="mse") + model.compile(optimizer="adam", loss="mse") - model_inputs, _ = next(iter(test_dataset)) - test_output_before = self.run_with_strategy( - model.__call__, model_inputs - ) + model_inputs, _ = next(iter(test_dataset)) + test_output_before = tpu_test_utils.run_with_strategy( + self.strategy, model.__call__, model_inputs + ) - model.fit(train_dataset, steps_per_epoch=1, epochs=1) + model.fit(train_dataset, steps_per_epoch=1, epochs=1) - test_output_after = self.run_with_strategy( - model.__call__, model_inputs - ) + test_output_after = tpu_test_utils.run_with_strategy( + self.strategy, model.__call__, model_inputs + ) # Verify that the embedding has actually trained. for before, after in zip( @@ -567,8 +507,7 @@ def test_correctness( if not use_weights: weights = None - with self._strategy.scope(): - layer = distributed_embedding.DistributedEmbedding(feature_config) + layer = distributed_embedding.DistributedEmbedding(feature_config) if keras.backend.backend() == "jax": preprocessed = layer.preprocess(inputs, weights) @@ -610,16 +549,21 @@ def test_correctness( preprocessed, ) else: - res = self.run_with_strategy(layer.__call__, preprocessed) + res = tpu_test_utils.run_with_strategy( + self.strategy, layer.__call__, preprocessed + ) else: - res = self.run_with_strategy( - layer.__call__, inputs, weights, jit_compile=jit_compile + res = tpu_test_utils.run_with_strategy( + self.strategy, + layer.__call__, + inputs, + weights, + jit_compile=jit_compile, ) self.assertEqual(res.shape, (self.batch_size, EMBEDDING_OUTPUT_DIM)) - with self._strategy.scope(): - tables = layer.get_embedding_tables() + tables = layer.get_embedding_tables() emb = tables["table"] @@ -683,10 +627,11 @@ def test_shared_table(self): "dense", embedding_config ) - with self._strategy.scope(): - layer = distributed_embedding.DistributedEmbedding(embedding_config) + layer = distributed_embedding.DistributedEmbedding(embedding_config) - res = self.run_with_strategy(layer.__call__, inputs) + res = tpu_test_utils.run_with_strategy( + self.strategy, layer.__call__, inputs + ) if self.placement == "default_device": self.assertLen(layer._flatten_layers(include_self=False), 1) @@ -757,10 +702,11 @@ def test_mixed_placement(self): "dense", embedding_config ) - with self._strategy.scope(): - layer = distributed_embedding.DistributedEmbedding(embedding_config) + layer = distributed_embedding.DistributedEmbedding(embedding_config) - res = self.run_with_strategy(layer.__call__, inputs) + res = tpu_test_utils.run_with_strategy( + self.strategy, layer.__call__, inputs + ) self.assertEqual( res["feature1"].shape, (self.batch_size, embedding_output_dim1) @@ -786,21 +732,19 @@ def test_save_load_model(self): with tempfile.TemporaryDirectory() as temp_dir: path = os.path.join(temp_dir, "model.keras") - with self._strategy.scope(): - layer = distributed_embedding.DistributedEmbedding( - feature_configs - ) - keras_outputs = layer(keras_inputs) - model = keras.Model(inputs=keras_inputs, outputs=keras_outputs) + layer = distributed_embedding.DistributedEmbedding(feature_configs) + keras_outputs = layer(keras_inputs) + model = keras.Model(inputs=keras_inputs, outputs=keras_outputs) - output_before = self.run_with_strategy(model.__call__, inputs) - model.save(path) + output_before = tpu_test_utils.run_with_strategy( + self.strategy, model.__call__, inputs + ) + model.save(path) - with self._strategy.scope(): - reloaded_model = keras.models.load_model(path) - output_after = self.run_with_strategy( - reloaded_model.__call__, inputs - ) + reloaded_model = keras.models.load_model(path) + output_after = tpu_test_utils.run_with_strategy( + self.strategy, reloaded_model.__call__, inputs + ) if self.placement == "sparsecore": self.skipTest("TODO table reloading.") diff --git a/keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py b/keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py index f314f887..3d8e4daa 100644 --- a/keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py +++ b/keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py @@ -1,4 +1,5 @@ import keras +import pytest import tensorflow as tf from absl.testing import parameterized @@ -7,6 +8,10 @@ from keras_rs.src.layers.embedding.tensorflow import config_conversion +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="Tensorflow specific test", +) class ConfigConversionTest(testing.TestCase, parameterized.TestCase): @parameterized.named_parameters( ( diff --git a/keras_rs/src/layers/feature_interaction/dot_interaction_test.py b/keras_rs/src/layers/feature_interaction/dot_interaction_test.py index 99c38abc..da714a8a 100644 --- a/keras_rs/src/layers/feature_interaction/dot_interaction_test.py +++ b/keras_rs/src/layers/feature_interaction/dot_interaction_test.py @@ -12,6 +12,8 @@ class DotInteractionTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() + self.input = [ ops.array([[0.1, -4.3, 0.2, 1.1, 0.3]]), ops.array([[2.0, 3.2, -1.0, 0.0, 1.0]]), @@ -81,7 +83,12 @@ def test_call(self, self_interaction, skip_gather, exp_output_idx): self_interaction=self_interaction, skip_gather=skip_gather ) output = layer(self.input) - self.assertAllClose(output, self.exp_outputs[exp_output_idx]) + self.assertAllClose( + output, + self.exp_outputs[exp_output_idx], + tpu_atol=1e-2, + tpu_rtol=1e-2, + ) def test_invalid_input_rank(self): rank_1_input = [ops.ones((3,)), ops.ones((3,))] diff --git a/keras_rs/src/layers/feature_interaction/feature_cross_test.py b/keras_rs/src/layers/feature_interaction/feature_cross_test.py index 8724ab53..6bb200a2 100644 --- a/keras_rs/src/layers/feature_interaction/feature_cross_test.py +++ b/keras_rs/src/layers/feature_interaction/feature_cross_test.py @@ -10,6 +10,8 @@ class FeatureCrossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() + self.x0 = ops.array([[0.1, 0.2, 0.3]], dtype="float32") self.x = ops.array([[0.4, 0.5, 0.6]], dtype="float32") self.exp_output = ops.array([[0.55, 0.8, 1.05]]) diff --git a/keras_rs/src/losses/list_mle_loss_test.py b/keras_rs/src/losses/list_mle_loss_test.py index 3656354b..306fcd26 100644 --- a/keras_rs/src/losses/list_mle_loss_test.py +++ b/keras_rs/src/losses/list_mle_loss_test.py @@ -10,6 +10,7 @@ class ListMLELossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() self.unbatched_scores = ops.array( [1.0, 3.0, 2.0, 4.0, 0.8], dtype="float32" ) diff --git a/keras_rs/src/losses/pairwise_hinge_loss_test.py b/keras_rs/src/losses/pairwise_hinge_loss_test.py index f5aedb20..0e82533e 100644 --- a/keras_rs/src/losses/pairwise_hinge_loss_test.py +++ b/keras_rs/src/losses/pairwise_hinge_loss_test.py @@ -10,6 +10,7 @@ class PairwiseHingeLossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) diff --git a/keras_rs/src/losses/pairwise_logistic_loss_test.py b/keras_rs/src/losses/pairwise_logistic_loss_test.py index ffba4b05..f0644634 100644 --- a/keras_rs/src/losses/pairwise_logistic_loss_test.py +++ b/keras_rs/src/losses/pairwise_logistic_loss_test.py @@ -10,6 +10,7 @@ class PairwiseLogisticLossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) diff --git a/keras_rs/src/losses/pairwise_mean_squared_error_test.py b/keras_rs/src/losses/pairwise_mean_squared_error_test.py index 4b93eff9..8d2912f3 100644 --- a/keras_rs/src/losses/pairwise_mean_squared_error_test.py +++ b/keras_rs/src/losses/pairwise_mean_squared_error_test.py @@ -12,6 +12,7 @@ class PairwiseMeanSquaredErrorTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) diff --git a/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py b/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py index 66e7d634..6063b17b 100644 --- a/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py +++ b/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py @@ -12,6 +12,7 @@ class PairwiseSoftZeroOneLossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) diff --git a/keras_rs/src/metrics/dcg_test.py b/keras_rs/src/metrics/dcg_test.py index 430214ac..e3e000c5 100644 --- a/keras_rs/src/metrics/dcg_test.py +++ b/keras_rs/src/metrics/dcg_test.py @@ -19,6 +19,8 @@ def _compute_dcg(labels, ranks): class DCGTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() + self.y_true_batched = ops.array( [ [0, 0, 1, 0], diff --git a/keras_rs/src/metrics/mean_average_precision_test.py b/keras_rs/src/metrics/mean_average_precision_test.py index 9c16d25e..574e8af3 100644 --- a/keras_rs/src/metrics/mean_average_precision_test.py +++ b/keras_rs/src/metrics/mean_average_precision_test.py @@ -10,6 +10,8 @@ class MeanAveragePrecisionTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() + self.y_true_batched = ops.array( [ [0, 0, 1, 0], diff --git a/keras_rs/src/metrics/mean_reciprocal_rank_test.py b/keras_rs/src/metrics/mean_reciprocal_rank_test.py index 02940c36..5277d29d 100644 --- a/keras_rs/src/metrics/mean_reciprocal_rank_test.py +++ b/keras_rs/src/metrics/mean_reciprocal_rank_test.py @@ -10,6 +10,8 @@ class MeanReciprocalRankTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() + self.y_true_batched = ops.array( [ [0, 0, 1, 0], diff --git a/keras_rs/src/metrics/ndcg_test.py b/keras_rs/src/metrics/ndcg_test.py index 8c86e01c..155b8843 100644 --- a/keras_rs/src/metrics/ndcg_test.py +++ b/keras_rs/src/metrics/ndcg_test.py @@ -19,6 +19,8 @@ def _compute_dcg(labels, ranks): class NDCGTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() + self.y_true_batched = ops.array( [ [0, 0, 1, 0], diff --git a/keras_rs/src/metrics/precision_at_k_test.py b/keras_rs/src/metrics/precision_at_k_test.py index d83c5c9d..c1b345a0 100644 --- a/keras_rs/src/metrics/precision_at_k_test.py +++ b/keras_rs/src/metrics/precision_at_k_test.py @@ -10,6 +10,8 @@ class PrecisionAtKTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() + self.y_true_batched = ops.array( [ [0, 0, 1, 0], diff --git a/keras_rs/src/metrics/ranking_metrics_utils.py b/keras_rs/src/metrics/ranking_metrics_utils.py index ea3ddaaf..91712847 100644 --- a/keras_rs/src/metrics/ranking_metrics_utils.py +++ b/keras_rs/src/metrics/ranking_metrics_utils.py @@ -224,12 +224,12 @@ def get_list_weights( return final_weights -@keras.saving.register_keras_serializable() # type: ignore[misc] +@keras.saving.register_keras_serializable() # type: ignore[untyped-decorator] def default_gain_fn(label: types.Tensor) -> types.Tensor: return ops.subtract(ops.power(2.0, label), 1.0) -@keras.saving.register_keras_serializable() # type: ignore[misc] +@keras.saving.register_keras_serializable() # type: ignore[untyped-decorator] def default_rank_discount_fn(rank: types.Tensor) -> types.Tensor: return ops.divide( ops.cast(1, dtype=rank.dtype), diff --git a/keras_rs/src/metrics/recall_at_k_test.py b/keras_rs/src/metrics/recall_at_k_test.py index 1a6672ce..4c1c3d67 100644 --- a/keras_rs/src/metrics/recall_at_k_test.py +++ b/keras_rs/src/metrics/recall_at_k_test.py @@ -10,6 +10,8 @@ class RecallAtKTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() + self.y_true_batched = ops.array( [ [0, 0, 1, 0], diff --git a/keras_rs/src/testing/test_case.py b/keras_rs/src/testing/test_case.py index a764abf3..dae87e97 100644 --- a/keras_rs/src/testing/test_case.py +++ b/keras_rs/src/testing/test_case.py @@ -1,12 +1,14 @@ import os import tempfile import unittest -from typing import Any +from typing import Any, Optional import keras import numpy as np +import tensorflow as tf from keras_rs.src import types +from keras_rs.src.utils import tpu_test_utils class TestCase(unittest.TestCase): @@ -16,6 +18,31 @@ def setUp(self) -> None: super().setUp() keras.utils.clear_session() keras.config.disable_traceback_filtering() + self.on_tpu = "TPU_NAME" in os.environ + self._strategy: Optional[tpu_test_utils.StrategyType] = None + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + if self.on_tpu: + strategy = tpu_test_utils.get_shared_tpu_strategy() + if strategy is not None: + scope = strategy.scope() + scope.__enter__() + self.addCleanup(lambda: scope.__exit__(None, None, None)) + + @property + def strategy(self) -> tpu_test_utils.StrategyType: + strat = tpu_test_utils.get_shared_tpu_strategy() + + if strat is None: + self.fail( + "TPU environment detected, but the shared TPUStrategy is None. " + "Initialization likely failed." + ) + return strat + # if self._strategy is not None: + # return self._strategy + # self._strategy = tpu_test_utils.get_tpu_strategy(self) + # return self._strategy def assertAllClose( self, @@ -23,6 +50,8 @@ def assertAllClose( desired: types.Tensor, atol: float = 1e-6, rtol: float = 1e-6, + tpu_atol: float | None = None, + tpu_rtol: float | None = None, msg: str = "", ) -> None: """Verify that two tensors are close in value element by element. @@ -34,6 +63,11 @@ def assertAllClose( rtol: Relative tolerance. msg: Optional error message. """ + if tpu_atol is not None and self.on_tpu: + atol = tpu_atol + if tpu_rtol is not None and self.on_tpu: + rtol = tpu_rtol + if not isinstance(actual, np.ndarray): actual = keras.ops.convert_to_numpy(actual) if not isinstance(desired, np.ndarray): diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py new file mode 100644 index 00000000..4798e6e0 --- /dev/null +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -0,0 +1,120 @@ +import contextlib +import os +import threading +from typing import Any, Callable, ContextManager, Optional, Tuple, Union + +import keras +import tensorflow as tf + + +class DummyStrategy: + def scope(self) -> ContextManager[None]: + return contextlib.nullcontext() + + @property + def num_replicas_in_sync(self) -> int: + return 1 + + def run(self, fn: Callable[..., Any], args: Tuple[Any, ...]) -> Any: + return fn(*args) + + def experimental_distribute_dataset( + self, dataset: Any, options: Optional[Any] = None + ) -> Any: + del options + return dataset + + +class JaxDummyStrategy(DummyStrategy): + @property + def num_replicas_in_sync(self) -> Any: + import jax + + return jax.device_count("tpu") + + +StrategyType = Union[tf.distribute.Strategy, DummyStrategy] + +_shared_strategy: Optional[StrategyType] = None +_lock = threading.Lock() + + +def create_tpu_strategy() -> Optional[StrategyType]: + """Initializes the TPU system and returns a TPUStrategy.""" + print("Attempting to create TPUStrategy...") + try: + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="") + tf.config.experimental_connect_to_cluster(resolver) + tf.tpu.experimental.initialize_tpu_system(resolver) + strategy = tf.distribute.TPUStrategy(resolver) + print( + "TPUStrategy created successfully." + "Devices: {strategy.extended.num_replicas_in_sync}" + ) + return strategy + except Exception as e: + print(f"Error creating TPUStrategy: {e}") + return None + + +def get_shared_tpu_strategy() -> Optional[StrategyType]: + """ + Returns a session-wide shared TPUStrategy instance. + Creates the instance on the first call. + Returns None if not in a TPU environment or if creation fails. + """ + global _shared_strategy + if _shared_strategy is not None: + return _shared_strategy + + with _lock: + if _shared_strategy is None: + if "TPU_NAME" not in os.environ: + _shared_strategy = DummyStrategy() + return _shared_strategy + if keras.backend.backend() == "tensorflow": + resolver = tf.distribute.cluster_resolver.TPUClusterResolver() + tf.config.experimental_connect_to_cluster(resolver) + topology = tf.tpu.experimental.initialize_tpu_system(resolver) + tpu_metadata = resolver.get_tpu_system_metadata() + device_assignment = tf.tpu.experimental.DeviceAssignment.build( + topology, num_replicas=tpu_metadata.num_hosts + ) + _shared_strategy = tf.distribute.TPUStrategy( + resolver, experimental_device_assignment=device_assignment + ) + print("### num_replicas", _shared_strategy.num_replicas_in_sync) + elif keras.backend.backend() == "jax": + _shared_strategy = JaxDummyStrategy() + print("### num_replicas", _shared_strategy.num_replicas_in_sync) + else: + _shared_strategy = DummyStrategy() + return _shared_strategy + + +def run_with_strategy( + strategy: Any, + fn: Callable[..., Any], + *args: Any, + jit_compile: bool = False, + **kwargs: Any, +) -> Any: + """ + Final wrapper fix: Flattens allowed kwargs into positional args before + entering tf.function to guarantee a fixed graph signature. + """ + if keras.backend.backend() == "tensorflow": + all_inputs = (args, kwargs) + + @tf.function(jit_compile=jit_compile) # type: ignore[untyped-decorator] + def tf_function_wrapper(input_tuple: Tuple[Any, Any]) -> Any: + core_args, core_kwargs = input_tuple + if core_kwargs: + return strategy.run(fn, args=core_args, kwargs=core_kwargs) + else: + return strategy.run(fn, args=core_args) + + return tf_function_wrapper(all_inputs) + else: + assert not jit_compile + return fn(*args, **kwargs)