-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More information on how to convert / how the generated MLX models are done. #21
Comments
Hey @raoulritter, Thanks for reaching out. I'll go through the steps to convert MLX weights. It may not be the best practice, but that's how I did it for both SD3 and flux-schnell. Conversion Steps
model = MMDiT(config)
mlx_model = tree_flatten(model)
mlx_dict = {m[0]: m[1] for m in mlx_model if isinstance(m[1], mx.array)} # This is the dict format we want
weights = mx.load(flux_weights_ckpt)
weights = flux_state_dict_adjustments(
weights, prefix="", hidden_size=config.hidden_size, mlp_ratio=config.mlp_ratio
) # This is the dict we create from our dict_adjustments() function
weights_set = set(weights.keys())
mlx_dict_set = set(mlx_dict.keys())
for k in weights_set - mlx_dict_set:
print(k)
print(len(modified_dict_set - mlx_dict_set)) # This should print 0 if the conversion is correct
for k in mlx_dict_set - weights_set:
print(k)
print(len(mlx_dict_set - weights_set)) # This should print 0 if the conversion is correct
# Check the shapes of the converted weights
count = 0
for k in weights_set & mlx_dict_set:
if weights[k].shape != mlx_dict[k].shape:
print(k, weights[k].shape, mlx_dict[k].shape)
count += 1
print(count) # This should print 0 if the conversion is correct If we can load the model checkpoint, we can wire it up in our diffusion pipeline here. We also need to update this function as we need the negative text prompt for the I think this covers the gist of it. Let me know if you have any questions. Best, |
Hey @arda-argmax, Thanks so much for the detailed information.
Let me know if you think I missed anything. Would like to continue working on this. Best, Raoul |
Hey Raoul,
In this blog post, Here, FLUX developers also explicitly stated that only the This is how we handled prompt and negative text with guidance (
Nice! We don't actually use a saving function. After using your modified model.update(tree_unflatten(tree_flatten(weights))) We used the You can open a pull request to add the tensors to our organization.
When I was implementing the SD3 pipeline for DiffusionKit, I used mlx-examples as a reference, which was very helpful. I would first look at their lora implementation example in the same repo. I would also look at how others implemented LoRA for FLUX PyTorch.
It seems that you can retrieve the Thank you for your efforts. Best, |
Up voting since I'm interested in running it |
Hey @raoulritter, how is this going? We would like to publish FLUX.1-dev this week. We wanted to give you a chance to PR your work if it is ready to go this week. If not, we will go ahead and publish it next week. Please let us know, thanks! |
Hey @atiorh, My implementation is working. I hope to be able to make the PR this weekend. Apologies for the delay. |
Awesome to hear! Looking forward to it :) |
Can't wait to try! |
Hey @atiorh or @arda-argmax ,
First off, I just want to say that I love your projects. The mlx-FLUX.1-schnell is awesome!
I'm reaching out because I'm interested in contributing to the community by helping create an MLX version for the dev model. However, I'm a bit stuck on the conversion process.
What I'm Looking For:
This information would be invaluable not just for me, but potentially for others in the community who want to contribute to MLX conversions of other models.
Thanks in advance for any help or pointers you can provide! Keep up the great work!
The text was updated successfully, but these errors were encountered: