Skip to content
This repository has been archived by the owner on Sep 26, 2020. It is now read-only.

Commit

Permalink
Add create and remove heartbeat
Browse files Browse the repository at this point in the history
  • Loading branch information
Octogonapus committed Jan 24, 2020
1 parent fc12020 commit ccf9232
Showing 1 changed file with 68 additions and 2 deletions.
70 changes: 68 additions & 2 deletions axon/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,14 +689,48 @@ def impl_update_training_progress(model_name, dataset_name, progress_text, bucke
with open(local_file, "w") as f:
f.write(progress_text)
client = make_client("s3", region)
remote_path = "axon-training-progress/" + os.path.basename(model_name) + "/" + \
os.path.basename(dataset_name) + "/progress.txt"
remote_path = create_progress_prefix(model_name, dataset_name) + "/progress.txt"
client.upload_file(path, bucket_name, remote_path)
print("Updated progress in: {}\n".format(remote_path))
finally:
os.remove(path)


def impl_create_heartbeat(model_name, dataset_name, bucket_name, region):
"""
Creates a heartbeat that Axon uses to check if the training script is running properly.
:param model_name: The filename of the model.
:param dataset_name: The filename of the dataset.
:param bucket_name: The S3 bucket name.
:param region: The region, or `None` to pull the region from the environment.
"""
client = make_client("s3", region)
remote_path = create_progress_prefix(model_name, dataset_name) + "/heartbeat.txt"
client.put_object(Body="1", Bucket=bucket_name, Key=remote_path)
print("Created heartbeat file in: {}\n".format(remote_path))


def impl_remove_heartbeat(model_name, dataset_name, bucket_name, region):
"""
Removes a heartbeat that Axon uses to check if the training script is running properly.
:param model_name: The filename of the model.
:param dataset_name: The filename of the dataset.
:param bucket_name: The S3 bucket name.
:param region: The region, or `None` to pull the region from the environment.
"""
client = make_client("s3", region)
remote_path = create_progress_prefix(model_name, dataset_name) + "/heartbeat.txt"
client.put_object(Body="0", Bucket=bucket_name, Key=remote_path)
print("Removed heartbeat file in: {}\n".format(remote_path))


def create_progress_prefix(model_name, dataset_name):
return "axon-training-progress/" + os.path.basename(model_name) + "/" + \
os.path.basename(dataset_name)


@click.group()
def cli():
return
Expand Down Expand Up @@ -902,3 +936,35 @@ def update_training_progress(model_name, dataset_name, progress_text, region):
"""
impl_update_training_progress(model_name, dataset_name, progress_text, ensure_s3_bucket(region),
region)


@cli.command(name="create-heartbeat")
@click.argument("model-name")
@click.argument("dataset-name")
@click.option("--region", help="The region to connect to.",
type=click.Choice(region_choices))
def create_heartbeat(model_name, dataset_name, region):
"""
Creates a heartbeat that Axon uses to check if the training script is running properly.
MODEL_NAME The filename of the model currently being trained.
DATASET_NAME The name of the dataset currently being trained on.
"""
impl_create_heartbeat(model_name, dataset_name, ensure_s3_bucket(region), region)


@cli.command(name="remove-heartbeat")
@click.argument("model-name")
@click.argument("dataset-name")
@click.option("--region", help="The region to connect to.",
type=click.Choice(region_choices))
def remove_heartbeat(model_name, dataset_name, region):
"""
Removes a heartbeat that Axon uses to check if the training script is running properly.
MODEL_NAME The filename of the model currently being trained.
DATASET_NAME The name of the dataset currently being trained on.
"""
impl_remove_heartbeat(model_name, dataset_name, ensure_s3_bucket(region), region)

0 comments on commit ccf9232

Please sign in to comment.