Skip to content

Commit

Permalink
TypeError: build_encoder_decoder_net() takes 0 positional arguments b…
Browse files Browse the repository at this point in the history
…ut 3 were given
  • Loading branch information
cl0ver012 committed May 16, 2018
1 parent 75adb0f commit 8fc25dd
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 15 deletions.
4 changes: 2 additions & 2 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
import numpy as np

from data_generator import generate_trimap, random_choice, get_alpha_test
from model import build_encoder_decoder_net
from model import build_encoder_decoder
from utils import get_final_output, safe_crop

if __name__ == '__main__':
img_rows, img_cols = 320, 320
channel = 4

model_weights_path = 'models/model.62-0.0524.hdf5'
model = build_encoder_decoder_net()
model = build_encoder_decoder()
model.load_weights(model_weights_path)
print(model.summary())

Expand Down
4 changes: 2 additions & 2 deletions migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np

from config import channel
from model import build_encoder_decoder_net
from model import build_encoder_decoder
from vgg16 import vgg16_model


Expand Down Expand Up @@ -43,7 +43,7 @@ def migrate_model(new_model):


if __name__ == '__main__':
model = build_encoder_decoder_net()
model = build_encoder_decoder()
migrate_model(model)
print(model.summary())
model.save_weights('models/model_weights.h5')
Expand Down
8 changes: 4 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from custom_layers.unpooling_layer import Unpooling


def build_encoder_decoder_net():
def build_encoder_decoder():
# Encoder
input_tensor = Input(shape=(320, 320, 4))
x = ZeroPadding2D((1, 1))(input_tensor)
Expand Down Expand Up @@ -96,7 +96,7 @@ def build_encoder_decoder_net():
return model


def build_refinement_net(encoder_decoder):
def build_refinement(encoder_decoder):
input_tensor = encoder_decoder.input
x = encoder_decoder.output
x = Concatenate(axis=1)([input_tensor, x])
Expand All @@ -111,12 +111,12 @@ def build_refinement_net(encoder_decoder):


if __name__ == '__main__':
encoder_decoder = build_encoder_decoder_net(320, 320, 4)
encoder_decoder = build_encoder_decoder()
# input_layer = model.get_layer('input')
print(encoder_decoder.summary())
plot_model(encoder_decoder, to_file='encoder_decoder.svg', show_layer_names=True, show_shapes=True)

refinement = build_refinement_net(encoder_decoder)
refinement = build_refinement(encoder_decoder)
print(refinement.summary())
plot_model(refinement, to_file='refinement.svg', show_layer_names=True, show_shapes=True)

Expand Down
4 changes: 2 additions & 2 deletions plot_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# dependency: pip install pydot & brew install graphviz
from model import build_encoder_decoder_net
from model import build_encoder_decoder
from keras.utils import plot_model

if __name__ == '__main__':
img_rows, img_cols = 320, 320
channel = 3
model = build_encoder_decoder_net(img_rows, img_cols, channel)
model = build_encoder_decoder(img_rows, img_cols, channel)
plot_model(model, to_file='model.svg', show_layer_names=True, show_shapes=True)
4 changes: 2 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import cv2 as cv
import numpy as np

from model import build_encoder_decoder_net
from model import build_encoder_decoder

# python test.py -i "images/image.png" -t "images/trimap.png"
if __name__ == '__main__':
img_rows, img_cols = 320, 320
channel = 4

model_weights_path = 'models/model.35-0.03.hdf5'
model = build_encoder_decoder_net()
model = build_encoder_decoder()
model.load_weights(model_weights_path)
print(model.summary())

Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import migrate
from config import patience, batch_size, epochs, num_train_samples, num_valid_samples
from data_generator import train_gen, valid_gen
from model import build_encoder_decoder_net
from model import build_encoder_decoder
from utils import custom_loss_wrapper, get_available_cpus, get_available_gpus

if __name__ == '__main__':
Expand Down Expand Up @@ -57,10 +57,10 @@ def on_epoch_end(self, epoch, logs=None):
# model_checkpoint = MyCbk(model)
# else:
if pretrained_path is not None:
new_model = build_encoder_decoder_net()
new_model = build_encoder_decoder()
new_model.load_weights(pretrained_path)
else:
new_model = build_encoder_decoder_net()
new_model = build_encoder_decoder()
migrate.migrate_model(new_model)

# sgd = SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True)
Expand Down

0 comments on commit 8fc25dd

Please sign in to comment.