Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added support for PyTorch Lightning in the DDP backend. #162

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion neps/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
WorkerRaiseError,
)
from neps.state._eval import evaluate_trial
from neps.state.filebased import create_or_load_filebased_neps_state
from neps.state.filebased import (
create_or_load_filebased_neps_state,
load_filebased_neps_state,
)
from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo
from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings
from neps.state.trial import Trial
Expand All @@ -43,6 +46,24 @@ def _default_worker_name() -> str:
return f"{os.getpid()}-{isoformat}"


def _is_ddp_and_not_rank_zero() -> bool:
import torch.distributed as dist

# Check for environment variables typically set by DDP
ddp_env_vars = ["WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"]
rank_env_vars = ["RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK"]

# Check if PyTorch distributed is initialized
if (dist.is_available() and dist.is_initialized()) or all(
var in os.environ for var in ddp_env_vars
):
for var in rank_env_vars:
rank = os.environ.get(var)
if rank is not None:
return int(rank) != 0
return False


N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR = 0
N_FAILED_TO_SET_TRIAL_STATE = 10

Expand Down Expand Up @@ -488,6 +509,26 @@ def run(self) -> None: # noqa: C901, PLR0915
)


def _launch_ddp_runtime(
*,
evaluation_fn: Callable[..., float | Mapping[str, Any]],
optimization_dir: Path,
) -> None:
neps_state = load_filebased_neps_state(directory=optimization_dir)

# TODO: This is a bit of a hack to get the current trial to evaluate. Sometimes
# the previous trial gets sampled when we don't want it to. This is a bit of a
# hack to get around that.
prev_trial = None
while True:
current_trial = neps_state.get_current_evaluating_trial()
if current_trial is not None and (
prev_trial is None or current_trial.id != prev_trial.id # type: ignore[unreachable]
):
evaluation_fn(**current_trial.config)
prev_trial = current_trial


# TODO: This should be done directly in `api.run` at some point to make it clearer at an
# entryy point how the woerer is set up to run if someone reads the entry point code.
def _launch_runtime( # noqa: PLR0913
Expand All @@ -506,6 +547,13 @@ def _launch_runtime( # noqa: PLR0913
max_evaluations_for_worker: int | None,
pre_load_hooks: Iterable[Callable[[BaseOptimizer], BaseOptimizer]] | None,
) -> None:
if _is_ddp_and_not_rank_zero():
# Do not launch a new worker if we are in a DDP setup and not rank 0
_launch_ddp_runtime(
evaluation_fn=evaluation_fn, optimization_dir=optimization_dir
)
return

if overwrite_optimization_dir and optimization_dir.exists():
logger.info(
f"Overwriting optimization directory '{optimization_dir}' as"
Expand Down
9 changes: 9 additions & 0 deletions neps/state/filebased.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,15 @@ def pending(self) -> Iterable[tuple[str, Synced[Trial, Path]]]:
]
return iter((_id, t) for _id, t, _ in sorted(pending, key=lambda x: x[2]))

@override
def evaluating(self) -> Iterable[tuple[str, Synced[Trial, Path]]]:
evaluating = [
(_id, t, trial.metadata.time_sampled)
for (_id, t) in self.all().items()
if (trial := t.synced()).state == Trial.State.EVALUATING
]
return iter((_id, t) for _id, t, _ in sorted(evaluating, key=lambda x: x[2]))


@dataclass
class ReaderWriterTrial(ReaderWriter[Trial, Path]):
Expand Down
6 changes: 6 additions & 0 deletions neps/state/neps_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,12 @@ def get_next_pending_trial(self, n: int | None = None) -> Trial | list[Trial] |
return take(n, _pending_itr)
return next(_pending_itr, None)

def get_current_evaluating_trial(self) -> Trial | None:
"""Get the current evaluating trial."""
for _, shared_trial in self._trials.evaluating():
return shared_trial.synced()
return None

def all_trial_ids(self) -> set[str]:
"""Get all the trial ids that are known about."""
return self._trials.all_trial_ids()
Expand Down
4 changes: 4 additions & 0 deletions neps/state/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def pending(self) -> Iterable[tuple[str, Synced[Trial, K]]]:
"""
...

def evaluating(self) -> Iterable[tuple[str, Synced[Trial, K]]]:
"""Get all evaluating trials in the repo."""
...


@dataclass
class VersionedResource(Generic[T, K]):
Expand Down
Loading