-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
182 lines (148 loc) · 6.76 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os # Import the os module for interacting with the operating system
from PIL import Image # Import the Image module from PIL for image processing
import torch # Import the torch module for working with PyTorch
import torchvision.models as models # Import the models module from torchvision for pre-trained models
import torchvision.transforms as transforms # Import the transforms module from torchvision for data transformations
import joblib # Import the joblib module for model serialization
import io # Import the io module for handling byte streams
import ssl # SSL certificate handling
from defaults import (
MODEL_NAME,
OBJECT_NAME,
PKL_FILE_NAME,
NOT_OBJECT_NAME,
TEST_IMAGES_PATH,
) # Import the MODEL_NAME, OBJECT_NAME, PKL_FILE_NAME, NOT_OBJECT_NAME, and TEST_IMAGES_PATH variables from the defaults module.
# Fix SSL certificate verification issues
ssl._create_default_https_context = ssl._create_unverified_context
# Define the target image size for resizing
imageSize = (256, 256)
# Define the data preprocessing transformations (without augmentation for testing)
dataTransforms = transforms.Compose(
[
transforms.Resize(imageSize), # Resize the image to 256x256
transforms.ToTensor(), # Convert the image to a PyTorch tensor
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.256, 0.225]
), # Normalize the image with specified mean and standard deviation
]
)
print(
f"Testing the {MODEL_NAME} model..."
) # Inform the user which model is being tested.
if MODEL_NAME == "resnet": # Check if the model name is "resnet".
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
elif MODEL_NAME == "efficientnet": # Check if the model name is "efficientnet".
model = models.efficientnet_b0(
weights=models.EfficientNet_B0_Weights.DEFAULT
) # Load pre-trained EfficientNet model
elif MODEL_NAME == "vgg": # Check if the model name is "vgg".
model = models.vgg16(
weights=models.VGG16_Weights.DEFAULT
) # Load pre-trained VGG model
elif MODEL_NAME == "densenet": # Check if the model name is "densenet".
model = models.densenet201(
weights=models.DenseNet201_Weights.DEFAULT
) # Load pre-trained DenseNet model
elif MODEL_NAME == "mobilenet": # Check if the model name is "mobilenet".
model = models.mobilenet_v2(
weights=models.MobileNet_V2_Weights.DEFAULT
) # Load pre-trained MobileNet model
else: # If the model name is neither "resnet" nor "efficientnet".
raise ValueError(
"Model not found!"
) # Raise an error indicating the model was not found.
model.eval() # Set the model to evaluation mode
# Remove the final classification layer from the model to use it as a feature extractor
featureExtractor = torch.nn.Sequential(*list(model.children())[:-1])
# Function to extract features from a single image using the pre-trained model
def extractFeatures(image):
"""
Extracts features from an image using a feature extractor model.
Args:
image (torch.Tensor): The input image tensor.
Returns:
numpy.ndarray: The extracted features as a flattened numpy array.
"""
# Apply transformations and add a batch dimension
image = dataTransforms(image).unsqueeze(0)
with torch.no_grad(): # Disable gradient computation
# Extract features and flatten them
feature = featureExtractor(image).numpy().flatten()
return feature # Return the extracted features
# Load the trained model from the .pkl file
pipeline = joblib.load(PKL_FILE_NAME)
# Function to predict if an image is an object or not using the extracted features
def prediction(image):
"""
Predicts the class of an image using a trained model.
Args:
image (str): The path to the image file.
Returns:
tuple: A tuple containing the predicted class (True if object, False otherwise) and the prediction percentage.
"""
img = Image.open(image).convert("RGB") # Open the image and convert it to RGB
feature = extractFeatures(img) # Extract features from the image
result = pipeline.predict([feature]) # Predict the class using the trained model
predictionPercentage = pipeline.predict_proba([feature])[0][
0
] # Get the prediction percentage
return (
True if result[0] == 0 else False,
predictionPercentage,
) # Return the prediction result as a tuple
# Function to predict if an image is an object or not given its file path
def predictImage(imagePath):
"""
Predicts the image using the given image path.
Args:
imagePath (str): The path of the image to be predicted.
Returns:
str: The prediction result.
"""
result = prediction(imagePath) # Predict the image
return result # Return the prediction result
# Function to predict if an image is an object or not given its byte content
def predictImageFromBytes(imageBytes):
"""
Predicts the image based on the provided byte content.
Args:
imageBytes (bytes): The byte content of the image.
Returns:
str: The prediction result.
"""
img = io.BytesIO(imageBytes) # Convert the byte content to a byte stream
result = prediction(img) # Predict the image
return result # Return the prediction result
# Function to predict if images in a directory are an object or not
def predictImagesInDirectory(directoryPath):
"""
Predicts images in a given directory and returns the results.
Args:
directoryPath (str): The path to the directory containing the images.
Returns:
list: A list of tuples, where each tuple contains the filename and the prediction result for the corresponding image.
"""
results = [] # Initialize an empty list to store the results
for filename in os.listdir(
directoryPath
): # Iterate over the files in the directory
filepath = os.path.join(
directoryPath, filename
) # Get the full path of the file
if os.path.isfile(filepath): # Check if it is a file
result = predictImage(filepath) # Predict the image
results.append((filename, result)) # Append the result to the list
return results # Return the list of results
predictions = predictImagesInDirectory(
TEST_IMAGES_PATH
) # Predict the images in the test directory
for filename, result in predictions: # Iterate over the prediction results
imageName = filename # Get the filename of the image
isObject = (
OBJECT_NAME if result[0] else NOT_OBJECT_NAME
) # Get the object name based on the prediction
percentage = float(result[1] * 100) # Convert the result to a float
print(
f"Image: {imageName}, It is {isObject}, Percentage: {percentage}"
) # Print the prediction result