Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX] Migrating from Xmap to Custom Partitioning for All Custom Calls (…
…#472) * Refactor sharding.py for the further custom_partitioning migration Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Migrating both FWD and BWD of LayerNorm/RMSNorm from xmap to custom_partitioning. Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Migrating both FWD and BWD of all kinds of softmax from xmap to custom_partitioning. Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Fix the wrong order of parameters to LN/RMSN bwd in ln_mlp_fp8. Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * WAR to LN/RMSN_fp8 before migrating to CP. Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Fix the wrong order of parameters of bwd of LN/RMSN_fp8. Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Following review feedback to modify Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Force the hidden dim in Norm ops to no sharding and add warning msg. Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Reuse fwd_rule in VJP functions Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Migrating both FWD and BWD of self-fused-attn from xmap to custom_partitioning. Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Migrating both FWD and BWD of cross-fused-attn from xmap to custom_partitioning. Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * add gelu and dgelu. Signed-off-by: Ming-Xu Huang <[email protected]> * Reuse fwd_rule in VJP functions for attentions Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Apply native FP8 Dtypes to fp8.py Signed-off-by: Ming-Xu Huang <[email protected]> * Migrating cast_and_transpose from xmap to custom_partitioning Signed-off-by: Ming-Xu Huang <[email protected]> * Migrating transpose from xmap to custom_partitioning Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Apply XLA pattern match to perform FP8 GEMM. Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * migrate layernorm_fp8 to custom_partitioning. Signed-off-by: Ming-Xu Huang <[email protected]> * Unify code style Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Extend supported of Transpose with FP8 Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Implementing layernorm_fp8_dot based on migrated custom calls. Signed-off-by: Ming-Xu Huang <[email protected]> * Renaming variables and publish NVTE_FP8_COLLECTION_NAME Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Replace Q/DQ custom calls with native XLA implementations Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * migrate gelu_fp to custom_partitioning. Signed-off-by: Ming-Xu Huang <[email protected]> * Miner fix Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Support custom calls with mutli-dims Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Support gerneral dot indices in _fp8_dot_impl Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Implementing layernrom_geglu_fp8_mlp Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Remove GEMM custom calls Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Remove xmap related code Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Fix typo and add query-function to FP8MetaPackage Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Fix some bugs of custom calls Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Fix CT's bugs Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Update UTs/eaxmaples to adapt to the API changes. Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Unify kernel initilization in MLP. Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Modifing with code review's feedback Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Update README and Add deprecating warning to *ShardingType Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> * Canonicalize the dtype Signed-off-by: Ming Huang <[email protected]> * Adding assertion for non-supported batch dims. Signed-off-by: Ming Huang <[email protected]> * Adding doc/examples to _multidim_transpose Signed-off-by: Ming Huang <[email protected]> * Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules. Signed-off-by: Ming Huang <[email protected]> * Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules. Signed-off-by: Ming Huang <[email protected]> * Apply dtype-based rtol/atol to UTs Signed-off-by: Ming Huang <[email protected]> * Deprecate QKV_INTERLEAVED enum Signed-off-by: Ming Huang <[email protected]> * Skip test_distributed_custom_ops.py Signed-off-by: Ming Huang <[email protected]> * Fix the wrong sharding of bias in SelfAttn Signed-off-by: Ming Huang <[email protected]> * WAR to fix the wrong cu_seqlen of MHA when DP/FSDP enabled Signed-off-by: Ming Huang <[email protected]> * Adding distributed ops unit-tests Signed-off-by: Ming Huang <[email protected]> * Adding license to test_distributed_* Signed-off-by: Ming Huang <[email protected]> * Follow review feedback to modify Signed-off-by: Ming Huang <[email protected]> * Use total bytes involved in collective ops as criteria. Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Ming-Xu Huang <[email protected]> Co-authored-by: Donglin Yang <[email protected]>
- Loading branch information