Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory optimization for loading weights for no_std mode #2871

Open
HerrMuellerluedenscheid opened this issue Mar 5, 2025 · 3 comments
Open
Labels
enhancement Enhance existing features

Comments

@HerrMuellerluedenscheid
Copy link

Hey folks,

Thanks for building a genius framework! From a recent issue, you probably remember that I tried running the SqueezeNet example on an ESP32. I switched to an ESP32-S3 with 8 MB PSRAM. After failing to run it there due to allocation failures, I started a discussion in the #esp-rs:matrix.org chat, which turned out to be super fruitful (BIG shoutout!).

A few key findings and questions that I will try to summarize from the thread:

  1. Burn starts by taking the neural network, pushing it into a Vec, and then cloning it—leading to 5 MB of RAM usage for something that’s already readable in flash. More specifically, the first part is the generated squeezenet1.rs (generated in target//debug/build/squeezenet-burn-whatever/out/model):
impl<B: Backend> Model<B> {
    pub fn from_embedded(device: &B::Device) -> Self {
        let record = BinBytesRecorder::<HalfPrecisionSettings>::default()
            .load(EMBEDDED_STATES.to_vec(), device) // <-- here is an allocation that shouldn't be needed
            .expect("Should decode state successfully");
        Self::new(device).load_record(record)
    }
}
  1. burn-core, Recorder::load clones the model again. Is that really necessary?
    fn load<R>(&self, args: Self::LoadArgs, device: &B::Device) -> Result<R, RecorderError>
    where
        R: Record<B>,
    {
        let item: BurnRecord<R::Item<Self::Settings>, B> =
            self.load_item(args.clone()).map_err(|err| { // <-- here
  1. can burn operate directly off from the EMBEDDED_STATES, i.e. without copying the model to RAM

  2. Ideally they should split their model into "readonly" stuff and "readwrite" stuff and then the "readonly" stuff is used as-is. I.e. flashed un-decoded.

Apparently there is room for some improvement to run models in no_std which I guess will also be beneficial when running with std.
Looking forward hearing your thoughts.

@antimora
Copy link
Collaborator

antimora commented Mar 5, 2025

Yes, I agree there is a lot of room to improve for loading weights. We haven't focused on this initially. But I think this is perfect time. Since you're in the weeds of it, it would be great if you try using more efficient Rust APIs to consume the existing preallocated memory without duplicating (albeit temporarily). I know there are some. We will review your PR.

@antimora antimora added the enhancement Enhance existing features label Mar 5, 2025
@antimora antimora changed the title optimize burn for no_std Memory optimization for loading weights for no_std mode Mar 5, 2025
@BjornTheProgrammer
Copy link
Contributor

I have the exact same issues, and I noticed the same things with the Raspberry Pi Pico, I would be willing to tackle this issue as well with a team that I'm working with. Is there an active PR for this, or should I create one?

@HerrMuellerluedenscheid
Copy link
Author

HerrMuellerluedenscheid commented Mar 9, 2025

@BjornTheProgrammer I just pushed some experiments to #2881. On my ESP32s3 it does not panic because of allocation failures.
I'm getting this error now instead:
/crates/burn-core/src/record/memory.rs:39:85: called Result::unwrap()on anErrvalue: InvalidIntegerType { expected: U32, found: Reserved }
But I consider this already some success with regard to the memory. So, feel free to take a look, modify, be inspired :) Would love to have this working on some mcus.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Enhance existing features
Projects
None yet
Development

No branches or pull requests

3 participants