Skip to content

Commit

Permalink
move custom methods to processors
Browse files Browse the repository at this point in the history
  • Loading branch information
VladOS95-cyber committed Nov 1, 2024
1 parent 45b5cfe commit 4b412ab
Showing 1 changed file with 64 additions and 66 deletions.
130 changes: 64 additions & 66 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,31 @@ def process(self, weights, name, **kwargs):
if None in (num_heads, num_kv_heads):
return weights, name
if ".attn_q." in name:
weights = _reverse_permute_weights(weights, num_heads, num_heads)
weights = self._reverse_permute_weights(weights, num_heads, num_heads)
elif ".attn_k." in name:
weights = _reverse_permute_weights(weights, num_heads, num_kv_heads)
weights = self._reverse_permute_weights(weights, num_heads, num_kv_heads)
return {"weights": weights, "name": name, "metadata": {}}

def _reverse_permute_weights(
self, weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None
) -> np.ndarray:
# Original permutation implementation
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408
if num_kv_heads is not None and n_head != num_kv_heads:
n_head = num_kv_heads

dim = weights.shape[0] // n_head // 2
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
return w.swapaxes(2, 1).reshape(weights.shape)


class Qwen2MoeTensorProcessor(TensorProcessor):
def process(self, weights, name, **kwargs):
if "_exp" in name:
tensor_key_mapping = kwargs.get("tensor_key_mapping")
config = kwargs.get("config", {})
if tensor_key_mapping:
_split_moe_expert_tensor(weights, config, name, tensor_key_mapping)
self._split_moe_expert_tensor(weights, config, name, tensor_key_mapping)
return {
"weights": weights,
"name": None, # Signal to skip further processing
Expand All @@ -94,6 +106,29 @@ def process(self, weights, name, **kwargs):
weights = np.expand_dims(weights, axis=0)
return {"weights": weights, "name": name, "metadata": {}}

def _split_moe_expert_tensor(
self, weights: np.ndarray, parsed_parameters: Dict[str, Dict], name: str, tensor_key_mapping: dict
):
# Original merge implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022
exp_name = ""
if "ffn_gate_exps" in name:
exp_name = "gate_proj"
elif "ffn_down_exps" in name:
exp_name = "down_proj"
elif "ffn_up_exps" in name:
exp_name = "up_proj"
else:
raise ValueError(f"Cannot map expert tensor {name} in Qwen2Moe architecture.")
for tensor_name in tensor_key_mapping:
if tensor_name in name:
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
w_counter = parsed_parameters["config"].get("num_experts", 60)
for i in range(0, w_counter):
temp_name = name.replace(".weight", f".{i}.{exp_name}.weight")
exp_weight = weights[i]
parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight))


class BloomTensorProcessor(TensorProcessor):
def process(self, weights, name, **kwargs):
Expand All @@ -102,11 +137,35 @@ def process(self, weights, name, **kwargs):
num_heads = config["n_head"]
n_embed = config["hidden_size"]
if "weight" in name:
weights = _reverse_reshape_weights(weights, num_heads, n_embed)
weights = self._reverse_reshape_weights(weights, num_heads, n_embed)
else:
weights = _reverse_reshape_bias(weights, num_heads, n_embed)
weights = self._reverse_reshape_bias(weights, num_heads, n_embed)
return {"weights": weights, "name": name, "metadata": {}}

def _reverse_reshape_weights(self, weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985
q, k, v = np.array_split(weights, 3, axis=0)

q = q.reshape(n_head, n_embed // n_head, n_embed)
k = k.reshape(n_head, n_embed // n_head, n_embed)
v = v.reshape(n_head, n_embed // n_head, n_embed)
qkv_weights = np.stack([q, k, v], axis=1)

return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed)

def _reverse_reshape_bias(self, weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998
q_bias, k_bias, v_bias = np.array_split(weights, 3)

q_bias = q_bias.reshape(n_head, n_embed // n_head)
k_bias = k_bias.reshape(n_head, n_embed // n_head)
v_bias = v_bias.reshape(n_head, n_embed // n_head)

qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
return qkv_bias


class T5TensorProcessor(TensorProcessor):
def process(self, weights, name, **kwargs):
Expand Down Expand Up @@ -316,64 +375,3 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")

return parsed_parameters


def _reverse_permute_weights(weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None) -> np.ndarray:
# Original permutation implementation
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408
if num_kv_heads is not None and n_head != num_kv_heads:
n_head = num_kv_heads

dim = weights.shape[0] // n_head // 2
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
return w.swapaxes(2, 1).reshape(weights.shape)


def _reverse_reshape_weights(weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985
q, k, v = np.array_split(weights, 3, axis=0)

q = q.reshape(n_head, n_embed // n_head, n_embed)
k = k.reshape(n_head, n_embed // n_head, n_embed)
v = v.reshape(n_head, n_embed // n_head, n_embed)
qkv_weights = np.stack([q, k, v], axis=1)

return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed)


def _reverse_reshape_bias(weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998
q_bias, k_bias, v_bias = np.array_split(weights, 3)

q_bias = q_bias.reshape(n_head, n_embed // n_head)
k_bias = k_bias.reshape(n_head, n_embed // n_head)
v_bias = v_bias.reshape(n_head, n_embed // n_head)

qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
return qkv_bias


def _split_moe_expert_tensor(
weights: np.ndarray, parsed_parameters: Dict[str, Dict], name: str, tensor_key_mapping: dict
):
# Original merge implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022
exp_name = ""
if "ffn_gate_exps" in name:
exp_name = "gate_proj"
elif "ffn_down_exps" in name:
exp_name = "down_proj"
elif "ffn_up_exps" in name:
exp_name = "up_proj"
else:
raise ValueError(f"Cannot map expert tensor {name} in Qwen2Moe architecture.")
for tensor_name in tensor_key_mapping:
if tensor_name in name:
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
w_counter = parsed_parameters["config"].get("num_experts", 60)
for i in range(0, w_counter):
temp_name = name.replace(".weight", f".{i}.{exp_name}.weight")
exp_weight = weights[i]
parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight))

0 comments on commit 4b412ab

Please sign in to comment.