From bc2bf6b17955bf93256c656b181824069afcaade Mon Sep 17 00:00:00 2001 From: Sam Washko Date: Thu, 23 May 2024 15:55:11 -0700 Subject: [PATCH] Replace enums to be extendable (#151) --- modelscan/error.py | 1 - modelscan/issues.py | 12 ++++++---- modelscan/scanners/h5/scan.py | 6 +++-- modelscan/scanners/keras/scan.py | 6 +++-- modelscan/scanners/pickle/scan.py | 14 +++++++---- modelscan/scanners/saved_model/scan.py | 6 +++-- modelscan/settings.py | 33 +++++++++++++++----------- modelscan/skip.py | 17 ++++++------- 8 files changed, 57 insertions(+), 38 deletions(-) diff --git a/modelscan/error.py b/modelscan/error.py index 7e227c4..d471169 100644 --- a/modelscan/error.py +++ b/modelscan/error.py @@ -1,4 +1,3 @@ -from enum import Enum from modelscan.model import Model import abc from pathlib import Path diff --git a/modelscan/issues.py b/modelscan/issues.py index 130318d..16bfb51 100644 --- a/modelscan/issues.py +++ b/modelscan/issues.py @@ -6,6 +6,8 @@ from collections import defaultdict +from modelscan.settings import Property + logger = logging.getLogger("modelscan") @@ -16,8 +18,8 @@ class IssueSeverity(Enum): CRITICAL = 4 -class IssueCode(Enum): - UNSAFE_OPERATOR = 1 +class IssueCode: + UNSAFE_OPERATOR = Property("UNSAFE_OPERATOR", 1) class IssueDetails(metaclass=abc.ABCMeta): @@ -40,14 +42,14 @@ class Issue: def __init__( self, - code: IssueCode, + code: Property, severity: IssueSeverity, details: IssueDetails, ) -> None: """ Create a issue with given information - :param code: Code of the issue from the issue code enum. + :param code: Code of the issue from the issue code class. :param severity: The severity level of the issue from Severity enum. :param details: An implementation of the IssueDetails object. """ @@ -82,7 +84,7 @@ def __hash__(self) -> int: def print(self) -> None: issue_description = self.code.name - if self.code == IssueCode.UNSAFE_OPERATOR: + if self.code.value == IssueCode.UNSAFE_OPERATOR.value: issue_description = "Unsafe operator" else: logger.error("No issue description for issue code %s", self.code) diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index af2da51..bd088f6 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -18,7 +18,7 @@ from modelscan.scanners.scan import ScanResults from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan from modelscan.model import Model -from modelscan.settings import DefaultModelFormats +from modelscan.settings import SupportedModelFormats logger = logging.getLogger("modelscan") @@ -28,7 +28,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if DefaultModelFormats.KERAS_H5 not in model.get_context("formats"): + if SupportedModelFormats.KERAS_H5.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None dep_error = self.handle_binary_dependencies() diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index c4a7727..1e88c38 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -9,7 +9,7 @@ from modelscan.scanners.scan import ScanResults from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan from modelscan.model import Model -from modelscan.settings import DefaultModelFormats +from modelscan.settings import SupportedModelFormats logger = logging.getLogger("modelscan") @@ -17,7 +17,9 @@ class KerasLambdaDetectScan(SavedModelLambdaDetectScan): def scan(self, model: Model) -> Optional[ScanResults]: - if DefaultModelFormats.KERAS not in model.get_context("formats"): + if SupportedModelFormats.KERAS.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None dep_error = self.handle_binary_dependencies() diff --git a/modelscan/scanners/pickle/scan.py b/modelscan/scanners/pickle/scan.py index 77c9445..3ece571 100644 --- a/modelscan/scanners/pickle/scan.py +++ b/modelscan/scanners/pickle/scan.py @@ -9,7 +9,7 @@ scan_pytorch, ) from modelscan.model import Model -from modelscan.settings import DefaultModelFormats +from modelscan.settings import SupportedModelFormats logger = logging.getLogger("modelscan") @@ -19,7 +19,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if DefaultModelFormats.PYTORCH not in model.get_context("formats"): + if SupportedModelFormats.PYTORCH.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None if _is_zipfile(model.get_source(), model.get_stream()): @@ -46,7 +48,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if DefaultModelFormats.NUMPY not in model.get_context("formats"): + if SupportedModelFormats.NUMPY.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None results = scan_numpy( @@ -70,7 +74,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if DefaultModelFormats.PICKLE not in model.get_context("formats"): + if SupportedModelFormats.PICKLE.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None results = scan_pickle_bytes( diff --git a/modelscan/scanners/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py index d7af13e..4c8f6f6 100644 --- a/modelscan/scanners/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -22,7 +22,7 @@ from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails from modelscan.scanners.scan import ScanBase, ScanResults from modelscan.model import Model -from modelscan.settings import DefaultModelFormats +from modelscan.settings import SupportedModelFormats logger = logging.getLogger("modelscan") @@ -32,7 +32,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if DefaultModelFormats.TENSORFLOW not in model.get_context("formats"): + if SupportedModelFormats.TENSORFLOW.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None dep_error = self.handle_binary_dependencies() diff --git a/modelscan/settings.py b/modelscan/settings.py index 21500cb..395dfbe 100644 --- a/modelscan/settings.py +++ b/modelscan/settings.py @@ -1,18 +1,23 @@ import tomlkit -from enum import Enum from typing import Any from modelscan._version import __version__ -class DefaultModelFormats(Enum): - TENSORFLOW = "tensorflow" - KERAS_H5 = "keras_h5" - KERAS = "keras" - NUMPY = "numpy" - PYTORCH = "pytorch" - PICKLE = "pickle" +class Property: + def __init__(self, name: str, value: Any) -> None: + self.name = name + self.value = value + + +class SupportedModelFormats: + TENSORFLOW = Property("TENSORFLOW", "tensorflow") + KERAS_H5 = Property("KERAS_H5", "keras_h5") + KERAS = Property("KERAS", "keras") + NUMPY = Property("NUMPY", "numpy") + PYTORCH = Property("PYTORCH", "pytorch") + PICKLE = Property("PICKLE", "pickle") DEFAULT_REPORTING_MODULES = { @@ -70,12 +75,12 @@ class DefaultModelFormats(Enum): "middlewares": { "modelscan.middlewares.FormatViaExtensionMiddleware": { "formats": { - DefaultModelFormats.TENSORFLOW: [".pb"], - DefaultModelFormats.KERAS_H5: [".h5"], - DefaultModelFormats.KERAS: [".keras"], - DefaultModelFormats.NUMPY: [".npy"], - DefaultModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"], - DefaultModelFormats.PICKLE: [ + SupportedModelFormats.TENSORFLOW: [".pb"], + SupportedModelFormats.KERAS_H5: [".h5"], + SupportedModelFormats.KERAS: [".keras"], + SupportedModelFormats.NUMPY: [".npy"], + SupportedModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"], + SupportedModelFormats.PICKLE: [ ".pkl", ".pickle", ".joblib", diff --git a/modelscan/skip.py b/modelscan/skip.py index 27272e6..2f83b75 100644 --- a/modelscan/skip.py +++ b/modelscan/skip.py @@ -1,17 +1,18 @@ import logging from enum import Enum +from modelscan.settings import Property logger = logging.getLogger("modelscan") -class SkipCategories(Enum): - SCAN_NOT_SUPPORTED = 1 - BAD_ZIP = 2 - MODEL_CONFIG = 3 - H5_DATA = 4 - NOT_IMPLEMENTED = 5 - MAGIC_NUMBER = 6 +class SkipCategories: + SCAN_NOT_SUPPORTED = Property("SCAN_NOT_SUPPORTED", 1) + BAD_ZIP = Property("BAD_ZIP", 2) + MODEL_CONFIG = Property("MODEL_CONFIG", 3) + H5_DATA = Property("H5_DATA", 4) + NOT_IMPLEMENTED = Property("NOT_IMPLEMENTED", 5) + MAGIC_NUMBER = Property("MAGIC_NUMBER", 6) class Skip: @@ -31,7 +32,7 @@ class ModelScanSkipped: def __init__( self, scan_name: str, - category: SkipCategories, + category: Property, message: str, source: str, ) -> None: