Skip to content

Commit

Permalink
Merge pull request #2302 from balaganesh102004:master
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 667491664
  • Loading branch information
copybara-github committed Aug 26, 2024
2 parents 57d0938 + 426ab2a commit 5c5a466
Showing 1 changed file with 67 additions and 15 deletions.
82 changes: 67 additions & 15 deletions site/en/guide/ragged_tensor.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -674,14 +674,14 @@
"source": [
"### Keras\n",
"\n",
"[tf.keras](https://www.tensorflow.org/guide/keras) is TensorFlow's high-level API for building and training deep learning models. Ragged tensors may be passed as inputs to a Keras model by setting `ragged=True` on `tf.keras.Input` or `tf.keras.layers.InputLayer`. Ragged tensors may also be passed between Keras layers, and returned by Keras models. The following example shows a toy LSTM model that is trained using ragged tensors."
"[tf.keras](https://www.tensorflow.org/guide/keras) is TensorFlow's high-level API for building and training deep learning models. It doesn't have ragged support. But it does support masked tensors. So the easiest way to use a ragged tensor in a Keras model is to convert the ragged tensor to a dense tensor, using `.to_tensor()` and then using Keras's builtin masking:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pHls7hQVJlk5"
"id": "ucYf2sSzTvQo"
},
"outputs": [],
"source": [
Expand All @@ -691,26 +691,77 @@
" 'She turned me into a newt.',\n",
" 'A newt?',\n",
" 'Well, I got better.'])\n",
"is_question = tf.constant([True, False, True, False])\n",
"\n",
"is_question = tf.constant([True, False, True, False])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MGYKmizJTw8B"
},
"outputs": [],
"source": [
"# Preprocess the input strings.\n",
"hash_buckets = 1000\n",
"words = tf.strings.split(sentences, ' ')\n",
"hashed_words = tf.strings.to_hash_bucket_fast(words, hash_buckets)\n",
"\n",
"hashed_words.to_list()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7FTujwOlUT8J"
},
"outputs": [],
"source": [
"hashed_words.to_tensor()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vzWudaESUBOZ"
},
"outputs": [],
"source": [
"tf.keras.Input?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pHls7hQVJlk5"
},
"outputs": [],
"source": [
"# Build the Keras model.\n",
"keras_model = tf.keras.Sequential([\n",
" tf.keras.layers.Input(shape=[None], dtype=tf.int64, ragged=True),\n",
" tf.keras.layers.Embedding(hash_buckets, 16),\n",
" tf.keras.layers.LSTM(32, use_bias=False),\n",
" tf.keras.layers.Embedding(hash_buckets, 16, mask_zero=True),\n",
" tf.keras.layers.LSTM(32, return_sequences=True, use_bias=False),\n",
" tf.keras.layers.GlobalAveragePooling1D(),\n",
" tf.keras.layers.Dense(32),\n",
" tf.keras.layers.Activation(tf.nn.relu),\n",
" tf.keras.layers.Dense(1)\n",
"])\n",
"\n",
"keras_model.compile(loss='binary_crossentropy', optimizer='rmsprop')\n",
"keras_model.fit(hashed_words, is_question, epochs=5)\n",
"print(keras_model.predict(hashed_words))"
"keras_model.fit(hashed_words.to_tensor(), is_question, epochs=5)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1IAjjmdTU9OU"
},
"outputs": [],
"source": [
"print(keras_model.predict(hashed_words.to_tensor()))"
]
},
{
Expand Down Expand Up @@ -799,7 +850,7 @@
"source": [
"### Datasets\n",
"\n",
"[tf.data](https://www.tensorflow.org/guide/data) is an API that enables you to build complex input pipelines from simple, reusable pieces. Its core data structure is `tf.data.Dataset`, which represents a sequence of elements, in which each element consists of one or more components. "
"[tf.data](https://www.tensorflow.org/guide/data) is an API that enables you to build complex input pipelines from simple, reusable pieces. Its core data structure is `tf.data.Dataset`, which represents a sequence of elements, in which each element consists of one or more components."
]
},
{
Expand Down Expand Up @@ -1078,9 +1129,11 @@
"import tempfile\n",
"\n",
"keras_module_path = tempfile.mkdtemp()\n",
"tf.saved_model.save(keras_model, keras_module_path)\n",
"imported_model = tf.saved_model.load(keras_module_path)\n",
"imported_model(hashed_words)"
"keras_model.save(keras_module_path+\"/my_model.keras\")\n",
"\n",
"imported_model = tf.keras.models.load_model(keras_module_path+\"/my_model.keras\")\n",
"\n",
"imported_model(hashed_words.to_tensor())"
]
},
{
Expand Down Expand Up @@ -2125,7 +2178,6 @@
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "ragged_tensor.ipynb",
"toc_visible": true
},
Expand Down

0 comments on commit 5c5a466

Please sign in to comment.