Skip to content

Commit

Permalink
add file extension generalization
Browse files Browse the repository at this point in the history
  • Loading branch information
zhmiao committed Dec 18, 2024
1 parent 67e4ebd commit 8db978f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
1 change: 1 addition & 0 deletions PW_FT_classification/src/algorithms/plain.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import numpy as np
import json
from datetime import datetime
from tqdm import tqdm
import random
Expand Down
16 changes: 14 additions & 2 deletions PW_FT_classification/src/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@
'Custom_Crop'
]

# Define the allowed image extensions
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")

def has_file_allowed_extension(filename: str, extensions: tuple) -> bool:
"""Checks if a file is an allowed extension."""
return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))

def is_image_file(filename: str) -> bool:
"""Checks if a file is an allowed image extension."""
return has_file_allowed_extension(filename, IMG_EXTENSIONS)

# Define normalization mean and standard deviation for image preprocessing
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
Expand Down Expand Up @@ -45,7 +56,7 @@ class Custom_Base_DS(Dataset):
predict (bool): Flag to indicate if the dataset is used for prediction.
"""

def __init__(self, rootdir, transform=None, predict=False, extension="jpg"):
def __init__(self, rootdir, transform=None, predict=False):
"""
Initialize the Custom_Base_DS with the directory, transformations, and mode.
Expand All @@ -68,7 +79,8 @@ def load_data(self):
"""
if self.predict:
# Load data for prediction
self.data = glob(os.path.join(self.img_root,"*.{}".format(self.extension)))
# self.data = glob(os.path.join(self.img_root,"*.{}".format(self.extension)))
self.data = [os.path.join(dp, f) for dp, dn, filenames in os.walk(self.img_root) for f in filenames if is_image_file(f)] # dp: directory path, dn: directory name, f: filename
else:
# Load data for training/validation
self.data = list(self.ann['path'])
Expand Down

0 comments on commit 8db978f

Please sign in to comment.