Skip to content

Commit

Permalink
Replace enums to be extendable (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
swashko authored May 23, 2024
1 parent 28e0e27 commit bc2bf6b
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 38 deletions.
1 change: 0 additions & 1 deletion modelscan/error.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from enum import Enum
from modelscan.model import Model
import abc
from pathlib import Path
Expand Down
12 changes: 7 additions & 5 deletions modelscan/issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from collections import defaultdict

from modelscan.settings import Property

logger = logging.getLogger("modelscan")


Expand All @@ -16,8 +18,8 @@ class IssueSeverity(Enum):
CRITICAL = 4


class IssueCode(Enum):
UNSAFE_OPERATOR = 1
class IssueCode:
UNSAFE_OPERATOR = Property("UNSAFE_OPERATOR", 1)


class IssueDetails(metaclass=abc.ABCMeta):
Expand All @@ -40,14 +42,14 @@ class Issue:

def __init__(
self,
code: IssueCode,
code: Property,
severity: IssueSeverity,
details: IssueDetails,
) -> None:
"""
Create a issue with given information
:param code: Code of the issue from the issue code enum.
:param code: Code of the issue from the issue code class.
:param severity: The severity level of the issue from Severity enum.
:param details: An implementation of the IssueDetails object.
"""
Expand Down Expand Up @@ -82,7 +84,7 @@ def __hash__(self) -> int:

def print(self) -> None:
issue_description = self.code.name
if self.code == IssueCode.UNSAFE_OPERATOR:
if self.code.value == IssueCode.UNSAFE_OPERATOR.value:
issue_description = "Unsafe operator"
else:
logger.error("No issue description for issue code %s", self.code)
Expand Down
6 changes: 4 additions & 2 deletions modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan
from modelscan.model import Model
from modelscan.settings import DefaultModelFormats
from modelscan.settings import SupportedModelFormats

logger = logging.getLogger("modelscan")

Expand All @@ -28,7 +28,9 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if DefaultModelFormats.KERAS_H5 not in model.get_context("formats"):
if SupportedModelFormats.KERAS_H5.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

dep_error = self.handle_binary_dependencies()
Expand Down
6 changes: 4 additions & 2 deletions modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan
from modelscan.model import Model
from modelscan.settings import DefaultModelFormats
from modelscan.settings import SupportedModelFormats


logger = logging.getLogger("modelscan")


class KerasLambdaDetectScan(SavedModelLambdaDetectScan):
def scan(self, model: Model) -> Optional[ScanResults]:
if DefaultModelFormats.KERAS not in model.get_context("formats"):
if SupportedModelFormats.KERAS.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

dep_error = self.handle_binary_dependencies()
Expand Down
14 changes: 10 additions & 4 deletions modelscan/scanners/pickle/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
scan_pytorch,
)
from modelscan.model import Model
from modelscan.settings import DefaultModelFormats
from modelscan.settings import SupportedModelFormats

logger = logging.getLogger("modelscan")

Expand All @@ -19,7 +19,9 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if DefaultModelFormats.PYTORCH not in model.get_context("formats"):
if SupportedModelFormats.PYTORCH.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

if _is_zipfile(model.get_source(), model.get_stream()):
Expand All @@ -46,7 +48,9 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if DefaultModelFormats.NUMPY not in model.get_context("formats"):
if SupportedModelFormats.NUMPY.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

results = scan_numpy(
Expand All @@ -70,7 +74,9 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if DefaultModelFormats.PICKLE not in model.get_context("formats"):
if SupportedModelFormats.PICKLE.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

results = scan_pickle_bytes(
Expand Down
6 changes: 4 additions & 2 deletions modelscan/scanners/saved_model/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails
from modelscan.scanners.scan import ScanBase, ScanResults
from modelscan.model import Model
from modelscan.settings import DefaultModelFormats
from modelscan.settings import SupportedModelFormats

logger = logging.getLogger("modelscan")

Expand All @@ -32,7 +32,9 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if DefaultModelFormats.TENSORFLOW not in model.get_context("formats"):
if SupportedModelFormats.TENSORFLOW.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

dep_error = self.handle_binary_dependencies()
Expand Down
33 changes: 19 additions & 14 deletions modelscan/settings.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import tomlkit

from enum import Enum
from typing import Any

from modelscan._version import __version__


class DefaultModelFormats(Enum):
TENSORFLOW = "tensorflow"
KERAS_H5 = "keras_h5"
KERAS = "keras"
NUMPY = "numpy"
PYTORCH = "pytorch"
PICKLE = "pickle"
class Property:
def __init__(self, name: str, value: Any) -> None:
self.name = name
self.value = value


class SupportedModelFormats:
TENSORFLOW = Property("TENSORFLOW", "tensorflow")
KERAS_H5 = Property("KERAS_H5", "keras_h5")
KERAS = Property("KERAS", "keras")
NUMPY = Property("NUMPY", "numpy")
PYTORCH = Property("PYTORCH", "pytorch")
PICKLE = Property("PICKLE", "pickle")


DEFAULT_REPORTING_MODULES = {
Expand Down Expand Up @@ -70,12 +75,12 @@ class DefaultModelFormats(Enum):
"middlewares": {
"modelscan.middlewares.FormatViaExtensionMiddleware": {
"formats": {
DefaultModelFormats.TENSORFLOW: [".pb"],
DefaultModelFormats.KERAS_H5: [".h5"],
DefaultModelFormats.KERAS: [".keras"],
DefaultModelFormats.NUMPY: [".npy"],
DefaultModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"],
DefaultModelFormats.PICKLE: [
SupportedModelFormats.TENSORFLOW: [".pb"],
SupportedModelFormats.KERAS_H5: [".h5"],
SupportedModelFormats.KERAS: [".keras"],
SupportedModelFormats.NUMPY: [".npy"],
SupportedModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"],
SupportedModelFormats.PICKLE: [
".pkl",
".pickle",
".joblib",
Expand Down
17 changes: 9 additions & 8 deletions modelscan/skip.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import logging
from enum import Enum

from modelscan.settings import Property

logger = logging.getLogger("modelscan")


class SkipCategories(Enum):
SCAN_NOT_SUPPORTED = 1
BAD_ZIP = 2
MODEL_CONFIG = 3
H5_DATA = 4
NOT_IMPLEMENTED = 5
MAGIC_NUMBER = 6
class SkipCategories:
SCAN_NOT_SUPPORTED = Property("SCAN_NOT_SUPPORTED", 1)
BAD_ZIP = Property("BAD_ZIP", 2)
MODEL_CONFIG = Property("MODEL_CONFIG", 3)
H5_DATA = Property("H5_DATA", 4)
NOT_IMPLEMENTED = Property("NOT_IMPLEMENTED", 5)
MAGIC_NUMBER = Property("MAGIC_NUMBER", 6)


class Skip:
Expand All @@ -31,7 +32,7 @@ class ModelScanSkipped:
def __init__(
self,
scan_name: str,
category: SkipCategories,
category: Property,
message: str,
source: str,
) -> None:
Expand Down

0 comments on commit bc2bf6b

Please sign in to comment.