Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
nmcardoso committed Aug 19, 2023
1 parent 71d95ee commit efef510
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions mergernet/estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,24 +253,24 @@ def get_scheduler(self, scheduler: str, lr: float) -> tf.keras.optimizers.schedu
if scheduler == 'cosine_restarts':
return tf.keras.optimizers.schedules.CosineDecayRestarts(
initial_learning_rate=lr,
first_decay_steps=self.hp.get('lr_decay_steps', 40),
t_mul=self.hp.get('lr_decay_t', 2.0),
m_mul=self.hp.get('lr_decay_m', 1.0),
alpha=self.hp.get('lr_decay_alpha', 0.0),
first_decay_steps=self.hp.get('lr_decay_steps', default=40),
t_mul=self.hp.get('lr_decay_t', default=2.0),
m_mul=self.hp.get('lr_decay_m', default=1.0),
alpha=self.hp.get('lr_decay_alpha', default=0.0),
)
elif scheduler == 'cosine':
return tf.keras.optimizers.schedules.CosineDecay(
initial_learning_rate=lr,
decay_steps=self.hp.get('lr_decay_steps', 40),
alpha=self.hp.get('lr_decay_alpha', 0.0),
warmup_target=self.get('lr_warmup_target', None),
warmup_steps=self.get('lr_warmup_steps', 0)
decay_steps=self.hp.get('lr_decay_steps', default=40),
alpha=self.hp.get('lr_decay_alpha', default=0.0),
warmup_target=self.get('lr_warmup_target', default=None),
warmup_steps=self.get('lr_warmup_steps', default=0)
)
elif scheduler == 'exponential':
return tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=lr,
decay_steps=self.hp.get('lr_decay_steps', 40),
decay_rate=self.hp.get('lr_decay_rate', 40),
decay_steps=self.hp.get('lr_decay_steps', default=40),
decay_rate=self.hp.get('lr_decay_rate', default=40),
)
else:
return None
Expand Down

0 comments on commit efef510

Please sign in to comment.