Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor modules to be in a single class and dataframe backend checking #379

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion nemo_curator/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import DocumentFilter, FilterMode, import_filter
from .classifier_filter import FastTextLangId, FastTextQualityFilter
from .code import (
AlphaFilter,
Expand All @@ -23,7 +24,6 @@
TokenizerFertilityFilter,
XMLHeaderFilter,
)
from .doc_filter import DocumentFilter, import_filter
from .heuristic_filter import (
BoilerPlateStringFilter,
BulletsFilter,
Expand Down Expand Up @@ -52,6 +52,7 @@
__all__ = [
"DocumentFilter",
"import_filter",
"FilterMode",
"FastTextLangId",
"FastTextQualityFilter",
"NonAlphaNumericFilter",
Expand Down
216 changes: 216 additions & 0 deletions nemo_curator/filters/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, List, Optional, Type, Union

import dask.dataframe as dd

from nemo_curator.datasets import DocumentDataset
from nemo_curator.modules.base import Module
from nemo_curator.utils.module_utils import is_batched


class FilterMode(Enum):
SCORE_FILTER = "score_filter"
SCORE = "score"
FILTER = "filter"


class DocumentFilter(Module, ABC):
"""
An abstract base class for text-based document filters.

This class serves as a template for creating specific document filters
in the library. Subclasses should implement the abstract methods to
define custom filtering behavior.
"""

def __init__(
self,
score_types: Union[List[str], List[Type]],
text_fields: List[str] = ["text"],
score_fields: List[str] = ["score"],
filter_mode: FilterMode = FilterMode.SCORE_FILTER,
removed_path: Optional[str] = None,
invert: bool = False,
save_score: bool = True,
input_backend: str = "pandas",
):
"""
text_fields: If len(text_fields) == 1, then score_document will
get a series instead of a dataframe. You may still output
multiple scores in the form of a dataframe. Need to verify if that's possible with Dask though.
score_fields: If len(score_fields) == 1, then score_document
must output a series instead of a dataframe. keep_document
must accept a series instead of a dataframe. keep_document must always return a series.
"""
super().__init__(input_backend=input_backend)

if len(score_types) != len(score_fields):
raise ValueError(
f"score_types must have the same length as score_fields, but got {len(score_types)} and {len(score_fields)}."
)

if len(score_fields) > 1 and not is_batched(self.score_document):
raise ValueError(
f"When outputing multiple scores ({len(score_fields)} in this case), score_document must be defined in @batched mode."
)

self.score_types = score_types
self.text_fields = text_fields
self.score_fields = score_fields
self.removed_path = removed_path
self.invert = invert
self.filter_mode = filter_mode
self.save_score = save_score

@abstractmethod
def score_document(self, text: str) -> Any:
"""
Calculate a score for the given document text.

This method should be implemented by subclasses to define how
a document's text is evaluated and scored.

Args:
text (str): The text content of the document to be scored.

Returns:
Any: A score or set of scores representing the document's
relevance or quality. The type and structure of the
return value should be consistent for each subclass.

Raises:
NotImplementedError: If the method is not implemented in a subclass.
"""
raise NotImplementedError(
"score_document method must be implemented by subclasses"
)

@abstractmethod
def keep_document(self, scores: Any) -> bool:
"""
Determine whether to keep a document based on its scores.

This method should be implemented by subclasses to define the
criteria for keeping or discarding a document based on the
scores calculated by score_document().

Args:
scores (Any): The score or set of scores returned by score_document().
The type should match what is returned by score_document().

Returns:
bool: True if the document should be kept, False otherwise.

Raises:
NotImplementedError: If the method is not implemented in a subclass.
"""
raise NotImplementedError(
"keep_document method must be implemented by subclasses"
)

def _score_dataset(self, dataset: DocumentDataset):
meta = (
list(zip(self.score_fields, self.score_types))
if len(self.score_fields) > 1
else (None, self.score_types[0])
)
# Get the field name directly if there's only one
text_fields = (
self.text_fields if len(self.text_fields) > 1 else self.text_fields[0]
)

if is_batched(self.score_document):
scores = dataset.df[text_fields].map_partitions(
self.score_document, meta=meta
)
else:
axis = 1 if len(self.text_fields) > 1 else 0
scores = dataset.df[text_fields].apply(
self.score_document, axis=axis, meta=meta
)

if self.save_score:
if len(self.score_fields) > 1:
dataset.df = dd.concat([dataset.df, scores], axis=1)
else:
score_fields = self.score_fields[0]
dataset.df[score_fields] = scores

return scores

def _filter_dataset(self, dataset: DocumentDataset, scores):
if is_batched(self.keep_document):
bool_mask = scores.map_partitions(self.keep_document, meta=(None, bool))
else:
axis = 1 if isinstance(scores, dd.DataFrame) else 0
bool_mask = scores.apply(self.keep_document, axis=axis, meta=(None, bool))
if self.invert:
bool_mask = ~bool_mask

if self.removed_path:
removed_docs = DocumentDataset(dataset.df[~bool_mask])
removed_docs.to_parquet(output_file_dir=self.removed_path)

return bool_mask

def compute_filter_mask(self, dataset: DocumentDataset):
scores = self._score_dataset(dataset)
return self._filter_dataset(dataset, scores)

def call(self, dataset: DocumentDataset) -> DocumentDataset:
match self.filter_mode:
case FilterMode.SCORE:
self._score_dataset(dataset)
return dataset
case FilterMode.FILTER:
score_fields = (
self.score_fields
if len(self.score_fields) > 1
else self.score_fields[0]
)
scores = dataset.df[score_fields]
mask = self._filter_dataset(dataset, scores)
return DocumentDataset(dataset.df[mask])
case FilterMode.SCORE_FILTER:
mask = self.compute_filter_mask(dataset)
return DocumentDataset(dataset.df[mask])


def import_filter(filter_path: str) -> DocumentFilter:
"""
Imports a filter under nemo_curator.filters given the module path

Args:
filter_path (str): The path to the filter in the form of "nemo_curator.filters.filter_name"

Returns:
DocumentFilter: The filter that is at the given path

Raises:
ValueError: If the filter_path does not point to a DocumentFilter
"""
module_path, filter_name = filter_path.rsplit(".", 1)
filter_module = importlib.import_module(module_path)
filter_class = getattr(filter_module, filter_name)
if not issubclass(filter_class, DocumentFilter):
raise ValueError(
f"Input filter {filter_class.__name__} must be derived "
"from DocumentFilter defined in nemo_curator.filters.base"
)
return filter_class
75 changes: 62 additions & 13 deletions nemo_curator/filters/classifier_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import dask
import fasttext
import numpy as np
import pandas as pd

from nemo_curator.filters.doc_filter import DocumentFilter
from nemo_curator.filters.base import DocumentFilter, FilterMode
from nemo_curator.utils.decorators import batched
from nemo_curator.utils.distributed_utils import NoWorkerError, load_object_on_worker


class FastTextQualityFilter(DocumentFilter):

def __init__(self, model_path=None, label="__label__hq", alpha=3, seed=42):
def __init__(
self,
model_path=None,
label="__label__hq",
alpha=3,
seed=42,
text_field: str = "text",
score_field: str = "fasttext_quality_score",
filter_mode: FilterMode = FilterMode.SCORE_FILTER,
removed_path: Optional[str] = None,
invert: bool = False,
save_score: bool = True,
):
super().__init__(
[float],
text_fields=[text_field],
score_fields=[score_field],
filter_mode=filter_mode,
removed_path=removed_path,
invert=invert,
save_score=save_score,
)
if model_path is None:
raise ValueError(
"Must provide a valid path to a FastText model "
Expand Down Expand Up @@ -64,38 +86,65 @@ def _load_model(self):


class FastTextLangId(DocumentFilter):

def __init__(self, model_path=None, min_langid_score=0.3):
def __init__(
self,
model_path=None,
min_langid_score=0.3,
text_field: str = "text",
score_field: str = "language_score",
lang_field: str = "language",
filter_mode: FilterMode = FilterMode.SCORE_FILTER,
removed_path: Optional[str] = None,
invert: bool = False,
save_score: bool = True,
):
super().__init__(
[float],
text_fields=[text_field],
score_fields=[score_field, lang_field],
filter_mode=filter_mode,
removed_path=removed_path,
invert=invert,
save_score=save_score,
)
if model_path is None:
raise ValueError(
"Must provide a valid path to a FastText model "
"to identify languages with this filter"
)
self.score_field = score_field
self.lang_field = lang_field
self._model_path = model_path
self._lang_code = None
self._cutoff = min_langid_score
self._name = "lang_id"

@batched
def score_document(self, df: pd.Series):
def score_document(self, series: pd.Series):
model_attr = f"{self._name}_{self._model_path}"
try:
model = load_object_on_worker(model_attr, self._load_model, {})
except NoWorkerError:
return pd.Series([[1.0, "N/A"] for _ in range(len(df))])
return pd.Series([[1.0, "N/A"] for _ in range(len(series))])

def _score_document(text):
pp = text.strip().replace("\n", " ")
label, score = model.predict(pp, k=1)
processed_series = series.str.strip().str.replace("\n", " ")
scores = []
lang_codes = []
for text in processed_series:
label, score = model.predict(text, k=1)
score = score[0]
lang_code = label[0][-2:].upper()

return [score, lang_code]
scores.append(score)
lang_codes.append(lang_code)

return df.apply(_score_document)
return pd.DataFrame(
{self.score_field: scores, self.lang_field: lang_codes}, index=series.index
)

@batched
def keep_document(self, score):
return score[0] >= self._cutoff
return score[self.score_field] >= self._cutoff

def _load_model(self):
return fasttext.load_model(self._model_path)
Loading