From 2f8dab870bbd6f5391434c767be2ae97e750a233 Mon Sep 17 00:00:00 2001 From: a-mhamdi Date: Mon, 23 Dec 2024 11:39:23 +0100 Subject: [PATCH] vae --- Codes/Julia/Part-3/vae/vae.ipynb | 26 ++++++----- Codes/Julia/Part-3/vae/vae.jl | 74 +++++++++++++++++++------------- 2 files changed, 55 insertions(+), 45 deletions(-) diff --git a/Codes/Julia/Part-3/vae/vae.ipynb b/Codes/Julia/Part-3/vae/vae.ipynb index d2a9d07..8f931f5 100644 --- a/Codes/Julia/Part-3/vae/vae.ipynb +++ b/Codes/Julia/Part-3/vae/vae.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "beab227b-2d88-4b30-ac4d-52f0ebbaf1c3", "metadata": {}, "outputs": [ @@ -19,8 +19,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Julia Version 1.11.1\n", - "Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)\n", + "Julia Version 1.11.2\n", + "Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)\n", "Build Info:\n", " Official https://julialang.org/ release\n", "Platform Info:\n", @@ -30,14 +30,14 @@ " LLVM: libLLVM-16.0.6 (ORCJIT, skylake)\n", "Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)\n", "Environment:\n", - " LD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:\n", " DYLD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:\n", + " LD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:\n", " JULIA_NUM_THREADS = 8\n" ] } ], "source": [ - "versioninfo() # -> v\"1.11.1\"" + "versioninfo() # -> v\"1.11.2\"" ] }, { @@ -63,7 +63,7 @@ "metadata": {}, "outputs": [], "source": [ - "using Flux # v0.14.25\n", + "using Flux # v\"0.16.0\"\n", "using Flux: @functor\n", "using Flux: DataLoader\n", "using Flux: onecold, onehotbatch" @@ -151,7 +151,7 @@ "id": "bc40899e-5ae7-4e21-8315-cc3c9fb71571", "metadata": {}, "source": [ - "Define the `encoder` network" + "Define the `encoder` network: The encoder network should return the parameters of the _latent distribution_ (μ and σ)." ] }, { @@ -161,7 +161,6 @@ "metadata": {}, "outputs": [], "source": [ - "# The encoder network should return the parameters of the _latent distribution_ (μ and σ).\n", "struct Encoder\n", " linear\n", " μ\n", @@ -211,7 +210,7 @@ "id": "2b653535-0cfb-44e7-ab0f-363638bb0cd5", "metadata": {}, "source": [ - "Define the `decoder` network" + "Define the `decoder` network: The decoder network should return the reconstruction of the input data" ] }, { @@ -221,7 +220,6 @@ "metadata": {}, "outputs": [], "source": [ - "# The decoder network should return the reconstruction of the input data\n", "decoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Chain(\n", " Dense(latent_dim, hidden_dim, tanh),\n", " Dense(hidden_dim, input_dim)\n", @@ -308,7 +306,7 @@ " end\n", " end\n", " \n", - " md\"Save the model\"\n", + " # Save the model\n", " #=\n", " using DrWatson: struct2dict\n", " using BSON\n", @@ -337,15 +335,15 @@ ], "metadata": { "kernelspec": { - "display_name": "Julia 1.11.1", + "display_name": "IJulia 1.11.2", "language": "julia", - "name": "julia-1.11" + "name": "ijulia-1.11" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", - "version": "1.11.1" + "version": "1.11.2" } }, "nbformat": 4, diff --git a/Codes/Julia/Part-3/vae/vae.jl b/Codes/Julia/Part-3/vae/vae.jl index 4440cf0..42ccee9 100644 --- a/Codes/Julia/Part-3/vae/vae.jl +++ b/Codes/Julia/Part-3/vae/vae.jl @@ -1,19 +1,14 @@ ### A Pluto.jl notebook ### -# v0.20.3 +# v0.20.4 using Markdown using InteractiveUtils -# ╔═╡ be2eb522-922c-4cd9-9f60-3305eacef23a -################################### -#= VARIATIONAL AUTOENCODER (VAE) =# -################################### -# `versioninfo()` -> 1.11.1 - -using Markdown +# ╔═╡ 4c4e5a73-dfb5-4925-9b3e-5100879feee6 + import Pkg; Pkg.activate(".") # ╔═╡ 5431e024-d2ae-4279-bf9e-c72e1455c7bc -using Flux # v0.14.25 +using Flux # v"0.16.0" # ╔═╡ 1d0e8c52-3758-4520-afde-2c7cadd3b427 using Flux: @functor @@ -30,8 +25,14 @@ using ProgressMeter: Progress, next! # ╔═╡ 165e4355-ada6-4c29-9687-01ca368735a2 using MLDatasets +# ╔═╡ be2eb522-922c-4cd9-9f60-3305eacef23a +md"# VARIATIONAL AUTOENCODER (VAE)" + +# ╔═╡ aae8ee36-56b0-4594-9454-e0fb13bb9e32 +versioninfo() # -> v"1.11.2" + # ╔═╡ 6cdf0ae0-ac6b-4938-9598-e834cad5a94f -md"VAE implemented in `Julia` using the `Flux.jl` library" +md"**VAE** implemented in `Julia` using the `Flux.jl` library" # ╔═╡ a294646d-fc11-472e-9b3a-f7567387a373 md"Import the machine learning library `Flux`" @@ -72,11 +73,23 @@ train_loader = get_data(); test_loader = get_data(split=:test); # ╔═╡ 52d26401-72b9-421f-85bb-42cab7acf738 -md"Define the `encoder` network" -# The encoder network should return the parameters of the _latent distribution_ (μ and σ). +md"Define the `encoder` network: The encoder network should return the parameters of the _latent distribution_ (μ and σ)." -# ╔═╡ 166f94e0-e865-4bb8-96e8-0383180db4a3 -@functor Encoder +# ╔═╡ 1f68c46b-443f-404a-b872-f2ac0f5b0b0f +Encoder = begin + struct Encoder + linear + μ + log_σ + end + + @functor Encoder + + function (encoder::Encoder)(x) + h = encoder.linear(x) + encoder.μ(h), encoder.log_σ(h) + end +end # ╔═╡ 8940c576-206d-42e2-bb3f-12ecedf963c2 encoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Encoder( @@ -86,8 +99,7 @@ encoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Encoder( ) # ╔═╡ b413a818-19f5-4d89-b8de-04bf085a2ffb -md"Define the `decoder` network" -# The decoder network should return the reconstruction of the input data +md"Define the `decoder` network: The decoder network should return the reconstruction of the input data" # ╔═╡ f2b696d0-06c5-4570-8419-aa921609500c decoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Chain( @@ -150,7 +162,7 @@ function train(; kws...) end end - md"Save the model" + # Save the model #= using DrWatson: struct2dict using BSON @@ -168,22 +180,23 @@ end # ╔═╡ 73aec761-ab49-4235-9d52-3c80419de7e2 enc_model, dec_model = train() -# ╔═╡ c4f262cb-938a-457c-807f-e17b41caee5e -function (encoder::Encoder)(x) - h = encoder.linear(x) - encoder.μ(h), encoder.log_σ(h) -end - -# ╔═╡ 1f68c46b-443f-404a-b872-f2ac0f5b0b0f -struct Encoder - linear - μ - log_σ -end +# ╔═╡ a43c274c-81fd-4d5d-a37c-858ead7dfd01 +html""" + +""" # ╔═╡ Cell order: # ╠═be2eb522-922c-4cd9-9f60-3305eacef23a +# ╠═aae8ee36-56b0-4594-9454-e0fb13bb9e32 # ╠═6cdf0ae0-ac6b-4938-9598-e834cad5a94f +# ╠═4c4e5a73-dfb5-4925-9b3e-5100879feee6 # ╠═a294646d-fc11-472e-9b3a-f7567387a373 # ╠═5431e024-d2ae-4279-bf9e-c72e1455c7bc # ╠═1d0e8c52-3758-4520-afde-2c7cadd3b427 @@ -199,9 +212,7 @@ end # ╠═6a0cc987-638e-45cf-9f8c-e7cdbbfa189b # ╠═52d26401-72b9-421f-85bb-42cab7acf738 # ╠═1f68c46b-443f-404a-b872-f2ac0f5b0b0f -# ╠═166f94e0-e865-4bb8-96e8-0383180db4a3 # ╠═8940c576-206d-42e2-bb3f-12ecedf963c2 -# ╠═c4f262cb-938a-457c-807f-e17b41caee5e # ╠═b413a818-19f5-4d89-b8de-04bf085a2ffb # ╠═f2b696d0-06c5-4570-8419-aa921609500c # ╠═bf736186-b254-42be-a2ec-a6d27af51b6e @@ -209,3 +220,4 @@ end # ╠═db131d6e-c245-46cb-a292-b14c44a845b7 # ╠═91d78390-28d4-456b-9a24-8708e9172233 # ╠═73aec761-ab49-4235-9d52-3c80419de7e2 +# ╟─a43c274c-81fd-4d5d-a37c-858ead7dfd01