forked from aws/sagemaker-tensorflow-extensions
-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_integ_test_docker_images.py
58 lines (49 loc) · 1.93 KB
/
create_integ_test_docker_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from __future__ import absolute_import
import argparse
import base64
import subprocess
import docker
import boto3
import botocore
import glob
import sys
TF_VERSION = "2.14.0"
REGION = "us-west-2"
REPOSITORY_NAME = "sagemaker-tensorflow-extensions-test"
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('device', nargs='?', default='cpu')
args = parser.parse_args()
client = docker.from_env()
ecr_client = boto3.client('ecr', region_name=REGION)
token = ecr_client.get_authorization_token()
username, password = base64.b64decode(token['authorizationData'][0]['authorizationToken']).decode().split(':')
registry = token['authorizationData'][0]['proxyEndpoint']
subprocess.check_call([sys.executable, 'setup.py', 'sdist'])
[sdist_path] = glob.glob('dist/sagemaker_tensorflow-{}*'.format(TF_VERSION))
try:
ecr_client.create_repository(repositoryName=REPOSITORY_NAME)
except botocore.exceptions.ClientError as e:
if e.response['Error']['Code'] == 'RepositoryAlreadyExistsException':
pass
else:
raise
python_version = str(sys.version_info[0])
tag = '{}/{}:{}-{}-{}'.format(registry, REPOSITORY_NAME, TF_VERSION, args.device, python_version)[8:]
# pull existing image for layer cache
try:
client.images.pull(tag, auth_config={'username': username, 'password': password})
except docker.errors.NotFound:
pass
client.images.build(
path='.',
dockerfile='test/integ/Dockerfile',
tag=tag,
cache_from=[tag],
buildargs={'sagemaker_tensorflow': sdist_path,
'device': args.device,
'python': '/usr/bin/python3',
'tensorflow_version': TF_VERSION,
'script': 'test/integ/scripts/estimator_script.py'})
client.images.push(tag, auth_config={'username': username, 'password': password})
print(tag)