-
Notifications
You must be signed in to change notification settings - Fork 20
/
pickle_scan.py
79 lines (68 loc) · 2.26 KB
/
pickle_scan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# copyright zxix 2022
# https://creativecommons.org/licenses/by-nc-sa/4.0/
import torch
import pickle_inspector
import sys
from pathlib import Path
sys.stdout.reconfigure(encoding='utf-8')
debug = len(sys.argv) == 3
dir = sys.argv[1]
print("checking dir: " + dir)
BASE_DIR = Path(dir)
EXTENSIONS = {'.pt', '.bin', '.ckpt'}
BAD_CALLS = {'os', 'shutil', 'sys', 'requests', 'net'}
BAD_SIGNAL = {'rm ', 'cat ', 'nc ', '/bin/sh '}
for path in BASE_DIR.glob(r'**/*'):
if path.suffix in EXTENSIONS:
print("")
print("..." + path.as_posix())
result = torch.load(path.as_posix(), pickle_module=pickle_inspector.pickle)
result_total = 0
result_other = 0
result_calls = {}
result_signals = {}
result_output = ""
for call in BAD_CALLS:
result_calls[call] = 0
for signal in BAD_SIGNAL:
result_signals[signal] = 0
for c in result.calls:
for call in BAD_CALLS:
if (c.find(call + ".") == 0):
result_calls[call] += 1
result_total += 1
result_output += "\n--- found lib call (" + call + ") ---\n"
result_output += c
result_output += "\n---------------\n"
break
for signal in BAD_SIGNAL:
if (c.find(signal) > -1):
result_signals[signal] += 1
result_total += 1
result_output += "\n--- found malicious signal (" + signal + ") ---\n"
result_output += c
result_output += "\n---------------\n"
break
if (
c.find("numpy.") != 0 and
c.find("_codecs.") != 0 and
c.find("collections.") != 0 and
c.find("torch.") != 0):
result_total += 1
result_other += 1
result_output += "\n--- found non-standard lib call ---\n"
result_output += c
result_output += "\n---------------\n"
if (result_total > 0):
for call in BAD_CALLS:
print("library call (" + call + ".): " + str(result_calls[call]))
for signal in BAD_SIGNAL:
print("malicious signal (" + signal + "): " + str(result_signals[signal]))
print("non-standard calls: " + str(result_other))
print("total: " + str(result_total))
print("")
print("SCAN FAILED")
if (debug):
print(result_output)
else:
print("SCAN PASSED!")