Skip to content

Commit

Permalink
Standardize model format names (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
swashko authored May 23, 2024
1 parent 02fdd4a commit 28e0e27
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 13 deletions.
3 changes: 2 additions & 1 deletion modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan
from modelscan.model import Model
from modelscan.settings import DefaultModelFormats

logger = logging.getLogger("modelscan")

Expand All @@ -27,7 +28,7 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if "keras_h5" not in model.get_context("formats"):
if DefaultModelFormats.KERAS_H5 not in model.get_context("formats"):
return None

dep_error = self.handle_binary_dependencies()
Expand Down
3 changes: 2 additions & 1 deletion modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan
from modelscan.model import Model
from modelscan.settings import DefaultModelFormats


logger = logging.getLogger("modelscan")


class KerasLambdaDetectScan(SavedModelLambdaDetectScan):
def scan(self, model: Model) -> Optional[ScanResults]:
if "keras" not in model.get_context("formats"):
if DefaultModelFormats.KERAS not in model.get_context("formats"):
return None

dep_error = self.handle_binary_dependencies()
Expand Down
7 changes: 4 additions & 3 deletions modelscan/scanners/pickle/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
scan_pytorch,
)
from modelscan.model import Model
from modelscan.settings import DefaultModelFormats

logger = logging.getLogger("modelscan")

Expand All @@ -18,7 +19,7 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if "pytorch" not in model.get_context("formats"):
if DefaultModelFormats.PYTORCH not in model.get_context("formats"):
return None

if _is_zipfile(model.get_source(), model.get_stream()):
Expand All @@ -45,7 +46,7 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if "numpy" not in model.get_context("formats"):
if DefaultModelFormats.NUMPY not in model.get_context("formats"):
return None

results = scan_numpy(
Expand All @@ -69,7 +70,7 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if "pickle" not in model.get_context("formats"):
if DefaultModelFormats.PICKLE not in model.get_context("formats"):
return None

results = scan_pickle_bytes(
Expand Down
3 changes: 2 additions & 1 deletion modelscan/scanners/saved_model/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails
from modelscan.scanners.scan import ScanBase, ScanResults
from modelscan.model import Model
from modelscan.settings import DefaultModelFormats

logger = logging.getLogger("modelscan")

Expand All @@ -31,7 +32,7 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if "tf_saved_model" not in model.get_context("formats"):
if DefaultModelFormats.TENSORFLOW not in model.get_context("formats"):
return None

dep_error = self.handle_binary_dependencies()
Expand Down
24 changes: 17 additions & 7 deletions modelscan/settings.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import tomlkit

from enum import Enum
from typing import Any

from modelscan._version import __version__


class DefaultModelFormats(Enum):
TENSORFLOW = "tensorflow"
KERAS_H5 = "keras_h5"
KERAS = "keras"
NUMPY = "numpy"
PYTORCH = "pytorch"
PICKLE = "pickle"


DEFAULT_REPORTING_MODULES = {
"console": "modelscan.reports.ConsoleReport",
"json": "modelscan.reports.JSONReport",
Expand Down Expand Up @@ -59,13 +70,12 @@
"middlewares": {
"modelscan.middlewares.FormatViaExtensionMiddleware": {
"formats": {
"tf": [".pb"],
"tf_saved_model": [".pb"],
"keras_h5": [".h5"],
"keras": [".keras"],
"numpy": [".npy"],
"pytorch": [".bin", ".pt", ".pth", ".ckpt"],
"pickle": [
DefaultModelFormats.TENSORFLOW: [".pb"],
DefaultModelFormats.KERAS_H5: [".h5"],
DefaultModelFormats.KERAS: [".keras"],
DefaultModelFormats.NUMPY: [".npy"],
DefaultModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"],
DefaultModelFormats.PICKLE: [
".pkl",
".pickle",
".joblib",
Expand Down

0 comments on commit 28e0e27

Please sign in to comment.