Skip to content

Commit

Permalink
add lit lama impl
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Aug 13, 2024
1 parent 2fb22e9 commit 3dfe62f
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
get_compression_kwargs,
get_sharding_strategy,
)
from open_diloco.model import ModelArgs, TransformerHF


TIMEOUT_NCCL_MINUTES = os.environ.get("TIMEOUT_NCCL_MINUTES", 120)
Expand Down Expand Up @@ -119,6 +120,7 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:

class Config(BaseConfig):
path_model: str = "PrimeIntellect/llama-150m-fresh"
torch_titan_llama: bool = False
torch_compile: bool = True
attn_implementation: str = "sdpa"
# Data
Expand Down Expand Up @@ -189,8 +191,12 @@ def tokenize_function(data):

def get_model(config: Config) -> LlamaForCausalLM:
# Load model
config_model = LlamaConfig.from_pretrained(config.path_model, attn_implementation=config.attn_implementation)
return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model)
if config.torch_titan_llama:
config_model = ModelArgs.from_name(config.path_model)
return TransformerHF(config=config_model)
else:
config_model = LlamaConfig.from_pretrained(config.path_model, attn_implementation=config.attn_implementation)
return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model)


def train(config: Config):
Expand Down Expand Up @@ -398,9 +404,6 @@ def scheduler_fn(opt):
batch[key] = batch[key].to("cuda")

with model.no_sync() if is_accumulating else nullcontext():
log(batch.keys())
log(f"input_ids shape: {batch['input_ids'].shape}")

logits = model(input_ids=batch["input_ids"]).logits.contiguous()
labels = batch["labels"].contiguous()

Expand Down

0 comments on commit 3dfe62f

Please sign in to comment.