Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA that doesn't require memory for zero gradients of the underlying matrices #28

Open
colehaus opened this issue Aug 28, 2024 · 6 comments
Labels
question User queries

Comments

@colehaus
Copy link

colehaus commented Aug 28, 2024

I think one of the main motives for LoRA is to reduce memory consumption—certainly that's my motive. I'm already using gradient checkpointing and AdaFactor so the main thing I want from LoRA is to reduce the size of the gradient pytree itself. However, unless I'm quite confused, in a trivial setup like:

class DummyModel(eqx.Module, Generic[Dim1, Dim2, Float]):
    tmp: eqx.nn.Linear[Dim1, Dim2, Float]

    def __init__(self, dim1: Dim1, dim2: Dim2, dtype: type[Float], key: jax.Array) -> None:
        self.tmp = eqx.nn.Linear(dim1, dim2, dtype=dtype, key=key)

    def __call__(self, ndarray: ndarray[Dim1, Float]) -> ndarray[Dim2, Float]:
        return self.tmp(ndarray)

@eqx.filter_value_and_grad
@ft.partial(eqx.filter_checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def grad_fn(m: DummyModel[Dim1, Dim2, Float], y: ndarray[Dim1, Float]) -> ndarray[Float]:
    m = quax.quaxify(m)
    return jnp.square(jnp.mean(m(y)) - 0)
    
def main():
    x = DummyModel[Dim1T, Dim2T, np.float32](4096, 4096, np.float32, jax.random.PRNGKey(0))
    loraed = loraify(x, rank=64, scale=0.1, key=jax.random.PRNGKey(1))
    return grad_fn(loraed, np.random.rand(4096))

the returned grads include a full Dim1 x Dim2 array of zeros for _w. Almost all the values in the gradient pytree are zero (for typical LoRAs) and this is wasted memory.

I thought perhaps I could get around this by replacing jax.lax.stop_gradient in LoraArray with something like:

@jax.custom_jvp
def symbolic_stop_gradient(x: A) -> A:
    return x


@symbolic_stop_gradient.defjvp
def symbolic_stop_gradient_jvp(primals: tuple[ndarray[*Shape, Float]], tangents: tuple[ndarray[*Shape, Float]]):
    return primals[0], Zero(primals[0].shape, primals[0].dtype)

but that produces the following error:

TypeError: Custom JVP rule symbolic_stop_gradient_jvp for function symbolic_stop_gradient must produce primal and tangent outputs with equal container (pytree) structures, but got PyTreeDef(*) and PyTreeDef(CustomNode(Zero[(), ('_shape', '_dtype'), ((4096, 4096), dtype('float32'))], [])) respectively.

Is there a reasonable way to use quax to implement LoRA in a way that doesn't allocate tons of space for zeros?

(I guess it's mildly possible that JAX optimizes out this allocation behind the scenes if the gradient pytree is "consumed" inside the same JIT where the gradients are produced, but I assume it's not quite that clever.)

Thanks.

@patrick-kidger
Copy link
Owner

Actually, I think JAX is exactly that clever :)

Optimizing x+0 to just x is a simple optimization that XLA should perform for us.

That said I'd be happy to adjust Quax to avoid ever emitting the +0 in the first place, but I'm not immediately sure how.

@patrick-kidger patrick-kidger added the question User queries label Aug 28, 2024
Repository owner deleted a comment Aug 28, 2024
Repository owner deleted a comment Aug 28, 2024
@colehaus
Copy link
Author

colehaus commented Aug 28, 2024

Hmm. That optimization does not seem to be happening in my test case with LoRA. Both full training and LoRA have a peak memory usage that's basically double the model size but this optimization does seem to fire when we take gradients of a trivial constant function.

from __future__ import annotations

import functools as ft
from typing import Any, TypeVar, TypeVarTuple

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax
import quax
from numpy import ndarray

BatchLen = TypeVar("BatchLen", bound=int)
Dim1 = TypeVar("Dim1", bound=int)
Dim2 = TypeVar("Dim2", bound=int)
Dim3 = TypeVar("Dim3", bound=int)
Rank = TypeVar("Rank", bound=int)
Float = TypeVar("Float", bound=float)
Shape = TypeVarTuple("Shape")
A = TypeVar("A")
Opt = TypeVar("Opt")


def tree_size(tree: Any) -> int:
    return sum(x.nbytes for x in jax.tree_util.tree_leaves(tree) if eqx.is_array(x))


def human_bytes(size: float, decimal_places: int = 2) -> str:
    unit = "B"
    for unit in ["B", "KB", "MB", "GB", "TB"]:  # noqa: B007
        if size < 1024.0:  # noqa: PLR2004
            break
        size /= 1024.0

    formatted_num = f"{size:.{decimal_places}f}".rstrip("0").rstrip(".")
    return f"{formatted_num:>4} {unit}"


@eqx.filter_value_and_grad
@eqx.filter_jit
@ft.partial(eqx.filter_checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def full_grad_fn(model: eqx.nn.Linear[Dim1, Dim2, Float], input_: ndarray[Dim1, Float]) -> ndarray[Float]:
    return jnp.square(jnp.mean(model.__call__(input_)) - 1)


@eqx.filter_jit(donate="all")
def full_prim_step(
    model: eqx.nn.Linear[Dim1, Dim2, Float],
    input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
    _, grads = full_grad_fn(model, input_)
    lr = 1e-3
    return jax.tree.map(lambda o, c: o - lr * c, model, grads)  # type: ignore


@eqx.filter_value_and_grad
@eqx.filter_jit
@ft.partial(eqx.filter_checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def lora_grad_fn(model: eqx.nn.Linear[Dim1, Dim2, Float], input_: ndarray[Dim1, Float]) -> ndarray[Float]:
    model = quax.quaxify(model)
    return jnp.square(jnp.mean(model.__call__(input_)) - 1)


@eqx.filter_value_and_grad
@eqx.filter_jit
@ft.partial(eqx.filter_checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def no_op_grad_fn(model: eqx.nn.Linear[Dim1, Dim2, Float], input_: ndarray[Dim1, Float]) -> ndarray[Float]:
    return jnp.array(0, input_.dtype)

@eqx.filter_jit(donate="all")
def lora_prim_step(
    model: eqx.nn.Linear[Dim1, Dim2, Float],
    input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
    _, grads = lora_grad_fn(model, input_)
    lr = 1e-3
    return jax.tree.map(lambda o, c: o - lr * c, model, grads)  # type: ignore


@eqx.filter_jit(donate="all")
def no_op_step(
    model: eqx.nn.Linear[Dim1, Dim2, Float],
    input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
    _, grads = no_op_grad_fn(model, input_)
    lr = 1e-3
    return jax.tree.map(lambda o, c: o - lr * c, model, grads)  # type: ignore

dim1 = 65536
dim2 = 1024

def print_live_buffer_total():
    print(human_bytes(sum([x.nbytes for x in jax.live_arrays()])))

def full_prim_main():
    # OOMs on 75_000 but not 70_000
    model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))

    print("model size", human_bytes(tree_size(model)))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
    for _ in range(40):
        model = full_prim_step(model, np.random.rand(dim1).astype(np.float32))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))

def lora_prim_main():
    model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))
    model = quax.examples.lora.loraify(model, rank=64, scale=0.1, key=jax.random.PRNGKey(1))

    print_live_buffer_total()
    print("model size", human_bytes(tree_size(model)))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
    for _ in range(40):
        model = lora_prim_step(model, np.random.rand(dim1).astype(np.float32))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))


def no_op_main():
    model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))

    print("model size", human_bytes(tree_size(model)))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
    for _ in range(40):
        model = no_op_step(model, np.random.rand(dim1).astype(np.float32))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))

image

(If this counts as a JAX bug and/or is out of scope, I'm happy to move it over to the JAX repo.)

@patrick-kidger
Copy link
Owner

Hmm, that's unfortunate if so.

Quax is still a fairly experimental library, so I'd be happy to take suggestions on how we might adjust the internals to work around this.

For example this could be accomplished by partition/combineing either side of the grad. Maybe there's a way to more easily enable that.

@colehaus
Copy link
Author

Yeah, I'll think about the problem.

I already tried a version of the partition/combine approach for a different problem (not LoRA but a whole chunk of the model frozen) and the memory usage didn't work out there as hoped. I'll see if I can reproduce that problem, but, if not, maybe something in that region is the right thing to aim for.

@colehaus
Copy link
Author

(I opened an issue on this optimization at jax-ml/jax#23316.)

@colehaus
Copy link
Author

colehaus commented Aug 29, 2024

Actually, I think those peak usages may be misleading. The problem may be something else. Even with an explicitly split model we get very similar behavior:

@jax.value_and_grad
@jax.jit
@ft.partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
def split_lora_grad_fn(
    malleable: PartOf[eqx.nn.Linear[Dim1, Dim2, Float]], frozen: PartOf[eqx.nn.Linear[Dim1, Dim2, Float]], input_: ndarray[Dim1, Float]
) -> ndarray[Float]:
    model = quax.quaxify(eqx.combine(malleable, frozen))
    return jnp.square(jnp.mean(model.__call__(input_)) - 1)

@ft.partial(jax.jit, donate_argnums=0)
def split_lora_prim_step(
    model: eqx.nn.Linear[Dim1, Dim2, Float],
    input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
    loraed = jtu.tree_map_with_path(lambda path, _: path[-2:] != (jtu.GetAttrKey("weight"), jtu.GetAttrKey("_w")), model)  # type: ignore
    malleable, frozen = eqx.partition(model, loraed)
    del loraed, model
    _, grads = split_lora_grad_fn(malleable, frozen, input_)
    print("grad size", human_bytes(tree_size(grads)))
    lr = 1e-3
    malleable = jax.tree.map(lambda o, c: o - lr * c, malleable, grads)  # type: ignore
    return eqx.combine(malleable, frozen)

def split_lora_prim_main():
    model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))
    model = quax.examples.lora.loraify(model, rank=64, scale=0.1, key=jax.random.PRNGKey(1))

   #  ir = split_lora_prim_step.lower(model, np.random.rand(dim1).astype(np.float32)).compiler_ir()
    # ir.dump()

    print("model size", human_bytes(tree_size(model)))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
    for _ in range(1):
        model = split_lora_prim_step(model, np.random.rand(dim1).astype(np.float32))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
model size 272.25 MB
peak usage 272.25 MB
grad size 16.25 MB
peak usage 628.51 MB

And they OOM in the same way.

@github-staff github-staff deleted a comment from Superstar-IT Oct 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants
@colehaus @patrick-kidger and others