Skip to content

Commit

Permalink
Add turbo mode (#402)
Browse files Browse the repository at this point in the history
* enhance ckpt logic

* fix tests

* add turbo mode

* fix tests
  • Loading branch information
drcege authored Sep 10, 2024
1 parent b3559af commit 7954241
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 13 deletions.
26 changes: 18 additions & 8 deletions data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,12 @@ def init_configs(args=None):
help='Suffixes of files that will be find and loaded. If not set, we '
'will find all suffix files, and select a suitable formatter '
'with the most files as default.')
parser.add_argument(
'--turbo',
type=bool,
default=False,
help='Enable Turbo mode to maximize processing speed. Stability '
'features like fault tolerance will be disabled.')
parser.add_argument(
'--use_cache',
type=bool,
Expand Down Expand Up @@ -470,6 +476,8 @@ def init_setup_from_cfg(cfg):
'image_key': cfg.image_key,
'audio_key': cfg.audio_key,
'video_key': cfg.video_key,
'num_proc': cfg.np,
'turbo': cfg.turbo,
}
else:
if 'text_key' not in args or args['text_key'] is None:
Expand All @@ -480,6 +488,10 @@ def init_setup_from_cfg(cfg):
args['audio_key'] = cfg.audio_key
if 'video_key' not in args or args['video_key'] is None:
args['video_key'] = cfg.video_key
if 'num_proc' not in args or args['num_proc'] is None:
args['num_proc'] = cfg.np
if 'turbo' not in args or args['turbo'] is None:
args['turbo'] = cfg.turbo
op[op_name] = args

return cfg
Expand Down Expand Up @@ -574,14 +586,12 @@ def update_op_process(cfg, parser):

# update op params of cfg.process
internal_op_para = temp_cfg.get(op_in_process_name)
if internal_op_para is not None:
num_proc = internal_op_para.get('num_proc')
if 'num_proc' in internal_op_para:
internal_op_para['num_proc'] = num_proc or cfg.np
internal_op_para = namespace_to_dict(internal_op_para)
else:
internal_op_para = None
cfg.process[i] = {op_in_process_name: internal_op_para}

cfg.process[i] = {
op_in_process_name:
None if internal_op_para is None else
namespace_to_dict(internal_op_para)
}

# check the op params via type hint
temp_parser = copy.deepcopy(parser)
Expand Down
13 changes: 8 additions & 5 deletions data_juicer/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,15 @@ def map(self, *args, **kargs):
called_func, '__wrapped__'):
called_func = called_func.__wrapped__

# Batched is always required for fault tolerance
if inspect.ismethod(called_func):
kargs['batched'] = True
kargs['batch_size'] = kargs.pop('batch_size', 1) if hasattr(
called_func.__self__, 'is_batched_op'
) and called_func.__self__.is_batched_op() else 1
# batched is required for fault-tolerant or batched OP
if not called_func.__self__.turbo or hasattr(
called_func.__self__,
'is_batched_op') and called_func.__self__.is_batched_op():
kargs['batched'] = True
kargs['batch_size'] = kargs.pop('batch_size', 1)
else:
kargs['batched'] = False

if 'new_fingerprint' not in kargs or kargs['new_fingerprint'] is None:
new_fingerprint = generate_fingerprint(self, *args, **kargs)
Expand Down
2 changes: 2 additions & 0 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def __init__(self, *args, **kwargs):
if isinstance(self.mem_required, str):
self.mem_required = size_to_bytes(self.mem_required) / 1024**3

self.turbo = kwargs.get('turbo', False)

# nested wrappers
from data_juicer.core.data import wrap_func_with_nested_access
for name in ['process', 'compute_stats', 'compute_hash']:
Expand Down
7 changes: 7 additions & 0 deletions tests/config/test_config_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def test_yaml_cfg_file(self):
'num_proc': 4,
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
}
}, 'nested dict load fail, for nonparametric op')
self.assertDictEqual(
Expand All @@ -65,6 +66,7 @@ def test_yaml_cfg_file(self):
'stats_export_path': None,
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
}
}, 'nested dict load fail, un-expected internal value')

Expand Down Expand Up @@ -131,6 +133,7 @@ def test_mixture_cfg(self):
'stats_export_path': None,
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
}
})
self.assertDictEqual(
Expand All @@ -147,6 +150,7 @@ def test_mixture_cfg(self):
'stats_export_path': None,
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
}
})
self.assertDictEqual(
Expand All @@ -163,6 +167,7 @@ def test_mixture_cfg(self):
'stats_export_path': None,
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
}
})
self.assertDictEqual(
Expand All @@ -179,6 +184,7 @@ def test_mixture_cfg(self):
'stats_export_path': None,
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
}
})
self.assertDictEqual(
Expand All @@ -195,6 +201,7 @@ def test_mixture_cfg(self):
'stats_export_path': None,
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
}
})

Expand Down

0 comments on commit 7954241

Please sign in to comment.