Skip to content

Commit

Permalink
Fix init_value and end_value in cosine decay
Browse files Browse the repository at this point in the history
Problem:
* the starting value was not always init_value
* the last value was not always end_value

Solution:
* changed the formulas (they are now compatible with the pytorch implementation https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html)
* added test to check that starting and end value coincide with `init_value` and `end_value` respectively

Misc:
   renamed alpha -> end_value in warmup_cosine_decay_schedule
PiperOrigin-RevId: 619109148
  • Loading branch information
fabianp authored and OptaxDev committed Mar 26, 2024
1 parent 5d8c7a3 commit cb6cef2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 17 deletions.
14 changes: 6 additions & 8 deletions optax/schedules/_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,11 @@ def cosine_decay_schedule(
.. math::
\frac{I (1 - E)}{2}(1+\cos(\pi\,\frac{t}{T})^p) + E\,,
\frac{(I - E)}{2}(1+\cos(\pi\,\frac{t}{T})^p) + E\,,
where :math:`T` is the number of decay steps (``decay_steps``), :math:`p` is
the ``exponent``, :math:`I` is the initial value (``init_value``) and
:math:`E` is the end value,.
:math:`E` is the end value (``end_value``).
References:
Loshchilov et al., `SGDR: Stochastic Gradient Descent with Warm Restarts
Expand All @@ -286,8 +286,8 @@ def cosine_decay_schedule(
``t`` is the current timestep and ``T`` is the ``decay_steps``. The
exponent modifies this to be ``(0.5 * (1 + cos(pi * t/T))) ** exponent``.
Defaults to 1.0.
alpha: The minimum value of the multiplier used to adjust the
learning rate. Defaults to 0.0.
alpha: Deprecated, use end_value instead. The minimum value of the
multiplier used to adjust the learning rate. Defaults to 0.0.
Returns:
schedule
Expand Down Expand Up @@ -316,8 +316,7 @@ def cosine_decay_schedule(
def schedule(count):
count = jnp.minimum(count, decay_steps)
cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * count / decay_steps))
decayed = (1 - end_value) * cosine_decay ** exponent + end_value
return init_value * decayed
return (init_value - end_value) * cosine_decay ** exponent + end_value

return schedule

Expand Down Expand Up @@ -501,7 +500,6 @@ def warmup_cosine_decay_schedule(
schedule
A function that maps step counts to values
"""
alpha = 0. if peak_value == 0. else end_value / peak_value
schedules = [
linear_schedule(
init_value=init_value,
Expand All @@ -511,7 +509,7 @@ def warmup_cosine_decay_schedule(
cosine_decay_schedule(
init_value=peak_value,
decay_steps=decay_steps - warmup_steps,
alpha=alpha,
end_value=end_value,
exponent=exponent,
),
]
Expand Down
41 changes: 32 additions & 9 deletions optax/schedules/_schedule_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,24 @@ def test_immutable_count(self):

class CosineDecayTest(chex.TestCase):

@chex.all_variants
def test_init_value_end_value(self):
"""Check cosine schedule decay for the entire training schedule."""
initial_value = 1.5
end_value = 0.2
num_steps = 10
schedule_fn = self.variant(
_schedule.cosine_decay_schedule(initial_value, num_steps, end_value))
# Test that generated values equal the expected schedule values.
generated_vals = []
for count in range(num_steps + 1):
# Compute next value.
generated_vals.append(schedule_fn(count))

# Test that the first and last values are correct.
self.assertAlmostEqual(generated_vals[0], initial_value)
self.assertAlmostEqual(generated_vals[-1], end_value)

@chex.all_variants
def test_decay_count_smaller_count(self):
"""Check cosine schedule decay for the entire training schedule."""
Expand Down Expand Up @@ -345,23 +363,28 @@ def test_decay_count_greater_count(self):
def test_decay_count_greater_count_with_end_value(self):
"""Check cosine schedule decay for a part of the training schedule."""
# Get schedule function.
initial_value = 0.1
initial_value = 0.2
end_value = 0.1
num_steps = 5
schedule_fn = self.variant(
_schedule.cosine_decay_schedule(initial_value, 5, 0.1))
_schedule.cosine_decay_schedule(initial_value, num_steps, end_value))
# Test that generated values equal the expected schedule values.
generated_vals = []
for count in range(12):
for count in range(2 * num_steps):
# Compute next value.
generated_vals.append(schedule_fn(count))

# Test output.
expected_multipliers = np.array(
0.5 + 0.5 * np.cos(
np.pi * np.array(
[0.0, 0.2, 0.4, 0.6, 0.8, 1., 1., 1., 1., 1., 1., 1.])))
expected_multipliers = 0.9 * expected_multipliers + 0.1
cos_values = 0.5 * (1 + np.cos(np.pi * np.linspace(0, 1, num_steps + 1)))
expected_values = (
(initial_value - end_value) * cos_values + end_value
)
# padd with [end_value] at the end.
expected_values = np.concatenate(
(expected_values, [end_value] * (num_steps - 1))
)
np.testing.assert_allclose(
initial_value * expected_multipliers,
expected_values,
np.array(generated_vals), atol=1e-3)

def test_cosine_alpha_exception(self):
Expand Down

0 comments on commit cb6cef2

Please sign in to comment.