|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +import sys |
| 4 | +import argparse |
| 5 | +import os |
| 6 | +import socket |
| 7 | +import ssl |
| 8 | +import io |
| 9 | +import zipfile |
| 10 | +from getpass import getpass |
| 11 | + |
| 12 | + |
| 13 | +def get_artiq_cert(): |
| 14 | + try: |
| 15 | + import artiq |
| 16 | + except ImportError: |
| 17 | + return None |
| 18 | + filename = os.path.join(os.path.dirname(artiq.__file__), "afws.pem") |
| 19 | + if not os.path.isfile(filename): |
| 20 | + return None |
| 21 | + return filename |
| 22 | + |
| 23 | + |
| 24 | +def get_artiq_rev(): |
| 25 | + try: |
| 26 | + import artiq |
| 27 | + except ImportError: |
| 28 | + return None |
| 29 | + version = artiq.__version__ |
| 30 | + if version.endswith(".beta"): |
| 31 | + version = version[:-5] |
| 32 | + version = version.split(".") |
| 33 | + if len(version) != 3: |
| 34 | + return None |
| 35 | + major, minor, rev = version |
| 36 | + return rev |
| 37 | + |
| 38 | + |
| 39 | +def zip_unarchive(data, directory): |
| 40 | + buf = io.BytesIO(data) |
| 41 | + with zipfile.ZipFile(buf) as archive: |
| 42 | + archive.extractall(directory) |
| 43 | + |
| 44 | + |
| 45 | +class Client: |
| 46 | + def __init__(self, server, port, cafile): |
| 47 | + self.ssl_context = ssl.create_default_context(cafile=cafile) |
| 48 | + self.raw_socket = socket.create_connection((server, port)) |
| 49 | + try: |
| 50 | + self.socket = self.ssl_context.wrap_socket(self.raw_socket, server_hostname=server) |
| 51 | + except: |
| 52 | + self.raw_socket.close() |
| 53 | + raise |
| 54 | + self.fsocket = self.socket.makefile("rwb") |
| 55 | + |
| 56 | + def close(self): |
| 57 | + self.socket.close() |
| 58 | + self.raw_socket.close() |
| 59 | + |
| 60 | + def send_command(self, *command): |
| 61 | + self.fsocket.write((" ".join(command) + "\n").encode()) |
| 62 | + self.fsocket.flush() |
| 63 | + |
| 64 | + def read_reply(self): |
| 65 | + return self.fsocket.readline().decode("ascii").split() |
| 66 | + |
| 67 | + def login(self, username, password): |
| 68 | + self.send_command("LOGIN", username, password) |
| 69 | + return self.read_reply() == ["HELLO"] |
| 70 | + |
| 71 | + def build(self, rev, variant): |
| 72 | + self.send_command("BUILD", rev, variant) |
| 73 | + reply = self.read_reply()[0] |
| 74 | + if reply != "BUILDING": |
| 75 | + return reply, None |
| 76 | + print("Build in progress. This may take 10-15 minutes.") |
| 77 | + reply, status = self.read_reply() |
| 78 | + if reply != "DONE": |
| 79 | + raise ValueError("Unexpected server reply: expected 'DONE', got '{}'".format(reply)) |
| 80 | + if status != "done": |
| 81 | + return status, None |
| 82 | + print("Build completed. Downloading...") |
| 83 | + reply, length = self.read_reply() |
| 84 | + if reply != "PRODUCT": |
| 85 | + raise ValueError("Unexpected server reply: expected 'PRODUCT', got '{}'".format(reply)) |
| 86 | + contents = self.fsocket.read(int(length)) |
| 87 | + print("Download completed.") |
| 88 | + return "OK", contents |
| 89 | + |
| 90 | + def passwd(self, password): |
| 91 | + self.send_command("PASSWD", password) |
| 92 | + return self.read_reply() == ["OK"] |
| 93 | + |
| 94 | + |
| 95 | +def main(): |
| 96 | + parser = argparse.ArgumentParser() |
| 97 | + parser.add_argument("--server", default="nixbld.m-labs.hk", help="server to connect to (default: %(default)s)") |
| 98 | + parser.add_argument("--port", default=7402, type=int, help="port to connect to (default: %(default)d)") |
| 99 | + parser.add_argument("--cert", default=None, help="SSL certificate file used to authenticate server (default: afws.pem in ARTIQ)") |
| 100 | + parser.add_argument("username", help="user name for logging into AFWS") |
| 101 | + action = parser.add_subparsers(dest="action") |
| 102 | + action.required = True |
| 103 | + act_build = action.add_parser("build", help="build and download firmware") |
| 104 | + act_build.add_argument("--rev", default=None, help="revision to build (default: currently installed ARTIQ revision)") |
| 105 | + act_build.add_argument("variant", help="variant to build") |
| 106 | + act_build.add_argument("directory", help="output directory") |
| 107 | + act_passwd = action.add_parser("passwd", help="change password") |
| 108 | + args = parser.parse_args() |
| 109 | + |
| 110 | + cert = args.cert |
| 111 | + if cert is None: |
| 112 | + cert = get_artiq_cert() |
| 113 | + if cert is None: |
| 114 | + print("SSL certificate not found in ARTIQ. Specify manually using --cert.") |
| 115 | + sys.exit(1) |
| 116 | + |
| 117 | + if args.action == "passwd": |
| 118 | + password = getpass("Current password: ") |
| 119 | + else: |
| 120 | + password = getpass() |
| 121 | + |
| 122 | + client = Client(args.server, args.port, cert) |
| 123 | + try: |
| 124 | + if not client.login(args.username, password): |
| 125 | + print("Login failed") |
| 126 | + sys.exit(1) |
| 127 | + if args.action == "passwd": |
| 128 | + print("Password must made of alphanumeric characters (a-z, A-Z, 0-9) and be at least 8 characters long.") |
| 129 | + password = getpass("New password: ") |
| 130 | + password_confirm = getpass("New password (again): ") |
| 131 | + while password != password_confirm: |
| 132 | + print("Passwords do not match") |
| 133 | + password = getpass("New password: ") |
| 134 | + password_confirm = getpass("New password (again): ") |
| 135 | + if not client.passwd(password): |
| 136 | + print("Failed to change password") |
| 137 | + sys.exit(1) |
| 138 | + elif args.action == "build": |
| 139 | + try: |
| 140 | + os.mkdir(args.directory) |
| 141 | + except FileExistsError: |
| 142 | + if any(os.scandir(args.directory)): |
| 143 | + print("Output directory already exists and is not empty. Please remove it and try again.") |
| 144 | + sys.exit(1) |
| 145 | + rev = args.rev |
| 146 | + if rev is None: |
| 147 | + rev = get_artiq_rev() |
| 148 | + if rev is None: |
| 149 | + print("Unable to determine currently installed ARTIQ revision. Specify manually using --rev.") |
| 150 | + sys.exit(1) |
| 151 | + result, contents = client.build(rev, args.variant) |
| 152 | + if result != "OK": |
| 153 | + if result == "UNAUTHORIZED": |
| 154 | + print("You are not authorized to build this variant. Your firmware subscription may have expired. Contact helpdesk\x40m-labs.hk.") |
| 155 | + else: |
| 156 | + print("Build failed: {}".format(result)) |
| 157 | + sys.exit(1) |
| 158 | + zip_unarchive(contents, args.directory) |
| 159 | + else: |
| 160 | + raise ValueError |
| 161 | + finally: |
| 162 | + client.close() |
| 163 | + |
| 164 | + |
| 165 | +if __name__ == "__main__": |
| 166 | + main() |
0 commit comments