Skip to content

Commit

Permalink
add convnext
Browse files Browse the repository at this point in the history
  • Loading branch information
nmcardoso committed Aug 17, 2023
1 parent ac5b82b commit 8845a49
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions mergernet/estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,4 +228,13 @@ def get_conv_arch(self, pretrained_arch: str) -> Tuple[Callable, Callable]:
elif pretrained_arch == 'efficientnetb7':
preprocess_input = tf.keras.applications.efficientnet.preprocess_input
base_model = tf.keras.applications.EfficientNetB7
elif pretrained_arch == 'convnext_tiny':
preprocess_input = tf.keras.applications.convnext.preprocess_input
base_model = tf.keras.applications.ConvNeXtTiny
elif pretrained_arch == 'convnext_small':
preprocess_input = tf.keras.applications.convnext.preprocess_input
base_model = tf.keras.applications.ConvNeXtSmall
elif pretrained_arch == 'convnext_base':
preprocess_input = tf.keras.applications.convnext.preprocess_input
base_model = tf.keras.applications.ConvNeXtBase
return base_model, preprocess_input

0 comments on commit 8845a49

Please sign in to comment.