Skip to content

Commit 3c265f5

Browse files
Allow setting the Markov state (augmented_state and other variables as needed) of the env - use this in render() to allow imaginary rollouts from a custom starting state there.
1 parent 119f740 commit 3c265f5

File tree

1 file changed

+78
-9
lines changed

1 file changed

+78
-9
lines changed

mdp_playground/envs/rl_toy_env.py

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,9 @@ class RLToyEnv(gym.Env):
202202
R(state, action)
203203
defined as a lambda function in the call to init_reward_function() and is equivalent to calling reward_function()
204204
get_augmented_state()
205-
gets underlying Markovian state of the MDP
205+
gets underlying Markovian state of the MDP as a dictionary
206+
set_augmented_state(augmented_state_dict)
207+
sets underlying Markovian state of the MDP, by default, using a dictionary in the same format as returned by get_augmented_state()
206208
reset()
207209
Resets environment state
208210
seed()
@@ -1575,8 +1577,14 @@ def get_rews(rng, r_dict):
15751577
def transition_function(self, state, action):
15761578
"""The transition function, P.
15771579
1578-
Performs a transition according to the initialised P for discrete environments (with dynamics independent for relevant vs irrelevant dimension sub-spaces). For continuous environments, we have a fixed available option for the dynamics (which is the same for relevant or irrelevant dimensions):
1579-
The order of the system decides the dynamics. For an nth order system, the nth order derivative of the state is set to the action value / inertia for time_unit seconds. And then the dynamics are integrated over the time_unit to obtain the next state.
1580+
Performs a transition according to the initialised P for discrete environments (the independent dynamics for the irrelevant
1581+
dimension sub-spaces are handled in step() for discrete envs while they are handled here for grid envs).
1582+
For continuous environments, we have a fixed available option for the dynamics (which is the same for relevant
1583+
or irrelevant dimensions):
1584+
The order of the system decides the dynamics. For an nth order system, the nth order derivative of the state is set to the
1585+
action value / inertia for time_unit seconds. And then the dynamics are integrated over the time_unit to obtain the next state.
1586+
1587+
###TODO Make this function use Markov state also for continuous envs.
15801588
15811589
Parameters
15821590
----------
@@ -1989,7 +1997,7 @@ def step(self, action, imaginary_rollout=False):
19891997
action : int or np.array
19901998
The action that the environment will use to perform a transition.
19911999
imaginary_rollout: boolean
1992-
Option for the user to perform "imaginary" transitions, e.g., for model-based RL. If set to true, underlying augmented state of the MDP is not changed and user is responsible to maintain and provide a list of states to this function to be able to perform a rollout.
2000+
Unsupported at the moment. Option for the user to perform "imaginary" transitions, e.g., for model-based RL. If set to true, underlying augmented state of the MDP is not changed and user is responsible to maintain and provide a list of states to this function to be able to perform a rollout.
19932001
19942002
Returns
19952003
-------
@@ -2117,12 +2125,21 @@ def step(self, action, imaginary_rollout=False):
21172125
return self.curr_obs, self.reward, self.done, False, self.get_augmented_state()
21182126

21192127
def get_augmented_state(self):
2120-
"""Intended to return the full augmented state which would be Markovian. (However, it's not Markovian wrt the noise in P and R because we're not returning the underlying RNG.) Currently, returns the augmented state which is the sequence of length "delay + sequence_length + 1" of past states for both discrete and continuous environments. Additonally, the current state derivatives are also returned for continuous environments.
2128+
"""Intended to return the full augmented state which would be Markovian. (However, it's not Markovian wrt the noise in P and R
2129+
because we're not returning the underlying RNG.)
21212130
2122-
Returns
2131+
Returns a dictionary with the following keys:
21232132
-------
2124-
dict
2125-
Contains at the end of the current transition
2133+
2134+
augmented_state contains the sequence / list of past states of length "delay + sequence_length + 1". Each element in this list contains
2135+
only the relevant parts for discrete envs, continuous (only 0th order info, i.e., position) and grid envs iirc.
2136+
state_derivatives contains the list of state derivatives - only present for continuous envs.
2137+
curr_state contains the relevant and irrelevant parts (if any) for discrete, continuous and grid envs.
2138+
curr_obs contains the same unless image_representations is True, in which case it contains the image representation of curr_state.
2139+
2140+
Remark: relevant_indices for cont. envs can be figured out using curr_state and augmented_state. So, all the info needed to make the
2141+
state Markov is present in the returned dict (except for the RNG state if P and R are noisy). This could be improved but this is the
2142+
current implementation.
21262143
21272144
"""
21282145
# #TODO For noisy processes, this would need the noise distribution and random seed too. Also add the irrelevant state parts, etc.? We don't need the irrelevant parts for the state to be Markovian.
@@ -2148,6 +2165,55 @@ def get_augmented_state(self):
21482165

21492166
return augmented_state_dict
21502167

2168+
def set_augmented_state(self, augmented_state_dict):
2169+
"""Sets the underlying Markov state of the environment to the one specified in the argument. This is useful for
2170+
setting a custom state from which to rollout, e.g., for model-based RL imaginary rollouts.
2171+
2172+
Parameters
2173+
----------
2174+
augmented_state_dict : dict or state
2175+
If it's a dictionary, it should be in the format returned by get_augmented_state().
2176+
If it's a state, all of the elements of the member variables curr_state, curr_obs,
2177+
augmented_state and state_derivatives (for continuous envs) are set to this state.
2178+
2179+
"""
2180+
2181+
if type(augmented_state_dict) is not dict:
2182+
warnings.warn(
2183+
"Warning: When setting the Markov state of the env, the passed state dictionary " \
2184+
"was not a dict in the expected format (i.e. the one returned by get_augmented_state()). " \
2185+
"Setting all relevant member variables of the env to be the value that was passed in. " \
2186+
"If you see any errors, you will need to dig deeper into the code to see how to set the state properly."
2187+
)
2188+
2189+
if self.config["state_space_type"] == "continuous":
2190+
# Create copies of np.arrays to avoid modifying the original state which may be used by external code:
2191+
augmented_state_dict = {
2192+
"curr_state": augmented_state_dict.copy(),
2193+
"curr_obs": augmented_state_dict.copy(),
2194+
"augmented_state": [[np.nan] * self.state_space_dim] * (self.augmented_state_length - 1) + [augmented_state_dict.copy()],
2195+
}
2196+
2197+
# If continuous env, also set state_derivatives to 0, except the 0th order one which is the state itself:
2198+
augmented_state_dict["state_derivatives"] = [
2199+
np.zeros(self.state_space_dim, dtype=self.dtype_s)
2200+
] * (self.dynamics_order + 1)
2201+
augmented_state_dict["state_derivatives"][0] = augmented_state_dict["curr_state"].copy()
2202+
2203+
else: # discrete or grid env
2204+
augmented_state_dict = {
2205+
"curr_state": augmented_state_dict,
2206+
"curr_obs": augmented_state_dict,
2207+
"augmented_state": [np.nan] * (self.augmented_state_length - 1) + [augmented_state_dict],
2208+
}
2209+
2210+
self.curr_state = augmented_state_dict["curr_state"]
2211+
self.curr_obs = augmented_state_dict["curr_obs"]
2212+
self.augmented_state = augmented_state_dict["augmented_state"]
2213+
2214+
if self.config["state_space_type"] == "continuous":
2215+
self.state_derivatives = augmented_state_dict["state_derivatives"]
2216+
21512217
def reset(self, seed=None, options=None):
21522218
"""Resets the environment for the beginning of an episode and samples a start state from rho_0. For discrete environments uses the defined rho_0 directly. For continuous environments, samples a state and resamples until a non-terminal state is sampled.
21532219
@@ -2339,7 +2405,7 @@ def seed(self, seed=None):
23392405
)
23402406
return self.seed_
23412407

2342-
def render(self, actions=None, render_mode=None):
2408+
def render(self, actions=None, state=None, render_mode=None):
23432409
'''
23442410
Renders the environment using pygame if render_mode is "human" and returns the rendered
23452411
image if render_mode is "rgb_array".
@@ -2419,6 +2485,9 @@ def render(self, actions=None, render_mode=None):
24192485
# Make a copy of the environment to perform the rollout:
24202486
env_copy = copy.deepcopy(self)
24212487
env_copy.render_mode = "rgb_array" # Set render_mode to rgb_array for the copy #hardcoded
2488+
# Allow rolling out from a custom state:
2489+
if state is not None:
2490+
env_copy.set_augmented_state(state)
24222491
if render_mode is not None and render_mode != "rgb_array":
24232492
raise NotImplementedError(
24242493
"Currently, only render_mode 'rgb_array' is supported for action sequences."

0 commit comments

Comments
 (0)