Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CAME Optimizer #2385

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft

Add CAME Optimizer #2385

wants to merge 5 commits into from

Conversation

xzuyn
Copy link
Contributor

@xzuyn xzuyn commented Mar 5, 2025

Description

PR set as a draft as I'm unsure how betas (tuple[float, float, float]) and eps (tuple[float, float]) should be set. Everything works fine though, I just don't know how best to make those options editable.

https://arxiv.org/abs/2307.02047
https://github.com/yangluo7/CAME

Requires you to pip install came_pytorch.

optimizer: came_pytorch

Screenshots (if appropriate)

Screenshot 2025-03-05 at 11-32-02 LLaMa-3 2-1B-Instruct Workspace – Weights   Biases

Screenshot 2025-03-05 at 11-50-22 LLaMa-3 2-1B-Instruct Workspace – Weights   Biases

Axolotl Config
# Weights and Biases logging config
wandb_project: LLaMa-3.2-1B-Instruct
wandb_entity:
wandb_watch:
wandb_name: came_pytorch
wandb_log_model:

# Model checkpointing config
output_dir: ./Outputs/came_pytorch
resume_from_checkpoint:
save_steps: 50
save_safetensors: true
save_total_limit: 3
save_only_model: false

# Model architecture config
base_model: meta-llama/Llama-3.2-1B-Instruct
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

# Mixed precision training config
bf16: true
fp16: false
tf32: false

# Model loading config
load_in_8bit: false
load_in_4bit: true
strict: false

# Sequence config
sequence_len: 4096
min_sample_len: 256
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
train_on_inputs: false
group_by_length: false

# LoRA adapter config
adapter: qlora
lora_model_dir:
lora_r: 32
lora_alpha: 32
lora_dropout: 0.125
peft_layers_to_transform:
peft_use_dora:
peft_use_rslora:
peft_layer_replication:
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj
lora_modules_to_save:

# Fix uninitialized tokens (such as <|start_header_id|> on the base L3 models)
fix_untrained_tokens:

# Dataset config
# Completion: https://github.com/xzuyn/axolotl/blob/prompt_formats/src/axolotl/prompt_strategies/customcompletion-regex.py
datasets:
# Completion
  - path: PJMixers-Dev/fundusnewsonly
    split: train[:1024]  # Only the first 1024
    type: customcompletion-regex
test_datasets:
# Completion
  - path: PJMixers-Dev/fundusnewsonly
    split: train[-32:]  # Only the last 32
    type: customcompletion-regex
val_set_size: 0
eval_strategy: steps
eval_steps: 1
dataset_prepared_path: ./00-Tokenized-Datasets/testt
shuffle_merged_datasets: true
dataset_processes:

# Training hyperparameters
num_epochs: 1
gradient_accumulation_steps: 1
micro_batch_size: 8
eval_batch_size: 8
warmup_steps: 5
optimizer: came_pytorch
optim_args:
optim_target_modules:
lr_scheduler: rex
learning_rate: 1e-5
cosine_min_lr_ratio:
loraplus_lr_ratio:
loraplus_lr_embedding:
weight_decay: 0.1
max_grad_norm: 1
logging_steps: 1

# Model optimization
gradient_checkpointing: unsloth
sdp_attention: true
plugins:
  - axolotl.integrations.liger.LigerPlugin
  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
liger_rope: true
liger_rms_norm: true
liger_layer_norm: true
liger_glu_activation: true
liger_cross_entropy: false
liger_fused_linear_cross_entropy: false
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false

# DeepSpeed
deepspeed:

# Garbage Collection
gc_steps: 1

# Debug config
debug: true
seed: 42

# Token config
special_tokens:
  bos_token: "<|begin_of_text|>"
  eos_token: "<|end_of_text|>"
  pad_token: "<|finetune_right_pad_id|>"
tokens:

@winglian
Copy link
Collaborator

winglian commented Mar 5, 2025

We're doing a refactor of how we handle custom optimizers. See #2367

@xzuyn
Copy link
Contributor Author

xzuyn commented Mar 5, 2025

I will keep an eye on that.

@winglian
Copy link
Collaborator

you could probably use optim_args like

optim_args:
  betas: [0.999, 0.999, 0.999]
  eps: 0.00000001

@@ -691,6 +691,22 @@ def build(self, total_num_steps):
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"

if self.cfg.optimizer == "came_pytorch":
from came_pytorch import CAME
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably have a try/except ImportError with messaging to install it from pip.

@NanoCode012
Copy link
Collaborator

The linked PR was merged. If you need an example on how to add optimizer's now, it's as simple as this: https://github.com/axolotl-ai-cloud/axolotl/pull/2367/files#diff-5edc13801ecfdd108e81872527bdc78c6d24a73833968147b0a9ecb8452996f4R693-R698

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants