Skip to content

Commit

Permalink
Update train_embedding.py (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
DLPerf authored Aug 27, 2021
1 parent f35d964 commit 5885441
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/train_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def transform(_):
# Note: Train and test are drawn from same distribution for demonstration purposes.
# We should get near identical scores on both of them.
ds_train = ds_full.take(train_size)
ds_train = ds_train.map(transform)
ds_train = ds_train.map(transform,num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.batch(BATCH_SIZE)

ds_test = ds_full.skip(train_size)
ds_test = ds_test.map(transform)
ds_test = ds_test.map(transform,num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(BATCH_SIZE)


Expand Down

0 comments on commit 5885441

Please sign in to comment.