Skip to content

Commit 655a507

Browse files
Uncertainty Baselines Teamcopybara-github
authored andcommitted
Paracrawl baseline task
PiperOrigin-RevId: 521695791
1 parent f5f6f50 commit 655a507

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2023 The Uncertainty Baselines Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Task-specific configurations for Paracrawl training and evaluation.
16+
from __gin__ import dynamic_registration
17+
18+
import seqio
19+
import __main__ as train_script
20+
from t5x import checkpoints
21+
from t5x import utils
22+
from t5x import decoding
23+
24+
# Register necessary SeqIO Tasks/Mixtures.
25+
import data.mixtures # local file import from baselines.t5
26+
27+
# Change to paracrawl_wmt_deen_precleaned for the baseline on precleaned data.
28+
MIXTURE_OR_TASK_NAME = 'paracrawl_wmt_deen'
29+
30+
# Eval mixture with evaluation on both dev and test sets.
31+
EVAL_MIXTURE_OR_TASK_NAME = 'paracrawl_wmt_deen_eval_mixture'
32+
33+
# Disable caching since ub tasks are not cached in the official directory.
34+
USE_CACHED_TASKS = False
35+
36+
BASE_LEARNING_RATE = 2.0
37+
38+
# Adjust checkpoint saving.
39+
utils.SaveCheckpointConfig:
40+
period = 5000
41+
dtype = 'float32'
42+
keep = 8 # Keep the 8 best checkpoints.
43+
save_dataset = False # Don't checkpoint dataset state.
44+
checkpointer_cls = @checkpoints.SaveBestCheckpointer
45+
46+
checkpoints.SaveBestCheckpointer:
47+
metric_name_to_monitor = 'inference_eval/paracrawl/eval/bleu'
48+
metric_mode = 'max'
49+
keep_checkpoints_without_metrics = False
50+
51+
train_script.train.infer_eval_dataset_cfg = @train_infer/utils.DatasetConfig()
52+
train_script.train.inference_evaluator_cls = @seqio.Evaluator
53+
54+
train_infer/utils.DatasetConfig:
55+
mixture_or_task_name = %EVAL_MIXTURE_OR_TASK_NAME
56+
task_feature_lengths = None
57+
split = 'validation'
58+
batch_size = %BATCH_SIZE
59+
shuffle = False
60+
seed = 42
61+
use_cached = %USE_CACHED_TASKS
62+
pack = False
63+
module = None # %MIXTURE_OR_TASK_MODULE
64+
65+
decoding.beam_search.max_decode_len = 256
66+
67+
# Disable JSON logger to reduce the cns storage required for inference.
68+
seqio.Evaluator:
69+
logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
70+
num_examples = None # Use all examples in the infer_eval dataset.
71+
72+
seqio.JSONLogger.write_n_results = None # Write all inferences.
73+
74+
utils.create_learning_rate_scheduler:
75+
factors = 'constant * rsqrt_decay'
76+
base_learning_rate = %BASE_LEARNING_RATE
77+
warmup_steps = 5000
78+

0 commit comments

Comments
 (0)