Skip to content

Commit bece0d1

Browse files
committed
* fix skip_op_error & update_sampling_params
1 parent 684666e commit bece0d1

File tree

4 files changed

+10
-13
lines changed

4 files changed

+10
-13
lines changed

data_juicer/ops/base_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def wrapper(sample, *args, **kwargs):
108108
else:
109109
return [res]
110110
except Exception as e:
111-
if skip_op_error:
111+
if not skip_op_error:
112112
raise
113113
from loguru import logger
114114
logger.error(f'An error occurred in {op_name} when processing '

data_juicer/ops/mapper/generate_qa_from_examples_mapper.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ def __init__(self,
118118
model_params = model_params or {}
119119
sampling_params = sampling_params or {}
120120

121+
sampling_params = update_sampling_params(sampling_params, hf_model,
122+
self.enable_vllm)
123+
121124
if enable_vllm:
122125
assert torch.cuda.device_count() >= 1, 'must be executed in CUDA'
123126
# cannot initialize vllm replicas on different GPUs
@@ -140,10 +143,6 @@ def __init__(self,
140143
**model_params)
141144
self.sampling_params = sampling_params
142145

143-
self.sampling_params = update_sampling_params(sampling_params,
144-
hf_model,
145-
self.enable_vllm)
146-
147146
self.seed_qa_samples = self._load_seed_qa_samples()
148147
if len(self.seed_qa_samples) == 0:
149148
raise ValueError('No QA data was parsed from the seed file!')

data_juicer/ops/mapper/generate_qa_from_text_mapper.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def __init__(self,
8585
model_params = model_params or {}
8686
sampling_params = sampling_params or {}
8787

88+
sampling_params = update_sampling_params(sampling_params, hf_model,
89+
self.enable_vllm)
90+
8891
if enable_vllm:
8992
assert torch.cuda.device_count() >= 1, 'must be executed in CUDA'
9093
# cannot initialize vllm replicas on different GPUs
@@ -107,10 +110,6 @@ def __init__(self,
107110
**model_params)
108111
self.sampling_params = sampling_params
109112

110-
self.sampling_params = update_sampling_params(sampling_params,
111-
hf_model,
112-
self.enable_vllm)
113-
114113
def parse_output(self, raw_output):
115114
logger.debug(raw_output)
116115
qa_list = []

data_juicer/ops/mapper/optimize_qa_mapper.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def __init__(self,
7777
model_params = model_params or {}
7878
sampling_params = sampling_params or {}
7979

80+
sampling_params = update_sampling_params(sampling_params, hf_model,
81+
self.enable_vllm)
82+
8083
if enable_vllm:
8184
assert torch.cuda.device_count() >= 1, 'must be executed in CUDA'
8285
# cannot initialize vllm replicas on different GPUs
@@ -99,10 +102,6 @@ def __init__(self,
99102
**model_params)
100103
self.sampling_params = sampling_params
101104

102-
self.sampling_params = update_sampling_params(sampling_params,
103-
hf_model,
104-
self.enable_vllm)
105-
106105
def build_input(self, sample):
107106
qa_pair = self.qa_pair_template.format(sample[self.query_key],
108107
sample[self.response_key])

0 commit comments

Comments
 (0)