Skip to content

Commit

Permalink
Fix newline terminator parsing (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanpmorgan authored Mar 29, 2024
1 parent e21035e commit 15df39e
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 5 deletions.
28 changes: 23 additions & 5 deletions modelscan/tools/picklescanner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import pickletools # nosec
from tarfile import TarError
from typing import IO, Any, Dict, List, Set, Tuple, Union
from typing import IO, Any, Dict, List, Set, Tuple, Union, Optional

import numpy as np

Expand All @@ -17,15 +17,15 @@


class GenOpsError(Exception):
def __init__(self, msg: str):
def __init__(self, msg: str, globals: Optional[Set[Tuple[str, str]]]):
self.msg = msg
self.globals = globals
super().__init__()

def __str__(self) -> str:
return self.msg


#
# TODO: handle methods loading other Pickle files (either mark as suspicious, or follow calls to scan other files [preventing infinite loops])
#
# pickle.loads()
Expand Down Expand Up @@ -62,7 +62,11 @@ def _list_globals(
pickletools.genops(data)
)
except Exception as e:
raise GenOpsError(str(e))
# Given we can have multiple pickles in a file, we may have already successfully extracted globals from a valid pickle.
# Thus return the already found globals in the error & let the caller decide what to do.
globals_opt = globals if len(globals) > 0 else None
raise GenOpsError(str(e), globals_opt)

last_byte = data.read(1)
data.seek(-1, 1)

Expand Down Expand Up @@ -126,6 +130,12 @@ def scan_pickle_bytes(
try:
raw_globals = _list_globals(model.get_stream(), multiple_pickles)
except GenOpsError as e:
if e.globals is not None:
return _build_scan_result_from_raw_globals(
e.globals,
model,
settings,
)
return ScanResults(
issues,
[
Expand All @@ -138,8 +148,16 @@ def scan_pickle_bytes(
],
[],
)
logger.debug("Global imports in %s: %s", model, raw_globals, settings)
return _build_scan_result_from_raw_globals(raw_globals, model, settings)


logger.debug("Global imports in %s: %s", model.get_source(), raw_globals)
def _build_scan_result_from_raw_globals(
raw_globals: Set[Tuple[str, str]],
model: Model,
settings: Dict[str, Any],
) -> ScanResults:
issues: List[Issue] = []
severities = {
"CRITICAL": IssueSeverity.CRITICAL,
"HIGH": IssueSeverity.HIGH,
Expand Down
59 changes: 59 additions & 0 deletions tests/test_modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,36 @@ def malicious13_gen() -> bytes:
return p


def malicious14_gen() -> bytes:
p = b"".join(
[
pickle.UNICODE + b"os\n",
pickle.PUT + b"2\n",
pickle.POP,
pickle.UNICODE + b"system\n",
pickle.PUT + b"3\n",
pickle.POP,
pickle.UNICODE + b"torch\n",
pickle.PUT + b"0\n",
pickle.POP,
pickle.UNICODE + b"LongStorage\n",
pickle.PUT + b"1\n",
pickle.POP,
pickle.GET + b"2\n",
pickle.GET + b"3\n",
pickle.STACK_GLOBAL,
pickle.MARK,
pickle.UNICODE + b"cat flag.txt\n",
pickle.TUPLE,
pickle.REDUCE,
pickle.STOP,
b"\n\n\t\t",
]
)

return p


def initialize_pickle_file(path: str, obj: Any, version: int) -> None:
if not os.path.exists(path):
with open(path, "wb") as file:
Expand Down Expand Up @@ -288,6 +318,8 @@ def file_path(tmp_path_factory: Any) -> Any:

initialize_data_file(f"{tmp}/data/malicious13.pkl", malicious13_gen())

initialize_data_file(f"{tmp}/data/malicious14.pkl", malicious14_gen())

return tmp


Expand Down Expand Up @@ -950,6 +982,22 @@ def test_scan_pickle_operators(file_path: Any) -> None:
malicious13.scan(Path(f"{file_path}/data/malicious13.pkl"))
assert malicious13.issues.all_issues == expected_malicious13

expected_malicious14 = [
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.CRITICAL,
OperatorIssueDetails(
"os",
"system",
IssueSeverity.CRITICAL,
f"{file_path}/data/malicious14.pkl",
),
)
]
malicious14 = ModelScan()
malicious14.scan(Path(f"{file_path}/data/malicious14.pkl"))
assert malicious14.issues.all_issues == expected_malicious14


def test_scan_directory_path(file_path: str) -> None:
expected = {
Expand Down Expand Up @@ -1204,6 +1252,16 @@ def test_scan_directory_path(file_path: str) -> None:
f"{file_path}/data/malicious13.pkl",
),
),
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.CRITICAL,
OperatorIssueDetails(
"os",
"system",
IssueSeverity.CRITICAL,
f"{file_path}/data/malicious14.pkl",
),
),
}
ms = ModelScan()
p = Path(f"{file_path}/data/")
Expand All @@ -1221,6 +1279,7 @@ def test_scan_directory_path(file_path: str) -> None:
f"malicious11.pkl",
f"malicious12.pkl",
f"malicious13.pkl",
f"malicious14.pkl",
f"malicious1_v0.dill",
f"malicious1_v3.dill",
f"malicious1_v4.dill",
Expand Down

0 comments on commit 15df39e

Please sign in to comment.