Skip to content

[question]A problem of how to use MlpLstmPolicy in GAIL training? #1148

@LongchaoDa

Description

@LongchaoDa

I was training a GAIL model with MlpLstmPolicy in Stable_Baselines2, however, I could not successfully run the training process even though: I made the

assert issubclass(self.policy, LstmPolicy)

in the TRPO part, Is there any changes I should make? Or If there is no other possible solution, how can i customize a LSTM policy to be compatible with GAIL for myself?

Looking forward to your reply!

The error happened is here:

Traceback (most recent call last): File "/home/.../train-recurrentGail.py", line 18, in <module>
    model.learn(total_timesteps=100000)
  File "/home/.../model/gail/model.py", line 57, in learn
    return super().learn(total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps)
  File "/home/.../model/gail/myTrpo.py", line 364, in learn
    seg = seg_gen.__next__()
  File "/home/.../common/runners.py", line 118, in traj_segment_generator
    action, vpred, states, _ = policy.step(observation.reshape(-1, *observation.shape), states, done)
  File "/home/.../stable_baselines/common/policies.py", line 508, in step
    {self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})
  File "/home/.../python/client/session.py", line 956, in run
    run_metadata_ptr)
  File "/home/.../python/client/session.py", line 1156, in _run
    (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape () for Tensor 'input_1/dones_ph:0', which has shape '(1,)'

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions