Skip to content

Commit 3e46b8e

Browse files
authored
fix: correct error message (#193)
* fix: correct error message It was supposed to print the received array but it was not doing it. Also, the type is now displayed. * review: fix format with pyink
1 parent 26908e9 commit 3e46b8e

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

jetstream_pt/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
flags.DEFINE_bool("enable_model_warmup", False, "enable model warmup")
4141

4242

43+
4344
def shard_weights(env, weights, weight_shardings):
4445
"""Shard weights according to weight_shardings"""
4546
sharded = {}

jetstream_pt/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def prefill(
310310
else:
311311
raise TypeError(
312312
"Input tokens should be of type Jax Array, but receiving:"
313-
" {prefill_inputs}"
313+
f" {prefill_inputs} of type {type(prefill_inputs)}"
314314
)
315315
seq_len = padded_tokens.shape[0]
316316
input_indexes = jnp.arange(0, seq_len)

0 commit comments

Comments
 (0)