Skip to contents

Train a neural network on genomic data. Data can be fasta/fastq files, rds files or a prepared data set. If the data is given as collection of fasta, fastq or rds files, function will create a data generator that extracts training and validation batches from files. Function includes several options to determine the sampling strategy of the generator and preprocessing of the data. Training progress can be visualized in tensorboard. Model weights can be stored during training using checkpoints.

Usage

train_model(
  model = NULL,
  dataset = NULL,
  dataset_val = NULL,
  train_val_ratio = 0.2,
  run_name = "run_1",
  initial_epoch = 0,
  class_weight = NULL,
  print_scores = TRUE,
  epochs = 10,
  max_queue_size = 100,
  steps_per_epoch = 1000,
  path_checkpoint = NULL,
  path_tensorboard = NULL,
  path_log = NULL,
  save_best_only = NULL,
  save_weights_only = FALSE,
  tb_images = FALSE,
  path_file_log = NULL,
  reset_states = FALSE,
  early_stopping_time = NULL,
  validation_only_after_training = FALSE,
  train_val_split_csv = NULL,
  reduce_lr_on_plateau = TRUE,
  lr_plateau_factor = 0.9,
  patience = 20,
  cooldown = 1,
  model_card = NULL,
  callback_list = NULL,
  train_type = "label_folder",
  path = NULL,
  path_val = NULL,
  batch_size = 64,
  step = NULL,
  shuffle_file_order = TRUE,
  vocabulary = c("a", "c", "g", "t"),
  format = "fasta",
  ambiguous_nuc = "zero",
  seed = c(1234, 4321),
  file_limit = NULL,
  use_coverage = NULL,
  set_learning = NULL,
  proportion_entries = NULL,
  sample_by_file_size = FALSE,
  n_gram = NULL,
  n_gram_stride = 1,
  masked_lm = NULL,
  random_sampling = FALSE,
  add_noise = NULL,
  return_int = FALSE,
  maxlen = NULL,
  reverse_complement = FALSE,
  reverse_complement_encoding = FALSE,
  output_format = "target_right",
  proportion_per_seq = NULL,
  read_data = FALSE,
  use_quality_score = FALSE,
  padding = FALSE,
  concat_seq = NULL,
  target_len = 1,
  skip_amb_nuc = NULL,
  max_samples = NULL,
  added_label_path = NULL,
  add_input_as_seq = NULL,
  target_from_csv = NULL,
  target_split = NULL,
  shuffle_input = TRUE,
  vocabulary_label = NULL,
  delete_used_files = FALSE,
  reshape_xy = NULL,
  return_gen = FALSE
)

Arguments

model

A keras model.

dataset

List of training data holding training samples in RAM instead of using generator. Should be list with two entries called "X" and "Y".

dataset_val

List of validation data. Should have two entries called "X" and "Y".

train_val_ratio

For generator defines the fraction of batches that will be used for validation (compared to size of training data), i.e. one validation iteration processes batch_size \(*\) steps_per_epoch \(*\) train_val_ratio samples. If you use dataset instead of generator and dataset_val is NULL, splits dataset into train/validation data.

run_name

Name of the run. Name will be used to identify output from callbacks. If NULL, will use date as run name. If name already present, will add "_2" to name or "_{x+1}" if name ends with _x, where x is some integer.

initial_epoch

Epoch at which to start training. Note that network will run for (epochs - initial_epochs) rounds and not epochs rounds.

class_weight

List of weights for output. Order should correspond to vocabulary_label. You can use get_class_weight function to estimate class weights:

class_weights <- get_class_weights(path = path, train_type = train_type)

If train_type = "label_csv" you need to add path to csv file:

class_weights <- get_class_weights(path = path, train_type = train_type, csv_path = target_from_csv)

print_scores

Whether to print train/validation scores during training.

epochs

Number of iterations.

max_queue_size

Maximum size for the generator queue.

steps_per_epoch

Number of training batches per epoch.

path_checkpoint

Path to checkpoints folder or NULL. If NULL, checkpoints don't get stored.

path_tensorboard

Path to tensorboard directory or NULL. If NULL, training not tracked on tensorboard.

path_log

Path to directory to write training scores. File name is run_name + ".csv". No output if NULL.

save_best_only

Only save model that improved on some score. Not applied if argument is NULL. Otherwise must be list with argument monitor or save_freq (can only use one option). moniter specifies what metric to use. save_freq, integer specifying how often to store a checkpoint (in epochs).

save_weights_only

Whether to save weights only.

tb_images

Whether to show custom images (confusion matrix) in tensorboard "IMAGES" tab.

path_file_log

Write name of files used for training to csv file if path is specified.

reset_states

Whether to reset hidden states of RNN layer at every new input file and before/after validation.

early_stopping_time

Time in seconds after which to stop training.

validation_only_after_training

Whether to skip validation during training and only do one validation iteration after training.

train_val_split_csv

A csv file specifying train/validation split. csv file should contain one column named "file" and one column named "type". The "file" column contains names of fasta/fastq files and "type" column specifies if file is used for training or validation. Entries in "type" must be named "train" or "val", otherwise file will not be used for either. path and path_val arguments should be the same. Not implemented for train_type = "label_folder".

reduce_lr_on_plateau

Whether to use learning rate scheduler.

lr_plateau_factor

Factor of decreasing learning rate when plateau is reached.

patience

Number of epochs waiting for decrease in validation loss before reducing learning rate.

cooldown

Number of epochs without changing learning rate.

model_card

List of arguments for training parameters of training run. Must contain at least an entry path_model_card, i.e. the directory where parameters are stored. List can contain additional (optional) arguments, for example model_card = list(path_model_card = "/path/to/logs", description = "transfer learning with BERT model on virus data", ...)

callback_list

Add additional callbacks to keras::fit call.

train_type

Either "lm", "lm_rds", "masked_lm" for language model; "label_header", "label_folder", "label_csv", "label_rds" for classification or "dummy_gen".

  • Language model is trained to predict character(s) in a sequence.

  • "label_header"/"label_folder"/"label_csv" are trained to predict a corresponding class given a sequence as input.

  • If "label_header", class will be read from fasta headers.

  • If "label_folder", class will be read from folder, i.e. all files in one folder must belong to the same class.

  • If "label_csv", targets are read from a csv file. This file should have one column named "file". The targets then correspond to entries in that row (except "file" column). Example: if we are currently working with a file called "a.fasta" and corresponding label is "label_1", there should be a row in our csv file

    filelabel_1label_2
    "a.fasta"10
  • If "label_rds", generator will iterate over set of .rds files containing each a list of input and target tensors. Not implemented for model with multiple inputs.

  • If "lm_rds", generator will iterate over set of .rds files and will split tensor according to target_len argument (targets are last target_len nucleotides of each sequence).

  • If "dummy_gen", generator creates random data once and repeatedly feeds these to model.

  • If "masked_lm", generator maskes some parts of the input. See masked_lm argument for details.

path

Path to training data. If train_type is label_folder, should be a vector or list where each entry corresponds to a class (list elements can be directories and/or individual files). If train_type is not label_folder, can be a single directory or file or a list of directories and/or files.

path_val

Path to validation data. See path argument for details.

batch_size

Number of samples used for one network update.

step

Frequency of sampling steps.

shuffle_file_order

Boolean, whether to go through files sequentially or shuffle beforehand.

vocabulary

Vector of allowed characters. Characters outside vocabulary get encoded as specified in ambiguous_nuc.

format

File format, "fasta", "fastq", "rds" or "fasta.tar.gz", "fastq.tar.gz" for tar.gz files.

ambiguous_nuc

How to handle nucleotides outside vocabulary, either "zero", "discard", "empirical" or "equal".

  • If "zero", input gets encoded as zero vector.

  • If "equal", input is repetition of 1/length(vocabulary).

  • If "discard", samples containing nucleotides outside vocabulary get discarded.

  • If "empirical", use nucleotide distribution of current file.

seed

Sets seed for reproducible results.

file_limit

Integer or NULL. If integer, use only specified number of randomly sampled files for training. Ignored if greater than number of files in path.

use_coverage

Integer or NULL. If not NULL, use coverage as encoding rather than one-hot encoding and normalize. Coverage information must be contained in fasta header: there must be a string "cov_n" in the header, where n is some integer.

set_learning

When you want to assign one label to set of samples. Only implemented for train_type = "label_folder". Input is a list with the following parameters

  • samples_per_target: how many samples to use for one target.

  • maxlen: length of one sample.

  • reshape_mode: "time_dist", "multi_input" or "concat".

    • If reshape_mode is "multi_input", generator will produce samples_per_target separate inputs, each of length maxlen (model should have samples_per_target input layers).

    • If reshape_mode is "time_dist", generator will produce a 4D input array. The dimensions correspond to (batch_size, samples_per_target, maxlen, length(vocabulary)).

    • If reshape_mode is "concat", generator will concatenate samples_per_target sequences of length maxlen to one long sequence.

  • If reshape_mode is "concat", there is an additional buffer_len argument. If buffer_len is an integer, the subsequences are interspaced with buffer_len rows. The input length is (maxlen \(*\) samples_per_target) + buffer_len \(*\) (samples_per_target - 1).

proportion_entries

Proportion of fasta entries to keep. For example, if fasta file has 50 entries and proportion_entries = 0.1, will randomly select 5 entries.

sample_by_file_size

Sample new file weighted by file size (bigger files more likely).

n_gram

Integer, encode target not nucleotide wise but combine n nucleotides at once. For example for n=2, "AA" -> (1, 0,..., 0), "AC" -> (0, 1, 0,..., 0), "TT" -> (0,..., 0, 1), where the one-hot vectors have length length(vocabulary)^n.

n_gram_stride

Step size for n-gram encoding. For AACCGGTT with n_gram = 4 and n_gram_stride = 2, generator encodes (AACC), (CCGG), (GGTT); for n_gram_stride = 4 generator encodes (AACC), (GGTT).

masked_lm

If not NULL, input and target are equal except some parts of the input are masked or random. Must be list with the following arguments:

  • mask_rate: Rate of input to mask (rate of input to replace with mask token).

  • random_rate: Rate of input to set to random token.

  • identity_rate: Rate of input where sample weights are applied but input and output are identical.

  • include_sw: Whether to include sample weights.

  • block_len (optional): Masked/random/identity regions appear in blocks of size block_len.

random_sampling

Whether samples should be taken from random positions when using max_samples argument. If FALSE random samples are taken from a consecutive subsequence.

add_noise

NULL or list of arguments. If not NULL, list must contain the following arguments: noise_type can be "normal" or "uniform"; optional arguments sd or mean if noise_type is "normal" (default is sd=1 and mean=0) or min, max if noise_type is "uniform" (default is min=0, max=1).

return_int

Whether to return integer encoding or one-hot encoding.

maxlen

Length of predictor sequence.

reverse_complement

Boolean, for every new file decide randomly to use original data or its reverse complement.

reverse_complement_encoding

Whether to use both original sequence and reverse complement as two input sequences.

output_format

Determines shape of output tensor for language model. Either "target_right", "target_middle_lstm", "target_middle_cnn" or "wavenet". Assume a sequence "AACCGTA". Output correspond as follows

  • "target_right": X = "AACCGT", Y = "A"

  • "target_middle_lstm": X = (X_1 = "AAC", X_2 = "ATG"), Y = "C" (note reversed order of X_2)

  • "target_middle_cnn": X = "AACGTA", Y = "C"

  • "wavenet": X = "AACCGT", Y = "ACCGTA"

proportion_per_seq

Numerical value between 0 and 1. Proportion of sequence to take samples from (use random subsequence).

read_data

If TRUE the first element of output is a list of length 2, each containing one part of paired read. Maxlen should be 2*length of one read.

use_quality_score

Whether to use fastq quality scores. If TRUE input is not one-hot-encoding but corresponds to probabilities. For example (0.97, 0.01, 0.01, 0.01) instead of (1, 0, 0, 0).

padding

Whether to pad sequences too short for one sample with zeros.

concat_seq

Character string or NULL. If not NULL all entries from file get concatenated to one sequence with concat_seq string between them. Example: If 1.entry AACC, 2. entry TTTG and concat_seq = "ZZZ" this becomes AACCZZZTTTG.

target_len

Number of nucleotides to predict at once for language model.

skip_amb_nuc

Threshold of ambiguous nucleotides to accept in fasta entry. Complete entry will get discarded otherwise.

max_samples

Maximum number of samples to use from one file. If not NULL and file has more than max_samples samples, will randomly choose a subset of max_samples samples.

added_label_path

Path to file with additional input labels. Should be a csv file with one column named "file". Other columns should correspond to labels.

add_input_as_seq

Boolean vector specifying for each entry in added_label_path if rows from csv should be encoded as a sequence or used directly. If a row in your csv file is a sequence this should be TRUE. For example you may want to add another sequence, say ACCGT. Then this would correspond to 1,2,2,3,4 in csv file (if vocabulary = c("A", "C", "G", "T")). If add_input_as_seq is TRUE, 12234 gets one-hot encoded, so added input is a 3D tensor. If add_input_as_seq is FALSE this will feed network just raw data (a 2D tensor).

target_from_csv

Path to csv file with target mapping. One column should be called "file" and other entries in row are the targets.

target_split

If target gets read from csv file, list of names to divide target tensor into list of tensors. Example: if csv file has header names "file", "label_1", "label_2", "label_3" and target_split = list(c("label_1", "label_2"), "label_3"), this will divide target matrix to list of length 2, where the first element contains columns named "label_1" and "label_2" and the second entry contains the column named "label_3".

shuffle_input

Whether to shuffle entries in file.

vocabulary_label

Character vector of possible targets. Targets outside vocabulary_label will get discarded if train_type = "label_header".

delete_used_files

Whether to delete file once used. Only applies for rds files.

reshape_xy

Can be a list of functions to apply to input and/or target. List elements (containing the reshape functions) must be called x for input or y for target and each have arguments called x and y. For example: reshape_xy = list(x = function(x, y) {return(x+1)}, y = function(x, y) {return(x+y)}) . For rds generator needs to have an additional argument called sw.

return_gen

Whether to return the train and validation generators (instead of training).

Value

A list of training metrics.

Examples

if (FALSE) { # reticulate::py_module_available("tensorflow")
# create dummy data
path_train_1 <- tempfile()
path_train_2 <- tempfile()
path_val_1 <- tempfile()
path_val_2 <- tempfile()

for (current_path in c(path_train_1, path_train_2,
                       path_val_1, path_val_2)) {
  dir.create(current_path)
  create_dummy_data(file_path = current_path,
                    num_files = 3,
                    seq_length = 10,
                    num_seq = 5,
                    vocabulary = c("a", "c", "g", "t"))
}

# create model
model <- create_model_lstm_cnn(layer_lstm = 8, layer_dense = 2, maxlen = 5)

# train model
hist <- train_model(train_type = "label_folder",
                    model = model,
                    path = c(path_train_1, path_train_2),
                    path_val = c(path_val_1, path_val_2),
                    batch_size = 8,
                    epochs = 3,
                    steps_per_epoch = 6,
                    step = 5,
                    format = "fasta",
                    vocabulary_label = c("label_1", "label_2"))
 
}