Skip to content

Commit

Permalink
set obs dtype to float64
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene committed Apr 5, 2024
1 parent 0fe4449 commit 6283bc7
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions gym_pusht/envs/pusht.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,17 @@ def __init__(
if self.obs_type == "state":
# [agent_x, agent_y, block_x, block_y, block_angle]
self.observation_space = spaces.Box(
low=np.array([0, 0, 0, 0, 0], dtype=np.float32),
high=np.array([512, 512, 512, 512, 2 * np.pi], dtype=np.float32),
low=np.array([0, 0, 0, 0, 0], dtype=np.float64),
high=np.array([512, 512, 512, 512, 2 * np.pi], dtype=np.float64),
)
elif self.obs_type == "pixels":
self.observation_space = spaces.Box(low=0, high=255, shape=(render_size, render_size, 3), dtype=np.uint8)
elif self.obs_type == "pixels_agent_pos":
self.observation_space = spaces.Dict({
"pixels": spaces.Box(low=0, high=255, shape=(render_size, render_size, 3), dtype=np.uint8),
"agent_pos": spaces.Box(
low=np.array([0, 0], dtype=np.float32),
high=np.array([512, 512], dtype=np.float32),
low=np.array([0, 0], dtype=np.float64),
high=np.array([512, 512], dtype=np.float64),
),
})

Expand Down

0 comments on commit 6283bc7

Please sign in to comment.