From 2649bca01779af81c20c5c69c34ac2f47e749a39 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 18 Oct 2024 16:12:50 -0700 Subject: [PATCH 01/11] Add module base class and begin filter refactor Signed-off-by: Ryan Wolf --- nemo_curator/filters/base.py | 202 +++++++++++++++++++++++++ nemo_curator/modules/add_id.py | 9 +- nemo_curator/modules/base.py | 42 +++++ nemo_curator/modules/dataset_ops.py | 9 +- nemo_curator/modules/exact_dedup.py | 7 +- nemo_curator/modules/fuzzy_dedup.py | 25 ++- nemo_curator/modules/modify.py | 9 +- nemo_curator/modules/semantic_dedup.py | 9 +- nemo_curator/modules/task.py | 9 +- nemo_curator/modules/to_backend.py | 29 ++++ 10 files changed, 333 insertions(+), 17 deletions(-) create mode 100644 nemo_curator/filters/base.py create mode 100644 nemo_curator/modules/base.py create mode 100644 nemo_curator/modules/to_backend.py diff --git a/nemo_curator/filters/base.py b/nemo_curator/filters/base.py new file mode 100644 index 000000000..c48d9809d --- /dev/null +++ b/nemo_curator/filters/base.py @@ -0,0 +1,202 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, List, Optional, Union + +from nemo_curator.datasets import DocumentDataset +from nemo_curator.modules.base import Module +from nemo_curator.utils.module_utils import is_batched + + +class FilterMode(Enum): + SCORE_FILTER = "score_filter" + SCORE = "score" + FILTER = "filter" + + +class DocumentFilter(Module, ABC): + """ + An abstract base class for text-based document filters. + + This class serves as a template for creating specific document filters + in the library. Subclasses should implement the abstract methods to + define custom filtering behavior. + """ + + def __init__( + self, + text_fields: List[str] = ["text"], + score_fields: List[str] = ["score"], + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__() + self.text_fields = text_fields + self.score_fields = score_fields + self.removed_path = removed_path + self.invert = invert + self.filter_mode = filter_mode + self.save_score = save_score + + @property + def input_backend(self) -> str: + return "pandas" + + @abstractmethod + def score_document(self, text: str) -> Any: + """ + Calculate a score for the given document text. + + This method should be implemented by subclasses to define how + a document's text is evaluated and scored. + + Args: + text (str): The text content of the document to be scored. + + Returns: + Any: A score or set of scores representing the document's + relevance or quality. The type and structure of the + return value should be consistent for each subclass. + + Raises: + NotImplementedError: If the method is not implemented in a subclass. + """ + raise NotImplementedError( + "score_document method must be implemented by subclasses" + ) + + @abstractmethod + def keep_document(self, scores: Any) -> bool: + """ + Determine whether to keep a document based on its scores. + + This method should be implemented by subclasses to define the + criteria for keeping or discarding a document based on the + scores calculated by score_document(). + + Args: + scores (Any): The score or set of scores returned by score_document(). + The type should match what is returned by score_document(). + + Returns: + bool: True if the document should be kept, False otherwise. + + Raises: + NotImplementedError: If the method is not implemented in a subclass. + """ + raise NotImplementedError( + "keep_document method must be implemented by subclasses" + ) + + @abstractmethod + @property + def score_type(self): + raise NotImplementedError( + "keep_document method must be implemented by subclasses" + ) + + def call(self, dataset: DocumentDataset) -> DocumentDataset: + match self.filter_mode: + case FilterMode.SCORE: + meta = (None, self.score_type) + + if is_batched(self.score_document): + scores = dataset.df[self.text_fields].map_partitions( + self.score_document, meta=meta + ) + else: + scores = dataset.df[self.text_fields].apply( + self.score_document, meta=meta + ) + + if self.save_score: + dataset.df[self.score_fields] = scores + + return dataset + case FilterMode.FILTER: + scores = dataset.df[self.score_fields] + + if is_batched(self.keep_document): + bool_mask = scores.map_partitions( + self.keep_document, meta=(None, bool) + ) + else: + bool_mask = scores.apply(self.keep_document, meta=(None, bool)) + if self.invert: + bool_mask = ~bool_mask + + if self.removed_path: + removed_docs = DocumentDataset(dataset.df[~bool_mask]) + removed_docs.to_parquet(output_file_dir=self.removed_path) + + return DocumentDataset(dataset.df[bool_mask]) + + case FilterMode.SCORE_FILTER: + meta = (None, self.score_type) + + if is_batched(self.score_document): + scores = dataset.df[self.text_fields].map_partitions( + self.score_document, meta=meta + ) + else: + scores = dataset.df[self.text_fields].apply( + self.score_document, meta=meta + ) + + if self.save_score: + dataset.df[self.score_fields] = scores + + if is_batched(self.keep_document): + bool_mask = scores.map_partitions( + self.keep_document, meta=(None, bool) + ) + else: + bool_mask = scores.apply(self.keep_document, meta=(None, bool)) + if self.invert: + bool_mask = ~bool_mask + + if self.removed_path: + removed_docs = DocumentDataset(dataset.df[~bool_mask]) + removed_docs.to_parquet(output_file_dir=self.removed_path) + + return DocumentDataset(dataset.df[bool_mask]) + + +def import_filter(filter_path: str) -> DocumentFilter: + """ + Imports a filter under nemo_curator.filters given the module path + + Args: + filter_path (str): The path to the filter in the form of "nemo_curator.filters.filter_name" + + Returns: + DocumentFilter: The filter that is at the given path + + Raises: + ValueError: If the filter_path does not point to a DocumentFilter + """ + module_path, filter_name = filter_path.rsplit(".", 1) + filter_module = importlib.import_module(module_path) + filter_class = getattr(filter_module, filter_name) + if not issubclass(filter_class, DocumentFilter): + raise ValueError( + f"Input filter {filter_class.__name__} must be derived " + "from DocumentFilter defined in nemo_curator.filters.doc_filter" + ) + return filter_class diff --git a/nemo_curator/modules/add_id.py b/nemo_curator/modules/add_id.py index 244163912..894f25bfe 100644 --- a/nemo_curator/modules/add_id.py +++ b/nemo_curator/modules/add_id.py @@ -19,10 +19,11 @@ from dask import delayed from nemo_curator.datasets import DocumentDataset +from nemo_curator.modules.base import Module from nemo_curator.utils.module_utils import count_digits -class AddId: +class AddId(Module): def __init__( self, id_field, id_prefix: str = "doc_id", start_index: Optional[int] = None ) -> None: @@ -30,7 +31,11 @@ def __init__( self.id_prefix = id_prefix self.start_index = start_index - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + @property + def input_backend(self) -> str: + return "pandas" + + def call(self, dataset: DocumentDataset) -> DocumentDataset: if self.start_index is None: return self._add_id_fast(dataset) else: diff --git a/nemo_curator/modules/base.py b/nemo_curator/modules/base.py new file mode 100644 index 000000000..260badee1 --- /dev/null +++ b/nemo_curator/modules/base.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from nemo_curator.datasets import DocumentDataset + + +class Module(ABC): + def __init__(self, name=None) -> None: + super().__init__() + self.name = name or self.__class__.__name__ + + @abstractmethod + @property + def input_backend(self) -> str: + raise NotImplementedError( + "input_backend method must be implemented by subclasses" + ) + + @abstractmethod + def call(self, dataset: DocumentDataset) -> DocumentDataset: + raise NotImplementedError("call method must be implemented by subclasses") + + def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + if self.input_backend != "any" and dataset.df.backend != self.input_backend: + raise ValueError( + f"Module {self.name} requires dataset to have backend {self.input_backend} but got backend {dataset.df.backend}" + ) + + return self.call(dataset) diff --git a/nemo_curator/modules/dataset_ops.py b/nemo_curator/modules/dataset_ops.py index 38589b1e9..e5979d4c6 100644 --- a/nemo_curator/modules/dataset_ops.py +++ b/nemo_curator/modules/dataset_ops.py @@ -5,13 +5,14 @@ import numpy as np from nemo_curator.datasets.doc_dataset import DocumentDataset +from nemo_curator.modules.base import Module def default_filename(partition_num: int) -> str: return f"file_{partition_num:010d}.jsonl" -class Shuffle: +class Shuffle(Module): def __init__( self, seed: Optional[int] = None, @@ -36,7 +37,11 @@ def __init__( self.partition_to_filename = partition_to_filename self.rand_col = "_shuffle_rand" - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + @property + def input_backend(self) -> str: + return "pandas" + + def call(self, dataset: DocumentDataset) -> DocumentDataset: if self.seed is None: return self.shuffle_nondeterministic(dataset) else: diff --git a/nemo_curator/modules/exact_dedup.py b/nemo_curator/modules/exact_dedup.py index d274d0e35..121a2a4c1 100644 --- a/nemo_curator/modules/exact_dedup.py +++ b/nemo_curator/modules/exact_dedup.py @@ -29,11 +29,12 @@ from nemo_curator._compat import DASK_P2P_ERROR from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger +from nemo_curator.modules.base import Module from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix from nemo_curator.utils.gpu_utils import is_cudf_type -class ExactDuplicates: +class ExactDuplicates(Module): """Find exact duplicates in a document corpus""" SUPPORTED_HASHES = {"md5"} @@ -83,6 +84,10 @@ def __init__( else: self._logger = logger + @property + def input_backend(self) -> str: + return "any" + def _exact_dup_ids(self, df: dd.DataFrame): """ Get the id's for text/documents that are exact duplicates diff --git a/nemo_curator/modules/fuzzy_dedup.py b/nemo_curator/modules/fuzzy_dedup.py index 63576516c..1d1f0dafa 100644 --- a/nemo_curator/modules/fuzzy_dedup.py +++ b/nemo_curator/modules/fuzzy_dedup.py @@ -37,6 +37,7 @@ from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger +from nemo_curator.modules.base import Module from nemo_curator.modules.config import FuzzyDuplicatesConfig from nemo_curator.modules.meta import Sequential from nemo_curator.utils.distributed_utils import ( @@ -65,7 +66,7 @@ ) -class MinHash: +class MinHash(Module): """ Computes minhash signatures of a document corpus """ @@ -120,6 +121,10 @@ def __init__( else: self._logger = logger + @property + def input_backend(self) -> str: + return "cudf" + def generate_seeds(self, n_seeds: int = 260, seed: int = 0) -> np.ndarray: """ Generate seeds for all minhash permutations based on the given seed. @@ -149,7 +154,7 @@ def minhash64( seeds = cudf.Series(seeds, dtype="uint64") return ser.str.minhash64(seeds=seeds, width=char_ngram) - def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]: + def call(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]: """ Computes the MinHash Signatures for a given dataset. Parameters @@ -187,7 +192,7 @@ def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]: ) -class LSH: +class LSH(Module): """ Performs LSH on a MinhashSignatures """ @@ -245,6 +250,10 @@ def __init__( else: self._logger = logger + @property + def input_backend(self) -> str: + return "cudf" + def _generate_bucket_ranges( self, num_buckets: int, num_hashes: int ) -> List[List[int]]: @@ -366,7 +375,7 @@ def lsh( self._logger.info(f"Wrote data for buckets: {value_vars}") - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + def call(self, dataset: DocumentDataset) -> DocumentDataset: df = dataset.df write_path = os.path.join(self.cache_dir, "_buckets.parquet") @@ -381,7 +390,7 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset: return DocumentDataset(buckets_df) -class FuzzyDuplicates: +class FuzzyDuplicates(Module): def __init__( self, config: FuzzyDuplicatesConfig, @@ -475,7 +484,11 @@ def __init__( profile_dir=self.config.profile_dir, ) - def __call__(self, dataset: DocumentDataset): + @property + def input_backend(self) -> str: + return "cudf" + + def call(self, dataset: DocumentDataset): """ Parameters ---------- diff --git a/nemo_curator/modules/modify.py b/nemo_curator/modules/modify.py index 1307ab177..c9fa59325 100644 --- a/nemo_curator/modules/modify.py +++ b/nemo_curator/modules/modify.py @@ -14,15 +14,20 @@ from nemo_curator.datasets import DocumentDataset from nemo_curator.modifiers import DocumentModifier +from nemo_curator.modules.base import Module from nemo_curator.utils.module_utils import is_batched -class Modify: +class Modify(Module): def __init__(self, modifier: DocumentModifier, text_field="text"): self.modifier = modifier self.text_field = text_field - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + @property + def input_backend(self) -> str: + return "pandas" + + def call(self, dataset: DocumentDataset) -> DocumentDataset: if is_batched(self.modifier.modify_document): dataset.df[self.text_field] = dataset.df[self.text_field].map_partitions( self.modifier.modify_document, meta=(None, str) diff --git a/nemo_curator/modules/semantic_dedup.py b/nemo_curator/modules/semantic_dedup.py index 39f2870ab..8b900c64a 100644 --- a/nemo_curator/modules/semantic_dedup.py +++ b/nemo_curator/modules/semantic_dedup.py @@ -36,6 +36,7 @@ from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger +from nemo_curator.modules.base import Module from nemo_curator.modules.config import SemDedupConfig from nemo_curator.utils.distributed_utils import ( performance_report_if_with_ts_suffix, @@ -573,7 +574,7 @@ def extract_dedup_data(self, eps_to_extract: float) -> DocumentDataset: return DocumentDataset.read_parquet(fps, backend="cudf") -class SemDedup: +class SemDedup(Module): def __init__( self, config: SemDedupConfig, @@ -624,7 +625,11 @@ def __init__( self.eps_thresholds = config.eps_thresholds self.eps_to_extract = config.eps_to_extract - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + @property + def input_backend(self) -> str: + return "cudf" + + def call(self, dataset: DocumentDataset) -> DocumentDataset: """ Execute the SemDedup process. diff --git a/nemo_curator/modules/task.py b/nemo_curator/modules/task.py index 2571b6a8c..2fe82cd3e 100644 --- a/nemo_curator/modules/task.py +++ b/nemo_curator/modules/task.py @@ -20,12 +20,13 @@ from dask import delayed from nemo_curator.datasets import DocumentDataset +from nemo_curator.modules.base import Module from nemo_curator.tasks.downstream_task import DownstreamTask from nemo_curator.utils.distributed_utils import single_partition_write_with_filename from nemo_curator.utils.text_utils import get_words -class TaskDecontamination: +class TaskDecontamination(Module): def __init__( self, tasks: Union[DownstreamTask, Iterable[DownstreamTask]], @@ -58,7 +59,11 @@ def __init__( self.max_splits = max_splits self.removed_dir = removed_dir - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + @property + def input_backend(self) -> str: + return "pandas" + + def call(self, dataset: DocumentDataset) -> DocumentDataset: # Convert the dataframe to delayed objects for complex operations original_meta = dataset.df.dtypes.to_dict() diff --git a/nemo_curator/modules/to_backend.py b/nemo_curator/modules/to_backend.py new file mode 100644 index 000000000..85c977890 --- /dev/null +++ b/nemo_curator/modules/to_backend.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_curator.datasets.doc_dataset import DocumentDataset +from nemo_curator.modules.base import Module + + +class ToBackend(Module): + def __init__(self, backend: str) -> None: + super().__init__() + self.backend = backend + + @property + def input_backend(self) -> str: + return "any" + + def call(self, dataset: DocumentDataset) -> DocumentDataset: + return DocumentDataset(dataset.df.to_backend(self.backend)) From 5cf68a3ce971e05dbb18abe3b1aff79ce7378c81 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 21 Oct 2024 15:34:16 -0700 Subject: [PATCH 02/11] Change how input backend is specified for modules Signed-off-by: Ryan Wolf --- nemo_curator/modules/add_id.py | 5 +---- nemo_curator/modules/base.py | 15 ++++++++------- nemo_curator/modules/dataset_ops.py | 5 +---- nemo_curator/modules/exact_dedup.py | 6 +----- nemo_curator/modules/fuzzy_dedup.py | 15 +++------------ nemo_curator/modules/modify.py | 5 +---- nemo_curator/modules/semantic_dedup.py | 5 +---- nemo_curator/modules/task.py | 5 +---- nemo_curator/modules/to_backend.py | 6 +----- 9 files changed, 18 insertions(+), 49 deletions(-) diff --git a/nemo_curator/modules/add_id.py b/nemo_curator/modules/add_id.py index 894f25bfe..ad5b593ee 100644 --- a/nemo_curator/modules/add_id.py +++ b/nemo_curator/modules/add_id.py @@ -27,14 +27,11 @@ class AddId(Module): def __init__( self, id_field, id_prefix: str = "doc_id", start_index: Optional[int] = None ) -> None: + super().__init__(input_backend="pandas") self.id_field = id_field self.id_prefix = id_prefix self.start_index = start_index - @property - def input_backend(self) -> str: - return "pandas" - def call(self, dataset: DocumentDataset) -> DocumentDataset: if self.start_index is None: return self._add_id_fast(dataset) diff --git a/nemo_curator/modules/base.py b/nemo_curator/modules/base.py index 260badee1..c1e7a83e9 100644 --- a/nemo_curator/modules/base.py +++ b/nemo_curator/modules/base.py @@ -18,16 +18,17 @@ class Module(ABC): - def __init__(self, name=None) -> None: + SUPPORTED_BACKENDS = ["pandas", "cudf", "any"] + + def __init__(self, input_backend: str, name=None) -> None: super().__init__() self.name = name or self.__class__.__name__ - @abstractmethod - @property - def input_backend(self) -> str: - raise NotImplementedError( - "input_backend method must be implemented by subclasses" - ) + if input_backend not in self.SUPPORTED_BACKENDS: + raise ValueError( + f"{input_backend} not one of the supported backends {self.SUPPORTED_BACKENDS}" + ) + self.input_backend = input_backend @abstractmethod def call(self, dataset: DocumentDataset) -> DocumentDataset: diff --git a/nemo_curator/modules/dataset_ops.py b/nemo_curator/modules/dataset_ops.py index e5979d4c6..c26c6b47b 100644 --- a/nemo_curator/modules/dataset_ops.py +++ b/nemo_curator/modules/dataset_ops.py @@ -32,15 +32,12 @@ def __init__( will look like given the partition number. The default method names the partition f'file_{partition_num:010d}.jsonl' and should be changed if the user is not using a .jsonl format. """ + super().__init__(input_backend="pandas") self.seed = seed self.npartitions = npartitions self.partition_to_filename = partition_to_filename self.rand_col = "_shuffle_rand" - @property - def input_backend(self) -> str: - return "pandas" - def call(self, dataset: DocumentDataset) -> DocumentDataset: if self.seed is None: return self.shuffle_nondeterministic(dataset) diff --git a/nemo_curator/modules/exact_dedup.py b/nemo_curator/modules/exact_dedup.py index 121a2a4c1..0df864372 100644 --- a/nemo_curator/modules/exact_dedup.py +++ b/nemo_curator/modules/exact_dedup.py @@ -60,7 +60,7 @@ def __init__( cache_dir: str, Default None If specified, will compute & write duplicate id's to cache directory. """ - + super().__init__(input_backend="any") if hash_method not in self.SUPPORTED_HASHES: raise ValueError( f"{hash_method} not in supported hash_methods. Choose a hash_method from {self.SUPPORTED_HASHES}" @@ -84,10 +84,6 @@ def __init__( else: self._logger = logger - @property - def input_backend(self) -> str: - return "any" - def _exact_dup_ids(self, df: dd.DataFrame): """ Get the id's for text/documents that are exact duplicates diff --git a/nemo_curator/modules/fuzzy_dedup.py b/nemo_curator/modules/fuzzy_dedup.py index 1d1f0dafa..af623a9a2 100644 --- a/nemo_curator/modules/fuzzy_dedup.py +++ b/nemo_curator/modules/fuzzy_dedup.py @@ -98,6 +98,7 @@ def __init__( cache_dir: str, Default None If specified, will compute & write id, minhash pairs to directory """ + super().__init__(input_backend="cudf") self.num_hashes = num_hashes self.char_ngram = char_ngrams self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed) @@ -121,10 +122,6 @@ def __init__( else: self._logger = logger - @property - def input_backend(self) -> str: - return "cudf" - def generate_seeds(self, n_seeds: int = 260, seed: int = 0) -> np.ndarray: """ Generate seeds for all minhash permutations based on the given seed. @@ -225,6 +222,7 @@ def __init__( profile_dir: str, Default None If specified directory to write dask profile """ + super().__init__(input_backend="cudf") self.num_hashes = num_hashes self.num_buckets = num_buckets self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields @@ -250,10 +248,6 @@ def __init__( else: self._logger = logger - @property - def input_backend(self) -> str: - return "cudf" - def _generate_bucket_ranges( self, num_buckets: int, num_hashes: int ) -> List[List[int]]: @@ -408,6 +402,7 @@ def __init__( DocumentDataset containing IDs of all documents and the corresponding duplicate group they belong to. Documents in the same group are near duplicates. """ + super().__init__(input_backend="cudf") if isinstance(logger, str): self._logger = create_logger( rank=0, @@ -484,10 +479,6 @@ def __init__( profile_dir=self.config.profile_dir, ) - @property - def input_backend(self) -> str: - return "cudf" - def call(self, dataset: DocumentDataset): """ Parameters diff --git a/nemo_curator/modules/modify.py b/nemo_curator/modules/modify.py index c9fa59325..d1958f9d8 100644 --- a/nemo_curator/modules/modify.py +++ b/nemo_curator/modules/modify.py @@ -20,13 +20,10 @@ class Modify(Module): def __init__(self, modifier: DocumentModifier, text_field="text"): + super().__init__(input_backend="pandas") self.modifier = modifier self.text_field = text_field - @property - def input_backend(self) -> str: - return "pandas" - def call(self, dataset: DocumentDataset) -> DocumentDataset: if is_batched(self.modifier.modify_document): dataset.df[self.text_field] = dataset.df[self.text_field].map_partitions( diff --git a/nemo_curator/modules/semantic_dedup.py b/nemo_curator/modules/semantic_dedup.py index 8b900c64a..a98e04068 100644 --- a/nemo_curator/modules/semantic_dedup.py +++ b/nemo_curator/modules/semantic_dedup.py @@ -587,6 +587,7 @@ def __init__( config (SemDedupConfig): Configuration for SemDedup. logger (Union[logging.Logger, str]): Logger instance or path to the log file directory. """ + super().__init__(input_backend="cudf") self.config = config self.logger = logger cache_dir = config.cache_dir @@ -625,10 +626,6 @@ def __init__( self.eps_thresholds = config.eps_thresholds self.eps_to_extract = config.eps_to_extract - @property - def input_backend(self) -> str: - return "cudf" - def call(self, dataset: DocumentDataset) -> DocumentDataset: """ Execute the SemDedup process. diff --git a/nemo_curator/modules/task.py b/nemo_curator/modules/task.py index 2fe82cd3e..2e94013cd 100644 --- a/nemo_curator/modules/task.py +++ b/nemo_curator/modules/task.py @@ -48,6 +48,7 @@ def __init__( max_splits: The maximum number of times a document may be split before being entirely discarded. removed_dir: If not None, the documents split too many times will be written to this directory using the filename in the dataset. """ + super().__init__(input_backend="pandas") if isinstance(tasks, DownstreamTask): tasks = [tasks] self.tasks = tasks @@ -59,10 +60,6 @@ def __init__( self.max_splits = max_splits self.removed_dir = removed_dir - @property - def input_backend(self) -> str: - return "pandas" - def call(self, dataset: DocumentDataset) -> DocumentDataset: # Convert the dataframe to delayed objects for complex operations diff --git a/nemo_curator/modules/to_backend.py b/nemo_curator/modules/to_backend.py index 85c977890..f15619241 100644 --- a/nemo_curator/modules/to_backend.py +++ b/nemo_curator/modules/to_backend.py @@ -18,12 +18,8 @@ class ToBackend(Module): def __init__(self, backend: str) -> None: - super().__init__() + super().__init__(input_backend="any") self.backend = backend - @property - def input_backend(self) -> str: - return "any" - def call(self, dataset: DocumentDataset) -> DocumentDataset: return DocumentDataset(dataset.df.to_backend(self.backend)) From 0c89c70dc8677c493a9be6afc9d5b5fb8cb6f964 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 22 Oct 2024 09:46:31 -0700 Subject: [PATCH 03/11] Refactor out common functions in filter Signed-off-by: Ryan Wolf --- nemo_curator/filters/base.py | 134 +++++++++++++++++------------------ 1 file changed, 64 insertions(+), 70 deletions(-) diff --git a/nemo_curator/filters/base.py b/nemo_curator/filters/base.py index c48d9809d..943dcf1b8 100644 --- a/nemo_curator/filters/base.py +++ b/nemo_curator/filters/base.py @@ -15,7 +15,7 @@ import importlib from abc import ABC, abstractmethod from enum import Enum -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Type, Union from nemo_curator.datasets import DocumentDataset from nemo_curator.modules.base import Module @@ -39,14 +39,25 @@ class DocumentFilter(Module, ABC): def __init__( self, + score_type: Union[str, Type], text_fields: List[str] = ["text"], score_fields: List[str] = ["score"], filter_mode: FilterMode = FilterMode.SCORE_FILTER, removed_path: Optional[str] = None, invert: bool = False, save_score: bool = True, + input_backend: str = "pandas", ): - super().__init__() + """ + text_fields: If len(text_fields) == 1, then score_document will + get a series instead of a dataframe. You may still output + multiple scores in the form of a dataframe. Need to verify if that's possible with Dask though. + score_fields: If len(score_fields) == 1, then score_document + must output a series instead of a dataframe. keep_document + must accept a series instead of a dataframe. keep_document must always return a series. + """ + super().__init__(input_backend=input_backend) + self.score_type = score_type self.text_fields = text_fields self.score_fields = score_fields self.removed_path = removed_path @@ -54,10 +65,6 @@ def __init__( self.filter_mode = filter_mode self.save_score = save_score - @property - def input_backend(self) -> str: - return "pandas" - @abstractmethod def score_document(self, text: str) -> Any: """ @@ -104,78 +111,65 @@ def keep_document(self, scores: Any) -> bool: "keep_document method must be implemented by subclasses" ) - @abstractmethod - @property - def score_type(self): - raise NotImplementedError( - "keep_document method must be implemented by subclasses" + def _score_dataset(self, dataset: DocumentDataset): + meta = (None, self.score_type) + # Get the field name directly if there's only one + text_fields = ( + self.text_fields if len(self.text_fields) > 1 else self.text_fields[0] ) + if is_batched(self.score_document): + scores = dataset.df[text_fields].map_partitions( + self.score_document, meta=meta + ) + else: + scores = dataset.df[text_fields].apply(self.score_document, meta=meta) + + if self.save_score: + score_fields = ( + self.score_fields + if len(self.score_fields) > 1 + else self.score_fields[0] + ) + dataset.df[score_fields] = scores + + return scores + + def _filter_dataset(self, dataset: DocumentDataset, scores): + if is_batched(self.keep_document): + bool_mask = scores.map_partitions(self.keep_document, meta=(None, bool)) + else: + bool_mask = scores.apply(self.keep_document, meta=(None, bool)) + if self.invert: + bool_mask = ~bool_mask + + if self.removed_path: + removed_docs = DocumentDataset(dataset.df[~bool_mask]) + removed_docs.to_parquet(output_file_dir=self.removed_path) + + return bool_mask + + def compute_filter_mask(self, dataset: DocumentDataset): + scores = self._score_dataset(dataset) + return self._filter_dataset(dataset, scores) + def call(self, dataset: DocumentDataset) -> DocumentDataset: match self.filter_mode: case FilterMode.SCORE: - meta = (None, self.score_type) - - if is_batched(self.score_document): - scores = dataset.df[self.text_fields].map_partitions( - self.score_document, meta=meta - ) - else: - scores = dataset.df[self.text_fields].apply( - self.score_document, meta=meta - ) - - if self.save_score: - dataset.df[self.score_fields] = scores - + self._score_dataset(dataset) return dataset case FilterMode.FILTER: - scores = dataset.df[self.score_fields] - - if is_batched(self.keep_document): - bool_mask = scores.map_partitions( - self.keep_document, meta=(None, bool) - ) - else: - bool_mask = scores.apply(self.keep_document, meta=(None, bool)) - if self.invert: - bool_mask = ~bool_mask - - if self.removed_path: - removed_docs = DocumentDataset(dataset.df[~bool_mask]) - removed_docs.to_parquet(output_file_dir=self.removed_path) - - return DocumentDataset(dataset.df[bool_mask]) - + score_fields = ( + self.score_fields + if len(self.score_fields) > 1 + else self.score_fields[0] + ) + scores = dataset.df[score_fields] + mask = self._filter_dataset(dataset, scores) + return DocumentDataset(dataset.df[mask]) case FilterMode.SCORE_FILTER: - meta = (None, self.score_type) - - if is_batched(self.score_document): - scores = dataset.df[self.text_fields].map_partitions( - self.score_document, meta=meta - ) - else: - scores = dataset.df[self.text_fields].apply( - self.score_document, meta=meta - ) - - if self.save_score: - dataset.df[self.score_fields] = scores - - if is_batched(self.keep_document): - bool_mask = scores.map_partitions( - self.keep_document, meta=(None, bool) - ) - else: - bool_mask = scores.apply(self.keep_document, meta=(None, bool)) - if self.invert: - bool_mask = ~bool_mask - - if self.removed_path: - removed_docs = DocumentDataset(dataset.df[~bool_mask]) - removed_docs.to_parquet(output_file_dir=self.removed_path) - - return DocumentDataset(dataset.df[bool_mask]) + mask = self.compute_filter_mask(dataset) + return DocumentDataset(dataset.df[mask]) def import_filter(filter_path: str) -> DocumentFilter: From e365397676d9aaeef5436f6c745eed9cd32c2a57 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 23 Oct 2024 15:03:28 -0700 Subject: [PATCH 04/11] Refactor modifier Signed-off-by: Ryan Wolf --- nemo_curator/modifiers/base.py | 54 ++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 nemo_curator/modifiers/base.py diff --git a/nemo_curator/modifiers/base.py b/nemo_curator/modifiers/base.py new file mode 100644 index 000000000..45f7c1637 --- /dev/null +++ b/nemo_curator/modifiers/base.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import List + +from nemo_curator.datasets import DocumentDataset +from nemo_curator.modules.base import Module +from nemo_curator.utils.module_utils import is_batched + + +class DocumentModifier(Module, ABC): + def __init__( + self, + text_fields: List[str] = ["text"], + meta=(None, str), + input_backend: str = "pandas", + ): + super().__init__(input_backend=input_backend) + self.text_fields = text_fields + self.meta = meta + + @abstractmethod + def modify_document(self, text): + raise NotImplementedError( + "score_document method must be implemented by subclasses" + ) + + def call(self, dataset: DocumentDataset) -> DocumentDataset: + text_fields = ( + self.text_fields if len(self.text_fields) > 1 else self.text_fields[0] + ) + + if is_batched(self.modify_document): + dataset.df[text_fields] = dataset.df[text_fields].map_partitions( + self.modify_document, meta=self.meta + ) + else: + dataset.df[text_fields] = dataset.df[text_fields].apply( + self.modify_document, meta=self.meta + ) + + return dataset From 3ef422395b06e61a575c63626379c74fa8b6762b Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 23 Oct 2024 15:38:56 -0700 Subject: [PATCH 05/11] Change backend validation Signed-off-by: Ryan Wolf --- nemo_curator/modules/base.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/nemo_curator/modules/base.py b/nemo_curator/modules/base.py index c1e7a83e9..ccdae9005 100644 --- a/nemo_curator/modules/base.py +++ b/nemo_curator/modules/base.py @@ -34,10 +34,18 @@ def __init__(self, input_backend: str, name=None) -> None: def call(self, dataset: DocumentDataset) -> DocumentDataset: raise NotImplementedError("call method must be implemented by subclasses") - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: - if self.input_backend != "any" and dataset.df.backend != self.input_backend: + def _check_backend(self, partition, partition_info=None): + if partition_info is None: + return + + backend = type(partition).__module__.split(".")[0] + if backend != self.input_backend: raise ValueError( - f"Module {self.name} requires dataset to have backend {self.input_backend} but got backend {dataset.df.backend}" + f"Module {self.name} requires dataset to have backend {self.input_backend} but got backend {backend}" ) + def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + if self.input_backend != "any": + dataset.df.map_partitions(self._check_backend) + return self.call(dataset) From 6ce4dcdaad878dcf63cb7e163079f164cb1c3b92 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 23 Oct 2024 15:50:07 -0700 Subject: [PATCH 06/11] Add axis arg to dataframe apply Signed-off-by: Ryan Wolf --- nemo_curator/filters/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nemo_curator/filters/base.py b/nemo_curator/filters/base.py index 943dcf1b8..8e1cb9e4f 100644 --- a/nemo_curator/filters/base.py +++ b/nemo_curator/filters/base.py @@ -123,7 +123,9 @@ def _score_dataset(self, dataset: DocumentDataset): self.score_document, meta=meta ) else: - scores = dataset.df[text_fields].apply(self.score_document, meta=meta) + scores = dataset.df[text_fields].apply( + self.score_document, axis=1, meta=meta + ) if self.save_score: score_fields = ( @@ -139,7 +141,7 @@ def _filter_dataset(self, dataset: DocumentDataset, scores): if is_batched(self.keep_document): bool_mask = scores.map_partitions(self.keep_document, meta=(None, bool)) else: - bool_mask = scores.apply(self.keep_document, meta=(None, bool)) + bool_mask = scores.apply(self.keep_document, axis=1, meta=(None, bool)) if self.invert: bool_mask = ~bool_mask From 2ef203a15f9b6ef6ad2c93110963d1551d21564a Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 23 Oct 2024 16:03:14 -0700 Subject: [PATCH 07/11] Dynamically adapt axis based on data type Signed-off-by: Ryan Wolf --- nemo_curator/filters/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nemo_curator/filters/base.py b/nemo_curator/filters/base.py index 8e1cb9e4f..89e8e85ea 100644 --- a/nemo_curator/filters/base.py +++ b/nemo_curator/filters/base.py @@ -17,6 +17,8 @@ from enum import Enum from typing import Any, List, Optional, Type, Union +import dask.dataframe as dd + from nemo_curator.datasets import DocumentDataset from nemo_curator.modules.base import Module from nemo_curator.utils.module_utils import is_batched @@ -123,6 +125,7 @@ def _score_dataset(self, dataset: DocumentDataset): self.score_document, meta=meta ) else: + axis = 1 if len(self.text_fields) > 1 else 0 scores = dataset.df[text_fields].apply( self.score_document, axis=1, meta=meta ) @@ -141,7 +144,8 @@ def _filter_dataset(self, dataset: DocumentDataset, scores): if is_batched(self.keep_document): bool_mask = scores.map_partitions(self.keep_document, meta=(None, bool)) else: - bool_mask = scores.apply(self.keep_document, axis=1, meta=(None, bool)) + axis = 1 if isinstance(scores, dd.DataFrame) else 0 + bool_mask = scores.apply(self.keep_document, axis=axis, meta=(None, bool)) if self.invert: bool_mask = ~bool_mask From 6635602c32905f80aff8cba780745f5ec6ac390b Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 24 Oct 2024 14:48:20 -0700 Subject: [PATCH 08/11] Use axis in score apply Signed-off-by: Ryan Wolf --- nemo_curator/filters/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_curator/filters/base.py b/nemo_curator/filters/base.py index 89e8e85ea..0c79a94c4 100644 --- a/nemo_curator/filters/base.py +++ b/nemo_curator/filters/base.py @@ -127,7 +127,7 @@ def _score_dataset(self, dataset: DocumentDataset): else: axis = 1 if len(self.text_fields) > 1 else 0 scores = dataset.df[text_fields].apply( - self.score_document, axis=1, meta=meta + self.score_document, axis=axis, meta=meta ) if self.save_score: From 5d4f9ee1c4d0b140e9cb61204781559c4b500ed9 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 25 Oct 2024 11:31:20 -0700 Subject: [PATCH 09/11] Add support for multiple scores in doc filter Signed-off-by: Ryan Wolf --- nemo_curator/filters/base.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/nemo_curator/filters/base.py b/nemo_curator/filters/base.py index 0c79a94c4..a1d63768c 100644 --- a/nemo_curator/filters/base.py +++ b/nemo_curator/filters/base.py @@ -41,7 +41,7 @@ class DocumentFilter(Module, ABC): def __init__( self, - score_type: Union[str, Type], + score_types: Union[List[str], List[Type]], text_fields: List[str] = ["text"], score_fields: List[str] = ["score"], filter_mode: FilterMode = FilterMode.SCORE_FILTER, @@ -59,7 +59,18 @@ def __init__( must accept a series instead of a dataframe. keep_document must always return a series. """ super().__init__(input_backend=input_backend) - self.score_type = score_type + + if len(score_types) != len(score_fields): + raise ValueError( + f"score_types must have the same length as score_fields, but got {len(score_types)} and {len(score_fields)}." + ) + + if len(score_fields) > 1 and not is_batched(self.score_document): + raise ValueError( + f"When outputing multiple scores ({len(score_fields)} in this case), score_document must be defined in @batched mode." + ) + + self.score_types = score_types self.text_fields = text_fields self.score_fields = score_fields self.removed_path = removed_path @@ -114,7 +125,11 @@ def keep_document(self, scores: Any) -> bool: ) def _score_dataset(self, dataset: DocumentDataset): - meta = (None, self.score_type) + meta = ( + list(zip(self.score_fields, self.score_types)) + if len(self.score_fields) > 1 + else (None, self.score_types[0]) + ) # Get the field name directly if there's only one text_fields = ( self.text_fields if len(self.text_fields) > 1 else self.text_fields[0] @@ -131,12 +146,11 @@ def _score_dataset(self, dataset: DocumentDataset): ) if self.save_score: - score_fields = ( - self.score_fields - if len(self.score_fields) > 1 - else self.score_fields[0] - ) - dataset.df[score_fields] = scores + if len(self.score_fields) > 1: + dataset.df = dd.concat([dataset.df, scores], axis=1) + else: + score_fields = self.score_fields[0] + dataset.df[score_fields] = scores return scores From 513e75f3580ec584ea60c3f83e7aeb914a2a6579 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 25 Oct 2024 15:11:49 -0700 Subject: [PATCH 10/11] Change the inheritence of the filters Signed-off-by: Ryan Wolf --- nemo_curator/filters/__init__.py | 2 +- nemo_curator/filters/base.py | 2 +- nemo_curator/filters/classifier_filter.py | 75 +++- nemo_curator/filters/code.py | 162 +++++++- nemo_curator/filters/heuristic_filter.py | 466 +++++++++++++++++++--- 5 files changed, 636 insertions(+), 71 deletions(-) diff --git a/nemo_curator/filters/__init__.py b/nemo_curator/filters/__init__.py index 4eb800992..61890983d 100644 --- a/nemo_curator/filters/__init__.py +++ b/nemo_curator/filters/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .base import DocumentFilter, import_filter from .classifier_filter import FastTextLangId, FastTextQualityFilter from .code import ( AlphaFilter, @@ -23,7 +24,6 @@ TokenizerFertilityFilter, XMLHeaderFilter, ) -from .doc_filter import DocumentFilter, import_filter from .heuristic_filter import ( BoilerPlateStringFilter, BulletsFilter, diff --git a/nemo_curator/filters/base.py b/nemo_curator/filters/base.py index a1d63768c..0f52d303c 100644 --- a/nemo_curator/filters/base.py +++ b/nemo_curator/filters/base.py @@ -211,6 +211,6 @@ def import_filter(filter_path: str) -> DocumentFilter: if not issubclass(filter_class, DocumentFilter): raise ValueError( f"Input filter {filter_class.__name__} must be derived " - "from DocumentFilter defined in nemo_curator.filters.doc_filter" + "from DocumentFilter defined in nemo_curator.filters.base" ) return filter_class diff --git a/nemo_curator/filters/classifier_filter.py b/nemo_curator/filters/classifier_filter.py index 741df9640..12fde411d 100644 --- a/nemo_curator/filters/classifier_filter.py +++ b/nemo_curator/filters/classifier_filter.py @@ -12,19 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import dask import fasttext import numpy as np import pandas as pd -from nemo_curator.filters.doc_filter import DocumentFilter +from nemo_curator.filters.base import DocumentFilter, FilterMode from nemo_curator.utils.decorators import batched from nemo_curator.utils.distributed_utils import NoWorkerError, load_object_on_worker class FastTextQualityFilter(DocumentFilter): - - def __init__(self, model_path=None, label="__label__hq", alpha=3, seed=42): + def __init__( + self, + model_path=None, + label="__label__hq", + alpha=3, + seed=42, + text_field: str = "text", + score_field: str = "fasttext_quality_score", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) if model_path is None: raise ValueError( "Must provide a valid path to a FastText model " @@ -64,38 +86,65 @@ def _load_model(self): class FastTextLangId(DocumentFilter): - - def __init__(self, model_path=None, min_langid_score=0.3): + def __init__( + self, + model_path=None, + min_langid_score=0.3, + text_field: str = "text", + score_field: str = "language_score", + lang_field: str = "language", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field, lang_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) if model_path is None: raise ValueError( "Must provide a valid path to a FastText model " "to identify languages with this filter" ) + self.score_field = score_field + self.lang_field = lang_field self._model_path = model_path self._lang_code = None self._cutoff = min_langid_score self._name = "lang_id" @batched - def score_document(self, df: pd.Series): + def score_document(self, series: pd.Series): model_attr = f"{self._name}_{self._model_path}" try: model = load_object_on_worker(model_attr, self._load_model, {}) except NoWorkerError: - return pd.Series([[1.0, "N/A"] for _ in range(len(df))]) + return pd.Series([[1.0, "N/A"] for _ in range(len(series))]) - def _score_document(text): - pp = text.strip().replace("\n", " ") - label, score = model.predict(pp, k=1) + processed_series = series.str.strip().str.replace("\n", " ") + scores = [] + lang_codes = [] + for text in processed_series: + label, score = model.predict(text, k=1) score = score[0] lang_code = label[0][-2:].upper() - return [score, lang_code] + scores.append(score) + lang_codes.append(lang_code) - return df.apply(_score_document) + return pd.DataFrame( + {self.score_field: scores, self.lang_field: lang_codes}, index=series.index + ) + @batched def keep_document(self, score): - return score[0] >= self._cutoff + return score[self.score_field] >= self._cutoff def _load_model(self): return fasttext.load_model(self._model_path) diff --git a/nemo_curator/filters/code.py b/nemo_curator/filters/code.py index a05bb2b6b..342f21930 100644 --- a/nemo_curator/filters/code.py +++ b/nemo_curator/filters/code.py @@ -14,22 +14,37 @@ import csv import warnings +from typing import Optional from bs4 import BeautifulSoup from comment_parser import comment_parser -from nemo_curator.filters.doc_filter import DocumentFilter, import_filter +from nemo_curator.filters.base import DocumentFilter, FilterMode from nemo_curator.utils.constants import regex_alpha, regex_alphanum from nemo_curator.utils.text_utils import get_comments_and_docstring class PythonCommentToCodeFilter(DocumentFilter): - def __init__( self, min_comment_to_code_ratio=0.01, max_comment_to_code_ratio=0.85, + text_field: str = "text", + score_field: str = "python_comment_ratio", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._min_threshold = min_comment_to_code_ratio self._max_threshold = max_comment_to_code_ratio @@ -48,13 +63,27 @@ def keep_document(self, score): class GeneralCommentToCodeFilter(DocumentFilter): - def __init__( self, language, min_comment_to_code_ratio=0.01, max_comment_to_code_ratio=0.85, + text_field: str = "text", + score_field: str = "comment_ratio", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) """ Does not include the comment characters (// or /**/) towards the length of the comment. Args: @@ -85,8 +114,26 @@ def keep_document(self, score): class NumberOfLinesOfCodeFilter(DocumentFilter): - - def __init__(self, min_lines=10, max_lines=20000): + def __init__( + self, + min_lines=10, + max_lines=20000, + text_field: str = "text", + score_field: str = "num_lines", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [int], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._min_lines = min_lines self._max_lines = max_lines @@ -100,8 +147,26 @@ def keep_document(self, score): class TokenizerFertilityFilter(DocumentFilter): - - def __init__(self, path_to_tokenizer=None, min_char_to_token_ratio=2.5): + def __init__( + self, + path_to_tokenizer=None, + min_char_to_token_ratio=2.5, + text_field: str = "text", + score_field: str = "tokenizer_fertility", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) from nemo.collections.common.tokenizers import SentencePieceTokenizer if path_to_tokenizer is None: @@ -133,7 +198,25 @@ class XMLHeaderFilter(DocumentFilter): (Source: Starcoder https://arxiv.org/abs/2305.06161) """ - def __init__(self, char_prefix_search_length=100): + def __init__( + self, + char_prefix_search_length=100, + text_field: str = "text", + score_field: str = "xml_header", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [int], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._char_prefix_search_length = char_prefix_search_length self._name = "xml_header" @@ -156,7 +239,25 @@ class AlphaFilter(DocumentFilter): (Source: Starcoder https://arxiv.org/abs/2305.06161) """ - def __init__(self, min_alpha_ratio=0.25): + def __init__( + self, + min_alpha_ratio=0.25, + text_field: str = "text", + score_field: str = "alpha_filter", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._min_alpha_ratio = min_alpha_ratio self._name = "alpha_filter" @@ -173,7 +274,26 @@ class HTMLBoilerplateFilter(DocumentFilter): This filter tries to identify HTML that is largely boilerplate. """ - def __init__(self, min_lang_content_ratio=0.2, min_lang_content_num_chars=100): + def __init__( + self, + min_lang_content_ratio=0.2, + min_lang_content_num_chars=100, + text_field: str = "text", + score_field: str = "html_boilerplate", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._min_lang_content_ratio = min_lang_content_ratio self._min_lang_content_num_chars = min_lang_content_num_chars @@ -207,7 +327,27 @@ class PerExtensionFilter(DocumentFilter): This filter that has specific conditions depending on the file extension. """ - def __init__(self, lang, extension, metadata_file="code_meta.csv"): + def __init__( + self, + lang, + extension, + metadata_file="code_meta.csv", + text_field: str = "text", + score_field: str = "per_extension_filter", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._metadata_file = metadata_file self._lang = lang self._extension = extension diff --git a/nemo_curator/filters/heuristic_filter.py b/nemo_curator/filters/heuristic_filter.py index a4c876ab1..116dcb3f2 100644 --- a/nemo_curator/filters/heuristic_filter.py +++ b/nemo_curator/filters/heuristic_filter.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import regex +from typing import Optional -from nemo_curator.filters.doc_filter import DocumentFilter, import_filter +from nemo_curator.filters.base import DocumentFilter, FilterMode from nemo_curator.utils.constants import ( bullet_list, common_english_words, @@ -34,7 +34,6 @@ get_paragraphs, get_sentences, get_word_splitter, - is_paragraph_indices_in_top_or_bottom_only, ) @@ -45,8 +44,25 @@ class NonAlphaNumericFilter(DocumentFilter): Source: Adapted from Gopher (Rae et al., 2021) """ - def __init__(self, max_non_alpha_numeric_to_text_ratio: float = 0.25): - super().__init__() + def __init__( + self, + max_non_alpha_numeric_to_text_ratio: float = 0.25, + text_field: str = "text", + score_field: str = "alpha_numeric", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._cutoff = max_non_alpha_numeric_to_text_ratio self._name = "alpha_numeric" @@ -70,8 +86,26 @@ class SymbolsToWordsFilter(DocumentFilter): Source: Gopher (Rae et al., 2021) """ - def __init__(self, max_symbol_to_word_ratio=0.1, lang="en"): - super().__init__() + def __init__( + self, + max_symbol_to_word_ratio=0.1, + lang="en", + text_field: str = "text", + score_field: str = "symbol_to_word", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._cutoff = max_symbol_to_word_ratio self._word_splitter = get_word_splitter(lang) self._name = "symbol_to_word" @@ -96,8 +130,25 @@ class NumbersFilter(DocumentFilter): If more than 15% of the document contains numbers then discard """ - def __init__(self, max_number_to_text_ratio=0.15): - super().__init__() + def __init__( + self, + max_number_to_text_ratio=0.15, + text_field: str = "text", + score_field: str = "numbers_ratio", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._cutoff = max_number_to_text_ratio self._name = "numbers_ratio" @@ -119,8 +170,25 @@ class UrlsFilter(DocumentFilter): If more than 20% of the document is comprised of URLs then discard """ - def __init__(self, max_url_to_text_ratio=0.2): - super().__init__() + def __init__( + self, + max_url_to_text_ratio=0.2, + text_field: str = "text", + score_field: str = "urls_ratio", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._cutoff = max_url_to_text_ratio self._name = "urls_ratio" @@ -145,8 +213,25 @@ class BulletsFilter(DocumentFilter): Source: Gopher (Rae et al., 2021) """ - def __init__(self, max_bullet_lines_ratio=0.9): - super().__init__() + def __init__( + self, + max_bullet_lines_ratio=0.9, + text_field: str = "text", + score_field: str = "bullet_ratio", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._cutoff = max_bullet_lines_ratio self._sentences = None self._name = "bullet_ratio" @@ -174,8 +259,25 @@ class WhiteSpaceFilter(DocumentFilter): of white space characters then discard """ - def __init__(self, max_white_space_ratio=0.25): - super().__init__() + def __init__( + self, + max_white_space_ratio=0.25, + text_field: str = "text", + score_field: str = "white_space", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._cutoff = max_white_space_ratio self._name = "white_space" @@ -199,8 +301,25 @@ class ParenthesesFilter(DocumentFilter): If more than 10% of the sentence is in parentheses then discard """ - def __init__(self, max_parentheses_ratio=0.1): - super().__init__() + def __init__( + self, + max_parentheses_ratio=0.1, + text_field: str = "text", + score_field: str = "parentheses_ratio", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._max_parentheses_ratio = max_parentheses_ratio self._name = "parentheses_ratio" @@ -225,8 +344,26 @@ class LongWordFilter(DocumentFilter): Source: C4 (Google) """ - def __init__(self, max_word_length=1000, lang="en"): - super().__init__() + def __init__( + self, + max_word_length=1000, + lang="en", + text_field: str = "text", + score_field: str = "max_word_length", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._max_word_length = max_word_length self._word_splitter = get_word_splitter(lang) self._name = "max_word_length" @@ -244,8 +381,27 @@ class WordCountFilter(DocumentFilter): within a specified range then discard """ - def __init__(self, min_words=50, max_words=100000, lang="en"): - super().__init__() + def __init__( + self, + min_words=50, + max_words=100000, + lang="en", + text_field: str = "text", + score_field: str = "word_count", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._min_words = min_words self._max_words = max_words self._word_splitter = get_word_splitter(lang) @@ -269,8 +425,22 @@ def __init__( self, remove_if_at_top_or_bottom=True, max_boilerplate_string_ratio=0.4, + text_field: str = "text", + score_field: str = "boilerplate_string_ratio", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, ): - super().__init__() + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._remove_if_at_top_or_bottom = remove_if_at_top_or_bottom self._max_boilerplate_string_ratio = max_boilerplate_string_ratio self._boilerplate_paragraph_indices = [] @@ -308,8 +478,22 @@ def __init__( min_mean_word_length=3, max_mean_word_length=10, lang="en", + text_field: str = "text", + score_field: str = "mean_word_length", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, ): - super().__init__() + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._min_cutoff = min_mean_word_length self._max_cutoff = max_mean_word_length self._word_splitter = get_word_splitter(lang) @@ -330,8 +514,25 @@ class RepeatedLinesFilter(DocumentFilter): Source: Gopher (Rae et al., 2021) """ - def __init__(self, max_repeated_line_fraction=0.7): - super().__init__() + def __init__( + self, + max_repeated_line_fraction=0.7, + text_field: str = "text", + score_field: str = "repeated_lines", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._cutoff = max_repeated_line_fraction self._name = "repeated_lines" @@ -352,8 +553,25 @@ class RepeatedParagraphsFilter(DocumentFilter): Source: Gopher (Rae et al., 2021) """ - def __init__(self, max_repeated_paragraphs_ratio=0.7): - super().__init__() + def __init__( + self, + max_repeated_paragraphs_ratio=0.7, + text_field: str = "text", + score_field: str = "repeated_paragraphs", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._max_repeated_paragraphs_ratio = max_repeated_paragraphs_ratio self._name = "repeated_paragraphs" @@ -374,8 +592,25 @@ class RepeatedLinesByCharFilter(DocumentFilter): Source: Gopher (Rae et al., 2021) """ - def __init__(self, max_repeated_lines_char_ratio=0.8): - super().__init__() + def __init__( + self, + max_repeated_lines_char_ratio=0.8, + text_field: str = "text", + score_field: str = "repeated_lines_char", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._cutoff = max_repeated_lines_char_ratio self._name = "repeated_lines_char" @@ -397,8 +632,25 @@ class RepeatedParagraphsByCharFilter(DocumentFilter): Source: Gopher (Rae et al., 2021) """ - def __init__(self, max_repeated_paragraphs_char_ratio=0.8): - super().__init__() + def __init__( + self, + max_repeated_paragraphs_char_ratio=0.8, + text_field: str = "text", + score_field: str = "repeated_paragraphs_char", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._cutoff = max_repeated_paragraphs_char_ratio self._name = "repeated_paragraphs_char" @@ -420,8 +672,27 @@ class RepeatingTopNGramsFilter(DocumentFilter): Source: Gopher (Rae et al., 2021) """ - def __init__(self, n=2, max_repeating_ngram_ratio=0.2, lang="en"): - super().__init__() + def __init__( + self, + n=2, + max_repeating_ngram_ratio=0.2, + lang="en", + text_field: str = "text", + score_field: str = "repeating_top_ngrams", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._n = n self._cutoff = max_repeating_ngram_ratio self._max_ratio = 1.0 @@ -466,8 +737,27 @@ class RepeatingDuplicateNGramsFilter(DocumentFilter): Source: Gopher (Rae et al., 2021) """ - def __init__(self, n=2, max_repeating_duplicate_ngram_ratio=0.2, lang="en"): - super().__init__() + def __init__( + self, + n=2, + max_repeating_duplicate_ngram_ratio=0.2, + lang="en", + text_field: str = "text", + score_field: str = "repeating_dup_ngram", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._n = n self._cutoff = max_repeating_duplicate_ngram_ratio self._max_ratio = 1.0 @@ -517,8 +807,25 @@ class PunctuationFilter(DocumentFilter): Source: Google C4 processing """ - def __init__(self, max_num_sentences_without_endmark_ratio=0.85): - super().__init__() + def __init__( + self, + max_num_sentences_without_endmark_ratio=0.85, + text_field: str = "text", + score_field: str = "punctuation", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._cutoff = max_num_sentences_without_endmark_ratio self._name = "punctuation" @@ -541,8 +848,25 @@ class EllipsisFilter(DocumentFilter): Source: Google C4 processing """ - def __init__(self, max_num_lines_ending_with_ellipsis_ratio=0.3): - super().__init__() + def __init__( + self, + max_num_lines_ending_with_ellipsis_ratio=0.3, + text_field: str = "text", + score_field: str = "ellipsis", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._cutoff = max_num_lines_ending_with_ellipsis_ratio self._name = "ellipsis" @@ -569,8 +893,26 @@ class CommonEnglishWordsFilter(DocumentFilter): to remove documents with over-capitalization. """ - def __init__(self, min_num_common_words=2, stop_at_false=True): - super().__init__() + def __init__( + self, + min_num_common_words=2, + stop_at_false=True, + text_field: str = "text", + score_field: str = "common_english_words", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._cutoff = min_num_common_words self._stop_at_false = stop_at_false self._word_splitter = get_word_splitter("en") @@ -596,8 +938,26 @@ class WordsWithoutAlphabetsFilter(DocumentFilter): Source: Gopher (Rae et al., 2021) """ - def __init__(self, min_words_with_alphabets=0.8, lang="en"): - super().__init__() + def __init__( + self, + min_words_with_alphabets=0.8, + lang="en", + text_field: str = "text", + score_field: str = "words_without_alphabets", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) self._cutoff = min_words_with_alphabets self._word_splitter = get_word_splitter(lang) self._name = "words_without_alphabets" @@ -620,8 +980,24 @@ class PornographicUrlsFilter(DocumentFilter): Check if any of the urls within the document point to porn """ - def __init__(self): - super().__init__() + def __init__( + self, + text_field: str = "text", + score_field: str = "unsafe_url", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [float], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) def score_document(self, text): all_urls = regex_url.findall(text) From ba561a355df2ef0065c478f8455a86517b40b862 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 28 Oct 2024 11:42:52 -0700 Subject: [PATCH 11/11] Remove private variables from filters and add tests for filters Signed-off-by: Ryan Wolf --- nemo_curator/filters/__init__.py | 3 +- nemo_curator/filters/heuristic_filter.py | 49 ++--- tests/test_filters.py | 267 ++++++++++++++++++++++- 3 files changed, 283 insertions(+), 36 deletions(-) diff --git a/nemo_curator/filters/__init__.py b/nemo_curator/filters/__init__.py index 61890983d..ac90d254e 100644 --- a/nemo_curator/filters/__init__.py +++ b/nemo_curator/filters/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import DocumentFilter, import_filter +from .base import DocumentFilter, FilterMode, import_filter from .classifier_filter import FastTextLangId, FastTextQualityFilter from .code import ( AlphaFilter, @@ -52,6 +52,7 @@ __all__ = [ "DocumentFilter", "import_filter", + "FilterMode", "FastTextLangId", "FastTextQualityFilter", "NonAlphaNumericFilter", diff --git a/nemo_curator/filters/heuristic_filter.py b/nemo_curator/filters/heuristic_filter.py index 116dcb3f2..10daede95 100644 --- a/nemo_curator/filters/heuristic_filter.py +++ b/nemo_curator/filters/heuristic_filter.py @@ -233,14 +233,11 @@ def __init__( save_score=save_score, ) self._cutoff = max_bullet_lines_ratio - self._sentences = None self._name = "bullet_ratio" def score_document(self, text): # Get sentences - sentences = self._sentences - if sentences is None: - sentences = get_sentences(text) + sentences = get_sentences(text) num_bullet_lines = 0 for sentence in sentences: for bullet in bullet_list: @@ -537,9 +534,7 @@ def __init__( self._name = "repeated_lines" def score_document(self, text): - sentences = self._sentences - if sentences is None: - sentences = get_sentences(text) + sentences = get_sentences(text) return len(set(sentences)) / len(sentences) def keep_document(self, score): @@ -576,9 +571,7 @@ def __init__( self._name = "repeated_paragraphs" def score_document(self, text): - paragraphs = self._paragraphs - if paragraphs is None: - paragraphs = get_paragraphs(text) + paragraphs = get_paragraphs(text) return len(set(paragraphs)) / len(paragraphs) def keep_document(self, score): @@ -615,9 +608,7 @@ def __init__( self._name = "repeated_lines_char" def score_document(self, text): - sentences = self._sentences - if sentences is None: - sentences = get_sentences(text) + sentences = get_sentences(text) return len("".join(set(sentences))) / len("".join(sentences)) @@ -655,9 +646,7 @@ def __init__( self._name = "repeated_paragraphs_char" def score_document(self, text): - paragraphs = self._paragraphs - if paragraphs is None: - paragraphs = get_paragraphs(text) + paragraphs = get_paragraphs(text) return len("".join(set(paragraphs))) / len("".join(paragraphs)) @@ -700,12 +689,10 @@ def __init__( self._name = f"repeating_top_{n}grams" def score_document(self, text): - ngrams = self._ngrams - if ngrams is None: - split_text = self._word_splitter(text.strip()) - if len(split_text) < self._n: - return self._max_ratio - ngrams = get_ngrams(split_text, self._n) + split_text = self._word_splitter(text.strip()) + if len(split_text) < self._n: + return self._max_ratio + ngrams = get_ngrams(split_text, self._n) unique_ngrams = set(ngrams) # Find the most frequent ngram in the zipped ngram list counts = { @@ -765,12 +752,10 @@ def __init__( self._name = f"repeating_dup_{n}gram" def score_document(self, text): - ngrams = self._ngrams - if ngrams is None: - split_text = self._word_splitter(text.strip()) - if len(split_text) < self._n: - return self._max_ratio - ngrams = get_ngrams(split_text, self._n) + split_text = self._word_splitter(text.strip()) + if len(split_text) < self._n: + return self._max_ratio + ngrams = get_ngrams(split_text, self._n) counts = {} duplicated_nchar = 0 @@ -830,9 +815,7 @@ def __init__( self._name = "punctuation" def score_document(self, text): - sentences = self._sentences - if sentences is None: - sentences = get_sentences(text) + sentences = get_sentences(text) num_sentence_without_endmarks = len( [s for s in sentences if not s.strip().endswith(end_marks)] ) @@ -871,9 +854,7 @@ def __init__( self._name = "ellipsis" def score_document(self, text): - sentences = self._sentences - if sentences is None: - sentences = get_sentences(text) + sentences = get_sentences(text) num_lines_ending_with_ellipsis = 0 for sentence in sentences: for ellipsis in ellipsis_marks: diff --git a/tests/test_filters.py b/tests/test_filters.py index 791b176b6..81d10d70b 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from typing import Optional import dask import numpy as np @@ -54,6 +55,8 @@ WordsWithoutAlphabetsFilter, XMLHeaderFilter, ) +from nemo_curator.filters.base import FilterMode +from nemo_curator.filters.doc_filter import DocumentFilter as LegacyDocumentFilter from nemo_curator.modules import Filter, Score, ScoreFilter, Sequential from nemo_curator.utils.decorators import batched @@ -63,6 +66,80 @@ class LetterCountFilter(DocumentFilter): Keeps documents that have at least some number of a given letter """ + def __init__( + self, + letter="a", + min_count=5, + text_field: str = "text", + score_field: str = "letter_count", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [int], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) + self.letter = letter + self.min_count = min_count + + def score_document(self, text): + return text.count(self.letter) + + def keep_document(self, score): + return score >= self.min_count + + +class BatchedLengthFilter(DocumentFilter): + """ + Keeps documents of a given length + """ + + def __init__( + self, + min_length=5, + max_length=10, + text_field: str = "text", + score_field: str = "length", + filter_mode: FilterMode = FilterMode.SCORE_FILTER, + removed_path: Optional[str] = None, + invert: bool = False, + save_score: bool = True, + ): + super().__init__( + [int], + text_fields=[text_field], + score_fields=[score_field], + filter_mode=filter_mode, + removed_path=removed_path, + invert=invert, + save_score=save_score, + ) + self.min_length = min_length + self.max_length = max_length + + @batched + def score_document(self, df): + return df.str.len() + + @batched + def keep_document(self, scores): + min_threshold = self.min_length <= scores + max_threshold = scores <= self.max_length + return min_threshold & max_threshold + + +class LegacyLetterCountFilter(LegacyDocumentFilter): + """ + Keeps documents that have at least some number of a given letter + """ + def __init__(self, letter="a", min_count=5): super().__init__() self.letter = letter @@ -75,7 +152,7 @@ def keep_document(self, score): return score >= self.min_count -class BatchedLengthFilter(DocumentFilter): +class LegacyBatchedLengthFilter(LegacyDocumentFilter): """ Keeps documents of a given length """ @@ -300,6 +377,194 @@ def test_chain_filter(self, letter_count_data): ), f"Expected {expected_data} but got {filtered_data}" +class TestLegacyFilterModule: + def test_score_filter(self, letter_count_data): + letter_filter = LegacyLetterCountFilter() + filter_step = ScoreFilter(letter_filter, text_field="documents") + filtered_data = filter_step(letter_count_data) + + expected_indices = [2, 3] + expected_data = DocumentDataset(letter_count_data.df.loc[expected_indices]) + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + + def test_score(self, letter_count_data): + letter_filter = LegacyLetterCountFilter() + score_field = "a_count" + score_step = Score( + letter_filter.score_document, + text_field="documents", + score_field=score_field, + ) + scored_data = score_step(letter_count_data) + + expected_scores = pd.Series([2, 3, 5, 7]) + scores = scored_data.df[score_field] + assert all( + expected_scores == scores.compute() + ), f"Expected {expected_scores} but got {scores}" + + def test_retain_score_filter(self, letter_count_data): + letter_filter = LegacyLetterCountFilter() + score_field = "count_a" + filter_step = ScoreFilter( + letter_filter, text_field="documents", score_field=score_field + ) + filtered_data = filter_step(letter_count_data) + + expected_indices = [2, 3] + # Compute before loc due to https://github.com/dask/dask-expr/issues/1036 + expected_data = letter_count_data.df.compute().loc[expected_indices] + expected_data = DocumentDataset(dd.from_pandas(expected_data, 2)) + expected_data.df[score_field] = pd.Series([5, 7], index=expected_data.df.index) + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + + def test_filter(self, letter_count_data): + letter_filter = LegacyLetterCountFilter() + score_field = "a_count" + score_step = Score( + letter_filter.score_document, + text_field="documents", + score_field=score_field, + ) + scored_data = score_step(letter_count_data) + filter_step = Filter(letter_filter.keep_document, score_field) + filtered_data = filter_step(scored_data) + + expected_indices = [2, 3] + # Compute before loc due to https://github.com/dask/dask-expr/issues/1036 + expected_data = letter_count_data.df.compute().loc[expected_indices] + expected_data = dd.from_pandas(expected_data, 2) + expected_data[score_field] = pd.Series([5, 7], index=expected_data.index) + expected_data = DocumentDataset(expected_data) + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + + def test_invert(self, letter_count_data): + letter_filter = LegacyLetterCountFilter() + filter_step = ScoreFilter(letter_filter, text_field="documents", invert=True) + filtered_data = filter_step(letter_count_data) + + expected_indices = [0, 1] + expected_data = DocumentDataset(letter_count_data.df.loc[expected_indices]) + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + + def test_sequential_filter(self, letter_count_data): + filters = Sequential( + [ + ScoreFilter(LegacyLetterCountFilter(), text_field="documents"), + ScoreFilter( + LegacyLetterCountFilter(min_count=6), text_field="documents" + ), + ] + ) + filtered_data = filters(letter_count_data) + + expected_indices = [3] + expected_data = DocumentDataset(letter_count_data.df.loc[expected_indices]) + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + + def test_batch_score_filter(self, letter_count_data): + length_filter = LegacyBatchedLengthFilter(min_length=8, max_length=11) + filter_step = ScoreFilter(length_filter, text_field="documents") + filtered_data = filter_step(letter_count_data) + + expected_indices = [1, 2] + expected_data = DocumentDataset(letter_count_data.df.loc[expected_indices]) + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + + def test_batch_score(self, letter_count_data): + length_filter = LegacyBatchedLengthFilter(min_length=8, max_length=11) + score_field = "lengths" + score_step = Score( + length_filter.score_document, + text_field="documents", + score_field=score_field, + ) + scored_data = score_step(letter_count_data) + + expected_scores = pd.Series([6, 11, 11, 13]) + scores = scored_data.df[score_field] + assert all( + expected_scores == scores.compute() + ), f"Expected {expected_scores} but got {scores}" + + def test_batch_filter(self, letter_count_data): + length_filter = LegacyBatchedLengthFilter(min_length=8, max_length=11) + score_field = "lengths" + score_step = Score( + length_filter.score_document, + text_field="documents", + score_field=score_field, + ) + scored_data = score_step(letter_count_data) + filter_step = Filter(length_filter.keep_document, score_field) + filtered_data = filter_step(scored_data) + + expected_indices = [1, 2] + expected_data = letter_count_data.df.loc[expected_indices] + expected_data[score_field] = pd.Series([11, 11], index=expected_data.index) + expected_data = DocumentDataset(expected_data) + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + + def test_score_filter_type(self, letter_count_data): + letter_filter = LegacyLetterCountFilter() + filter_step = ScoreFilter(letter_filter, text_field="documents", score_type=int) + filtered_data = filter_step(letter_count_data) + + expected_indices = [2, 3] + expected_data = DocumentDataset(letter_count_data.df.loc[expected_indices]) + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + + def test_score_type(self, letter_count_data): + letter_filter = LegacyLetterCountFilter() + score_field = "a_count" + score_step = Score( + letter_filter.score_document, + text_field="documents", + score_field=score_field, + score_type=int, + ) + scored_data = score_step(letter_count_data) + + expected_scores = pd.Series([2, 3, 5, 7]) + scores = scored_data.df[score_field] + assert all( + expected_scores == scores.compute() + ), f"Expected {expected_scores} but got {scores}" + + def test_chain_filter(self, letter_count_data): + letter_count_filter = LegacyLetterCountFilter(min_count=4) + length_filter = BatchedLengthFilter(min_length=8, max_length=11) + filters = Sequential( + [ + ScoreFilter(letter_count_filter, text_field="documents"), + ScoreFilter(length_filter, text_field="documents"), + ] + ) + filtered_data = filters(letter_count_data) + + expected_indices = [2] + expected_data = DocumentDataset(letter_count_data.df.loc[expected_indices]) + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + + class TestHeuristicFilters: def test_nonalpha(self): dataset = list_to_dataset(