Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
vijayaditya committed Jan 24, 2017
1 parent 2d599fe commit 83a9bab
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 17 deletions.
4 changes: 2 additions & 2 deletions egs/swbd/s5c/local/chain/tuning/run_mtdnn_1a.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ set -e
# configs for 'chain'
affix=
stage=12
train_stage=-3
train_stage=-10
get_egs_stage=-10
speed_perturb=true
dir=exp/chain/mtdnn_1a # Note: _sp will get added to this if $speed_perturb == true.
dir=exp/chain/mtdnn_1a_new # Note: _sp will get added to this if $speed_perturb == true.
decode_iter=
lattice_beam=

Expand Down
4 changes: 2 additions & 2 deletions egs/swbd/s5c/local/chain/tuning/run_mtdnn_1b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ set -e

# configs for 'chain'
affix=
stage=12
train_stage=-10
stage=13
train_stage=0
get_egs_stage=-10
speed_perturb=true
dir=exp/chain/mtdnn_1b # Note: _sp will get added to this if $speed_perturb == true.
Expand Down
35 changes: 33 additions & 2 deletions egs/wsj/s5/steps/libs/nnet3/report/log_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

class KaldiLogParseException(Exception):
""" An Exception class that throws an error when there is an issue in
parsing the log files. Extend this class if more granularity is needed.
"""
def __init__(self, message = None):
if message.strip() == "":
message = None

Exception.__init__(self,
"There was an error while trying to parse the logs."
" Details : \n{0}\n".format(message))


def parse_progress_logs_for_nonlinearity_stats(exp_dir):
""" Parse progress logs for mean and std stats for non-linearities.
Expand Down Expand Up @@ -279,7 +291,7 @@ def parse_prob_logs(exp_dir, key='accuracy', output="output"):

parse_regex = re.compile(
".*compute_prob_.*\.([0-9]+).log:LOG "
".nnet3.*compute-prob:PrintTotalStats..:"
".nnet3.*compute-prob.*:PrintTotalStats..:"
"nnet.*diagnostics.cc:[0-9]+. Overall ([a-zA-Z\-]+) for "
"'{output}'.*is ([0-9.\-e]+) .*per frame".format(output=output))

Expand All @@ -292,19 +304,33 @@ def parse_prob_logs(exp_dir, key='accuracy', output="output"):
groups = mat_obj.groups()
if groups[1] == key:
train_loss[int(groups[0])] = groups[2]
if not train_loss:
raise KaldiLogParseException("Could not find any lines with {k} in "
" {l}".format(k=key, l=train_prob_files))

for line in valid_prob_strings.split('\n'):
mat_obj = parse_regex.search(line)
if mat_obj is not None:
groups = mat_obj.groups()
if groups[1] == key:
valid_loss[int(groups[0])] = groups[2]

if not valid_loss:
raise KaldiLogParseException("Could not find any lines with {k} in "
" {l}".format(k=key, l=valid_prob_files))

iters = list(set(valid_loss.keys()).intersection(train_loss.keys()))
if not iters:
raise KaldiLogParseException("Could not any common iterations with"
" key {k} in both {tl} and {vl}".format(
k=key, tl=train_prob_files, vl=valid_prob_files))
iters.sort()
return map(lambda x: (int(x), float(train_loss[x]),
float(valid_loss[x])), iters)


def generate_accuracy_report(exp_dir, key="accuracy", output="output"):

def generate_acc_logprob_report(exp_dir, key="accuracy", output="output"):
times = parse_train_logs(exp_dir)
data = parse_prob_logs(exp_dir, key, output)
report = []
Expand All @@ -315,6 +341,11 @@ def generate_accuracy_report(exp_dir, key="accuracy", output="output"):
x[1], x[2], x[2]-x[1]))
except KeyError:
continue
if len(report) - 1 == 0:
raise KaldiLogParseException("Could not find any lines with {k} in "
" {e}/log/compute_prob_train.*.log or "
" {e}/log/compute_prob_valid.*.log or both".format(
k=key, e=exp_dir))

total_time = 0
for iter in times.keys():
Expand Down
22 changes: 11 additions & 11 deletions egs/wsj/s5/steps/nnet3/report/generate_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def latex_compliant_name(name_string):
return node_name_string


def generate_accuracy_plots(exp_dir, output_dir, plot, key='accuracy',
def generate_acc_logprob_plots(exp_dir, output_dir, plot, key='accuracy',
file_basename='accuracy', comparison_dir=None,
start_iter=1,
latex_report=None, output_name='output'):
Expand All @@ -170,20 +170,20 @@ def generate_accuracy_plots(exp_dir, output_dir, plot, key='accuracy',
dirs = [exp_dir] + comparison_dir
index = 0
for dir in dirs:
[accuracy_report, accuracy_times,
accuracy_data] = log_parse.generate_accuracy_report(dir, key,
output_name)
[report, times, data] = log_parse.generate_acc_logprob_report(dir, key,
output_name)
if index == 0:
# this is the main experiment directory
with open("{0}/{1}.log".format(output_dir,
file_basename), "w") as f:
f.write(accuracy_report)
f.write(report)

if plot:
color_val = g_plot_colors[index]
data = np.array(accuracy_data)
data = np.array(data)
if data.shape[0] == 0:
raise Exception("Couldn't find any rows for the accuracy plot")
raise Exception("Couldn't find any rows for the"
"accuracy/log-probability plot")
data = data[data[:, 0] >= start_iter, :]
plot_handle, = plt.plot(data[:, 0], data[:, 1], color=color_val,
linestyle="--",
Expand Down Expand Up @@ -594,28 +594,28 @@ def generate_plots(exp_dir, output_dir, output_names, comparison_dir=None,
for (output_name, objective_type) in output_names:
if objective_type == "linear":
logger.info("Generating accuracy plots")
generate_accuracy_plots(
generate_acc_logprob_plots(
exp_dir, output_dir, g_plot, key='accuracy',
file_basename='accuracy', comparison_dir=comparison_dir,
start_iter=start_iter,
latex_report=latex_report, output_name=output_name)

logger.info("Generating log-likelihood plots")
generate_accuracy_plots(
generate_acc_logprob_plots(
exp_dir, output_dir, g_plot, key='log-likelihood',
file_basename='loglikelihood', comparison_dir=comparison_dir,
start_iter=start_iter,
latex_report=latex_report, output_name=output_name)
elif objective_type == "chain":
logger.info("Generating log-probability plots")
generate_accuracy_plots(
generate_acc_logprob_plots(
exp_dir, output_dir, g_plot,
key='log-probability', file_basename='log_probability',
comparison_dir=comparison_dir, start_iter=start_iter,
latex_report=latex_report, output_name=output_name)
else:
logger.info("Generating " + objective_type + " objective plots")
generate_accuracy_plots(
generate_acc_logprob_plots(
exp_dir, output_dir, g_plot, key='objective',
file_basename='objective', comparison_dir=comparison_dir,
start_iter=start_iter,
Expand Down

0 comments on commit 83a9bab

Please sign in to comment.