Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix blockwise sharding #149

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions default_shardings/llama-blockwise-quant.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,24 @@


freqs_cis : -1 # torch.complex64 (2048, 64)
tok_embeddings.weight : 1 # torch.int8 (32000, 4096)
tok_embeddings.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.attention.wo.weight : 2 # torch.int8 (32, 128, 4096)
layers.*.attention.wo.weight_scaler : 1 # torch.bfloat16 (32, 4096)
layers.*.attention.wq.weight : 0 # torch.int8 (32, 128, 4096)
layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (32, 4096)
layers.*.attention.wk.weight : 0 # torch.int8 (32, 128, 4096)
layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (32, 4096)
layers.*.attention.wv.weight : 0 # torch.int8 (32, 128, 4096)
layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (32, 4096)
layers.*.feed_forward.w1.weight : 0 # torch.int8 (32, 128, 11008)
layers.*.feed_forward.w1.weight_scaler : 0 # torch.bfloat16 (32, 11008)
layers.*.feed_forward.w2.weight : 2 # torch.int8 (86, 128, 4096)
layers.*.feed_forward.w2.weight_scaler : 1 # torch.bfloat16 (86, 4096)
layers.*.feed_forward.w3.weight : 0 # torch.int8 (32, 128, 11008)
layers.*.feed_forward.w3.weight_scaler : 0 # torch.bfloat16 (32, 11008)
tok_embeddings.weight : -1 # torch.int8 (32000, 4096)
tok_embeddings.weight_scaler : -1 # torch.bfloat16 (4096,)
layers.*.attention.wo.weight : 0 # torch.int8 (32, 128, 4096)
layers.*.attention.wo.weight_scaler : 0 # torch.bfloat16 (32, 4096)
layers.*.attention.wq.weight : 2 # torch.int8 (32, 128, 4096)
layers.*.attention.wq.weight_scaler : 1 # torch.bfloat16 (32, 4096)
layers.*.attention.wk.weight : 2 # torch.int8 (32, 128, 4096)
layers.*.attention.wk.weight_scaler : 1 # torch.bfloat16 (32, 4096)
layers.*.attention.wv.weight : 2 # torch.int8 (32, 128, 4096)
layers.*.attention.wv.weight_scaler : 1 # torch.bfloat16 (32, 4096)
layers.*.feed_forward.w1.weight : 2 # torch.int8 (32, 128, 11008)
layers.*.feed_forward.w1.weight_scaler : 1 # torch.bfloat16 (32, 11008)
layers.*.feed_forward.w2.weight : 0 # torch.int8 (86, 128, 4096)
layers.*.feed_forward.w2.weight_scaler : 0 # torch.bfloat16 (86, 4096)
layers.*.feed_forward.w3.weight : 2 # torch.int8 (32, 128, 11008)
layers.*.feed_forward.w3.weight_scaler : 1 # torch.bfloat16 (32, 11008)
layers.*.attention_norm.weight : -1 # torch.float32 (4096,)
layers.*.ffn_norm.weight : -1 # torch.float32 (4096,)
norm.weight : -1 # torch.float32 (4096,)
output.weight : 0 # torch.int8 (32, 128, 32000)
output.weight_scaler : 0 # torch.float32 (32, 32000)
output.weight : 2 # torch.int8 (32, 128, 32000)
output.weight_scaler : 1 # torch.float32 (32, 32000)
7 changes: 4 additions & 3 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,9 +689,9 @@ def _load_from_safetensors(self, path):
if key == "freqs_cis":
continue
weights[key] = f.get_tensor(key)
assert tuple(model_weights.shape) == tuple(
weights[key].shape
), f"key: {key} error: {model_weights.shape} != {weights[key].shape}"
# assert tuple(model_weights.shape) == tuple(
# weights[key].shape
# ), f"key: {key} error: {model_weights.shape} != {weights[key].shape}"
weights["freqs_cis"] = torch_xla2.tensor.t2j(self.pt_model.freqs_cis)
return weights

Expand Down Expand Up @@ -730,6 +730,7 @@ def load_params(self) -> Params:
quantize_linear_weights_scaler_map = (
self.pt_model.get_quantized_linear_weight_to_scaler_map()
)
self.pt_model.process_weight_hook(jax_weights, env=self.env)
with jax.default_device(jax.devices("cpu")[0]):
for key, val in jax_weights.items():
for qname in quantize_linear_weights_scaler_map.keys():
Expand Down
29 changes: 19 additions & 10 deletions jetstream_pt/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ def __init__(
out_features,
bias=False,
device=None,
quant_config=QuantizationConfig(),
env=None,
):
super().__init__()
quant_config = env.quant_config
self.in_features = in_features
self.out_features = out_features

Expand Down Expand Up @@ -175,26 +176,34 @@ def __init__(
out_features,
bias=False,
device=None,
quant_config=QuantizationConfig(),
round_out_features=False,
env=None,
):
super().__init__()
quant_config = env.quant_config
assert (
not quant_config.enable_activation_quantization
), "Activation quantization not supported for blockwise quantized matmul."

self.block_size = quant_config.block_size_weight
self.in_features = in_features
num_partitions = env.mesh.size
if round_out_features and (out_features % (self.block_size * num_partitions)) != 0:
# Make sure out_features is multiple of 128 * num_partitions.
out_features = ((out_features // (self.block_size * num_partitions)) + 1) * (num_partitions * self.block_size)
self.out_features = out_features

n_blocks = in_features // self.block_size
if n_blocks % num_partitions != 0:
n_blocks = ((n_blocks // num_partitions) + 1) * num_partitions

# Use dot general instead of einsum
# Use dot general is slow now.
self.use_dot_general = False
# Flatten einsum operands to 3D. XLA was slow if operands are 4D. But it's fixed now.
# Same perf as non flattened one now.
self.flatten = False

self.block_size = quant_config.block_size_weight
n_blocks = in_features // self.block_size

assert (
not quant_config.enable_activation_quantization
), "Activation quantization not supported for blockwise quantized matmul."

if self.use_dot_general:
weight = torch.ones(
(n_blocks, out_features, self.block_size),
Expand Down Expand Up @@ -516,7 +525,7 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
LinearLayer = get_quantized_linear_layer(env.quant_config)
linear_kwargs = {}
if LinearLayer != torch.nn.Linear:
linear_kwargs = {"quant_config": env.quant_config}
linear_kwargs = {"env": env}

self.wo = LinearLayer(
n_heads * self.head_dim,
Expand Down
4 changes: 2 additions & 2 deletions jetstream_pt/third_party/gemma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
)
linear_kwargs = {}
if Linear != torch.nn.Linear:
linear_kwargs = {"quant_config": env.quant_config}
linear_kwargs = {"env": env}

self.wq = Linear(
hidden_size,
Expand Down Expand Up @@ -237,7 +237,7 @@ def __init__(
)
linear_kwargs = {}
if Linear != torch.nn.Linear:
linear_kwargs = {"quant_config": env.quant_config}
linear_kwargs = {"env": env}

self.gate_proj = Linear(
hidden_size,
Expand Down
49 changes: 43 additions & 6 deletions jetstream_pt/third_party/llama/model_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, List, Optional

import jax
import jax.numpy as jnp
import torch
import torch.nn.functional as F
from jetstream_pt.layers import (
Expand Down Expand Up @@ -41,30 +42,37 @@ def __init__(
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

LinearLayer = get_quantized_linear_layer(env.quant_config)
linear_kwargs = {}
w1_w3_linear_kwargs = {}
w2_linear_kwargs = {}
if LinearLayer != torch.nn.Linear:
linear_kwargs["quant_config"] = env.quant_config
w1_w3_linear_kwargs["env"] = env
w2_linear_kwargs["env"] = env
if env.quant_config.is_blockwise_weight:
# To make w2's n_block is divisible by the number of partitions,
# The out_features of w1/w3 need to round up.
w1_w3_linear_kwargs["round_out_features"] = True


self.w1 = LinearLayer(
dim,
hidden_dim,
bias=False,
device=device,
**linear_kwargs,
**w1_w3_linear_kwargs,
)
self.w2 = LinearLayer(
hidden_dim,
dim,
bias=False,
device=device,
**linear_kwargs,
**w2_linear_kwargs,
)
self.w3 = LinearLayer(
dim,
hidden_dim,
bias=False,
device=device,
**linear_kwargs,
**w1_w3_linear_kwargs,
)

def forward(self, x):
Expand Down Expand Up @@ -179,7 +187,7 @@ def __init__(
LinearLayer = get_quantized_linear_layer(env.quant_config)
linear_kwargs = {}
if LinearLayer != torch.nn.Linear:
linear_kwargs["quant_config"] = env.quant_config
linear_kwargs["env"] = env

self.output = LinearLayer(
params.dim,
Expand Down Expand Up @@ -267,6 +275,35 @@ def get_quantized_embedding_weight_to_scaler_map():
return {
"tok_embeddings.weight": "tok_embeddings.weight_scaler",
}

@staticmethod
def process_weight_hook(jax_weights, env=None):
# Right now we only process weights for blockwise quantization.
# We pad the weights so that the sharded dimension size is divisible by the number of partitions.
quant_config = env.quant_config
num_partitions = env.mesh.size
if quant_config.enable_weight_quantization and quant_config.is_blockwise_weight:
block_size = quant_config.block_size_weight
for k, v in jax_weights.items():
if "w1" in k or "w3" in k:
# Pad w1/w3 to make n_out_channel divisible by num_partitions * block_size.
# This is to make w2's n_block is divisible by the number of partitions.
n_out_channel = v.shape[-1]
multiple_of = block_size * num_partitions
if n_out_channel % (multiple_of) != 0:
n_pad = multiple_of - n_out_channel % (multiple_of)
pad = jnp.zeros(v.shape[:-1] + (n_pad,)).astype(v.dtype)
padded = jnp.concatenate([v, pad], axis=-1)
jax_weights[k] = padded
if "w2" in k:
# Pad w2 to make n_block is divisible by the number of partitions.
n_blocks = v.shape[0]
if n_blocks % num_partitions != 0:
n_pad = num_partitions - n_blocks % (num_partitions)
pad = jnp.zeros((n_pad,) + v.shape[1:]).astype(v.dtype)
padded = jnp.concatenate([v, pad], axis=0)
jax_weights[k] = padded


@staticmethod
def get_weight_sharding_type(model_name: str = ""):
Expand Down
Loading