Skip to contents

Compute balanced accuracy as additional score. Useful for imbalanced data. Only implemented for model with mutually exclusive targets.

Usage

balanced_acc_wrapper(num_targets, cm_dir)

Arguments

num_targets

Number of targets.

cm_dir

Directory of confusion matrix used to compute balanced accuracy.

Value

A keras metric.

Examples

if (FALSE) { # reticulate::py_module_available("tensorflow")

y_true <- c(1,0,0,1,
            0,1,0,0,
            0,0,1,0) %>% matrix(ncol = 3)
y_pred <- c(0.9,0.1,0.2,0.1,
            0.05,0.7,0.2,0.0,
            0.05,0.2,0.6,0.9) %>% matrix(ncol = 3)

cm_dir <- tempfile() 
dir.create(cm_dir)
bal_acc_metric <- balanced_acc_wrapper(num_targets = 3L, cm_dir = cm_dir)
bal_acc_metric$update_state(y_true, y_pred)
bal_acc_metric$result()
as.array(bal_acc_metric$cm)
}