Skip to content

Commit

Permalink
Enhance/ckpt (#399)
Browse files Browse the repository at this point in the history
  • Loading branch information
drcege authored Aug 22, 2024
1 parent 213f7f8 commit 69e199e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 6 additions & 2 deletions data_juicer/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def process(self,
dataset = op.run(dataset, exporter=exporter, tracer=tracer)
# record processed ops
if checkpointer is not None:
checkpointer.record(op._of_cfg)
checkpointer.record(op._op_cfg)
end = time()
logger.info(f'OP [{op._name}] Done in {end - start:.3f}s. '
f'Left {len(dataset)} samples.')
Expand All @@ -196,7 +196,7 @@ def process(self,
traceback.print_exc()
exit(1)
finally:
if checkpointer:
if checkpointer and dataset is not self:
logger.info('Writing checkpoint of dataset processed by '
'last op...')
dataset.cleanup_cache_files()
Expand Down Expand Up @@ -334,6 +334,10 @@ def cleanup_cache_files(self):
cleanup_compressed_cache_files(self)
return super().cleanup_cache_files()

@staticmethod
def load_from_disk(*args, **kargs):
return NestedDataset(Dataset.load_from_disk(*args, **kargs))


def nested_query(root_obj: Union[NestedDatasetDict, NestedDataset,
NestedQueryDict], key):
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/utils/ckpt_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import os

from datasets import Dataset
from loguru import logger


Expand Down Expand Up @@ -133,5 +132,6 @@ def load_ckpt(self):
:return: a dataset stored in checkpoint file.
"""
ds = Dataset.load_from_disk(self.ckpt_ds_dir)
from data_juicer.core.data import NestedDataset
ds = NestedDataset.load_from_disk(self.ckpt_ds_dir)
return ds

0 comments on commit 69e199e

Please sign in to comment.