Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] grad is not computed when calling LoRA-like training target in mx.custom_function #1669

Open
kaeru-shigure opened this issue Dec 7, 2024 · 0 comments

Comments

@kaeru-shigure
Copy link

Describe the bug
grad is not computed when calling LoRA-like training target in mx.custom_function

To Reproduce

import mlx.core as mx
import mlx.nn as nn

class BaseModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2, 2)
    def __call__(self, x):
        return self.linear(x)

def create_lora(original_module):
    class LoRA(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear1 = nn.Linear(2, 3)
            self.linear2 = nn.Linear(3, 2)
        def __call__(self, x):
            nonlocal original_module
            print("LORA_called")
            y = self.linear1(x)
            y = self.linear2(y)
            return original_module(x) + y
    return LoRA()

def test_run(model, lora_module, x):
    def loss_fn_wrapped(params, x):
        lora_module.update(params)
        x = model(x)
        x = model(x)
        return x.mean()
    loss, grad = mx.value_and_grad(loss_fn_wrapped)(lora_module.trainable_parameters(), x)

    mx.eval(loss, grad)
    print("loss=",loss)
    print("grad(lora).linear1.weight=",grad["linear1"]["weight"])

def with_cfn(original_fn):
    @mx.custom_function
    def cfn(*args):
        return original_fn(*args)
    @cfn.vjp
    def cfn_vjp(primals, cotangent, output):
        raise RuntimeError("vjp called correctly.")
    return cfn

base_model = BaseModel()
lora_module = create_lora(base_model.linear)
base_model.linear = lora_module

x = mx.random.normal([1, 2])
print("without_custom_fn-------------")
test_run(base_model, lora_module, x)
print("with_custom_fn----------------")
test_run(with_cfn(base_model), lora_module, x)
# mx.value_and_grad is called with lora params,
# and lora is called inside mx.value_and_grad,
# returning the correct loss
# but grad is Zero, and vjp is not called

result:

without_custom_fn-------------
LORA_called
LORA_called
loss= array(0.494176, dtype=float32)
grad(lora).linear1.weight= array([[-0.61214, -0.00502913],
       [-0.721728, -0.0861656],
       [-0.574084, -0.0290771]], dtype=float32)
with_custom_fn----------------
LORA_called
LORA_called
loss= array(0.494176, dtype=float32)
grad(lora).linear1.weight= array([[0, 0],
       [0, 0],
       [0, 0]], dtype=float32)

Expected behavior
should compute grad for lora-like training target

without_custom_fn-------------
LORA_called
LORA_called
loss= array(0.494176, dtype=float32)
grad(lora).linear1.weight= array([[-0.61214, -0.00502913],
       [-0.721728, -0.0861656],
       [-0.574084, -0.0290771]], dtype=float32)
with_custom_fn----------------
LORA_called
LORA_called
Traceback (most recent call last):
  File "/.../mx.py", line 53, in <module>
    test_run(with_cfn(base_model), lora_module, x)
  File "/.../mx.py", line 30, in test_run
    loss, grad = mx.value_and_grad(loss_fn_wrapped)(lora_module.trainable_parameters(), x)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../mx.py", line 42, in cfn_vjp
    raise RuntimeError("vjp called correctly.")
RuntimeError: vjp called correctly.

Desktop (please complete the following information):

  • OS Version: MacOS 15.1.1
  • Version: 0.21.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant