diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 9de809e78..49dacb6fb 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -255,6 +255,7 @@ process: sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} - generate_qa_from_text_mapper: # mapper to generate question and answer pairs from text. hf_model: 'alibaba-pai/pai-qwen1_5-7b-doc2qa' # Model name on huggingface to generate question and answer pairs. + max_num: null # The max num of returned QA sample for each text. Not limit if it is None. output_pattern: null # Regular expression pattern to extract questions and answers from model response. enable_vllm: false # Whether to use vllm for inference acceleration. model_params: {} # Parameters for initializing the model. diff --git a/data_juicer/ops/mapper/generate_qa_from_text_mapper.py b/data_juicer/ops/mapper/generate_qa_from_text_mapper.py index 0f3a1cfef..3113f0f95 100644 --- a/data_juicer/ops/mapper/generate_qa_from_text_mapper.py +++ b/data_juicer/ops/mapper/generate_qa_from_text_mapper.py @@ -2,6 +2,7 @@ from typing import Dict, Optional from loguru import logger +from pydantic import PositiveInt from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.lazy_loader import LazyLoader @@ -35,6 +36,7 @@ class GenerateQAFromTextMapper(Mapper): def __init__(self, hf_model: str = 'alibaba-pai/pai-qwen1_5-7b-doc2qa', + max_num: Optional[PositiveInt] = None, *, output_pattern: Optional[str] = None, enable_vllm: bool = False, @@ -45,6 +47,8 @@ def __init__(self, Initialization method. :param hf_model: Hugginface model ID. + :param max_num: The max num of returned QA sample for each text. + Not limit if it is None. :param output_pattern: Regular expression pattern to extract questions and answers from model response. :param enable_vllm: Whether to use vllm for inference acceleration. @@ -69,6 +73,8 @@ def __init__(self, super().__init__(**kwargs) + self.max_num = max_num + if output_pattern is None: self.output_pattern = r'Human:(.*?)Assistant:(.*?)(?=Human|$)' # noqa: E501 else: @@ -131,6 +137,10 @@ def process_batched(self, samples, rank=None): output = response[0]['generated_text'] qa_list = self.parse_output(output) + + if self.max_num is not None: + qa_list = qa_list[:self.max_num] + if len(qa_list) > 0: for q, a in qa_list: for input_k in input_keys: diff --git a/tests/ops/mapper/test_generate_qa_from_text_mapper.py b/tests/ops/mapper/test_generate_qa_from_text_mapper.py index e67285b18..7b3131fd3 100644 --- a/tests/ops/mapper/test_generate_qa_from_text_mapper.py +++ b/tests/ops/mapper/test_generate_qa_from_text_mapper.py @@ -19,11 +19,13 @@ def _run_op(self, enable_vllm=False, model_params=None, sampling_params=None, - num_proc=1): + num_proc=1, + max_num=None): op = GenerateQAFromTextMapper(enable_vllm=enable_vllm, model_params=model_params, - sampling_params=sampling_params) + sampling_params=sampling_params, + max_num=max_num) samples = [{ self.text_key: @@ -36,6 +38,9 @@ def _run_op(self, dataset = Dataset.from_list(samples) results = dataset.map(op.process, num_proc=num_proc, with_rank=True) + if max_num is not None: + self.assertLessEqual(len(results), len(samples)*max_num) + for row in results: logger.info(row) self.assertIn(op.query_key, row) @@ -45,6 +50,10 @@ def test(self): sampling_params = {'max_new_tokens': 200} self._run_op(sampling_params=sampling_params) + def test_max_num(self): + sampling_params = {'max_new_tokens': 200} + self._run_op(sampling_params=sampling_params, max_num=1) + def test_multi_process(self): sampling_params = {'max_new_tokens': 200} self._run_op(sampling_params=sampling_params, num_proc=2)