Skip to content

Commit

Permalink
add gcs upload for composer train
Browse files Browse the repository at this point in the history
  • Loading branch information
lruizcalico committed Jun 19, 2024
1 parent b6f5c3c commit a4ce96c
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion src/baskerville/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

import numpy as np
import tensorflow as tf

import tempfile
from baskerville.helpers.gcs_utils import is_gcs_path, upload_folder_gcs
from baskerville import metrics


Expand Down Expand Up @@ -119,6 +120,15 @@ def __init__(
self.batch_size = self.train_data[0].batch_size
self.compiled = False

# if log_dir is in gcs then create a local temp dir
if is_gcs_path(self.log_dir):
folder_name = self.log_dir.split("/")[-1]
self.log_dir = tempfile.mkdtemp() + folder_name
self.gcs_log_dir = log_dir
self.gcs = True
else:
self.gcs = False

# early stopping
self.patience = self.params.get("patience", 20)

Expand Down Expand Up @@ -498,6 +508,10 @@ def eval_step1_distr(xd, yd):
print(" - valid_r2: %.4f" % valid_r2[di].result().numpy(), end="")
early_stop_stat = valid_r[di].result().numpy()

# upload to gcs
if self.gcs:
upload_folder_gcs(self.log_dir, self.gcs_log_dir)

# checkpoint
managers[di].save()
model.save(
Expand Down Expand Up @@ -697,6 +711,10 @@ def eval_step_distr(xd, yd):
end="",
)

# upload to gcs
if self.gcs:
upload_folder_gcs(self.log_dir, self.gcs_log_dir)

# checkpoint
manager.save()
seqnn_model.save("%s/model_check.h5" % self.out_dir)
Expand Down

0 comments on commit a4ce96c

Please sign in to comment.