Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyushuo committed Jan 20, 2025
2 parents faec9f2 + 030e786 commit 46ba4dd
Show file tree
Hide file tree
Showing 23 changed files with 269 additions and 201 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ repos:
language: python
require_serial: true
additional_dependencies:
- googletrans==4.0.2
- translators==5.9.3

exclude: |
(?x)^(
Expand Down
27 changes: 11 additions & 16 deletions .pre-commit-hooks/build_op_doc.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import ast
import asyncio
import json
import os
import re
from typing import Any, List

from googletrans import Translator
import translators as ts

DOC_PATH = 'docs/Operators.md'

Expand Down Expand Up @@ -152,7 +151,7 @@ def analyze_resource_tag(code):
def analyze_model_tags(code):
"""
Analyze the model tag for the given code content string. SHOULD be one of
the "Modal Tags" in `tagging_mappings.json`. It makes the choice by finding
the "Model Tags" in `tagging_mappings.json`. It makes the choice by finding
the `model_type` arg in `prepare_model` method invocation.
"""
pattern = r'model_type=[\'|\"](.*?)[\'|\"]'
Expand Down Expand Up @@ -431,20 +430,16 @@ def generate_op_table_section(op_type, op_record_list):
return '\n\n'.join(doc)


async def translate_text(text, dest='zh'):
async with Translator() as translator:
res = await translator.translate(text, src='en', dest=dest)
return res


def get_op_desc_in_en_zh_batched(descs):
zhs = asyncio.run(translate_text(descs, dest='zh'))
return [desc + ' ' + zh.text for desc, zh in zip(descs, zhs)]


def get_op_desc_in_en_zh(desc):
zh = asyncio.run(translate_text(desc, dest='zh')).text
return desc + ' ' + zh
separator = '\n'
batch = separator.join(descs)
res = ts.translate_text(batch,
translator='alibaba',
from_language='en',
to_language='zh')
zhs = res.split(separator)
assert len(zhs) == len(descs)
return [desc + ' ' + zh.strip() for desc, zh in zip(descs, zhs)]


def parse_op_record_from_current_doc():
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-hooks/tag_mappings.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"desc": "stable version OP. Based on the beta version, OP optimizations related to DJ (e.g. model management, batched processing, OP fusion, ...) are added to this OP. 表示 stable 版本算子。基于 beta 版本,完善了DJ相关的算子优化项(如模型管理,批处理,算子融合等)。"
}
},
"Modal Tags": {
"Model Tags": {
"api": {
"icon": "🔗API",
"desc": "equipped with API-based models. (e.g. ChatGPT, GPT-4o). 支持基于 API 调用模型(如 ChatGPT,GPT-4o)。"
Expand Down
2 changes: 2 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ np: 4 # number of subproce
text_keys: 'text' # the key name of field where the sample texts to be processed, e.g., `text`, `instruction`, `output`, ...
# Note: currently, we support specify only ONE key for each op, for cases requiring multiple keys, users can specify the op multiple times. We will only use the first key of `text_keys` when you set multiple keys.
suffixes: [] # the suffix of files that will be read. For example: '.txt', 'txt' or ['txt', '.pdf', 'docx']
turbo: false # Enable Turbo mode to maximize processing speed when batch size is 1.
skip_op_error: true # Skip errors in OPs caused by unexpected invalid samples.
use_cache: true # whether to use the cache management of Hugging Face datasets. It might take up lots of disk space when using cache
ds_cache_dir: null # cache dir for Hugging Face datasets. In default, it\'s the same as the environment variable `HF_DATASETS_CACHE`, whose default value is usually "~/.cache/huggingface/datasets". If this argument is set to a valid path by users, it will override the default cache dir
open_monitor: true # Whether to open the monitor to trace resource utilization for each OP during data processing. It\'s True in default.
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '1.0.3'
__version__ = '1.1.0'

import os
import subprocess
Expand Down
10 changes: 8 additions & 2 deletions data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,13 @@ def init_configs(args: Optional[List[str]] = None, which_entry: object = None):
'--turbo',
type=bool,
default=False,
help='Enable Turbo mode to maximize processing speed. Stability '
'features like fault tolerance will be disabled.')
help='Enable Turbo mode to maximize processing speed when batch size '
'is 1.')
parser.add_argument(
'--skip_op_error',
type=bool,
default=True,
help='Skip errors in OPs caused by unexpected invalid samples.')
parser.add_argument(
'--use_cache',
type=bool,
Expand Down Expand Up @@ -550,6 +555,7 @@ def init_setup_from_cfg(cfg: Namespace):
'video_key': cfg.video_key,
'num_proc': cfg.np,
'turbo': cfg.turbo,
'skip_op_error': cfg.skip_op_error,
'work_dir': cfg.work_dir,
}
cfg.process = update_op_attr(cfg.process, op_attrs)
Expand Down
67 changes: 48 additions & 19 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def wrapper(sample, *args, **kwargs):
return wrapper


def catch_map_batches_exception(method, op_name=None):
def catch_map_batches_exception(method, skip_op_error=False, op_name=None):
"""
For batched-map sample-level fault tolerance.
"""
Expand All @@ -61,6 +61,8 @@ def wrapper(samples, *args, **kwargs):
try:
return method(samples, *args, **kwargs)
except Exception as e:
if not skip_op_error:
raise
from loguru import logger
logger.error(f'An error occurred in {op_name} when processing '
f'samples "{samples}" -- {type(e)}: {e}')
Expand All @@ -72,7 +74,10 @@ def wrapper(samples, *args, **kwargs):
return wrapper


def catch_map_single_exception(method, return_sample=True, op_name=None):
def catch_map_single_exception(method,
return_sample=True,
skip_op_error=False,
op_name=None):
"""
For single-map sample-level fault tolerance.
The input sample is expected batch_size = 1.
Expand Down Expand Up @@ -103,6 +108,8 @@ def wrapper(sample, *args, **kwargs):
else:
return [res]
except Exception as e:
if skip_op_error:
raise
from loguru import logger
logger.error(f'An error occurred in {op_name} when processing '
f'sample "{sample}" -- {type(e)}: {e}')
Expand Down Expand Up @@ -157,6 +164,10 @@ def __init__(self, *args, **kwargs):
self.batch_size = kwargs.get('batch_size', 1000)
self.work_dir = kwargs.get('work_dir', None)

# for unittest, do not skip the error.
# It would be set to be True in config init.
self.skip_op_error = kwargs.get('skip_op_error', False)

# whether the model can be accelerated using cuda
_accelerator = kwargs.get('accelerator', None)
if _accelerator is not None:
Expand Down Expand Up @@ -278,11 +289,15 @@ def __init__(self, *args, **kwargs):

# runtime wrappers
if self.is_batched_op():
self.process = catch_map_batches_exception(self.process_batched,
op_name=self._name)
self.process = catch_map_batches_exception(
self.process_batched,
skip_op_error=self.skip_op_error,
op_name=self._name)
else:
self.process = catch_map_single_exception(self.process_single,
op_name=self._name)
self.process = catch_map_single_exception(
self.process_single,
skip_op_error=self.skip_op_error,
op_name=self._name)

# set the process method is not allowed to be overridden
def __init_subclass__(cls, **kwargs):
Expand Down Expand Up @@ -369,15 +384,23 @@ def __init__(self, *args, **kwargs):
# runtime wrappers
if self.is_batched_op():
self.compute_stats = catch_map_batches_exception(
self.compute_stats_batched, op_name=self._name)
self.process = catch_map_batches_exception(self.process_batched,
op_name=self._name)
self.compute_stats_batched,
skip_op_error=self.skip_op_error,
op_name=self._name)
self.process = catch_map_batches_exception(
self.process_batched,
skip_op_error=self.skip_op_error,
op_name=self._name)
else:
self.compute_stats = catch_map_single_exception(
self.compute_stats_single, op_name=self._name)
self.process = catch_map_single_exception(self.process_single,
return_sample=False,
op_name=self._name)
self.compute_stats_single,
skip_op_error=self.skip_op_error,
op_name=self._name)
self.process = catch_map_single_exception(
self.process_single,
return_sample=False,
skip_op_error=self.skip_op_error,
op_name=self._name)

# set the process method is not allowed to be overridden
def __init_subclass__(cls, **kwargs):
Expand Down Expand Up @@ -486,11 +509,15 @@ def __init__(self, *args, **kwargs):

# runtime wrappers
if self.is_batched_op():
self.compute_hash = catch_map_batches_exception(self.compute_hash,
op_name=self._name)
self.compute_hash = catch_map_batches_exception(
self.compute_hash,
skip_op_error=self.skip_op_error,
op_name=self._name)
else:
self.compute_hash = catch_map_single_exception(self.compute_hash,
op_name=self._name)
self.compute_hash = catch_map_single_exception(
self.compute_hash,
skip_op_error=self.skip_op_error,
op_name=self._name)

def compute_hash(self, sample):
"""
Expand Down Expand Up @@ -626,8 +653,10 @@ def __init__(self, *args, **kwargs):
queries and responses
"""
super(Aggregator, self).__init__(*args, **kwargs)
self.process = catch_map_single_exception(self.process_single,
op_name=self._name)
self.process = catch_map_single_exception(
self.process_single,
skip_op_error=self.skip_op_error,
op_name=self._name)

def process_single(self, sample):
"""
Expand Down
Loading

0 comments on commit 46ba4dd

Please sign in to comment.