diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 9cb64fa30..de53e0420 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -316,6 +316,11 @@ process: score_threshold: 0.5 # the nsfw score threshold for samples, range from 0 to 1. Samples with nsfw score less than this threshold will be kept. any_or_all: any # keep this sample when any/all images meet the filter condition mem_required: '1GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched + - image_pair_similarity_filter: # filter samples according to the similarity score between the image pair. + hf_clip: 'openai/clip-vit-base-patch32' # model name of the CLIP model on huggingface + min_score: 0.1 # the min similarity score of filter range + max_score: 1.0 # the max similarity score of filter range + any_or_all: "any" # keep this sample when any/all images meet the filter condition - image_shape_filter: # filter samples according to the widths and heights of images in them min_width: 200 # the min width of width filter range max_width: 5000 # the max width of width filter range diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index abce40a5b..68e9ba521 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -4,7 +4,8 @@ average_line_length_filter, character_repetition_filter, flagged_words_filter, image_aesthetics_filter, image_aspect_ratio_filter, image_face_ratio_filter, - image_nsfw_filter, image_shape_filter, image_size_filter, + image_nsfw_filter, image_pair_similarity_filter, + image_shape_filter, image_size_filter, image_text_matching_filter, image_text_similarity_filter, image_watermark_filter, language_id_score_filter, maximum_line_length_filter, perplexity_filter, @@ -30,6 +31,7 @@ from .image_aspect_ratio_filter import ImageAspectRatioFilter from .image_face_ratio_filter import ImageFaceRatioFilter from .image_nsfw_filter import ImageNSFWFilter +from .image_pair_similarity_filter import ImagePairSimilarityFilter from .image_shape_filter import ImageShapeFilter from .image_size_filter import ImageSizeFilter from .image_text_matching_filter import ImageTextMatchingFilter @@ -104,6 +106,7 @@ 'FlaggedWordFilter', 'WordRepetitionFilter', 'VideoMotionScoreFilter', + 'ImagePairSimilarityFilter' ] # yapf: enable diff --git a/data_juicer/ops/filter/image_pair_similarity_filter.py b/data_juicer/ops/filter/image_pair_similarity_filter.py new file mode 100644 index 000000000..dcb1b7059 --- /dev/null +++ b/data_juicer/ops/filter/image_pair_similarity_filter.py @@ -0,0 +1,114 @@ +import numpy as np +from jsonargparse.typing import ClosedUnitInterval + +from data_juicer.ops.base_op import OPERATORS, Filter +from data_juicer.ops.op_fusion import LOADED_IMAGES +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.mm_utils import load_data_with_context, load_image +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'image_pair_similarity_filter' + +with AvailabilityChecking(['torch', 'transformers'], OP_NAME): + + import torch + import transformers # noqa: F401 + + # avoid hanging when calling clip in multiprocessing + torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +@LOADED_IMAGES.register_module(OP_NAME) +class ImagePairSimilarityFilter(Filter): + """Filter to keep image pairs with similarities between images + within a specific range.""" + + _accelerator = 'cuda' + + def __init__(self, + hf_clip='openai/clip-vit-base-patch32', + trust_remote_code=False, + min_score: ClosedUnitInterval = 0.1, + max_score: ClosedUnitInterval = 1.0, + any_or_all: str = 'any', + *args, + **kwargs): + """ + Initialization method. + + :param hf_clip: clip model name on huggingface to compute + the similarity between image and text. + :param min_score: The min similarity to keep samples. + :param max_score: The max similarity to keep samples. + :param any_or_all: keep this sample with 'any' or 'all' strategy of + all images. 'any': keep this sample if any images meet the + condition. 'all': keep this sample only if all images meet the + condition. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_score = min_score + self.max_score = max_score + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + self.any = (any_or_all == 'any') + self.model_key = prepare_model(model_type='huggingface', + pretrained_model_name_or_path=hf_clip, + trust_remote_code=trust_remote_code) + + def compute_stats(self, sample, rank=None, context=False): + + # check if it's computed already + if StatsKeys.image_pair_similarity in sample[Fields.stats]: + return sample + + # there is no image in this sample + if (self.image_key not in sample + or not len(sample[self.image_key]) == 2 + or sample[self.image_key][0] == sample[self.image_key][1]): + raise ValueError('Each sample must include two images.') + + # load images + loaded_image_keys = sample[self.image_key] + sample, images = load_data_with_context(sample, context, + loaded_image_keys, load_image) + + similarity = [] + model, processor = get_model(self.model_key, rank, self.use_cuda()) + + image_list = [] + for temp_key in images.keys(): + image_list.append(images[temp_key]) + image_tensors = processor.image_processor( + image_list, return_tensors='pt')['pixel_values'] + image1_batch_feature = model.get_image_features( + image_tensors[0].unsqueeze(0).to(model.device)) + image2_batch_feature = model.get_image_features( + image_tensors[1].unsqueeze(0).to(model.device)) + + similarity = torch.cosine_similarity(image1_batch_feature, + image2_batch_feature, + dim=1) + sample[Fields.stats][StatsKeys.image_pair_similarity] = similarity + + return sample + + def process(self, sample, rank=None): + similarity = sample[Fields.stats][StatsKeys.image_pair_similarity] + if len(similarity) <= 0: + return True + + keep_bools = np.array([ + self.min_score <= sim_value <= self.max_score + for sim_value in similarity + ]) + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 13bddb687..205685c48 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -142,6 +142,7 @@ class StatsKeysConstant(object): image_aesthetics_scores = 'image_aesthetics_scores' image_nsfw_score = 'image_nsfw_score' image_watermark_prob = 'image_watermark_prob' + image_pair_similarity = 'image_pair_similarity' # audios audio_duration = 'audio_duration' diff --git a/docs/Operators.md b/docs/Operators.md index 144550790..7b1f8f0f3 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -12,7 +12,7 @@ The operators in Data-Juicer are categorized into 5 types. |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data | | [ Mapper ]( #mapper ) | 46 | Edits and transforms samples | -| [ Filter ]( #filter ) | 41 | Filters out low-quality samples | +| [ Filter ]( #filter ) | 42 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | @@ -113,6 +113,7 @@ All the specific operators are listed below, each featured with several capabili | image_aspect_ratio_filter | Image | - | Keeps samples containing images with aspect ratios within the specified range | | image_face_ratio_filter | Image | - | Keeps samples containing images with face area ratios within the specified range | | image_nsfw_filter | Image | - | Keeps samples containing images with NSFW scores below the threshold | +| image_pair_similarity_filter | Image | - | Keeps image pairs with image feature cosine similarity within the specified range based on a CLIP model | | image_shape_filter | Image | - | Keeps samples containing images with widths and heights within the specified range | | image_size_filter | Image | - | Keeps samples containing images whose size in bytes are within the specified range | | image_text_matching_filter | Multimodal | - | Keeps samples with image-text classification matching score within the specified range based on a BLIP model | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 3d0e33df3..7ee0bda66 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -12,7 +12,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 | | [ Mapper ]( #mapper ) | 46 | 对数据样本进行编辑和转换 | -| [ Filter ]( #filter ) | 41 | 过滤低质量样本 | +| [ Filter ]( #filter ) | 42 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 5 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -111,6 +111,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | image_aspect_ratio_filter | Image | - | 保留样本中包含的图片的宽高比在指定范围内的样本 | | image_face_ratio_filter | Image | - | 保留样本中包含的图片的最大脸部区域在指定范围内的样本 | | image_nsfw_filter | Image | - | 保留包含NSFW分数在指定阈值之下的图像的样本 | +| image_pair_similarity_filter | Image | - | 保留图像特征余弦相似度(基于CLIP模型)在指定范围内的样本 | | image_shape_filter | Image | - | 保留样本中包含的图片的形状(即宽和高)在指定范围内的样本 | | image_size_filter | Image | - | 保留样本中包含的图片的大小(bytes)在指定范围内的样本 | | image_text_matching_filter | Multimodal | - | 保留图像-文本的分类匹配分(基于BLIP模型)在指定范围内的样本 | diff --git a/tests/ops/filter/test_image_pair_similarity_filter.py b/tests/ops/filter/test_image_pair_similarity_filter.py new file mode 100644 index 000000000..590889db6 --- /dev/null +++ b/tests/ops/filter/test_image_pair_similarity_filter.py @@ -0,0 +1,67 @@ +import os +import unittest + +from data_juicer.core.data import NestedDataset as Dataset + +from data_juicer.ops.filter.image_pair_similarity_filter import ImagePairSimilarityFilter +from data_juicer.utils.constant import Fields +from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class ImagePairSimilarityFilterTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + cat_path = os.path.join(data_path, 'cat.jpg') + img2_path = os.path.join(data_path, 'img2.jpg') + img3_path = os.path.join(data_path, 'img3.jpg') + img5_path = os.path.join(data_path, 'img5.jpg') + img7_path = os.path.join(data_path, 'img7.jpg') + hf_clip = 'openai/clip-vit-base-patch32' + + @classmethod + def tearDownClass(cls) -> None: + super().tearDownClass(cls.hf_clip) + + def _run_filter(self, dataset: Dataset, op, num_proc=1): + + if Fields.stats not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + + dataset = dataset.map(op.compute_stats, + num_proc=num_proc, + with_rank=True) + dataset = dataset.filter(op.process, num_proc=num_proc) + dataset = dataset.select_columns(column_names=['text', 'images']) + res_list = dataset.to_list() + print(res_list) + + def test_no_eoc_special_token(self): + + ds_list = [{ + 'text': 'image pair 1', + 'images': [self.cat_path, self.img3_path] + }, { + 'text': 'image pair 2', + 'images': [self.img3_path, self.img7_path] + }, { + 'text': 'image pair 3', + 'images': [self.img2_path, self.img5_path] + }] + + + dataset = Dataset.from_list(ds_list) + op = ImagePairSimilarityFilter(hf_clip=self.hf_clip, + any_or_all='any', + min_score=0.85, + max_score=1) + self._run_filter(dataset, op) + + +if __name__ == '__main__': + unittest.main()