diff --git a/rocket_learn/ppo.py b/rocket_learn/ppo.py index cad2d24..be634e6 100644 --- a/rocket_learn/ppo.py +++ b/rocket_learn/ppo.py @@ -171,7 +171,7 @@ def _iter(): iteration += 1 if save_dir: - self.save(os.path.join(save_dir, self.logger.project + "_" + "latest"), -1, save_jit) + self.save(os.path.join(save_dir, self.logger.project + "_" + "latest"), iteration, save_actor_jit=save_jit, is_latest=True) if iteration % iterations_per_save == 0: self.save(current_run_dir, iteration, save_jit) # noqa @@ -485,18 +485,19 @@ def load(self, load_location, continue_iterations=True): self.total_steps = checkpoint["total_steps"] print("Continuing training at iteration " + str(self.starting_iteration)) - def save(self, save_location, current_step, save_actor_jit=False): + def save(self, save_location, current_step, save_actor_jit=False, is_latest=False,): """ Save the model weights, optimizer values, and metadata :param save_location: where to save :param current_step: the current iteration when saved. Use to later continue training :param save_actor_jit: save the policy network as a torch jit file for rlbot use + :param is_latest: if this file is the "latest" checkpoint, used to decide if real checkpoint number should be used in filename """ - version_str = str(self.logger.project) + "_" + str(current_step) + version_str = str(self.logger.project) + "_" + (str(current_step) if not is_latest else "c") version_dir = save_location + "\\" + version_str - os.makedirs(version_dir, exist_ok=current_step == -1) + os.makedirs(version_dir, exist_ok=is_latest) torch.save({ 'epoch': current_step,