Skip to content

Commit

Permalink
Merge branch 'jbarker/non_determinism_redux' into 'main'
Browse files Browse the repository at this point in the history
Fix embedding layer non-determinism again

See merge request ADLR/megatron-lm!751
  • Loading branch information
jon-barker committed Sep 29, 2023
2 parents 9b8643c + bbc6dc1 commit 8737bc1
Show file tree
Hide file tree
Showing 22 changed files with 25 additions and 343 deletions.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,8 @@ We utilize the publicly available [OpenWebText](https://github.com/eukaryote31/o
# Reproducibility
Megatron training is intended to be bitwise reproducible. This means that the same training config run twice in the same HW and SW environment should produce identical model checkpoints, losses and accuracy metric values (iteration time metrics may vary).

There are currently three known Megatron optimizations that break reproducibility whilst still producing almost identical training runs. They are only applicable when using NGC containers >=22.05. The following workarounds should be applied in cases where reproducibility is required:
1. When training using the `--bf16` option the backward pass of `torch.nn.functional.embedding` is non-deterministic. If reproducibility is required you should also use the option `--embedding-weights-in-fp32`. The speed and memory impact of this change is negligible.
2. Also when training using `--bf16`, reproducbility is only obtained when the checkpointing and resume schedule of training is identical. If the checkpointing schedule will change, i.e. checkpointing and resume will occur at different iterations, the option `--no-bias-gelu-fusion` should be used.
3. Flash attention is non-deterministic. If reproducibility is required do not use `--use-flash-attn`.
There are currently two known Megatron optimizations that break reproducibility whilst still producing almost identical training runs. The following workarounds should be applied in cases where reproducibility is required:
1. When training using `--bf16`, reproducbility is only obtained when the checkpointing and resume schedule of training is identical. If the checkpointing schedule will change, i.e. checkpointing and resume will occur at different iterations, the option `--no-bias-gelu-fusion` should be used.
2. Flash attention is non-deterministic. If reproducibility is required do not use `--use-flash-attn`.

These sources of non-determinism are under active investigation. If you observe non-determinism in Megatron training under other circumstances please open an issue.
2 changes: 0 additions & 2 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,6 @@ def _add_network_size_args(parser):
help='Number of Experts in Switch Transformer (None means no Switch)')
group.add_argument('--untie-embeddings-and-output-weights', action='store_true',
help='Untie embeddings and output weights.'),
group.add_argument('--embedding-weights-in-fp32', action='store_true',
help='Cast word embedding weights to fp32 before embedding fwd.'),
return parser


Expand Down
19 changes: 2 additions & 17 deletions megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,6 @@ def __init__(
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
# Set the detauls for compatibility.
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.0
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
(
Expand Down Expand Up @@ -211,16 +204,8 @@ def forward(self, input_):
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(
masked_input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
# Get the embeddings.
output_parallel = self.weight[masked_input]
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
Expand Down
16 changes: 2 additions & 14 deletions megatron/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,6 @@ class Embedding(MegatronModule):
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
embedding_weights_in_fp32: casts word embedding weights to
fp32 before sampling. Required to
maintain reproducibility when
training in bf16.
"""

def __init__(self,
Expand All @@ -141,8 +137,7 @@ def __init__(self,
max_sequence_length,
embedding_dropout_prob,
config,
num_tokentypes=0,
embedding_weights_in_fp32=False):
num_tokentypes=0):
super(Embedding, self).__init__()

self.hidden_size = hidden_size
Expand All @@ -152,7 +147,6 @@ def __init__(self,
args = get_args()

# Word embeddings (parallel).
self.embedding_weights_in_fp32 = embedding_weights_in_fp32
self.params_dtype = args.params_dtype
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size, config=config, init_method=config.init_method)
Expand Down Expand Up @@ -217,12 +211,7 @@ def add_tokentype_embeddings(self, num_tokentypes):

def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
if self.embedding_weights_in_fp32:
self.word_embeddings = self.word_embeddings.to(torch.float32)
words_embeddings = self.word_embeddings(input_ids)
if self.embedding_weights_in_fp32:
words_embeddings = words_embeddings.to(self.params_dtype)
self.word_embeddings = self.word_embeddings.to(self.params_dtype)
if self.add_position_embedding:
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
Expand Down Expand Up @@ -366,8 +355,7 @@ def __init__(self,
args.max_position_embeddings,
args.hidden_dropout,
config,
self.num_tokentypes,
args.embedding_weights_in_fp32)
self.num_tokentypes)
self._embedding_key = 'embedding'

# Rotary positional embeddings
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.50685, 10.49816, 10.47982, 10.48566, 10.49533, 10.46662, 10.42393, 10.30694, 10.1598, 9.96959]}, "num-zeros": {"start_step": 0, "end_step": 34, "step_interval": 5, "values": [18771.0, 19036.0, 22186.0, 18552.0, 21033.0, 23314.0, 22529.0]}, "iteration_timing_avg": 0.44337617647058825}
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.50685, 10.49816, 10.47982, 10.48566, 10.49533, 10.46662, 10.42394, 10.30694, 10.15979, 9.96957]}, "num-zeros": {"start_step": 0, "end_step": 34, "step_interval": 5, "values": [18772.0, 19035.0, 22296.0, 18412.0, 20887.0, 23006.0, 22439.0]}, "iteration_timing_avg": 0.4169808823529412}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.54837, 10.54636, 10.55694, 10.54151, 10.53088, 10.48503, 10.46272, 10.31499, 10.1712, 9.97326]}, "num-zeros": {"start_step": 0, "end_step": 34, "step_interval": 5, "values": [22603.0, 20620.0, 26075.0, 23583.0, 21709.0, 21601.0, 23088.0]}, "iteration_timing_avg": 0.9086541176470588}
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.54837, 10.54636, 10.55694, 10.54151, 10.53088, 10.48503, 10.46275, 10.31499, 10.17122, 9.97326]}, "num-zeros": {"start_step": 0, "end_step": 34, "step_interval": 5, "values": [22606.0, 20619.0, 26292.0, 23607.0, 21666.0, 21672.0, 23313.0]}, "iteration_timing_avg": 0.9262994117647059}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.44877, 10.43852, 10.44018, 10.44113, 10.45623, 10.44141, 10.39044, 10.25681, 10.133, 9.95745]}, "num-zeros": {"start_step": 0, "end_step": 34, "step_interval": 5, "values": [27843.0, 20675.0, 28449.0, 26397.0, 24158.0, 21043.0, 21057.0]}, "iteration_timing_avg": 0.8035391176470587}
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.44877, 10.43852, 10.44018, 10.44113, 10.45623, 10.44143, 10.39045, 10.25681, 10.13301, 9.95744]}, "num-zeros": {"start_step": 0, "end_step": 34, "step_interval": 5, "values": [27844.0, 20265.0, 28481.0, 26139.0, 24126.0, 21087.0, 21026.0]}, "iteration_timing_avg": 0.7951058823529413}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.48681, 10.48784, 10.4873, 10.50416, 10.49442, 10.47818, 10.41362, 10.28136, 10.14424, 9.94147]}, "num-zeros": {"start_step": 0, "end_step": 34, "step_interval": 5, "values": [27199.0, 19944.0, 25298.0, 24277.0, 21516.0, 19536.0, 20924.0]}, "iteration_timing_avg": 1.3894499999999999}
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.48681, 10.48784, 10.4873, 10.50416, 10.49442, 10.47817, 10.41358, 10.28136, 10.14425, 9.94147]}, "num-zeros": {"start_step": 0, "end_step": 34, "step_interval": 5, "values": [27195.0, 19616.0, 25279.0, 24916.0, 21579.0, 19699.0, 20897.0]}, "iteration_timing_avg": 1.4259938235294118}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"lm loss": {"start_step": 0, "end_step": 44, "step_interval": 5, "values": [10.84008, 10.89053, 10.90905, 10.87934, 10.86562, 10.83752, 10.64582, 10.62397, 10.53554]}, "num-zeros": {"start_step": 0, "end_step": 27, "step_interval": 5, "values": [2078.0, 2320.0, 2519.0, 2248.0, 2127.0, 1987.0]}, "iteration_timing_avg": 0.09863333333333332}
{"lm loss": {"start_step": 0, "end_step": 42, "step_interval": 5, "values": [10.84008, 10.89053, 10.90905, 10.87934, 10.86562, 10.83752, 10.64582, 10.62396, 10.53554]}, "num-zeros": {"start_step": 0, "end_step": 25, "step_interval": 5, "values": [2078.0, 2328.0, 2420.0, 2256.0, 2180.0]}, "iteration_timing_avg": 0.09522035714285715}
Original file line number Diff line number Diff line change
@@ -1,33 +1 @@
{
"lm loss": {
"start_step": 0,
"end_step": 36,
"step_interval": 5,
"values": [
10.83273,
10.86937,
10.89188,
10.80831,
10.68615,
10.6145,
10.09491,
10.21578
]
},
"num-zeros": {
"start_step": 0,
"end_step": 36,
"step_interval": 5,
"values": [
1548.0,
1851.0,
1858.0,
1845.0,
1768.0,
1715.0,
1526.0,
1917.0
]
},
"iteration_timing_avg": 0.09456208333333331
}
{"lm loss": {"start_step": 0, "end_step": 36, "step_interval": 5, "values": [10.83273, 10.86936, 10.89186, 10.80832, 10.68611, 10.61451, 10.09495, 10.21575]}, "num-zeros": {"start_step": 0, "end_step": 36, "step_interval": 5, "values": [1551.0, 1779.0, 1907.0, 1882.0, 1871.0, 1667.0, 1501.0, 1933.0]}, "iteration_timing_avg": 0.09391500000000001}
Original file line number Diff line number Diff line change
@@ -1,29 +1 @@
{
"lm loss": {
"start_step": 0,
"end_step": 28,
"step_interval": 5,
"values": [
10.84609,
10.87725,
10.90506,
10.81872,
10.67719,
10.60489
]
},
"num-zeros": {
"start_step": 0,
"end_step": 28,
"step_interval": 5,
"values": [
1743.0,
2097.0,
1981.0,
1981.0,
2013.0,
1896.0
]
},
"iteration_timing_avg": 0.10225333333333335
}
{"lm loss": {"start_step": 0, "end_step": 42, "step_interval": 5, "values": [10.84609, 10.87727, 10.90506, 10.81871, 10.67715, 10.60493, 10.06861, 10.1946, 10.11546]}, "num-zeros": {"start_step": 0, "end_step": 42, "step_interval": 5, "values": [1744.0, 2089.0, 2023.0, 2009.0, 2130.0, 1933.0, 1666.0, 2033.0, 2223.0]}, "iteration_timing_avg": 0.10196714285714288}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"lm loss": {"start_step": 0, "end_step": 47, "step_interval": 5, "values": [10.81244, 10.87437, 10.90132, 10.84946, 10.84826, 10.81872, 10.61811, 10.61053, 10.52823, 10.22961]}, "num-zeros": {"start_step": 0, "end_step": 30, "step_interval": 5, "values": [2356.0, 2601.0, 2778.0, 2282.0, 2350.0, 2782.0]}, "iteration_timing_avg": 0.12793593749999999}
{"lm loss": {"start_step": 0, "end_step": 45, "step_interval": 5, "values": [10.81244, 10.87437, 10.90132, 10.84946, 10.84826, 10.81873, 10.61811, 10.61052, 10.52823]}, "num-zeros": {"start_step": 0, "end_step": 27, "step_interval": 5, "values": [2365.0, 2535.0, 2707.0, 2210.0, 2411.0, 2781.0]}, "iteration_timing_avg": 0.13055}
Original file line number Diff line number Diff line change
@@ -1,29 +1 @@
{
"lm loss": {
"start_step": 0,
"end_step": 27,
"step_interval": 5,
"values": [
10.79373,
10.86736,
10.89174,
10.78285,
10.66227,
10.58291
]
},
"num-zeros": {
"start_step": 0,
"end_step": 27,
"step_interval": 5,
"values": [
1670.0,
1914.0,
1868.0,
1951.0,
1846.0,
1709.0
]
},
"iteration_timing_avg": 0.12781055555555554
}
{"lm loss": {"start_step": 0, "end_step": 29, "step_interval": 5, "values": [10.79373, 10.86739, 10.89171, 10.78289, 10.66227, 10.58291]}, "num-zeros": {"start_step": 0, "end_step": 29, "step_interval": 5, "values": [1670.0, 1836.0, 1842.0, 1890.0, 1795.0, 1705.0]}, "iteration_timing_avg": 0.12559400000000004}
Original file line number Diff line number Diff line change
@@ -1,33 +1 @@
{
"lm loss": {
"start_step": 0,
"end_step": 36,
"step_interval": 5,
"values": [
10.79374,
10.86741,
10.89181,
10.78307,
10.66263,
10.58358,
10.08691,
10.19344
]
},
"num-zeros": {
"start_step": 0,
"end_step": 36,
"step_interval": 5,
"values": [
1568.0,
1829.0,
1883.0,
1921.0,
1839.0,
1701.0,
1580.0,
1954.0
]
},
"iteration_timing_avg": 0.12052666666666663
}
{"lm loss": {"start_step": 0, "end_step": 38, "step_interval": 5, "values": [10.79374, 10.86745, 10.89179, 10.78304, 10.66262, 10.58362, 10.08688, 10.19342]}, "num-zeros": {"start_step": 0, "end_step": 38, "step_interval": 5, "values": [1567.0, 1904.0, 1912.0, 1931.0, 1799.0, 1722.0, 1591.0, 1950.0]}, "iteration_timing_avg": 0.12253038461538461}
Original file line number Diff line number Diff line change
@@ -1,33 +1 @@
{
"lm loss": {
"start_step": 0,
"end_step": 40,
"step_interval": 5,
"values": [
10.79373,
10.86736,
10.89174,
10.78285,
10.66227,
10.58291,
10.08584,
10.1921
]
},
"num-zeros": {
"start_step": 0,
"end_step": 40,
"step_interval": 5,
"values": [
1670.0,
1914.0,
1868.0,
1951.0,
1846.0,
1709.0,
1557.0,
1942.0
]
},
"iteration_timing_avg": 0.12695888888888887
}
{"lm loss": {"start_step": 0, "end_step": 42, "step_interval": 5, "values": [10.79373, 10.86739, 10.89171, 10.78289, 10.66227, 10.58291, 10.08584, 10.19211, 10.13576]}, "num-zeros": {"start_step": 0, "end_step": 42, "step_interval": 5, "values": [1670.0, 1836.0, 1842.0, 1890.0, 1795.0, 1705.0, 1516.0, 1968.0, 2356.0]}, "iteration_timing_avg": 0.12682214285714286}
Original file line number Diff line number Diff line change
@@ -1,33 +1 @@
{
"lm loss": {
"start_step": 0,
"end_step": 40,
"step_interval": 5,
"values": [
10.73353,
10.81785,
10.84054,
10.76024,
10.70354,
10.63165,
10.21176,
10.37203
]
},
"num-zeros": {
"start_step": 0,
"end_step": 40,
"step_interval": 5,
"values": [
2536.0,
2967.0,
2881.0,
2747.0,
2639.0,
2566.0,
2367.0,
2701.0
]
},
"iteration_timing_avg": 0.12756653846153845
}
{"lm loss": {"start_step": 0, "end_step": 29, "step_interval": 5, "values": [10.73353, 10.81786, 10.84052, 10.76021, 10.70355, 10.63168]}, "num-zeros": {"start_step": 0, "end_step": 28, "step_interval": 5, "values": [2536.0, 3043.0, 2818.0, 2790.0, 2582.0, 2459.0]}, "iteration_timing_avg": 0.1284436842105263}
Original file line number Diff line number Diff line change
@@ -1,33 +1 @@
{
"lm loss": {
"start_step": 0,
"end_step": 39,
"step_interval": 5,
"values": [
10.8968,
10.90832,
10.91767,
10.84824,
10.70838,
10.63459,
10.15693,
10.26264
]
},
"num-zeros": {
"start_step": 0,
"end_step": 39,
"step_interval": 5,
"values": [
22727758.0,
23021490.0,
22500312.0,
22830774.0,
22739320.0,
22546524.0,
22955648.0,
22588796.0
]
},
"iteration_timing_avg": 0.12539576923076923
}
{"lm loss": {"start_step": 0, "end_step": 28, "step_interval": 5, "values": [10.8968, 10.9083, 10.91766, 10.84824, 10.70841, 10.63455]}, "num-zeros": {"start_step": 0, "end_step": 28, "step_interval": 5, "values": [22727842.0, 23021604.0, 22500412.0, 22830772.0, 22739552.0, 22546566.0]}, "iteration_timing_avg": 0.12624631578947368}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.85543, 10.89355, 10.87608, 10.87365, 10.88042, 10.84182, 10.67177, 10.62854, 10.52511, 10.25229]}, "num-zeros": {"start_step": 0, "end_step": 33, "step_interval": 5, "values": [2470.0, 2444.0, 2570.0, 2192.0, 2241.0, 2574.0, 2476.0]}, "iteration_timing_avg": 0.14008088235294117}
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.85543, 10.89355, 10.87608, 10.87365, 10.88042, 10.84182, 10.67177, 10.62853, 10.52511, 10.2523]}, "num-zeros": {"start_step": 0, "end_step": 33, "step_interval": 5, "values": [2472.0, 2462.0, 2480.0, 2235.0, 2268.0, 2619.0, 2429.0]}, "iteration_timing_avg": 0.14355058823529418}
Loading

0 comments on commit 8737bc1

Please sign in to comment.