From ccf9232256e883117f68fbbc2b84bd2709e7cb96 Mon Sep 17 00:00:00 2001 From: Ryan Benasutti Date: Thu, 23 Jan 2020 19:36:36 -0500 Subject: [PATCH] Add create and remove heartbeat --- axon/client.py | 70 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 2 deletions(-) diff --git a/axon/client.py b/axon/client.py index 9f2fd7e..1a2c66b 100644 --- a/axon/client.py +++ b/axon/client.py @@ -689,14 +689,48 @@ def impl_update_training_progress(model_name, dataset_name, progress_text, bucke with open(local_file, "w") as f: f.write(progress_text) client = make_client("s3", region) - remote_path = "axon-training-progress/" + os.path.basename(model_name) + "/" + \ - os.path.basename(dataset_name) + "/progress.txt" + remote_path = create_progress_prefix(model_name, dataset_name) + "/progress.txt" client.upload_file(path, bucket_name, remote_path) print("Updated progress in: {}\n".format(remote_path)) finally: os.remove(path) +def impl_create_heartbeat(model_name, dataset_name, bucket_name, region): + """ + Creates a heartbeat that Axon uses to check if the training script is running properly. + + :param model_name: The filename of the model. + :param dataset_name: The filename of the dataset. + :param bucket_name: The S3 bucket name. + :param region: The region, or `None` to pull the region from the environment. + """ + client = make_client("s3", region) + remote_path = create_progress_prefix(model_name, dataset_name) + "/heartbeat.txt" + client.put_object(Body="1", Bucket=bucket_name, Key=remote_path) + print("Created heartbeat file in: {}\n".format(remote_path)) + + +def impl_remove_heartbeat(model_name, dataset_name, bucket_name, region): + """ + Removes a heartbeat that Axon uses to check if the training script is running properly. + + :param model_name: The filename of the model. + :param dataset_name: The filename of the dataset. + :param bucket_name: The S3 bucket name. + :param region: The region, or `None` to pull the region from the environment. + """ + client = make_client("s3", region) + remote_path = create_progress_prefix(model_name, dataset_name) + "/heartbeat.txt" + client.put_object(Body="0", Bucket=bucket_name, Key=remote_path) + print("Removed heartbeat file in: {}\n".format(remote_path)) + + +def create_progress_prefix(model_name, dataset_name): + return "axon-training-progress/" + os.path.basename(model_name) + "/" + \ + os.path.basename(dataset_name) + + @click.group() def cli(): return @@ -902,3 +936,35 @@ def update_training_progress(model_name, dataset_name, progress_text, region): """ impl_update_training_progress(model_name, dataset_name, progress_text, ensure_s3_bucket(region), region) + + +@cli.command(name="create-heartbeat") +@click.argument("model-name") +@click.argument("dataset-name") +@click.option("--region", help="The region to connect to.", + type=click.Choice(region_choices)) +def create_heartbeat(model_name, dataset_name, region): + """ + Creates a heartbeat that Axon uses to check if the training script is running properly. + + MODEL_NAME The filename of the model currently being trained. + + DATASET_NAME The name of the dataset currently being trained on. + """ + impl_create_heartbeat(model_name, dataset_name, ensure_s3_bucket(region), region) + + +@cli.command(name="remove-heartbeat") +@click.argument("model-name") +@click.argument("dataset-name") +@click.option("--region", help="The region to connect to.", + type=click.Choice(region_choices)) +def remove_heartbeat(model_name, dataset_name, region): + """ + Removes a heartbeat that Axon uses to check if the training script is running properly. + + MODEL_NAME The filename of the model currently being trained. + + DATASET_NAME The name of the dataset currently being trained on. + """ + impl_remove_heartbeat(model_name, dataset_name, ensure_s3_bucket(region), region)