Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hf llama to neox conversion #1247

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
- name: install pytest
run: python3 -m pip install pytest pytest-forked pyyaml requests wandb
- name: install torch
run: python3 -m pip install torch
run: python3 -m pip install torch
- name: install requirements
run: pip install -r requirements/requirements.txt
- name: Run Tests
Expand Down
12 changes: 6 additions & 6 deletions megatron/data/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
}

} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {

if (!second) {
if (verbose) {
Expand Down Expand Up @@ -660,9 +660,9 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {

if (!second) {
if (verbose) {
Expand Down
1 change: 1 addition & 0 deletions megatron/model/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def get_norm(neox_args):
eps = neox_args.layernorm_epsilon
if neox_args.layernorm_fusion:
from .fused_layer_norm import MixedFusedLayerNorm

norm = MixedFusedLayerNorm
else:
norm = LayerNorm
Expand Down
8 changes: 5 additions & 3 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,9 @@ def calculate_batch_parameters(

# either none of the three parameters are provided or just gradient_accumulation_step is provided
else:
assert False, "Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided"
assert (
False
), "Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided"
return int(train_batch), int(micro_batch), int(grad_acc)

@staticmethod
Expand Down Expand Up @@ -1098,8 +1100,8 @@ def calculate_derived(self):
if "flash" in self.attention_config:
_flash_version = packaging.version.Version(version("flash-attn"))
if self.sliding_window_width is not None:
assert (
_flash_version >= packaging.version.Version("2.3.0")
assert _flash_version >= packaging.version.Version(
"2.3.0"
), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.3.0 to support sliding window attention."
if self.pos_emb == "alibi":
if not _flash_version >= packaging.version.Version("2.4.0.post1"):
Expand Down
17 changes: 17 additions & 0 deletions tools/ckpts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,20 @@ options:
--num_output_shards NUM_OUTPUT_SHARDS
--pipeline_parallel Only use if PP>1
```

### `convert_hf_llama_to_neox.py`
Takes an HF Llama checkpoint and puts it into a NeoX-compatible format.

Note that this does not support pipeline parallelism!

```
usage: convert_hf_llama_to_neox.py [-h] [--tp TP] [--pp PP] [--model MODEL] [--model_path MODEL_PATH]

options:
-h, --help show this help message and exit
--tp TP Number of tensor parallelism ranks
--pp PP Number of pipeline parallelism stages
--model MODEL HF model name
--model_path MODEL_PATH
Path to save model
```
219 changes: 219 additions & 0 deletions tools/ckpts/convert_hf_llama_to_neox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import torch
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import tqdm


def convert_model(hf_state_dict, hf_config, tp_ranks):
conv_state_dicts = [{} for _ in range(tp_ranks)]
# get embeddings...
for i, chunk in enumerate(
torch.chunk(hf_state_dict["model.embed_tokens.weight"], tp_ranks, dim=0)
):
conv_state_dicts[i][
"sequential.0.word_embeddings.weight"
] = chunk.clone().detach()
print(
"model.embed_tokens.weight",
hf_state_dict["model.embed_tokens.weight"].shape,
"sequential.0.word_embeddings.weight",
conv_state_dicts[0]["sequential.0.word_embeddings.weight"].shape,
)
# Get config data...
num_kv_heads = hf_config.num_key_value_heads
num_q_heads = hf_config.num_attention_heads
head_dim = hf_config.hidden_size // num_q_heads
# do layers...
for layer_num in tqdm.tqdm(range(model.model.config.num_hidden_layers)):
# --- attention ---
# Output first since it's a simple row parallel...
for i, chunk in enumerate(
torch.chunk(
hf_state_dict[f"model.layers.{layer_num}.self_attn.o_proj.weight"],
tp_ranks,
dim=1,
)
):
conv_state_dicts[i][
f"sequential.{layer_num+2}.attention.dense.weight"
] = chunk.clone().detach()
print(
f"model.layers.{layer_num}.self_attn.o_proj.weight",
hf_state_dict[f"model.layers.{layer_num}.self_attn.o_proj.weight"].shape,
f"sequential.{layer_num+2}.attention.dense.weight",
conv_state_dicts[0][
f"sequential.{layer_num+2}.attention.dense.weight"
].shape,
)
# Now for attention...
# Split into heads...
q = hf_state_dict[f"model.layers.{layer_num}.self_attn.q_proj.weight"]
k = hf_state_dict[f"model.layers.{layer_num}.self_attn.k_proj.weight"]
v = hf_state_dict[f"model.layers.{layer_num}.self_attn.v_proj.weight"]
# The GQA code splits the heads by the num_q_heads so we also do that
# here to ensure it matches...
q = q.view(num_q_heads, -1, q.shape[-1])
k = k.view(num_q_heads, -1, q.shape[-1])
v = v.view(num_q_heads, -1, q.shape[-1])
# Chunk for tensor parallelism...
for i, q_chunk, k_chunk, v_chunk in zip(
range(tp_ranks),
torch.chunk(q, tp_ranks, dim=0),
torch.chunk(k, tp_ranks, dim=0),
torch.chunk(v, tp_ranks, dim=0),
):
# Need to join the heads across q, k, v...
conv_state_dicts[i][
f"sequential.{layer_num+2}.attention.query_key_value.weight"
] = (
torch.cat([q_chunk, k_chunk, v_chunk], dim=1)
.view(-1, q.shape[-1])
.clone()
.detach()
)
print(
f"model.layers.{layer_num}.self_attn.(q/k/v)_proj.weight",
hf_state_dict[f"model.layers.{layer_num}.self_attn.q_proj.weight"].shape,
hf_state_dict[f"model.layers.{layer_num}.self_attn.k_proj.weight"].shape,
hf_state_dict[f"model.layers.{layer_num}.self_attn.v_proj.weight"].shape,
f"sequential.{layer_num+2}.attention.query_key_value.weight",
conv_state_dicts[0][
f"sequential.{layer_num+2}.attention.query_key_value.weight"
].shape,
)
# --- mlp ---
# Do SwiGLU weights...
# w1...
for i, chunk in enumerate(
torch.chunk(
hf_state_dict[f"model.layers.{layer_num}.mlp.gate_proj.weight"],
tp_ranks,
dim=0,
)
):
conv_state_dicts[i][
f"sequential.{layer_num+2}.mlp.w1.weight"
] = chunk.clone().detach()
print(
f"model.layers.{layer_num}.mlp.gate_proj.weight",
hf_state_dict[f"model.layers.{layer_num}.mlp.gate_proj.weight"].shape,
f"sequential.{layer_num+2}.mlp.w1.weight",
conv_state_dicts[0][f"sequential.{layer_num+2}.mlp.w1.weight"].shape,
)
# w3...
for i, chunk in enumerate(
torch.chunk(
hf_state_dict[f"model.layers.{layer_num}.mlp.up_proj.weight"],
tp_ranks,
dim=0,
)
):
conv_state_dicts[i][
f"sequential.{layer_num+2}.mlp.w3.weight"
] = chunk.clone().detach()
print(
f"model.layers.{layer_num}.mlp.up_proj.weight",
hf_state_dict[f"model.layers.{layer_num}.mlp.up_proj.weight"].shape,
f"sequential.{layer_num+2}.mlp.w3.weight",
conv_state_dicts[0][f"sequential.{layer_num+2}.mlp.w3.weight"].shape,
)
# w2 (output)...
for i, chunk in enumerate(
torch.chunk(
hf_state_dict[f"model.layers.{layer_num}.mlp.down_proj.weight"],
tp_ranks,
dim=1,
)
):
conv_state_dicts[i][
f"sequential.{layer_num+2}.mlp.w2.weight"
] = chunk.clone().detach()
print(
f"model.layers.{layer_num}.mlp.down_proj.weight",
hf_state_dict[f"model.layers.{layer_num}.mlp.down_proj.weight"].shape,
f"sequential.{layer_num+2}.mlp.w2.weight",
conv_state_dicts[0][f"sequential.{layer_num+2}.mlp.w2.weight"].shape,
)
# --- norm ---
for i in range(tp_ranks):
conv_state_dicts[i][f"sequential.{layer_num+2}.input_layernorm.scale"] = (
hf_state_dict[f"model.layers.{layer_num}.input_layernorm.weight"]
.clone()
.detach()
)
conv_state_dicts[i][
f"sequential.{layer_num+2}.post_attention_layernorm.scale"
] = (
hf_state_dict[
f"model.layers.{layer_num}.post_attention_layernorm.weight"
]
.clone()
.detach()
)

# Get final ln/linear....
index = model.model.config.num_hidden_layers + 3
for i in range(tp_ranks):
conv_state_dicts[i][f"sequential.{index}.norm.scale"] = (
hf_state_dict["model.norm.weight"].clone().detach()
)
index += 1
# do output...
for i, chunk in enumerate(
torch.chunk(hf_state_dict["lm_head.weight"], tp_ranks, dim=0)
):
conv_state_dicts[i][
f"sequential.{index}.final_linear.weight"
] = chunk.clone().detach()
print(
"lm_head.weight",
hf_state_dict["lm_head.weight"].shape,
f"sequential.{index}.final_linear.weight",
conv_state_dicts[0][f"sequential.{index}.final_linear.weight"].shape,
)
return conv_state_dicts


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tp", type=int, default=1, help="Number of tensor parallelism ranks"
)
parser.add_argument(
"--pp", type=int, default=0, help="Number of pipeline parallelism stages"
)
parser.add_argument("--model", type=str, default="gpt2", help="HF model name")
parser.add_argument(
"--model_path", type=str, default=None, help="Path to save model"
)
args = parser.parse_args()
assert args.pp == 0, "Pipeline parallelism not supported yet"
tokenizer = AutoTokenizer.from_pretrained(args.model).save_pretrained(
args.model_path + "/tokenizer"
)
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype="auto")
state_dict = model.state_dict()
for key in state_dict.keys():
print(key, state_dict[key].shape)
os.makedirs(args.model_path, exist_ok=True)
# Setup model directory...
os.makedirs(f"{args.model_path}/0", exist_ok=True)
# Save the latest file so neox can figure out where to grab the weights...
with open(f"{args.model_path}/latest", "w") as f:
f.write("0")
# Convert the model...
tp_state_dicts = convert_model(state_dict, model.model.config, args.tp)
for i in range(args.tp):
torch.save(
{
"dp_world_size": 1,
"mp_world_size": args.tp,
"optimizer": {},
"global_steps": 1,
"skipped_steps": 1,
"iteration": 1,
"module": tp_state_dicts[i],
},
f"{args.model_path}/0/mp_rank_{i:02d}_model_states.pt",
)
Loading