diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 79a225232e..05a0b8a360 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -651,10 +651,31 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}") edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform( _get_source_transforms( - modelname=args.model, dtype_override=dtype_override, + checkpoint=args.checkpoint, checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore - args=args, + tokenizer_path=args.tokenizer_path, + use_spin_quant=args.use_spin_quant, + embedding_quantize=args.embedding_quantize, + quantization_mode=args.quantization_mode, + expand_rope_table=args.expand_rope_table, + use_custom_sdpa_with_attention_mask=getattr(args, "use_custom_sdpa_with_attention_mask", False), + use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, + quantize_kv_cache=args.quantize_kv_cache, + use_kv_cache=args.use_kv_cache, + qnn=args.qnn, + use_qnn_sha=args.use_qnn_sha, + optimized_rotation_path=args.optimized_rotation_path, + mps=args.mps, + coreml=args.coreml, + coreml_ios=args.coreml_ios, + vulkan=args.vulkan, + use_shared_embedding=args.use_shared_embedding, + use_qat=args.use_qat, + use_lora=args.use_lora, + preq_mode=args.preq_mode, + preq_group_size=args.preq_group_size, + preq_embedding_quantize=args.preq_embedding_quantize, ) ) @@ -1145,23 +1166,65 @@ def _load_llama_model( def _get_source_transforms( # noqa - modelname: str, dtype_override: DType, *, + checkpoint: Optional[str] = None, checkpoint_dtype: Optional[DType] = None, - args, + tokenizer_path: Optional[str] = None, + use_spin_quant: Optional[str] = None, + embedding_quantize: Optional[str] = None, + quantization_mode: Optional[str] = None, + expand_rope_table: bool = False, + use_custom_sdpa_with_attention_mask: bool = False, + use_sdpa_with_kv_cache: bool = False, + quantize_kv_cache: bool = False, + use_kv_cache: bool = False, + qnn: bool = False, + use_qnn_sha: bool = False, + optimized_rotation_path: Optional[str] = None, + mps: bool = False, + coreml: bool = False, + coreml_ios: int = 15, + vulkan: bool = False, + use_shared_embedding: bool = False, + use_qat: bool = False, + use_lora: int = 0, + preq_mode: Optional[str] = None, + preq_group_size: int = 32, + preq_embedding_quantize: str = "8,0", ) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: """ Return a list of functions that transform a graph. Args: - modelname: The name of the model. dtype_override: The dtype to use for the model. + checkpoint: Path to the checkpoint file. checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified, it means that you want to run quantize transformations on the weights represented in their original dtype, while the overall dtype of the model maybe something different. If not specified, defaults to dtype_override. - args: The arguments passed to the script. + tokenizer_path: Path to the tokenizer file. + use_spin_quant: Type of spin quant to use ("cuda" or "native"). + embedding_quantize: Type of embedding quantization. + quantization_mode: Type of quantization mode. + expand_rope_table: Whether to expand rope table. + use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask. + use_sdpa_with_kv_cache: Whether to use SDPA with KV cache. + quantize_kv_cache: Whether to quantize KV cache. + use_kv_cache: Whether to use KV cache. + qnn: Whether to use QNN. + use_qnn_sha: Whether to use QNN SHA. + optimized_rotation_path: Path to optimized rotation. + mps: Whether to use MPS. + coreml: Whether to use CoreML. + coreml_ios: CoreML iOS version. + vulkan: Whether to use Vulkan. + use_shared_embedding: Whether to use shared embedding. + use_qat: Whether to use QAT. + use_lora: LoRA rank (0 means no LoRA). + preq_mode: Pre-quantization mode. + preq_group_size: Pre-quantization group size. + preq_embedding_quantize: Pre-quantization embedding quantize. Returns: A list of transformation functions. @@ -1172,21 +1235,21 @@ def _get_source_transforms( # noqa transforms = [] - if args.use_spin_quant: - if args.use_spin_quant == "cuda": + if use_spin_quant: + if use_spin_quant == "cuda": from .source_transformation.spin_quant import ( inject_fast_hadamard_transform_cuda_for_spin_quant, ) transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant) - elif args.use_spin_quant == "native": + elif use_spin_quant == "native": from .source_transformation.spin_quant import ( inject_fast_hadamard_transform_native_for_spin_quant, ) transforms.append(inject_fast_hadamard_transform_native_for_spin_quant) - if args.embedding_quantize: + if embedding_quantize: """ When this option is selected, it finds all embedding layers and transforms into quantized embedding equivalent module. @@ -1196,12 +1259,25 @@ def _get_source_transforms( # noqa transformations based on the given checkpoint first. In those cases, this wil be a no-op. """ - modelname = f"{modelname}_e" + # Create a mock args object with the necessary attributes + class Args: + pass + args = Args() + args.checkpoint = checkpoint + args.tokenizer_path = tokenizer_path + args.embedding_quantize = embedding_quantize + args.use_shared_embedding = use_shared_embedding + args.use_qat = use_qat + args.use_lora = use_lora + args.preq_mode = preq_mode + args.preq_group_size = preq_group_size + args.preq_embedding_quantize = preq_embedding_quantize + transforms.append(get_quant_embedding_transform(args, checkpoint_dtype)) # quantization_mode should be applied after embedding_quantize # to support shared_embedding - if args.quantization_mode: + if quantization_mode: """ When this option is selected, it finds all linear layers and transforms into quantized linear equivalent module. @@ -1215,7 +1291,19 @@ def _get_source_transforms( # noqa There are cases where this may be a no-op, namely, if all linears are quantized in the checkpoint. """ - modelname = f"{modelname}_q" + # Create a mock args object with the necessary attributes + class Args: + pass + args = Args() + args.checkpoint = checkpoint + args.tokenizer_path = tokenizer_path + args.quantization_mode = quantization_mode + args.group_size = preq_group_size # Using preq_group_size as group_size + args.use_shared_embedding = use_shared_embedding + args.use_qat = use_qat + args.use_lora = use_lora + args.preq_mode = preq_mode + transforms.append( get_quant_weight_transform( args=args, @@ -1224,15 +1312,12 @@ def _get_source_transforms( # noqa ) ) - if args.expand_rope_table: + if expand_rope_table: transforms.append(materialze_broadcast_of_rope_freq_cis) - use_attention_mask_for_custom_sdpa = False - if isinstance(args, argparse.Namespace): - if getattr(args, "use_custom_sdpa_with_attention_mask", None): - use_attention_mask_for_custom_sdpa = True + use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask - if args.use_sdpa_with_kv_cache: + if use_sdpa_with_kv_cache: transforms.append(replace_kv_cache_with_custom_kv_cache) # todo: do this optionally # if use attention mask instead of causal attention @@ -1244,23 +1329,23 @@ def _get_source_transforms( # noqa else: transforms.append(replace_sdpa_with_custom_op) - if args.quantize_kv_cache: - assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" + if quantize_kv_cache: + assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" transforms.append(replace_kv_cache_with_quantized_kv_cache) # Right now transforms.append(replace_sdpa_with_quantized_sdpa) - if args.use_kv_cache: - if args.qnn: + if use_kv_cache: + if qnn: from executorch.backends.qualcomm.utils.utils import ( convert_linear_to_conv2d, ) - if args.use_qnn_sha: - if args.optimized_rotation_path: + if use_qnn_sha: + if optimized_rotation_path: transforms.append(fuse_layer_norms) transforms.append( - get_model_with_r1_r2(args.optimized_rotation_path) + get_model_with_r1_r2(optimized_rotation_path) ) transforms.append(replace_attention_to_attention_sha) transforms.append(replace_causal_mask) @@ -1272,29 +1357,29 @@ def _get_source_transforms( # noqa transforms.append(replace_sdpa_with_flex_sdpa) transforms.append(replace_causal_mask) transforms.append(replace_rms_norm_with_native_rms_norm) - if args.optimized_rotation_path: + if optimized_rotation_path: transforms.append(fuse_layer_norms) transforms.append( - get_model_with_r1_r2(args.optimized_rotation_path) + get_model_with_r1_r2(optimized_rotation_path) ) # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. transforms.append(convert_linear_to_conv2d) - elif args.mps: + elif mps: # Currently mps doesn't support sdpa op, use the simpler decomposition # to get free perf gain. transforms.append(replace_sdpa_with_simple_sdpa) transforms.append(replace_causal_mask) - elif args.coreml: + elif coreml: # iOS 18 introduced fused sdpa op - if args.coreml_ios >= 18: + if coreml_ios >= 18: transforms.append(replace_sdpa_with_coreml_sdpa) else: transforms.append(replace_sdpa_with_simple_sdpa) transforms.append(replace_kv_cache_with_coreml_kv_cache) - if args.vulkan: + if vulkan: transforms.append(replace_with_vulkan_rotary_emb) return transforms