We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Describe the bug grad is not computed when calling LoRA-like training target in mx.custom_function
To Reproduce
import mlx.core as mx import mlx.nn as nn class BaseModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(2, 2) def __call__(self, x): return self.linear(x) def create_lora(original_module): class LoRA(nn.Module): def __init__(self): super().__init__() self.linear1 = nn.Linear(2, 3) self.linear2 = nn.Linear(3, 2) def __call__(self, x): nonlocal original_module print("LORA_called") y = self.linear1(x) y = self.linear2(y) return original_module(x) + y return LoRA() def test_run(model, lora_module, x): def loss_fn_wrapped(params, x): lora_module.update(params) x = model(x) x = model(x) return x.mean() loss, grad = mx.value_and_grad(loss_fn_wrapped)(lora_module.trainable_parameters(), x) mx.eval(loss, grad) print("loss=",loss) print("grad(lora).linear1.weight=",grad["linear1"]["weight"]) def with_cfn(original_fn): @mx.custom_function def cfn(*args): return original_fn(*args) @cfn.vjp def cfn_vjp(primals, cotangent, output): raise RuntimeError("vjp called correctly.") return cfn base_model = BaseModel() lora_module = create_lora(base_model.linear) base_model.linear = lora_module x = mx.random.normal([1, 2]) print("without_custom_fn-------------") test_run(base_model, lora_module, x) print("with_custom_fn----------------") test_run(with_cfn(base_model), lora_module, x) # mx.value_and_grad is called with lora params, # and lora is called inside mx.value_and_grad, # returning the correct loss # but grad is Zero, and vjp is not called
result:
without_custom_fn------------- LORA_called LORA_called loss= array(0.494176, dtype=float32) grad(lora).linear1.weight= array([[-0.61214, -0.00502913], [-0.721728, -0.0861656], [-0.574084, -0.0290771]], dtype=float32) with_custom_fn---------------- LORA_called LORA_called loss= array(0.494176, dtype=float32) grad(lora).linear1.weight= array([[0, 0], [0, 0], [0, 0]], dtype=float32)
Expected behavior should compute grad for lora-like training target
without_custom_fn------------- LORA_called LORA_called loss= array(0.494176, dtype=float32) grad(lora).linear1.weight= array([[-0.61214, -0.00502913], [-0.721728, -0.0861656], [-0.574084, -0.0290771]], dtype=float32) with_custom_fn---------------- LORA_called LORA_called Traceback (most recent call last): File "/.../mx.py", line 53, in <module> test_run(with_cfn(base_model), lora_module, x) File "/.../mx.py", line 30, in test_run loss, grad = mx.value_and_grad(loss_fn_wrapped)(lora_module.trainable_parameters(), x) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.../mx.py", line 42, in cfn_vjp raise RuntimeError("vjp called correctly.") RuntimeError: vjp called correctly.
Desktop (please complete the following information):
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Describe the bug
grad is not computed when calling LoRA-like training target in mx.custom_function
To Reproduce
result:
Expected behavior
should compute grad for lora-like training target
Desktop (please complete the following information):
The text was updated successfully, but these errors were encountered: