Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in the Reply Buffer: end of episodes is not correctly handled #228

Open
theovincent opened this issue Feb 18, 2025 · 1 comment
Open

Comments

@theovincent
Copy link

theovincent commented Feb 18, 2025

Hi,

Thank you for your great work. It is really cool to open source such an amazing code base!

TL;DR

@yogesh1q2w, and I noticed that the last transitions of a trajectory are not properly handled. Indeed, multiple ReplayElements with a terminal flag are stored when only one is given to the accumulator.

It is problematic because the additional terminal states do not correspond to states that can be observed from the environment. This is problematic because we use function approximation.

How to reproduce?

After forking the repo and running

python3.11.5 -m venv env_cpu
source env_cpu/bin/activate
pip install --upgrade pip setuptools wheel
pip install -e .

I ran

import numpy as np
from dopamine.jax.replay_memory import accumulator, samplers, replay_buffer, elements

transition_accumulator = accumulator.TransitionAccumulator(stack_size=4, update_horizon=1, gamma=0.99)
sampling_distribution = samplers.UniformSamplingDistribution(seed=1)
rb = replay_buffer.ReplayBuffer(
	transition_accumulator=transition_accumulator,
	sampling_distribution=sampling_distribution,
	batch_size=1,
	max_capacity=50,
	compress=False
)

for i in range(8):
	rb.add(elements.TransitionElement(i * np.ones(1), i, i, False if i < 7 else True, False))

print(rb._memory)
OrderedDict([(0,
              ReplayElement(state=array([[0., 0., 0., 0.]]), action=0, reward=0.0, next_state=array([[0., 0., 0., 1.]]), is_terminal=False, episode_end=False)),
             (1,
              ReplayElement(state=array([[0., 0., 0., 1.]]), action=1, reward=1.0, next_state=array([[0., 0., 1., 2.]]), is_terminal=False, episode_end=False)),
             (2,
              ReplayElement(state=array([[0., 0., 1., 2.]]), action=2, reward=2.0, next_state=array([[0., 1., 2., 3.]]), is_terminal=False, episode_end=False)),
             (3,
              ReplayElement(state=array([[0., 1., 2., 3.]]), action=3, reward=3.0, next_state=array([[1., 2., 3., 4.]]), is_terminal=False, episode_end=False)),
             (4,
              ReplayElement(state=array([[1., 2., 3., 4.]]), action=4, reward=4.0, next_state=array([[2., 3., 4., 5.]]), is_terminal=False, episode_end=False)),
             (5,
              ReplayElement(state=array([[2., 3., 4., 5.]]), action=5, reward=5.0, next_state=array([[3., 4., 5., 6.]]), is_terminal=False, episode_end=False)),
             (6,
              ReplayElement(state=array([[3., 4., 5., 6.]]), action=6, reward=6.0, next_state=array([[4., 5., 6., 7.]]), is_terminal=True, episode_end=True)),
             (7,
              ReplayElement(state=array([[0., 4., 5., 6.]]), action=6, reward=6.0, next_state=array([[4., 5., 6., 7.]]), is_terminal=True, episode_end=True)),
             (8,
              ReplayElement(state=array([[0., 0., 5., 6.]]), action=6, reward=6.0, next_state=array([[0., 5., 6., 7.]]), is_terminal=True, episode_end=True)),
             (9,
              ReplayElement(state=array([[0., 0., 0., 6.]]), action=6, reward=6.0, next_state=array([[0., 0., 6., 7.]]), is_terminal=True, episode_end=True))])

The last 3 ReplayElements are incorrect. They should not have been added.

How to fix the bug?

Replacing the following lines

# Check if we have a valid transition, i.e. we either
# 1) have accumulated more transitions than the update horizon
# 2) have a trajectory shorter than the update horizon, but the
# last element is terminal
if not (
trajectory_len > self._update_horizon
or (trajectory_len > 1 and last_transition.is_terminal)
):
return None

by

    # Check if we have a valid transition, i.e. we either
    #   1) have accumulated more transitions than the update horizon and the
    #      last element is not terminal
    #   2) have a trajectory shorter than the update horizon, but the
    #      last element is terminal and we have enough frames to stack
    if not (
        (trajectory_len > self._update_horizon and not last_transition.is_terminal)
        or (trajectory_len > self._stack_size and last_transition.is_terminal)
    ):
        return None

solves the issue. Indeed, by running the same code again, we obtain:

OrderedDict([(0,
              ReplayElement(state=array([[0., 0., 0., 0.]]), action=0, reward=0.0, next_state=array([[0., 0., 0., 1.]]), is_terminal=False, episode_end=False)),
             (1,
              ReplayElement(state=array([[0., 0., 0., 1.]]), action=1, reward=1.0, next_state=array([[0., 0., 1., 2.]]), is_terminal=False, episode_end=False)),
             (2,
              ReplayElement(state=array([[0., 0., 1., 2.]]), action=2, reward=2.0, next_state=array([[0., 1., 2., 3.]]), is_terminal=False, episode_end=False)),
             (3,
              ReplayElement(state=array([[0., 1., 2., 3.]]), action=3, reward=3.0, next_state=array([[1., 2., 3., 4.]]), is_terminal=False, episode_end=False)),
             (4,
              ReplayElement(state=array([[1., 2., 3., 4.]]), action=4, reward=4.0, next_state=array([[2., 3., 4., 5.]]), is_terminal=False, episode_end=False)),
             (5,
              ReplayElement(state=array([[2., 3., 4., 5.]]), action=5, reward=5.0, next_state=array([[3., 4., 5., 6.]]), is_terminal=False, episode_end=False)),
             (6,
              ReplayElement(state=array([[3., 4., 5., 6.]]), action=6, reward=6.0, next_state=array([[4., 5., 6., 7.]]), is_terminal=True, episode_end=True))])

The last ReplayElements have been filtered 🎉

@theovincent theovincent changed the title Bug in the Reply Buffer: end of episodes are not correctly handled Bug in the Reply Buffer: end of episodes is not correctly handled Feb 18, 2025
theovincent added a commit to theovincent/dopamine that referenced this issue Feb 18, 2025
@theovincent
Copy link
Author

I made the change on this fork: https://github.com/theovincent/dopamine

Let me know if you would like me to make a PR 🙂

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant