-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
draft for rollin implementation #12
base: development
Are you sure you want to change the base?
Conversation
rollin_samples = random.sample(sample_ids, nr_samples) | ||
with torch.no_grad(): | ||
# restore | ||
for id_, sample in sample_stack: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@peter-makarov Here my idea was to sample an increasing size of samples each epoch while sampling a different subset each epoch (and setting the previous subset to the previous status (teacher forcing))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense.
trans/transducer.py
Outdated
@@ -539,6 +539,99 @@ def continue_decoding(): | |||
return Output(action_history, self.decode_encoded_output(input_, action_history), | |||
log_p, None) | |||
|
|||
def roll_in(self, sample: utils.Sample, rollin: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation is for a single sample at the moment, the decoder steps could probably be batched (and sampling / exploration as loop). This would likely increase speed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rollin
should probably be a float
representing probability?
# expert prediction | ||
expert_actions = self.expert_rollout(sample.input, sample.target, | ||
current_alignment.item(), output) | ||
optimal_actions.append(expert_actions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The targets are always generated by expert.
optimal_actions.append(expert_actions) | ||
|
||
# update states | ||
if np.random.rand() <= rollin: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The next state is either sampled from the expert or the model itself. Is this the idea? Alternatively, one could always execute the model (and only rely on the expert for the targets).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's the idea!
@peter-makarov So I've drafted an implementation. I've tested it a bit and it seems to work somehow, however, I think I am still missing something... |
for id_ in rollin_samples: | ||
sample_stack.append((id_, dataclasses.replace(training_data.samples[id_]))) | ||
transducer_.roll_in(training_data.samples[id_], rollin) | ||
|
||
j = 0 | ||
for j, batch in enumerate(training_data_loader): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the current setup, rollin is performed before each epoch. I think a problem with this approach could be that the model performs update during the epoch and the rolled in target sequences are not representative anymore for the model. So maybe it would make more sense to rollin after every training step with some probability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, it would be more sound and more useful to sample for each batch (this will address the errors of the current model checkpoint, not the errors that may have resolved themselves anyway already due to the recent parameter updates).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
for id_ in rollin_samples: | ||
sample_stack.append((id_, dataclasses.replace(training_data.samples[id_]))) | ||
transducer_.roll_in(training_data.samples[id_], rollin) | ||
|
||
j = 0 | ||
for j, batch in enumerate(training_data_loader): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, it would be more sound and more useful to sample for each batch (this will address the errors of the current model checkpoint, not the errors that may have resolved themselves anyway already due to the recent parameter updates).
trans/transducer.py
Outdated
@@ -539,6 +539,99 @@ def continue_decoding(): | |||
return Output(action_history, self.decode_encoded_output(input_, action_history), | |||
log_p, None) | |||
|
|||
def roll_in(self, sample: utils.Sample, rollin: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rollin
should probably be a float
representing probability?
optimal_actions.append(expert_actions) | ||
|
||
# update states | ||
if np.random.rand() <= rollin: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's the idea!
action = sample_action | ||
else: | ||
action = expert_actions[ | ||
int(np.argmax([log_probs_np[a] for a in expert_actions])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this does not over-corrects the model if it already predicts an optimal action, in case multiple actions are optimal.
if char != "": | ||
output.append(char) | ||
|
||
alignment_history = torch.cat( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't seem like this has to be done via concatenation in a loop. Can this not be re-written using list append?
] | ||
|
||
action_history = torch.cat( | ||
(action_history, torch.tensor([[[action]]], device=self.device)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment regarding concatenation in the loop.
Is there no improvement? Test it on some little data (e.g. morphological inflection 100 samples) and batch size 1. Try with roll-in and without roll-in. You should be seeing consistent improvement. |
No description provided.