-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DPO trainer example #172
Comments
Another issue , with ORPO trainer , |
So tried, the DPO Trainer test , python_test/trainers/dpo_test.py , but it is also getting stuck at How can I initialize from initial , I already have a SFT model , is there a way to transfer the parameters from trained model to initialized model , and the logic for this I couldn't understand, why it was necessary ? |
Hi |
@sparsh35, I have added a tutorial for the DPO trainer, and I am sorry if I was slow on that; I was checking some other things. https://github.com/erfanzar/EasyDeL/blob/main/notebooks/dpo-trainer.ipynb |
@erfanzar No problem at all , I know you are the sole maintainer of this project, things do get hectic. |
Still getting issue , with dtypes, Using same code as given in example. |
@sparsh35 hello and thanks for re-opening issue, i have that fixed rn |
@erfanzar Thanks , i also modified code for hinge loss and kto also to check, I think there is another bug related to jax array sharding , this code maybe isn't compatible with multi-host, I can't debug this error , can you help The above exception was the direct cause of the following exception: Traceback (most recent call last): |
@sparsh35 thanks ill fix that one asap |
can you try again and tell me if u still facing this issue |
No this is the error , Traceback (most recent call last): |
I haven't fix one yet... gime one week |
There is error in ORPO trainer as well, getting Nan losses, maybe I might have incorrect format of data , looking at the code of test and trainer, I think one problem is not all tokenizers have BOS token id , it is not handled here, even after debugging this I am gettting Nan losses, can't find the issue though , if you can please check it. |
@sparsh35 sure ill check and update orpo trainer |
Describe the bug
In trying DPO trainer example getting a bug with batch size and sharding , may be shard axis are not properly set or could be jax error as well , system used is V3 -32 , 4 hosts
To Reproduce
Steps to reproduce the behavior
Just run the dpo trainer examples
the error is this
File "/home/spars/.local/lib/python3.10/site-packages/easydel/modules/qwen2/modeling_qwen_flax.py", line 862, in call
output = pecs entry. Consider checking that in_specs are correct, and if so consider changing the mesh axis sizes or else padding the input and adapting 'functools.partial(<function ring_attention at 0x7f2e71b196c0>, axis_name='sp', float32_logits=True, platform=None, backend=None, autocheck=True, blocksize_c=None, blocksize_k=128, blocksize_q=128, dtype=<class 'jax.numpy.float32'>, softmax_scale=0.08838834764831843, deterministic=True, dropout_rng=None)' appropriately.
layer(
File "/home/spars/.local/lib/python3.10/site-packages/easydel/modules/qwen2/modeling_qwen_flax.py", line 502, in call
attn_outputs = self.self_attn(
File "/home/spars/.local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 568, in inner
return rematted(variable_groups, rng_groups, *dyn_args)
File "/home/spars/.local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 565, in rematted
y = fn(scope, *args)
File "/home/spars/.local/lib/python3.10/site-packages/easydel/modules/qwen2/modeling_qwen_flax.py", line 382, in call
attentions = self.attention_performer(
File "/home/spars/.local/lib/python3.10/site-packages/easydel/modules/attention_module.py", line 529, in call
return self.ring_attention(
File "/home/spars/.local/lib/python3.10/site-packages/easydel/modules/attention_module.py", line 649, in ring_attention
attn_output = shard_map(
ValueError: shard_map applied to the function 'functools.partial(<function ring_attention at 0x7f9c76b396c0>, axis_name='sp', float32_logits=True, platform=None, backend=None, autocheck=True, blocksize_c=None, blocksize_k=128, blocksize_q=128, dtype=<class 'jax.numpy.float32'>, softmax_scale=0.08838834764831843, deterministic=True, dropout_rng=None)' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:
The mesh given has shape (1, 32, 1, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').
args[0] of shape float32[1,1,28,128], where args[0] is bound to functools.partial(<function ring_attention at 0x7f9c76b396c0>, axis_name='sp', float32_logits=True, platform=None, backend=None, autocheck=True, blocksize_c=None, blocksize_k=128, blocksize_q=128, dtype=<class 'jax.numpy.float32'>, softmax_scale=0.08838834764831843, deterministic=True, dropout_rng=None)'s parameter 'query', corresponds to in_specs[0] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 32), but 32 does not evenly divide 1
args[1] of shape float32[1,1,28,128], where args[1] is bound to functools.partial(<function ring_attention at 0x7f9c76b396c0>, axis_name='sp', float32_logits=True, platform=None, backend=None, autocheck=True, blocksize_c=None, blocksize_k=128, blocksize_q=128, dtype=<class 'jax.numpy.float32'>, softmax_scale=0.08838834764831843, deterministic=True, dropout_rng=None)'s parameter 'key', corresponds to in_specs[1] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 32), but 32 does not evenly divide 1
args[2] of shape float32[1,1,28,128], where args[2] is bound to functools.partial(<function ring_attention at 0x7f9c76b396c0>, axis_name='sp', float32_logits=True, platform=None, backend=None, autocheck=True, blocksize_c=None, blocksize_k=128, blocksize_q=128, dtype=<class 'jax.numpy.float32'>, softmax_scale=0.08838834764831843, deterministic=True, dropout_rng=None)'s parameter 'value', corresponds to in_specs[2] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 32), but 32 does not evenly divide 1
args[3] of shape float32[1,1,1,1], where args[3] is bound to functools.partial(<function ring_attention at 0x7f9c76b396c0>, axis_name='sp', float32_logits=True, platform=None, backend=None, autocheck=True, blocksize_c=None, blocksize_k=128, blocksize_q=128, dtype=<class 'jax.numpy.float32'>, softmax_scale=0.08838834764831843, deterministic=True, dropout_rng=None)'s parameter 'bias', corresponds to in_specs[3] of value PartitionSpec(('dp', 'fsdp'), 'tp', 'sp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 32), but 32 does not evenly divide 1
The text was updated successfully, but these errors were encountered: