Skip to contents

Load checkpoint from directory. Chooses best checkpoint based on some condition. Condition can be best accuracy, best loss, last epoch number or a specified epoch number.

Usage

load_cp(
  cp_path,
  cp_filter = "last_ep",
  ep_index = NULL,
  compile = FALSE,
  learning_rate = 0.01,
  solver = "adam",
  re_compile = FALSE,
  loss = "categorical_crossentropy",
  add_custom_object = NULL,
  margin = 1,
  verbose = TRUE,
  mirrored_strategy = FALSE
)

Arguments

cp_path

A directory containing checkpoints or a single checkpoint file. If a directory, choose checkpoint based on cp_filter or ep_index.

cp_filter

Condition to choose checkpoint if cp_path is a directory. Either "acc" for best validation accuracy, "loss" for best validation loss or "last_ep" for last epoch.

ep_index

Load checkpoint from specific epoch number. If not NULL, has priority over cp_filter.

compile

Whether to load compiled model.

learning_rate

Learning rate for optimizer.

solver

Optimization method, options are "adam", "adagrad", "rmsprop" or "sgd".

re_compile

Whether to compile model with parameters from learning_rate, solver and loss.

loss

Loss function. Only used if model gets compiled.

add_custom_object

Named list of custom objects.

margin

Margin for contrastive loss, see loss_cl.

verbose

Whether to print chosen checkpoint path.

mirrored_strategy

Whether to use distributed mirrored strategy. If NULL, will use distributed mirrored strategy only if >1 GPU available.

Value

A keras model loaded from a checkpoint.

Examples

if (FALSE) { # reticulate::py_module_available("tensorflow")
model <- create_model_lstm_cnn(layer_lstm = 8)
checkpoint_folder <- tempfile()
dir.create(checkpoint_folder)
keras::save_model_hdf5(model, file.path(checkpoint_folder, 'Ep.007-val_loss11.07-val_acc0.6.hdf5'))
keras::save_model_hdf5(model, file.path(checkpoint_folder, 'Ep.019-val_loss8.74-val_acc0.7.hdf5'))
keras::save_model_hdf5(model, file.path(checkpoint_folder, 'Ep.025-val_loss0.03-val_acc0.8.hdf5'))
model <- load_cp(cp_path = checkpoint_folder, cp_filter = "last_ep")
}