Train a neural network on genomic data. Data can be fasta/fastq files, rds files or a prepared data set. If the data is given as collection of fasta, fastq or rds files, function will create a data generator that extracts training and validation batches from files. Function includes several options to determine the sampling strategy of the generator and preprocessing of the data. Training progress can be visualized in tensorboard. Model weights can be stored during training using checkpoints.
Usage
train_model(
model = NULL,
dataset = NULL,
dataset_val = NULL,
train_val_ratio = 0.2,
run_name = "run_1",
initial_epoch = 0,
class_weight = NULL,
print_scores = TRUE,
epochs = 10,
max_queue_size = 100,
steps_per_epoch = 1000,
path_checkpoint = NULL,
path_tensorboard = NULL,
path_log = NULL,
save_best_only = NULL,
save_weights_only = FALSE,
tb_images = FALSE,
path_file_log = NULL,
reset_states = FALSE,
early_stopping_time = NULL,
validation_only_after_training = FALSE,
train_val_split_csv = NULL,
reduce_lr_on_plateau = TRUE,
lr_plateau_factor = 0.9,
patience = 20,
cooldown = 1,
model_card = NULL,
callback_list = NULL,
train_type = "label_folder",
path = NULL,
path_val = NULL,
batch_size = 64,
step = NULL,
shuffle_file_order = TRUE,
vocabulary = c("a", "c", "g", "t"),
format = "fasta",
ambiguous_nuc = "zero",
seed = c(1234, 4321),
file_limit = NULL,
use_coverage = NULL,
set_learning = NULL,
proportion_entries = NULL,
sample_by_file_size = FALSE,
n_gram = NULL,
n_gram_stride = 1,
masked_lm = NULL,
random_sampling = FALSE,
add_noise = NULL,
return_int = FALSE,
maxlen = NULL,
reverse_complement = FALSE,
reverse_complement_encoding = FALSE,
output_format = "target_right",
proportion_per_seq = NULL,
read_data = FALSE,
use_quality_score = FALSE,
padding = FALSE,
concat_seq = NULL,
target_len = 1,
skip_amb_nuc = NULL,
max_samples = NULL,
added_label_path = NULL,
add_input_as_seq = NULL,
target_from_csv = NULL,
target_split = NULL,
shuffle_input = TRUE,
vocabulary_label = NULL,
delete_used_files = FALSE,
reshape_xy = NULL,
return_gen = FALSE
)
Arguments
- model
A keras model.
- dataset
List of training data holding training samples in RAM instead of using generator. Should be list with two entries called
"X"
and"Y"
.- dataset_val
List of validation data. Should have two entries called
"X"
and"Y"
.- train_val_ratio
For generator defines the fraction of batches that will be used for validation (compared to size of training data), i.e. one validation iteration processes
batch_size
\(*\)steps_per_epoch
\(*\)train_val_ratio
samples. If you use dataset instead of generator anddataset_val
isNULL
, splitsdataset
into train/validation data.- run_name
Name of the run. Name will be used to identify output from callbacks. If
NULL
, will use date as run name. If name already present, will add"_2"
to name or"_{x+1}"
if name ends with_x
, wherex
is some integer.- initial_epoch
Epoch at which to start training. Note that network will run for (
epochs
-initial_epochs
) rounds and notepochs
rounds.- class_weight
List of weights for output. Order should correspond to
vocabulary_label
. You can useget_class_weight
function to estimate class weights:class_weights <- get_class_weights(path = path, train_type = train_type)
If
train_type = "label_csv"
you need to add path to csv file:class_weights <- get_class_weights(path = path, train_type = train_type, csv_path = target_from_csv)
- print_scores
Whether to print train/validation scores during training.
- epochs
Number of iterations.
- max_queue_size
Maximum size for the generator queue.
- steps_per_epoch
Number of training batches per epoch.
- path_checkpoint
Path to checkpoints folder or
NULL
. IfNULL
, checkpoints don't get stored.- path_tensorboard
Path to tensorboard directory or
NULL
. IfNULL
, training not tracked on tensorboard.- path_log
Path to directory to write training scores. File name is
run_name
+".csv"
. No output ifNULL
.- save_best_only
Only save model that improved on some score. Not applied if argument is
NULL
. Otherwise must be list with argumentmonitor
orsave_freq
(can only use one option).moniter
specifies what metric to use.save_freq
, integer specifying how often to store a checkpoint (in epochs).- save_weights_only
Whether to save weights only.
- tb_images
Whether to show custom images (confusion matrix) in tensorboard "IMAGES" tab.
- path_file_log
Write name of files used for training to csv file if path is specified.
- reset_states
Whether to reset hidden states of RNN layer at every new input file and before/after validation.
- early_stopping_time
Time in seconds after which to stop training.
- validation_only_after_training
Whether to skip validation during training and only do one validation iteration after training.
- train_val_split_csv
A csv file specifying train/validation split. csv file should contain one column named
"file"
and one column named"type"
. The"file"
column contains names of fasta/fastq files and"type"
column specifies if file is used for training or validation. Entries in"type"
must be named"train"
or"val"
, otherwise file will not be used for either.path
andpath_val
arguments should be the same. Not implemented fortrain_type = "label_folder"
.- reduce_lr_on_plateau
Whether to use learning rate scheduler.
- lr_plateau_factor
Factor of decreasing learning rate when plateau is reached.
- patience
Number of epochs waiting for decrease in validation loss before reducing learning rate.
- cooldown
Number of epochs without changing learning rate.
- model_card
List of arguments for training parameters of training run. Must contain at least an entry
path_model_card
, i.e. the directory where parameters are stored. List can contain additional (optional) arguments, for examplemodel_card = list(path_model_card = "/path/to/logs", description = "transfer learning with BERT model on virus data", ...)
- callback_list
Add additional callbacks to
keras::fit
call.- train_type
Either
"lm"
,"lm_rds"
,"masked_lm"
for language model;"label_header"
,"label_folder"
,"label_csv"
,"label_rds"
for classification or"dummy_gen"
.Language model is trained to predict character(s) in a sequence.
"label_header"
/"label_folder"
/"label_csv"
are trained to predict a corresponding class given a sequence as input.If
"label_header"
, class will be read from fasta headers.If
"label_folder"
, class will be read from folder, i.e. all files in one folder must belong to the same class.If
"label_csv"
, targets are read from a csv file. This file should have one column named "file". The targets then correspond to entries in that row (except "file" column). Example: if we are currently working with a file called "a.fasta" and corresponding label is "label_1", there should be a row in our csv filefile label_1 label_2 "a.fasta" 1 0 If
"label_rds"
, generator will iterate over set of .rds files containing each a list of input and target tensors. Not implemented for model with multiple inputs.If
"lm_rds"
, generator will iterate over set of .rds files and will split tensor according totarget_len
argument (targets are lasttarget_len
nucleotides of each sequence).If
"dummy_gen"
, generator creates random data once and repeatedly feeds these to model.If
"masked_lm"
, generator maskes some parts of the input. Seemasked_lm
argument for details.
- path
Path to training data. If
train_type
islabel_folder
, should be a vector or list where each entry corresponds to a class (list elements can be directories and/or individual files). Iftrain_type
is notlabel_folder
, can be a single directory or file or a list of directories and/or files.- path_val
Path to validation data. See
path
argument for details.- batch_size
Number of samples used for one network update.
- step
Frequency of sampling steps.
- shuffle_file_order
Boolean, whether to go through files sequentially or shuffle beforehand.
- vocabulary
Vector of allowed characters. Characters outside vocabulary get encoded as specified in
ambiguous_nuc
.- format
File format,
"fasta"
,"fastq"
,"rds"
or"fasta.tar.gz"
,"fastq.tar.gz"
fortar.gz
files.- ambiguous_nuc
How to handle nucleotides outside vocabulary, either
"zero"
,"discard"
,"empirical"
or"equal"
.If
"zero"
, input gets encoded as zero vector.If
"equal"
, input is repetition of1/length(vocabulary)
.If
"discard"
, samples containing nucleotides outside vocabulary get discarded.If
"empirical"
, use nucleotide distribution of current file.
- seed
Sets seed for reproducible results.
- file_limit
Integer or
NULL
. If integer, use only specified number of randomly sampled files for training. Ignored if greater than number of files inpath
.- use_coverage
Integer or
NULL
. If notNULL
, use coverage as encoding rather than one-hot encoding and normalize. Coverage information must be contained in fasta header: there must be a string"cov_n"
in the header, wheren
is some integer.- set_learning
When you want to assign one label to set of samples. Only implemented for
train_type = "label_folder"
. Input is a list with the following parameterssamples_per_target
: how many samples to use for one target.maxlen
: length of one sample.reshape_mode
:"time_dist", "multi_input"
or"concat"
.If
reshape_mode
is"multi_input"
, generator will producesamples_per_target
separate inputs, each of lengthmaxlen
(model should havesamples_per_target
input layers).If reshape_mode is
"time_dist"
, generator will produce a 4D input array. The dimensions correspond to(batch_size, samples_per_target, maxlen, length(vocabulary))
.If
reshape_mode
is"concat"
, generator will concatenatesamples_per_target
sequences of lengthmaxlen
to one long sequence.
If
reshape_mode
is"concat"
, there is an additionalbuffer_len
argument. Ifbuffer_len
is an integer, the subsequences are interspaced withbuffer_len
rows. The input length is (maxlen
\(*\)samples_per_target
) +buffer_len
\(*\) (samples_per_target
- 1).
- proportion_entries
Proportion of fasta entries to keep. For example, if fasta file has 50 entries and
proportion_entries = 0.1
, will randomly select 5 entries.- sample_by_file_size
Sample new file weighted by file size (bigger files more likely).
- n_gram
Integer, encode target not nucleotide wise but combine n nucleotides at once. For example for
n=2, "AA" -> (1, 0,..., 0),
"AC" -> (0, 1, 0,..., 0), "TT" -> (0,..., 0, 1)
, where the one-hot vectors have lengthlength(vocabulary)^n
.- n_gram_stride
Step size for n-gram encoding. For AACCGGTT with
n_gram = 4
andn_gram_stride = 2
, generator encodes(AACC), (CCGG), (GGTT)
; forn_gram_stride = 4
generator encodes(AACC), (GGTT)
.- masked_lm
If not
NULL
, input and target are equal except some parts of the input are masked or random. Must be list with the following arguments:mask_rate
: Rate of input to mask (rate of input to replace with mask token).random_rate
: Rate of input to set to random token.identity_rate
: Rate of input where sample weights are applied but input and output are identical.include_sw
: Whether to include sample weights.block_len
(optional): Masked/random/identity regions appear in blocks of sizeblock_len
.
- random_sampling
Whether samples should be taken from random positions when using
max_samples
argument. IfFALSE
random samples are taken from a consecutive subsequence.- add_noise
NULL
or list of arguments. If notNULL
, list must contain the following arguments:noise_type
can be"normal"
or"uniform"
; optional argumentssd
ormean
if noise_type is"normal"
(default issd=1
andmean=0
) ormin, max
ifnoise_type
is"uniform"
(default ismin=0, max=1
).- return_int
Whether to return integer encoding or one-hot encoding.
- maxlen
Length of predictor sequence.
- reverse_complement
Boolean, for every new file decide randomly to use original data or its reverse complement.
- reverse_complement_encoding
Whether to use both original sequence and reverse complement as two input sequences.
- output_format
Determines shape of output tensor for language model. Either
"target_right"
,"target_middle_lstm"
,"target_middle_cnn"
or"wavenet"
. Assume a sequence"AACCGTA"
. Output correspond as follows"target_right": X = "AACCGT", Y = "A"
"target_middle_lstm": X = (X_1 = "AAC", X_2 = "ATG"), Y = "C"
(note reversed order of X_2)"target_middle_cnn": X = "AACGTA", Y = "C"
"wavenet": X = "AACCGT", Y = "ACCGTA"
- proportion_per_seq
Numerical value between 0 and 1. Proportion of sequence to take samples from (use random subsequence).
- read_data
If
TRUE
the first element of output is a list of length 2, each containing one part of paired read. Maxlen should be 2*length of one read.- use_quality_score
Whether to use fastq quality scores. If
TRUE
input is not one-hot-encoding but corresponds to probabilities. For example (0.97, 0.01, 0.01, 0.01) instead of (1, 0, 0, 0).- padding
Whether to pad sequences too short for one sample with zeros.
- concat_seq
Character string or
NULL
. If notNULL
all entries from file get concatenated to one sequence withconcat_seq
string between them. Example: If 1.entry AACC, 2. entry TTTG andconcat_seq = "ZZZ"
this becomes AACCZZZTTTG.- target_len
Number of nucleotides to predict at once for language model.
- skip_amb_nuc
Threshold of ambiguous nucleotides to accept in fasta entry. Complete entry will get discarded otherwise.
- max_samples
Maximum number of samples to use from one file. If not
NULL
and file has more thanmax_samples
samples, will randomly choose a subset ofmax_samples
samples.- added_label_path
Path to file with additional input labels. Should be a csv file with one column named "file". Other columns should correspond to labels.
- add_input_as_seq
Boolean vector specifying for each entry in
added_label_path
if rows from csv should be encoded as a sequence or used directly. If a row in your csv file is a sequence this should beTRUE
. For example you may want to add another sequence, say ACCGT. Then this would correspond to 1,2,2,3,4 in csv file (if vocabulary = c("A", "C", "G", "T")). Ifadd_input_as_seq
isTRUE
, 12234 gets one-hot encoded, so added input is a 3D tensor. Ifadd_input_as_seq
isFALSE
this will feed network just raw data (a 2D tensor).- target_from_csv
Path to csv file with target mapping. One column should be called "file" and other entries in row are the targets.
- target_split
If target gets read from csv file, list of names to divide target tensor into list of tensors. Example: if csv file has header names
"file", "label_1", "label_2", "label_3"
andtarget_split = list(c("label_1", "label_2"), "label_3")
, this will divide target matrix to list of length 2, where the first element contains columns named"label_1"
and"label_2"
and the second entry contains the column named"label_3"
.- shuffle_input
Whether to shuffle entries in file.
- vocabulary_label
Character vector of possible targets. Targets outside
vocabulary_label
will get discarded iftrain_type = "label_header"
.- delete_used_files
Whether to delete file once used. Only applies for rds files.
- reshape_xy
Can be a list of functions to apply to input and/or target. List elements (containing the reshape functions) must be called x for input or y for target and each have arguments called x and y. For example:
reshape_xy = list(x = function(x, y) {return(x+1)}, y = function(x, y) {return(x+y)})
. For rds generator needs to have an additional argument called sw.- return_gen
Whether to return the train and validation generators (instead of training).
Examples
if (FALSE) { # reticulate::py_module_available("tensorflow")
# create dummy data
path_train_1 <- tempfile()
path_train_2 <- tempfile()
path_val_1 <- tempfile()
path_val_2 <- tempfile()
for (current_path in c(path_train_1, path_train_2,
path_val_1, path_val_2)) {
dir.create(current_path)
create_dummy_data(file_path = current_path,
num_files = 3,
seq_length = 10,
num_seq = 5,
vocabulary = c("a", "c", "g", "t"))
}
# create model
model <- create_model_lstm_cnn(layer_lstm = 8, layer_dense = 2, maxlen = 5)
# train model
hist <- train_model(train_type = "label_folder",
model = model,
path = c(path_train_1, path_train_2),
path_val = c(path_val_1, path_val_2),
batch_size = 8,
epochs = 3,
steps_per_epoch = 6,
step = 5,
format = "fasta",
vocabulary_label = c("label_1", "label_2"))
}