Releases: RulinShao/FastCkpt
Releases · RulinShao/FastCkpt
Pre-release that supports importing FlashAttention and rematerialization-aware gradient checkpointing in one line.
Pre-release
In this pre-release, FastCkpt supports importing FlashAttention and rematerialization-aware gradient checkpointing in one line. Check README for detailed usage.
Use FaskCkpt and FlashAttention
To use fasckpt
with flash_attn
, import and run replace_hf_ckpt_with_fast_ckpt
before importing transformers
# add monkey patch for fastckpt
from fastckpt.llama_flash_attn_ckpt_monkey_patch import replace_hf_ckpt_with_fast_ckpt
replace_hf_ckpt_with_fast_ckpt()
# import transformers and other packages
import transformers
...
Use FlashAttention only
To only replace the LlamaAttention
with flash_attn
without chaning the checkpointing strategy, import and run replace_llama_attn_with_flash_attn
# add monkey patch for fastckpt
from fastckpt.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
# import transformers and other packages
import transformers
...