Skip to content

Commit

Permalink
vae
Browse files Browse the repository at this point in the history
  • Loading branch information
a-mhamdi committed Dec 23, 2024
1 parent fa96da8 commit 2f8dab8
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 45 deletions.
26 changes: 12 additions & 14 deletions Codes/Julia/Part-3/vae/vae.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "beab227b-2d88-4b30-ac4d-52f0ebbaf1c3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Julia Version 1.11.1\n",
"Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)\n",
"Julia Version 1.11.2\n",
"Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)\n",
"Build Info:\n",
" Official https://julialang.org/ release\n",
"Platform Info:\n",
Expand All @@ -30,14 +30,14 @@
" LLVM: libLLVM-16.0.6 (ORCJIT, skylake)\n",
"Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)\n",
"Environment:\n",
" LD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:\n",
" DYLD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:\n",
" LD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:\n",
" JULIA_NUM_THREADS = 8\n"
]
}
],
"source": [
"versioninfo() # -> v\"1.11.1\""
"versioninfo() # -> v\"1.11.2\""
]
},
{
Expand All @@ -63,7 +63,7 @@
"metadata": {},
"outputs": [],
"source": [
"using Flux # v0.14.25\n",
"using Flux # v\"0.16.0\"\n",
"using Flux: @functor\n",
"using Flux: DataLoader\n",
"using Flux: onecold, onehotbatch"
Expand Down Expand Up @@ -151,7 +151,7 @@
"id": "bc40899e-5ae7-4e21-8315-cc3c9fb71571",
"metadata": {},
"source": [
"Define the `encoder` network"
"Define the `encoder` network: The encoder network should return the parameters of the _latent distribution_ (μ and σ)."
]
},
{
Expand All @@ -161,7 +161,6 @@
"metadata": {},
"outputs": [],
"source": [
"# The encoder network should return the parameters of the _latent distribution_ (μ and σ).\n",
"struct Encoder\n",
" linear\n",
" μ\n",
Expand Down Expand Up @@ -211,7 +210,7 @@
"id": "2b653535-0cfb-44e7-ab0f-363638bb0cd5",
"metadata": {},
"source": [
"Define the `decoder` network"
"Define the `decoder` network: The decoder network should return the reconstruction of the input data"
]
},
{
Expand All @@ -221,7 +220,6 @@
"metadata": {},
"outputs": [],
"source": [
"# The decoder network should return the reconstruction of the input data\n",
"decoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Chain(\n",
" Dense(latent_dim, hidden_dim, tanh),\n",
" Dense(hidden_dim, input_dim)\n",
Expand Down Expand Up @@ -308,7 +306,7 @@
" end\n",
" end\n",
" \n",
" md\"Save the model\"\n",
" # Save the model\n",
" #=\n",
" using DrWatson: struct2dict\n",
" using BSON\n",
Expand Down Expand Up @@ -337,15 +335,15 @@
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.11.1",
"display_name": "IJulia 1.11.2",
"language": "julia",
"name": "julia-1.11"
"name": "ijulia-1.11"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.11.1"
"version": "1.11.2"
}
},
"nbformat": 4,
Expand Down
74 changes: 43 additions & 31 deletions Codes/Julia/Part-3/vae/vae.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
### A Pluto.jl notebook ###
# v0.20.3
# v0.20.4

using Markdown
using InteractiveUtils

# ╔═╡ be2eb522-922c-4cd9-9f60-3305eacef23a
###################################
#= VARIATIONAL AUTOENCODER (VAE) =#
###################################
# `versioninfo()` -> 1.11.1

using Markdown
# ╔═╡ 4c4e5a73-dfb5-4925-9b3e-5100879feee6
import Pkg; Pkg.activate(".")

# ╔═╡ 5431e024-d2ae-4279-bf9e-c72e1455c7bc
using Flux # v0.14.25
using Flux # v"0.16.0"

# ╔═╡ 1d0e8c52-3758-4520-afde-2c7cadd3b427
using Flux: @functor
Expand All @@ -30,8 +25,14 @@ using ProgressMeter: Progress, next!
# ╔═╡ 165e4355-ada6-4c29-9687-01ca368735a2
using MLDatasets

# ╔═╡ be2eb522-922c-4cd9-9f60-3305eacef23a
md"# VARIATIONAL AUTOENCODER (VAE)"

# ╔═╡ aae8ee36-56b0-4594-9454-e0fb13bb9e32
versioninfo() # -> v"1.11.2"

# ╔═╡ 6cdf0ae0-ac6b-4938-9598-e834cad5a94f
md"VAE implemented in `Julia` using the `Flux.jl` library"
md"**VAE** implemented in `Julia` using the `Flux.jl` library"

# ╔═╡ a294646d-fc11-472e-9b3a-f7567387a373
md"Import the machine learning library `Flux`"
Expand Down Expand Up @@ -72,11 +73,23 @@ train_loader = get_data();
test_loader = get_data(split=:test);

# ╔═╡ 52d26401-72b9-421f-85bb-42cab7acf738
md"Define the `encoder` network"
# The encoder network should return the parameters of the _latent distribution_ (μ and σ).
md"Define the `encoder` network: The encoder network should return the parameters of the _latent distribution_ (μ and σ)."

# ╔═╡ 166f94e0-e865-4bb8-96e8-0383180db4a3
@functor Encoder
# ╔═╡ 1f68c46b-443f-404a-b872-f2ac0f5b0b0f
Encoder = begin
struct Encoder
linear
μ
log_σ
end

@functor Encoder

function (encoder::Encoder)(x)
h = encoder.linear(x)
encoder.μ(h), encoder.log_σ(h)
end
end

# ╔═╡ 8940c576-206d-42e2-bb3f-12ecedf963c2
encoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Encoder(
Expand All @@ -86,8 +99,7 @@ encoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Encoder(
)

# ╔═╡ b413a818-19f5-4d89-b8de-04bf085a2ffb
md"Define the `decoder` network"
# The decoder network should return the reconstruction of the input data
md"Define the `decoder` network: The decoder network should return the reconstruction of the input data"

# ╔═╡ f2b696d0-06c5-4570-8419-aa921609500c
decoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Chain(
Expand Down Expand Up @@ -150,7 +162,7 @@ function train(; kws...)
end
end

md"Save the model"
# Save the model
#=
using DrWatson: struct2dict
using BSON
Expand All @@ -168,22 +180,23 @@ end
# ╔═╡ 73aec761-ab49-4235-9d52-3c80419de7e2
enc_model, dec_model = train()

# ╔═╡ c4f262cb-938a-457c-807f-e17b41caee5e
function (encoder::Encoder)(x)
h = encoder.linear(x)
encoder.μ(h), encoder.log_σ(h)
end

# ╔═╡ 1f68c46b-443f-404a-b872-f2ac0f5b0b0f
struct Encoder
linear
μ
log_σ
end
# ╔═╡ a43c274c-81fd-4d5d-a37c-858ead7dfd01
html"""
<style>
main {
margin: 0 auto;
max-width: 2000px;
padding-left: max(160px, 10%);
padding-right: max(160px, 10%);
}
</style>
"""

# ╔═╡ Cell order:
# ╠═be2eb522-922c-4cd9-9f60-3305eacef23a
# ╠═aae8ee36-56b0-4594-9454-e0fb13bb9e32
# ╠═6cdf0ae0-ac6b-4938-9598-e834cad5a94f
# ╠═4c4e5a73-dfb5-4925-9b3e-5100879feee6
# ╠═a294646d-fc11-472e-9b3a-f7567387a373
# ╠═5431e024-d2ae-4279-bf9e-c72e1455c7bc
# ╠═1d0e8c52-3758-4520-afde-2c7cadd3b427
Expand All @@ -199,13 +212,12 @@ end
# ╠═6a0cc987-638e-45cf-9f8c-e7cdbbfa189b
# ╠═52d26401-72b9-421f-85bb-42cab7acf738
# ╠═1f68c46b-443f-404a-b872-f2ac0f5b0b0f
# ╠═166f94e0-e865-4bb8-96e8-0383180db4a3
# ╠═8940c576-206d-42e2-bb3f-12ecedf963c2
# ╠═c4f262cb-938a-457c-807f-e17b41caee5e
# ╠═b413a818-19f5-4d89-b8de-04bf085a2ffb
# ╠═f2b696d0-06c5-4570-8419-aa921609500c
# ╠═bf736186-b254-42be-a2ec-a6d27af51b6e
# ╠═7dc4a830-9289-49bb-a690-32a9e82ff262
# ╠═db131d6e-c245-46cb-a292-b14c44a845b7
# ╠═91d78390-28d4-456b-9a24-8708e9172233
# ╠═73aec761-ab49-4235-9d52-3c80419de7e2
# ╟─a43c274c-81fd-4d5d-a37c-858ead7dfd01

0 comments on commit 2f8dab8

Please sign in to comment.