-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
draft for rollin implementation #12
base: development
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
import os | ||
import random | ||
import sys | ||
import dataclasses | ||
|
||
import progressbar | ||
|
||
|
@@ -305,7 +306,7 @@ def train_worker_init_fn(worker_id): | |
train_subset_loader = utils.Dataset( | ||
random.sample(training_data.samples, int(len(training_data.samples) * args.train_subset_eval_size / 100))) \ | ||
.get_data_loader(batch_size=eval_batch_size, device=args.device) | ||
# rollin_schedule = inverse_sigmoid_schedule(args.k) | ||
rollin_schedule = inverse_sigmoid_schedule(args.k) | ||
max_patience = args.patience | ||
|
||
if args.loss_reduction == "sum": | ||
|
@@ -322,15 +323,31 @@ def train_worker_init_fn(worker_id): | |
best_epoch = 0 | ||
patience = 0 | ||
|
||
sample_ids = list(range(len(training_data.samples))) | ||
sample_stack = [] | ||
|
||
for epoch in range(args.epochs): | ||
|
||
logging.info("Training...") | ||
transducer_.train() | ||
transducer_.zero_grad() | ||
with utils.Timer(): | ||
train_loss = 0. | ||
# rollin not implemented at the moment | ||
# rollin = rollin_schedule(epoch) | ||
|
||
#rollin | ||
if args.rollin is not None and args.rollin: | ||
rollin = rollin_schedule(epoch) | ||
nr_samples = int(len(training_data.samples) * rollin) | ||
rollin_samples = random.sample(sample_ids, nr_samples) | ||
with torch.no_grad(): | ||
# restore | ||
for id_, sample in sample_stack: | ||
training_data.samples[id_] = sample | ||
sample_stack = [] | ||
for id_ in rollin_samples: | ||
sample_stack.append((id_, dataclasses.replace(training_data.samples[id_]))) | ||
transducer_.roll_in(training_data.samples[id_], rollin) | ||
|
||
j = 0 | ||
for j, batch in enumerate(training_data_loader): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the current setup, rollin is performed before each epoch. I think a problem with this approach could be that the model performs update during the epoch and the rolled in target sequences are not representative anymore for the model. So maybe it would make more sense to rollin after every training step with some probability. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, it would be more sound and more useful to sample for each batch (this will address the errors of the current model checkpoint, not the errors that may have resolved themselves anyway already due to the recent parameter updates). |
||
losses = transducer_.training_step(encoded_input=batch.encoded_input, | ||
|
@@ -478,8 +495,10 @@ def cli_main(): | |
help="Number of decoder LSTM layers.") | ||
parser.add_argument("--beam-width", type=int, default=4, | ||
help="Beam width for beam search decoding. A value < 1 will disable beam search decoding.") | ||
# parser.add_argument("--k", type=int, default=1, | ||
# help="k for inverse sigmoid rollin schedule.") | ||
parser.add_argument("--rollin", action="store_true", default=False, | ||
help="Rollin.") | ||
parser.add_argument("--k", type=int, default=1, | ||
help="k for inverse sigmoid rollin schedule.") | ||
parser.add_argument("--patience", type=int, default=12, | ||
help="Maximal patience for early stopping.") | ||
parser.add_argument("--epochs", type=int, default=60, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
from trans.vocabulary import BEGIN_WORD, COPY, DELETE, END_WORD, PAD, \ | ||
FeatureVocabularies | ||
from trans import ENCODER_MAPPING | ||
|
||
from trans import utils | ||
|
||
MAX_ACTION_SEQ_LEN = 150 | ||
MAX_INPUT_SEQ_LEN = 100 | ||
|
@@ -539,6 +539,99 @@ def continue_decoding(): | |
return Output(action_history, self.decode_encoded_output(input_, action_history), | ||
log_p, None) | ||
|
||
def roll_in(self, sample: utils.Sample, rollin: int) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The implementation is for a single sample at the moment, the decoder steps could probably be batched (and sampling / exploration as loop). This would likely increase speed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
"""Runs the transducer for greedy decoding. | ||
|
||
Args: | ||
sample: A single sample from the training data. | ||
rollin: tbd. | ||
|
||
Returns: | ||
None.""" | ||
|
||
# adjust initial decoder states if batch_size has changed | ||
self.h0_c0 = 1 # batch size is always zero | ||
|
||
# initialize state variables | ||
alignment_history = torch.tensor([[0]], device=self.device) | ||
action_history = torch.tensor([[[BEGIN_WORD]]], | ||
device=self.device, dtype=torch.int) | ||
input_length = torch.tensor([len(sample.input) + 1], device=self.device) | ||
output = [] | ||
optimal_actions = [[BEGIN_WORD]] | ||
|
||
# run encoder | ||
bidirectional_emb = self.encoder_step(sample.encoded_input.unsqueeze(dim=0)) | ||
|
||
# compute feature embedding | ||
feature_emb = self.feature_embedding(sample.encoded_features) | ||
|
||
# initial cell state for decoder | ||
decoder = self.h0_c0 | ||
|
||
stop = False | ||
while not stop and action_history.size(2) <= MAX_ACTION_SEQ_LEN: | ||
current_alignment = alignment_history[:, -1] | ||
valid_actions_mask = self.valid_actions_lookup[:, input_length - current_alignment] | ||
|
||
# run decoder | ||
decoder_output, decoder = self.decoder_step( | ||
bidirectional_emb, feature_emb, decoder, | ||
current_alignment, action_history[:, :, -1]) | ||
|
||
# model's predictions | ||
actions, log_probs = self.calculate_actions(decoder_output, valid_actions_mask) | ||
log_probs_np = log_probs.squeeze().cpu().detach().numpy() | ||
sample_action = self.sample(log_probs_np) | ||
|
||
# expert prediction | ||
expert_actions = self.expert_rollout(sample.input, sample.target, | ||
current_alignment.item(), output) | ||
optimal_actions.append(expert_actions) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The targets are always generated by expert. |
||
|
||
# update states | ||
if np.random.rand() <= rollin: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The next state is either sampled from the expert or the model itself. Is this the idea? Alternatively, one could always execute the model (and only rely on the expert for the targets). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's the idea! |
||
action = sample_action | ||
else: | ||
action = expert_actions[ | ||
int(np.argmax([log_probs_np[a] for a in expert_actions])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this does not over-corrects the model if it already predicts an optimal action, in case multiple actions are optimal. |
||
] | ||
|
||
action_history = torch.cat( | ||
(action_history, torch.tensor([[[action]]], device=self.device)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment regarding concatenation in the loop. |
||
dim=2 | ||
) | ||
|
||
char, current_alignment, stop = self.decode_single_action(sample.input, action, current_alignment.item()) | ||
if char != "": | ||
output.append(char) | ||
|
||
alignment_history = torch.cat( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't seem like this has to be done via concatenation in a loop. Can this not be re-written using list append? |
||
(alignment_history, torch.tensor([[current_alignment]], device=self.device)), | ||
dim=1 | ||
) | ||
|
||
# build masks | ||
action_history = action_history[0, 0, :].tolist() | ||
alignment_history = alignment_history[0, :].tolist() | ||
|
||
optimal_actions_mask = torch.full( | ||
(len(optimal_actions) - 1, self.number_actions), | ||
False, dtype=torch.bool, device=self.device) | ||
seq_pos, emb_pos = zip(*[(s - 1, a) for s in range(1, len(optimal_actions)) | ||
for a in optimal_actions[s]]) | ||
optimal_actions_mask[seq_pos, emb_pos] = True | ||
sample.optimal_actions_mask = optimal_actions_mask | ||
|
||
sample.alignment_history = torch.tensor(alignment_history[:-1], device=self.device) | ||
sample.action_history = torch.tensor(action_history[:-1], device=self.device) | ||
|
||
valid_actions_mask = torch.stack( | ||
[self.compute_valid_actions(len(sample.input) + 1 - a) for a in alignment_history[:-1]], dim=0) | ||
sample.valid_actions_mask = valid_actions_mask | ||
|
||
return None | ||
|
||
def decode_encoded_output(self, input_: List[List[str]], encoded_output: List[List[int]]) -> List[str]: | ||
"""Decode a list of encoded output sequences given their string input. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@peter-makarov Here my idea was to sample an increasing size of samples each epoch while sampling a different subset each epoch (and setting the previous subset to the previous status (teacher forcing))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense.