Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More information on how to convert / how the generated MLX models are done. #21

Open
raoulritter opened this issue Aug 18, 2024 · 8 comments

Comments

@raoulritter
Copy link
Contributor

raoulritter commented Aug 18, 2024

Hey @atiorh or @arda-argmax ,

First off, I just want to say that I love your projects. The mlx-FLUX.1-schnell is awesome!
I'm reaching out because I'm interested in contributing to the community by helping create an MLX version for the dev model. However, I'm a bit stuck on the conversion process.
What I'm Looking For:

  1. Conversion Steps: Could you provide some guidance on how to convert FLUX models to MLX? Are the steps similar to the CoreML conversion process you've documented elsewhere?
  2. MLX Generation Process: I'd love to understand more about how the generated MLX models are created. Any insights into your workflow would be super helpful.
  3. Specific to FLUX.1-schnell: If there were any unique considerations or steps for converting FLUX.1-schnell to MLX, it'd be great to know about those.

This information would be invaluable not just for me, but potentially for others in the community who want to contribute to MLX conversions of other models.
Thanks in advance for any help or pointers you can provide! Keep up the great work!

@arda-argmax
Copy link
Collaborator

arda-argmax commented Aug 19, 2024

Hey @raoulritter,

Thanks for reaching out. I'll go through the steps to convert MLX weights. It may not be the best practice, but that's how I did it for both SD3 and flux-schnell.

Conversion Steps
Here is how we load flux model.

  1. Initialize the config for MMDiT module. We hardcode the config for flux-schnell here according to black-forest-labs repo. You can do a similar thing for the flux-dev version and initialize the config according to the inputs of the load_flux() function.
  2. Load the weights. Here you are downloading and loading the weights as a dict. We must adjust this dict since the layer names differ between our MLX model and black-forest-labs checkpoint.
  3. Adjust the weight dict. You can check if the adjustment is correct by running a code similar to this in a local jupyter notebook:
model = MMDiT(config)
mlx_model = tree_flatten(model)
mlx_dict = {m[0]: m[1] for m in mlx_model if isinstance(m[1], mx.array)} # This is the dict format we want

weights = mx.load(flux_weights_ckpt)
weights = flux_state_dict_adjustments(
    weights, prefix="", hidden_size=config.hidden_size, mlp_ratio=config.mlp_ratio
) # This is the dict we create from our dict_adjustments() function

weights_set = set(weights.keys())
mlx_dict_set = set(mlx_dict.keys())

for k in weights_set - mlx_dict_set:
    print(k)
print(len(modified_dict_set - mlx_dict_set)) # This should print 0 if the conversion is correct

for k in mlx_dict_set - weights_set:
    print(k)
print(len(mlx_dict_set - weights_set)) # This should print 0 if the conversion is correct

# Check the shapes of the converted weights
count = 0
for k in weights_set & mlx_dict_set:
    if weights[k].shape != mlx_dict[k].shape:
        print(k, weights[k].shape, mlx_dict[k].shape)
        count += 1
print(count) # This should print 0 if the conversion is correct

If we can load the model checkpoint, we can wire it up in our diffusion pipeline here. We also need to update this function as we need the negative text prompt for the flux-dev version.

I think this covers the gist of it. Let me know if you have any questions.

Best,
Arda

@raoulritter
Copy link
Contributor Author

raoulritter commented Aug 22, 2024

Hey @arda-argmax,

Thanks so much for the detailed information.

  1. Could you clarify why we need to add negative text prompt for the flux-dev version?

  2. I've been able to convert the model. But need some more help into how to save / load the model correctly in mlx.
    What is your saving function? Also when I complete would it be possible to add the tensors to the argmaxinc huggingface organization when done?

  3. After this would love pointers on how you would implement LORA for flux. Thanks in advance.

Keys in weights but not in model:
Count: 0

Keys in model but not in weights:
Count: 0

Shape mismatches:
Total mismatches: 0

Let me know if you think I missed anything. Would like to continue working on this.

Best,

Raoul

@arda-argmax
Copy link
Collaborator

arda-argmax commented Aug 25, 2024

Hey Raoul,
Sorry for the late answer.

  1. Could you clarify why we need to add negative text prompt for the flux-dev version?

In this blog post, FLUX.1 [dev] is introduced as a guidance-distilled model. This means we can give both our prompt and negative text to the pipeline as inputs for our image generation. We can use a guidance parameter to adjust the strength of the negative text. sd3-medium can also use guidance for the prompt and negative text. However, FLUX.1 [schnell] does not have this capability.

Here, FLUX developers also explicitly stated that only the [dev] model can use guidance.

This is how we handled prompt and negative text with guidance (cfg_weight in the code) in DiffusionKit for SD3.

  1. I've been able to convert the model. But need some more help into how to save / load the model correctly in mlx.
    What is your saving function? Also when I complete would it be possible to add the tensors to the argmaxinc huggingface organization when done?

Nice! We don't actually use a saving function. After using your modified adjust_dict function to retrieve weights dict, you can load the model with the following code as we did for the load_flux() function:

model.update(tree_unflatten(tree_flatten(weights)))

We used the load_flux() function here to wire-up our diffusion pipeline.

You can open a pull request to add the tensors to our organization.

  1. After this would love pointers on how you would implement LORA for flux. Thanks in advance.

When I was implementing the SD3 pipeline for DiffusionKit, I used mlx-examples as a reference, which was very helpful. I would first look at their lora implementation example in the same repo. I would also look at how others implemented LoRA for FLUX PyTorch.

Let me know if you think I missed anything. Would like to continue working on this.

It seems that you can retrieve the weights dict correctly. You can update the load_flux() function and try to generate an image to see how it performs.

Thank you for your efforts.

Best,
Arda

@mgierschdev
Copy link

Hey @raoulritter,

Thanks for reaching out. I'll go through the steps to convert MLX weights. It may not be the best practice, but that's how I did it for both SD3 and flux-schnell.

Conversion Steps

Here is how we load flux model.

  1. Initialize the config for MMDiT module. We hardcode the config for flux-schnell here according to black-forest-labs repo. You can do a similar thing for the flux-dev version and initialize the config according to the inputs of the load_flux() function.

  2. Load the weights. Here you are downloading and loading the weights as a dict. We must adjust this dict since the layer names differ between our MLX model and black-forest-labs checkpoint.

  3. Adjust the weight dict. You can check if the adjustment is correct by running a code similar to this in a local jupyter notebook:

model = MMDiT(config)

mlx_model = tree_flatten(model)

mlx_dict = {m[0]: m[1] for m in mlx_model if isinstance(m[1], mx.array)} # This is the dict format we want



weights = mx.load(flux_weights_ckpt)

weights = flux_state_dict_adjustments(

    weights, prefix="", hidden_size=config.hidden_size, mlp_ratio=config.mlp_ratio

) # This is the dict we create from our dict_adjustments() function



weights_set = set(weights.keys())

mlx_dict_set = set(mlx_dict.keys())



for k in weights_set - mlx_dict_set:

    print(k)

print(len(modified_dict_set - mlx_dict_set)) # This should print 0 if the conversion is correct



for k in mlx_dict_set - weights_set:

    print(k)

print(len(mlx_dict_set - weights_set)) # This should print 0 if the conversion is correct



# Check the shapes of the converted weights

count = 0

for k in weights_set & mlx_dict_set:

    if weights[k].shape != mlx_dict[k].shape:

        print(k, weights[k].shape, mlx_dict[k].shape)

        count += 1

print(count) # This should print 0 if the conversion is correct

If we can load the model checkpoint, we can wire it up in our diffusion pipeline here. We also need to update this function as we need the negative text prompt for the flux-dev version.

I think this covers the gist of it. Let me know if you have any questions.

Best,

Arda

Up voting since I'm interested in running it

@atiorh
Copy link
Contributor

atiorh commented Aug 29, 2024

Hey @raoulritter, how is this going? We would like to publish FLUX.1-dev this week. We wanted to give you a chance to PR your work if it is ready to go this week. If not, we will go ahead and publish it next week. Please let us know, thanks!

@raoulritter
Copy link
Contributor Author

Hey @atiorh,

My implementation is working. I hope to be able to make the PR this weekend. Apologies for the delay.

@atiorh
Copy link
Contributor

atiorh commented Aug 29, 2024

Awesome to hear! Looking forward to it :)

@arda-argmax
Copy link
Collaborator

Can't wait to try!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants