Count number of nucleotides for each class and use as estimation for relation of class distribution.
Outputs list of class relations. Can be used as input for class_weigth
in train_model
function.
Usage
get_class_weight(
path,
vocabulary_label = NULL,
format = "fasta",
file_proportion = 1,
train_type = "label_folder",
named_list = FALSE,
csv_path = NULL
)
Arguments
- path
Path to training data. If
train_type
islabel_folder
, should be a vector or list where each entry corresponds to a class (list elements can be directories and/or individual files). Iftrain_type
is notlabel_folder
, can be a single directory or file or a list of directories and/or files.- vocabulary_label
Character vector of possible targets. Targets outside
vocabulary_label
will get discarded.- format
File format, either
"fasta"
or"fastq"
.- file_proportion
Proportion of files to randomly sample for estimating class distributions.
- train_type
Either
"lm"
,"lm_rds"
,"masked_lm"
for language model;"label_header"
,"label_folder"
,"label_csv"
,"label_rds"
for classification or"dummy_gen"
.Language model is trained to predict character(s) in a sequence.
"label_header"
/"label_folder"
/"label_csv"
are trained to predict a corresponding class given a sequence as input.If
"label_header"
, class will be read from fasta headers.If
"label_folder"
, class will be read from folder, i.e. all files in one folder must belong to the same class.If
"label_csv"
, targets are read from a csv file. This file should have one column named "file". The targets then correspond to entries in that row (except "file" column). Example: if we are currently working with a file called "a.fasta" and corresponding label is "label_1", there should be a row in our csv filefile label_1 label_2 "a.fasta" 1 0 If
"label_rds"
, generator will iterate over set of .rds files containing each a list of input and target tensors. Not implemented for model with multiple inputs.If
"lm_rds"
, generator will iterate over set of .rds files and will split tensor according totarget_len
argument (targets are lasttarget_len
nucleotides of each sequence).If
"dummy_gen"
, generator creates random data once and repeatedly feeds these to model.If
"masked_lm"
, generator maskes some parts of the input. Seemasked_lm
argument for details.
- named_list
Whether to give class weight list names
"0", "1", ...
or not.- csv_path
If
train_type = "label_csv"
, path to csv file containing labels.
Examples
# create dummy data
path_1 <- tempfile()
path_2 <- tempfile()
for (current_path in c(path_1, path_2)) {
dir.create(current_path)
# create twice as much data for first class
num_files <- ifelse(current_path == path_1, 6, 3)
create_dummy_data(file_path = current_path,
num_files = num_files,
seq_length = 10,
num_seq = 5,
vocabulary = c("a", "c", "g", "t"))
}
class_weight <- get_class_weight(
path = c(path_1, path_2),
vocabulary_label = c("A", "B"),
format = "fasta",
file_proportion = 1,
train_type = "label_folder",
csv_path = NULL)
class_weight
#> $`0`
#> A
#> 0.75
#>
#> $`1`
#> B
#> 1.5
#>