Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[pre-commit.ci] pre-commit suggestions #1697

Merged
merged 2 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ repos:
- id: detect-private-key

- repo: https://github.com/asottile/pyupgrade
rev: v3.8.0
rev: v3.14.0
hooks:
- id: pyupgrade
args: [--py38-plus]
Expand All @@ -48,20 +48,20 @@ repos:
- id: nbstripout

- repo: https://github.com/PyCQA/docformatter
rev: v1.7.3
rev: v1.7.5
hooks:
- id: docformatter
additional_dependencies: [tomli]
args: ["--in-place"]

- repo: https://github.com/psf/black
rev: 23.3.0
rev: 23.9.1
hooks:
- id: black
name: Format code

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.276
rev: v0.0.292
hooks:
- id: ruff
args: ["--fix"]
10 changes: 8 additions & 2 deletions src/flash/audio/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@


class AudioClassificationData(DataModule):
"""The ``AudioClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of
class methods for loading data for audio classification."""
"""The ``AudioClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of class
methods for loading data for audio classification."""

input_transform_cls = AudioClassificationInputTransform

Expand Down Expand Up @@ -141,6 +141,7 @@ def from_files(
>>> import os
>>> _ = [os.remove(f"spectrogram_{i}.png") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_spectrogram_{i}.png") for i in range(1, 4)]

"""

ds_kw = {
Expand Down Expand Up @@ -275,6 +276,7 @@ def from_folders(
>>> import shutil
>>> shutil.rmtree("train_folder")
>>> shutil.rmtree("predict_folder")

"""

ds_kw = {
Expand Down Expand Up @@ -365,6 +367,7 @@ def from_numpy(
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...

"""

ds_kw = {
Expand Down Expand Up @@ -453,6 +456,7 @@ def from_tensors(
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...

"""

ds_kw = {
Expand Down Expand Up @@ -607,6 +611,7 @@ def from_data_frame(
>>> shutil.rmtree("predict_folder")
>>> del train_data_frame
>>> del predict_data_frame

"""

ds_kw = {
Expand Down Expand Up @@ -854,6 +859,7 @@ def from_csv(
>>> shutil.rmtree("predict_folder")
>>> os.remove("train_data.tsv")
>>> os.remove("predict_data.tsv")

"""

ds_kw = {
Expand Down
4 changes: 4 additions & 0 deletions src/flash/audio/speech_recognition/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def from_files(
>>> import os
>>> _ = [os.remove(f"speech_{i}.wav") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)]

"""

ds_kw = {"sampling_rate": sampling_rate}
Expand Down Expand Up @@ -302,6 +303,7 @@ def from_csv(
>>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)]
>>> os.remove("train_data.tsv")
>>> os.remove("predict_data.tsv")

"""

ds_kw = {"input_key": input_field, "sampling_rate": sampling_rate}
Expand Down Expand Up @@ -424,6 +426,7 @@ def from_json(
>>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)]
>>> os.remove("train_data.json")
>>> os.remove("predict_data.json")

"""

ds_kw = {"input_key": input_field, "sampling_rate": sampling_rate, "field": field}
Expand Down Expand Up @@ -570,6 +573,7 @@ def from_datasets(
>>> import os
>>> _ = [os.remove(f"speech_{i}.wav") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)]

"""

ds_kw = {"sampling_rate": sampling_rate}
Expand Down
1 change: 1 addition & 0 deletions src/flash/audio/speech_recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class SpeechRecognition(Task):
learning_rate: Learning rate to use for training, defaults to ``1e-5``.
optimizer: Optimizer to use for training.
lr_scheduler: The LR scheduler to use during training.

"""

backbones: FlashRegistry = SPEECH_RECOGNITION_BACKBONES
Expand Down
2 changes: 2 additions & 0 deletions src/flash/core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def from_task(cls, task: "flash.Task", **kwargs) -> "Adapter":
"""Instantiate the adapter from the given :class:`~flash.core.model.Task`.

This includes resolution / creation of backbones / heads and any other provider specific options.

"""

def forward(self, x: Any) -> Any:
Expand Down Expand Up @@ -73,6 +74,7 @@ class AdapterTask(Task):
Args:
adapter: The :class:`~flash.core.adapter.Adapter` to wrap.
kwargs: Keyword arguments to be passed to the base :class:`~flash.core.model.Task`.

"""

def __init__(self, adapter: Adapter, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions src/flash/core/data/base_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage):
As the :class:`~flash.core.data.io.input_transform.InputTransform` hooks are injected within
the threaded workers of the DataLoader,
the data won't be accessible when using ``num_workers > 0``.

"""

def _show(
Expand Down
2 changes: 2 additions & 0 deletions src/flash/core/data/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class FlashCallback(Callback):
Same as PyTorch Lightning, Callbacks can be provided directly to the Trainer::

trainer = Trainer(callbacks=[MyCustomCallback()])

"""

def on_per_sample_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
Expand Down Expand Up @@ -146,6 +147,7 @@ def from_inputs(
'val': {},
'predict': {}
}

"""

batches: dict
Expand Down
5 changes: 3 additions & 2 deletions src/flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@

class DatasetInput(Input):
"""The ``DatasetInput`` implements default behaviours for data sources which expect the input to
:meth:`~flash.core.data.io.input.Input.load_data` to be a :class:`torch.utils.data.dataset.Dataset`
"""
:meth:`~flash.core.data.io.input.Input.load_data` to be a :class:`torch.utils.data.dataset.Dataset`"""

def load_sample(self, sample: Any) -> Dict[str, Any]:
if isinstance(sample, tuple) and len(sample) == 2:
Expand Down Expand Up @@ -103,6 +102,7 @@ class DataModule(pl.LightningDataModule):
>>> datamodule = DataModule(train_input, sampler=WeightedRandomSampler([0.1, 0.5], 2), batch_size=1)
>>> print(datamodule.train_dataloader().sampler) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<torch.utils.data.sampler.WeightedRandomSampler object at ...>

"""

input_transform_cls = InputTransform
Expand Down Expand Up @@ -399,6 +399,7 @@ def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
"""This function is used to configure a :class:`~flash.core.data.callback.BaseDataFetcher`.

Override with your custom one.

"""
return BaseDataFetcher()

Expand Down
2 changes: 2 additions & 0 deletions src/flash/core/data/io/classification_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ClassificationInputMixin(Properties):
targets and store metadata like ``labels`` and ``num_classes``.
* In the ``load_sample`` method, use ``format_target`` to convert the target to a standard format for use with our
tasks.

"""

target_formatter: TargetFormatter
Expand All @@ -46,6 +47,7 @@ def load_target_metadata(
rather than inferring from the targets.
add_background: If ``True``, a background class will be inserted as class zero if ``labels`` and
``num_classes`` are being inferred.

"""
self.target_formatter = target_formatter
if target_formatter is None and targets is not None:
Expand Down
20 changes: 17 additions & 3 deletions src/flash/core/data/io/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _validate_input(input: "InputBase") -> None:
Raises:
RuntimeError: If the ``input`` is of type ``Input`` and it's ``data`` attribute does not support ``len``.
RuntimeError: If the ``input`` is of type ``IterableInput`` and it's ``data`` attribute does support ``len``.

"""
if input.data is not None:
if isinstance(input, Input) and not _has_len(input.data):
Expand All @@ -122,6 +123,7 @@ def _wrap_init(class_dict: Dict[str, Any]) -> None:

Args:
class_dict: The class construction dict, optionally containing an init to wrap.

"""
if "__init__" in class_dict:
fn = class_dict["__init__"]
Expand Down Expand Up @@ -153,14 +155,15 @@ def __new__(mcs, name: str, bases: Tuple, class_dict: Dict[str, Any]) -> "_Itera

class InputBase(Properties, metaclass=_InputMeta):
"""``InputBase`` is the base class for the :class:`~flash.core.data.io.input.Input` and
:class:`~flash.core.data.io.input.IterableInput` dataset implementations in Flash. These datasets are
constructed via the ``load_data`` and ``load_sample`` hooks, which allow a single dataset object to include custom
loading logic according to the running stage (e.g. train, validate, test, predict).
:class:`~flash.core.data.io.input.IterableInput` dataset implementations in Flash. These datasets are constructed
via the ``load_data`` and ``load_sample`` hooks, which allow a single dataset object to include custom loading logic
according to the running stage (e.g. train, validate, test, predict).

Args:
running_stage: The running stage for which the input will be used.
*args: Any arguments that are to be passed to the ``load_data`` hook.
**kwargs: Any additional keyword arguments to pass to the ``load_data`` hook.

"""

def __init__(self, running_stage: RunningStage, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -194,6 +197,7 @@ def load_data(*args: Any, **kwargs: Any) -> Union[Sequence, Iterable]:
Args:
*args: Any arguments that the input requires.
**kwargs: Any additional keyword arguments that the input requires.

"""
return args[0]

Expand All @@ -203,6 +207,7 @@ def train_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterable
Args:
*args: Any arguments that the input requires.
**kwargs: Any additional keyword arguments that the input requires.

"""
return self.load_data(*args, **kwargs)

Expand All @@ -212,6 +217,7 @@ def val_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterable]:
Args:
*args: Any arguments that the input requires.
**kwargs: Any additional keyword arguments that the input requires.

"""
return self.load_data(*args, **kwargs)

Expand All @@ -221,6 +227,7 @@ def test_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterable]
Args:
*args: Any arguments that the input requires.
**kwargs: Any additional keyword arguments that the input requires.

"""
return self.load_data(*args, **kwargs)

Expand All @@ -230,6 +237,7 @@ def predict_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterab
Args:
*args: Any arguments that the input requires.
**kwargs: Any additional keyword arguments that the input requires.

"""
return self.load_data(*args, **kwargs)

Expand All @@ -240,6 +248,7 @@ def load_sample(sample: Dict[str, Any]) -> Any:

Args:
sample: A single sample from the output of the ``load_data`` hook.

"""
return sample

Expand All @@ -248,6 +257,7 @@ def train_load_sample(self, sample: Dict[str, Any]) -> Any:

Args:
sample: A single sample from the output of the ``load_data`` hook.

"""
return self.load_sample(sample)

Expand All @@ -256,6 +266,7 @@ def val_load_sample(self, sample: Dict[str, Any]) -> Any:

Args:
sample: A single sample from the output of the ``load_data`` hook.

"""
return self.load_sample(sample)

Expand All @@ -264,6 +275,7 @@ def test_load_sample(self, sample: Dict[str, Any]) -> Any:

Args:
sample: A single sample from the output of the ``load_data`` hook.

"""
return self.load_sample(sample)

Expand All @@ -272,13 +284,15 @@ def predict_load_sample(self, sample: Dict[str, Any]) -> Any:

Args:
sample: A single sample from the output of the ``load_data`` hook.

"""
return self.load_sample(sample)

def __bool__(self):
"""If ``self.data`` is ``None`` then the ``InputBase`` is considered falsey.

This allows for quickly checking whether or not the ``InputBase`` is populated with data.

"""
return self.data is not None

Expand Down
1 change: 1 addition & 0 deletions src/flash/core/data/io/transform_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class TransformPredictions(Callback):
Args:
output_transform: The :class:`~flash.core.data.io.output_transform.OutputTransform` to apply.
output: The :class:`~flash.core.data.io.output.Output` to apply.

"""

def __init__(self, output_transform: OutputTransform, output: Output):
Expand Down
1 change: 1 addition & 0 deletions src/flash/core/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class ApplyToKeys(nn.Sequential):
Args:
keys: The key (``str``) or sequence of keys (``Sequence[str]``) to extract and forward to the transforms.
args: The transforms, passed to the ``nn.Sequential`` super constructor.

"""

def __init__(self, keys: Union[str, Sequence[str]], *args):
Expand Down
Loading
Loading