Introduction
The Integrated Gradient (IG) method can be used to determine what parts of an input sequence are important for the models decision. We start with training a model that can differentiate sequences based on the GC content (as described in the Getting started tutorial).
Model Training
We create two simple dummy training and validation data sets. Both consist of random ACGT sequences but the first category has a probability of 40% each for drawing G or C and the second has equal probability for each nucleotide (first category has around 80% GC content and second one around 50%).
set.seed(123)
# Create data
vocabulary <- c("A", "C", "G", "T")
data_type <- c("train_1", "train_2", "val_1", "val_2")
for (i in 1:length(data_type)) {
temp_file <- tempfile()
assign(paste0(data_type[i], "_dir"), temp_file)
dir.create(temp_file)
if (i %% 2 == 1) {
header <- "label_1"
prob <- c(0.1, 0.4, 0.4, 0.1)
} else {
header <- "label_2"
prob <- rep(0.25, 4)
}
fasta_name_start <- paste0(header, "_", data_type[i], "file")
create_dummy_data(file_path = temp_file,
num_files = 1,
seq_length = 20000,
num_seq = 1,
header = header,
prob = prob,
fasta_name_start = fasta_name_start,
vocabulary = vocabulary)
}
# Create model
maxlen <- 50
model <- create_model_lstm_cnn(maxlen = maxlen,
filters = c(8, 16),
kernel_size = c(8, 8),
pool_size = c(3, 3),
layer_lstm = 8,
layer_dense = c(4, 2),
model_seed = 3)
# Train model
hist <- train_model(model,
train_type = "label_folder",
run_name = "gc_model_1",
path = c(train_1_dir, train_2_dir),
path_val = c(val_1_dir, val_2_dir),
epochs = 6,
batch_size = 64,
steps_per_epoch = 50,
step = 50,
vocabulary_label = c("high_gc", "equal_dist"))
plot(hist)
Integrated Gradient
We can try to visualize what parts of an input sequence is important for the models decision, using Integrated Gradient. Let’s create a sequence with a high GC content. We use same number of Cs as Gs and of As as Ts.
set.seed(321)
g_count <- 17
stopifnot(g_count < 25)
a_count <- (50 - (2*g_count))/2
high_gc_seq <- c(rep("G", g_count), rep("C", g_count), rep("A", a_count), rep("T", a_count))
high_gc_seq <- high_gc_seq[sample(maxlen)] %>% paste(collapse = "") # shuffle nt order
high_gc_seq
We need to one-hot encode the sequence before applying Integrated Gradient.
high_gc_seq_one_hot <- seq_encoding_label(char_sequence = high_gc_seq,
maxlen = 50,
start_ind = 1,
vocabulary = vocabulary)
head(high_gc_seq_one_hot[1,,])
Our model should be confident, this sequences belongs to the first class
pred <- predict(model, high_gc_seq_one_hot, verbose = 0)
colnames(pred) <- c("high_gc", "equal_dist")
pred
We can visualize what parts where important for the prediction.
ig <- integrated_gradients(
input_seq = high_gc_seq_one_hot,
target_class_idx = 1,
model = model)
if (requireNamespace("ComplexHeatmap", quietly = TRUE)) {
heatmaps_integrated_grad(integrated_grads = ig,
input_seq = high_gc_seq_one_hot)
} else {
message("Skipping ComplexHeatmap-related code because the package is not installed.")
}
We may test how our models prediction changes if we exchange certain nucleotides in the input sequence. First, we look for the positions with the smallest IG score.
We may change the nucleotide with the lowest score and observe the change in prediction confidence
# copy original sequence
high_gc_seq_one_hot_changed <- high_gc_seq_one_hot
# prediction for original sequence
predict(model, high_gc_seq_one_hot, verbose = 0)
# change nt
smallest_index <- which(ig == min(ig), arr.ind = TRUE)
smallest_index
row_index <- smallest_index[ , "row"]
col_index <- smallest_index[ , "col"]
new_row <- rep(0, 4)
nt_index_old <- col_index
nt_index_new <- which.max(ig[row_index, ])
new_row[nt_index_new] <- 1
high_gc_seq_one_hot_changed[1, row_index, ] <- new_row
cat("At position", row_index, "changing", vocabulary[nt_index_old], "to", vocabulary[nt_index_new], "\n")
pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0)
print(pred)
Let’s repeatedly apply the previous step and change the sequence after each iteration.
# copy original sequence
high_gc_seq_one_hot_changed <- high_gc_seq_one_hot
pred_list <- list()
pred_list[[1]] <- pred <- predict(model, high_gc_seq_one_hot, verbose = 0)
# change nts
for (i in 1:20) {
# update ig scores for changed input
ig <- integrated_gradients(
input_seq = high_gc_seq_one_hot_changed,
target_class_idx = 1,
model = model) %>% as.array()
smallest_index <- which(ig == min(ig), arr.ind = TRUE)
smallest_index
row_index <- smallest_index[ , "row"]
col_index <- smallest_index[ , "col"]
new_row <- rep(0, 4)
nt_index_old <- col_index
nt_index_new <- which.max(ig[row_index, ])
new_row[nt_index_new] <- 1
high_gc_seq_one_hot_changed[1, row_index, ] <- new_row
cat("At position", row_index, "changing", vocabulary[nt_index_old],
"to", vocabulary[nt_index_new], "\n")
pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0)
pred_list[[i + 1]] <- pred
}
pred_df <- do.call(rbind, pred_list)
pred_df <- data.frame(pred_df, iteration = 0:(nrow(pred_df) - 1))
names(pred_df) <- c("high_gc", "equal_dist", "iteration")
ggplot(pred_df, aes(x = iteration, y = high_gc)) + geom_line() + ylab("high GC confidence")
We can try the same in the opposite direction, i.e. replace big IG scores.
# copy original sequence
high_gc_seq_one_hot_changed <- high_gc_seq_one_hot
pred_list <- list()
pred <- predict(model, high_gc_seq_one_hot, verbose = 0)
pred_list[[1]] <- pred
# change nts
for (i in 1:20) {
# update ig scores for changed input
ig <- integrated_gradients(
input_seq = high_gc_seq_one_hot_changed,
target_class_idx = 1,
model = model) %>% as.array()
biggest_index <- which(ig == max(ig), arr.ind = TRUE)
biggest_index
row_index <- biggest_index[ , "row"]
row_index <- row_index[1]
col_index <- biggest_index[ , "col"]
new_row <- rep(0, 4)
nt_index_old <- col_index
nt_index_new <- which.min(ig[row_index, ])
new_row[nt_index_new] <- 1
high_gc_seq_one_hot_changed[1, row_index, ] <- new_row
cat("At position", row_index, "changing", vocabulary[nt_index_old], "to", vocabulary[nt_index_new], "\n")
pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0)
pred_list[[i + 1]] <- pred
}
pred_df <- do.call(rbind, pred_list)
pred_df <- data.frame(pred_df, iteration = 0:(nrow(pred_df) - 1))
names(pred_df) <- c("high_gc", "equal_dist", "iteration")
ggplot(pred_df, aes(x = iteration, y = high_gc)) + geom_line() + ylab("high GC confidence")