From 28e0e2705b9d03798a7c8c1e8fc8e92d298525af Mon Sep 17 00:00:00 2001 From: Sam Washko Date: Thu, 23 May 2024 10:16:45 -0700 Subject: [PATCH] Standardize model format names (#150) --- modelscan/scanners/h5/scan.py | 3 ++- modelscan/scanners/keras/scan.py | 3 ++- modelscan/scanners/pickle/scan.py | 7 ++++--- modelscan/scanners/saved_model/scan.py | 3 ++- modelscan/settings.py | 24 +++++++++++++++++------- 5 files changed, 27 insertions(+), 13 deletions(-) diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index c398535..af2da51 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -18,6 +18,7 @@ from modelscan.scanners.scan import ScanResults from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan from modelscan.model import Model +from modelscan.settings import DefaultModelFormats logger = logging.getLogger("modelscan") @@ -27,7 +28,7 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if "keras_h5" not in model.get_context("formats"): + if DefaultModelFormats.KERAS_H5 not in model.get_context("formats"): return None dep_error = self.handle_binary_dependencies() diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index 2a7fb5e..c4a7727 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -9,6 +9,7 @@ from modelscan.scanners.scan import ScanResults from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan from modelscan.model import Model +from modelscan.settings import DefaultModelFormats logger = logging.getLogger("modelscan") @@ -16,7 +17,7 @@ class KerasLambdaDetectScan(SavedModelLambdaDetectScan): def scan(self, model: Model) -> Optional[ScanResults]: - if "keras" not in model.get_context("formats"): + if DefaultModelFormats.KERAS not in model.get_context("formats"): return None dep_error = self.handle_binary_dependencies() diff --git a/modelscan/scanners/pickle/scan.py b/modelscan/scanners/pickle/scan.py index d138202..77c9445 100644 --- a/modelscan/scanners/pickle/scan.py +++ b/modelscan/scanners/pickle/scan.py @@ -9,6 +9,7 @@ scan_pytorch, ) from modelscan.model import Model +from modelscan.settings import DefaultModelFormats logger = logging.getLogger("modelscan") @@ -18,7 +19,7 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if "pytorch" not in model.get_context("formats"): + if DefaultModelFormats.PYTORCH not in model.get_context("formats"): return None if _is_zipfile(model.get_source(), model.get_stream()): @@ -45,7 +46,7 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if "numpy" not in model.get_context("formats"): + if DefaultModelFormats.NUMPY not in model.get_context("formats"): return None results = scan_numpy( @@ -69,7 +70,7 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if "pickle" not in model.get_context("formats"): + if DefaultModelFormats.PICKLE not in model.get_context("formats"): return None results = scan_pickle_bytes( diff --git a/modelscan/scanners/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py index 74e8fb8..d7af13e 100644 --- a/modelscan/scanners/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -22,6 +22,7 @@ from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails from modelscan.scanners.scan import ScanBase, ScanResults from modelscan.model import Model +from modelscan.settings import DefaultModelFormats logger = logging.getLogger("modelscan") @@ -31,7 +32,7 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if "tf_saved_model" not in model.get_context("formats"): + if DefaultModelFormats.TENSORFLOW not in model.get_context("formats"): return None dep_error = self.handle_binary_dependencies() diff --git a/modelscan/settings.py b/modelscan/settings.py index 5f4e6ed..21500cb 100644 --- a/modelscan/settings.py +++ b/modelscan/settings.py @@ -1,9 +1,20 @@ import tomlkit +from enum import Enum from typing import Any from modelscan._version import __version__ + +class DefaultModelFormats(Enum): + TENSORFLOW = "tensorflow" + KERAS_H5 = "keras_h5" + KERAS = "keras" + NUMPY = "numpy" + PYTORCH = "pytorch" + PICKLE = "pickle" + + DEFAULT_REPORTING_MODULES = { "console": "modelscan.reports.ConsoleReport", "json": "modelscan.reports.JSONReport", @@ -59,13 +70,12 @@ "middlewares": { "modelscan.middlewares.FormatViaExtensionMiddleware": { "formats": { - "tf": [".pb"], - "tf_saved_model": [".pb"], - "keras_h5": [".h5"], - "keras": [".keras"], - "numpy": [".npy"], - "pytorch": [".bin", ".pt", ".pth", ".ckpt"], - "pickle": [ + DefaultModelFormats.TENSORFLOW: [".pb"], + DefaultModelFormats.KERAS_H5: [".h5"], + DefaultModelFormats.KERAS: [".keras"], + DefaultModelFormats.NUMPY: [".npy"], + DefaultModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"], + DefaultModelFormats.PICKLE: [ ".pkl", ".pickle", ".joblib",