Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ec59816
support all tests in layers to tpu
wenyi-guo Nov 20, 2025
9d1b089
fix jax test
wenyi-guo Nov 23, 2025
3c348df
add jax strategy print
wenyi-guo Nov 23, 2025
86c96dc
support tpu for losses tests and some metrics tests
wenyi-guo Nov 24, 2025
13fa93d
support all tpu tests in metrics and format
wenyi-guo Nov 24, 2025
214d40e
update tpu test workflow
wenyi-guo Nov 24, 2025
461a9ea
update actions.yml
wenyi-guo Nov 24, 2025
c4df618
fix cpu tf error on run with strategy taking kwargs and format
wenyi-guo Nov 24, 2025
5691b4d
format
wenyi-guo Nov 24, 2025
e557805
fix import
wenyi-guo Nov 24, 2025
91e82e7
fix test errors
wenyi-guo Nov 25, 2025
a756439
format
wenyi-guo Nov 25, 2025
4e0363a
fix type
wenyi-guo Nov 25, 2025
df8c1ff
ignore long runnign tpu test
wenyi-guo Nov 25, 2025
0b81a27
update ignore
wenyi-guo Nov 25, 2025
376a954
clean up
wenyi-guo Nov 25, 2025
a1dedc6
revert unnecessary tpu strategy for eager
wenyi-guo Nov 26, 2025
d9a5aeb
revert more unnecessary changes and resolve comments
wenyi-guo Nov 26, 2025
b634f78
remove venv and reformat
wenyi-guo Nov 26, 2025
58c7897
use a shared strategy in conftest.py
wenyi-guo Nov 26, 2025
ea18299
format
wenyi-guo Nov 26, 2025
dd2219f
format conftest
wenyi-guo Nov 26, 2025
29c7f29
format import
wenyi-guo Nov 26, 2025
836de86
format
wenyi-guo Nov 27, 2025
16dd280
resolve comments
wenyi-guo Dec 1, 2025
8092f3f
clean gitignore
wenyi-guo Dec 2, 2025
e8a12a6
format mypy
wenyi-guo Dec 2, 2025
3bfcb5c
resolve comments
wenyi-guo Dec 2, 2025
620e797
address new comments
wenyi-guo Dec 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ jobs:
run: python3 -c "import jax; print('JAX devices:', jax.devices())"

- name: Test with pytest
run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py
run: pytest keras_rs/src/layers/

check_format:
name: Check the code format
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ build/
.idea/

venv/
venv_tf/
101 changes: 29 additions & 72 deletions keras_rs/src/layers/embedding/distributed_embedding_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import functools
import math
import os
Expand All @@ -14,6 +13,7 @@
from keras_rs.src import types
from keras_rs.src.layers.embedding import distributed_embedding
from keras_rs.src.layers.embedding import distributed_embedding_config as config
from keras_rs.src.utils import tpu_test_utils

try:
import jax
Expand All @@ -30,28 +30,6 @@
SEQUENCE_LENGTH = 13


class DummyStrategy:
def scope(self):
return contextlib.nullcontext()

@property
def num_replicas_in_sync(self):
return 1

def run(self, fn, args):
return fn(*args)

def experimental_distribute_dataset(self, dataset, options=None):
del options
return dataset


class JaxDummyStrategy(DummyStrategy):
@property
def num_replicas_in_sync(self):
return jax.device_count("tpu")


def ragged_bool_true(self):
return True

Expand All @@ -74,46 +52,11 @@ def setUp(self):
# FLAGS.xla_sparse_core_max_ids_per_partition_per_sample = 16
# FLAGS.xla_sparse_core_max_unique_ids_per_partition_per_sample = 16

resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)

topology = tf.tpu.experimental.initialize_tpu_system(resolver)
tpu_metadata = resolver.get_tpu_system_metadata()

device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology, num_replicas=tpu_metadata.num_hosts
)
self._strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment
)
print("### num_replicas", self._strategy.num_replicas_in_sync)
self.addCleanup(tf.tpu.experimental.shutdown_tpu_system, resolver)
elif keras.backend.backend() == "jax" and self.on_tpu:
self._strategy = JaxDummyStrategy()
else:
self._strategy = DummyStrategy()

self._strategy = tpu_test_utils.get_tpu_strategy(self)
self.batch_size = (
BATCH_SIZE_PER_CORE * self._strategy.num_replicas_in_sync
)

def run_with_strategy(self, fn, *args, jit_compile=False):
"""Wrapper for running a function under a strategy."""

if keras.backend.backend() == "tensorflow":

@tf.function(jit_compile=jit_compile)
def tf_function_wrapper(*tf_function_args):
def strategy_fn(*strategy_fn_args):
return fn(*strategy_fn_args)

return self._strategy.run(strategy_fn, args=tf_function_args)

return tf_function_wrapper(*args)
else:
self.assertFalse(jit_compile)
return fn(*args)

def get_embedding_config(self, input_type, placement):
sequence_length = 1 if input_type == "dense" else SEQUENCE_LENGTH

Expand Down Expand Up @@ -263,7 +206,9 @@ def test_basics(self, input_type, placement):
preprocessed_inputs = layer.preprocess(inputs, weights)
res = layer(preprocessed_inputs)
else:
res = self.run_with_strategy(layer.__call__, inputs, weights)
res = tpu_test_utils.run_with_strategy(
self._strategy, layer.__call__, inputs, weights
)

if placement == "default_device" or not self.on_tpu:
# verify sublayers and variables are tracked
Expand Down Expand Up @@ -422,14 +367,14 @@ def test_dataset_generator():
model.compile(optimizer="adam", loss="mse")

model_inputs, _ = next(iter(test_dataset))
test_output_before = self.run_with_strategy(
model.__call__, model_inputs
test_output_before = tpu_test_utils.run_with_strategy(
self._strategy, model.__call__, model_inputs
)

model.fit(train_dataset, steps_per_epoch=1, epochs=1)

test_output_after = self.run_with_strategy(
model.__call__, model_inputs
test_output_after = tpu_test_utils.run_with_strategy(
self._strategy, model.__call__, model_inputs
)

# Verify that the embedding has actually trained.
Expand Down Expand Up @@ -610,10 +555,16 @@ def test_correctness(
preprocessed,
)
else:
res = self.run_with_strategy(layer.__call__, preprocessed)
res = tpu_test_utils.run_with_strategy(
self._strategy, layer.__call__, preprocessed
)
else:
res = self.run_with_strategy(
layer.__call__, inputs, weights, jit_compile=jit_compile
res = tpu_test_utils.run_with_strategy(
self._strategy,
layer.__call__,
inputs,
weights,
jit_compile=jit_compile,
)

self.assertEqual(res.shape, (self.batch_size, EMBEDDING_OUTPUT_DIM))
Expand Down Expand Up @@ -686,7 +637,9 @@ def test_shared_table(self):
with self._strategy.scope():
layer = distributed_embedding.DistributedEmbedding(embedding_config)

res = self.run_with_strategy(layer.__call__, inputs)
res = tpu_test_utils.run_with_strategy(
self._strategy, layer.__call__, inputs
)

if self.placement == "default_device":
self.assertLen(layer._flatten_layers(include_self=False), 1)
Expand Down Expand Up @@ -760,7 +713,9 @@ def test_mixed_placement(self):
with self._strategy.scope():
layer = distributed_embedding.DistributedEmbedding(embedding_config)

res = self.run_with_strategy(layer.__call__, inputs)
res = tpu_test_utils.run_with_strategy(
self._strategy, layer.__call__, inputs
)

self.assertEqual(
res["feature1"].shape, (self.batch_size, embedding_output_dim1)
Expand Down Expand Up @@ -793,13 +748,15 @@ def test_save_load_model(self):
keras_outputs = layer(keras_inputs)
model = keras.Model(inputs=keras_inputs, outputs=keras_outputs)

output_before = self.run_with_strategy(model.__call__, inputs)
output_before = tpu_test_utils.run_with_strategy(
self._strategy, model.__call__, inputs
)
model.save(path)

with self._strategy.scope():
reloaded_model = keras.models.load_model(path)
output_after = self.run_with_strategy(
reloaded_model.__call__, inputs
output_after = tpu_test_utils.run_with_strategy(
self._strategy, reloaded_model.__call__, inputs
)

if self.placement == "sparsecore":
Expand Down
Loading