Skip to content

Refactor _get_source_transforms to remove args #10519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 117 additions & 32 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,10 +651,31 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
_get_source_transforms(
modelname=args.model,
dtype_override=dtype_override,
checkpoint=args.checkpoint,
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore
args=args,
tokenizer_path=args.tokenizer_path,
use_spin_quant=args.use_spin_quant,
embedding_quantize=args.embedding_quantize,
quantization_mode=args.quantization_mode,
expand_rope_table=args.expand_rope_table,
use_custom_sdpa_with_attention_mask=getattr(args, "use_custom_sdpa_with_attention_mask", False),
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
quantize_kv_cache=args.quantize_kv_cache,
use_kv_cache=args.use_kv_cache,
qnn=args.qnn,
use_qnn_sha=args.use_qnn_sha,
optimized_rotation_path=args.optimized_rotation_path,
mps=args.mps,
coreml=args.coreml,
coreml_ios=args.coreml_ios,
vulkan=args.vulkan,
use_shared_embedding=args.use_shared_embedding,
use_qat=args.use_qat,
use_lora=args.use_lora,
preq_mode=args.preq_mode,
preq_group_size=args.preq_group_size,
preq_embedding_quantize=args.preq_embedding_quantize,
)
)

Expand Down Expand Up @@ -1145,23 +1166,65 @@ def _load_llama_model(


def _get_source_transforms( # noqa
modelname: str,
dtype_override: DType,
*,
checkpoint: Optional[str] = None,
checkpoint_dtype: Optional[DType] = None,
args,
tokenizer_path: Optional[str] = None,
use_spin_quant: Optional[str] = None,
embedding_quantize: Optional[str] = None,
quantization_mode: Optional[str] = None,
expand_rope_table: bool = False,
use_custom_sdpa_with_attention_mask: bool = False,
use_sdpa_with_kv_cache: bool = False,
quantize_kv_cache: bool = False,
use_kv_cache: bool = False,
qnn: bool = False,
use_qnn_sha: bool = False,
optimized_rotation_path: Optional[str] = None,
mps: bool = False,
coreml: bool = False,
coreml_ios: int = 15,
vulkan: bool = False,
use_shared_embedding: bool = False,
use_qat: bool = False,
use_lora: int = 0,
preq_mode: Optional[str] = None,
preq_group_size: int = 32,
preq_embedding_quantize: str = "8,0",
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
"""
Return a list of functions that transform a graph.

Args:
modelname: The name of the model.
dtype_override: The dtype to use for the model.
checkpoint: Path to the checkpoint file.
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
it means that you want to run quantize transformations on the weights represented
in their original dtype, while the overall dtype of the model maybe something
different. If not specified, defaults to dtype_override.
args: The arguments passed to the script.
tokenizer_path: Path to the tokenizer file.
use_spin_quant: Type of spin quant to use ("cuda" or "native").
embedding_quantize: Type of embedding quantization.
quantization_mode: Type of quantization mode.
expand_rope_table: Whether to expand rope table.
use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
quantize_kv_cache: Whether to quantize KV cache.
use_kv_cache: Whether to use KV cache.
qnn: Whether to use QNN.
use_qnn_sha: Whether to use QNN SHA.
optimized_rotation_path: Path to optimized rotation.
mps: Whether to use MPS.
coreml: Whether to use CoreML.
coreml_ios: CoreML iOS version.
vulkan: Whether to use Vulkan.
use_shared_embedding: Whether to use shared embedding.
use_qat: Whether to use QAT.
use_lora: LoRA rank (0 means no LoRA).
preq_mode: Pre-quantization mode.
preq_group_size: Pre-quantization group size.
preq_embedding_quantize: Pre-quantization embedding quantize.

Returns:
A list of transformation functions.
Expand All @@ -1172,21 +1235,21 @@ def _get_source_transforms( # noqa

transforms = []

if args.use_spin_quant:
if args.use_spin_quant == "cuda":
if use_spin_quant:
if use_spin_quant == "cuda":
from .source_transformation.spin_quant import (
inject_fast_hadamard_transform_cuda_for_spin_quant,
)

transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
elif args.use_spin_quant == "native":
elif use_spin_quant == "native":
from .source_transformation.spin_quant import (
inject_fast_hadamard_transform_native_for_spin_quant,
)

transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)

if args.embedding_quantize:
if embedding_quantize:
"""
When this option is selected, it finds all embedding layers and transforms
into quantized embedding equivalent module.
Expand All @@ -1196,12 +1259,25 @@ def _get_source_transforms( # noqa
transformations based on the given checkpoint first. In those cases,
this wil be a no-op.
"""
modelname = f"{modelname}_e"
# Create a mock args object with the necessary attributes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it to limit the extend of changes? It's better than Namespace, but I'm wondering if eventually we need to group args, and there can still be passing the entire args even not all fields are needed.

class Args:
pass
args = Args()
args.checkpoint = checkpoint
args.tokenizer_path = tokenizer_path
args.embedding_quantize = embedding_quantize
args.use_shared_embedding = use_shared_embedding
args.use_qat = use_qat
args.use_lora = use_lora
args.preq_mode = preq_mode
args.preq_group_size = preq_group_size
args.preq_embedding_quantize = preq_embedding_quantize

transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))

# quantization_mode should be applied after embedding_quantize
# to support shared_embedding
if args.quantization_mode:
if quantization_mode:
"""
When this option is selected, it finds all linear layers and transforms
into quantized linear equivalent module.
Expand All @@ -1215,7 +1291,19 @@ def _get_source_transforms( # noqa
There are cases where this may be a no-op, namely, if all linears are
quantized in the checkpoint.
"""
modelname = f"{modelname}_q"
# Create a mock args object with the necessary attributes
class Args:
pass
args = Args()
args.checkpoint = checkpoint
args.tokenizer_path = tokenizer_path
args.quantization_mode = quantization_mode
args.group_size = preq_group_size # Using preq_group_size as group_size
args.use_shared_embedding = use_shared_embedding
args.use_qat = use_qat
args.use_lora = use_lora
args.preq_mode = preq_mode

transforms.append(
get_quant_weight_transform(
args=args,
Expand All @@ -1224,15 +1312,12 @@ def _get_source_transforms( # noqa
)
)

if args.expand_rope_table:
if expand_rope_table:
transforms.append(materialze_broadcast_of_rope_freq_cis)

use_attention_mask_for_custom_sdpa = False
if isinstance(args, argparse.Namespace):
if getattr(args, "use_custom_sdpa_with_attention_mask", None):
use_attention_mask_for_custom_sdpa = True
use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask

if args.use_sdpa_with_kv_cache:
if use_sdpa_with_kv_cache:
transforms.append(replace_kv_cache_with_custom_kv_cache)
# todo: do this optionally
# if use attention mask instead of causal attention
Expand All @@ -1244,23 +1329,23 @@ def _get_source_transforms( # noqa
else:
transforms.append(replace_sdpa_with_custom_op)

if args.quantize_kv_cache:
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
if quantize_kv_cache:
assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
transforms.append(replace_kv_cache_with_quantized_kv_cache)
# Right now
transforms.append(replace_sdpa_with_quantized_sdpa)

if args.use_kv_cache:
if args.qnn:
if use_kv_cache:
if qnn:
from executorch.backends.qualcomm.utils.utils import (
convert_linear_to_conv2d,
)

if args.use_qnn_sha:
if args.optimized_rotation_path:
if use_qnn_sha:
if optimized_rotation_path:
transforms.append(fuse_layer_norms)
transforms.append(
get_model_with_r1_r2(args.optimized_rotation_path)
get_model_with_r1_r2(optimized_rotation_path)
)
transforms.append(replace_attention_to_attention_sha)
transforms.append(replace_causal_mask)
Expand All @@ -1272,29 +1357,29 @@ def _get_source_transforms( # noqa
transforms.append(replace_sdpa_with_flex_sdpa)
transforms.append(replace_causal_mask)
transforms.append(replace_rms_norm_with_native_rms_norm)
if args.optimized_rotation_path:
if optimized_rotation_path:
transforms.append(fuse_layer_norms)
transforms.append(
get_model_with_r1_r2(args.optimized_rotation_path)
get_model_with_r1_r2(optimized_rotation_path)
)
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
transforms.append(convert_linear_to_conv2d)

elif args.mps:
elif mps:
# Currently mps doesn't support sdpa op, use the simpler decomposition
# to get free perf gain.
transforms.append(replace_sdpa_with_simple_sdpa)
transforms.append(replace_causal_mask)

elif args.coreml:
elif coreml:
# iOS 18 introduced fused sdpa op
if args.coreml_ios >= 18:
if coreml_ios >= 18:
transforms.append(replace_sdpa_with_coreml_sdpa)
else:
transforms.append(replace_sdpa_with_simple_sdpa)
transforms.append(replace_kv_cache_with_coreml_kv_cache)

if args.vulkan:
if vulkan:
transforms.append(replace_with_vulkan_rotary_emb)

return transforms
Expand Down
Loading