diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 8c1762086..efaf1c565 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -585,7 +585,8 @@ def config_backup(cfg): target_path = os.path.join(work_dir, os.path.basename(cfg_path)) logger.info(f'Back up the input config file [{cfg_path}] into the ' f'work_dir [{work_dir}]') - shutil.copyfile(cfg_path, target_path) + if not os.path.exists(target_path): + shutil.copyfile(cfg_path, target_path) def display_config(cfg): diff --git a/demos/tool_quality_classifier/quality_classifier/eval.py b/demos/tool_quality_classifier/quality_classifier/eval.py index 06eb72069..19fe321f0 100644 --- a/demos/tool_quality_classifier/quality_classifier/eval.py +++ b/demos/tool_quality_classifier/quality_classifier/eval.py @@ -27,7 +27,7 @@ from qc_utils import eval, init_spark, load_datasets -@logger.catch +@logger.catch(reraise=True) def main(positive_datasets=None, negative_datasets=None, model='my_quality_model', diff --git a/demos/tool_quality_classifier/quality_classifier/predict.py b/demos/tool_quality_classifier/quality_classifier/predict.py index ddbb084b7..ecdbc182d 100644 --- a/demos/tool_quality_classifier/quality_classifier/predict.py +++ b/demos/tool_quality_classifier/quality_classifier/predict.py @@ -64,7 +64,7 @@ prepare_model) -@logger.catch +@logger.catch(reraise=True) def main(dataset_path, result_path, model='gpt3', diff --git a/demos/tool_quality_classifier/quality_classifier/train.py b/demos/tool_quality_classifier/quality_classifier/train.py index ea4459c69..c1dcdcb39 100644 --- a/demos/tool_quality_classifier/quality_classifier/train.py +++ b/demos/tool_quality_classifier/quality_classifier/train.py @@ -33,7 +33,7 @@ from qc_utils import eval, init_spark, load_datasets, shuffle, train -@logger.catch +@logger.catch(reraise=True) def main(positive_datasets, negative_datasets, output_model_path='my_quality_model', diff --git a/tests/tools/__init__.py b/tests/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tools/test_process_data.py b/tests/tools/test_process_data.py new file mode 100644 index 000000000..1c923a87b --- /dev/null +++ b/tests/tools/test_process_data.py @@ -0,0 +1,70 @@ +import os +import os.path as osp +import shutil +import subprocess +import tempfile +import unittest +import yaml + +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class ProcessDataTest(DataJuicerTestCaseBase): + + def setUp(self): + super().setUp() + + self.tmp_dir = tempfile.TemporaryDirectory().name + if not osp.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + def _test_status_code(self, yaml_file, output_path, text_keys): + data_path = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), + 'demos', 'data', 'demo-dataset.jsonl') + yaml_config = { + 'dataset_path': data_path, + 'text_keys': text_keys, + 'np': 2, + 'export_path': output_path, + 'process': [ + { + 'clean_copyright_mapper': None + } + ] + } + + with open(yaml_file, 'w') as file: + yaml.dump(yaml_config, file) + + status_code = subprocess.call( + f'python tools/process_data.py --config {yaml_file}', shell=True) + + return status_code + + def test_status_code_0(self): + tmp_yaml_file = osp.join(self.tmp_dir, 'config_0.yaml') + tmp_out_path = osp.join(self.tmp_dir, 'output_0.json') + text_keys = 'text' + + status_code = self._test_status_code(tmp_yaml_file, tmp_out_path, text_keys) + + self.assertEqual(status_code, 0) + self.assertTrue(osp.exists(tmp_out_path)) + + def test_status_code_1(self): + tmp_yaml_file = osp.join(self.tmp_dir, 'config_1.yaml') + tmp_out_path = osp.join(self.tmp_dir, 'output_1.json') + text_keys = 'keys_not_exists' + + status_code = self._test_status_code(tmp_yaml_file, tmp_out_path, text_keys) + + self.assertEqual(status_code, 1) + self.assertFalse(osp.exists(tmp_out_path)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/analyze_data.py b/tools/analyze_data.py index 5d8db6e54..a6487131b 100644 --- a/tools/analyze_data.py +++ b/tools/analyze_data.py @@ -3,7 +3,7 @@ from data_juicer.core import Analyser -@logger.catch +@logger.catch(reraise=True) def main(): analyser = Analyser() analyser.run() diff --git a/tools/hpo/execute_hpo_3sigma.py b/tools/hpo/execute_hpo_3sigma.py index 2a1d5e819..975f94114 100644 --- a/tools/hpo/execute_hpo_3sigma.py +++ b/tools/hpo/execute_hpo_3sigma.py @@ -8,7 +8,7 @@ from data_juicer.utils.constant import StatsKeys -@logger.catch +@logger.catch(reraise=True) def main(): path_k_sigma_recipe = None diff --git a/tools/multimodal/data_juicer_format_to_target_format/dj_to_llava.py b/tools/multimodal/data_juicer_format_to_target_format/dj_to_llava.py index 8a58bfc4b..d9be7e7f7 100644 --- a/tools/multimodal/data_juicer_format_to_target_format/dj_to_llava.py +++ b/tools/multimodal/data_juicer_format_to_target_format/dj_to_llava.py @@ -68,7 +68,7 @@ from data_juicer.utils.mm_utils import SpecialTokens -@logger.catch +@logger.catch(reraise=True) def main( dj_ds_path: str, target_llava_ds_path: str, diff --git a/tools/multimodal/data_juicer_format_to_target_format/dj_to_mmc4.py b/tools/multimodal/data_juicer_format_to_target_format/dj_to_mmc4.py index 3c4238ba5..bd03ff5d5 100644 --- a/tools/multimodal/data_juicer_format_to_target_format/dj_to_mmc4.py +++ b/tools/multimodal/data_juicer_format_to_target_format/dj_to_mmc4.py @@ -94,7 +94,7 @@ from data_juicer.utils.mm_utils import SpecialTokens -@logger.catch +@logger.catch(reraise=True) def main( dj_ds_path: str, target_mmc4_ds_path: str, diff --git a/tools/multimodal/data_juicer_format_to_target_format/dj_to_wavcaps.py b/tools/multimodal/data_juicer_format_to_target_format/dj_to_wavcaps.py index b7cf268e1..79834019e 100644 --- a/tools/multimodal/data_juicer_format_to_target_format/dj_to_wavcaps.py +++ b/tools/multimodal/data_juicer_format_to_target_format/dj_to_wavcaps.py @@ -75,7 +75,7 @@ from data_juicer.utils.mm_utils import SpecialTokens -@logger.catch +@logger.catch(reraise=True) def main( dj_ds_path: str, target_wavcaps_ds_path: str, diff --git a/tools/multimodal/source_format_to_data_juicer_format/llava_to_dj.py b/tools/multimodal/source_format_to_data_juicer_format/llava_to_dj.py index 0020e0484..90c176bce 100644 --- a/tools/multimodal/source_format_to_data_juicer_format/llava_to_dj.py +++ b/tools/multimodal/source_format_to_data_juicer_format/llava_to_dj.py @@ -69,7 +69,7 @@ from data_juicer.utils.mm_utils import SpecialTokens -@logger.catch +@logger.catch(reraise=True) def main( llava_ds_path: str, target_ds_path: str, diff --git a/tools/multimodal/source_format_to_data_juicer_format/mmc4_to_dj.py b/tools/multimodal/source_format_to_data_juicer_format/mmc4_to_dj.py index 9a045aa68..891ff4ff9 100644 --- a/tools/multimodal/source_format_to_data_juicer_format/mmc4_to_dj.py +++ b/tools/multimodal/source_format_to_data_juicer_format/mmc4_to_dj.py @@ -88,7 +88,7 @@ from data_juicer.utils.mm_utils import SpecialTokens -@logger.catch +@logger.catch(reraise=True) def main( mmc4_ds_path: str, target_ds_path: str, diff --git a/tools/multimodal/source_format_to_data_juicer_format/video_chatgpt_to_dj.py b/tools/multimodal/source_format_to_data_juicer_format/video_chatgpt_to_dj.py index 6aa1db738..118859a4e 100644 --- a/tools/multimodal/source_format_to_data_juicer_format/video_chatgpt_to_dj.py +++ b/tools/multimodal/source_format_to_data_juicer_format/video_chatgpt_to_dj.py @@ -41,7 +41,7 @@ convert_text_to_dj) -@logger.catch +@logger.catch(reraise=True) def main( video_chatgpt_ds_path: str, target_ds_dj_path: str, diff --git a/tools/multimodal/source_format_to_data_juicer_format/wavcaps_to_dj.py b/tools/multimodal/source_format_to_data_juicer_format/wavcaps_to_dj.py index 7cb9470a2..786024ff0 100644 --- a/tools/multimodal/source_format_to_data_juicer_format/wavcaps_to_dj.py +++ b/tools/multimodal/source_format_to_data_juicer_format/wavcaps_to_dj.py @@ -104,7 +104,7 @@ def get_all_files(dirname): return result -@logger.catch +@logger.catch(reraise=True) def main( wavcaps_json_path: str, wavcaps_audio_path: str, diff --git a/tools/multimodal/source_format_to_data_juicer_format/youku_to_dj.py b/tools/multimodal/source_format_to_data_juicer_format/youku_to_dj.py index f347dd090..e3864c6c2 100644 --- a/tools/multimodal/source_format_to_data_juicer_format/youku_to_dj.py +++ b/tools/multimodal/source_format_to_data_juicer_format/youku_to_dj.py @@ -62,7 +62,7 @@ convert_text_to_dj) -@logger.catch +@logger.catch(reraise=True) def main( youku_ds_path: str, target_ds_path: str, diff --git a/tools/preprocess/raw_arxiv_to_jsonl.py b/tools/preprocess/raw_arxiv_to_jsonl.py index d92efd235..866b24b34 100644 --- a/tools/preprocess/raw_arxiv_to_jsonl.py +++ b/tools/preprocess/raw_arxiv_to_jsonl.py @@ -25,7 +25,7 @@ from loguru import logger -@logger.catch +@logger.catch(reraise=True) def tex_proj_loader(file_or_dir_path: pathlib.Path): """ Load the tex files from a tar file or a gzip file. @@ -69,7 +69,7 @@ def tex_proj_loader(file_or_dir_path: pathlib.Path): return files_and_content -@logger.catch +@logger.catch(reraise=True) def convert_tar_to_jsonl(tar_fp, jsonl_fp, tmp_dir): """ Extract the contents of tex files from tar file, convert and diff --git a/tools/preprocess/raw_stackexchange_to_jsonl.py b/tools/preprocess/raw_stackexchange_to_jsonl.py index ad1a0bfe4..cd4c90aef 100644 --- a/tools/preprocess/raw_stackexchange_to_jsonl.py +++ b/tools/preprocess/raw_stackexchange_to_jsonl.py @@ -23,7 +23,7 @@ from tqdm import tqdm -@logger.catch +@logger.catch(reraise=True) def get_sites_count(path, topk=28): """ Take top-K sites(`.xml`) by its size of content @@ -57,7 +57,7 @@ def get_sites_count(path, topk=28): return counts, sites -@logger.catch +@logger.catch(reraise=True) def get_parents(site, counts): """ Find all answers's parent id, and groups by parent id @@ -90,7 +90,7 @@ def get_parents(site, counts): return parents -@logger.catch +@logger.catch(reraise=True) def get_qapairs(site, counts, parents): """ Find and group all matched pairs of question and answer in site file @@ -140,7 +140,7 @@ def get_qapairs(site, counts, parents): return qa_pairs -@logger.catch +@logger.catch(reraise=True) def process_qa_pair(pair, site_name, site_count): """ Sort answers by their score for question in qa pair sample, @@ -171,7 +171,7 @@ def process_qa_pair(pair, site_name, site_count): } -@logger.catch +@logger.catch(reraise=True) def process_site(site, counts, src_dir, target_dir, num_proc=24): """ Convert one raw Stack Exchange site data to jsonl file. @@ -207,7 +207,7 @@ def process_site(site, counts, src_dir, target_dir, num_proc=24): f.write(json.dumps(result) + '\n') -@logger.catch +@logger.catch(reraise=True) def main(src_dir, target_dir, topk=28, num_proc=1): """ Convert the raw Stack Exchange data downloaded from from Archive diff --git a/tools/process_data.py b/tools/process_data.py index f92d4ac5a..a97ef9a40 100644 --- a/tools/process_data.py +++ b/tools/process_data.py @@ -4,7 +4,7 @@ from data_juicer.core import Executor -@logger.catch +@logger.catch(reraise=True) def main(): cfg = init_configs() if cfg.executor_type == 'default': diff --git a/tools/quality_classifier/eval.py b/tools/quality_classifier/eval.py index be1bfa622..de043e562 100644 --- a/tools/quality_classifier/eval.py +++ b/tools/quality_classifier/eval.py @@ -27,7 +27,7 @@ from tools.quality_classifier.qc_utils import eval, init_spark, load_datasets -@logger.catch +@logger.catch(reraise=True) def main(positive_datasets=None, negative_datasets=None, model='my_quality_model', diff --git a/tools/quality_classifier/predict.py b/tools/quality_classifier/predict.py index 93bfe5b24..488f78ee7 100644 --- a/tools/quality_classifier/predict.py +++ b/tools/quality_classifier/predict.py @@ -65,7 +65,7 @@ prepare_model) -@logger.catch +@logger.catch(reraise=True) def predict_score(dataset_path, result_path, model='gpt3', diff --git a/tools/quality_classifier/train.py b/tools/quality_classifier/train.py index e0f7fd5aa..14f11972e 100644 --- a/tools/quality_classifier/train.py +++ b/tools/quality_classifier/train.py @@ -34,7 +34,7 @@ shuffle, train) -@logger.catch +@logger.catch(reraise=True) def main(positive_datasets, negative_datasets, output_model_path='my_quality_model',