@@ -176,6 +176,8 @@ class RLToyEnv(gym.Env):
176176 The externally visible observation space for the enviroment.
177177 action_space : Gym.Space
178178 The externally visible action space for the enviroment.
179+ feature_space : Gym.Space
180+ In case of continuous and grid environments, this is the underlying state space. ##TODO Unify this across all types of environments.
179181 rewardable_sequences : dict
180182 holds the rewardable sequences. The keys are tuples of rewardable sequences and values are the rewards handed out. When make_denser is True for discrete environments, this dict also holds the rewardable partial sequences.
181183
@@ -519,7 +521,6 @@ def __init__(self, **config):
519521 elif config ["state_space_type" ] == "grid" :
520522 assert "grid_shape" in config
521523 self .grid_shape = config ["grid_shape" ]
522- self .grid_np_data_type = np .int64
523524 else :
524525 raise ValueError ("Unknown state_space_type" )
525526
@@ -546,9 +547,9 @@ def __init__(self, **config):
546547 else :
547548 self .repeats_in_sequences = config ["repeats_in_sequences" ]
548549
549- self .dtype = np .float32 if "dtype" not in config else config ["dtype" ]
550550
551551 if config ["state_space_type" ] == "discrete" :
552+ self .dtype_s = np .int64 if "dtype_s" not in config else config ["dtype_s" ]
552553 if self .irrelevant_features :
553554 assert (
554555 len (config ["action_space_size" ]) == 2
@@ -570,6 +571,7 @@ def __init__(self, **config):
570571 )
571572 # assert (np.array(self.state_space_size) % np.array(self.diameter) == 0).all(), "state_space_size should be a multiple of the diameter to allow for the generation of regularly connected MDPs."
572573 elif config ["state_space_type" ] == "continuous" :
574+ self .dtype_s = np .float32 if "dtype_s" not in config else config ["dtype_s" ]
573575 self .action_space_dim = self .state_space_dim
574576 if self .irrelevant_features :
575577 assert (
@@ -580,10 +582,18 @@ def __init__(self, **config):
580582 config ["relevant_indices" ] = range (self .state_space_dim )
581583 # config["irrelevant_indices"] = list(set(range(len(config["state_space_dim"]))) - set(config["relevant_indices"]))
582584 elif config ["state_space_type" ] == "grid" :
585+ self .dtype_s = np .int64 if "dtype_s" not in config else config ["dtype_s" ]
583586 # Repeat the grid for the irrelevant part as well
584587 if self .irrelevant_features :
585588 self .grid_shape = self .grid_shape * 2
586589
590+ # Set the dtype for the observation space:
591+ if self .image_representations :
592+ self .dtype_o = np .float32 if "dtype_o" not in config else config ["dtype_o" ]
593+ else :
594+ self .dtype_o = self .dtype_s if "dtype_o" not in config else config ["dtype_o" ]
595+
596+
587597 if ("init_state_dist" in config ) and ("relevant_init_state_dist" not in config ):
588598 config ["relevant_init_state_dist" ] = config ["init_state_dist" ]
589599
@@ -614,7 +624,7 @@ def __init__(self, **config):
614624 assert self .sequence_length == 1
615625 if "target_point" in config :
616626 self .target_point = np .array (
617- config ["target_point" ], dtype = self .dtype
627+ config ["target_point" ], dtype = self .dtype_s
618628 )
619629 assert self .target_point .shape == (
620630 len (config ["relevant_indices" ]),
@@ -640,6 +650,7 @@ def __init__(self, **config):
640650 DiscreteExtended (
641651 self .state_space_size [0 ],
642652 seed = self .seed_dict ["relevant_state_space" ],
653+ # dtype=self.dtype_o, # Gymnasium seems to hardcode as np.int64
643654 )
644655 ] # #seed #hardcoded, many time below as well
645656 self .action_spaces = [
@@ -671,7 +682,7 @@ def __init__(self, **config):
671682 # self.action_spaces[i] = DiscreteExtended(self.action_space_size[i],
672683 # seed=self.seed_dict["irrelevant_action_space"]) #seed
673684
674- if self .image_representations :
685+ if self .image_representations : # for discrete envs
675686 # underlying_obs_space = MultiDiscreteExtended(self.state_space_size, seed=self.seed_dict["state_space"]) #seed
676687 self .observation_space = ImageMultiDiscrete (
677688 self .state_space_size ,
@@ -714,7 +725,7 @@ def __init__(self, **config):
714725 self .state_space_max ,
715726 shape = (self .state_space_dim ,),
716727 seed = self .seed_dict ["state_space" ],
717- dtype = self .dtype ,
728+ dtype = self .dtype_s ,
718729 ) # #seed
719730 # hack #TODO # low and high are 1st 2 and required arguments
720731 # for instantiating BoxExtended
@@ -729,7 +740,7 @@ def __init__(self, **config):
729740 self .action_space_max ,
730741 shape = (self .action_space_dim ,),
731742 seed = self .seed_dict ["action_space" ],
732- dtype = self .dtype ,
743+ dtype = self .dtype_s ,
733744 ) # #seed
734745 # hack #TODO
735746
@@ -754,7 +765,7 @@ def __init__(self, **config):
754765 0 * underlying_space_maxes ,
755766 underlying_space_maxes ,
756767 seed = self .seed_dict ["state_space" ],
757- dtype = self .dtype ,
768+ dtype = self .dtype_s ,
758769 ) # #seed
759770
760771 lows = np .array ([- 1 ] * len (self .grid_shape ))
@@ -893,7 +904,7 @@ def init_terminal_states(self):
893904 # print("Term state lows, highs:", lows, highs)
894905 self .term_spaces .append (
895906 BoxExtended (
896- low = lows , high = highs , seed = self .seed_ , dtype = self .dtype
907+ low = lows , high = highs , seed = self .seed_ , dtype = self .dtype_s
897908 )
898909 ) # #seed #hack #TODO
899910 self .logger .debug (
@@ -931,7 +942,7 @@ def init_terminal_states(self):
931942 highs = term_state # #hardcoded
932943 self .term_spaces .append (
933944 BoxExtended (
934- low = lows , high = highs , seed = self .seed_ , dtype = self .grid_np_data_type
945+ low = lows , high = highs , seed = self .seed_ , dtype = self .dtype_s
935946 )
936947 ) # #seed #hack #TODO
937948
@@ -1657,7 +1668,7 @@ def transition_function(self, state, action):
16571668 # for a "wall", but would need to take care of multiple
16581669 # reflections near a corner/edge.
16591670 # Resets all higher order derivatives to 0
1660- zero_state = np .array ([0.0 ] * (self .state_space_dim ), dtype = self .dtype )
1671+ zero_state = np .array ([0.0 ] * (self .state_space_dim ), dtype = self .dtype_s )
16611672 # #####IMP to have copy() otherwise it's the same array
16621673 # (in memory) at every position in the list:
16631674 self .state_derivatives = [
@@ -1666,7 +1677,7 @@ def transition_function(self, state, action):
16661677 self .state_derivatives [0 ] = next_state
16671678
16681679 if self .config ["reward_function" ] == "move_to_a_point" :
1669- next_state_rel = np .array (next_state , dtype = self .dtype )[
1680+ next_state_rel = np .array (next_state , dtype = self .dtype_s )[
16701681 self .config ["relevant_indices" ]
16711682 ]
16721683 dist_ = np .linalg .norm (next_state_rel - self .target_point )
@@ -1678,7 +1689,7 @@ def transition_function(self, state, action):
16781689 # Need to check that dtype is int because Gym doesn't
16791690 if (
16801691 self .action_space .contains (action )
1681- and np .array (action ).dtype == self .grid_np_data_type
1692+ and np .array (action ).dtype == self .dtype_s
16821693 ):
16831694 if self .transition_noise :
16841695 # self._np_random.choice only works for 1-D arrays
@@ -1820,7 +1831,7 @@ def reward_function(self, state, action):
18201831 # of the formulae and see that programmatic results match: should
18211832 # also have a unit version of 4. for dist_of_pt_from_line() and
18221833 # an integration version here for total_deviation calc.?.
1823- data_ = np .array (state_considered , dtype = self .dtype )[
1834+ data_ = np .array (state_considered , dtype = self .dtype_s )[
18241835 1 + delay : self .augmented_state_length ,
18251836 self .config ["relevant_indices" ],
18261837 ]
@@ -1863,10 +1874,10 @@ def reward_function(self, state, action):
18631874 # that. #TODO Generate it randomly to have random Rs?
18641875 if self .make_denser :
18651876 old_relevant_state = np .array (
1866- state_considered , dtype = self .dtype
1877+ state_considered , dtype = self .dtype_s
18671878 )[- 2 , self .config ["relevant_indices" ]]
18681879 new_relevant_state = np .array (
1869- state_considered , dtype = self .dtype
1880+ state_considered , dtype = self .dtype_s
18701881 )[- 1 , self .config ["relevant_indices" ]]
18711882 reward = - np .linalg .norm (new_relevant_state - self .target_point )
18721883 # Should allow other powers of the distance from target_point,
@@ -1879,7 +1890,7 @@ def reward_function(self, state, action):
18791890 # TODO also make_denser, sparse rewards only at target
18801891 else : # sparse reward
18811892 new_relevant_state = np .array (
1882- state_considered , dtype = self .dtype
1893+ state_considered , dtype = self .dtype_s
18831894 )[- 1 , self .config ["relevant_indices" ]]
18841895 if (
18851896 np .linalg .norm (new_relevant_state - self .target_point )
@@ -1890,7 +1901,7 @@ def reward_function(self, state, action):
18901901 # stay in the radius and earn more reward.
18911902
18921903 reward -= self .action_loss_weight * np .linalg .norm (
1893- np .array (action , dtype = self .dtype )
1904+ np .array (action , dtype = self .dtype_s )
18941905 )
18951906
18961907 elif self .config ["state_space_type" ] == "grid" :
@@ -2044,8 +2055,8 @@ def step(self, action, imaginary_rollout=False):
20442055 if self .image_representations :
20452056 next_obs = self .observation_space .get_concatenated_image (next_state )
20462057
2047- self .curr_state = next_state
2048- self .curr_obs = next_obs
2058+ self .curr_state = self . dtype_s ( next_state )
2059+ self .curr_obs = self . dtype_o ( next_obs )
20492060
20502061 # #### TODO curr_state is external state, while we need to check relevant state for terminality! Done - by using augmented_state now instead of curr_state!
20512062 self .done = (
@@ -2199,7 +2210,7 @@ def reset(self, seed=None):
21992210
22002211 # if not self.use_custom_mdp:
22012212 # init the state derivatives needed for continuous spaces
2202- zero_state = np .array ([0.0 ] * (self .state_space_dim ), dtype = self .dtype )
2213+ zero_state = np .array ([0.0 ] * (self .state_space_dim ), dtype = self .dtype_s )
22032214 self .state_derivatives = [
22042215 zero_state .copy () for i in range (self .dynamics_order + 1 )
22052216 ] # #####IMP to have copy()
@@ -2217,7 +2228,7 @@ def reset(self, seed=None):
22172228 while True : # Be careful about infinite loops
22182229 term_space_was_sampled = False
22192230 # curr_state is an np.array while curr_state_relevant is a list
2220- self .curr_state = self .feature_space .sample ().astype (int ) # #random
2231+ self .curr_state = self .feature_space .sample ().astype (self . dtype_s ) # #random
22212232 self .curr_state_relevant = list (self .curr_state [[0 , 1 ]]) # #hardcoded
22222233 if self .is_terminal_state (self .curr_state_relevant ):
22232234 self .logger .debug (
@@ -2241,6 +2252,9 @@ def reset(self, seed=None):
22412252 else :
22422253 self .curr_obs = self .curr_state
22432254
2255+ self .curr_state = self .dtype_s (self .curr_state )
2256+ self .curr_obs = self .dtype_o (self .curr_obs )
2257+
22442258 self .logger .info ("RESET called. curr_state reset to: " + str (self .curr_state ))
22452259 self .reached_terminal = False
22462260
0 commit comments