Skip to contents

Function takes a model as input and removes all layers after a certain layer, specified in layer_name argument. Optional to add dense layers on top of pruned model. Model can have multiple output layers with separate loss/activation functions. You can freeze all the weights of the pruned model by setting freeze_base_model = TRUE.

Usage

remove_add_layers(
  model = NULL,
  layer_name = NULL,
  dense_layers = NULL,
  shared_dense_layers = NULL,
  last_activation = list("softmax"),
  output_names = NULL,
  losses = NULL,
  verbose = TRUE,
  dropout = NULL,
  dropout_shared = NULL,
  freeze_base_model = FALSE,
  compile = FALSE,
  learning_rate = 0.001,
  solver = "adam",
  flatten = FALSE,
  global_pooling = NULL,
  model_seed = NULL,
  mixed_precision = FALSE,
  mirrored_strategy = NULL
)

Arguments

model

A keras model.

layer_name

Name of last layer to use from old model.

dense_layers

List of vectors specifying number of units for each dense layer. If this is a list of length > 1, model has multiple output layers.

shared_dense_layers

Vector with number of units for dense layer. These layers will be connected on top of layer in argument layer_name. Can be used to have shared dense layers, before model has multiple output layers. Don't use if model has just one output layer (use only dense_layers).

last_activation

List of activations for last entry for each list entry from dense_layers. Either "softmax", "sigmoid" or "linear".

output_names

List of names for each output layer.

losses

List of loss function for each output.

verbose

Boolean.

dropout

List of vectors with dropout rates for each new dense layer.

dropout_shared

Vectors of dropout rates for dense layer from shared_dense_layers.

freeze_base_model

Whether to freeze all weights before new dense layers.

compile

Boolean, whether to compile the new model.

learning_rate

Learning rate if compile = TRUE, default learning rate of the old model.

solver

Optimization method, options are "adam", "adagrad", "rmsprop" or "sgd".

flatten

Whether to add flatten layer before new dense layers.

global_pooling

"max_ch_first" for global max pooling with channel first (keras docs), "max_ch_last" for global max pooling with channel last, "average_ch_first" for global average pooling with channel first, "average_ch_last" for global average pooling with channel last or NULL for no global pooling. "both_ch_first" or "both_ch_last" to combine average and max pooling. "all" for all 4 options at once.

model_seed

Set seed for model parameters in tensorflow if not NULL.

mixed_precision

Whether to use mixed precision (https://www.tensorflow.org/guide/mixed_precision).

mirrored_strategy

Whether to use distributed mirrored strategy. If NULL, will use distributed mirrored strategy only if >1 GPU available.

Value

A keras model; added and/or removed layers from some base model.

Examples

if (FALSE) { # reticulate::py_module_available("tensorflow")
model_1 <- create_model_lstm_cnn(layer_lstm = c(64, 64),
                                 maxlen = 50,
                                 layer_dense = c(32, 4), 
                                 verbose = FALSE)
# get name of second to last layer 
num_layers <- length(model_1$get_config()$layers)
layer_name <- model_1$get_config()$layers[[num_layers-1]]$name
# add dense layer with multi outputs and separate loss/activation functions
model_2 <- remove_add_layers(model = model_1,
                             layer_name = layer_name,
                             dense_layers = list(c(32, 16, 1), c(8, 1), c(12, 5)),
                             losses = list("binary_crossentropy", "mae",
                                           "categorical_crossentropy"),
                             last_activation = list("sigmoid", "linear", "softmax"),
                             freeze_base_model = TRUE,
                             output_names = list("out_1_binary_classsification", 
                                                 "out_2_regression", 
                                                 "out_3_classification")
) 
}