Skip to content

Commit

Permalink
Refine/llm api op unittest (#528)
Browse files Browse the repository at this point in the history
* * update unittests

* tags specified field

* doc done

* + add reference

* move mm tags

* move meta key

* done

* test done

* rm nested set

* enable op error for unittest

* enhance api unittest

* expose skip_op_error

* fix typo

---------

Co-authored-by: null <[email protected]>
Co-authored-by: gece.gc <[email protected]>
Co-authored-by: lielin.hyl <[email protected]>
  • Loading branch information
4 people authored Jan 17, 2025
1 parent 129c75a commit 0815c29
Show file tree
Hide file tree
Showing 18 changed files with 103 additions and 30 deletions.
2 changes: 2 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ np: 4 # number of subproce
text_keys: 'text' # the key name of field where the sample texts to be processed, e.g., `text`, `instruction`, `output`, ...
# Note: currently, we support specify only ONE key for each op, for cases requiring multiple keys, users can specify the op multiple times. We will only use the first key of `text_keys` when you set multiple keys.
suffixes: [] # the suffix of files that will be read. For example: '.txt', 'txt' or ['txt', '.pdf', 'docx']
turbo: false # Enable Turbo mode to maximize processing speed when batch size is 1.
skip_op_error: true # Skip errors in OPs caused by unexpected invalid samples.
use_cache: true # whether to use the cache management of Hugging Face datasets. It might take up lots of disk space when using cache
ds_cache_dir: null # cache dir for Hugging Face datasets. In default, it\'s the same as the environment variable `HF_DATASETS_CACHE`, whose default value is usually "~/.cache/huggingface/datasets". If this argument is set to a valid path by users, it will override the default cache dir
open_monitor: true # Whether to open the monitor to trace resource utilization for each OP during data processing. It\'s True in default.
Expand Down
10 changes: 8 additions & 2 deletions data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,13 @@ def init_configs(args: Optional[List[str]] = None, which_entry: object = None):
'--turbo',
type=bool,
default=False,
help='Enable Turbo mode to maximize processing speed. Stability '
'features like fault tolerance will be disabled.')
help='Enable Turbo mode to maximize processing speed when batch size '
'is 1.')
parser.add_argument(
'--skip_op_error',
type=bool,
default=True,
help='Skip errors in OPs caused by unexpected invalid samples.')
parser.add_argument(
'--use_cache',
type=bool,
Expand Down Expand Up @@ -550,6 +555,7 @@ def init_setup_from_cfg(cfg: Namespace):
'video_key': cfg.video_key,
'num_proc': cfg.np,
'turbo': cfg.turbo,
'skip_op_error': cfg.skip_op_error,
'work_dir': cfg.work_dir,
}
cfg.process = update_op_attr(cfg.process, op_attrs)
Expand Down
67 changes: 48 additions & 19 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def wrapper(sample, *args, **kwargs):
return wrapper


def catch_map_batches_exception(method, op_name=None):
def catch_map_batches_exception(method, skip_op_error=False, op_name=None):
"""
For batched-map sample-level fault tolerance.
"""
Expand All @@ -61,6 +61,8 @@ def wrapper(samples, *args, **kwargs):
try:
return method(samples, *args, **kwargs)
except Exception as e:
if not skip_op_error:
raise
from loguru import logger
logger.error(f'An error occurred in {op_name} when processing '
f'samples "{samples}" -- {type(e)}: {e}')
Expand All @@ -72,7 +74,10 @@ def wrapper(samples, *args, **kwargs):
return wrapper


def catch_map_single_exception(method, return_sample=True, op_name=None):
def catch_map_single_exception(method,
return_sample=True,
skip_op_error=False,
op_name=None):
"""
For single-map sample-level fault tolerance.
The input sample is expected batch_size = 1.
Expand Down Expand Up @@ -103,6 +108,8 @@ def wrapper(sample, *args, **kwargs):
else:
return [res]
except Exception as e:
if skip_op_error:
raise
from loguru import logger
logger.error(f'An error occurred in {op_name} when processing '
f'sample "{sample}" -- {type(e)}: {e}')
Expand Down Expand Up @@ -157,6 +164,10 @@ def __init__(self, *args, **kwargs):
self.batch_size = kwargs.get('batch_size', 1000)
self.work_dir = kwargs.get('work_dir', None)

# for unittest, do not skip the error.
# It would be set to be True in config init.
self.skip_op_error = kwargs.get('skip_op_error', False)

# whether the model can be accelerated using cuda
_accelerator = kwargs.get('accelerator', None)
if _accelerator is not None:
Expand Down Expand Up @@ -278,11 +289,15 @@ def __init__(self, *args, **kwargs):

# runtime wrappers
if self.is_batched_op():
self.process = catch_map_batches_exception(self.process_batched,
op_name=self._name)
self.process = catch_map_batches_exception(
self.process_batched,
skip_op_error=self.skip_op_error,
op_name=self._name)
else:
self.process = catch_map_single_exception(self.process_single,
op_name=self._name)
self.process = catch_map_single_exception(
self.process_single,
skip_op_error=self.skip_op_error,
op_name=self._name)

# set the process method is not allowed to be overridden
def __init_subclass__(cls, **kwargs):
Expand Down Expand Up @@ -369,15 +384,23 @@ def __init__(self, *args, **kwargs):
# runtime wrappers
if self.is_batched_op():
self.compute_stats = catch_map_batches_exception(
self.compute_stats_batched, op_name=self._name)
self.process = catch_map_batches_exception(self.process_batched,
op_name=self._name)
self.compute_stats_batched,
skip_op_error=self.skip_op_error,
op_name=self._name)
self.process = catch_map_batches_exception(
self.process_batched,
skip_op_error=self.skip_op_error,
op_name=self._name)
else:
self.compute_stats = catch_map_single_exception(
self.compute_stats_single, op_name=self._name)
self.process = catch_map_single_exception(self.process_single,
return_sample=False,
op_name=self._name)
self.compute_stats_single,
skip_op_error=self.skip_op_error,
op_name=self._name)
self.process = catch_map_single_exception(
self.process_single,
return_sample=False,
skip_op_error=self.skip_op_error,
op_name=self._name)

# set the process method is not allowed to be overridden
def __init_subclass__(cls, **kwargs):
Expand Down Expand Up @@ -486,11 +509,15 @@ def __init__(self, *args, **kwargs):

# runtime wrappers
if self.is_batched_op():
self.compute_hash = catch_map_batches_exception(self.compute_hash,
op_name=self._name)
self.compute_hash = catch_map_batches_exception(
self.compute_hash,
skip_op_error=self.skip_op_error,
op_name=self._name)
else:
self.compute_hash = catch_map_single_exception(self.compute_hash,
op_name=self._name)
self.compute_hash = catch_map_single_exception(
self.compute_hash,
skip_op_error=self.skip_op_error,
op_name=self._name)

def compute_hash(self, sample):
"""
Expand Down Expand Up @@ -626,8 +653,10 @@ def __init__(self, *args, **kwargs):
queries and responses
"""
super(Aggregator, self).__init__(*args, **kwargs)
self.process = catch_map_single_exception(self.process_single,
op_name=self._name)
self.process = catch_map_single_exception(
self.process_single,
skip_op_error=self.skip_op_error,
op_name=self._name)

def process_single(self, sample):
"""
Expand Down
7 changes: 7 additions & 0 deletions tests/config/test_config_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_yaml_cfg_file(self):
'turbo': False,
'batch_size': 1000,
'index_key': None,
'skip_op_error': True,
'work_dir': WORKDIR,
}
}, 'nested dict load fail, for nonparametric op')
Expand All @@ -79,6 +80,7 @@ def test_yaml_cfg_file(self):
'turbo': False,
'batch_size': 1000,
'index_key': None,
'skip_op_error': True,
'work_dir': WORKDIR,
}
}, 'nested dict load fail, un-expected internal value')
Expand Down Expand Up @@ -151,6 +153,7 @@ def test_mixture_cfg(self):
'turbo': False,
'batch_size': 1000,
'index_key': None,
'skip_op_error': True,
'work_dir': WORKDIR,
}
})
Expand All @@ -174,6 +177,7 @@ def test_mixture_cfg(self):
'turbo': False,
'batch_size': 1000,
'index_key': None,
'skip_op_error': True,
'work_dir': WORKDIR,
}
})
Expand All @@ -197,6 +201,7 @@ def test_mixture_cfg(self):
'turbo': False,
'batch_size': 1000,
'index_key': None,
'skip_op_error': True,
'work_dir': WORKDIR,
}
})
Expand All @@ -220,6 +225,7 @@ def test_mixture_cfg(self):
'turbo': False,
'batch_size': 1000,
'index_key': None,
'skip_op_error': True,
'work_dir': WORKDIR,
}
})
Expand All @@ -243,6 +249,7 @@ def test_mixture_cfg(self):
'turbo': False,
'batch_size': 1000,
'index_key': None,
'skip_op_error': True,
'work_dir': WORKDIR,
}
})
Expand Down
8 changes: 5 additions & 3 deletions tests/ops/aggregator/test_entity_attribute_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.aggregator import EntityAttributeAggregator
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.constant import Fields, BatchMetaKeys, MetaKeys


@SKIPPED_TESTS.register_module()
class EntityAttributeAggregatorTest(DataJuicerTestCaseBase):

def _run_helper(self, op, samples):
def _run_helper(self, op, samples, output_key=BatchMetaKeys.entity_attribute):

# before runing this test, set below environment variables:
# export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/
Expand All @@ -23,6 +23,8 @@ def _run_helper(self, op, samples):
for data in new_dataset:
for k in data:
logger.info(f"{k}: {data[k]}")
self.assertIn(output_key, data[Fields.batch_meta])
self.assertNotEqual(data[Fields.batch_met][output_key], '')

self.assertEqual(len(new_dataset), len(samples))

Expand Down Expand Up @@ -64,7 +66,7 @@ def test_input_output(self):
input_key='sub_docs',
output_key='text'
)
self._run_helper(op, samples)
self._run_helper(op, samples, output_key='text')

def test_max_token_num(self):
samples = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from data_juicer.ops.aggregator import MostRelavantEntitiesAggregator
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS

from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.constant import Fields, BatchMetaKeys, MetaKeys


@SKIPPED_TESTS.register_module()
class MostRelavantEntitiesAggregatorTest(DataJuicerTestCaseBase):

def _run_helper(self, op, samples):
def _run_helper(self, op, samples, output_key=BatchMetaKeys.most_relavant_entities):

# before runing this test, set below environment variables:
# export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/
Expand All @@ -24,6 +24,8 @@ def _run_helper(self, op, samples):
for data in new_dataset:
for k in data:
logger.info(f"{k}: {data[k]}")
self.assertIn(output_key, data[Fields.batch_meta])
self.assertNotEqual(data[Fields.batch_meta][output_key], '')

self.assertEqual(len(new_dataset), len(samples))

Expand Down Expand Up @@ -67,7 +69,7 @@ def test_input_output(self):
input_key='events',
output_key='relavant_roles'
)
self._run_helper(op, samples)
self._run_helper(op, samples, output_key='relavant_roles')

def test_max_token_num(self):
samples = [
Expand Down
6 changes: 4 additions & 2 deletions tests/ops/aggregator/test_nested_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@SKIPPED_TESTS.register_module()
class NestedAggregatorTest(DataJuicerTestCaseBase):

def _run_helper(self, op, samples):
def _run_helper(self, op, samples, output_key=MetaKeys.event_description):

# before runing this test, set below environment variables:
# export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/
Expand All @@ -24,6 +24,8 @@ def _run_helper(self, op, samples):
for data in new_dataset:
for k in data:
logger.info(f"{k}: {data[k]}")
self.assertIn(output_key, data[Fields.batch_meta])
self.assertNotEqual(data[Fields.batch_meta][output_key], '')

self.assertEqual(len(new_dataset), len(samples))

Expand Down Expand Up @@ -61,7 +63,7 @@ def test_input_output(self):
input_key='sub_docs',
output_key='text'
)
self._run_helper(op, samples)
self._run_helper(op, samples, output_key='text')

def test_max_token_num_1(self):
samples = [
Expand Down
2 changes: 2 additions & 0 deletions tests/ops/mapper/test_dialog_intent_detection_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def _run_op(self, op, samples, target_len, labels_key=None, analysis_key=None):
for analysis, labels in zip(analysis_list, labels_list):
logger.info(f'分析:{analysis}')
logger.info(f'意图:{labels}')
self.assertNotEqual(analysis, '')
self.assertNotEqual(labels, '')

self.assertEqual(len(analysis_list), target_len)
self.assertEqual(len(labels_list), target_len)
Expand Down
2 changes: 2 additions & 0 deletions tests/ops/mapper/test_dialog_sentiment_detection_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def _run_op(self, op, samples, target_len, labels_key=None, analysis_key=None):
for analysis, labels in zip(analysis_list, labels_list):
logger.info(f'分析:{analysis}')
logger.info(f'情绪:{labels}')
self.assertNotEqual(analysis, '')
self.assertNotEqual(labels, '')

self.assertEqual(len(analysis_list), target_len)
self.assertEqual(len(labels_list), target_len)
Expand Down
1 change: 1 addition & 0 deletions tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def _run_op(self, op, samples, target_len, intensities_key=None, analysis_key=No
for analysis, intensity in zip(analysis_list, intensity_list):
logger.info(f'分析:{analysis}')
logger.info(f'情绪:{intensity}')
self.assertNotEqual(analysis, '')

self.assertEqual(len(analysis_list), target_len)
self.assertEqual(len(intensity_list), target_len)
Expand Down
2 changes: 2 additions & 0 deletions tests/ops/mapper/test_dialog_topic_detection_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def _run_op(self, op, samples, target_len, labels_key=None, analysis_key=None):
for analysis, labels in zip(analysis_list, labels_list):
logger.info(f'分析:{analysis}')
logger.info(f'话题:{labels}')
self.assertNotEqual(analysis, '')
self.assertNotEqual(labels, '')

self.assertEqual(len(analysis_list), target_len)
self.assertEqual(len(labels_list), target_len)
Expand Down
4 changes: 4 additions & 0 deletions tests/ops/mapper/test_extract_entity_attribute_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def _run_op(self, api_model, response_path=None):
dataset = Dataset.from_list(samples)
dataset = op.run(dataset)
for sample in dataset:
self.assertIn(MetaKeys.main_entities, sample[Fields.meta])
self.assertIn(MetaKeys.attributes, sample[Fields.meta])
self.assertIn(MetaKeys.attribute_descriptions, sample[Fields.meta])
self.assertIn(MetaKeys.attribute_support_texts, sample[Fields.meta])
ents = sample[Fields.meta][MetaKeys.main_entities]
attrs = sample[Fields.meta][MetaKeys.attributes]
descs = sample[Fields.meta][MetaKeys.attribute_descriptions]
Expand Down
4 changes: 4 additions & 0 deletions tests/ops/mapper/test_extract_entity_relation_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def _run_op(self, op):
dataset = Dataset.from_list(samples)
dataset = op.run(dataset)
sample = dataset[0]
self.assertIn(MetaKeys.entity, sample[Fields.meta])
self.assertIn(MetaKeys.relation, sample[Fields.meta])
self.assertNotEqual(len(sample[Fields.meta][MetaKeys.entity]), 0)
self.assertNotEqual(len(sample[Fields.meta][MetaKeys.relation]), 0)
logger.info(f"entitis: {sample[Fields.meta][MetaKeys.entity]}")
logger.info(f"relations: {sample[Fields.meta][MetaKeys.relation]}")

Expand Down
3 changes: 2 additions & 1 deletion tests/ops/mapper/test_extract_event_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def _run_op(self, api_model, response_path=None):

dataset = Dataset.from_list(samples)
dataset = op.run(dataset)
self.assertNotEqual(len(dataset), 0)
for sample in dataset:
self.assertIn(MetaKeys.event_description, sample[Fields.meta])
self.assertIn(MetaKeys.relevant_characters, sample[Fields.meta])
logger.info(f"chunk_id: {sample['chunk_id']}")
self.assertEqual(sample['chunk_id'], 0)
logger.info(f"event: {sample[Fields.meta][MetaKeys.event_description]}")
Expand Down
2 changes: 2 additions & 0 deletions tests/ops/mapper/test_extract_keyword_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def _run_op(self, api_model, response_path=None):
dataset = Dataset.from_list(samples)
dataset = op.run(dataset)
sample = dataset[0]
self.assertIn(MetaKeys.keyword, sample[Fields.meta])
self.assertNotEqual(len(sample[Fields.meta][MetaKeys.keyword]), 0)
logger.info(f"keywords: {sample[Fields.meta][MetaKeys.keyword]}")

def test(self):
Expand Down
Loading

0 comments on commit 0815c29

Please sign in to comment.