-
-
Notifications
You must be signed in to change notification settings - Fork 620
/
test_image.py
65 lines (53 loc) · 2.18 KB
/
test_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#
# Tests :
# For all images
# can import torch and its version == required one
# can import ignite and its version == required one
# for all -vision images
# can import opencv without driver issue
# for all horovod images
# can import horovod and its version == required one
#
import argparse
import importlib
import os
def check_package(package_name, expected_version=None):
mod = importlib.import_module(package_name)
if expected_version is not None:
assert hasattr(mod, "__version__"), f"Imported package {package_name} does not have __version__ attribute"
version = mod.__version__
# Remove all +something from the version name: e.g torch 2.5.1+cu124
if "+" in version:
old_version = version
version = version.split("+")[0]
print(f"Transformed version: {old_version} -> {version}")
assert (
version == expected_version
), f"Version mismatch for package {package_name}: got {version} but expected {expected_version}"
if __name__ == "__main__":
parser = argparse.ArgumentParser("Check docker image script")
parser.add_argument("image", type=str, help="Docker image to check")
args = parser.parse_args()
docker_image_name = args.image
name, version = docker_image_name.split(":")
assert version != "latest", version
torch_version, ignite_version = version.split("-")
_, image_type = name.split("/")
check_package("torch", expected_version=torch_version)
check_package("ignite", expected_version=ignite_version)
if "hvd" in image_type:
assert "HVD_VERSION" in os.environ
val = os.environ["HVD_VERSION"]
hvd_version = val if val[0] != "v" else val[1:]
check_package("horovod", expected_version=hvd_version)
if "msdp" in image_type:
assert "MSDP_VERSION" in os.environ
val = os.environ["MSDP_VERSION"]
hvd_version = val if val[0] != "v" else val[1:]
check_package("deepspeed", expected_version=hvd_version)
if "vision" in image_type:
check_package("cv2")
if "nlp" in image_type:
check_package("transformers")
if "apex" in image_type:
check_package("apex")