Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft for rollin implementation #12

Draft
wants to merge 2 commits into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions trans/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import random
import sys
import dataclasses

import progressbar

Expand Down Expand Up @@ -305,7 +306,7 @@ def train_worker_init_fn(worker_id):
train_subset_loader = utils.Dataset(
random.sample(training_data.samples, int(len(training_data.samples) * args.train_subset_eval_size / 100))) \
.get_data_loader(batch_size=eval_batch_size, device=args.device)
# rollin_schedule = inverse_sigmoid_schedule(args.k)
rollin_schedule = inverse_sigmoid_schedule(args.k)
max_patience = args.patience

if args.loss_reduction == "sum":
Expand All @@ -322,15 +323,31 @@ def train_worker_init_fn(worker_id):
best_epoch = 0
patience = 0

sample_ids = list(range(len(training_data.samples)))
sample_stack = []

for epoch in range(args.epochs):

logging.info("Training...")
transducer_.train()
transducer_.zero_grad()
with utils.Timer():
train_loss = 0.
# rollin not implemented at the moment
# rollin = rollin_schedule(epoch)

#rollin
if args.rollin is not None and args.rollin:
rollin = rollin_schedule(epoch)
nr_samples = int(len(training_data.samples) * rollin)
rollin_samples = random.sample(sample_ids, nr_samples)
with torch.no_grad():
# restore
for id_, sample in sample_stack:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peter-makarov Here my idea was to sample an increasing size of samples each epoch while sampling a different subset each epoch (and setting the previous subset to the previous status (teacher forcing))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

training_data.samples[id_] = sample
sample_stack = []
for id_ in rollin_samples:
sample_stack.append((id_, dataclasses.replace(training_data.samples[id_])))
transducer_.roll_in(training_data.samples[id_], rollin)

j = 0
for j, batch in enumerate(training_data_loader):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current setup, rollin is performed before each epoch. I think a problem with this approach could be that the model performs update during the epoch and the rolled in target sequences are not representative anymore for the model. So maybe it would make more sense to rollin after every training step with some probability.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it would be more sound and more useful to sample for each batch (this will address the errors of the current model checkpoint, not the errors that may have resolved themselves anyway already due to the recent parameter updates).

losses = transducer_.training_step(encoded_input=batch.encoded_input,
Expand Down Expand Up @@ -478,8 +495,10 @@ def cli_main():
help="Number of decoder LSTM layers.")
parser.add_argument("--beam-width", type=int, default=4,
help="Beam width for beam search decoding. A value < 1 will disable beam search decoding.")
# parser.add_argument("--k", type=int, default=1,
# help="k for inverse sigmoid rollin schedule.")
parser.add_argument("--rollin", action="store_true", default=False,
help="Rollin.")
parser.add_argument("--k", type=int, default=1,
help="k for inverse sigmoid rollin schedule.")
parser.add_argument("--patience", type=int, default=12,
help="Maximal patience for early stopping.")
parser.add_argument("--epochs", type=int, default=60,
Expand Down
95 changes: 94 additions & 1 deletion trans/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from trans.vocabulary import BEGIN_WORD, COPY, DELETE, END_WORD, PAD, \
FeatureVocabularies
from trans import ENCODER_MAPPING

from trans import utils

MAX_ACTION_SEQ_LEN = 150
MAX_INPUT_SEQ_LEN = 100
Expand Down Expand Up @@ -539,6 +539,99 @@ def continue_decoding():
return Output(action_history, self.decode_encoded_output(input_, action_history),
log_p, None)

def roll_in(self, sample: utils.Sample, rollin: float) -> None:
"""Runs the transducer for greedy decoding.

Args:
sample: A single sample from the training data.
rollin: tbd.

Returns:
None."""

# adjust initial decoder states if batch_size has changed
self.h0_c0 = 1 # batch size is always zero

# initialize state variables
alignment_history = torch.tensor([[0]], device=self.device)
action_history = torch.tensor([[[BEGIN_WORD]]],
device=self.device, dtype=torch.int)
input_length = torch.tensor([len(sample.input) + 1], device=self.device)
output = []
optimal_actions = [[BEGIN_WORD]]

# run encoder
bidirectional_emb = self.encoder_step(sample.encoded_input.unsqueeze(dim=0))

# compute feature embedding
feature_emb = self.feature_embedding(sample.encoded_features)

# initial cell state for decoder
decoder = self.h0_c0

stop = False
while not stop and action_history.size(2) <= MAX_ACTION_SEQ_LEN:
current_alignment = alignment_history[:, -1]
valid_actions_mask = self.valid_actions_lookup[:, input_length - current_alignment]

# run decoder
decoder_output, decoder = self.decoder_step(
bidirectional_emb, feature_emb, decoder,
current_alignment, action_history[:, :, -1])

# model's predictions
actions, log_probs = self.calculate_actions(decoder_output, valid_actions_mask)
log_probs_np = log_probs.squeeze().cpu().detach().numpy()
sample_action = self.sample(log_probs_np)

# expert prediction
expert_actions = self.expert_rollout(sample.input, sample.target,
current_alignment.item(), output)
optimal_actions.append(expert_actions)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The targets are always generated by expert.


# update states
if np.random.rand() <= rollin:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next state is either sampled from the expert or the model itself. Is this the idea? Alternatively, one could always execute the model (and only rely on the expert for the targets).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the idea!

action = sample_action
else:
action = expert_actions[
int(np.argmax([log_probs_np[a] for a in expert_actions]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this does not over-corrects the model if it already predicts an optimal action, in case multiple actions are optimal.

]

action_history = torch.cat(
(action_history, torch.tensor([[[action]]], device=self.device)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment regarding concatenation in the loop.

dim=2
)

char, current_alignment, stop = self.decode_single_action(sample.input, action, current_alignment.item())
if char != "":
output.append(char)

alignment_history = torch.cat(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem like this has to be done via concatenation in a loop. Can this not be re-written using list append?

(alignment_history, torch.tensor([[current_alignment]], device=self.device)),
dim=1
)

# build masks
action_history = action_history[0, 0, :].tolist()
alignment_history = alignment_history[0, :].tolist()

optimal_actions_mask = torch.full(
(len(optimal_actions) - 1, self.number_actions),
False, dtype=torch.bool, device=self.device)
seq_pos, emb_pos = zip(*[(s - 1, a) for s in range(1, len(optimal_actions))
for a in optimal_actions[s]])
optimal_actions_mask[seq_pos, emb_pos] = True
sample.optimal_actions_mask = optimal_actions_mask

sample.alignment_history = torch.tensor(alignment_history[:-1], device=self.device)
sample.action_history = torch.tensor(action_history[:-1], device=self.device)

valid_actions_mask = torch.stack(
[self.compute_valid_actions(len(sample.input) + 1 - a) for a in alignment_history[:-1]], dim=0)
sample.valid_actions_mask = valid_actions_mask

return None

def decode_encoded_output(self, input_: List[List[str]], encoded_output: List[List[int]]) -> List[str]:
"""Decode a list of encoded output sequences given their string input.

Expand Down