From edd62625233627fba289c4e8347bddbaea07d02b Mon Sep 17 00:00:00 2001 From: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:31:57 -0800 Subject: [PATCH] Update `FineWebEduClassifier` identifier (#403) * update identifier Signed-off-by: Sarah Yurick * edit example Signed-off-by: Sarah Yurick * fix bug Signed-off-by: Sarah Yurick * run black Signed-off-by: Sarah Yurick --------- Signed-off-by: Sarah Yurick --- examples/classifiers/fineweb_edu_example.py | 6 +++--- nemo_curator/classifiers/fineweb_edu.py | 12 +++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/classifiers/fineweb_edu_example.py b/examples/classifiers/fineweb_edu_example.py index b31715161..d5bb374d8 100644 --- a/examples/classifiers/fineweb_edu_example.py +++ b/examples/classifiers/fineweb_edu_example.py @@ -36,13 +36,13 @@ def main(args): input_file_path, backend="cudf", add_filename=True ) - fineweb_classifier = FineWebEduClassifier() - result_dataset = fineweb_classifier(dataset=input_dataset) + fineweb_edu_classifier = FineWebEduClassifier() + result_dataset = fineweb_edu_classifier(dataset=input_dataset) result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True) global_et = time.time() print( - f"Total time taken for fineweb classifier inference: {global_et-global_st} s", + f"Total time taken for FineWeb-Edu classifier inference: {global_et-global_st} s", flush=True, ) diff --git a/nemo_curator/classifiers/fineweb_edu.py b/nemo_curator/classifiers/fineweb_edu.py index 75050b22a..932bfc8c8 100644 --- a/nemo_curator/classifiers/fineweb_edu.py +++ b/nemo_curator/classifiers/fineweb_edu.py @@ -26,7 +26,7 @@ ) from nemo_curator.datasets import DocumentDataset -FINEWEB_EDU_IDENTIFIER = "HuggingFaceTB/fineweb-edu-classifier" +FINEWEB_EDU_IDENTIFIER = "HuggingFaceFW/fineweb-edu-classifier" class FinewebEduModel(HFModel): @@ -138,9 +138,11 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset: keep_cols=ddf.columns.tolist(), ) ddf = pipe(ddf) - # Go from list to scalar - ddf[self.pred_column] = ddf[self.pred_column].list.get(0) - ddf[self.int_column] = ( - ddf[self.pred_column].clip(lower=0, upper=5).round().astype(int) + ddf[self.pred_column] = ddf[self.pred_column].where( + ddf[self.pred_column] >= 0, 0 ) + ddf[self.pred_column] = ddf[self.pred_column].where( + ddf[self.pred_column] <= 5, 5 + ) + ddf[self.int_column] = ddf[self.pred_column].round().astype(int) return DocumentDataset(ddf)