diff --git a/README-MUP.md b/README-MUP.md index b2cb01a55..c2be40299 100644 --- a/README-MUP.md +++ b/README-MUP.md @@ -26,6 +26,13 @@ "mup-rp-embedding-mult": 1.0, ``` +## Install package + +``` +cd mup +pip install -e . +``` + ## Generate base shapes 1. Set use-mup to true @@ -33,12 +40,22 @@ 3. Run once. gpt-neox will instantiate a base model and a delta model, then save one file per rank named .. gpt-neox will exit immediately. 4. Set save-base-shapes to false -## Generate coord check plots (optional) +## Testing the implementation +The most simple test is to use the coordinate check: 1. Keep use-mup true 2. Set coord-check to true -3. Run once. gpt-neox will output jpg images similar to https://github.com/microsoft/mutransformers/blob/main/README.md#coord-check. gpt-neox will exit immediately +3. Run once. gpt-neox will output jpg images similar to those below and exit immediately 4. Set coord-check to false +What you are gonna get is some stastistics of pre-activations for models only differing by the width. If done correctly these should be approximately horizontal +![](mup/figures/coord_check_up.0.jpg) + *Healthy coordinate check* +![](mup/figures/coord_check_sp.0.jpg) + *Something's wrong* + +A second kind of test is to pick any configuration and learning rate (that doesn't lead to diverging training) and simply run a few different experiments fixing everything except for the width. Since with mup wider is always better the results should look like the figure below +![](mup/figures/width_check.png) + *Healthy training* ## Tune mup hyperparameters and LR @@ -47,3 +64,10 @@ The values under `mup hp search` were added and correspond to appendix F.4 from ## Transfer With the best LR set and the best mup HPs set, revert the value of hidden-size in the scaled-up config and run again. + +## Usage under distributed setting + +The code is setup so that each individual rank takes care of its own piece of model and dumps a different shape file to be picked up for training. The easiest way to do the right thing is to generate the base shapes with the same number of devices and same tensor/pipe parallelism that should be used later on. Consider also the following +- Data parallelism: nothing changes for mup, you can copy paste a base_shape N times for each data-parallel rank +- Pipe parallelism: still nothing changes but different ranks need to deal with different layers so check above +- **Tensor parallelism: has a huge effect on mup**. Column parallel layers get chopped on the input dimension changing the actual width of the parameter. Think carefully about what you are doing if you are not sticking to what's written above \ No newline at end of file diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 63f4122e2..939927767 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -301,8 +301,14 @@ def __init__( coeff = max(1, self.layer_number) self.norm_factor *= coeff - if neox_args.use_mup: - self.norm_factor = self.hidden_size_per_attention_head + # TODO + #right now there's no way to correctly set use_mup here, possible options: + #- refactor model init (hard) + #- do this via another config argument, e.g. "mup_norm_factor" (probably easy) + #- ignore, this never changed anything in my experiments + # + #if neox_args.use_mup: + # self.norm_factor = self.hidden_size_per_attention_head self.rpe = rpe diff --git a/megatron/model/word_embeddings.py b/megatron/model/word_embeddings.py index 488baf042..a0bbe4e55 100644 --- a/megatron/model/word_embeddings.py +++ b/megatron/model/word_embeddings.py @@ -50,7 +50,7 @@ def __init__( self.hidden_size = hidden_size self.init_method = init_method self.num_tokentypes = num_tokentypes - self.use_mup = neox_args.use_mup + self.use_mup = neox_args.use_mup # TODO: as of now this will always be false self.mup_embedding_mult = neox_args.mup_embedding_mult self.mup_rp_embedding_mult = neox_args.mup_rp_embedding_mult @@ -155,9 +155,11 @@ def forward(self, input_ids, position_ids, tokentype_ids=None): # Dropout. embeddings = self.embedding_dropout(embeddings) - if self.use_mup: - with torch.no_grad(): - embeddings.mul_(self.mup_embedding_mult) + # TODO: + # not only this always false because of the way the model is initialized, but this also throws an error + # if self.use_mup: + # with torch.no_grad(): + # embeddings.mul_(self.mup_embedding_mult) return embeddings diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index 92edbd6eb..3bcea1462 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -428,7 +428,7 @@ def __init__( self.init_method = init_method self.stride = stride self.mup_rescale_parameters = mup_rescale_parameters - self.use_mup = neox_args.use_mup + self.use_mup = False # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result @@ -539,6 +539,7 @@ def mup_reinitialize_weights(self, neox_args): partition_dim=0, stride=self.stride, ) + self.use_mup = True def set_parallel_output(self, value: bool): assert isinstance(value, bool) @@ -547,8 +548,9 @@ def set_parallel_output(self, value: bool): ) # if gather_output is True, parallel output is False, so we set the opposite def forward(self, input_): - if self.use_mup and self.mup_rescale_parameters: - input_ /= self.width_mult() + if self.mup_rescale_parameters: + if hasattr(self.weight, "infshape"): + input_ /= self.weight.infshape.width_mult() # Set up backprop all-reduce. input_parallel = copy_to_model_parallel_region(input_) # Matrix multiply. @@ -623,7 +625,7 @@ def __init__( self.stride = stride self.keep_master_weight_for_test = keep_master_weight_for_test self.mup_rescale_parameters = mup_rescale_parameters - self.use_mup = neox_args.use_mup + self.use_mup = False # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result @@ -728,13 +730,14 @@ def mup_reinitialize_weights(self, neox_args): partition_dim=1, stride=self.stride, ) + self.use_mup = True def set_parallel_output(self, parallel_output: bool): assert isinstance(parallel_output, bool) self.parallel_output = parallel_output def forward(self, input_): - if self.use_mup and self.mup_rescale_parameters: + if self.mup_rescale_parameters: input_ /= self.width_mult() # Set up backprop all-reduce. if self.input_is_parallel: diff --git a/megatron/training.py b/megatron/training.py index 548f81cb0..64c59f36f 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -62,13 +62,18 @@ def mup_weights_reinit(neox_args, model): def has_method(o, name): return callable(getattr(o, name, None)) + # HACK: it uses the mother class name to avoid re-initializing the output layer, highly prone to future bugs + # HACK: only works with non-tied input-output layers + + previous = "" for layer in model.modules(): # This normally would happen in set_base_shapes if we actually were able to use the MuReadout class if hasattr(layer, "mup_rescale_parameters") and layer.mup_rescale_parameters: layer._rescale_parameters() - - if has_method(layer, "mup_reinitialize_weights"): - layer.mup_reinitialize_weights(neox_args) + if previous != "ParallelLinearPipe": + if has_method(layer, "mup_reinitialize_weights"): + layer.mup_reinitialize_weights(neox_args) + previous = layer.__class__.__name__ def save_base_shapes(neox_args, base_shapes, use_cache): @@ -530,9 +535,9 @@ def get_optimizer(model, neox_args): # Use Adam if neox_args.use_mup: try: - from mup import MuAdam + from mup import MuAdamW # TODO: was there any particular reason for not using MuAdamW? - adam_optimizer = MuAdam + adam_optimizer = MuAdamW except ModuleNotFoundError: print("Please install mup https://github.com/microsoft/mup") raise Exception diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 88e49f073..d94ee5710 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -5,6 +5,7 @@ git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b huggingface_hub>=0.11.0 lm_eval>=0.3.0 mpi4py>=3.0.3 +git+https://github.com/EleutherAI/mup.git#egg=deepspeed numpy>=1.22.0 pybind11>=2.6.2 regex diff --git a/test.png b/test.png new file mode 100644 index 000000000..91b13026b Binary files /dev/null and b/test.png differ