Skip to content

Commit fdec546

Browse files
committed
Improvements and bugfixes:
- Fixed a bug in PyTorch-model loading - Improved loading and handling of TensorFlow models - Added support for Metal Performance Shaders (MPS) - Added requirement file for virtual python environment on mac os
1 parent 590e068 commit fdec546

File tree

8 files changed

+69
-42
lines changed

8 files changed

+69
-42
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ source env/bin/activate
7272

7373
**Note**: For further information on how to set up a virtual python environment (also on **Windows**) refer to https://docs.python.org/3/library/venv.html .
7474

75-
When successfully installed, the software outputs the line : "Successfully installed NNC-0.2.2"
75+
When successfully installed, the software outputs the line : "Successfully installed NNC-0.3.0"
7676

7777
### Importing the main module
7878

create_venv_macos.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
3+
python3 -m venv env
4+
source env/bin/activate
5+
pip install --upgrade pip
6+
pip install -r requirements_macos.txt
7+
pip install -e .
8+
deactivate

framework/pytorch_model/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,14 @@ def __init__(self,
482482
lr=1e-4,
483483
):
484484

485-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
485+
if torch.cuda.is_available():
486+
device = "cuda"
487+
elif torch.backends.mps.is_available():
488+
device = "mps"
489+
else:
490+
device = "cpu"
491+
492+
self.device = torch.device(device)
486493

487494
torch.manual_seed(451)
488495
torch.backends.cudnn.deterministic = True
@@ -491,7 +498,6 @@ def __init__(self,
491498
self.learning_rate = lr
492499
self.epochs = epochs
493500
self.max_batches = max_batches
494-
495501
self.handle = handler
496502
if test_set:
497503
self.test_set = test_set
@@ -517,7 +523,6 @@ def test_model(self,
517523
verbose=False
518524
):
519525

520-
torch.set_num_threads(1)
521526
Model = copy.deepcopy(self.model)
522527

523528
base_model_arch = Model.state_dict()
@@ -556,7 +561,6 @@ def eval_model(self,
556561
verbose=False
557562
):
558563

559-
torch.set_num_threads(1)
560564

561565
Model = copy.deepcopy(self.model)
562566

@@ -595,7 +599,7 @@ def tune_model(
595599
ft_flag=False,
596600
verbose=False,
597601
):
598-
torch.set_num_threads(1)
602+
599603
verbose = 1 if (verbose & 1) else 0
600604

601605
base_model_arch = self.model.state_dict()

framework/tensorflow_model/__init__.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def create_NNC_model_instance_from_file(
131131
model_struct = loaded_model_struct
132132

133133
if dataset_path and model_struct:
134+
if model_name == None and hasattr(model_struct, 'name'):
135+
model_name=model_struct.name
134136
TEFModelExecuter = create_imagenet_model_executer(model_struct=model_struct,
135137
dataset_path=dataset_path,
136138
lr=lr,
@@ -161,6 +163,8 @@ def create_NNC_model_instance_from_object(
161163
model_struct = loaded_model_struct
162164

163165
if dataset_path and model_struct:
166+
if model_name == None and hasattr(model_struct, 'name'):
167+
model_name=model_struct.name
164168
TEFModelExecuter = create_imagenet_model_executer(model_struct=model_struct,
165169
dataset_path=dataset_path,
166170
lr=lr,
@@ -230,6 +234,10 @@ def __init__(self, model_dict=None):
230234
def load_model( self,
231235
model_path
232236
):
237+
238+
try:
239+
model_file = tf.keras.models.load_model(model_path)
240+
except:
233241
model_file = h5py.File(model_path, 'r')
234242

235243
try:
@@ -262,26 +270,19 @@ def init_model_from_model_object( self,
262270
model_object,
263271
):
264272
self.model = model_object
265-
266-
h5_model_path = './temp.h5'
267-
model_object.save_weights(h5_model_path)
268-
model = h5py.File(h5_model_path, 'r')
269-
os.remove(h5_model_path)
270273

271-
if 'layer_names' in model.attrs:
272-
module_names = [n for n in model.attrs['layer_names']]
273-
274+
weights = model_object.get_weights()
274275
layer_names = []
275-
for mod_name in module_names:
276-
layer = model[mod_name]
277-
if 'weight_names' in layer.attrs:
278-
weight_names = [mod_name+'/'+n for n in layer.attrs['weight_names']]
279-
if weight_names:
280-
layer_names += weight_names
276+
277+
for layer in model_object.layers:
278+
mod_name = layer.name
279+
if layer.weights != []:
280+
for weight in layer.weights:
281+
layer_names.append(mod_name+"/"+weight.name)
281282

282283
model_parameter_dict = {}
283-
for name in layer_names:
284-
model_parameter_dict[name] = model[name]
284+
for i, name in enumerate(layer_names):
285+
model_parameter_dict[name] = weights[i]
285286

286287
return self.init_model_from_dict( model_parameter_dict ), model_object
287288

framework/use_case_init/__init__.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -148,47 +148,49 @@ def preprocess(
148148
):
149149
image_size = 224
150150

151-
if self.__model_name == 'EfficientNetB1':
151+
if self.__model_name == 'EfficientNetB1' or self.__model_name == 'efficientnetb1':
152152
image_size = 240
153-
elif self.__model_name == 'EfficientNetB2':
153+
elif self.__model_name == 'EfficientNetB2' or self.__model_name == 'efficientnetb2':
154154
image_size = 260
155-
elif self.__model_name == 'EfficientNetB3':
155+
elif self.__model_name == 'EfficientNetB3' or self.__model_name == 'efficientnetb3':
156156
image_size = 300
157-
elif self.__model_name == 'EfficientNetB4':
157+
elif self.__model_name == 'EfficientNetB4' or self.__model_name == 'efficientnetb4':
158158
image_size = 380
159-
elif self.__model_name == 'EfficientNetB5':
159+
elif self.__model_name == 'EfficientNetB5' or self.__model_name == 'efficientnetb5':
160160
image_size = 456
161-
elif self.__model_name == 'EfficientNetB6':
161+
elif self.__model_name == 'EfficientNetB6' or self.__model_name == 'efficientnetb6':
162162
image_size = 528
163-
elif self.__model_name == 'EfficientNetB7':
163+
elif self.__model_name == 'EfficientNetB7' or self.__model_name == 'efficientnetb7':
164164
image_size = 600
165165

166166
image, label = self.model_transform(image, label, image_size=image_size)
167167

168-
if 'DenseNet' in self.__model_name:
168+
if 'DenseNet' in self.__model_name or 'densenet' in self.__model_name:
169169
return tf.keras.applications.densenet.preprocess_input(image), label
170-
elif 'EfficientNet' in self.__model_name:
170+
elif 'EfficientNet' in self.__model_name or 'efficientnet' in self.__model_name:
171171
return tf.keras.applications.efficientnet.preprocess_input(image), label
172-
elif self.__model_name == 'InceptionResNetV2':
172+
elif self.__model_name == 'InceptionResNetV2' or self.__model_name == 'inception_resnet_v2':
173173
return tf.keras.applications.inception_resnet_v2.preprocess_input(image), label
174-
elif self.__model_name == 'InceptionV3':
174+
elif self.__model_name == 'InceptionV3' or self.__model_name == "inception_v3":
175175
return tf.keras.applications.inception_v3.preprocess_input(image), label
176-
elif self.__model_name == 'MobileNet':
176+
elif self.__model_name == 'MobileNet' or ( 'mobilenet' in self.__model_name and 'v2' not in self.__model_name ):
177177
return tf.keras.applications.mobilenet.preprocess_input(image), label
178-
elif self.__model_name == 'MobileNetV2':
178+
elif self.__model_name == 'MobileNetV2' or 'mobilenetv2' in self.__model_name:
179179
return tf.keras.applications.mobilenet_v2.preprocess_input(image), label
180180
elif 'NASNet' in self.__model_name:
181181
return tf.keras.applications.nasnet.preprocess_input(image), label
182-
elif 'ResNet' in self.__model_name and 'V2' not in self.__model_name:
182+
elif ('ResNet' in self.__model_name and 'V2' not in self.__model_name) or ('resnet' in self.__model_name and 'v2' not in self.__model_name):
183183
return tf.keras.applications.resnet.preprocess_input(image), label
184-
elif 'ResNet' in self.__model_name and 'V2' in self.__model_name:
184+
elif ('ResNet' in self.__model_name and 'V2' in self.__model_name) or ('resnet' in self.__model_name and 'v2' in self.__model_name):
185185
return tf.keras.applications.resnet_v2.preprocess_input(image), label
186-
elif self.__model_name == 'VGG16':
186+
elif self.__model_name == 'VGG16' or self.__model_name == 'vgg16':
187187
return tf.keras.applications.vgg16.preprocess_input(image), label
188-
elif self.__model_name == 'VGG19':
188+
elif self.__model_name == 'VGG19' or self.__model_name == 'vgg19':
189189
return tf.keras.applications.vgg19.preprocess_input(image), label
190-
elif self.__model_name == 'Xception':
190+
elif self.__model_name == 'Xception' or self.__model_name == 'xception':
191191
return tf.keras.applications.xception.preprocess_input(image), label
192+
elif 'RegNet' in self.__model_name or 'regnet' in self.__model_name:
193+
return tf.keras.applications.regnet.preprocess_input(image), label
192194

193195

194196
# supported use cases

requirements_cu11.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ scikit-learn>=0.23.1
44
tqdm>=4.32.2
55
h5py>=3.1.0
66
pybind11>=2.6.2
7-
tensorflow>=2.6.0
7+
tensorflow[and-cuda]>=2.6.0
88
pandas>=1.0.5
99
opencv-python>=4.4.0.46
1010
torch>=1.8.1

requirements_macos.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
urllib3<2.0
2+
Click>=7.0
3+
scikit-learn>=0.23.1
4+
tqdm>=4.32.2
5+
h5py>=3.1.0
6+
pybind11>=2.6.2
7+
tensorflow>=2.13.0
8+
tensorflow-metal>=1.0.0
9+
pandas>=1.0.5
10+
opencv-python>=4.4.0.46
11+
torch>=1.12.0
12+
torchvision>=0.13.1

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from setuptools.command.build_ext import build_ext
5050
import setuptools
5151

52-
__version__ = '0.2.2'
52+
__version__ = '0.3.0'
5353

5454

5555
class get_pybind_include(object):

0 commit comments

Comments
 (0)