Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft for rollin implementation #12

Draft
wants to merge 2 commits into
base: development
Choose a base branch
from
Draft

Conversation

slvnwhrl
Copy link
Owner

No description provided.

@slvnwhrl slvnwhrl self-assigned this Jun 16, 2022
rollin_samples = random.sample(sample_ids, nr_samples)
with torch.no_grad():
# restore
for id_, sample in sample_stack:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peter-makarov Here my idea was to sample an increasing size of samples each epoch while sampling a different subset each epoch (and setting the previous subset to the previous status (teacher forcing))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

@@ -539,6 +539,99 @@ def continue_decoding():
return Output(action_history, self.decode_encoded_output(input_, action_history),
log_p, None)

def roll_in(self, sample: utils.Sample, rollin: int) -> None:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation is for a single sample at the moment, the decoder steps could probably be batched (and sampling / exploration as loop). This would likely increase speed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rollin should probably be a float representing probability?

# expert prediction
expert_actions = self.expert_rollout(sample.input, sample.target,
current_alignment.item(), output)
optimal_actions.append(expert_actions)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The targets are always generated by expert.

optimal_actions.append(expert_actions)

# update states
if np.random.rand() <= rollin:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next state is either sampled from the expert or the model itself. Is this the idea? Alternatively, one could always execute the model (and only rely on the expert for the targets).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the idea!

@slvnwhrl
Copy link
Owner Author

@peter-makarov So I've drafted an implementation. I've tested it a bit and it seems to work somehow, however, I think I am still missing something...

for id_ in rollin_samples:
sample_stack.append((id_, dataclasses.replace(training_data.samples[id_])))
transducer_.roll_in(training_data.samples[id_], rollin)

j = 0
for j, batch in enumerate(training_data_loader):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current setup, rollin is performed before each epoch. I think a problem with this approach could be that the model performs update during the epoch and the rolled in target sequences are not representative anymore for the model. So maybe it would make more sense to rollin after every training step with some probability.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it would be more sound and more useful to sample for each batch (this will address the errors of the current model checkpoint, not the errors that may have resolved themselves anyway already due to the recent parameter updates).

Copy link
Collaborator

@peter-makarov peter-makarov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

for id_ in rollin_samples:
sample_stack.append((id_, dataclasses.replace(training_data.samples[id_])))
transducer_.roll_in(training_data.samples[id_], rollin)

j = 0
for j, batch in enumerate(training_data_loader):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it would be more sound and more useful to sample for each batch (this will address the errors of the current model checkpoint, not the errors that may have resolved themselves anyway already due to the recent parameter updates).

@@ -539,6 +539,99 @@ def continue_decoding():
return Output(action_history, self.decode_encoded_output(input_, action_history),
log_p, None)

def roll_in(self, sample: utils.Sample, rollin: int) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rollin should probably be a float representing probability?

optimal_actions.append(expert_actions)

# update states
if np.random.rand() <= rollin:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the idea!

action = sample_action
else:
action = expert_actions[
int(np.argmax([log_probs_np[a] for a in expert_actions]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this does not over-corrects the model if it already predicts an optimal action, in case multiple actions are optimal.

if char != "":
output.append(char)

alignment_history = torch.cat(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem like this has to be done via concatenation in a loop. Can this not be re-written using list append?

]

action_history = torch.cat(
(action_history, torch.tensor([[[action]]], device=self.device)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment regarding concatenation in the loop.

@peter-makarov
Copy link
Collaborator

@peter-makarov So I've drafted an implementation. I've tested it a bit and it seems to work somehow, however, I think I am still missing something...

Is there no improvement?

Test it on some little data (e.g. morphological inflection 100 samples) and batch size 1. Try with roll-in and without roll-in. You should be seeing consistent improvement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants