From 2a622f61a52e875839c8e168bde1ce4149d6e06c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 16 Jan 2023 17:57:21 +0100 Subject: [PATCH 1/2] Loop flattening: remove `.connect()` --- docs/source-pytorch/extensions/loops.rst | 87 ------------------- src/lightning_fabric/utilities/seed.py | 2 +- .../loops/dataloader/evaluation_loop.py | 4 - .../loops/dataloader/prediction_loop.py | 4 - .../loops/epoch/prediction_epoch_loop.py | 3 - .../loops/epoch/training_epoch_loop.py | 14 --- src/pytorch_lightning/loops/fit_loop.py | 4 - src/pytorch_lightning/loops/loop.py | 56 +----------- .../loops/optimization/optimizer_loop.py | 3 - src/pytorch_lightning/trainer/trainer.py | 83 ++---------------- tests/tests_pytorch/loops/test_loops.py | 86 +++--------------- .../tests_pytorch/loops/test_training_loop.py | 2 +- 12 files changed, 22 insertions(+), 326 deletions(-) diff --git a/docs/source-pytorch/extensions/loops.rst b/docs/source-pytorch/extensions/loops.rst index 5c7385ec7c0b5..c35fa27f296e0 100644 --- a/docs/source-pytorch/extensions/loops.rst +++ b/docs/source-pytorch/extensions/loops.rst @@ -259,28 +259,6 @@ run (optional) ---------- -Subloops --------- - -When you want to customize nested loops within loops use the :meth:`~pytorch_lightning.loops.loop.Loop.connect` method: - -.. code-block:: python - - # Optional: stitch back the trainer arguments - epoch_loop = MyEpochLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps) - # Optional: connect children loops as they might have existing state - epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop) - # Instantiate and connect the loop. - trainer.fit_loop.connect(epoch_loop=epoch_loop) - trainer.fit(model) - -More about the built-in loops and how they are composed is explained in the next section. - -.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/connect-epoch-loop.gif - :alt: Animation showing how to connect a custom subloop - ----------- - Built-in Loops -------------- @@ -342,71 +320,6 @@ Each of these :code:`for`-loops represents a class implementing the :class:`~pyt It simply iterates over each prediction dataloader from one to the next by calling :code:`PredictionEpochLoop.run()` in its :code:`advance()` method. ----------- - -Available Loops in Lightning Flash ----------------------------------- - -`Active Learning `__ is a machine learning practice in which the user interacts with the learner in order to provide new labels when required. - -You can find a real use case in `Lightning Flash `_. - -Flash implements the :code:`ActiveLearningLoop` that you can use together with the :code:`ActiveLearningDataModule` to label new data on the fly. -To run the following demo, install Flash and `BaaL `__ first: - -.. code-block:: bash - - pip install lightning-flash[image] baal - -.. code-block:: python - - import torch - - import flash - from flash.core.classification import Probabilities - from flash.core.data.utils import download_data - from flash.image import ImageClassificationData, ImageClassifier - from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop - - # 1. Create the DataModule - download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data") - - # Implement the research use-case where we mask labels from labelled dataset. - datamodule = ActiveLearningDataModule( - ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2), - initial_num_labels=5, - val_split=0.1, - ) - - # 2. Build the task - head = torch.nn.Sequential( - torch.nn.Dropout(p=0.1), - torch.nn.Linear(512, datamodule.num_classes), - ) - model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, output=Probabilities()) - - - # 3.1 Create the trainer - trainer = flash.Trainer(max_epochs=3) - - # 3.2 Create the active learning loop and connect it to the trainer - active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1) - active_learning_loop.connect(trainer.fit_loop) - trainer.fit_loop = active_learning_loop - - # 3.3 Finetune - trainer.finetune(model, datamodule=datamodule, strategy="freeze") - - # 4. Predict what's on a few images! ants or bees? - predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") - print(predictions) - - # 5. Save the model! - trainer.save_checkpoint("image_classification_model.pt") - -Here is the `Active Learning Loop example `_ and the `code for the active learning loop `_. - - ---------- Advanced Examples diff --git a/src/lightning_fabric/utilities/seed.py b/src/lightning_fabric/utilities/seed.py index a6bd1619d370a..b26125951989e 100644 --- a/src/lightning_fabric/utilities/seed.py +++ b/src/lightning_fabric/utilities/seed.py @@ -108,7 +108,7 @@ def _collect_rng_states() -> Dict[str, Any]: """Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python.""" return { "torch": torch.get_rng_state(), - "torch.cuda": torch.cuda.get_rng_state_all(), + # "torch.cuda": torch.cuda.get_rng_state_all(), "numpy": np.random.get_state(), "python": python_get_rng_state(), } diff --git a/src/pytorch_lightning/loops/dataloader/evaluation_loop.py b/src/pytorch_lightning/loops/dataloader/evaluation_loop.py index 2c96a9e2fc130..f2d840590e1e0 100644 --- a/src/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/src/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -78,10 +78,6 @@ def prefetch_batches(self) -> int: is_unsized = batches[self.current_dataloader_idx] == float("inf") return int(is_unsized) - def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override] - """Connect the evaluation epoch loop with this loop.""" - self.epoch_loop = epoch_loop - @property def done(self) -> bool: """Returns whether all dataloaders are processed or evaluation should be skipped altogether.""" diff --git a/src/pytorch_lightning/loops/dataloader/prediction_loop.py b/src/pytorch_lightning/loops/dataloader/prediction_loop.py index 606bfcc4024ce..1f9df89a00501 100644 --- a/src/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/src/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -66,10 +66,6 @@ def dataloaders(self) -> Sequence[DataLoader]: def skip(self) -> bool: return sum(self.max_batches) == 0 - def connect(self, epoch_loop: PredictionEpochLoop) -> None: # type: ignore[override] - """Connect the prediction epoch loop with this loop.""" - self.epoch_loop = epoch_loop - def reset(self) -> None: """Resets the internal state of the loop for a new run.""" self.predictions = [] diff --git a/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 77bca172d56b8..66794a8caf0ac 100644 --- a/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -38,9 +38,6 @@ def should_store_predictions(self) -> bool: any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks) return self.return_predictions or any_pred - def connect(self, **kwargs: "Loop") -> None: - raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") - def reset(self) -> None: """Resets the loops internal state.""" self._seen_batch_indices = [] diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index e63569f67a960..3abcfb95204d4 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -121,20 +121,6 @@ def done(self) -> bool: return False - def connect( # type: ignore[override] - self, - optimizer_loop: Optional[OptimizerLoop] = None, - manual_loop: Optional[ManualOptimization] = None, - val_loop: Optional["loops.EvaluationLoop"] = None, - ) -> None: - """Optionally connect a custom batch or validation loop to this training epoch loop.""" - if optimizer_loop is not None: - self.optimizer_loop = optimizer_loop - if manual_loop is not None: - self.manual_loop = manual_loop - if val_loop is not None: - self.val_loop = val_loop - def reset(self) -> None: """Resets the internal state of the loop for a new run.""" if self.restarting: diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index 6e39d0b59f11a..38b662b66e04b 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -170,10 +170,6 @@ def skip(self) -> bool: # until `on_run_start`, we use `limit_train_batches` instead return self.done or self.trainer.limit_train_batches == 0 - def connect(self, epoch_loop: TrainingEpochLoop) -> None: # type: ignore[override] - """Connects a training epoch loop to this fit loop.""" - self.epoch_loop = epoch_loop - def reset(self) -> None: """Resets the internal state of this loop.""" if self.restarting: diff --git a/src/pytorch_lightning/loops/loop.py b/src/pytorch_lightning/loops/loop.py index abe25a97c3a16..461b0a6a2f6f1 100644 --- a/src/pytorch_lightning/loops/loop.py +++ b/src/pytorch_lightning/loops/loop.py @@ -11,16 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union +from typing import Any, Dict, Generic, Optional, TypeVar from torchmetrics import Metric import pytorch_lightning as pl from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.progress import BaseProgress -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training T = TypeVar("T") # the output type of `run` @@ -102,58 +100,6 @@ def skip(self): """ return False - def connect(self, **kwargs: "Loop") -> None: - """Optionally connect one or multiple loops to this one. - - Linked loops should form a tree. - """ - - def replace(self, **loops: Union["Loop", Type["Loop"]]) -> None: - """Optionally replace one or multiple of this loop's sub-loops. - - This method takes care of instantiating the class (if necessary) with all existing arguments, connecting all - sub-loops of the old loop to the new instance, setting the ``Trainer`` reference, and connecting the new loop to - the parent. - - Args: - **loops: ``Loop`` subclasses or instances. The name used should match the loop attribute name you want to - replace. - - Raises: - MisconfigurationException: When passing a ``Loop`` class, if the ``__init__`` arguments do not match those - of the Loop class it replaces. - """ - new_loops = {} - - for name, type_or_object in loops.items(): - old_loop = getattr(self, name) - - if isinstance(type_or_object, type): - # compare the signatures - old_parameters = inspect.signature(old_loop.__class__.__init__).parameters - current_parameters = inspect.signature(type_or_object.__init__).parameters - if old_parameters != current_parameters: - raise MisconfigurationException( - f"`{self.__class__.__name__}.replace({type_or_object.__name__})` can only be used if the" - f" `__init__` signatures match but `{old_loop.__class__.__name__}` does not." - ) - # instantiate the loop - kwargs = {p: getattr(old_loop, p) for p in old_parameters if p != "self"} - loop = type_or_object(**kwargs) - else: - loop = type_or_object - - # connect sub-loops - kwargs = {n: l for n, l in old_loop.__dict__.items() if isinstance(l, Loop)} - if kwargs: - loop.connect(**kwargs) - # set the trainer reference - loop.trainer = self.trainer - - new_loops[name] = loop - # connect to self - self.connect(**new_loops) - def on_skip(self) -> T: """The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`. diff --git a/src/pytorch_lightning/loops/optimization/optimizer_loop.py b/src/pytorch_lightning/loops/optimization/optimizer_loop.py index 819178cff15c0..07284198aa183 100644 --- a/src/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/src/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -172,9 +172,6 @@ def done(self) -> bool: """Returns ``True`` when the last optimizer in the sequence has run.""" return self.optim_progress.optimizer_position >= len(self._indices) - def connect(self, **kwargs: "Loop") -> None: - raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") - def reset(self) -> None: if not self.restarting: # when reset() is called from outside (manually), we reset the loop progress diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index e332fbf68b824..321e4ae189f17 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -371,21 +371,16 @@ def __init__( self._signal_connector = SignalConnector(self) self.tuner = Tuner(self) - fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs) - training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps) - fit_loop.connect(epoch_loop=training_epoch_loop) - - # default .fit() loop - self.fit_loop = fit_loop - - # default .validate() loop + # init loops + self.fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs) + self.fit_loop.epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps) self.validate_loop = EvaluationLoop() - - # default .test() loop self.test_loop = EvaluationLoop() - - # default .predict() loop self.predict_loop = PredictionLoop() + self.fit_loop.trainer = self + self.validate_loop.trainer = self + self.test_loop.trainer = self + self.predict_loop.trainer = self # init callbacks # Declare attributes to be set in _callback_connector on_trainer_init @@ -1103,8 +1098,6 @@ def _run_train(self) -> None: self.model.train() torch.set_grad_enabled(True) - self.fit_loop.trainer = self - with torch.autograd.set_detect_anomaly(self._detect_anomaly): self.fit_loop.run() @@ -1114,9 +1107,6 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: # reload dataloaders self._evaluation_loop._reload_evaluation_dataloaders() - # reset trainer on this loop and all child loops in case user connected a custom loop - self._evaluation_loop.trainer = self - with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context( self.accelerator, self._inference_mode ): @@ -1133,8 +1123,6 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: self.reset_predict_dataloader(self.lightning_module) - # reset trainer on this loop and all child loops in case user connected a custom loop - self.predict_loop.trainer = self with _evaluation_context(self.accelerator, self._inference_mode): return self.predict_loop.run() @@ -1969,63 +1957,6 @@ def is_last_batch(self) -> bool: """Whether trainer is executing the last batch.""" return self.fit_loop.epoch_loop.batch_progress.is_last_batch - @property - def fit_loop(self) -> FitLoop: - return self._fit_loop - - @fit_loop.setter - def fit_loop(self, loop: FitLoop) -> None: - """Attach a custom fit loop to this Trainer. - - It will run with - :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`. - """ - loop.trainer = self - self._fit_loop = loop - - @property - def validate_loop(self) -> EvaluationLoop: - return self._validate_loop - - @validate_loop.setter - def validate_loop(self, loop: EvaluationLoop) -> None: - """Attach a custom validation loop to this Trainer. - - It will run with - :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`. Note that this loop is different from the one - running during training inside the :meth:`pytorch_lightning.trainer.trainer.Trainer.fit` call. - """ - loop.trainer = self - self._validate_loop = loop - - @property - def test_loop(self) -> EvaluationLoop: - return self._test_loop - - @test_loop.setter - def test_loop(self, loop: EvaluationLoop) -> None: - """Attach a custom test loop to this Trainer. - - It will run with - :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`. - """ - loop.trainer = self - self._test_loop = loop - - @property - def predict_loop(self) -> PredictionLoop: - return self._predict_loop - - @predict_loop.setter - def predict_loop(self, loop: PredictionLoop) -> None: - """Attach a custom prediction loop to this Trainer. - - It will run with - :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. - """ - loop.trainer = self - self._predict_loop = loop - @property def _evaluation_loop(self) -> EvaluationLoop: if self.state.fn == TrainerFn.FITTING: diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 0701071df2a7c..dbb944ae33352 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -25,72 +25,27 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset -from pytorch_lightning.loops import Loop, OptimizerLoop +from pytorch_lightning.loops import Loop from pytorch_lightning.trainer.progress import BaseProgress from tests_pytorch.helpers.runif import RunIf -class NestedLoop(Loop): - def __init__(self): - super().__init__() - self.child_loop0 = None - self.child_loop1 = None - - @property - def done(self) -> bool: - return False - - def connect(self, child0, child1): - self.child_loop0 = child0 - self.child_loop1 = child1 - - def reset(self) -> None: - pass - - def advance(self, *args, **kwargs): - pass - - -@pytest.mark.parametrize("loop_name", ["fit_loop", "validate_loop", "test_loop", "predict_loop"]) -def test_connect_loops_direct(loop_name): - """Test Trainer references in loops on assignment.""" - loop = NestedLoop() - - with pytest.raises(RuntimeError, match="The loop is not attached to a Trainer"): - _ = loop.trainer - - trainer = Trainer() - - # trainer.loop_name = loop - setattr(trainer, loop_name, loop) - assert loop.trainer is trainer - - -def test_connect_loops_recursive(): - """Test Trainer references in a nested loop assigned to a Trainer.""" - main_loop = NestedLoop() - child0 = NestedLoop() - child1 = NestedLoop() - main_loop.connect(child0, child1) - - with pytest.raises(RuntimeError, match="The loop is not attached to a Trainer"): - _ = main_loop.trainer - - with pytest.raises(RuntimeError, match="The loop is not attached to a Trainer"): - _ = main_loop.child_loop0.trainer - - trainer = Trainer() - trainer.fit_loop = main_loop - assert child0.trainer is child1.trainer - assert child0.trainer is trainer - - def test_restarting_loops_recursive(): - class MyLoop(NestedLoop): + class MyLoop(Loop): def __init__(self, loop=None): super().__init__() self.child = loop + @property + def done(self) -> bool: + return False + + def reset(self) -> None: + pass + + def advance(self, *args, **kwargs): + pass + loop = MyLoop(MyLoop(MyLoop())) assert not loop.restarting @@ -102,23 +57,6 @@ def __init__(self, loop=None): assert loop.child.child.restarting -def test_connect_subloops(tmpdir): - """Test connecting individual subloops by calling `trainer.x.y.connect()`""" - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - - epoch_loop = trainer.fit_loop.epoch_loop - new_optimizer_loop = OptimizerLoop() - epoch_loop.connect(optimizer_loop=new_optimizer_loop) - assert epoch_loop.optimizer_loop is new_optimizer_loop - - with pytest.raises(RuntimeError, match="The loop is not attached to a Trainer"): - _ = new_optimizer_loop.trainer - - trainer.fit(model) - assert new_optimizer_loop.trainer is trainer - - class CustomException(Exception): pass diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index 8cc06315e3e55..ffa3e40995c04 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -158,7 +158,7 @@ def test_fit_loop_done_log_messages(caplog): epoch_loop = Mock() epoch_loop.global_step = 10 - fit_loop.connect(epoch_loop=epoch_loop) + fit_loop.epoch_loop = epoch_loop fit_loop.max_steps = 10 assert fit_loop.done assert "max_steps=10` reached" in caplog.text From a228d35a856169f462f3d2afd49fce4ad92a9043 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 16 Jan 2023 18:04:15 +0100 Subject: [PATCH 2/2] # --- src/lightning_fabric/utilities/seed.py | 2 +- src/pytorch_lightning/CHANGELOG.md | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning_fabric/utilities/seed.py b/src/lightning_fabric/utilities/seed.py index b26125951989e..a6bd1619d370a 100644 --- a/src/lightning_fabric/utilities/seed.py +++ b/src/lightning_fabric/utilities/seed.py @@ -108,7 +108,7 @@ def _collect_rng_states() -> Dict[str, Any]: """Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python.""" return { "torch": torch.get_rng_state(), - # "torch.cuda": torch.cuda.get_rng_state_all(), + "torch.cuda": torch.cuda.get_rng_state_all(), "numpy": np.random.get_state(), "python": python_get_rng_state(), } diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index dc9778fdbabd6..e5fc8cc2cb470 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -47,6 +47,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed support for loop customization * Removed `Loop.replace()` ([#16361](https://github.com/Lightning-AI/lightning/pull/16361)) + * Removed `Loop.connect()` ([#16384](https://github.com/Lightning-AI/lightning/pull/16384)) + * Removed the `trainer.{fit,validate,test,predict}_loop` properties ([#16384](https://github.com/Lightning-AI/lightning/pull/16384)) - Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172)) * Removed the `LightningModule.truncated_bptt_steps` attribute