Skip to content

Commit

Permalink
Merge pull request #3 from freewym/vimal_raw_python_script
Browse files Browse the repository at this point in the history
added the option trainer.deriv-truncate-margin to train_rnn.py and tr…
  • Loading branch information
vimalmanohar committed Nov 29, 2016
2 parents 32d4167 + fc16bde commit 48fd6ab
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 24 deletions.
37 changes: 25 additions & 12 deletions egs/wsj/s5/steps/nnet3/train_raw_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,14 @@ def get_args():
"parallel every minibatch")
parser.add_argument("--trainer.rnn.num-bptt-steps", type=int,
dest='num_bptt_steps', default=None,
help="""The number of time steps to back-propagate from
the last label in the chunk. By default it is same as
the (chunk-width + 10).""")
help="""Deprecated. Kept for back compatibility.""")
parser.add_argument("--trainer.deriv-truncate-margin", type=int,
dest='deriv_truncate_margin', default=8,
help="""Margin (in input frames) around the 'required'
part of each chunk that the derivatives are
backpropagated to. E.g., 8 is a reasonable setting.
Note: the 'required' part of the chunk is defined by
the model's {left,right}-context.""")

# General options
parser.add_argument("--nj", type=int, default=4,
Expand Down Expand Up @@ -161,6 +166,17 @@ def process_args(args):
if args.chunk_right_context < 0:
raise Exception("--egs.chunk-right-context should be non-negative")

if args.num_bptt_steps is not None:
# -2 is used to compensate for the splicing of the input frame, assuming
# that splicing spans from -2 to 2
args.deriv_truncate_margin = args.num_bptt_steps - args.chunk_width - 2
logger.warning(
"--trainer.rnn.num-bptt-steps (deprecated) is set by user, and "
"--trainer.deriv-truncate-margin is set to (num-bptt-steps - "
"chunk-width - 2) = {0}. We recommend using the option "
"--trainer.deriv-truncate-margin.".format(
args.deriv_truncate_margin))

if (not os.path.exists(args.dir)
or not os.path.exists(args.dir+"/configs")):
raise Exception("This scripts expects {0} to exist and have a configs "
Expand Down Expand Up @@ -344,15 +360,12 @@ def learning_rate(iter, current_num_jobs, num_archives_processed):
args.initial_effective_lrate,
args.final_effective_lrate)

if args.num_bptt_steps is None:
# num_bptt_steps is set to (chunk_width + 10) by default
num_bptt_steps = args.chunk_width + min(10, args.chunk_left_context,
args.chunk_right_context)
else:
num_bptt_steps = args.num_bptt_steps

min_deriv_time = args.chunk_width - num_bptt_steps
max_deriv_time = num_bptt_steps - 1
min_deriv_time = None
max_deriv_time = None
if args.deriv_truncate_margin is not None:
min_deriv_time = -args.deriv_truncate_margin - model_left_context
max_deriv_time = (args.chunk_width - 1 + args.deriv_truncate_margin
+ model_right_context)

logger.info("Training will run for {0} epochs = "
"{1} iterations".format(args.num_epochs, num_iters))
Expand Down
37 changes: 25 additions & 12 deletions egs/wsj/s5/steps/nnet3/train_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,14 @@ def get_args():
"parallel every minibatch")
parser.add_argument("--trainer.rnn.num-bptt-steps", type=int,
dest='num_bptt_steps', default=None,
help="""The number of time steps to back-propagate from
the last label in the chunk. By default it is same as
the (chunk-width + 10).""")
help="""Deprecated. Kept for back compatibility.""")
parser.add_argument("--trainer.deriv-truncate-margin", type=int,
dest='deriv_truncate_margin', default=8,
help="""Margin (in input frames) around the 'required'
part of each chunk that the derivatives are
backpropagated to. E.g., 8 is a reasonable setting.
Note: the 'required' part of the chunk is defined by
the model's {left,right}-context.""")

# General options
parser.add_argument("--feat-dir", type=str, required=True,
Expand Down Expand Up @@ -157,6 +162,17 @@ def process_args(args):
if args.chunk_right_context < 0:
raise Exception("--egs.chunk-right-context should be non-negative")

if args.num_bptt_steps is not None:
# -2 is used to compensate for the splicing of the input frame, assuming
# that splicing spans from -2 to 2
args.deriv_truncate_margin = args.num_bptt_steps - args.chunk_width - 2
logger.warning(
"--trainer.rnn.num-bptt-steps (deprecated) is set by user, and "
"--trainer.deriv-truncate-margin is set to (num-bptt-steps - "
"chunk-width - 2) = {0}. We recommend using the option "
"--trainer.deriv-truncate-margin.".format(
args.deriv_truncate_margin))

if (not os.path.exists(args.dir)
or not os.path.exists(args.dir+"/configs")):
raise Exception("This scripts expects {0} to exist and have a configs "
Expand Down Expand Up @@ -343,15 +359,12 @@ def learning_rate(iter, current_num_jobs, num_archives_processed):
args.initial_effective_lrate,
args.final_effective_lrate)

if args.num_bptt_steps is None:
# num_bptt_steps is set to (chunk_width + 10) by default
num_bptt_steps = args.chunk_width + min(10, args.chunk_left_context,
args.chunk_right_context)
else:
num_bptt_steps = args.num_bptt_steps

min_deriv_time = args.chunk_width - num_bptt_steps
max_deriv_time = num_bptt_steps - 1
min_deriv_time = None
max_deriv_time = None
if args.deriv_truncate_margin is not None:
min_deriv_time = -args.deriv_truncate_margin - model_left_context
max_deriv_time = (args.chunk_width - 1 + args.deriv_truncate_margin
+ model_right_context)

logger.info("Training will run for {0} epochs = "
"{1} iterations".format(args.num_epochs, num_iters))
Expand Down

0 comments on commit 48fd6ab

Please sign in to comment.