Skip to content

How to continue training a fitted model? #154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
travis-leith opened this issue Jan 31, 2025 · 7 comments
Open

How to continue training a fitted model? #154

travis-leith opened this issue Jan 31, 2025 · 7 comments

Comments

@travis-leith
Copy link

I want to use torch to create a custom driver for a concomitant model for flexmix.

The way a flexmix model works is that the individual clusters are fit, then the concomitant model is fit using the individual clusters fit, then the individual clusters are fit using the concomitant fit, etc. Eventually it stabilizes and the whole thing can be considered as fitted.

I want to try make the concomitant model stateful. Instead of refitting the concomitant model at each step, I would like to resume training from the previous concomitant step.

If I have a fitted model as in

fitted <- net %>%
  setup(
    loss = function(y_hat, y_true) nnf_cross_entropy(y_hat, y_true),
    optimizer = optim_adam
  ) %>%
  fit(train_ds, epochs = 10)

How do I pick up from fitted and continue training?

@travis-leith
Copy link
Author

I suspect the solution will involve some kind of callback. The checkpoint documentation seems to deal with recovering from a failure. In this case it is not a failure. Each run of the concomitant model will be considered a full successful run. I just want to use the model weights to train a new model with a slightly different set of inputs.

@dfalbel
Copy link
Member

dfalbel commented Jan 31, 2025

I think the best way here is to create a wrapper model, eg as described in #113 (comment)

In this case you train a model that wraps the 'base model', and re-uses it's weights.

@travis-leith
Copy link
Author

@dfalbel what does the next comment (#113 (comment)) refer to?

@dfalbel
Copy link
Member

dfalbel commented Feb 3, 2025

I think the next comment refers to continue training from a checkpoint, eg what's described in here: https://mlverse.github.io/luz/articles/checkpoints.html#resuming-training-runs-that-crashed Which seems whan the OP initially wanted.

@travis-leith
Copy link
Author

This is a snippet of my attempt at achieving this, but it does not work.

net <- 
    if ("torch_model_cache" %in% ls(envir = .GlobalEnv)) {
      message("retraining a model")
      nn_module(
        initialize = function() {
          self$model <- torch_model_cache
        },
        forward = function(x) {
          x |> self$model()
        }
      )
    } else {
      message("training a fresh model")
      nn_module(
        initialize = function() {
          self$fc1 <- nn_linear(in_dimension, 10)
          self$fc2 <- nn_linear(10, out_dimension)
        },
        forward = function(x) {
          x |> 
            self$fc1() |>
            nnf_relu() |>
            self$fc2() |>
            torch_squeeze()
        }
      )
    }
  
  fitted <- net |>
    setup(
      loss = function(y_hat, y_true) nnf_cross_entropy(y_hat, y_true),
      optimizer = optim_adam
    ) |>
    # set_opt_hparams(weight_decay = 0.001) |>
    fit(train_ds, epochs = 5)
  
  torch_model_cache <<- fitted

I get an error

Error in `value_error()`:
! Wrong parameters specification.

I will see if I can do it using the checkpoint concept. Maybe if I can get it to always use the same file name, instead of something related to the epoch number, then it can be done.

@travis-leith
Copy link
Author

The following seems to work quite well

resume_training_callback <- luz_callback(
    initialize = function(cache_file_name) {
      self$cache_file_name <- cache_file_name
    },
    on_fit_begin = function() {
      if (file.exists(self$cache_file_name)) {
        luz_load_model_weights(self$ctx, self$cache_file_name)
      }
    },
    on_fit_end = function() {
      #save the model weights
      luz_save_model_weights(self$ctx, self$cache_file_name)
    }
  )
  
  net <- 
      nn_module(
        initialize = function() {
          self$fc1 <- nn_linear(in_dimension, 10)
          self$fc2 <- nn_linear(10, out_dimension)
        },
        forward = function(x) {
          x |> 
            self$fc1() |>
            nnf_relu() |>
            self$fc2() |>
            torch_squeeze()
        }
      )
  
  fitted <- net |>
    setup(
      loss = function(y_hat, y_true) nnf_cross_entropy(y_hat, y_true),
      optimizer = optim_adam
    ) |>
    # set_opt_hparams(weight_decay = 0.001) |>
    fit(
      train_ds,
      epochs = 5,
      callbacks = list(resume_training_callback("model_weights.rds"))
    )

These callbacks are quite elegant. @dfalbel if you are interested, I can raise a PR to include this example in https://github.com/mlverse/luz/blob/main/vignettes/articles/checkpoints.Rmd.

@dfalbel
Copy link
Member

dfalbel commented Feb 4, 2025

Ohh, that looks really nice! It would be awesome to include in the docs!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants