forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hf_lora_convert.py
208 lines (184 loc) · 7.35 KB
/
hf_lora_convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#! /usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import datetime
import json
import logging
import re
from collections import defaultdict
from pathlib import Path
import numpy as np
import torch
from tensorrt_llm._utils import str_dtype_to_torch, torch_to_numpy
from tensorrt_llm.lora_manager import LoraManager
from tensorrt_llm.models.convert_utils import get_model_path, load_state_dict
log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s"
logging.basicConfig(format=log_format)
LOGGER = logging.getLogger(__name__)
def save_val(val, dir, key, tp_num=None, write_npy=False):
ext = "npy" if write_npy else "bin"
suffix = ext if tp_num is None else f"{tp_num}.{ext}"
if write_npy:
np.save(dir / f"model.{key}.{suffix}", val)
else:
val.tofile(dir / f"model.{key}.{suffix}")
def get_all_lora_weights(lora_weights):
all_weights = defaultdict(lambda: defaultdict(dict))
pattern = re.compile(
r'.*\.layers\.([0-9]+)\.(self_attn|mlp)\.([a-z_]+)\.lora_(A|B)\.weight.*'
)
for key, weights in lora_weights.items():
m = pattern.match(key)
if not m:
print(f"no match {key}")
continue
layer_idx = int(m.group(1))
hf_module = m.group(3)
inout = "in" if m.group(4) == "A" else "out"
all_weights[layer_idx][hf_module][inout] = weights
return all_weights
def preprocess_lora_weights(lora_model):
# Swap weights of gate_up_proj
for key, value in lora_model.items():
if "gate_up_proj.lora_B.weight" in key:
print("Swap {}".format(key))
original_weights = value.contiguous().clone()
half_split = original_weights.shape[0] // 2
first_half = original_weights[:half_split, :]
second_half = original_weights[half_split:, :]
value = torch.cat((second_half, first_half), dim=0)
lora_model[key] = value
return lora_model
hf_modules_to_trtllm_modules = {
"q_proj": "attn_q",
"v_proj": "attn_v",
"k_proj": "attn_k",
"qkv_proj": "attn_qkv",
"query_key_value": "attn_qkv",
"o_proj": "attn_dense",
"dense": "attn_dense",
"gate_proj": "mlp_h_to_4h",
"down_proj": "mlp_4h_to_h",
"up_proj": "mlp_gate",
"gate_up_proj": "mlp_h_to_4h",
"c_fc": "mlp_h_to_4h",
"c_proj": "mlp_4h_to_h",
} # lora modules on llama
hf_modules_to_module_id = {
k: LoraManager.LORA_MODULE_IDS[v]
for k, v in hf_modules_to_trtllm_modules.items()
}
def convert_hf_model(model_dir, dtype, out_dir):
saved_dir = Path(out_dir)
saved_dir.mkdir(parents=True, exist_ok=True)
with open(f"{model_dir}/adapter_config.json", "r") as f:
config = json.load(f)
rank = config.get("r")
alpha = config.get("lora_alpha")
use_rslora = config.get("use_rslora", False)
if use_rslora:
scale = alpha / np.sqrt(rank)
else:
scale = alpha / rank
lora_model = load_state_dict(get_model_path(model_dir, "adapter_model"))
lora_model = preprocess_lora_weights(lora_model)
all_weights = get_all_lora_weights(lora_model)
converted_weights = []
converted_config = []
for layer_idx, layer_weights in all_weights.items():
for hf_module, module_weights in layer_weights.items():
in_weights = module_weights['in']
out_weights = module_weights['out']
in_out_weights = []
adapter_size = 0
for w, inout in ((in_weights, "in"), (out_weights, "out")):
assert len(w.shape) == 2
# assume that the hidden dim is the larger of the 2
dim0 = w.shape[0]
dim1 = w.shape[1]
adapter_size = min(dim0, dim1)
# in_weights should have shape [adaper_size, hidden]
if dim1 < dim0 and inout == "in":
adapter_size = dim1
w = w.transpose(1, 0)
# out_weights should have shape [hidden, adapter_size]
elif dim0 < dim1 and inout == "out":
adapter_size = dim0
w = w.transpose(1, 0)
if inout == "out":
w = w * scale
w = w.contiguous().flatten().to(dtype=str_dtype_to_torch(dtype))
in_out_weights.append(w)
in_out_weights = torch.concatenate(in_out_weights).flatten()
converted_weights.append(in_out_weights)
converted_config.append(
[hf_modules_to_module_id[hf_module], layer_idx, adapter_size])
max_row_size = 0
for t in converted_weights:
max_row_size = max(max_row_size, t.shape[0])
for i in range(len(converted_weights)):
converted_weights[i] = torch.nn.functional.pad(
converted_weights[i],
(0, max_row_size - converted_weights[i].shape[0])).unsqueeze(0)
converted_weights = torch_to_numpy(
torch.concatenate(
converted_weights,
dim=0).unsqueeze(0).to(dtype=str_dtype_to_torch(dtype)).cpu())
converted_config = torch.tensor(converted_config,
dtype=torch.int32,
device='cpu').unsqueeze(0).numpy()
save_val(converted_weights,
saved_dir,
"lora_weights",
tp_num=None,
write_npy=True)
save_val(converted_config,
saved_dir,
"lora_config",
tp_num=None,
write_npy=True)
def main(args):
start_time = datetime.datetime.now()
convert_hf_model(args.in_file, args.storage_type, args.out_dir)
LOGGER.info("Spent %s (h:m:s) to convert the prompt model",
datetime.datetime.now() - start_time)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--out-dir',
'-o',
type=Path,
help='path to output embedding table file in the .npy format',
required=True)
parser.add_argument('--in-file',
'-i',
type=Path,
help='path to input lora checkpoint file',
required=True)
parser.add_argument("--verbose",
action="store_true",
help="Provide verbose messages")
parser.add_argument("--storage-type",
type=str,
default="float16",
choices=["float32", "float16", "bfloat16"])
args = parser.parse_args()
LOGGER.setLevel(logging.DEBUG if args.verbose else logging.INFO)
print("\n=============== Argument ===============")
for key in vars(args):
print(f"{key}: {vars(args)[key]}")
print("========================================")
main(args)