Skip to content

Commit

Permalink
Support minimum interval seconds since last save in Continuous checkp…
Browse files Browse the repository at this point in the history
…oint save policy, Cloned from CL 732815358.

PiperOrigin-RevId: 733282182
  • Loading branch information
Orbax Authors committed Mar 5, 2025
1 parent d010cfa commit 0c7ff1a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Defines policies for when a checkpoint is saved."""

import dataclasses
import datetime
import typing
from typing import Container, Protocol, Sequence

Expand All @@ -23,6 +24,7 @@
class StepInfo:
"""Relevant information about a checkpoint step."""
step: int
time: datetime.datetime


@dataclasses.dataclass(kw_only=True)
Expand Down Expand Up @@ -87,19 +89,26 @@ def should_save(
return step.step in self.steps


@dataclasses.dataclass
class ContinuousCheckpointingPolicy(SaveDecisionPolicy):
"""Checkpoint as often as possible, as long as a save is not ongoing."""

interval: int | None = None

def should_save(
self,
step: StepInfo,
previous_steps: Sequence[StepInfo],
*,
context: DecisionContext
) -> bool:
del step
del previous_steps
return not context.is_saving_in_progress
if not previous_steps or self.interval is None:
return not context.is_saving_in_progress
else:
return not context.is_saving_in_progress and (
step.time - previous_steps[-1].time
>= datetime.timedelta(seconds=self.interval)
)


class PreemptionCheckpointingPolicy(SaveDecisionPolicy):
Expand Down
6 changes: 4 additions & 2 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,10 +1084,12 @@ def should_save(self, step: int) -> bool:
is_saving_in_progress = self.is_saving_in_progress()
reached_preemption = self.reached_preemption(step)
previous_step_infos = [
save_decision_policy_lib.StepInfo(step=ckpt.step)
save_decision_policy_lib.StepInfo(step=ckpt.step, time=ckpt.time)
for ckpt in self._checkpoints
]
current_step_info = save_decision_policy_lib.StepInfo(step=step)
current_step_info = save_decision_policy_lib.StepInfo(
step=step, time=datetime.datetime.now(tz=datetime.timezone.utc),
)
context = save_decision_policy_lib.DecisionContext(
is_saving_in_progress=is_saving_in_progress,
reached_preemption=reached_preemption,
Expand Down

0 comments on commit 0c7ff1a

Please sign in to comment.