forked from openai/baselines
-
Notifications
You must be signed in to change notification settings - Fork 726
Closed
Labels
questionFurther information is requestedFurther information is requested
Description
I was training a GAIL model with MlpLstmPolicy in Stable_Baselines2, however, I could not successfully run the training process even though: I made the
assert issubclass(self.policy, LstmPolicy)
in the TRPO part, Is there any changes I should make? Or If there is no other possible solution, how can i customize a LSTM policy to be compatible with GAIL for myself?
Looking forward to your reply!
The error happened is here:
Traceback (most recent call last): File "/home/.../train-recurrentGail.py", line 18, in <module>
model.learn(total_timesteps=100000)
File "/home/.../model/gail/model.py", line 57, in learn
return super().learn(total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps)
File "/home/.../model/gail/myTrpo.py", line 364, in learn
seg = seg_gen.__next__()
File "/home/.../common/runners.py", line 118, in traj_segment_generator
action, vpred, states, _ = policy.step(observation.reshape(-1, *observation.shape), states, done)
File "/home/.../stable_baselines/common/policies.py", line 508, in step
{self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})
File "/home/.../python/client/session.py", line 956, in run
run_metadata_ptr)
File "/home/.../python/client/session.py", line 1156, in _run
(np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape () for Tensor 'input_1/dones_ph:0', which has shape '(1,)'
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested