Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor parallelism for RWKV #1237

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions configs/local_setup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
"load": "checkpoints",
"checkpoint_validation_with_forward_pass": False,


# "launcher": "openmpi",
#"deepspeed_mpi": true,

"tensorboard_dir": "tensorboard",
"log_dir": "logs",
}
103 changes: 103 additions & 0 deletions configs/rwkv/1.5B.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
{
# Parallelism is not yet supported for rwkv
"pipe_parallel_size": 1,
"model_parallel_size": 2,

"num_layers": 24,
"hidden_size": 2048,
"num_attention_heads": 32, # head_size = dim_att / num_attention_heads.
# head_size is 64 for all rwkv models
"seq_length": 4096,
"max_position_embeddings": 4096,
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-5,
"train_micro_batch_size_per_gpu": 1,

"attention_config": [[["rwkv"], 24]],

"activation": "silu",

# model settings

#"pos_emb": "rotary",
"rotary_pct": 0.25,
"no_weight_tying": true,
"gpt_j_residual": true,

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,


# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0008,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00008,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"data_impl": "mmap",
"num_workers": 1,

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"bf16": {
"bf16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 12,
"hysteresis": 2,
"min_loss_scale": 1,
},

# misc. training settings
"train_iters": 1,
"lr_decay_iters": 1,
"distributed_backend": "nccl",
"lr_decay_style": "constant",
"warmup": 0.01,
"checkpoint_factor": 100,
"eval_interval": 100000,
"eval_iters": 10,
"seed": 1234,

# logging
"log_interval": 10,
"steps_per_print": 10,
"wall_clock_breakdown": true,
}
102 changes: 102 additions & 0 deletions configs/rwkv/7B.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
{
# Parallelism is not yet supported for rwkv
"pipe_parallel_size": 1,
"model_parallel_size": 1,

"num_layers": 32,
"hidden_size": 4096,
"num_attention_heads": 64, # head_size = dim_att / num_attention_heads.
# head_size is 64 for all rwkv models
"seq_length": 4096,
"max_position_embeddings": 4096,
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-5,
"train_micro_batch_size_per_gpu": 8,

"attention_config": [[["rwkv"], 32]],

"activation": "silu",

# model settings

#"pos_emb": "rotary",
"rotary_pct": 0.25,
"no_weight_tying": true,
"gpt_j_residual": true,

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,


# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0008,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00008,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"data_impl": "mmap",
"num_workers": 1,

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"bf16": {
"bf16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 12,
"hysteresis": 2,
"min_loss_scale": 1,
},

# misc. training settings
"train_iters": 500,
"lr_decay_iters": 500,
"distributed_backend": "nccl",
"lr_decay_style": "constant",
"warmup": 0.01,
"checkpoint_factor": 100,
"eval_interval": 100000,
"eval_iters": 10,

# logging
"log_interval": 10,
"steps_per_print": 10,
"wall_clock_breakdown": true,
}
2 changes: 2 additions & 0 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def cross_entropy(output, labels, _fp16=False):
else:
losses = mpu.vocab_parallel_cross_entropy(output.float().contiguous(), labels)
loss_mask = loss_mask.view(-1)
print(f"model output shape: {output.size()}, loss shape: {losses.size()}")
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
return loss

Expand Down Expand Up @@ -258,6 +259,7 @@ def init_specs(self):
LayerSpec(
RWKVResidualLayerPipe,
neox_args=self.neox_args,
init_method=self.init_method,
layer_number=i,
)
)
Expand Down
Loading
Loading