Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed save/load problem on dqn.py #184

Merged
merged 2 commits into from
Jul 14, 2023
Merged

Fixed save/load problem on dqn.py #184

merged 2 commits into from
Jul 14, 2023

Conversation

jmribeiro
Copy link
Contributor

@jmribeiro jmribeiro commented Jun 19, 2023

Saving and loading the DQN agent would not save/load four needed attributes:

  • self.t
  • self.optim_t
  • self._cumulative_steps
  • self.replay_buffer

This caused the agent to have different a performance when evaluated without killing the program vs when saving the agent, killing the program, resuming the program and loading the agent.

Fig 1 - Training without checkpoints (i.e. same program ran from start to finish)
plot

Fig 2 - Training with checkpoint (i.e., program killed at every t steps and agents loaded from disk)
plot

My proposed solution (working, but applied only to the DQN agent) was to add new save_snapshot and load_snapshot methods on the agent's class (without overwriting the original save and load methods, avoiding saving the replay buffer every time):

    def save_snapshot(self, dirname: str) -> None:
        self.save(dirname)
        torch.save(
            self.t, os.path.join(dirname, "t.pt")
        )
        torch.save(
            self.optim_t, os.path.join(dirname, "optim_t.pt")
        )
        torch.save(
            self._cumulative_steps, os.path.join(dirname, "_cumulative_steps.pt")
        )
        self.replay_buffer.save(
            os.path.join(dirname, "replay_buffer.pkl")
        )


    def load_snapshot(self, dirname: str) -> None:
        self.load(dirname)
        self.t = torch.load(
            os.path.join(dirname, "t.pt")
        )
        self.optim_t = torch.load(
            os.path.join(dirname, "optim_t.pt")
        )
        self._cumulative_steps = torch.load(
            os.path.join(dirname, "_cumulative_steps.pt")
        )
        self.replay_buffer.load(
            os.path.join(dirname, "replay_buffer.pkl")
        )

This change is working as intended, training is resumed properly after reloading the agent from disk:

Fig 3 - Training with checkpoint (New patch) (i.e., program killed at every t steps and agents loaded from disk)
image

Overwritten save() and load() methods on dqn.py to save four attributed needed for stopping/resuming training when saving the agent to disk:
- self.t
- self.optim_t
- self._cumulative_steps
- self.replay_buffer
Renamed overwritten methods save and load to save_snapshot and load_snapshot, to avoid saving the replay buffer in existing calls of the original methods
@muupan
Copy link
Member

muupan commented Jun 26, 2023

/test

@pfn-ci-bot
Copy link

Successfully created a job for commit 8fc26f4:

@muupan muupan self-requested a review June 26, 2023 03:37
@jmribeiro
Copy link
Contributor Author

/test

@muupan There seem to be a problem with tests on test_acer.py (unrelated to the changes)

image

Copy link
Member

@muupan muupan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for my delayed response! Some of CI including code lint is broken and should be fixed later. I see the changes by the PR is harmless besides potential lint warnings, which should also be fixed later if any. Thanks for your contribution!

@muupan muupan merged commit ee0f363 into pfnet:master Jul 14, 2023
@muupan muupan added this to the v0.4.0 milestone Jul 16, 2023
@muupan muupan added the enhancement New feature or request label Jul 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants