[MJX] When using jax.sharding over batch axis, getting different results than on single device #1454
Replies: 3 comments
-
Hey @areiner222 , can you send a minimal repro? I haven't used jax sharding, so have not encountered this error before |
Beta Was this translation helpful? Give feedback.
-
Hey @btaba I don't have a simple repro on hand right now, but I was able to fix the nan issue - when initializing mjx.Data, I was using make_data + forward with some initial qpos /qvel inside vmap along the sharded batch dimension of the mjx.Model originally. I tried to split make_data and forward into two separate vmap calls and this seemed to fix the nans. I still see small divergences in state vs single device simulation which is very small at first and accumulates over time (but still minimal). Can close for now since the issue is resolved on my end, and happy to engage again if this comes up for anyone else. |
Beta Was this translation helpful? Give feedback.
-
Thanks @areiner222 , glad you found a workaround. In the past I have seen numerical precision differences depending on the way ops get called/compiled. You may want to mess with this flag |
Beta Was this translation helpful? Give feedback.
-
Hi,
Has anyone successfully used the jax sharding api along a batch dimension for mjx? When I batch my mjx.Model, I find that I have slightly different outputs when I run mjx.forward and subsequently step forward. With no ctrl, I find that the qpos results differ slightly and diverge over the course of stepping through. When I use ctrl (position actuators), I get nan values in the qfrc_actuator and therefore nans in qpos.
Happy to provide further context (don't have a repro on hand) but was wondering if anyone had encountered this already?
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions