diff --git a/stable_baselines/a2c/a2c.py b/stable_baselines/a2c/a2c.py index 1806ce03e6..692b01e049 100644 --- a/stable_baselines/a2c/a2c.py +++ b/stable_baselines/a2c/a2c.py @@ -87,9 +87,20 @@ def __init__(self, policy, env, gamma=0.99, n_steps=5, vf_coef=0.25, ent_coef=0. def _get_pretrain_placeholders(self): policy = self.train_model + + if self.policy.recurrent: + states_ph = policy.states_ph + snew_ph = policy.snew + dones_ph = policy.dones_ph + else: + states_ph = None + snew_ph = None + dones_ph = None + if isinstance(self.action_space, gym.spaces.Discrete): - return policy.obs_ph, self.actions_ph, policy.policy - return policy.obs_ph, self.actions_ph, policy.deterministic_action + return policy.obs_ph, self.actions_ph, states_ph, snew_ph, dones_ph, policy.policy + return policy.obs_ph, self.actions_ph, states_ph, snew_ph, dones_ph,\ + policy.deterministic_action def setup_model(self): with SetVerbosity(self.verbose): diff --git a/stable_baselines/acer/acer_simple.py b/stable_baselines/acer/acer_simple.py index 3f4af89432..034e1b62cf 100644 --- a/stable_baselines/acer/acer_simple.py +++ b/stable_baselines/acer/acer_simple.py @@ -152,8 +152,18 @@ def __init__(self, policy, env, gamma=0.99, n_steps=20, num_procs=1, q_coef=0.5, def _get_pretrain_placeholders(self): policy = self.step_model action_ph = policy.pdtype.sample_placeholder([None]) + + if self.policy.recurrent: + states_ph = policy.states_ph + snew_ph = policy.snew + dones_ph = policy.dones_ph + else: + states_ph = None + snew_ph = None + dones_ph = None + if isinstance(self.action_space, Discrete): - return policy.obs_ph, action_ph, policy.policy + return policy.obs_ph, action_ph, states_ph, snew_ph, dones_ph, policy.policy raise NotImplementedError('Only discrete actions are supported for ACER for now') def set_env(self, env): diff --git a/stable_baselines/acktr/acktr_disc.py b/stable_baselines/acktr/acktr_disc.py index e2f4327387..20def5c10f 100644 --- a/stable_baselines/acktr/acktr_disc.py +++ b/stable_baselines/acktr/acktr_disc.py @@ -104,8 +104,18 @@ def __init__(self, policy, env, gamma=0.99, nprocs=1, n_steps=20, ent_coef=0.01, def _get_pretrain_placeholders(self): policy = self.train_model + + if self.initial_state is None: + states_ph = None + snew_ph = None + dones_ph = None + else: + states_ph = policy.states_ph + snew_ph = policy.snew + dones_ph = policy.dones_ph + if isinstance(self.action_space, Discrete): - return policy.obs_ph, self.action_ph, policy.policy + return policy.obs_ph, self.action_ph, states_ph, snew_ph, dones_ph, policy.policy raise NotImplementedError("WIP: ACKTR does not support Continuous actions yet.") def setup_model(self): diff --git a/stable_baselines/common/base_class.py b/stable_baselines/common/base_class.py index 2f58ce776b..e22877c89a 100644 --- a/stable_baselines/common/base_class.py +++ b/stable_baselines/common/base_class.py @@ -50,6 +50,10 @@ def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base, pol self.sess = None self.params = None self._param_load_ops = None + self.initial_state = None + self.n_batch = None + self.nminibatches = None + self.n_steps = None if env is not None: if isinstance(env, str): @@ -246,13 +250,22 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, else: val_interval = int(n_epochs / 10) + if self.policy.recurrent: + if self.nminibatches is None: + envs_per_batch = self.n_envs * self.n_steps + else: + batch_size = self.n_batch // self.nminibatches + envs_per_batch = batch_size // self.n_steps + with self.graph.as_default(): with tf.variable_scope('pretrain'): if continuous_actions: - obs_ph, actions_ph, deterministic_actions_ph = self._get_pretrain_placeholders() + obs_ph, actions_ph, states_ph, snew_ph, dones_ph, \ + deterministic_actions_ph = self._get_pretrain_placeholders() loss = tf.reduce_mean(tf.square(actions_ph - deterministic_actions_ph)) else: - obs_ph, actions_ph, actions_logits_ph = self._get_pretrain_placeholders() + obs_ph, actions_ph, states_ph, snew_ph, dones_ph, \ + actions_logits_ph = self._get_pretrain_placeholders() # actions_ph has a shape if (n_batch,), we reshape it to (n_batch, 1) # so no additional changes is needed in the dataloader actions_ph = tf.expand_dims(actions_ph, axis=1) @@ -272,13 +285,23 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, for epoch_idx in range(int(n_epochs)): train_loss = 0.0 + if self.policy.recurrent: + state = self.initial_state[:envs_per_batch] + # Full pass on the training set for _ in range(len(dataset.train_loader)): - expert_obs, expert_actions = dataset.get_next_batch('train') + expert_obs, expert_actions, expert_mask = dataset.get_next_batch('train') feed_dict = { obs_ph: expert_obs, actions_ph: expert_actions, } + + if self.policy.recurrent: + feed_dict.update({states_ph: state, dones_ph: expert_mask}) + state, train_loss_, _ = self.sess.run([snew_ph, loss, optim_op], feed_dict) + else: + train_loss_, _ = self.sess.run([loss, optim_op], feed_dict) + train_loss_, _ = self.sess.run([loss, optim_op], feed_dict) train_loss += train_loss_ @@ -288,9 +311,18 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, val_loss = 0.0 # Full pass on the validation set for _ in range(len(dataset.val_loader)): - expert_obs, expert_actions = dataset.get_next_batch('val') - val_loss_, = self.sess.run([loss], {obs_ph: expert_obs, - actions_ph: expert_actions}) + expert_obs, expert_actions, expert_mask = dataset.get_next_batch('val') + + feed_dict = { + obs_ph: expert_obs, + actions_ph: expert_actions, + } + + if self.policy.recurrent: + feed_dict.update({states_ph: state, dones_ph: expert_mask}) + + val_loss_, = self.sess.run([loss], feed_dict) + val_loss += val_loss_ val_loss /= len(dataset.val_loader) diff --git a/stable_baselines/ddpg/ddpg.py b/stable_baselines/ddpg/ddpg.py index 7036271960..4a2f7657b3 100644 --- a/stable_baselines/ddpg/ddpg.py +++ b/stable_baselines/ddpg/ddpg.py @@ -308,7 +308,7 @@ def _get_pretrain_placeholders(self): policy = self.policy_tf # Rescale deterministic_action = self.actor_tf * np.abs(self.action_space.low) - return policy.obs_ph, self.actions, deterministic_action + return policy.obs_ph, self.actions, None, None, None, deterministic_action def setup_model(self): with SetVerbosity(self.verbose): diff --git a/stable_baselines/deepq/dqn.py b/stable_baselines/deepq/dqn.py index 637c9462c7..f12180eb5b 100644 --- a/stable_baselines/deepq/dqn.py +++ b/stable_baselines/deepq/dqn.py @@ -98,7 +98,7 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000 def _get_pretrain_placeholders(self): policy = self.step_model - return policy.obs_ph, tf.placeholder(tf.int32, [None]), policy.q_values + return policy.obs_ph, tf.placeholder(tf.int32, [None]), None, None, None, policy.q_values def setup_model(self): diff --git a/stable_baselines/gail/__init__.py b/stable_baselines/gail/__init__.py index 0a4c7ac4ea..0cd9db8609 100644 --- a/stable_baselines/gail/__init__.py +++ b/stable_baselines/gail/__init__.py @@ -1,3 +1,3 @@ from stable_baselines.gail.model import GAIL -from stable_baselines.gail.dataset.dataset import ExpertDataset, DataLoader +from stable_baselines.gail.dataset.dataset import ExpertDataset, ExpertDatasetLSTM, DataLoader from stable_baselines.gail.dataset.record_expert import generate_expert_traj diff --git a/stable_baselines/gail/dataset/dataset.py b/stable_baselines/gail/dataset/dataset.py index 64b55dbb3e..260151c64b 100644 --- a/stable_baselines/gail/dataset/dataset.py +++ b/stable_baselines/gail/dataset/dataset.py @@ -1,15 +1,86 @@ import queue import time +import warnings from multiprocessing import Queue, Process import cv2 import numpy as np from joblib import Parallel, delayed +from itertools import cycle, islice from stable_baselines import logger +class Dataset(object): + + + def __del__(self): + del self.dataloader, self.train_loader, self.val_loader + + def prepare_pickling(self): + """ + Exit processes in order to pickle the dataset. + """ + self.dataloader, self.train_loader, self.val_loader = None, None, None + + def log_info(self): + """ + Log the information of the dataset. + """ + logger.log("Total trajectories: {}".format(self.num_traj)) + logger.log("Total transitions: {}".format(self.num_transition)) + logger.log("Average returns: {}".format(self.avg_ret)) + logger.log("Std for returns: {}".format(self.std_ret)) + + def get_next_batch(self, split=None): + """ + Get the batch from the dataset. + :param split: (str) the type of data split (can be None, 'train', 'val') + :return: (np.ndarray, np.ndarray) inputs and labels + """ + dataloader = { + None: self.dataloader, + 'train': self.train_loader, + 'val': self.val_loader + }[split] + + if dataloader.process is None: + dataloader.start_process() + try: + return next(dataloader) + except StopIteration: + dataloader = iter(dataloader) + return next(dataloader) + + def plot(self): + """ + Show histogram plotting of the episode returns + """ + # Isolate dependency since it is only used for plotting and also since + # different matplotlib backends have further dependencies themselves. + import matplotlib.pyplot as plt + plt.hist(self.returns) + plt.show() + +def check_traj_data(expert_path=None, traj_data=None): + """ + Sanity check expert_path and load traj_data. + + :param expert_path: (str) The path to trajectory data (.npz file). Mutually exclusive with traj_data. + :param traj_data: (dict) Trajectory data, in format described above. Mutually exclusive with expert_path. + :return: traj_data + """ + + if traj_data is not None and expert_path is not None: + raise ValueError("Cannot specify both 'traj_data' and 'expert_path'") + if traj_data is None and expert_path is None: + raise ValueError("Must specify one of 'traj_data' or 'expert_path'") + if traj_data is None: + traj_data = np.load(expert_path, allow_pickle=True) + + return traj_data + +class ExpertDataset(Dataset): -class ExpertDataset(object): """ Dataset for using behavior cloning or GAIL. @@ -22,23 +93,20 @@ class ExpertDataset(object): :param expert_path: (str) The path to trajectory data (.npz file). Mutually exclusive with traj_data. :param traj_data: (dict) Trajectory data, in format described above. Mutually exclusive with expert_path. :param train_fraction: (float) the train validation split (0 to 1) - for pre-training using behavior cloning (BC) + for pre-training using behavior cloning (BC). :param batch_size: (int) the minibatch size for behavior cloning :param traj_limitation: (int) the number of trajectory to use (if -1, load all) - :param randomize: (bool) if the dataset should be shuffled + :param randomize: (bool) if the dataset should be shuffled. :param verbose: (int) Verbosity :param sequential_preprocessing: (bool) Do not use subprocess to preprocess the data (slower but use less memory for the CI) """ - def __init__(self, expert_path=None, traj_data=None, train_fraction=0.7, batch_size=64, - traj_limitation=-1, randomize=True, verbose=1, sequential_preprocessing=False): - if traj_data is not None and expert_path is not None: - raise ValueError("Cannot specify both 'traj_data' and 'expert_path'") - if traj_data is None and expert_path is None: - raise ValueError("Must specify one of 'traj_data' or 'expert_path'") - if traj_data is None: - traj_data = np.load(expert_path, allow_pickle=True) + def __init__(self, expert_path=None, traj_data=None, train_fraction=0.7, + batch_size=64, traj_limitation=-1, randomize=True, verbose=1, + sequential_preprocessing=False): + + traj_data = check_traj_data(expert_path=expert_path, traj_data=traj_data) if verbose > 0: for key, val in traj_data.items(): @@ -60,6 +128,7 @@ def __init__(self, expert_path=None, traj_data=None, train_fraction=0.7, batch_s observations = traj_data['obs'][:traj_limit_idx] actions = traj_data['actions'][:traj_limit_idx] + mask = episode_starts[:traj_limit_idx] # obs, actions: shape (N * L, ) + S # where N = # episodes, L = episode length @@ -70,6 +139,8 @@ def __init__(self, expert_path=None, traj_data=None, train_fraction=0.7, batch_s observations = np.reshape(observations, [-1, np.prod(observations.shape[1:])]) if len(actions.shape) > 2: actions = np.reshape(actions, [-1, np.prod(actions.shape[1:])]) + if len(mask.shape) > 2: + mask = np.reshape(mask, [-1, np.prod(mask.shape[1:])]) indices = np.random.permutation(len(observations)).astype(np.int64) @@ -77,11 +148,15 @@ def __init__(self, expert_path=None, traj_data=None, train_fraction=0.7, batch_s train_indices = indices[:int(train_fraction * len(indices))] val_indices = indices[int(train_fraction * len(indices)):] + # Set randomize. + self.randomize = randomize + assert len(train_indices) > 0, "No sample for the training set" assert len(val_indices) > 0, "No sample for the validation set" self.observations = observations self.actions = actions + self.mask = mask self.returns = traj_data['episode_returns'][:traj_limit_idx] self.avg_ret = sum(self.returns) / len(self.returns) @@ -92,15 +167,24 @@ def __init__(self, expert_path=None, traj_data=None, train_fraction=0.7, batch_s "please check your expert dataset" self.num_traj = min(traj_limitation, np.sum(episode_starts)) self.num_transition = len(self.observations) - self.randomize = randomize self.sequential_preprocessing = sequential_preprocessing self.dataloader = None - self.train_loader = DataLoader(train_indices, self.observations, self.actions, batch_size, - shuffle=self.randomize, start_process=False, + self.train_loader = DataLoader(train_indices, + self.observations, + self.actions, + self.mask, batch_size, + shuffle=self.randomize, + start_process=False, sequential=sequential_preprocessing) - self.val_loader = DataLoader(val_indices, self.observations, self.actions, batch_size, - shuffle=self.randomize, start_process=False, + + self.val_loader = DataLoader(val_indices, + self.observations, + self.actions, + self.mask, + batch_size, + shuffle=self.randomize, + start_process=False, sequential=sequential_preprocessing) if self.verbose >= 1: @@ -109,62 +193,217 @@ def __init__(self, expert_path=None, traj_data=None, train_fraction=0.7, batch_s def init_dataloader(self, batch_size): """ Initialize the dataloader used by GAIL. - :param batch_size: (int) """ indices = np.random.permutation(len(self.observations)).astype(np.int64) - self.dataloader = DataLoader(indices, self.observations, self.actions, batch_size, - shuffle=self.randomize, start_process=False, + self.dataloader = DataLoader(indices, + self.observations, + self.actions, + self.mask, + batch_size, + shuffle=self.randomize, + start_process=False, sequential=self.sequential_preprocessing) - def __del__(self): - del self.dataloader, self.train_loader, self.val_loader - def prepare_pickling(self): - """ - Exit processes in order to pickle the dataset. - """ - self.dataloader, self.train_loader, self.val_loader = None, None, None - def log_info(self): - """ - Log the information of the dataset. - """ - logger.log("Total trajectories: {}".format(self.num_traj)) - logger.log("Total transitions: {}".format(self.num_transition)) - logger.log("Average returns: {}".format(self.avg_ret)) - logger.log("Std for returns: {}".format(self.std_ret)) +class ExpertDatasetLSTM(Dataset): + """ + Dataset for using behavior cloning or GAIL. - def get_next_batch(self, split=None): - """ - Get the batch from the dataset. + The structure of the expert dataset is a dict, saved as an ".npz" archive. + The dictionary contains the keys 'actions', 'episode_returns', 'rewards', + 'obs' and 'episode_starts'. The corresponding values have data + concatenated across episode: the first axis is the timestep, + the remaining axes index into the data. In case of images, + 'obs' contains the relative path to the images, to enable space + saving from image compression. + + :param expert_path: (str) The path to trajectory data (.npz file). + Mutually exclusive with traj_data. + :param traj_data: (dict) Trajectory data, in format described above. + Mutually exclusive with expert_path. + :param train_fraction: (float) the train validation split (0 to 1) + for pre-training using behavior cloning (BC) + :param batch_size: (int) the minibatch size for behavior cloning + :param traj_limitation: (int) the number of trajectory to use (if -1, load all) + :param verbose: (int) Verbosity + :param sequential_preprocessing: (bool) Do not use subprocess to preprocess + the data (slower but use less memory for the CI) + :param envs_per_batch: (int) Only used if LSTM is True. Number of envs that + are processed per batch. + """ - :param split: (str) the type of data split (can be None, 'train', 'val') - :return: (np.ndarray, np.ndarray) inputs and labels - """ - dataloader = { - None: self.dataloader, - 'train': self.train_loader, - 'val': self.val_loader - }[split] + def __init__(self, expert_path=None, traj_data=None, train_fraction=0.7, + batch_size=64, traj_limitation=-1, verbose=1, envs_per_batch=1, + sequential_preprocessing=False): - if dataloader.process is None: - dataloader.start_process() - try: - return next(dataloader) - except StopIteration: - dataloader = iter(dataloader) - return next(dataloader) + traj_data = check_traj_data(expert_path=expert_path, traj_data=traj_data) - def plot(self): + if verbose > 0: + for key, val in traj_data.items(): + print(key, val.shape) + + envs_per_batch = int(envs_per_batch) + use_batch_size = batch_size * envs_per_batch + + # Array of bool where episode_starts[i] = True for each new episode + episode_starts = traj_data['episode_starts'] + + traj_limit_idx = len(traj_data['obs']) + + if traj_limitation > 0: + n_episodes = 0 + # Retrieve the index corresponding + # to the traj_limitation trajectory + for idx, episode_start in enumerate(episode_starts): + n_episodes += int(episode_start) + if n_episodes == (traj_limitation + 1): + traj_limit_idx = idx - 1 + + observations = traj_data['obs'][:traj_limit_idx] + actions = traj_data['actions'][:traj_limit_idx] + mask = episode_starts[:traj_limit_idx] + + start_index_list = [] + for idx, episode_start in enumerate(mask): + if episode_start: + start_index_list.append(idx) + + start_index_list += [traj_limit_idx] + + # obs, actions: shape (N * L, ) + S + # where N = # episodes, L = episode length + # and S is the environment observation/action space. + # S = (1, ) for discrete space + # Flatten to (N * L, prod(S)) + if len(observations.shape) > 2: + observations = np.reshape(observations, [-1, np.prod(observations.shape[1:])]) + if len(actions.shape) > 2: + actions = np.reshape(actions, [-1, np.prod(actions.shape[1:])]) + if len(mask.shape) > 2: + mask = np.reshape(mask, [-1, np.prod(mask.shape[1:])]) + + # Creat indices list and split them per episode. + indices = np.arange(start=0, stop=len(observations)).astype(np.int64) + split_indices = [indices[start_index_list[i]:start_index_list[i+1]]\ + .tolist() for i in range(0, len(start_index_list)-1)] + + # Create list with episode lengths. + len_list = [len(s_i) for s_i in split_indices] + + assert len(len_list) >= envs_per_batch, "Not enough saved " \ + "episodes for this number " \ + "of workers and nminibatches." + + # Sort episode pos by lengths. + sort_buffer = np.argsort(len_list).tolist()[::-1] + + # Creat stack list and pre fill then with the longest episodes. + stack_indices = [] + for i in range(envs_per_batch): + stack_indices.append(split_indices[sort_buffer[0]]) + sort_buffer.pop(0) + + # Add next episode to the smallest stack. + for s_b in sort_buffer: + currend_stackt_indices_len = [len(st_i) for st_i in stack_indices] + smalest_stackt_indices_pos = np.argmin(currend_stackt_indices_len) + stack_indices[smalest_stackt_indices_pos] += split_indices[s_b] + + # Creat info varibelts used for data cycle. + pre_cycle_len = [len(st_i) for st_i in stack_indices] + max_len = max(pre_cycle_len) + min_len = min(pre_cycle_len) + mod_max_len = max_len % batch_size + final_stack_len = max_len + (batch_size - mod_max_len) + + # Calculate split point for Train/Validation split. + split_point = int(train_fraction * final_stack_len * envs_per_batch) + split_point = split_point - (split_point % use_batch_size) + + if mod_max_len > (min_len - (final_stack_len * envs_per_batch - split_point)) > 0: + warnings.warn('The Episode are divide to unequal, your validation set will ' + 'get polluted with training data.') + + # Cycle to it self to creat enough data to split it to batch_size length. + cycle_indices = [list(islice(cycle(st_i), None, final_stack_len)) for st_i in stack_indices] + + # Put the cycled data to the beginning to not affect the validation set. + cycle_indices = [cycle_indices[i][pre_cycle_len[i]:] + cycle_indices[i][:pre_cycle_len[i]]\ + for i in range(len(pre_cycle_len))] + + # Flatten the stack cycle list to a single list. + indices = [] + for i in range(0, len(cycle_indices[0]), batch_size): + for c_i in cycle_indices: + indices += c_i[i:i+batch_size] + + # Free memory + del split_indices, len_list, sort_buffer, stack_indices, max_len, mod_max_len,\ + final_stack_len, cycle_indices + + # Train/Validation split when using behavior cloning + train_indices = indices[:split_point] + val_indices = indices[split_point:] + + + assert len(train_indices) > 0, "No sample for the training set" + assert len(val_indices) > 0, "No sample for the validation set" + + self.observations = observations + self.actions = actions + self.mask = mask + + self.returns = traj_data['episode_returns'][:traj_limit_idx] + self.avg_ret = sum(self.returns) / len(self.returns) + self.std_ret = np.std(np.array(self.returns)) + self.verbose = verbose + + assert len(self.observations) == len(self.actions), "The number of actions and " \ + "observations differ " \ + "please check your expert" \ + "dataset" + self.num_traj = min(traj_limitation, np.sum(episode_starts)) + self.num_transition = len(self.observations) + self.sequential_preprocessing = sequential_preprocessing + + self.dataloader = None + self.train_loader = DataLoader(train_indices, + self.observations, + self.actions, + self.mask, + use_batch_size, + start_process=False, + sequential=sequential_preprocessing, + partial_minibatch=False) + + self.val_loader = DataLoader(val_indices, + self.observations, + self.actions, + self.mask, + use_batch_size, + start_process=False, + sequential=sequential_preprocessing, + partial_minibatch=False) + + if self.verbose >= 1: + self.log_info() + + def init_dataloader(self, batch_size): """ - Show histogram plotting of the episode returns + Initialize the dataloader used by GAIL. + + :param batch_size: (int) """ - # Isolate dependency since it is only used for plotting and also since - # different matplotlib backends have further dependencies themselves. - import matplotlib.pyplot as plt - plt.hist(self.returns) - plt.show() + indices = np.random.permutation(len(self.observations)).astype(np.int64) + self.dataloader = DataLoader(indices, + self.observations, + self.actions, + self.mask, + batch_size, + start_process=False, + sequential=self.sequential_preprocessing) class DataLoader(object): @@ -193,7 +432,7 @@ class DataLoader(object): lesser than the batch_size) """ - def __init__(self, indices, observations, actions, batch_size, n_workers=1, + def __init__(self, indices, observations, actions, mask, batch_size, n_workers=1, infinite_loop=True, max_queue_len=1, shuffle=False, start_process=True, backend='threading', sequential=False, partial_minibatch=True): super(DataLoader, self).__init__() @@ -209,6 +448,7 @@ def __init__(self, indices, observations, actions, batch_size, n_workers=1, self.batch_size = batch_size self.observations = observations self.actions = actions + self.mask = mask self.shuffle = shuffle self.queue = Queue(max_queue_len) self.process = None @@ -243,7 +483,7 @@ def sequential_next(self): """ Sequential version of the pre-processing. """ - if self.start_idx > len(self.indices): + if self.start_idx >= len(self.indices): raise StopIteration if self.start_idx == 0: @@ -257,8 +497,10 @@ def sequential_next(self): axis=0) actions = self.actions[self._minibatch_indices] + mask = self.mask[self._minibatch_indices] self.start_idx += self.batch_size - return obs, actions + + return obs, actions, mask def _run(self): start = True @@ -286,8 +528,9 @@ def _run(self): obs = np.concatenate(obs, axis=0) actions = self.actions[self._minibatch_indices] + mask = self.mask[self._minibatch_indices] - self.queue.put((obs, actions)) + self.queue.put((obs, actions, mask)) # Free memory del obs diff --git a/stable_baselines/ppo1/pposgd_simple.py b/stable_baselines/ppo1/pposgd_simple.py index 9154e74db0..3e0421f8a2 100644 --- a/stable_baselines/ppo1/pposgd_simple.py +++ b/stable_baselines/ppo1/pposgd_simple.py @@ -85,9 +85,19 @@ def __init__(self, policy, env, gamma=0.99, timesteps_per_actorbatch=256, clip_p def _get_pretrain_placeholders(self): policy = self.policy_pi action_ph = policy.pdtype.sample_placeholder([None]) + + if self.policy.recurrent: + states_ph = policy.states_ph + snew_ph = policy.snew + dones_ph = policy.dones_ph + else: + states_ph = None + snew_ph = None + dones_ph = None + if isinstance(self.action_space, gym.spaces.Discrete): - return policy.obs_ph, action_ph, policy.policy - return policy.obs_ph, action_ph, policy.deterministic_action + return policy.obs_ph, action_ph, states_ph, snew_ph, dones_ph, policy.policy + return policy.obs_ph, action_ph, states_ph, snew_ph, dones_ph, policy.deterministic_action def setup_model(self): with SetVerbosity(self.verbose): diff --git a/stable_baselines/ppo2/ppo2.py b/stable_baselines/ppo2/ppo2.py index 86b69d9cd8..73ebc8e891 100644 --- a/stable_baselines/ppo2/ppo2.py +++ b/stable_baselines/ppo2/ppo2.py @@ -100,10 +100,21 @@ def __init__(self, policy, env, gamma=0.99, n_steps=128, ent_coef=0.01, learning self.setup_model() def _get_pretrain_placeholders(self): - policy = self.act_model + + if self.policy.recurrent: + policy = self.train_model + states_ph = policy.states_ph + snew_ph = policy.snew + dones_ph = policy.dones_ph + else: + policy = self.act_model + states_ph = None + snew_ph = None + dones_ph = None + if isinstance(self.action_space, gym.spaces.Discrete): - return policy.obs_ph, self.action_ph, policy.policy - return policy.obs_ph, self.action_ph, policy.deterministic_action + return policy.obs_ph, self.action_ph, states_ph, snew_ph, dones_ph, policy.policy + return policy.obs_ph, self.action_ph, states_ph, snew_ph, dones_ph, policy.deterministic_action def setup_model(self): with SetVerbosity(self.verbose): diff --git a/stable_baselines/sac/sac.py b/stable_baselines/sac/sac.py index 88779d1ffb..99c2bd67a6 100644 --- a/stable_baselines/sac/sac.py +++ b/stable_baselines/sac/sac.py @@ -134,7 +134,7 @@ def _get_pretrain_placeholders(self): policy = self.policy_tf # Rescale deterministic_action = self.deterministic_action * np.abs(self.action_space.low) - return policy.obs_ph, self.actions_ph, deterministic_action + return policy.obs_ph, self.actions_ph, None, None, None, deterministic_action def setup_model(self): with SetVerbosity(self.verbose): diff --git a/stable_baselines/td3/td3.py b/stable_baselines/td3/td3.py index 307c76bc24..22ed948979 100644 --- a/stable_baselines/td3/td3.py +++ b/stable_baselines/td3/td3.py @@ -116,7 +116,7 @@ def _get_pretrain_placeholders(self): policy = self.policy_tf # Rescale policy_out = self.policy_out * np.abs(self.action_space.low) - return policy.obs_ph, self.actions_ph, policy_out + return policy.obs_ph, self.actions_ph, None, None, None, policy_out def setup_model(self): with SetVerbosity(self.verbose): diff --git a/stable_baselines/trpo_mpi/trpo_mpi.py b/stable_baselines/trpo_mpi/trpo_mpi.py index 1463030642..b8cb3e7966 100644 --- a/stable_baselines/trpo_mpi/trpo_mpi.py +++ b/stable_baselines/trpo_mpi/trpo_mpi.py @@ -99,9 +99,19 @@ def __init__(self, policy, env, gamma=0.99, timesteps_per_batch=1024, max_kl=0.0 def _get_pretrain_placeholders(self): policy = self.policy_pi action_ph = policy.pdtype.sample_placeholder([None]) + + if self.initial_state is None: + states_ph = None + snew_ph = None + dones_ph = None + else: + states_ph = policy.states_ph + snew_ph = policy.snew + dones_ph = policy.dones_ph + if isinstance(self.action_space, gym.spaces.Discrete): - return policy.obs_ph, action_ph, policy.policy - return policy.obs_ph, action_ph, policy.deterministic_action + return policy.obs_ph, action_ph, states_ph, snew_ph, dones_ph, policy.policy + return policy.obs_ph, action_ph, states_ph, snew_ph, dones_ph, policy.deterministic_action def setup_model(self): # prevent import loops @@ -436,7 +446,8 @@ def fisher_vector_product(vec): include_final_partial_batch=False, batch_size=batch_size, shuffle=True): - ob_expert, ac_expert = self.expert_dataset.get_next_batch() + ob_expert, ac_expert, _ = self.expert_dataset.get_next_batch() + # update running mean/std for reward_giver if self.reward_giver.normalize: self.reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0)) diff --git a/tests/test_gail.py b/tests/test_gail.py index 3d4a1608d8..b5fbce4316 100644 --- a/tests/test_gail.py +++ b/tests/test_gail.py @@ -6,9 +6,10 @@ from stable_baselines import A2C, ACER, ACKTR, GAIL, DDPG, DQN, PPO1, PPO2,\ TD3, TRPO, SAC + from stable_baselines.common.cmd_util import make_atari_env -from stable_baselines.common.vec_env import VecFrameStack -from stable_baselines.gail import ExpertDataset, generate_expert_traj +from stable_baselines.common.vec_env import VecFrameStack, DummyVecEnv +from stable_baselines.gail import ExpertDataset, ExpertDatasetLSTM, generate_expert_traj EXPERT_PATH_PENDULUM = "stable_baselines/gail/dataset/expert_pendulum.npz" EXPERT_PATH_DISCRETE = "stable_baselines/gail/dataset/expert_cartpole.npz" @@ -106,28 +107,60 @@ def test_pretrain_images(): del dataset, model, env -@pytest.mark.parametrize("model_class", [A2C, GAIL, DDPG, PPO1, PPO2, SAC, TD3, TRPO]) -def test_behavior_cloning_box(model_class): - """ - Behavior cloning with continuous actions. - """ - dataset = ExpertDataset(expert_path=EXPERT_PATH_PENDULUM, traj_limitation=10, - sequential_preprocessing=True, verbose=0) - model = model_class("MlpPolicy", "Pendulum-v0") - model.pretrain(dataset, n_epochs=20) - model.save("test-pretrain") - del dataset, model +@pytest.mark.parametrize("model_class_data", [[A2C, 4, True, "CartPole-v1", 32, 4], + [ACER, 4, True, "CartPole-v1", 1, 4], + [ACKTR, 4, True, "CartPole-v1", 16, 4], + [PPO2, 8, True, "CartPole-v1", 16, 2], + [A2C, 1, False, "CartPole-v1", 32, 1], + [ACER, 1, False, "CartPole-v1", 32, 1], + [ACKTR, 1, False, "CartPole-v1", 32, 1], + [DQN, 1, False, "CartPole-v1", 32, 1], + [GAIL, 1, False, "CartPole-v1", 32, 1], + [PPO1, 1, False, "CartPole-v1", 32, 1], + [PPO2, 1, False, "CartPole-v1", 32, 1], + [TRPO, 1, False, "CartPole-v1", 32, 1], + [A2C, 4, True, "Pendulum-v0", 32, 4], + [PPO2, 8, True, "Pendulum-v0", 16, 2], + [A2C, 1, False, "Pendulum-v0", 32, 1], + [GAIL, 1, False, "Pendulum-v0", 32, 1], + [PPO1, 1, False, "Pendulum-v0", 32, 1], + [PPO2, 1, False, "Pendulum-v0", 32, 1], + [TRPO, 1, False, "Pendulum-v0", 32, 1], + [TD3, 1, False, "Pendulum-v0", 32, 1] + ]) + +def test_behavior_cloning(model_class_data): + + model_class, num_env, lstm, game, batch_size, envs_per_batch = model_class_data + + if game == "CartPole-v1": + load_data = EXPERT_PATH_DISCRETE + else: + load_data = EXPERT_PATH_PENDULUM + if lstm: + dataset = ExpertDatasetLSTM(expert_path=load_data, traj_limitation=3, + sequential_preprocessing=True, verbose=0, + batch_size=batch_size, envs_per_batch=envs_per_batch) + policy = "MlpLstmPolicy" -@pytest.mark.parametrize("model_class", [A2C, ACER, ACKTR, DQN, GAIL, PPO1, PPO2, TRPO]) -def test_behavior_cloning_discrete(model_class): - dataset = ExpertDataset(expert_path=EXPERT_PATH_DISCRETE, traj_limitation=10, - sequential_preprocessing=True, verbose=0) - model = model_class("MlpPolicy", "CartPole-v1") - model.pretrain(dataset, n_epochs=10) + else: + dataset = ExpertDataset(expert_path=load_data, traj_limitation=3, + sequential_preprocessing=True, verbose=0, + batch_size=batch_size) + policy = "MlpPolicy" + + env = DummyVecEnv([lambda: gym.make(game) for i in range(num_env)]) + + try: + model = model_class(policy, env, n_steps=batch_size) + except TypeError: + model = model_class(policy, env) + + model.pretrain(dataset, n_epochs=3) model.save("test-pretrain") - del dataset, model + del dataset, model, env def test_dataset_param_validation(): with pytest.raises(ValueError):