Skip to content

Commit c7cae51

Browse files
committed
autoencoder
1 parent d4cd5ee commit c7cae51

File tree

5 files changed

+90
-19
lines changed

5 files changed

+90
-19
lines changed

src/Autoencoder/main.jl renamed to src/Autoencoder/Autoencoder.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ using Images
33
using Flux.Data, MLDatasets
44
using Flux.Data:DataLoader
55
using Noise
6+
7+
include("train.jl")
8+
69
device = cpu # where will the calculations be performed?
710

811
x_train, y_train = MNIST.traindata(Float32);
@@ -11,13 +14,10 @@ x_test, y_test = MNIST.testdata(Float32);
1114
m, n = size(x_test[:,:,1])
1215

1316
x_train = reshape(x_train, (m*n,size(x_train)[3]))
14-
# add some salt pepper noise to the image
15-
x_train_noise = add_gauss(x_train, 0.05)
1617
x_test = reshape(x_test, (m*n,size(x_test)[3]))
1718

1819
loader = DataLoader((data = x_train, label = x_train), batchsize=512, shuffle=true)
1920

20-
f
2121

2222
autoencoder1 = Chain(
2323
#ENCODER
@@ -65,6 +65,6 @@ params1 = Flux.params(autoencoder1)
6565
params2 = Flux.params(autoencoder2)
6666
params3 = Flux.params(autoencoder3) # parameters
6767

68-
train!(loss, params1, opt, loader, 10)
69-
train!(loss, params2, opt, loader, 35)
70-
train!(loss, params3, opt, loader, 35)
68+
# train!(loss,1 params1, opt, loader, 10)
69+
# train!(loss2, params2, opt, loader, 35)
70+
train!(loss3, params3, opt, loader, 35)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using Flux
2+
using Images
3+
using Flux.Data, MLDatasets
4+
using Flux.Data:DataLoader
5+
using Noise
6+
7+
include("train.jl")
8+
device = cpu
9+
10+
x_train, y_train = CIFAR10.traindata(Float32);
11+
x_test, y_test = CIFAR10.testdata(Float32);
12+
13+
14+
height, width, channels, number_of_pictures = size(x_train)
15+
16+
x_train = reshape(x_train, (m*n,size(x_train)[3]))
17+
x_test = reshape(x_test, (m*n,size(x_test)[3]))

src/Autoencoder/DAE.jl

Lines changed: 0 additions & 7 deletions
This file was deleted.
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# DENOISING AUTOENCODER
2+
using Flux
3+
using Images
4+
using Flux.Data, MLDatasets
5+
using Flux.Data:DataLoader
6+
using Noise
7+
8+
include("train.jl")
9+
device = cpu
10+
11+
device = cpu # where will the calculations be performed?
12+
13+
x_train, y_train = MNIST.traindata(Float32);
14+
x_test, y_test = MNIST.testdata(Float32);
15+
16+
x_train = 0.299 * x_train[:,:,1,:] + 0.587* x_train[:,:,2,:] + 0.114 * x_train[:,:,3,:]
17+
x_test = 0.299 * x_test[:,:,1,:] + 0.587* x_test[:,:,2,:] + 0.114 * x_test[:,:,3,:]
18+
19+
m, n = size(x_test[:,:,1])
20+
21+
x_train = reshape(x_train, (m*n,size(x_train)[3]))
22+
x_test = reshape(x_test, (m*n,size(x_test)[3]))
23+
24+
# add some salt pepper noise to the image
25+
x_train_noise = add_gauss(x_train, 0.15)
26+
x_test_noise = add_gauss(x_test, 0.15)
27+
loader = DataLoader((data = x_train_noise, label = x_train), batchsize=512, shuffle=true)
28+
29+
input_size = m*n
30+
31+
DAE1 = Chain(
32+
#ENCODER
33+
Dense(input_size, input_size),
34+
BatchNorm(input_size, relu),
35+
36+
Dense(input_size, input_size),
37+
38+
Dense(input_size, input_size),
39+
BatchNorm(input_size, relu),
40+
41+
Dense(input_size, input_size, sigmoid)
42+
43+
) |> device
44+
45+
46+
DAE2 = Chain(
47+
#ENCODER
48+
Dense(input_size, 128),
49+
BatchNorm(128, relu),
50+
51+
Dense(128, 16),
52+
53+
Dense(16,128),
54+
55+
# Dense(input_size, input_size),
56+
BatchNorm(128, relu),
57+
58+
Dense(128, input_size, sigmoid)
59+
60+
) |> device
61+
62+
loss(x, y) = Flux.Losses.mse(DAE2(x), y)
63+
64+
dae_params = Flux.params(DAE2)
65+
optim = ADAM(0.05)
66+
67+
train!(loss, dae_params, optim, loader, 35)

src/Autoencoder/train.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,3 @@ function train!(model_loss, model_params, opt, loader, epochs = 10)
1919
end
2020
"Total train steps: $train_steps" |> println
2121
end
22-
23-
function fit(model, data, optimizer, loss_func)
24-
Loss(x,y) = loss_func(x,y)
25-
model_params = Flux.params(model)
26-
27-
end

0 commit comments

Comments
 (0)