-
Notifications
You must be signed in to change notification settings - Fork 14
How to continue training a fitted model? #154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I suspect the solution will involve some kind of callback. The checkpoint documentation seems to deal with recovering from a failure. In this case it is not a failure. Each run of the concomitant model will be considered a full successful run. I just want to use the model weights to train a new model with a slightly different set of inputs. |
I think the best way here is to create a wrapper model, eg as described in #113 (comment) In this case you train a model that wraps the 'base model', and re-uses it's weights. |
@dfalbel what does the next comment (#113 (comment)) refer to? |
I think the next comment refers to continue training from a checkpoint, eg what's described in here: https://mlverse.github.io/luz/articles/checkpoints.html#resuming-training-runs-that-crashed Which seems whan the OP initially wanted. |
This is a snippet of my attempt at achieving this, but it does not work. net <-
if ("torch_model_cache" %in% ls(envir = .GlobalEnv)) {
message("retraining a model")
nn_module(
initialize = function() {
self$model <- torch_model_cache
},
forward = function(x) {
x |> self$model()
}
)
} else {
message("training a fresh model")
nn_module(
initialize = function() {
self$fc1 <- nn_linear(in_dimension, 10)
self$fc2 <- nn_linear(10, out_dimension)
},
forward = function(x) {
x |>
self$fc1() |>
nnf_relu() |>
self$fc2() |>
torch_squeeze()
}
)
}
fitted <- net |>
setup(
loss = function(y_hat, y_true) nnf_cross_entropy(y_hat, y_true),
optimizer = optim_adam
) |>
# set_opt_hparams(weight_decay = 0.001) |>
fit(train_ds, epochs = 5)
torch_model_cache <<- fitted I get an error
I will see if I can do it using the checkpoint concept. Maybe if I can get it to always use the same file name, instead of something related to the epoch number, then it can be done. |
The following seems to work quite well resume_training_callback <- luz_callback(
initialize = function(cache_file_name) {
self$cache_file_name <- cache_file_name
},
on_fit_begin = function() {
if (file.exists(self$cache_file_name)) {
luz_load_model_weights(self$ctx, self$cache_file_name)
}
},
on_fit_end = function() {
#save the model weights
luz_save_model_weights(self$ctx, self$cache_file_name)
}
)
net <-
nn_module(
initialize = function() {
self$fc1 <- nn_linear(in_dimension, 10)
self$fc2 <- nn_linear(10, out_dimension)
},
forward = function(x) {
x |>
self$fc1() |>
nnf_relu() |>
self$fc2() |>
torch_squeeze()
}
)
fitted <- net |>
setup(
loss = function(y_hat, y_true) nnf_cross_entropy(y_hat, y_true),
optimizer = optim_adam
) |>
# set_opt_hparams(weight_decay = 0.001) |>
fit(
train_ds,
epochs = 5,
callbacks = list(resume_training_callback("model_weights.rds"))
) These callbacks are quite elegant. @dfalbel if you are interested, I can raise a PR to include this example in https://github.com/mlverse/luz/blob/main/vignettes/articles/checkpoints.Rmd. |
Ohh, that looks really nice! It would be awesome to include in the docs! |
I want to use torch to create a custom driver for a concomitant model for flexmix.
The way a flexmix model works is that the individual clusters are fit, then the concomitant model is fit using the individual clusters fit, then the individual clusters are fit using the concomitant fit, etc. Eventually it stabilizes and the whole thing can be considered as fitted.
I want to try make the concomitant model stateful. Instead of refitting the concomitant model at each step, I would like to resume training from the previous concomitant step.
If I have a fitted model as in
How do I pick up from fitted and continue training?
The text was updated successfully, but these errors were encountered: