-
Notifications
You must be signed in to change notification settings - Fork 727
adding LSTM support to pretrain #315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
add frame index aliment to ExpertDataset
update forke
# Conflicts: # stable_baselines/gail/dataset/dataset.py # stable_baselines/trpo_mpi/trpo_mpi.py # tests/test_gail.py
|
The problem is that |
-convert float envs_per_batch in to int envs_per_batch
|
Hello, do you consider this PR ready for review? (After a quick look, I saw that a saved file was still there (nano) and there seems to be some code duplication that can be improved ;)) |
|
o/
I removed nano.
I am not quite sure but I think you are referring to the The only thing I still plan to do is to add a bit more functionality to it. The code which is already there is so far finalized and can be reviewed. |
|
The Test has failed after updating my branch. |
You can ignore this, I've attempted a fix in #467 |
|
Now that my PR has finally pass all the unit test again, could you start reviewing the PR, so that I then can change it if necessary. |
araffin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the test_recorded_images folder too.
stable_baselines/a2c/a2c.py
Outdated
| def _get_pretrain_placeholders(self): | ||
| policy = self.train_model | ||
|
|
||
| if self.initial_state is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do:
states_ph, snew_ph, dones_ph = None, None, Noneso it's more compact, same for the else case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having the variable declaration vertical and horizontal interrupts the read flow. Yes It would make the code shorter but also les readable in my opinion. But I will change it if you really wont it that way.
stable_baselines/acer/acer_simple.py
Outdated
| policy = self.step_model | ||
| action_ph = policy.pdtype.sample_placeholder([None]) | ||
|
|
||
| if self.initial_state is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same remark as before
stable_baselines/a2c/a2c.py
Outdated
| def _get_pretrain_placeholders(self): | ||
| policy = self.train_model | ||
|
|
||
| if self.initial_state is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should rather check the recurrent attribute of the policy, it is in the base policy class
| self.sess = None | ||
| self.params = None | ||
| self._param_load_ops = None | ||
| self.initial_state = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need that variable, there is the recurrent attribute for that
| else: | ||
| val_interval = int(n_epochs / 10) | ||
|
|
||
| use_lstm = self.initial_state is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same remark, you can use the recurrent attribute
| for epoch_idx in range(int(n_epochs)): | ||
| train_loss = 0.0 | ||
| if use_lstm: | ||
| state = self.initial_state[:envs_per_batch] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
initial state is an attribute of the policy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes and no.
All the models Which can use LSTM policies have the Variable self.initial_state, Which gets set to the initial state from policy. The variable self.initial_state gets used and not the one in the policy. It is also not that easy, to access the initial state from the BaseRLModel. It wars much simpler to at the self.initial_state variable to the Base Model, and then let is overwrite later at model initialization.
|
|
||
| if use_lstm: | ||
| feed_dict.update({states_ph: state, dones_ph: expert_mask}) | ||
| val_loss_, = self.sess.run([loss], feed_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you only need to update the feeddict, self.sess.run can be called outside, so you avoid code duplication
| :param traj_limitation: (int) the number of trajectory to use (if -1, load all) | ||
| :param randomize: (bool) if the dataset should be shuffled | ||
| :param randomize: (bool) if the dataset should be shuffled, this will be overwritten to False | ||
| if LSTM is True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
? where is the LSTM variable?
| except StopIteration: | ||
| dataloader = iter(dataloader) | ||
| return next(dataloader) | ||
| if traj_data is not None and expert_path is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't this be check in the base class?
Looks like duplicated code
Also, I'm not sure if two classes are needed...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally had it in one class, but someone who used the PR, has suggested to split it in two classes. I think that this was a good idea, because it clearly improved the user friendliness of the ExpertDataset class.
tests/test_gail.py
Outdated
| model.pretrain(dataset, n_epochs=20) | ||
| model.save("test-pretrain") | ||
| del dataset, model | ||
| @pytest.mark.parametrize("model_class_data", [[A2C, 4, True, "MlpLstmPolicy", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like duplicated code, I think you can handle the different type of policy and expert path in the if.
…aselines into LSTM-pretrain # Conflicts: # stable_baselines/gail/dataset/dataset.py
|
Hi everyone, I would like to know whether there is still active work on "LSTM support to pretrain". I've seen that this feature was removed from the v2.8.0 milestone more than a month ago. Is the work on hold? Kind regards! |
@skervim |
|
hello @skervim, As @XMaster96 says, you can use his fork for now if you want try the feature. |
Are you referring to the requested changes, I have partially implemented and partially commented on why I think I shouldn't change that. Or are you referring to future change requests? I am also aware that I don't have yet written a Documentation for the website, I was planning on doing it when everything is ok, and merge ready. |
This PR adds LSTM support to
pretrain. I am not quite done yet, but there are some Implementations matters that I need to discuss first.personal edit:
I finally found the time to work more on this PR. The problem is that I took so long that I forgot half of what I did, so if there is some rough code in there, it is because of that. I still do not have that much time, so expect me to not answer immediately
closes #253