Skip to content

Commit 119f740

Browse files
MAJOR: Made the render function of RLToyEnv more general to allow custom trajectory rollouts from the current state; made default continuous toy env have dense reward; improved error message.
1 parent ea0fc1f commit 119f740

File tree

1 file changed

+55
-13
lines changed

1 file changed

+55
-13
lines changed

mdp_playground/envs/rl_toy_env.py

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,10 @@ def __init__(self, **config):
381381
self.reward_density = config["reward_density"]
382382

383383
if "make_denser" not in config:
384-
self.make_denser = False
384+
if config["state_space_type"] == "discrete":
385+
self.make_denser = False
386+
elif config["state_space_type"] == "continuous":
387+
self.make_denser = True
385388
else:
386389
self.make_denser = config["make_denser"]
387390

@@ -1661,8 +1664,10 @@ def transition_function(self, state, action):
16611664
next_state = state
16621665
warnings.warn(
16631666
"WARNING: Action "
1664-
+ str(action)
1665-
+ " out of range of action space. Applying 0 action!!"
1667+
+ str(action) + " " + str(action.dtype) + " " + str(type(action))
1668+
+ " out of range of action space. Applying 0 action!!\n"
1669+
+ "If the action seems to be within range, please check the dtype. "
1670+
+ "We need to have the action dtype be " + str(self.dtype_s) + "."
16661671
)
16671672

16681673
# if "transition_noise" in self.config:
@@ -2334,17 +2339,25 @@ def seed(self, seed=None):
23342339
)
23352340
return self.seed_
23362341

2337-
def render(self,):
2342+
def render(self, actions=None, render_mode=None):
23382343
'''
23392344
Renders the environment using pygame if render_mode is "human" and returns the rendered
23402345
image if render_mode is "rgb_array".
23412346
23422347
Based on https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/
2348+
2349+
If actions and render_mode are None, it does the default Gymnasium render for the current state.
2350+
If an array or list of actions is provided, it performs steps in the environment with that action
2351+
sequence and then renders the resulting trajectory and returns the rendered RGB images.
2352+
If render_mode is provided, it overrides the default render_mode set in the environment's config. Currently, only
2353+
rgb_array is supported for this. Would need to look deeper into pygame, e.g. for how to instantiate mutliple windows
2354+
to support "human" mode overriding as well.
23432355
'''
23442356

23452357
import pygame
23462358

2347-
# Init stuff on first call. For non-image_representations based envs, it makes sense
2359+
# Init stuff on first call to this function now.
2360+
# For non-image_representations based envs, it makes sense
23482361
# to only instantiate the render_space here and not in __init__ because it's only needed
23492362
# if render() is called.
23502363
if self.window is None:
@@ -2384,16 +2397,45 @@ def render(self,):
23842397
seed=self.seed_dict["image_representations"],
23852398
) # #seed
23862399

2400+
# Also init pygame stuff on first call to render() if render_mode is "human":
2401+
if self.render_mode == "human":
2402+
pygame.init()
2403+
pygame.display.init()
2404+
self.window = pygame.display.set_mode(
2405+
(self.image_width, self.image_height)
2406+
)
2407+
self.clock = pygame.time.Clock()
2408+
2409+
# If actions is not None, we need to perform 1 or more steps in the environment.
2410+
# Create a deepcopy of the environment and roll it out with the actions provided.
2411+
if actions is not None:
2412+
if not isinstance(actions, (list, tuple, np.ndarray)):
2413+
raise TypeError(
2414+
"actions should be a list or numpy array of actions, not "
2415+
+ str(type(actions))
2416+
)
2417+
if len(actions) == 0:
2418+
raise ValueError("actions cannot be an empty list or array.")
2419+
# Make a copy of the environment to perform the rollout:
2420+
env_copy = copy.deepcopy(self)
2421+
env_copy.render_mode = "rgb_array" # Set render_mode to rgb_array for the copy #hardcoded
2422+
if render_mode is not None and render_mode != "rgb_array":
2423+
raise NotImplementedError(
2424+
"Currently, only render_mode 'rgb_array' is supported for action sequences."
2425+
"render_mode should be None or 'rgb_array' when such sequences are provided, not "
2426+
+ str(render_mode)
2427+
)
2428+
2429+
# Perform the rollout with the actions provided:
2430+
rgb_arrays = []
2431+
for action in actions:
2432+
obs, reward, done, truncated, _ = env_copy.step(action)
2433+
rgb_array = env_copy.render()
2434+
rgb_arrays.append(rgb_array)
23872435

2388-
if self.window is None and self.render_mode == "human":
2389-
pygame.init()
2390-
pygame.display.init()
2391-
self.window = pygame.display.set_mode(
2392-
(self.image_width, self.image_height)
2393-
)
2394-
if self.clock is None and self.render_mode == "human":
2395-
self.clock = pygame.time.Clock()
2436+
return rgb_arrays
23962437

2438+
# General render logic for every call to render():
23972439
if self.render_mode == "human":
23982440
if not self.image_representations:
23992441
rgb_array = self.render_space.get_concatenated_image(self.curr_state)

0 commit comments

Comments
 (0)