diff --git a/deepface/commons/image_utils.py b/deepface/commons/image_utils.py index b72ce0b4..868eaf25 100644 --- a/deepface/commons/image_utils.py +++ b/deepface/commons/image_utils.py @@ -1,7 +1,7 @@ # built-in dependencies import os import io -from typing import List, Union, Tuple +from typing import Generator, List, Union, Tuple import hashlib import base64 from pathlib import Path @@ -14,6 +14,10 @@ from werkzeug.datastructures import FileStorage +IMAGE_EXTS = {".jpg", ".jpeg", ".png"} +PIL_EXTS = {"jpeg", "png"} + + def list_images(path: str) -> List[str]: """ List images in a given path @@ -25,17 +29,29 @@ def list_images(path: str) -> List[str]: images = [] for r, _, f in os.walk(path): for file in f: - exact_path = os.path.join(r, file) - - ext_lower = os.path.splitext(exact_path)[-1].lower() + if os.path.splitext(file)[1].lower() in IMAGE_EXTS: + exact_path = os.path.join(r, file) + with Image.open(exact_path) as img: # lazy + if img.format.lower() in PIL_EXTS: + images.append(exact_path) + return images - if ext_lower not in {".jpg", ".jpeg", ".png"}: - continue - with Image.open(exact_path) as img: # lazy - if img.format.lower() in {"jpeg", "png"}: - images.append(exact_path) - return images +def yield_images(path: str) -> Generator[str, None, None]: + """ + Yield images in a given path + Args: + path (str): path's location + Yields: + image (str): image path + """ + for r, _, f in os.walk(path): + for file in f: + if os.path.splitext(file)[1].lower() in IMAGE_EXTS: + exact_path = os.path.join(r, file) + with Image.open(exact_path) as img: # lazy + if img.format.lower() in PIL_EXTS: + yield exact_path def find_image_hash(file_path: str) -> str: diff --git a/deepface/modules/recognition.py b/deepface/modules/recognition.py index f1531324..90e8c297 100644 --- a/deepface/modules/recognition.py +++ b/deepface/modules/recognition.py @@ -136,7 +136,7 @@ def find( representations = [] # required columns for representations - df_cols = [ + df_cols = { "identity", "hash", "embedding", @@ -144,7 +144,7 @@ def find( "target_y", "target_w", "target_h", - ] + } # Ensure the proper pickle file exists if not os.path.exists(datastore_path): @@ -157,18 +157,15 @@ def find( # check each item of representations list has required keys for i, current_representation in enumerate(representations): - missing_keys = set(df_cols) - set(current_representation.keys()) + missing_keys = df_cols - set(current_representation.keys()) if len(missing_keys) > 0: raise ValueError( f"{i}-th item does not have some required keys - {missing_keys}." f"Consider to delete {datastore_path}" ) - # embedded images - pickled_images = [representation["identity"] for representation in representations] - # Get the list of images on storage - storage_images = image_utils.list_images(path=db_path) + storage_images = set(image_utils.yield_images(path=db_path)) if len(storage_images) == 0 and refresh_database is True: raise ValueError(f"No item found in {db_path}") @@ -186,8 +183,13 @@ def find( # Enforce data consistency amongst on disk images and pickle file if refresh_database: - new_images = set(storage_images) - set(pickled_images) # images added to storage - old_images = set(pickled_images) - set(storage_images) # images removed from storage + # embedded images + pickled_images = { + representation["identity"] for representation in representations + } + + new_images = storage_images - pickled_images # images added to storage + old_images = pickled_images - storage_images # images removed from storage # detect replaced images for current_representation in representations: diff --git a/tests/test_find.py b/tests/test_find.py index ffea91b8..de8956da 100644 --- a/tests/test_find.py +++ b/tests/test_find.py @@ -95,12 +95,23 @@ def test_filetype_for_find(): def test_filetype_for_find_bulk_embeddings(): - imgs = image_utils.list_images("dataset") + # List + list_imgs = image_utils.list_images("dataset") - assert len(imgs) > 0 + assert len(list_imgs) > 0 # img47 is webp even though its extension is jpg - assert "dataset/img47.jpg" not in imgs + assert "dataset/img47.jpg" not in list_imgs + + # Generator + gen_imgs = list(image_utils.yield_images("dataset")) + + assert len(gen_imgs) > 0 + + # img47 is webp even though its extension is jpg + assert "dataset/img47.jpg" not in gen_imgs + + assert gen_imgs == list_imgs def test_find_without_refresh_database():