-
-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LoRA that doesn't require memory for zero gradients of the underlying matrices #28
Comments
Actually, I think JAX is exactly that clever :) Optimizing That said I'd be happy to adjust Quax to avoid ever emitting the |
Hmm. That optimization does not seem to be happening in my test case with LoRA. Both full training and LoRA have a peak memory usage that's basically double the model size but this optimization does seem to fire when we take gradients of a trivial constant function. from __future__ import annotations
import functools as ft
from typing import Any, TypeVar, TypeVarTuple
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax
import quax
from numpy import ndarray
BatchLen = TypeVar("BatchLen", bound=int)
Dim1 = TypeVar("Dim1", bound=int)
Dim2 = TypeVar("Dim2", bound=int)
Dim3 = TypeVar("Dim3", bound=int)
Rank = TypeVar("Rank", bound=int)
Float = TypeVar("Float", bound=float)
Shape = TypeVarTuple("Shape")
A = TypeVar("A")
Opt = TypeVar("Opt")
def tree_size(tree: Any) -> int:
return sum(x.nbytes for x in jax.tree_util.tree_leaves(tree) if eqx.is_array(x))
def human_bytes(size: float, decimal_places: int = 2) -> str:
unit = "B"
for unit in ["B", "KB", "MB", "GB", "TB"]: # noqa: B007
if size < 1024.0: # noqa: PLR2004
break
size /= 1024.0
formatted_num = f"{size:.{decimal_places}f}".rstrip("0").rstrip(".")
return f"{formatted_num:>4} {unit}"
@eqx.filter_value_and_grad
@eqx.filter_jit
@ft.partial(eqx.filter_checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def full_grad_fn(model: eqx.nn.Linear[Dim1, Dim2, Float], input_: ndarray[Dim1, Float]) -> ndarray[Float]:
return jnp.square(jnp.mean(model.__call__(input_)) - 1)
@eqx.filter_jit(donate="all")
def full_prim_step(
model: eqx.nn.Linear[Dim1, Dim2, Float],
input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
_, grads = full_grad_fn(model, input_)
lr = 1e-3
return jax.tree.map(lambda o, c: o - lr * c, model, grads) # type: ignore
@eqx.filter_value_and_grad
@eqx.filter_jit
@ft.partial(eqx.filter_checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def lora_grad_fn(model: eqx.nn.Linear[Dim1, Dim2, Float], input_: ndarray[Dim1, Float]) -> ndarray[Float]:
model = quax.quaxify(model)
return jnp.square(jnp.mean(model.__call__(input_)) - 1)
@eqx.filter_value_and_grad
@eqx.filter_jit
@ft.partial(eqx.filter_checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def no_op_grad_fn(model: eqx.nn.Linear[Dim1, Dim2, Float], input_: ndarray[Dim1, Float]) -> ndarray[Float]:
return jnp.array(0, input_.dtype)
@eqx.filter_jit(donate="all")
def lora_prim_step(
model: eqx.nn.Linear[Dim1, Dim2, Float],
input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
_, grads = lora_grad_fn(model, input_)
lr = 1e-3
return jax.tree.map(lambda o, c: o - lr * c, model, grads) # type: ignore
@eqx.filter_jit(donate="all")
def no_op_step(
model: eqx.nn.Linear[Dim1, Dim2, Float],
input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
_, grads = no_op_grad_fn(model, input_)
lr = 1e-3
return jax.tree.map(lambda o, c: o - lr * c, model, grads) # type: ignore
dim1 = 65536
dim2 = 1024
def print_live_buffer_total():
print(human_bytes(sum([x.nbytes for x in jax.live_arrays()])))
def full_prim_main():
# OOMs on 75_000 but not 70_000
model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))
print("model size", human_bytes(tree_size(model)))
print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
for _ in range(40):
model = full_prim_step(model, np.random.rand(dim1).astype(np.float32))
print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
def lora_prim_main():
model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))
model = quax.examples.lora.loraify(model, rank=64, scale=0.1, key=jax.random.PRNGKey(1))
print_live_buffer_total()
print("model size", human_bytes(tree_size(model)))
print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
for _ in range(40):
model = lora_prim_step(model, np.random.rand(dim1).astype(np.float32))
print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
def no_op_main():
model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))
print("model size", human_bytes(tree_size(model)))
print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
for _ in range(40):
model = no_op_step(model, np.random.rand(dim1).astype(np.float32))
print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"])) (If this counts as a JAX bug and/or is out of scope, I'm happy to move it over to the JAX repo.) |
Hmm, that's unfortunate if so. Quax is still a fairly experimental library, so I'd be happy to take suggestions on how we might adjust the internals to work around this. For example this could be accomplished by |
Yeah, I'll think about the problem. I already tried a version of the |
(I opened an issue on this optimization at jax-ml/jax#23316.) |
Actually, I think those peak usages may be misleading. The problem may be something else. Even with an explicitly split model we get very similar behavior: @jax.value_and_grad
@jax.jit
@ft.partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
def split_lora_grad_fn(
malleable: PartOf[eqx.nn.Linear[Dim1, Dim2, Float]], frozen: PartOf[eqx.nn.Linear[Dim1, Dim2, Float]], input_: ndarray[Dim1, Float]
) -> ndarray[Float]:
model = quax.quaxify(eqx.combine(malleable, frozen))
return jnp.square(jnp.mean(model.__call__(input_)) - 1)
@ft.partial(jax.jit, donate_argnums=0)
def split_lora_prim_step(
model: eqx.nn.Linear[Dim1, Dim2, Float],
input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
loraed = jtu.tree_map_with_path(lambda path, _: path[-2:] != (jtu.GetAttrKey("weight"), jtu.GetAttrKey("_w")), model) # type: ignore
malleable, frozen = eqx.partition(model, loraed)
del loraed, model
_, grads = split_lora_grad_fn(malleable, frozen, input_)
print("grad size", human_bytes(tree_size(grads)))
lr = 1e-3
malleable = jax.tree.map(lambda o, c: o - lr * c, malleable, grads) # type: ignore
return eqx.combine(malleable, frozen)
def split_lora_prim_main():
model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))
model = quax.examples.lora.loraify(model, rank=64, scale=0.1, key=jax.random.PRNGKey(1))
# ir = split_lora_prim_step.lower(model, np.random.rand(dim1).astype(np.float32)).compiler_ir()
# ir.dump()
print("model size", human_bytes(tree_size(model)))
print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
for _ in range(1):
model = split_lora_prim_step(model, np.random.rand(dim1).astype(np.float32))
print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
And they OOM in the same way. |
I think one of the main motives for LoRA is to reduce memory consumption—certainly that's my motive. I'm already using gradient checkpointing and AdaFactor so the main thing I want from LoRA is to reduce the size of the gradient pytree itself. However, unless I'm quite confused, in a trivial setup like:
the returned grads include a full
Dim1 x Dim2
array of zeros for_w
. Almost all the values in the gradient pytree are zero (for typical LoRAs) and this is wasted memory.I thought perhaps I could get around this by replacing
jax.lax.stop_gradient
inLoraArray
with something like:but that produces the following error:
Is there a reasonable way to use quax to implement LoRA in a way that doesn't allocate tons of space for zeros?
(I guess it's mildly possible that JAX optimizes out this allocation behind the scenes if the gradient pytree is "consumed" inside the same JIT where the gradients are produced, but I assume it's not quite that clever.)
Thanks.
The text was updated successfully, but these errors were encountered: