From f32412cb611ac41917bfbc7024fd51c0149f35f6 Mon Sep 17 00:00:00 2001 From: a-mhamdi Date: Wed, 25 Dec 2024 02:14:13 +0100 Subject: [PATCH] vae --- Codes/Julia/Part-3/vae/vae.ipynb | 271 +++++++++++++++++++++++++++---- Codes/Julia/Part-3/vae/vae.jl | 15 +- 2 files changed, 241 insertions(+), 45 deletions(-) diff --git a/Codes/Julia/Part-3/vae/vae.ipynb b/Codes/Julia/Part-3/vae/vae.ipynb index 8f931f5..38ade51 100644 --- a/Codes/Julia/Part-3/vae/vae.ipynb +++ b/Codes/Julia/Part-3/vae/vae.ipynb @@ -9,6 +9,14 @@ "---" ] }, + { + "cell_type": "markdown", + "id": "58bdb6b7-1714-405f-8112-699ff9c37c6a", + "metadata": {}, + "source": [ + "VAE implemented in `Julia` using the `Flux.jl` library" + ] + }, { "cell_type": "code", "execution_count": 1, @@ -30,8 +38,8 @@ " LLVM: libLLVM-16.0.6 (ORCJIT, skylake)\n", "Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)\n", "Environment:\n", - " DYLD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:\n", " LD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:\n", + " DYLD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:\n", " JULIA_NUM_THREADS = 8\n" ] } @@ -41,11 +49,35 @@ ] }, { - "cell_type": "markdown", - "id": "58bdb6b7-1714-405f-8112-699ff9c37c6a", + "cell_type": "code", + "execution_count": 2, + "id": "d9d8e2d6-eb3d-4446-a59a-dc8927ea4449", "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m project at `~/Work/git-repos/AI-ML-DL/jlai/Codes/Julia/Part-3/vae`\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m\u001b[1mStatus\u001b[22m\u001b[39m `~/Work/git-repos/AI-ML-DL/jlai/Codes/Julia/Part-3/vae/Project.toml`\n", + " \u001b[90m[587475ba] \u001b[39mFlux v0.16.0\n", + " \u001b[90m[eb30cadb] \u001b[39mMLDatasets v0.7.18\n", + " \u001b[90m[91a5bcdd] \u001b[39mPlots v1.40.9\n", + " \u001b[90m[c3e4b0f8] \u001b[39mPluto v0.20.4\n", + " \u001b[90m[7f904dfe] \u001b[39mPlutoUI v0.7.60\n", + " \u001b[90m[92933f4c] \u001b[39mProgressMeter v1.10.2\n", + " \u001b[90m[d6f4376e] \u001b[39mMarkdown v1.11.0\n" + ] + } + ], "source": [ - "VAE implemented in `Julia` using the `Flux.jl` library" + "using Pkg; pkg\"activate .\"; pkg\"status\"" ] }, { @@ -58,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "127ab0b7-5b37-4265-a66e-4c35d7df1ba6", "metadata": {}, "outputs": [], @@ -71,7 +103,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "ed08fb17-bcad-4ee8-ab55-0a11588804ea", "metadata": {}, "outputs": [], @@ -81,10 +113,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "f7c37037-9d7e-4e72-8346-f8328326c610", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "dataset MNIST:\n", + " metadata => Dict{String, Any} with 3 entries\n", + " split => :train\n", + " features => 28×28×60000 Array{Float32, 3}\n", + " targets => 60000-element Vector{Int64}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "using MLDatasets\n", "d = MNIST()" @@ -92,10 +139,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "74719038-eea2-4101-b795-71f3210c040c", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "HyperParams" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "Base.@kwdef mutable struct HyperParams\n", " η = 3f-3 # Learning rate\n", @@ -120,14 +178,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "bc126bdd-1dfd-44c1-b6fb-bfa019516e44", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "get_data (generic function with 1 method)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "function get_data(; kws...)\n", " args = HyperParams(; kws...);\n", - " md\"Split data\"\n", + " # Split data\n", " data = MNIST(split=args.split);\n", " X = reshape(data.features, (args.input_dim, :));\n", " loader = DataLoader(X; batchsize=args.batchsize, shuffle=true);\n", @@ -137,7 +206,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "5e8eeff3-a9db-42de-9005-aa5cc2adfa13", "metadata": {}, "outputs": [], @@ -156,7 +225,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "79fcdbea-c6d4-4d51-ab1f-e75fb33940ec", "metadata": {}, "outputs": [], @@ -170,20 +239,43 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "id": "9f1c171e-a2f6-4a65-9843-6f7f5ec3a46c", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[33m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[33m\u001b[1mWarning: \u001b[22m\u001b[39mThe use of `Flux.@functor` is deprecated.\n", + "\u001b[33m\u001b[1m│ \u001b[22m\u001b[39mMost likely, you should write `Flux.@layer MyLayer`which will add various convenience methods for your type,such as pretty-printing and use with Adapt.jl.\n", + "\u001b[33m\u001b[1m│ \u001b[22m\u001b[39mHowever, this is not required. Flux.jl v0.15 uses Functors.jl v0.5,which makes exploration of most nested `struct`s opt-out instead of opt-in...so Flux will automatically see inside any custom struct definitions.\n", + "\u001b[33m\u001b[1m│ \u001b[22m\u001b[39mIf you really want to apply the `@functor` macro to a custom struct, use `Functors.@functor` instead.\n", + "\u001b[33m\u001b[1m└ \u001b[22m\u001b[39m\u001b[90m@ Flux ~/.julia/packages/Flux/Mhg1r/src/deprecations.jl:101\u001b[39m\n" + ] + } + ], "source": [ "@functor Encoder" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "5f6dfed6-ee27-4d16-89da-510dde2e132d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "encoder (generic function with 1 method)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "encoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Encoder(\n", " Dense(input_dim, hidden_dim, tanh), # linear\n", @@ -194,7 +286,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "938a17e8-4a49-46ab-815a-73fd70cd2e46", "metadata": {}, "outputs": [], @@ -215,10 +307,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "6bc6817b-1f4a-4d65-89ca-8d9ede2f4b11", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "decoder (generic function with 1 method)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "decoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Chain(\n", " Dense(latent_dim, hidden_dim, tanh),\n", @@ -236,10 +339,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "id": "9ae574d6-17b2-4aba-8749-09d2b84bf95c", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "vae (generic function with 1 method)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "function vae(x, enc, dec)\n", " # Encode `x` into the latent space\n", @@ -255,10 +369,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "id": "99660b4d-c3ce-469c-bb9c-fdd86a3feb13", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "l (generic function with 1 method)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "function l(x, enc, dec, λ)\n", " μ, log_σ, x̂ = vae(x, enc, dec)\n", @@ -276,10 +401,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "id": "e8ace135-e555-4266-a007-ae7c91b3adc7", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "train (generic function with 1 method)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "function train(; kws...)\n", " args = HyperParams(; kws...)\n", @@ -324,10 +460,83 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "id": "94970592-419d-4779-9f21-7a7f626503f8", - "metadata": {}, - "outputs": [], + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[33m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[33m\u001b[1mWarning: \u001b[22m\u001b[39mProgressMeter by default refresh meters with additional information in IJulia via `IJulia.clear_output`, which clears all outputs in the cell. \n", + "\u001b[33m\u001b[1m│ \u001b[22m\u001b[39m - To prevent this behaviour, do `ProgressMeter.ijulia_behavior(:append)`. \n", + "\u001b[33m\u001b[1m│ \u001b[22m\u001b[39m - To disable this warning message, do `ProgressMeter.ijulia_behavior(:clear)`.\n", + "\u001b[33m\u001b[1m└ \u001b[22m\u001b[39m\u001b[90m@ ProgressMeter ~/.julia/packages/ProgressMeter/kVZZH/src/ProgressMeter.jl:594\u001b[39m\n", + "\u001b[32mProgress: 15%|██████▎ | ETA: 0:12:19\u001b[39m\n", + "\u001b[34m loss: 192.82417\u001b[39m" + ] + }, + { + "ename": "LoadError", + "evalue": "InterruptException:", + "output_type": "error", + "traceback": [ + "InterruptException:", + "", + "Stacktrace:", + " [1] _fast_broadcast!(f::ComposedFunction{typeof(identity), typeof(+)}, x::Matrix{Float32}, yz::Vector{Float32})", + " @ NNlib ~/.julia/packages/NNlib/mRRJu/src/utils.jl:131", + " [2] bias_act!", + " @ ~/.julia/packages/NNlib/mRRJu/src/bias_act.jl:32 [inlined]", + " [3] rrule", + " @ ~/.julia/packages/NNlib/mRRJu/src/bias_act.jl:101 [inlined]", + " [4] chain_rrule", + " @ ~/.julia/packages/Zygote/nyzjS/src/compiler/chainrules.jl:224 [inlined]", + " [5] macro expansion", + " @ ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0 [inlined]", + " [6] _pullback", + " @ ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:87 [inlined]", + " [7] Dense", + " @ ~/.julia/packages/Flux/Mhg1r/src/layers/basic.jl:199 [inlined]", + " [8] _pullback(ctx::Zygote.Context{false}, f::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, args::Matrix{Float32})", + " @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0", + " [9] _applychain", + " @ ~/.julia/packages/Flux/Mhg1r/src/layers/basic.jl:68 [inlined]", + " [10] Chain", + " @ ~/.julia/packages/Flux/Mhg1r/src/layers/basic.jl:65 [inlined]", + " [11] _pullback(ctx::Zygote.Context{false}, f::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, args::Matrix{Float32})", + " @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0", + " [12] vae", + " @ ./In[14]:7 [inlined]", + " [13] _pullback(::Zygote.Context{false}, ::typeof(vae), ::Matrix{Float32}, ::Encoder, ::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})", + " @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0", + " [14] l", + " @ ./In[15]:2 [inlined]", + " [15] _pullback(::Zygote.Context{false}, ::typeof(l), ::Matrix{Float32}, ::Encoder, ::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, ::Float32)", + " @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0", + " [16] #9", + " @ ./In[16]:17 [inlined]", + " [17] _pullback(::Zygote.Context{false}, ::var\"#9#10\"{HyperParams, Matrix{Float32}}, ::Encoder, ::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})", + " @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0", + " [18] pullback(::Function, ::Zygote.Context{false}, ::Encoder, ::Vararg{Any})", + " @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface.jl:90", + " [19] pullback", + " @ ~/.julia/packages/Zygote/nyzjS/src/compiler/interface.jl:88 [inlined]", + " [20] train(; kws::@Kwargs{})", + " @ Main ./In[16]:16", + " [21] train()", + " @ Main ./In[16]:1", + " [22] top-level scope", + " @ In[17]:1" + ] + } + ], "source": [ "enc_model, dec_model = train()" ] diff --git a/Codes/Julia/Part-3/vae/vae.jl b/Codes/Julia/Part-3/vae/vae.jl index 42ccee9..d5cc165 100644 --- a/Codes/Julia/Part-3/vae/vae.jl +++ b/Codes/Julia/Part-3/vae/vae.jl @@ -5,7 +5,7 @@ using Markdown using InteractiveUtils # ╔═╡ 4c4e5a73-dfb5-4925-9b3e-5100879feee6 - import Pkg; Pkg.activate(".") + import Pkg; Pkg.activate("."); Pkg.status() # ╔═╡ 5431e024-d2ae-4279-bf9e-c72e1455c7bc using Flux # v"0.16.0" @@ -180,18 +180,6 @@ end # ╔═╡ 73aec761-ab49-4235-9d52-3c80419de7e2 enc_model, dec_model = train() -# ╔═╡ a43c274c-81fd-4d5d-a37c-858ead7dfd01 -html""" - -""" - # ╔═╡ Cell order: # ╠═be2eb522-922c-4cd9-9f60-3305eacef23a # ╠═aae8ee36-56b0-4594-9454-e0fb13bb9e32 @@ -220,4 +208,3 @@ html""" # ╠═db131d6e-c245-46cb-a292-b14c44a845b7 # ╠═91d78390-28d4-456b-9a24-8708e9172233 # ╠═73aec761-ab49-4235-9d52-3c80419de7e2 -# ╟─a43c274c-81fd-4d5d-a37c-858ead7dfd01