Skip to content

Commit

Permalink
Update FineWebEduClassifier identifier (NVIDIA#403)
Browse files Browse the repository at this point in the history
* update identifier

Signed-off-by: Sarah Yurick <[email protected]>

* edit example

Signed-off-by: Sarah Yurick <[email protected]>

* fix bug

Signed-off-by: Sarah Yurick <[email protected]>

* run black

Signed-off-by: Sarah Yurick <[email protected]>

---------

Signed-off-by: Sarah Yurick <[email protected]>
  • Loading branch information
sarahyurick authored Dec 3, 2024
1 parent d1f52f6 commit edd6262
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
6 changes: 3 additions & 3 deletions examples/classifiers/fineweb_edu_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ def main(args):
input_file_path, backend="cudf", add_filename=True
)

fineweb_classifier = FineWebEduClassifier()
result_dataset = fineweb_classifier(dataset=input_dataset)
fineweb_edu_classifier = FineWebEduClassifier()
result_dataset = fineweb_edu_classifier(dataset=input_dataset)
result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True)

global_et = time.time()
print(
f"Total time taken for fineweb classifier inference: {global_et-global_st} s",
f"Total time taken for FineWeb-Edu classifier inference: {global_et-global_st} s",
flush=True,
)

Expand Down
12 changes: 7 additions & 5 deletions nemo_curator/classifiers/fineweb_edu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from nemo_curator.datasets import DocumentDataset

FINEWEB_EDU_IDENTIFIER = "HuggingFaceTB/fineweb-edu-classifier"
FINEWEB_EDU_IDENTIFIER = "HuggingFaceFW/fineweb-edu-classifier"


class FinewebEduModel(HFModel):
Expand Down Expand Up @@ -138,9 +138,11 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset:
keep_cols=ddf.columns.tolist(),
)
ddf = pipe(ddf)
# Go from list to scalar
ddf[self.pred_column] = ddf[self.pred_column].list.get(0)
ddf[self.int_column] = (
ddf[self.pred_column].clip(lower=0, upper=5).round().astype(int)
ddf[self.pred_column] = ddf[self.pred_column].where(
ddf[self.pred_column] >= 0, 0
)
ddf[self.pred_column] = ddf[self.pred_column].where(
ddf[self.pred_column] <= 5, 5
)
ddf[self.int_column] = ddf[self.pred_column].round().astype(int)
return DocumentDataset(ddf)

0 comments on commit edd6262

Please sign in to comment.