Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Export / import functions to / from a file #1642

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft

Conversation

awni
Copy link
Member

@awni awni commented Dec 3, 2024

Adds export_function and import_function so that we can save and load functions from a file. Makes it possible to use functions written in one language from another language (e.g. Python -> C++).

Basically works like so:

In Python:

# Note, the model parameters are saved in the export function
# An alternative is to make them inputs to forward
def forward(x):
    return model(x)

example_x = mx.zeros(shape=(batch_size, input_dim))

# Export to file using example input
mx.export_function("model.mlxfn", forward, example_x)

Then in C++, for example:

  auto example_x = random::uniform({batch_size, input_dim});

  // Import the function
  auto forward = import_function("model.mlxfn");

  // Call the imported function
  auto out = forward({example_x})[0];

Some notes on the implementation:

  • Reuses a lot of the compile infrastructure which simplifies things dramatically
  • The serialization of everything is mostly decoupled from the rest of the code and kept in the export.cpp
  • Serializing primitives that have member variables requires some way of accessing them. The API is not opinionated about this (so the primitive interface didn't change at all).. but the convention I'm using is to have a state which returns the data to save
  • Likely can use templates / preprocessor to reduce more boiler-plate from some of serialization code in export.cpp. But didn't want to obfuscate / over engineer it much yet until getting some input.

@angeloskath
Copy link
Member

This is massively cool. I 'll get to reviewing asap!

// - constants, which can be used directly
// - a load primitive which has no inputs and will become a constant
// after the first eval
if (!a.has_primitive() || is_load(a.primitive())) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is worth commenting on:

  • Previously if you loaded arrays from a file inside a compiled function then every call to the function would reload from the file.
  • Now only the first call to the function loads and after that the loaded arrays become constants in the tape

This seems better to me.. though perhaps that is debatable. It is also used by import_function which makes Load primitives for constants and so we get lazy loading even with import_function which is pretty nice.

Lmk thoughts.. I can switch it so compile doesn't force load (more flexible but more dangerous).

Comment on lines +21 to +22
std::function<std::vector<array>(const std::vector<array>&)> import_function(
std::string path);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another question: should import_function return metadata? I can see how it would be useful to get say the shapes and/or dtypes of the inputs, maybe the MLX version, etc in a dict of metadata. Can also wait and see and provide an overload / a return_metadata flag in the future.

@@ -2098,21 +2147,6 @@ class Tanh : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};

class Uniform : public UnaryPrimitive {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused 🤷‍♂️ ..

@awni awni force-pushed the export_import branch 2 times, most recently from f255768 to fd6520c Compare December 9, 2024 20:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants