diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 73a69f2b04..c419164f30 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: - id: detect-private-key - repo: https://github.com/asottile/pyupgrade - rev: v3.8.0 + rev: v3.14.0 hooks: - id: pyupgrade args: [--py38-plus] @@ -48,20 +48,20 @@ repos: - id: nbstripout - repo: https://github.com/PyCQA/docformatter - rev: v1.7.3 + rev: v1.7.5 hooks: - id: docformatter additional_dependencies: [tomli] args: ["--in-place"] - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 23.9.1 hooks: - id: black name: Format code - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.276 + rev: v0.0.292 hooks: - id: ruff args: ["--fix"] diff --git a/src/flash/audio/classification/data.py b/src/flash/audio/classification/data.py index 412713a9e0..fec824122f 100644 --- a/src/flash/audio/classification/data.py +++ b/src/flash/audio/classification/data.py @@ -42,8 +42,8 @@ class AudioClassificationData(DataModule): - """The ``AudioClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of - class methods for loading data for audio classification.""" + """The ``AudioClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of class + methods for loading data for audio classification.""" input_transform_cls = AudioClassificationInputTransform @@ -141,6 +141,7 @@ def from_files( >>> import os >>> _ = [os.remove(f"spectrogram_{i}.png") for i in range(1, 4)] >>> _ = [os.remove(f"predict_spectrogram_{i}.png") for i in range(1, 4)] + """ ds_kw = { @@ -275,6 +276,7 @@ def from_folders( >>> import shutil >>> shutil.rmtree("train_folder") >>> shutil.rmtree("predict_folder") + """ ds_kw = { @@ -365,6 +367,7 @@ def from_numpy( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + """ ds_kw = { @@ -453,6 +456,7 @@ def from_tensors( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + """ ds_kw = { @@ -607,6 +611,7 @@ def from_data_frame( >>> shutil.rmtree("predict_folder") >>> del train_data_frame >>> del predict_data_frame + """ ds_kw = { @@ -854,6 +859,7 @@ def from_csv( >>> shutil.rmtree("predict_folder") >>> os.remove("train_data.tsv") >>> os.remove("predict_data.tsv") + """ ds_kw = { diff --git a/src/flash/audio/speech_recognition/data.py b/src/flash/audio/speech_recognition/data.py index 90117bf90b..85a5008210 100644 --- a/src/flash/audio/speech_recognition/data.py +++ b/src/flash/audio/speech_recognition/data.py @@ -117,6 +117,7 @@ def from_files( >>> import os >>> _ = [os.remove(f"speech_{i}.wav") for i in range(1, 4)] >>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)] + """ ds_kw = {"sampling_rate": sampling_rate} @@ -302,6 +303,7 @@ def from_csv( >>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)] >>> os.remove("train_data.tsv") >>> os.remove("predict_data.tsv") + """ ds_kw = {"input_key": input_field, "sampling_rate": sampling_rate} @@ -424,6 +426,7 @@ def from_json( >>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)] >>> os.remove("train_data.json") >>> os.remove("predict_data.json") + """ ds_kw = {"input_key": input_field, "sampling_rate": sampling_rate, "field": field} @@ -570,6 +573,7 @@ def from_datasets( >>> import os >>> _ = [os.remove(f"speech_{i}.wav") for i in range(1, 4)] >>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)] + """ ds_kw = {"sampling_rate": sampling_rate} diff --git a/src/flash/audio/speech_recognition/model.py b/src/flash/audio/speech_recognition/model.py index a3c379c642..b88e3a02bd 100644 --- a/src/flash/audio/speech_recognition/model.py +++ b/src/flash/audio/speech_recognition/model.py @@ -45,6 +45,7 @@ class SpeechRecognition(Task): learning_rate: Learning rate to use for training, defaults to ``1e-5``. optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. + """ backbones: FlashRegistry = SPEECH_RECOGNITION_BACKBONES diff --git a/src/flash/core/adapter.py b/src/flash/core/adapter.py index 559c90ce13..71a5768ce2 100644 --- a/src/flash/core/adapter.py +++ b/src/flash/core/adapter.py @@ -35,6 +35,7 @@ def from_task(cls, task: "flash.Task", **kwargs) -> "Adapter": """Instantiate the adapter from the given :class:`~flash.core.model.Task`. This includes resolution / creation of backbones / heads and any other provider specific options. + """ def forward(self, x: Any) -> Any: @@ -73,6 +74,7 @@ class AdapterTask(Task): Args: adapter: The :class:`~flash.core.adapter.Adapter` to wrap. kwargs: Keyword arguments to be passed to the base :class:`~flash.core.model.Task`. + """ def __init__(self, adapter: Adapter, **kwargs): diff --git a/src/flash/core/data/base_viz.py b/src/flash/core/data/base_viz.py index 9f8c90cae7..4370f00588 100644 --- a/src/flash/core/data/base_viz.py +++ b/src/flash/core/data/base_viz.py @@ -96,6 +96,7 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage): As the :class:`~flash.core.data.io.input_transform.InputTransform` hooks are injected within the threaded workers of the DataLoader, the data won't be accessible when using ``num_workers > 0``. + """ def _show( diff --git a/src/flash/core/data/callback.py b/src/flash/core/data/callback.py index c99c0ddc9a..6f6872f34f 100644 --- a/src/flash/core/data/callback.py +++ b/src/flash/core/data/callback.py @@ -19,6 +19,7 @@ class FlashCallback(Callback): Same as PyTorch Lightning, Callbacks can be provided directly to the Trainer:: trainer = Trainer(callbacks=[MyCustomCallback()]) + """ def on_per_sample_transform(self, sample: Tensor, running_stage: RunningStage) -> None: @@ -146,6 +147,7 @@ def from_inputs( 'val': {}, 'predict': {} } + """ batches: dict diff --git a/src/flash/core/data/data_module.py b/src/flash/core/data/data_module.py index 9118300c96..af30a47554 100644 --- a/src/flash/core/data/data_module.py +++ b/src/flash/core/data/data_module.py @@ -44,8 +44,7 @@ class DatasetInput(Input): """The ``DatasetInput`` implements default behaviours for data sources which expect the input to - :meth:`~flash.core.data.io.input.Input.load_data` to be a :class:`torch.utils.data.dataset.Dataset` - """ + :meth:`~flash.core.data.io.input.Input.load_data` to be a :class:`torch.utils.data.dataset.Dataset`""" def load_sample(self, sample: Any) -> Dict[str, Any]: if isinstance(sample, tuple) and len(sample) == 2: @@ -103,6 +102,7 @@ class DataModule(pl.LightningDataModule): >>> datamodule = DataModule(train_input, sampler=WeightedRandomSampler([0.1, 0.5], 2), batch_size=1) >>> print(datamodule.train_dataloader().sampler) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + """ input_transform_cls = InputTransform @@ -399,6 +399,7 @@ def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: """This function is used to configure a :class:`~flash.core.data.callback.BaseDataFetcher`. Override with your custom one. + """ return BaseDataFetcher() diff --git a/src/flash/core/data/io/classification_input.py b/src/flash/core/data/io/classification_input.py index 31daf8273c..579696f432 100644 --- a/src/flash/core/data/io/classification_input.py +++ b/src/flash/core/data/io/classification_input.py @@ -25,6 +25,7 @@ class ClassificationInputMixin(Properties): targets and store metadata like ``labels`` and ``num_classes``. * In the ``load_sample`` method, use ``format_target`` to convert the target to a standard format for use with our tasks. + """ target_formatter: TargetFormatter @@ -46,6 +47,7 @@ def load_target_metadata( rather than inferring from the targets. add_background: If ``True``, a background class will be inserted as class zero if ``labels`` and ``num_classes`` are being inferred. + """ self.target_formatter = target_formatter if target_formatter is None and targets is not None: diff --git a/src/flash/core/data/io/input.py b/src/flash/core/data/io/input.py index 199aee6f00..60e7a630e4 100644 --- a/src/flash/core/data/io/input.py +++ b/src/flash/core/data/io/input.py @@ -108,6 +108,7 @@ def _validate_input(input: "InputBase") -> None: Raises: RuntimeError: If the ``input`` is of type ``Input`` and it's ``data`` attribute does not support ``len``. RuntimeError: If the ``input`` is of type ``IterableInput`` and it's ``data`` attribute does support ``len``. + """ if input.data is not None: if isinstance(input, Input) and not _has_len(input.data): @@ -122,6 +123,7 @@ def _wrap_init(class_dict: Dict[str, Any]) -> None: Args: class_dict: The class construction dict, optionally containing an init to wrap. + """ if "__init__" in class_dict: fn = class_dict["__init__"] @@ -153,14 +155,15 @@ def __new__(mcs, name: str, bases: Tuple, class_dict: Dict[str, Any]) -> "_Itera class InputBase(Properties, metaclass=_InputMeta): """``InputBase`` is the base class for the :class:`~flash.core.data.io.input.Input` and - :class:`~flash.core.data.io.input.IterableInput` dataset implementations in Flash. These datasets are - constructed via the ``load_data`` and ``load_sample`` hooks, which allow a single dataset object to include custom - loading logic according to the running stage (e.g. train, validate, test, predict). + :class:`~flash.core.data.io.input.IterableInput` dataset implementations in Flash. These datasets are constructed + via the ``load_data`` and ``load_sample`` hooks, which allow a single dataset object to include custom loading logic + according to the running stage (e.g. train, validate, test, predict). Args: running_stage: The running stage for which the input will be used. *args: Any arguments that are to be passed to the ``load_data`` hook. **kwargs: Any additional keyword arguments to pass to the ``load_data`` hook. + """ def __init__(self, running_stage: RunningStage, *args: Any, **kwargs: Any) -> None: @@ -194,6 +197,7 @@ def load_data(*args: Any, **kwargs: Any) -> Union[Sequence, Iterable]: Args: *args: Any arguments that the input requires. **kwargs: Any additional keyword arguments that the input requires. + """ return args[0] @@ -203,6 +207,7 @@ def train_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterable Args: *args: Any arguments that the input requires. **kwargs: Any additional keyword arguments that the input requires. + """ return self.load_data(*args, **kwargs) @@ -212,6 +217,7 @@ def val_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterable]: Args: *args: Any arguments that the input requires. **kwargs: Any additional keyword arguments that the input requires. + """ return self.load_data(*args, **kwargs) @@ -221,6 +227,7 @@ def test_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterable] Args: *args: Any arguments that the input requires. **kwargs: Any additional keyword arguments that the input requires. + """ return self.load_data(*args, **kwargs) @@ -230,6 +237,7 @@ def predict_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterab Args: *args: Any arguments that the input requires. **kwargs: Any additional keyword arguments that the input requires. + """ return self.load_data(*args, **kwargs) @@ -240,6 +248,7 @@ def load_sample(sample: Dict[str, Any]) -> Any: Args: sample: A single sample from the output of the ``load_data`` hook. + """ return sample @@ -248,6 +257,7 @@ def train_load_sample(self, sample: Dict[str, Any]) -> Any: Args: sample: A single sample from the output of the ``load_data`` hook. + """ return self.load_sample(sample) @@ -256,6 +266,7 @@ def val_load_sample(self, sample: Dict[str, Any]) -> Any: Args: sample: A single sample from the output of the ``load_data`` hook. + """ return self.load_sample(sample) @@ -264,6 +275,7 @@ def test_load_sample(self, sample: Dict[str, Any]) -> Any: Args: sample: A single sample from the output of the ``load_data`` hook. + """ return self.load_sample(sample) @@ -272,6 +284,7 @@ def predict_load_sample(self, sample: Dict[str, Any]) -> Any: Args: sample: A single sample from the output of the ``load_data`` hook. + """ return self.load_sample(sample) @@ -279,6 +292,7 @@ def __bool__(self): """If ``self.data`` is ``None`` then the ``InputBase`` is considered falsey. This allows for quickly checking whether or not the ``InputBase`` is populated with data. + """ return self.data is not None diff --git a/src/flash/core/data/io/transform_predictions.py b/src/flash/core/data/io/transform_predictions.py index a486c1659e..fc4588898f 100644 --- a/src/flash/core/data/io/transform_predictions.py +++ b/src/flash/core/data/io/transform_predictions.py @@ -28,6 +28,7 @@ class TransformPredictions(Callback): Args: output_transform: The :class:`~flash.core.data.io.output_transform.OutputTransform` to apply. output: The :class:`~flash.core.data.io.output.Output` to apply. + """ def __init__(self, output_transform: OutputTransform, output: Output): diff --git a/src/flash/core/data/transforms.py b/src/flash/core/data/transforms.py index 6aa428e8c0..9825103522 100644 --- a/src/flash/core/data/transforms.py +++ b/src/flash/core/data/transforms.py @@ -66,6 +66,7 @@ class ApplyToKeys(nn.Sequential): Args: keys: The key (``str``) or sequence of keys (``Sequence[str]``) to extract and forward to the transforms. args: The transforms, passed to the ``nn.Sequential`` super constructor. + """ def __init__(self, keys: Union[str, Sequence[str]], *args): diff --git a/src/flash/core/data/utilities/classification.py b/src/flash/core/data/utilities/classification.py index 6bd6992a5a..0492bfa3ff 100644 --- a/src/flash/core/data/utilities/classification.py +++ b/src/flash/core/data/utilities/classification.py @@ -73,6 +73,7 @@ class TargetFormatter: >>> formatter = CustomStringTargetFormatter() >>> formatter("#1") 1 + """ multi_label: ClassVar[Optional[bool]] = None @@ -109,6 +110,7 @@ class SingleNumericTargetFormatter(TargetFormatter): 5 >>> formatter(torch.tensor(5)) 5 + """ multi_label: ClassVar[Optional[bool]] = False @@ -137,6 +139,7 @@ class SingleLabelTargetFormatter(TargetFormatter): 0 >>> formatter(["dog"]) 1 + """ multi_label: ClassVar[Optional[bool]] = False @@ -167,6 +170,7 @@ class SingleBinaryTargetFormatter(TargetFormatter): 0 >>> formatter(torch.tensor([0, 1])) 1 + """ multi_label: ClassVar[Optional[bool]] = False @@ -196,6 +200,7 @@ class MultiNumericTargetFormatter(TargetFormatter): [0, 0, 1, 0, 0, 1, 0, 0, 0, 0] >>> formatter(torch.tensor([2, 5])) [0, 0, 1, 0, 0, 1, 0, 0, 0, 0] + """ multi_label: ClassVar[Optional[bool]] = True @@ -224,6 +229,7 @@ class MultiLabelTargetFormatter(SingleLabelTargetFormatter): [0, 1, 1] >>> formatter(["bird"]) [1, 0, 0] + """ multi_label: ClassVar[Optional[bool]] = True @@ -253,6 +259,7 @@ class CommaDelimitedMultiLabelTargetFormatter(MultiLabelTargetFormatter): [0, 1, 1] >>> formatter("bird") [1, 0, 0] + """ multi_label: ClassVar[Optional[bool]] = True @@ -278,6 +285,7 @@ class SpaceDelimitedTargetFormatter(MultiLabelTargetFormatter): [0, 1, 1] >>> formatter("bird") [1, 0, 0] + """ multi_label: ClassVar[Optional[bool]] = True @@ -304,6 +312,7 @@ class MultiBinaryTargetFormatter(TargetFormatter): [0, 1, 1] >>> formatter(torch.tensor([1, 0, 0])) [1, 0, 0] + """ multi_label: ClassVar[Optional[bool]] = True @@ -330,6 +339,7 @@ class MultiSoftTargetFormatter(MultiBinaryTargetFormatter): [0.1, 0.9, 0.6] >>> formatter(torch.tensor([0.9, 0.6, 0.7])) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE [0..., 0..., 0...] + """ binary: ClassVar[Optional[bool]] = False @@ -353,6 +363,7 @@ def _get_target_formatter_type(target: Any) -> Type[TargetFormatter]: Args: target: A target that is one of: a single target, a list of targets, a comma delimited string. + """ if isinstance(target, str): target = _strip(target) @@ -392,6 +403,7 @@ def _resolve_target_formatter(a: Type[TargetFormatter], b: Type[TargetFormatter] Raises: ValueError: If the two target formatters could not be resolved. + """ if a is b: return a @@ -423,6 +435,7 @@ def _get_target_details( Returns: (labels, num_classes): Tuple containing the inferred ``labels`` (or ``None`` if no labels could be inferred) and ``num_classes``. + """ targets = _as_list(targets) if target_formatter_type.numeric: @@ -481,6 +494,7 @@ def get_target_formatter( Returns: The target formatter to use when formatting targets. + """ targets = _as_list(targets) target_formatter_type: Type[TargetFormatter] = reduce( diff --git a/src/flash/core/data/utilities/collate.py b/src/flash/core/data/utilities/collate.py index 54f4a75b72..5e82c7434a 100644 --- a/src/flash/core/data/utilities/collate.py +++ b/src/flash/core/data/utilities/collate.py @@ -44,6 +44,7 @@ def wrap_collate(collate): Returns: The wrapped collate function. + """ return functools.partial(_wrap_collate, collate) @@ -61,5 +62,6 @@ def default_collate(batch: List[Any]) -> Any: Returns: The collated batch. + """ return _default_collate(batch) diff --git a/src/flash/core/data/utilities/sort.py b/src/flash/core/data/utilities/sort.py index 521b2550f5..060d37e1a8 100644 --- a/src/flash/core/data/utilities/sort.py +++ b/src/flash/core/data/utilities/sort.py @@ -28,5 +28,6 @@ def sorted_alphanumeric(iterable: Iterable[str]) -> Iterable[str]: this returns ``["class_1", "class_2", "class_11"]``. Copied from: https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/ + """ return sorted(iterable, key=_alphanumeric_key) diff --git a/src/flash/core/integrations/fiftyone/utils.py b/src/flash/core/integrations/fiftyone/utils.py index 8a14dd21a5..5470079ea6 100644 --- a/src/flash/core/integrations/fiftyone/utils.py +++ b/src/flash/core/integrations/fiftyone/utils.py @@ -96,6 +96,7 @@ class FiftyOneLabelUtilities: label_field: The field in the ``SampleCollection`` containing the ground truth labels. label_cls: The ``FiftyOne.Label`` subclass to expect ground truth labels to be instances of. If ``None``, defaults to ``FiftyOne.Label``. + """ def __init__(self, label_field: str = "ground_truth", label_cls: Optional[Type[Label]] = None): diff --git a/src/flash/core/integrations/labelstudio/input.py b/src/flash/core/integrations/labelstudio/input.py index ed2aa63bdb..b2a720d859 100644 --- a/src/flash/core/integrations/labelstudio/input.py +++ b/src/flash/core/integrations/labelstudio/input.py @@ -222,19 +222,22 @@ def _split_train_val_data(data: Dict, split: float = 0) -> List[Dict]: class LabelStudioInput(BaseLabelStudioInput, Input): - """The ``LabelStudioInput`` expects the input to - :meth:`~flash.core.data.io.input.Input.load_data` to be a json export from label studio.""" + """The ``LabelStudioInput`` expects the input to :meth:`~flash.core.data.io.input.Input.load_data` to be a json + export from label studio.""" class LabelStudioIterableInput(BaseLabelStudioInput, IterableInput): - """The ``LabelStudioInput`` expects the input to - :meth:`~flash.core.data.io.input.Input.load_data` to be a json export from label studio.""" + """The ``LabelStudioInput`` expects the input to :meth:`~flash.core.data.io.input.Input.load_data` to be a json + export from label studio.""" class LabelStudioImageClassificationInput(LabelStudioInput): - """The ``LabelStudioImageInput`` expects the input to - :meth:`~flash.core.data.io.input.Input.load_data` to be a json export from label studio. - Export data should point to image files""" + """The ``LabelStudioImageInput`` expects the input to :meth:`~flash.core.data.io.input.Input.load_data` to be a json + export from label studio. + + Export data should point to image files + + """ def load_sample(self, sample: Mapping[str, Any] = None) -> Any: """Load 1 sample from dataset.""" @@ -248,9 +251,11 @@ def load_sample(self, sample: Mapping[str, Any] = None) -> Any: class LabelStudioTextClassificationInput(LabelStudioInput): - """The ``LabelStudioTextInput`` expects the input to - :meth:`~flash.core.data.io.input.Input.load_data` to be a json export from label studio. + """The ``LabelStudioTextInput`` expects the input to :meth:`~flash.core.data.io.input.Input.load_data` to be a json + export from label studio. + Export data should point to text data + """ def load_sample(self, sample: Mapping[str, Any] = None) -> Any: @@ -265,9 +270,12 @@ def load_sample(self, sample: Mapping[str, Any] = None) -> Any: class LabelStudioVideoClassificationInput(LabelStudioIterableInput): - """The ``LabelStudioVideoInput`` expects the input to - :meth:`~flash.core.data.io.input.Input.load_data` to be a json export from label studio. - Export data should point to video files""" + """The ``LabelStudioVideoInput`` expects the input to :meth:`~flash.core.data.io.input.Input.load_data` to be a json + export from label studio. + + Export data should point to video files + + """ def __init__( self, diff --git a/src/flash/core/model.py b/src/flash/core/model.py index df1c92399c..617f7d8943 100644 --- a/src/flash/core/model.py +++ b/src/flash/core/model.py @@ -73,6 +73,7 @@ class ModuleWrapperBase: ``LightningModule`` instances so that nested calls to ``.log`` are handled correctly. The ``ModuleWrapperBase`` is also stateful, meaning that a :class:`~flash.core.data.data_pipeline.DataPipelineState` can be attached. Attached state will be forwarded to any nested ``ModuleWrapperBase`` instances. + """ def __init__(self): @@ -485,6 +486,7 @@ def modules_to_freeze(self) -> Optional[nn.Module]: Returns: The backbone ``Module`` to freeze or ``None`` if this task does not have a ``backbone`` attribute. + """ return getattr(self, "backbone", None) @@ -579,6 +581,7 @@ def as_embedder(self, layer: str): Args: layer: The layer to embed to. This should be one of the :meth:`~flash.core.model.Task.available_layers`. + """ from flash.core.utilities.embedder import Embedder # Avoid circular import @@ -672,6 +675,7 @@ def available_outputs(cls) -> List[str]: >>> print(Task.available_outputs()) ['preds', 'raw'] + """ return cls.outputs.available_keys() @@ -861,6 +865,7 @@ def serve( input_cls: The ``ServeInput`` type to use. transform: The transform to use when serving. transform_kwargs: Keyword arguments used to instantiate the transform. + """ from flash.core.serve.flash_components import build_flash_serve_model_component diff --git a/src/flash/core/registry.py b/src/flash/core/registry.py index 924cce8219..a1d830f55e 100644 --- a/src/flash/core/registry.py +++ b/src/flash/core/registry.py @@ -187,6 +187,7 @@ class ExternalRegistry(FlashRegistry): Args: getter: A function whose first argument is a key that can optionally take additional args and kwargs. providers: The provider(/s) of entries in this registry. + """ # Prevent users from trying to remove or register items diff --git a/src/flash/core/serve/dag/optimization.py b/src/flash/core/serve/dag/optimization.py index 5b8de1196a..967607daf3 100644 --- a/src/flash/core/serve/dag/optimization.py +++ b/src/flash/core/serve/dag/optimization.py @@ -368,6 +368,7 @@ def default_fused_keys_renamer(keys, max_fused_key_length=120): The optional parameter `max_fused_key_length` is used to limit the maximum string length for each renamed key. If this parameter is set to `None`, there is no limit. + """ it = reversed(keys) first_key = next(it) @@ -476,6 +477,7 @@ def fuse( dependencies dict mapping dependencies after fusion. Useful side effect to accelerate other downstream optimizations. + """ if keys is not None and not isinstance(keys, set): diff --git a/src/flash/core/serve/dag/order.py b/src/flash/core/serve/dag/order.py index 69447fad22..07dab3ed48 100644 --- a/src/flash/core/serve/dag/order.py +++ b/src/flash/core/serve/dag/order.py @@ -557,6 +557,7 @@ def graph_metrics(dependencies, dependents, total_dependencies): Returns ------- metrics: Dict[key, Tuple[int, int, int, int, int]] + """ result = {} num_needed = {k: len(v) for k, v in dependents.items() if v} diff --git a/src/flash/core/serve/dag/rewrite.py b/src/flash/core/serve/dag/rewrite.py index aebe058ba6..c742961233 100644 --- a/src/flash/core/serve/dag/rewrite.py +++ b/src/flash/core/serve/dag/rewrite.py @@ -248,6 +248,7 @@ def __init__(self, *rules): ---------- rules One or more instances of RewriteRule + """ self._net = Node() self.rules = [] @@ -358,6 +359,7 @@ def rewrite(self, task, strategy="bottom_up"): >>> rs.rewrite(term) # doctest: +ELLIPSIS (, (, 2)) + """ return strategies[strategy](self, task) diff --git a/src/flash/core/serve/dag/task.py b/src/flash/core/serve/dag/task.py index 59bb7875f8..57d49f7a67 100644 --- a/src/flash/core/serve/dag/task.py +++ b/src/flash/core/serve/dag/task.py @@ -392,6 +392,7 @@ def getcycle(d, keys): See Also -------- isdag + """ return _toposort(d, keys=keys, returncycle=True) @@ -410,6 +411,7 @@ def isdag(d, keys): See Also -------- getcycle + """ return not getcycle(d, keys) diff --git a/src/flash/core/serve/interfaces/models.py b/src/flash/core/serve/interfaces/models.py index 2177640b14..11f7266b96 100644 --- a/src/flash/core/serve/interfaces/models.py +++ b/src/flash/core/serve/interfaces/models.py @@ -99,6 +99,7 @@ def request_model(self) -> RequestModel: the key "image" -> the raw data... The field names are NOT altered, and therefore this workaround should pose very little issue for our end users). + """ attrib_dict = {} inputs = self._endpoint.inputs @@ -164,6 +165,7 @@ def response_model(self) -> ResponseModel: the key "image" -> the raw data... The field names are NOT altered, and therefore this workaround should pose very little issue for our end users). + """ attrib_dict = {} outputs = self._endpoint.outputs diff --git a/src/flash/core/trainer.py b/src/flash/core/trainer.py index fe7e39b83a..36d58651b9 100644 --- a/src/flash/core/trainer.py +++ b/src/flash/core/trainer.py @@ -53,6 +53,7 @@ def _defaults_from_env_vars(fn: Callable) -> Callable: """Copy of ``pytorch_lightning.trainer.connectors.env_vars_connector._defaults_from_env_vars``. Required to fix build error in readthedocs. + """ @wraps(fn) @@ -116,6 +117,7 @@ def fit( val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped + """ if any(isinstance(c, BaseFinetuning) for c in self.callbacks): # TODO: if we find a finetuning callback in the trainer should we remove it? or just warn the user? @@ -160,6 +162,7 @@ def finetune( By default, ``no_freeze`` strategy will be used. train_bn: Whether to train Batch Norm layer + """ self._resolve_callbacks(model, strategy, train_bn=train_bn) return super().fit(model, train_dataloader, val_dataloaders, datamodule) diff --git a/src/flash/core/utilities/flash_cli.py b/src/flash/core/utilities/flash_cli.py index 132fc85479..87a4985101 100644 --- a/src/flash/core/utilities/flash_cli.py +++ b/src/flash/core/utilities/flash_cli.py @@ -152,6 +152,7 @@ def __init__( the corresponding ``DataModule.from_*`` method. - ``Callable``. A custom method. kwargs: See the parent arguments + """ if datamodule_attributes is None: datamodule_attributes = {"num_classes"} diff --git a/src/flash/core/utilities/lightning_cli.py b/src/flash/core/utilities/lightning_cli.py index 898df68854..9f8e293f24 100644 --- a/src/flash/core/utilities/lightning_cli.py +++ b/src/flash/core/utilities/lightning_cli.py @@ -261,6 +261,7 @@ def __init__( subclass_mode_data: Whether datamodule can be any `subclass `_ of the given class. + """ self.model_class = model_class self.datamodule_class = datamodule_class diff --git a/src/flash/core/utilities/stability.py b/src/flash/core/utilities/stability.py index 8dd553a8a8..34c9ecfe73 100644 --- a/src/flash/core/utilities/stability.py +++ b/src/flash/core/utilities/stability.py @@ -68,6 +68,7 @@ def beta(message: str = "This feature is currently in Beta."): ... MyBetaFeatureWithCustomMessage() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE ... <...> + """ def decorator(callable: Union[Callable, Type]): diff --git a/src/flash/graph/classification/data.py b/src/flash/graph/classification/data.py index 457d7b45c2..b52730788f 100644 --- a/src/flash/graph/classification/data.py +++ b/src/flash/graph/classification/data.py @@ -164,6 +164,7 @@ def from_datasets( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + """ ds_kw = { diff --git a/src/flash/graph/classification/input_transform.py b/src/flash/graph/classification/input_transform.py index 66a0134f00..78543bedeb 100644 --- a/src/flash/graph/classification/input_transform.py +++ b/src/flash/graph/classification/input_transform.py @@ -33,6 +33,7 @@ class PyGTransformAdapter: Args: transform: Transform to apply. + """ transform: Callable[[Data], Data] diff --git a/src/flash/graph/classification/model.py b/src/flash/graph/classification/model.py index de7abbf07c..4eff4d15f2 100644 --- a/src/flash/graph/classification/model.py +++ b/src/flash/graph/classification/model.py @@ -50,6 +50,7 @@ class GraphClassifier(ClassificationTask): optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. + """ backbones: FlashRegistry = GRAPH_BACKBONES diff --git a/src/flash/graph/embedding/model.py b/src/flash/graph/embedding/model.py index e214c25cf1..22ce1842f7 100644 --- a/src/flash/graph/embedding/model.py +++ b/src/flash/graph/embedding/model.py @@ -29,6 +29,7 @@ class GraphEmbedder(Task): Args: backbone: A model to use to extract image features. pooling_fn: The global pooling operation to use (one of: "max", "max", "add" or a callable). + """ required_extras: str = "graph" diff --git a/src/flash/image/classification/data.py b/src/flash/image/classification/data.py index 3bf13bb91e..d5d6c60c42 100644 --- a/src/flash/image/classification/data.py +++ b/src/flash/image/classification/data.py @@ -149,6 +149,7 @@ def from_files( >>> import os >>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)] >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] + """ ds_kw = { "target_formatter": target_formatter, @@ -269,6 +270,7 @@ def from_folders( >>> import shutil >>> shutil.rmtree("train_folder") >>> shutil.rmtree("predict_folder") + """ ds_kw = { "target_formatter": target_formatter, @@ -355,6 +357,7 @@ def from_numpy( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + """ ds_kw = { "target_formatter": target_formatter, @@ -446,6 +449,7 @@ def from_images( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + """ ds_kw = { "target_formatter": target_formatter, @@ -532,6 +536,7 @@ def from_tensors( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + """ ds_kw = { "target_formatter": target_formatter, @@ -671,6 +676,7 @@ def from_data_frame( >>> shutil.rmtree("predict_folder") >>> del train_data_frame >>> del predict_data_frame + """ ds_kw = { "target_formatter": target_formatter, @@ -902,6 +908,7 @@ def from_csv( >>> shutil.rmtree("predict_folder") >>> os.remove("train_data.tsv") >>> os.remove("predict_data.tsv") + """ ds_kw = { "target_formatter": target_formatter, @@ -1021,6 +1028,7 @@ def from_fiftyone( >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] >>> del train_dataset >>> del predict_dataset + """ ds_kw = { "target_formatter": target_formatter, @@ -1092,6 +1100,7 @@ def from_labelstudio( data_folder='label-studio/media/upload', val_split=0.8, ) + """ train_data, val_data, test_data, predict_data = _parse_labelstudio_arguments( @@ -1158,6 +1167,7 @@ def from_datasets( data_module = DataModule.from_datasets( train_dataset=train_dataset, ) + """ ds_kw = {} diff --git a/src/flash/image/classification/integrations/baal/data.py b/src/flash/image/classification/integrations/baal/data.py index 7fe1a5fe5a..fe1b913f2a 100644 --- a/src/flash/image/classification/integrations/baal/data.py +++ b/src/flash/image/classification/integrations/baal/data.py @@ -75,6 +75,7 @@ def __init__( initial_num_labels: Number of samples to randomly label to start the training with. query_size: Number of samples to be labelled at each Active Learning loop based on the fed heuristic. val_split: Float to split train dataset into train and validation set. + """ super().__init__(batch_size=1) self.labelled = labelled diff --git a/src/flash/image/classification/integrations/baal/loop.py b/src/flash/image/classification/integrations/baal/loop.py index 61785b668a..256a3b4179 100644 --- a/src/flash/image/classification/integrations/baal/loop.py +++ b/src/flash/image/classification/integrations/baal/loop.py @@ -58,6 +58,7 @@ def __init__(self, label_epoch_frequency: int, inference_iteration: int = 2, sho Args: label_epoch_frequency: Number of epoch to train on before requesting labellisation. inference_iteration: Number of inference to perform to compute uncertainty. + """ super().__init__() self.label_epoch_frequency = label_epoch_frequency diff --git a/src/flash/image/classification/model.py b/src/flash/image/classification/model.py index 2b28ecc35b..7b6d422abb 100644 --- a/src/flash/image/classification/model.py +++ b/src/flash/image/classification/model.py @@ -77,6 +77,7 @@ def fn_resnet(pretrained: bool = True): training_strategy: string indicating the training strategy. Adjust if you want to use `learn2learn` for doing meta-learning research training_strategy_kwargs: Additional kwargs for setting the training strategy + """ backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES diff --git a/src/flash/image/detection/data.py b/src/flash/image/detection/data.py index b6c8135675..d5e1c2e1b2 100644 --- a/src/flash/image/detection/data.py +++ b/src/flash/image/detection/data.py @@ -163,6 +163,7 @@ def from_files( >>> import os >>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)] >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] + """ ds_kw = { @@ -281,6 +282,7 @@ def from_numpy( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + """ ds_kw = { @@ -404,6 +406,7 @@ def from_images( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + """ ds_kw = { @@ -522,6 +525,7 @@ def from_tensors( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + """ ds_kw = { @@ -893,6 +897,7 @@ def from_voc( >>> shutil.rmtree("train_annotations") .. _PASCAL: http://host.robots.ox.ac.uk/pascal/VOC/ + """ return cls.from_icedata( train_folder=train_folder, @@ -1162,6 +1167,7 @@ def from_fiftyone( >>> import os >>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)] >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] + """ ds_kw = {} @@ -1195,6 +1201,7 @@ def from_folders( Returns: The constructed data module. + """ return cls( predict_input=input_cls(RunningStage.PREDICTING, predict_folder), diff --git a/src/flash/image/detection/model.py b/src/flash/image/detection/model.py index 15def3da8f..f58cfe6923 100644 --- a/src/flash/image/detection/model.py +++ b/src/flash/image/detection/model.py @@ -41,6 +41,7 @@ class ObjectDetector(AdapterTask): learning_rate: The learning rate to use for training. predict_kwargs: dictionary containing parameters that will be used during the prediction phase. kwargs: additional kwargs nessesary for initializing the backbone task + """ heads: FlashRegistry = OBJECT_DETECTION_HEADS diff --git a/src/flash/image/embedding/model.py b/src/flash/image/embedding/model.py index 2f02beafee..cd1796d58d 100644 --- a/src/flash/image/embedding/model.py +++ b/src/flash/image/embedding/model.py @@ -66,6 +66,7 @@ class ImageEmbedder(AdapterTask): backbone_kwargs: arguments to be passed to VISSL backbones, i.e. ``vision_transformer`` and ``resnet``. training_strategy_kwargs: arguments passed to VISSL loss function, projection head and training hooks. pretraining_transform_kwargs: arguments passed to VISSL transforms. + """ training_strategies: FlashRegistry = IMAGE_EMBEDDER_STRATEGIES @@ -173,6 +174,7 @@ def available_training_strategies(cls) -> List[str]: >>> from flash.image import ImageEmbedder >>> ImageEmbedder.available_training_strategies() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE ['barlow_twins', ..., 'swav'] + """ registry: Optional[FlashRegistry] = getattr(cls, "training_strategies", None) if registry is None: diff --git a/src/flash/image/embedding/vissl/adapter.py b/src/flash/image/embedding/vissl/adapter.py index ad901f7ed1..9633aeffb2 100644 --- a/src/flash/image/embedding/vissl/adapter.py +++ b/src/flash/image/embedding/vissl/adapter.py @@ -38,6 +38,7 @@ class _VISSLBackboneWrapper(nn.Module): """VISSL backbones take additional arguments in ``forward`` that are not needed for our integration. This wrapper can be applied to a Flash backbone to ignore any additional arguments to ``forward``. + """ def __init__(self, backbone: nn.Module): @@ -75,6 +76,7 @@ class VISSLAdapter(Adapter, AdaptVISSLHooks): """The ``VISSLAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with VISSL. Also inherits from ``AdaptVISSLHooks`` to support VISSL hooks. + """ required_extras: str = "image" diff --git a/src/flash/image/face_detection/model.py b/src/flash/image/face_detection/model.py index 50e4b7ae20..ca72d85d72 100644 --- a/src/flash/image/face_detection/model.py +++ b/src/flash/image/face_detection/model.py @@ -44,6 +44,7 @@ class FaceDetector(Task): lr_scheduler: The LR scheduler to use during training. learning_rate: The learning rate to use for training. kwargs: additional kwargs nessesary for initializing face detector backbone + """ required_extras: str = "image" diff --git a/src/flash/image/instance_segmentation/data.py b/src/flash/image/instance_segmentation/data.py index 0f9430e42a..1fc68b3f15 100644 --- a/src/flash/image/instance_segmentation/data.py +++ b/src/flash/image/instance_segmentation/data.py @@ -431,6 +431,7 @@ def from_voc( >>> shutil.rmtree("train_annotations") .. _PASCAL: http://host.robots.ox.ac.uk/pascal/VOC/ + """ return cls.from_icedata( train_folder=train_folder, @@ -472,6 +473,7 @@ def from_folders( Returns: The constructed data module. + """ return cls( predict_input=input_cls(RunningStage.PREDICTING, predict_folder), @@ -502,6 +504,7 @@ def from_files( Returns: The constructed data module. + """ return cls( predict_input=input_cls(RunningStage.PREDICTING, predict_files), diff --git a/src/flash/image/instance_segmentation/model.py b/src/flash/image/instance_segmentation/model.py index 1d4bd681d8..81521f3f85 100644 --- a/src/flash/image/instance_segmentation/model.py +++ b/src/flash/image/instance_segmentation/model.py @@ -34,6 +34,7 @@ class InstanceSegmentation(AdapterTask): learning_rate: The learning rate to use for training. predict_kwargs: dictionary containing parameters that will be used during the prediction phase. **kwargs: additional kwargs used for initializing the task + """ heads: FlashRegistry = INSTANCE_SEGMENTATION_HEADS diff --git a/src/flash/image/keypoint_detection/data.py b/src/flash/image/keypoint_detection/data.py index 98fb0f2078..4bd614dd6b 100644 --- a/src/flash/image/keypoint_detection/data.py +++ b/src/flash/image/keypoint_detection/data.py @@ -290,6 +290,7 @@ def from_folders( Returns: The constructed data module. + """ return cls( predict_input=input_cls(RunningStage.PREDICTING, predict_folder), @@ -320,6 +321,7 @@ def from_files( Returns: The constructed data module. + """ return cls( predict_input=input_cls(RunningStage.PREDICTING, predict_files), diff --git a/src/flash/image/keypoint_detection/model.py b/src/flash/image/keypoint_detection/model.py index 1bb65ed3dd..38ccf9f8cb 100644 --- a/src/flash/image/keypoint_detection/model.py +++ b/src/flash/image/keypoint_detection/model.py @@ -34,6 +34,7 @@ class KeypointDetector(AdapterTask): learning_rate: The learning rate to use for training. predict_kwargs: dictionary containing parameters that will be used during the prediction phase. **kwargs: additional kwargs used for initializing the task + """ heads: FlashRegistry = KEYPOINT_DETECTION_HEADS diff --git a/src/flash/image/segmentation/data.py b/src/flash/image/segmentation/data.py index c9f98b9a46..ed26ee4684 100644 --- a/src/flash/image/segmentation/data.py +++ b/src/flash/image/segmentation/data.py @@ -141,6 +141,7 @@ def from_files( >>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)] >>> _ = [os.remove(f"mask_{i}.npy") for i in range(1, 4)] >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] + """ ds_kw = { @@ -282,6 +283,7 @@ def from_folders( >>> shutil.rmtree("train_images") >>> shutil.rmtree("train_masks") >>> shutil.rmtree("predict_folder") + """ ds_kw = { @@ -370,6 +372,7 @@ def from_numpy( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + """ ds_kw = { @@ -458,6 +461,7 @@ def from_tensors( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + """ ds_kw = { @@ -569,6 +573,7 @@ def from_fiftyone( >>> import os >>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)] >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] + """ return cls( diff --git a/src/flash/image/segmentation/model.py b/src/flash/image/segmentation/model.py index 121150e60c..75fab675a1 100644 --- a/src/flash/image/segmentation/model.py +++ b/src/flash/image/segmentation/model.py @@ -92,6 +92,7 @@ class SemanticSegmentation(ClassificationTask): multi_label: Whether the targets are multi-label or not. output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. output_transform: :class:`~flash.core.data.io.output_transform.OutputTransform` use for post processing samples. + """ output_transform_cls = SemanticSegmentationOutputTransform diff --git a/src/flash/image/style_transfer/data.py b/src/flash/image/style_transfer/data.py index 8a3bb45dd3..d40e230c1d 100644 --- a/src/flash/image/style_transfer/data.py +++ b/src/flash/image/style_transfer/data.py @@ -100,6 +100,7 @@ def from_files( >>> import os >>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)] >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] + """ return cls( @@ -185,6 +186,7 @@ def from_folders( >>> import shutil >>> shutil.rmtree("train_folder") >>> shutil.rmtree("predict_folder") + """ return cls( @@ -243,6 +245,7 @@ def from_numpy( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + """ return cls( @@ -302,6 +305,7 @@ def from_tensors( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + """ return cls( diff --git a/src/flash/image/style_transfer/model.py b/src/flash/image/style_transfer/model.py index eb67e88961..c88a00e612 100644 --- a/src/flash/image/style_transfer/model.py +++ b/src/flash/image/style_transfer/model.py @@ -62,6 +62,7 @@ class StyleTransfer(Task): optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. learning_rate: Learning rate to use for training, defaults to ``1e-3``. + """ backbones: FlashRegistry = STYLE_TRANSFER_BACKBONES diff --git a/src/flash/pointcloud/detection/model.py b/src/flash/pointcloud/detection/model.py index a74ba701aa..e67e7966ff 100644 --- a/src/flash/pointcloud/detection/model.py +++ b/src/flash/pointcloud/detection/model.py @@ -50,6 +50,7 @@ class PointCloudObjectDetector(Task): lambda_loss_cls: The value to scale the loss classification. lambda_loss_bbox: The value to scale the bounding boxes loss. lambda_loss_dir: The value to scale the bounding boxes direction loss. + """ backbones: FlashRegistry = POINTCLOUD_OBJECT_DETECTION_BACKBONES diff --git a/src/flash/pointcloud/segmentation/model.py b/src/flash/pointcloud/segmentation/model.py index 9d38c9044c..d83ea1e55a 100644 --- a/src/flash/pointcloud/segmentation/model.py +++ b/src/flash/pointcloud/segmentation/model.py @@ -58,6 +58,7 @@ class PointCloudSegmentation(ClassificationTask): by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. multi_label: If ``True``, this will be treated as a multi-label classification problem. + """ backbones: FlashRegistry = POINTCLOUD_SEGMENTATION_BACKBONES diff --git a/src/flash/tabular/classification/data.py b/src/flash/tabular/classification/data.py index 9e5f89a6b8..da3b86c000 100644 --- a/src/flash/tabular/classification/data.py +++ b/src/flash/tabular/classification/data.py @@ -156,6 +156,7 @@ def from_data_frame( >>> del train_data >>> del predict_data + """ ds_kw = { "target_formatter": target_formatter, @@ -358,6 +359,7 @@ def from_csv( >>> import os >>> os.remove("train_data.tsv") >>> os.remove("predict_data.tsv") + """ ds_kw = { "target_formatter": target_formatter, @@ -494,6 +496,7 @@ def from_dicts( >>> del train_data >>> del predict_data + """ ds_kw = { "target_formatter": target_formatter, @@ -632,6 +635,7 @@ def from_lists( >>> del train_data >>> del predict_data + """ ds_kw = { "target_formatter": target_formatter, diff --git a/src/flash/tabular/classification/model.py b/src/flash/tabular/classification/model.py index 4c36eac568..5c9fb44968 100644 --- a/src/flash/tabular/classification/model.py +++ b/src/flash/tabular/classification/model.py @@ -53,6 +53,7 @@ class TabularClassifier(ClassificationAdapterTask): `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training. **backbone_kwargs: Optional additional arguments for the model. + """ required_extras: str = "tabular" @@ -127,6 +128,7 @@ def data_parameters(self) -> Dict[str, Any]: output="classes", ) Predicting... + """ return self._data_parameters diff --git a/src/flash/tabular/forecasting/data.py b/src/flash/tabular/forecasting/data.py index 36e1fc2f01..dc782e4c40 100644 --- a/src/flash/tabular/forecasting/data.py +++ b/src/flash/tabular/forecasting/data.py @@ -155,6 +155,7 @@ def from_data_frame( .. testcleanup:: >>> del data + """ ds_kw = dict( diff --git a/src/flash/tabular/forecasting/model.py b/src/flash/tabular/forecasting/model.py index 1fe2144205..9777316c84 100644 --- a/src/flash/tabular/forecasting/model.py +++ b/src/flash/tabular/forecasting/model.py @@ -68,6 +68,7 @@ def pytorch_forecasting_model(self) -> LightningModule: This can be used with :func:`~flash.core.integrations.pytorch_forecasting.transforms.convert_predictions` to access the visualization features built in to PyTorch Forecasting. + """ if not isinstance(self.adapter, PyTorchForecastingAdapter): raise AttributeError( diff --git a/src/flash/tabular/regression/data.py b/src/flash/tabular/regression/data.py index 9e1dd0d01c..2cffadff94 100644 --- a/src/flash/tabular/regression/data.py +++ b/src/flash/tabular/regression/data.py @@ -146,6 +146,7 @@ def from_data_frame( >>> del train_data >>> del predict_data + """ ds_kw = { "categorical_fields": categorical_fields, @@ -333,6 +334,7 @@ def from_csv( >>> import os >>> os.remove("train_data.tsv") >>> os.remove("predict_data.tsv") + """ ds_kw = { "categorical_fields": categorical_fields, @@ -459,6 +461,7 @@ def from_dicts( >>> del train_data >>> del predict_data + """ ds_kw = { "categorical_fields": categorical_fields, @@ -588,6 +591,7 @@ def from_lists( >>> del train_data >>> del predict_data + """ ds_kw = { "categorical_fields": categorical_fields, diff --git a/src/flash/tabular/regression/model.py b/src/flash/tabular/regression/model.py index 9bae8f76b3..96f5bdad18 100644 --- a/src/flash/tabular/regression/model.py +++ b/src/flash/tabular/regression/model.py @@ -52,6 +52,7 @@ class TabularRegressor(RegressionAdapterTask): `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training. **backbone_kwargs: Optional additional arguments for the model. + """ required_extras: str = "tabular" @@ -121,6 +122,7 @@ def data_parameters(self) -> Dict[str, Any]: model, datamodule=datamodule, ) + """ return self._data_parameters diff --git a/src/flash/template/classification/data.py b/src/flash/template/classification/data.py index 9b6ebdeba9..f040164119 100644 --- a/src/flash/template/classification/data.py +++ b/src/flash/template/classification/data.py @@ -54,6 +54,7 @@ def load_data( Returns: A sequence of samples / sample metadata. + """ if not self.predicting and isinstance(examples, np.ndarray): self.num_features = examples.shape[1] @@ -79,6 +80,7 @@ def load_data(self, data: Bunch, target_formatter: Optional[TargetFormatter] = N Returns: A sequence of samples / sample metadata. + """ return super().load_data(data.data, data.target, target_formatter=target_formatter) @@ -118,6 +120,7 @@ class TemplateData(DataModule): Next, we add a ``from_numpy`` method and a ``from_sklearn`` method. Finally, we define the ``num_features`` property for convenience. + """ input_transform_cls = TemplateInputTransform @@ -156,6 +159,7 @@ def from_numpy( Returns: The constructed data module. + """ ds_kw = {} @@ -201,6 +205,7 @@ def from_sklearn( Returns: The constructed data module. + """ ds_kw = {} @@ -238,6 +243,7 @@ class TemplateVisualization(BaseVisualization): the data. If you want to provide a visualization with your task, you can override these hooks. + """ def show_load_sample( diff --git a/src/flash/template/classification/model.py b/src/flash/template/classification/model.py index d8de0b2a15..eb19539fcc 100644 --- a/src/flash/template/classification/model.py +++ b/src/flash/template/classification/model.py @@ -39,6 +39,7 @@ class TemplateSKLearnClassifier(ClassificationTask): by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. multi_label: If ``True``, this will be treated as a multi-label classification problem. + """ backbones: FlashRegistry = TEMPLATE_BACKBONES diff --git a/src/flash/text/classification/data.py b/src/flash/text/classification/data.py index aaafaf4f49..027df62632 100644 --- a/src/flash/text/classification/data.py +++ b/src/flash/text/classification/data.py @@ -210,6 +210,7 @@ def from_csv( >>> os.remove("train_data.tsv") >>> os.remove("predict_data.tsv") + """ ds_kw = { "target_formatter": target_formatter, @@ -333,6 +334,7 @@ def from_json( >>> os.remove("train_data.json") >>> os.remove("predict_data.json") + """ ds_kw = { "target_formatter": target_formatter, @@ -457,6 +459,7 @@ def from_parquet( >>> os.remove("train_data.parquet") >>> os.remove("predict_data.parquet") + """ ds_kw = { "target_formatter": target_formatter, @@ -560,6 +563,7 @@ def from_hf_datasets( >>> del train_data >>> del predict_data + """ ds_kw = { "target_formatter": target_formatter, @@ -664,6 +668,7 @@ def from_data_frame( >>> del train_data >>> del predict_data + """ ds_kw = { "target_formatter": target_formatter, @@ -750,6 +755,7 @@ def from_lists( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + """ ds_kw = { "target_formatter": target_formatter, @@ -813,6 +819,7 @@ def from_labelstudio( Returns: The constructed data module. + """ train_data, val_data, test_data, predict_data = _parse_labelstudio_arguments( diff --git a/src/flash/text/classification/model.py b/src/flash/text/classification/model.py index 89a1273836..7725f5e61a 100644 --- a/src/flash/text/classification/model.py +++ b/src/flash/text/classification/model.py @@ -52,6 +52,7 @@ class TextClassifier(ClassificationAdapterTask): learning_rate: Learning rate to use for training, defaults to `1e-3` multi_label: Whether the targets are multi-label or not. enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training + """ required_extras: str = "text" diff --git a/src/flash/text/embedding/model.py b/src/flash/text/embedding/model.py index 9c73b975b5..da73c6c18d 100644 --- a/src/flash/text/embedding/model.py +++ b/src/flash/text/embedding/model.py @@ -47,6 +47,7 @@ class TextEmbedder(Task): Args: backbone: backbone model to use for the task. enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training + """ required_extras: str = "text" diff --git a/src/flash/text/question_answering/data.py b/src/flash/text/question_answering/data.py index d74b4047e0..fdf29f7bb9 100644 --- a/src/flash/text/question_answering/data.py +++ b/src/flash/text/question_answering/data.py @@ -207,6 +207,7 @@ def from_csv( >>> os.remove("train_data.tsv") >>> os.remove("predict_data.tsv") + """ ds_kw = { @@ -352,6 +353,7 @@ def from_json( >>> os.remove("train_data.json") >>> os.remove("predict_data.json") + """ ds_kw = { @@ -634,6 +636,7 @@ def from_squad_v2( >>> import os >>> os.remove("train_data.json") >>> os.remove("predict_data.json") + """ ds_kw = { @@ -747,6 +750,7 @@ def from_dicts( >>> del train_data >>> del predict_data + """ ds_kw = { diff --git a/src/flash/text/question_answering/model.py b/src/flash/text/question_answering/model.py index e999825777..baa04f407b 100644 --- a/src/flash/text/question_answering/model.py +++ b/src/flash/text/question_answering/model.py @@ -86,6 +86,7 @@ class QuestionAnsweringTask(Task): less than the score of the null answer minus this threshold, the null answer is selected for this example. Only useful when `version_2_with_negative=True`. use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching. + """ required_extras: str = "text" diff --git a/src/flash/text/seq2seq/summarization/data.py b/src/flash/text/seq2seq/summarization/data.py index 04c00547a8..6a7d755781 100644 --- a/src/flash/text/seq2seq/summarization/data.py +++ b/src/flash/text/seq2seq/summarization/data.py @@ -190,6 +190,7 @@ def from_csv( >>> os.remove("train_data.tsv") >>> os.remove("predict_data.tsv") + """ ds_kw = { @@ -300,6 +301,7 @@ def from_json( >>> os.remove("train_data.json") >>> os.remove("predict_data.json") + """ ds_kw = { @@ -393,6 +395,7 @@ def from_hf_datasets( >>> del train_data >>> del predict_data + """ ds_kw = { @@ -467,6 +470,7 @@ def from_lists( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + """ ds_kw = {} diff --git a/src/flash/text/seq2seq/summarization/model.py b/src/flash/text/seq2seq/summarization/model.py index 926a9823b6..fd74ba9d8a 100644 --- a/src/flash/text/seq2seq/summarization/model.py +++ b/src/flash/text/seq2seq/summarization/model.py @@ -42,6 +42,7 @@ class SummarizationTask(Seq2SeqTask): num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching. enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training + """ def __init__( diff --git a/src/flash/text/seq2seq/translation/data.py b/src/flash/text/seq2seq/translation/data.py index aaf5c54aee..35865f929b 100644 --- a/src/flash/text/seq2seq/translation/data.py +++ b/src/flash/text/seq2seq/translation/data.py @@ -184,6 +184,7 @@ def from_csv( >>> os.remove("train_data.tsv") >>> os.remove("predict_data.tsv") + """ ds_kw = { @@ -293,6 +294,7 @@ def from_json( >>> os.remove("train_data.json") >>> os.remove("predict_data.json") + """ ds_kw = { @@ -386,6 +388,7 @@ def from_hf_datasets( >>> del train_data >>> del predict_data + """ ds_kw = { @@ -460,6 +463,7 @@ def from_lists( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + """ ds_kw = {} diff --git a/src/flash/text/seq2seq/translation/model.py b/src/flash/text/seq2seq/translation/model.py index 71fe3834aa..e6efc36c6f 100644 --- a/src/flash/text/seq2seq/translation/model.py +++ b/src/flash/text/seq2seq/translation/model.py @@ -42,6 +42,7 @@ class TranslationTask(Seq2SeqTask): n_gram: Maximum n_grams to use in metric calculation. Defaults to `4` smooth: Apply smoothing in BLEU calculation. Defaults to `True` enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training + """ def __init__( diff --git a/src/flash/video/classification/data.py b/src/flash/video/classification/data.py index e55392b482..d47d5d714b 100644 --- a/src/flash/video/classification/data.py +++ b/src/flash/video/classification/data.py @@ -159,6 +159,7 @@ def from_files( >>> import os >>> _ = [os.remove(f"video_{i}.mp4") for i in range(1, 4)] >>> _ = [os.remove(f"predict_video_{i}.mp4") for i in range(1, 4)] + """ ds_kw = { "clip_sampler": clip_sampler, @@ -324,6 +325,7 @@ def from_folders( >>> import shutil >>> shutil.rmtree("train_folder") >>> shutil.rmtree("predict_folder") + """ ds_kw = { "clip_sampler": clip_sampler, @@ -507,6 +509,7 @@ def from_data_frame( >>> shutil.rmtree("predict_folder") >>> del train_data_frame >>> del predict_data_frame + """ ds_kw = { "clip_sampler": clip_sampler, @@ -635,6 +638,7 @@ def from_tensors( .. testcleanup:: >>> del frame + """ train_input = input_cls( @@ -905,6 +909,7 @@ def from_csv( >>> shutil.rmtree("predict_folder") >>> os.remove("train_data.tsv") >>> os.remove("predict_data.tsv") + """ ds_kw = { "clip_sampler": clip_sampler, @@ -1062,6 +1067,7 @@ def from_fiftyone( >>> _ = [os.remove(f"predict_video_{i}.mp4") for i in range(1, 4)] >>> del train_dataset >>> del predict_dataset + """ ds_kw = { "clip_sampler": clip_sampler, @@ -1171,6 +1177,7 @@ def from_labelstudio( data_folder='label-studio/media/upload', val_split=0.8, ) + """ train_data, val_data, test_data, predict_data = _parse_labelstudio_arguments( diff --git a/tests/helpers/task_tester.py b/tests/helpers/task_tester.py index c480c1641f..d54a5bad1e 100644 --- a/tests/helpers/task_tester.py +++ b/tests/helpers/task_tester.py @@ -154,6 +154,7 @@ class TaskTesterMeta(ABCMeta): These tests will also be wrapped with the appropriate marks to skip them if the required dependencies are not available. + """ @staticmethod @@ -240,6 +241,7 @@ class TaskTester(metaclass=TaskTesterMeta): Use the class attributes to control which tests will be run. For example, if ``traceable`` is ``False`` then no JIT tracing test will be performed. + """ task: Task