You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have my RAM getting used up to overflow when I use scale_by_radam gradient transform or equivalently optax.radamwithout JIT compiling the code. The problem appears on CPU and GPU but does not appear when I use JIT compilation. The problem does not seem to exist with optax.adam.
Here is a MWE derived from optax quick start tutorial:
importrandomfromtypingimportTupleimportosos.environ['CUDA_VISIBLE_DEVICES'] =''# uncomment to force CPUimportoptaximportjax.numpyasjnpimportjaximportnumpyasnpBATCH_SIZE=500NUM_TRAIN_STEPS=10000RAW_TRAINING_DATA=np.random.randint(255, size=(NUM_TRAIN_STEPS, BATCH_SIZE, 1))
TRAINING_DATA=np.unpackbits(RAW_TRAINING_DATA.astype(np.uint8), axis=-1)
LABELS=jax.nn.one_hot(RAW_TRAINING_DATA%2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2)
initial_params= {
'hidden': jax.random.normal(shape=[8, 200], key=jax.random.PRNGKey(0)),
'hidden2': jax.random.normal(shape=[200, 100], key=jax.random.PRNGKey(0)),
'output': jax.random.normal(shape=[100, 2], key=jax.random.PRNGKey(1)),
}
defnet(x: jnp.ndarray, params: jnp.ndarray) ->jnp.ndarray:
x=jnp.dot(x, params['hidden'])
x=jax.nn.relu(x)
x=jnp.dot(x, params['hidden2'])
x=jax.nn.relu(x)
x=jnp.dot(x, params['output'])
returnxdefloss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) ->jnp.ndarray:
y_hat=net(batch, params)
# optax also provides a number of common loss functions.loss_value=optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)
returnloss_value.mean()
deffit(params: optax.Params, optimizer: optax.GradientTransformation) ->optax.Params:
opt_state=optimizer.init(params)
#@jax.jitdefstep(params, opt_state, batch, labels):
loss_value, grads=jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state=optimizer.update(grads, opt_state, params)
params=optax.apply_updates(params, updates)
returnparams, opt_state, loss_valuefori, (batch, labels) inenumerate(zip(TRAINING_DATA, LABELS)):
params, opt_state, loss_value=step(params, opt_state, batch, labels)
ifi%100==0:
print(f'step {i}, loss: {loss_value}')
returnparams# Finally, we can fit our parametrized function using the Adam optimizer# provided by optax.optimizer=optax.radam(learning_rate=1e-2)
params=fit(initial_params, optimizer)
Of course this example is simple enough and does not saturate the RAM before a long time but this issue is really problematic in another particular research project.
Compile your code using jax.jit whenever possible to benefit from JAX's optimizations and potentially avoid the RAM issue.
Investigate RAdam Implementation:
If RAdam's performance is crucial for your research, consider:
Modifying RAdam's implementation to reduce memory footprint (if feasible).
Exploring alternative optimizers like Yogi, which share similarities with RAdam but might have different memory characteristics.
Report to Optax Maintainers:
Share your findings and code examples with the Optax maintainers to bring attention to the issue and potentially contribute to a fix.
Additional Considerations:
Memory Profiling: Use tools like jax.profiler or external profilers to track memory usage and identify bottlenecks.
Batch Size Adjustment: Experiment with smaller batch sizes to reduce memory requirements per step.
Hardware Constraints: Consider available RAM and potential hardware limitations.
I'm ready to assist further if you have more questions or require additional guidance. I'll be waiting for your positive response!!!
Hi,
I have my RAM getting used up to overflow when I use
scale_by_radam
gradient transform or equivalentlyoptax.radam
without JIT compiling the code. The problem appears on CPU and GPU but does not appear when I use JIT compilation. The problem does not seem to exist withoptax.adam
.Here is a MWE derived from optax quick start tutorial:
Of course this example is simple enough and does not saturate the RAM before a long time but this issue is really problematic in another particular research project.
The problem seems to be linked with this computation specific to RAdam: https://github.com/deepmind/optax/blob/fc5de3d3951c4dfd87513c6426e30baf505d89ae/optax/_src/transform.py#L685C7-L685C7. But I do not know how to investigate further.
Thanks for your feedback.
The text was updated successfully, but these errors were encountered: