Skip to content

Commit

Permalink
Loop flattening: remove .connect() (#16384)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jan 19, 2023
1 parent 63eec5e commit 05876b6
Show file tree
Hide file tree
Showing 12 changed files with 22 additions and 276 deletions.
87 changes: 0 additions & 87 deletions docs/source-pytorch/extensions/loops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -259,28 +259,6 @@ run (optional)

----------

Subloops
--------

When you want to customize nested loops within loops use the :meth:`~pytorch_lightning.loops.loop.Loop.connect` method:

.. code-block:: python
# Optional: stitch back the trainer arguments
epoch_loop = MyEpochLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps)
# Optional: connect children loops as they might have existing state
epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop)
# Instantiate and connect the loop.
trainer.fit_loop.connect(epoch_loop=epoch_loop)
trainer.fit(model)
More about the built-in loops and how they are composed is explained in the next section.

.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/connect-epoch-loop.gif
:alt: Animation showing how to connect a custom subloop

----------

Built-in Loops
--------------

Expand Down Expand Up @@ -342,71 +320,6 @@ Each of these :code:`for`-loops represents a class implementing the :class:`~pyt
It simply iterates over each prediction dataloader from one to the next by calling :code:`PredictionEpochLoop.run()` in its :code:`advance()` method.


----------

Available Loops in Lightning Flash
----------------------------------

`Active Learning <https://en.wikipedia.org/wiki/Active_learning_(machine_learning)>`__ is a machine learning practice in which the user interacts with the learner in order to provide new labels when required.

You can find a real use case in `Lightning Flash <https://github.com/Lightning-AI/lightning-flash>`_.

Flash implements the :code:`ActiveLearningLoop` that you can use together with the :code:`ActiveLearningDataModule` to label new data on the fly.
To run the following demo, install Flash and `BaaL <https://github.com/ElementAI/baal>`__ first:

.. code-block:: bash
pip install lightning-flash[image] baal
.. code-block:: python
import torch
import flash
from flash.core.classification import Probabilities
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")
# Implement the research use-case where we mask labels from labelled dataset.
datamodule = ActiveLearningDataModule(
ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2),
initial_num_labels=5,
val_split=0.1,
)
# 2. Build the task
head = torch.nn.Sequential(
torch.nn.Dropout(p=0.1),
torch.nn.Linear(512, datamodule.num_classes),
)
model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, output=Probabilities())
# 3.1 Create the trainer
trainer = flash.Trainer(max_epochs=3)
# 3.2 Create the active learning loop and connect it to the trainer
active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1)
active_learning_loop.connect(trainer.fit_loop)
trainer.fit_loop = active_learning_loop
# 3.3 Finetune
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Predict what's on a few images! ants or bees?
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")
Here is the `Active Learning Loop example <https://github.com/Lightning-AI/lightning-flash/blob/master/flash_examples/integrations/baal/image_classification_active_learning.py>`_ and the `code for the active learning loop <https://github.com/Lightning-AI/lightning-flash/blob/master/flash/image/classification/integrations/baal/loop.py>`_.


----------

Advanced Examples
Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed support for loop customization
* Removed `Loop.replace()` ([#16361](https://github.com/Lightning-AI/lightning/pull/16361))
* Removed `Loop.connect()` ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
* Removed the `trainer.{fit,validate,test,predict}_loop` properties ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))

- Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172))
* Removed the `LightningModule.truncated_bptt_steps` attribute
Expand Down
4 changes: 0 additions & 4 deletions src/pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,6 @@ def prefetch_batches(self) -> int:
is_unsized = batches[self.current_dataloader_idx] == float("inf")
return int(is_unsized)

def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override]
"""Connect the evaluation epoch loop with this loop."""
self.epoch_loop = epoch_loop

@property
def done(self) -> bool:
"""Returns whether all dataloaders are processed or evaluation should be skipped altogether."""
Expand Down
4 changes: 0 additions & 4 deletions src/pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ def dataloaders(self) -> Sequence[DataLoader]:
def skip(self) -> bool:
return sum(self.max_batches) == 0

def connect(self, epoch_loop: PredictionEpochLoop) -> None: # type: ignore[override]
"""Connect the prediction epoch loop with this loop."""
self.epoch_loop = epoch_loop

def reset(self) -> None:
"""Resets the internal state of the loop for a new run."""
self.predictions = []
Expand Down
3 changes: 0 additions & 3 deletions src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ def should_store_predictions(self) -> bool:
any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks)
return self.return_predictions or any_pred

def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")

def reset(self) -> None:
"""Resets the loops internal state."""
self._seen_batch_indices = []
Expand Down
14 changes: 0 additions & 14 deletions src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,6 @@ def done(self) -> bool:

return False

def connect( # type: ignore[override]
self,
optimizer_loop: Optional[OptimizerLoop] = None,
manual_loop: Optional[ManualOptimization] = None,
val_loop: Optional["loops.EvaluationLoop"] = None,
) -> None:
"""Optionally connect a custom batch or validation loop to this training epoch loop."""
if optimizer_loop is not None:
self.optimizer_loop = optimizer_loop
if manual_loop is not None:
self.manual_loop = manual_loop
if val_loop is not None:
self.val_loop = val_loop

def reset(self) -> None:
"""Resets the internal state of the loop for a new run."""
if self.restarting:
Expand Down
4 changes: 0 additions & 4 deletions src/pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,6 @@ def skip(self) -> bool:
# until `on_run_start`, we use `limit_train_batches` instead
return self.done or self.trainer.limit_train_batches == 0

def connect(self, epoch_loop: TrainingEpochLoop) -> None: # type: ignore[override]
"""Connects a training epoch loop to this fit loop."""
self.epoch_loop = epoch_loop

def reset(self) -> None:
"""Resets the internal state of this loop."""
if self.restarting:
Expand Down
6 changes: 0 additions & 6 deletions src/pytorch_lightning/loops/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,6 @@ def skip(self):
"""
return False

def connect(self, **kwargs: "Loop") -> None:
"""Optionally connect one or multiple loops to this one.
Linked loops should form a tree.
"""

def on_skip(self) -> T:
"""The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`.
Expand Down
3 changes: 0 additions & 3 deletions src/pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,6 @@ def done(self) -> bool:
"""Returns ``True`` when the last optimizer in the sequence has run."""
return self.optim_progress.optimizer_position >= len(self._indices)

def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")

def reset(self) -> None:
if not self.restarting:
# when reset() is called from outside (manually), we reset the loop progress
Expand Down
83 changes: 7 additions & 76 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,21 +371,16 @@ def __init__(
self._signal_connector = SignalConnector(self)
self.tuner = Tuner(self)

fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
fit_loop.connect(epoch_loop=training_epoch_loop)

# default .fit() loop
self.fit_loop = fit_loop

# default .validate() loop
# init loops
self.fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
self.fit_loop.epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
self.validate_loop = EvaluationLoop()

# default .test() loop
self.test_loop = EvaluationLoop()

# default .predict() loop
self.predict_loop = PredictionLoop()
self.fit_loop.trainer = self
self.validate_loop.trainer = self
self.test_loop.trainer = self
self.predict_loop.trainer = self

# init callbacks
# Declare attributes to be set in _callback_connector on_trainer_init
Expand Down Expand Up @@ -1103,8 +1098,6 @@ def _run_train(self) -> None:
self.model.train()
torch.set_grad_enabled(True)

self.fit_loop.trainer = self

with torch.autograd.set_detect_anomaly(self._detect_anomaly):
self.fit_loop.run()

Expand All @@ -1114,9 +1107,6 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT:
# reload dataloaders
self._evaluation_loop._reload_evaluation_dataloaders()

# reset trainer on this loop and all child loops in case user connected a custom loop
self._evaluation_loop.trainer = self

with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(
self.accelerator, self._inference_mode
):
Expand All @@ -1133,8 +1123,6 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT:

def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
self.reset_predict_dataloader(self.lightning_module)
# reset trainer on this loop and all child loops in case user connected a custom loop
self.predict_loop.trainer = self
with _evaluation_context(self.accelerator, self._inference_mode):
return self.predict_loop.run()

Expand Down Expand Up @@ -1955,63 +1943,6 @@ def is_last_batch(self) -> bool:
"""Whether trainer is executing the last batch."""
return self.fit_loop.epoch_loop.batch_progress.is_last_batch

@property
def fit_loop(self) -> FitLoop:
return self._fit_loop

@fit_loop.setter
def fit_loop(self, loop: FitLoop) -> None:
"""Attach a custom fit loop to this Trainer.
It will run with
:meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`.
"""
loop.trainer = self
self._fit_loop = loop

@property
def validate_loop(self) -> EvaluationLoop:
return self._validate_loop

@validate_loop.setter
def validate_loop(self, loop: EvaluationLoop) -> None:
"""Attach a custom validation loop to this Trainer.
It will run with
:meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`. Note that this loop is different from the one
running during training inside the :meth:`pytorch_lightning.trainer.trainer.Trainer.fit` call.
"""
loop.trainer = self
self._validate_loop = loop

@property
def test_loop(self) -> EvaluationLoop:
return self._test_loop

@test_loop.setter
def test_loop(self, loop: EvaluationLoop) -> None:
"""Attach a custom test loop to this Trainer.
It will run with
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.
"""
loop.trainer = self
self._test_loop = loop

@property
def predict_loop(self) -> PredictionLoop:
return self._predict_loop

@predict_loop.setter
def predict_loop(self, loop: PredictionLoop) -> None:
"""Attach a custom prediction loop to this Trainer.
It will run with
:meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`.
"""
loop.trainer = self
self._predict_loop = loop

@property
def _evaluation_loop(self) -> EvaluationLoop:
if self.state.fn == TrainerFn.FITTING:
Expand Down
Loading

0 comments on commit 05876b6

Please sign in to comment.