Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion configs/local_setup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@
"log-dir": "logs",
"use_wandb": True,
"wandb_host": "https://api.wandb.ai",
"wandb_project": "neox"
"wandb_project": "neox",
"num_gpus": 1,
"ia3_prompt_tuning": True
}
3 changes: 3 additions & 0 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def load_checkpoint(
):
"""Load a model checkpoint and return the iteration."""
if neox_args.deepspeed:
if neox_args.ia3_prompt_tuning:
neox_args.load_module_strict = False
load_optim_and_scheduler = (
not neox_args.no_load_optim
) # TODO: These should be configured by separate args
Expand All @@ -241,6 +243,7 @@ def load_checkpoint(
load_optimizer_states=load_optim_and_scheduler,
load_lr_scheduler_states=load_optim_and_scheduler,
tag=tag,
load_module_strict=neox_args.load_module_strict
)

if checkpoint_name is None:
Expand Down
168 changes: 166 additions & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Transformer."""

import math
import sys
import torch
import torch.nn.functional as F
import torch.nn as nn
Expand Down Expand Up @@ -93,7 +94,9 @@ def __init__(
if self.activation_type == "geglu"
else ff_mult * neox_args.hidden_size
)
self.dense_h_to_4h = mpu.ColumnParallelLinear(
mlp_column_parallel_cls = getattr(mpu, neox_args.mlp_column_parallel_cls)

self.dense_h_to_4h = mlp_column_parallel_cls(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim,
Expand Down Expand Up @@ -590,6 +593,166 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
return output, bias


class ParallelSelfAttentionIA3(ParallelSelfAttention):
def __init__(
self,
neox_args,
attention_mask_func,
init_method,
output_layer_init_method,
layer_number,
rpe=None,
rotary=False,
use_cache=False,
parallel_output=False,
):
super().__init__(
neox_args,
attention_mask_func,
init_method,
output_layer_init_method,
layer_number,
rpe=rpe,
rotary=rotary,
use_cache=use_cache,
parallel_output=parallel_output,
)
self.l_k = self._create_ia3_parameter(neox_args)
self.l_v = self._create_ia3_parameter(neox_args)

def _create_ia3_parameter(self, neox_args):
if neox_args.use_cpu_initialization:
param = torch.nn.Parameter(
torch.empty(
self.hidden_size_per_partition, dtype=neox_args.params_dtype
)
)
else:
param = torch.nn.Parameter(
torch.empty(
self.hidden_size_per_partition,
device=torch.cuda.current_device(),
dtype=neox_args.params_dtype,
)
)
param.model_parallel = True
param.partition_dim = 0
#param.stride = stride
# Always initialize to ones.
with torch.no_grad():
torch.nn.init.ones_(param)
return param

def forward(self, hidden_states, attention_mask, layer_past=None):

# hidden_states: [sq, b, h]

# =====================
# Query, Key, and Value
# =====================

# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(
mixed_x_layer, 3
)
# Apply IA3 rescaling to keys & values:
def _apply_ia3_rescaling(layer, scale_vector):
layer_size = layer.shape
layer = layer.reshape(layer_size[0], layer_size[1], -1)
layer *= scale_vector
return layer.reshape(layer_size)

key_layer = _apply_ia3_rescaling(key_layer, self.l_k)
value_layer = _apply_ia3_rescaling(value_layer, self.l_v)

if exists(self.rotary_emb):
if exists(self.rotary_ndims):
# partial rotary
query_rot, query_pass = (
query_layer[..., : self.rotary_ndims],
query_layer[..., self.rotary_ndims :],
)
key_rot, key_pass = (
key_layer[..., : self.rotary_ndims],
key_layer[..., self.rotary_ndims :],
)
else:
# full rotary
query_rot, key_rot = query_layer, key_layer
apply_rotary_fn = (
apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb
)

seq_len = key_layer.shape[0]
offset = 0
if exists(layer_past) and layer_past.numel() > 0:
offset = layer_past[0].shape[0]
seq_len += offset
cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
query_layer, key_layer = apply_rotary_fn(
query_rot, key_rot, cos, sin, offset=offset
)

if exists(self.rotary_ndims):
query_layer = torch.cat((query_layer, query_pass), dim=-1)
key_layer = torch.cat((key_layer, key_pass), dim=-1)

# ==================================
# Cache key and value for inference
# ==================================

if exists(layer_past) and layer_past.numel() > 0:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0)
value_layer = torch.cat(
(past_value.type_as(value_layer), value_layer), dim=0
)

if self.use_cache:
present = torch.stack((key_layer, value_layer))

if self.use_flash_attention:
context_layer = self.flash_attention(query_layer, key_layer, value_layer)
elif not self.sparse:
context_layer = self.attention(
query_layer, key_layer, value_layer, layer_past, attention_mask
)
else:
context_layer = self.sparse_attention(
query_layer, key_layer, value_layer, attention_mask
)

# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (
self.hidden_size_per_partition,
)
context_layer = context_layer.view(*new_context_layer_shape)

# =================
# Output. [sq, b, h]
# =================

output, bias = self.dense(context_layer)

if self.use_cache:
output = [output, present]

return output, bias


class ParallelTransformerLayer(nn.Module):
"""A single transformer layer.

Expand Down Expand Up @@ -625,9 +788,10 @@ def __init__(

if self.gpt_j_residual:
self.reduce = mpu.mappings.reduce_from_model_parallel_region
self_attention_cls = getattr(sys.modules[__name__], neox_args.self_attention_cls)

# Self attention.
self.attention = ParallelSelfAttention(
self.attention = self_attention_cls(
neox_args=neox_args,
attention_mask_func=attention_mask_func,
init_method=init_method,
Expand Down
4 changes: 2 additions & 2 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ def get_params_for_weight_decay_optimization(module, neox_args):
[
p
for n, p in list(module_._parameters.items())
if p is not None and n != "bias"
if p is not None and n not in neox_args.no_weight_decay_params
]
)
no_weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None and n == "bias"
if p is not None and n in neox_args.no_weight_decay_params
]
)
if neox_args.weight_decay == 0.0:
Expand Down
1 change: 1 addition & 0 deletions megatron/mpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .initialize import model_parallel_is_initialized

from .layers import ColumnParallelLinear
from .layers import ColumnParallelLinearIA3
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
from .layers import ParallelRelativePositionBias
Expand Down
65 changes: 65 additions & 0 deletions megatron/mpu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,71 @@ def forward(self, input_):
return output, output_bias


class ColumnParallelLinearIA3(ColumnParallelLinear):
def __init__(
self,
neox_args,
input_size,
output_size,
bias=True,
gather_output=True,
init_method=init.xavier_normal_,
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
mup_rescale_parameters=False,
):
super().__init__(
neox_args,
input_size,
output_size,
bias=bias,
gather_output=gather_output,
init_method=init_method,
stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test,
skip_bias_add=skip_bias_add,
mup_rescale_parameters=mup_rescale_parameters
)
if neox_args.use_cpu_initialization:
self.l_ff = Parameter(
torch.empty(
self.output_size_per_partition, dtype=neox_args.params_dtype
)
)
else:
self.l_ff = Parameter(
torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=neox_args.params_dtype,
)
)
self.l_ff.model_parallel = True
self.l_ff.partition_dim = 0
self.l_ff.stride = stride
# Always initialize l_ff to ones.
with torch.no_grad():
torch.nn.init.ones_(self.l_ff)

def forward(self, input_):
if self.use_mup and self.mup_rescale_parameters:
input_ /= self.width_mult()
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
# Matrix multiply.

bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
output_parallel *= self.l_ff # apply IA3 rescaling
if self.gather_output:
# All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias

class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.

Expand Down
27 changes: 26 additions & 1 deletion megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import subprocess
from dataclasses import dataclass
from dataclasses import dataclass, field

try:
from .template import NeoXArgsTemplate
Expand Down Expand Up @@ -355,11 +355,36 @@ class NeoXArgsModel(NeoXArgsTemplate):
"""

output_layer_parallelism: Literal["row", "column"] = "row"
ia3_prompt_tuning: bool = False
"""
Run IA3 prompt tuning based off:
Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning
https://arxiv.org/pdf/2205.05638.pdf
"""

"""
Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column)
"""

self_attention_cls: str = "ParallelSelfAttention"
"""
Default class to use for self attention
"""

mlp_column_parallel_cls: str = "ColumnParallelLinear"
"""
Default class to use for linear column layer parallelism
"""

no_weight_decay_params: list = field(default_factory=lambda: ["bias", "l_ff", "l_v", "l_k"])
"""
Which parameters we won't apply weight decay to
"""

load_module_strict: bool = True
"""
Whether to strictly enforce that the keys in state_dict of module & checkpoint match.
"""

@dataclass
class NeoXArgsOptimizer(NeoXArgsTemplate):
Expand Down
14 changes: 14 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,10 @@ def get_model(neox_args, use_cache=False):
# If mup isn't being used anyways, this has no effect.
old_use_mup = neox_args.use_mup
neox_args.use_mup = False
if neox_args.ia3_prompt_tuning:
neox_args.mlp_column_parallel_cls = "ColumnParallelLinearIA3"
neox_args.self_attention_cls = "ParallelSelfAttentionIA3"

model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
Expand Down Expand Up @@ -412,6 +416,16 @@ def get_model(neox_args, use_cache=False):
for name, param in model.named_parameters():
if not "soft_embedding" in name:
param.requires_grad = False
elif neox_args.ia3_prompt_tuning:
layers_to_train = ["l_ff", "l_k", "l_v"]
for name, param in model.named_parameters():
if not any([x in name for x in layers_to_train]):
param.requires_grad = False

trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
print(f"Number of trainable parameters: {trainable_params}")

if not neox_args.is_pipe_parallel:
# Export PipeParallel model to nn.Sequential model to avoid the overhead of deepspeed's pipe parallel training
Expand Down