Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loop flattening: remove .connect() #16384

Merged
merged 3 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 0 additions & 87 deletions docs/source-pytorch/extensions/loops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -259,28 +259,6 @@ run (optional)

----------

Subloops
--------

When you want to customize nested loops within loops use the :meth:`~pytorch_lightning.loops.loop.Loop.connect` method:

.. code-block:: python

# Optional: stitch back the trainer arguments
epoch_loop = MyEpochLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps)
# Optional: connect children loops as they might have existing state
epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop)
# Instantiate and connect the loop.
trainer.fit_loop.connect(epoch_loop=epoch_loop)
trainer.fit(model)

More about the built-in loops and how they are composed is explained in the next section.

.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/connect-epoch-loop.gif
:alt: Animation showing how to connect a custom subloop

----------

Built-in Loops
--------------

Expand Down Expand Up @@ -342,71 +320,6 @@ Each of these :code:`for`-loops represents a class implementing the :class:`~pyt
It simply iterates over each prediction dataloader from one to the next by calling :code:`PredictionEpochLoop.run()` in its :code:`advance()` method.


----------

Available Loops in Lightning Flash
----------------------------------

`Active Learning <https://en.wikipedia.org/wiki/Active_learning_(machine_learning)>`__ is a machine learning practice in which the user interacts with the learner in order to provide new labels when required.

You can find a real use case in `Lightning Flash <https://github.com/Lightning-AI/lightning-flash>`_.

Flash implements the :code:`ActiveLearningLoop` that you can use together with the :code:`ActiveLearningDataModule` to label new data on the fly.
To run the following demo, install Flash and `BaaL <https://github.com/ElementAI/baal>`__ first:

.. code-block:: bash

pip install lightning-flash[image] baal

.. code-block:: python

import torch

import flash
from flash.core.classification import Probabilities
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")

# Implement the research use-case where we mask labels from labelled dataset.
datamodule = ActiveLearningDataModule(
ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2),
initial_num_labels=5,
val_split=0.1,
)

# 2. Build the task
head = torch.nn.Sequential(
torch.nn.Dropout(p=0.1),
torch.nn.Linear(512, datamodule.num_classes),
)
model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, output=Probabilities())


# 3.1 Create the trainer
trainer = flash.Trainer(max_epochs=3)

# 3.2 Create the active learning loop and connect it to the trainer
active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1)
active_learning_loop.connect(trainer.fit_loop)
trainer.fit_loop = active_learning_loop

# 3.3 Finetune
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict what's on a few images! ants or bees?
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

Here is the `Active Learning Loop example <https://github.com/Lightning-AI/lightning-flash/blob/master/flash_examples/integrations/baal/image_classification_active_learning.py>`_ and the `code for the active learning loop <https://github.com/Lightning-AI/lightning-flash/blob/master/flash/image/classification/integrations/baal/loop.py>`_.
carmocca marked this conversation as resolved.
Show resolved Hide resolved


----------

Advanced Examples
Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed support for loop customization
* Removed `Loop.replace()` ([#16361](https://github.com/Lightning-AI/lightning/pull/16361))
* Removed `Loop.connect()` ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
* Removed the `trainer.{fit,validate,test,predict}_loop` properties ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))

- Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172))
* Removed the `LightningModule.truncated_bptt_steps` attribute
Expand Down
4 changes: 0 additions & 4 deletions src/pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,6 @@ def prefetch_batches(self) -> int:
is_unsized = batches[self.current_dataloader_idx] == float("inf")
return int(is_unsized)

def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override]
"""Connect the evaluation epoch loop with this loop."""
self.epoch_loop = epoch_loop

@property
def done(self) -> bool:
"""Returns whether all dataloaders are processed or evaluation should be skipped altogether."""
Expand Down
4 changes: 0 additions & 4 deletions src/pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ def dataloaders(self) -> Sequence[DataLoader]:
def skip(self) -> bool:
return sum(self.max_batches) == 0

def connect(self, epoch_loop: PredictionEpochLoop) -> None: # type: ignore[override]
"""Connect the prediction epoch loop with this loop."""
self.epoch_loop = epoch_loop

def reset(self) -> None:
"""Resets the internal state of the loop for a new run."""
self.predictions = []
Expand Down
3 changes: 0 additions & 3 deletions src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ def should_store_predictions(self) -> bool:
any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks)
return self.return_predictions or any_pred

def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")

def reset(self) -> None:
"""Resets the loops internal state."""
self._seen_batch_indices = []
Expand Down
14 changes: 0 additions & 14 deletions src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,6 @@ def done(self) -> bool:

return False

def connect( # type: ignore[override]
self,
optimizer_loop: Optional[OptimizerLoop] = None,
manual_loop: Optional[ManualOptimization] = None,
val_loop: Optional["loops.EvaluationLoop"] = None,
) -> None:
"""Optionally connect a custom batch or validation loop to this training epoch loop."""
if optimizer_loop is not None:
self.optimizer_loop = optimizer_loop
if manual_loop is not None:
self.manual_loop = manual_loop
if val_loop is not None:
self.val_loop = val_loop

def reset(self) -> None:
"""Resets the internal state of the loop for a new run."""
if self.restarting:
Expand Down
4 changes: 0 additions & 4 deletions src/pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,6 @@ def skip(self) -> bool:
# until `on_run_start`, we use `limit_train_batches` instead
return self.done or self.trainer.limit_train_batches == 0

def connect(self, epoch_loop: TrainingEpochLoop) -> None: # type: ignore[override]
"""Connects a training epoch loop to this fit loop."""
self.epoch_loop = epoch_loop

def reset(self) -> None:
"""Resets the internal state of this loop."""
if self.restarting:
Expand Down
6 changes: 0 additions & 6 deletions src/pytorch_lightning/loops/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,6 @@ def skip(self):
"""
return False

def connect(self, **kwargs: "Loop") -> None:
"""Optionally connect one or multiple loops to this one.

Linked loops should form a tree.
"""

def on_skip(self) -> T:
"""The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`.

Expand Down
3 changes: 0 additions & 3 deletions src/pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,6 @@ def done(self) -> bool:
"""Returns ``True`` when the last optimizer in the sequence has run."""
return self.optim_progress.optimizer_position >= len(self._indices)

def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")

def reset(self) -> None:
if not self.restarting:
# when reset() is called from outside (manually), we reset the loop progress
Expand Down
83 changes: 7 additions & 76 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,21 +371,16 @@ def __init__(
self._signal_connector = SignalConnector(self)
self.tuner = Tuner(self)

fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
fit_loop.connect(epoch_loop=training_epoch_loop)

# default .fit() loop
self.fit_loop = fit_loop

# default .validate() loop
# init loops
self.fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
self.fit_loop.epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
self.validate_loop = EvaluationLoop()

# default .test() loop
self.test_loop = EvaluationLoop()

# default .predict() loop
self.predict_loop = PredictionLoop()
self.fit_loop.trainer = self
self.validate_loop.trainer = self
self.test_loop.trainer = self
self.predict_loop.trainer = self

# init callbacks
# Declare attributes to be set in _callback_connector on_trainer_init
Expand Down Expand Up @@ -1103,8 +1098,6 @@ def _run_train(self) -> None:
self.model.train()
torch.set_grad_enabled(True)

self.fit_loop.trainer = self

with torch.autograd.set_detect_anomaly(self._detect_anomaly):
self.fit_loop.run()

Expand All @@ -1114,9 +1107,6 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT:
# reload dataloaders
self._evaluation_loop._reload_evaluation_dataloaders()

# reset trainer on this loop and all child loops in case user connected a custom loop
self._evaluation_loop.trainer = self

with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(
self.accelerator, self._inference_mode
):
Expand All @@ -1133,8 +1123,6 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT:

def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
self.reset_predict_dataloader(self.lightning_module)
# reset trainer on this loop and all child loops in case user connected a custom loop
self.predict_loop.trainer = self
with _evaluation_context(self.accelerator, self._inference_mode):
return self.predict_loop.run()

Expand Down Expand Up @@ -1969,63 +1957,6 @@ def is_last_batch(self) -> bool:
"""Whether trainer is executing the last batch."""
return self.fit_loop.epoch_loop.batch_progress.is_last_batch

@property
def fit_loop(self) -> FitLoop:
return self._fit_loop

@fit_loop.setter
def fit_loop(self, loop: FitLoop) -> None:
"""Attach a custom fit loop to this Trainer.

It will run with
:meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`.
"""
loop.trainer = self
self._fit_loop = loop

@property
def validate_loop(self) -> EvaluationLoop:
return self._validate_loop

@validate_loop.setter
def validate_loop(self, loop: EvaluationLoop) -> None:
"""Attach a custom validation loop to this Trainer.

It will run with
:meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`. Note that this loop is different from the one
running during training inside the :meth:`pytorch_lightning.trainer.trainer.Trainer.fit` call.
"""
loop.trainer = self
self._validate_loop = loop

@property
def test_loop(self) -> EvaluationLoop:
return self._test_loop

@test_loop.setter
def test_loop(self, loop: EvaluationLoop) -> None:
"""Attach a custom test loop to this Trainer.

It will run with
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.
"""
loop.trainer = self
self._test_loop = loop

@property
def predict_loop(self) -> PredictionLoop:
return self._predict_loop

@predict_loop.setter
def predict_loop(self, loop: PredictionLoop) -> None:
"""Attach a custom prediction loop to this Trainer.

It will run with
:meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`.
"""
loop.trainer = self
self._predict_loop = loop

@property
def _evaluation_loop(self) -> EvaluationLoop:
if self.state.fn == TrainerFn.FITTING:
Expand Down
Loading