Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace jax.tree_util.tree_map() with jax.tree_util.tree_multimap() #3

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions Tutorial_2_JAX_HeroPro+_Colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@
},
"source": [
"another_list_of_lists = list_of_lists\n",
"print(jax.tree_multimap(lambda x, y: x+y, list_of_lists, another_list_of_lists))"
"print(jax.jax.tree_util.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists))"
],
"execution_count": null,
"outputs": []
Expand All @@ -418,10 +418,10 @@
"id": "09Pdhyh2ISb4"
},
"source": [
"# PyTrees need to have the same structure if we are to apply tree_multimap!\n",
"# PyTrees need to have the same structure if we are to apply jax.tree_util.tree_map!\n",
"another_list_of_lists = deepcopy(list_of_lists)\n",
"another_list_of_lists.append([23])\n",
"print(jax.tree_multimap(lambda x, y: x+y, list_of_lists, another_list_of_lists))"
"print(jax.jax.tree_util.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists))"
],
"execution_count": null,
"outputs": []
Expand Down Expand Up @@ -493,7 +493,7 @@
" # Task: analyze grads and make sure it has the same structure as params\n",
"\n",
" # SGD update\n",
" return jax.tree_multimap(\n",
" return jax.jax.tree_util.tree_map(\n",
" lambda p, g: p - lr * g, params, grads # for every leaf i.e. for every param of MLP\n",
" )"
],
Expand Down Expand Up @@ -979,7 +979,7 @@
"\n",
" # Each device performs its own SGD update, but since we start with the same params\n",
" # and synchronise gradients, the params stay in sync on each device.\n",
" new_params = jax.tree_multimap(\n",
" new_params = jax.jax.tree_util.tree_map(\n",
" lambda param, g: param - g * lr, params, grads)\n",
" \n",
" # If we were using Adam or another stateful optimizer,\n",
Expand Down Expand Up @@ -1316,4 +1316,4 @@
"outputs": []
}
]
}
}