library(here)
library(tibble)
library(ranger)
library(caret)
library(ggplot2)

source(here("v2","src","evidencegeom.R"))

set.seed(42)

Dataset 3: UCI Heart Disease Cleveland

TL;DR

This notebook demonstrates the Evidence Geometry framework on the UCI Heart Disease Cleveland dataset.

Instead of relying only on classifier probability, the framework analyzes each case by collecting log likelihood ratios between classes by selecting marginal likelihood distribution families for each feature based on feature type and magnitude of log likelihood.

This log-likelihood ratio space (evidence space) is then analyzed using two geometric risk signals:

Validation Set Risk Map

Validation-set risk map.

Validation-set risk map in the low-RF probability region (p < 0.45).

Main Observation

Within the subset of cases assigned low probability by the Random Forest classifier, the two geometric signals reveal additional structure:

  • Boundary of ambiguity, d_dist = 0, occurs much earlier in the feature space than p = 0.5 classifier ambiguous region.
  • Dominant benign cluster is present in the region d_dist < 0, proj < 0.
  • Pathological cases are distinctly present among benign points that cross the benign central region of d_dist < 0 and proj < 0.

This suggests that geometric evidence can detect hidden risk structure earlier than classifier probability alone.

Test Set Result

With a conservative triage policy on the test set, the risk signals help achieve:

  • Identification of true ambiguity in data
  • Abstention of classification of ambiguous cases
  • 100% false negative capture rate through abstention
  • Maintain automated classification rate at 61%

Takeaway

d_dist and proj provide a compact and interpretable view of how risk accumulates geometrically, especially in cases where standard discriminative classifiers assign deceptively low probability.

Full Analysis

Comment Block : Data prep of raw UCI Heart Disease Cleveland data

# heartdisease <- read.csv(here("v2","src","data","heartdisease","processed.cleveland.data"))
# 
# names(heartdisease) <- c("age","sex","cp","trestbps","chol","fbs","restecg","thalach","exang","oldpeak","slope","ca","thal","num")
# 
# heartdisease$slope <- as.factor(
#   case_when(
#     heartdisease$slope == 1 ~ "Upsloping",
#     heartdisease$slope == 2 ~ "Flat",
#     heartdisease$slope == 3 ~ "Downsloping",
#     TRUE ~ "Unknown"
#   )
# )
# 
# 
# heartdisease$restecg <- as.factor(
#   case_when(
#     heartdisease$restecg == 0 ~ "Normal",
#     heartdisease$restecg == 1 ~ "ST-T Wave Abnormality",
#     heartdisease$restecg == 2 ~ "Probable Left Ventricular Hypertrophy",
#     TRUE ~ "Unknown"
#   )
# )
#   
# 
# heartdisease$cp <- as.factor(case_when(
#   heartdisease$cp == 1 ~ "Typical Angina",
#   heartdisease$cp == 2 ~ "Atypical Angina",
#   heartdisease$cp == 3 ~ "Non-anginal Pain",
#   heartdisease$cp == 4 ~ "Asymptomatic",
#   TRUE ~ "Unknown"
# ))
# 
# 
# heartdisease$thal <- as.factor(
#   case_when(
#     heartdisease$thal == "3.0" ~ "Normal",
#     heartdisease$thal == "6.0" ~ "Fixed Defect",
#     heartdisease$thal == "7.0" ~ "Reversible Defect",
#     TRUE ~ "Unknown"
#   )
# )
# 
# 
# heartdisease$ca <- as.integer(heartdisease$ca)
# 
# 
# heartdisease$fbs <- ifelse(heartdisease$fbs == 1, TRUE, FALSE)
# heartdisease$exang <- ifelse(heartdisease$exang == 1, TRUE, FALSE)
# 
# 
# heartdisease$sex <- as.factor(ifelse(heartdisease$sex == 1, "M", "F"))
# 
# 
# heartdisease <- drop_na(heartdisease)
# 
# 
# heartdisease$num <- as.factor(
#   case_when(
#     heartdisease$num < 1 ~ "Healthy",
#     heartdisease$num >= 1 ~ "Heart Disease",
#     TRUE ~ "Unknown"
#   )
# )
# heartdisease_train_idx <- caret::createDataPartition(heartdisease$num, p=0.5, list=FALSE)
# heartdisease_train <- heartdisease[heartdisease_train_idx, ]
# heartdisease_eval <- heartdisease[-heartdisease_train_idx, ]
# 
# heartdisease_val_idx <- createDataPartition(heartdisease_eval$num, p=0.5, list=FALSE)
# heartdisease_val <- heartdisease_eval[heartdisease_val_idx, ]
# heartdisease_test <- heartdisease_eval[-heartdisease_val_idx, ]
# write.csv(heartdisease_train, here("v2","src","data","heartdisease","heartdisease_train.csv"))
# write.csv(heartdisease_val, here("v2","src","data","heartdisease","heartdisease_val.csv"))
# write.csv(heartdisease_test, here("v2","src","data","heartdisease","heartdisease_test.csv"))

Load Test-Val-Train Splits

heartdisease_train <- read.csv(here("v2", "src", "data", "heartdisease", "heartdisease_train.csv"))
heartdisease_val <- read.csv(here("v2", "src", "data", "heartdisease", "heartdisease_val.csv"))
heartdisease_test <- read.csv(here("v2", "src", "data", "heartdisease", "heartdisease_test.csv"))
heartdisease_train[sapply(heartdisease_train, is.character)] <- lapply(heartdisease_train[sapply(heartdisease_train, is.character)], as.factor)
heartdisease_train$ca <- as.factor(heartdisease_train$ca)

heartdisease_val[sapply(heartdisease_val, is.character)] <- lapply(heartdisease_val[sapply(heartdisease_val, is.character)], as.factor)
heartdisease_val$ca <- as.factor(heartdisease_val$ca)

heartdisease_test[sapply(heartdisease_test, is.character)] <- lapply(heartdisease_test[sapply(heartdisease_test, is.character)], as.factor)
heartdisease_test$ca <- as.factor(heartdisease_test$ca)

Class Ratio to input for generating Evidence Space

alpha <- (nrow(heartdisease_train[heartdisease_train$num=="Heart Disease", ])/nrow(heartdisease_train))

Evidence Space Transformation

risk_spec : Object to store function argument values for evidence space generation

fit : Learn feature-wise marginal likelihoods, and geometries and eigenmodes of class manifolds in evidence space

loglik_matrices : Compute marginal positive-class and negative-class evidences, and marginal relative evidence for input data using fitted evidence generator

score_risk : Compute total feature-wise evidence, distance constrast, drift projection, and eigenmode energies

Generate Evidence Space

spec <- risk_spec(
   y_col="num", positive="Heart Disease",
   features = setdiff(names(heartdisease_train), c("num", "X")),
   alpha = alpha,
   laplace = 1, ridge = 1e-6, winsor_p=0.01,
   weights=FALSE,
   weight_method = "mutual_info",
   numeric_candidates = c("gaussian", "lognormal", "gamma"),
   numeric_val_frac = 0.2,
   numeric_min_n = 25,
   llr_cap_quantile = 0.01,
   mi_nbins = 10
 )

obj <- fit(heartdisease_train%>%select(-X), spec, k_eigen=2, k_energy=2, energy_ref="both")
print_feature_likelihoods(obj)
##     feature    class feature_type likelihood_family
## 1       age positive      numeric          gaussian
## 2       age negative      numeric          gaussian
## 3  trestbps positive      numeric         lognormal
## 4  trestbps negative      numeric         lognormal
## 5      chol positive      numeric         lognormal
## 6      chol negative      numeric         lognormal
## 7   thalach positive      numeric          gaussian
## 8   thalach negative      numeric          gaussian
## 9   oldpeak positive      numeric          gaussian
## 10  oldpeak negative      numeric          gaussian
## 11      sex positive  categorical       categorical
## 12      sex negative  categorical       categorical
## 13       cp positive  categorical       categorical
## 14       cp negative  categorical       categorical
## 15      fbs positive  categorical       categorical
## 16      fbs negative  categorical       categorical
## 17  restecg positive  categorical       categorical
## 18  restecg negative  categorical       categorical
## 19    exang positive  categorical       categorical
## 20    exang negative  categorical       categorical
## 21    slope positive  categorical       categorical
## 22    slope negative  categorical       categorical
## 23       ca positive  categorical       categorical
## 24       ca negative  categorical       categorical
## 25     thal positive  categorical       categorical
## 26     thal negative  categorical       categorical
##                                                                                     parameters
## 1                                                                      mean=56.6481, sd=7.4490
## 2                                                                      mean=52.0579, sd=8.7900
## 3                                                                 meanlog=4.8988, sdlog=0.1306
## 4                                                                 meanlog=4.8506, sdlog=0.1325
## 5                                                                 meanlog=5.5039, sdlog=0.1912
## 6                                                                 meanlog=5.4564, sdlog=0.1860
## 7                                                                    mean=138.6870, sd=22.7442
## 8                                                                    mean=157.1099, sd=18.6807
## 9                                                                       mean=1.5085, sd=1.3013
## 10                                                                      mean=0.5844, sd=0.7581
## 11                                                                          M=0.8310, F=0.1690
## 12                                                                          M=0.5488, F=0.4512
## 13 Asymptomatic=0.7671, Non-anginal Pain=0.1233, Atypical Angina=0.0685, Typical Angina=0.0411
## 14 Non-anginal Pain=0.3929, Asymptomatic=0.2976, Atypical Angina=0.2262, Typical Angina=0.0833
## 15                                                                   FALSE=0.8732, TRUE=0.1268
## 16                                                                   FALSE=0.8537, TRUE=0.1463
## 17   Probable Left Ventricular Hypertrophy=0.5278, Normal=0.4306, ST-T Wave Abnormality=0.0417
## 18   Normal=0.5783, Probable Left Ventricular Hypertrophy=0.4096, ST-T Wave Abnormality=0.0120
## 19                                                                   TRUE=0.5915, FALSE=0.4085
## 20                                                                   FALSE=0.8293, TRUE=0.1707
## 21                                           Flat=0.6528, Upsloping=0.2500, Downsloping=0.0972
## 22                                           Upsloping=0.6747, Flat=0.2530, Downsloping=0.0723
## 23                                                      1=0.3425, 0=0.3288, 2=0.2329, 3=0.0959
## 24                                                      0=0.8095, 1=0.1429, 2=0.0238, 3=0.0238
## 25                                Reversible Defect=0.6250, Normal=0.2222, Fixed Defect=0.1528
## 26                                Normal=0.7952, Reversible Defect=0.1566, Fixed Defect=0.0482
Lval <- loglik_matrices(heartdisease_val%>%select(-X), obj$fit, alpha=spec$alpha)
# val/test scoring
scores_val <- bind_cols(id=heartdisease_val$X, 
                            num=heartdisease_val$num, 
                            score_risk(Lval$l_pos, Lval$l_neg,  obj$fit$weights, obj$geom, 
                                       spec$alpha, spec$eps))

Learn baseline Random Forest and predict

rf_train_1 <- ranger(formula = num ~ ., data=heartdisease_train%>%select(-X),
                         mtry=4, min.node.size=5, num.trees=500, probability = TRUE,
                         keep.inbag = TRUE)
scores_val <- bind_cols(scores_val,
                            p = predict(rf_train_1, heartdisease_val%>%select(-X, -num))$predictions[,"Heart Disease"])

Sample from evidence output

l_pos : Sum of marginal positive evidence
l_neg : Sum of marginal negative evidence
l : Sum of marginal evidence (difference of positive and negative evidence)
proj : Projection of benign-relative z scores of evidences on mean deviation class separation direction learned from training
d_dist : Difference of Mahalanobis Distances of evidence vector from negative-class evidence manifold and positive-class evidence manifold
E_pos : Total energy (sum of squares of projections over k principal components of positive-class) of evidence vector
p : Random Forest probability score for positive-class

head(scores_val, 5)
##   id           num     l_pos     l_neg          l         t        proj
## 1  3       Healthy -32.72783 -31.81197 -0.9158537 -64.53980  0.19364106
## 2  8 Heart Disease -21.27192 -25.94749  4.6755715 -47.21941  0.95111451
## 3 13       Healthy -28.04745 -22.89964 -5.1478114 -50.94709 -0.07533242
## 4 14       Healthy -29.23021 -25.74143 -3.4887798 -54.97164  0.03387516
## 5 19       Healthy -27.38981 -20.55902 -6.8307984 -47.94883 -0.32935053
##       d_dist   l_norm      E_pos      E_neg         dE      eig_1      eig_2
## 1  0.7112194 6.459184 36.6215886 40.5967927 -3.9752041 -4.0419595  4.6754727
## 2  1.2833701 2.358627  1.3733554  8.0427772 -6.6694217 -0.7951663 -1.3908320
## 3 -0.4273574 2.912674  1.0661917  0.2625656  0.8036261  1.2476962  0.5504127
## 4 -0.3117642 2.752718  0.5386055  0.7794078 -0.2408023  1.0482897  0.3405261
## 5 -0.9436319 2.682278  0.9711189  1.1449501 -0.1738312  1.7888525  1.7904853
##           p
## 1 0.3388667
## 2 0.9379000
## 3 0.2757000
## 4 0.2797000
## 5 0.0198000

Feature Importance using ratio of feature-wise evidences

feature_importance(df=heartdisease_train, y_col="num", fit = obj$fit, method = "mutual_info", top_n = 15)
##     feature     weight abs_weight
## 1      thal 0.18129883 0.18129883
## 2        cp 0.15105583 0.15105583
## 3   thalach 0.12796960 0.12796960
## 4   oldpeak 0.12442872 0.12442872
## 5     slope 0.12368637 0.12368637
## 6        ca 0.12210565 0.12210565
## 7       age 0.08550260 0.08550260
## 8  trestbps 0.04523472 0.04523472
## 9      chol 0.02436100 0.02436100
## 10  restecg 0.01435668 0.01435668
## 11      sex 0.00000000 0.00000000
## 12      fbs 0.00000000 0.00000000
## 13    exang 0.00000000 0.00000000
ggplot(scores_val) +
  geom_histogram(aes(x=l, fill=num), alpha=0.6) +
  xlab("Sum of relative marginal evidence from all features") +
  theme_minimal()
## `stat_bin()` using `bins = 30`. Pick better value `binwidth`.

ggplot(scores_val) +
  geom_point(aes(x=qlogis(p), y=l, color=num), alpha=0.7) +
  xlab("log-odds of random forest heart disease probability (p)") +
  theme_minimal()

ggplot(scores_val) +
  geom_point(aes(x=qlogis(p), y=d_dist, color=num), alpha=0.7) +
  xlab("log-odds of random forest heart disease probability (p)") +
  theme_minimal()

ggplot(scores_val) +
  geom_point(aes(x=qlogis(p), y=proj, color=num), alpha=0.7) +
  xlab("log-odds of random forest heart disease probability (p)") +
  theme_minimal()

ggplot(scores_val) +
  geom_point(aes(x=d_dist, y=proj, color=num), alpha=0.7) +
  theme_minimal()

ggplot(scores_val%>%filter(d_dist < 0.7)) +
  geom_point(aes(x=d_dist, y=proj, color=num), alpha=0.7) +
  theme_minimal()

patchwork::wrap_plots(

ggplot(scores_val) +
  geom_point(aes(x=d_dist, y=E_pos, color=num), alpha=0.7) +
  theme_minimal(),

ggplot(scores_val) +
  geom_point(aes(x=proj, y=E_pos, color=num), alpha=0.7) +
  theme_minimal()
) + patchwork::plot_layout(guides="collect")

patchwork::wrap_plots(

ggplot(scores_val%>%filter(d_dist < 0.7, E_pos < 400)) +
  geom_point(aes(x=d_dist, y=E_pos, color=num), alpha=0.7) +
  theme_minimal(),

ggplot(scores_val%>%filter(d_dist < 0.7, E_pos < 400)) +
  geom_point(aes(x=proj, y=E_pos, color=num), alpha=0.7) +
  theme_minimal()
) + patchwork::plot_layout(guides="collect")

Weighted evidence matrix using optional weights. Default weights (when selected) : KL separation between positive and negative classes per evidence

Lval_w <- apply_llr_weights(Lval$L, obj$fit$weights)

Eigenmode Decomposition (PCA) of Heart Disease subset

Lval_w_M <- Lval_w[heartdisease_val$num == "Heart Disease", , drop = FALSE]
Lval_w_Sigma_M <- cov(Lval_w_M)
Lval_w_Sigma_M <- Lval_w_Sigma_M + diag(1e-6, ncol(Lval_w_Sigma_M))
Lval_w_eig_M <- eigen(Lval_w_Sigma_M, symmetric = TRUE)

Lval_w_eigvals_M  <- Lval_w_eig_M$values
Lval_w_eigvecs_M  <- Lval_w_eig_M$vectors

Lval_w_coords_M <- Lval_w_M %*% Lval_w_eigvecs_M[, 1:2]

Variance Explained

Lval_w_eigvals_M[1:2] / sum(Lval_w_eigvals_M[1:2])
## [1] 0.7778224 0.2221776

Top 5 feature loadings

decompose_eigenmode(Lval_w_eigvecs_M, k=1, feature_names = obj$fit$features, top_n=5)
## # A tibble: 5 × 3
##   feature  loading abs_loading
##   <chr>      <dbl>       <dbl>
## 1 oldpeak   0.977       0.977 
## 2 slope     0.129       0.129 
## 3 exang     0.108       0.108 
## 4 thal      0.0935      0.0935
## 5 trestbps  0.0511      0.0511
decompose_eigenmode(Lval_w_eigvecs_M, k=2, feature_names = obj$fit$features, top_n=5)
## # A tibble: 5 × 3
##   feature loading abs_loading
##   <chr>     <dbl>       <dbl>
## 1 ca        0.565       0.565
## 2 thal     -0.565       0.565
## 3 thalach   0.512       0.512
## 4 age       0.249       0.249
## 5 exang    -0.130       0.130

Eigenmode Decomposition (PCA) of Healthy subset

Lval_w_B <- Lval_w[heartdisease_val$num == "Healthy", , drop = FALSE]
Lval_w_Sigma_B <- cov(Lval_w_B)
Lval_w_Sigma_B <- Lval_w_Sigma_B + diag(1e-6, ncol(Lval_w_Sigma_B))
Lval_w_eig_B <- eigen(Lval_w_Sigma_B, symmetric = TRUE)

Lval_w_eigvals_B  <- Lval_w_eig_B$values
Lval_w_eigvecs_B  <- Lval_w_eig_B$vectors

Lval_w_coords_B <- Lval_w_B %*% Lval_w_eigvecs_B[, 1:2]

Variance Explained

Lval_w_eigvals_B[1:2] / sum(Lval_w_eigvals_B[1:2])
## [1] 0.5197891 0.4802109

Top 5 feature loadings

decompose_eigenmode(Lval_w_eigvecs_B, k=1, feature_names = obj$fit$features, top_n=5)
## # A tibble: 5 × 3
##   feature loading abs_loading
##   <chr>     <dbl>       <dbl>
## 1 ca        0.738       0.738
## 2 oldpeak  -0.452       0.452
## 3 age       0.340       0.340
## 4 slope    -0.290       0.290
## 5 sex      -0.174       0.174
decompose_eigenmode(Lval_w_eigvecs_B, k=2, feature_names = obj$fit$features, top_n=5)
## # A tibble: 5 × 3
##   feature loading abs_loading
##   <chr>     <dbl>       <dbl>
## 1 thal     -0.604       0.604
## 2 oldpeak   0.592       0.592
## 3 ca        0.303       0.303
## 4 sex      -0.292       0.292
## 5 slope     0.247       0.247

Cosine Similarity of Eigenmode 1 Heart Disease and Eigenmode 1 Healthy

acos((t(Lval_w_eigvecs_B[,1]) %*% Lval_w_eigvecs_M[,1]) / (sqrt(t(Lval_w_eigvecs_B[,1]) %*% Lval_w_eigvecs_B[,1]) * sqrt(t(Lval_w_eigvecs_M[,1]) %*% Lval_w_eigvecs_M[,1]))) * 180 / pi
##          [,1]
## [1,] 118.2988

Fitting Evidence Model and Random Forest on train+val set

obj_2 <- fit(bind_rows(heartdisease_train%>%select(-X),
                           heartdisease_val%>%select(-X)), spec, k_eigen=2, k_energy=2, energy_ref="both")
Ltest <- loglik_matrices(heartdisease_test%>%select(-X), obj_2$fit, alpha=spec$alpha)
# val/test scoring
scores_test <- bind_cols(X=heartdisease_test$X, 
                            num=heartdisease_test$num, 
                            
                            score_risk(Ltest$l_pos, Ltest$l_neg,  obj_2$fit$weights, obj_2$geom, 
                                       spec$alpha, spec$eps))
feature_importance(df=bind_rows(heartdisease_train%>%select(-X),
                           heartdisease_val%>%select(-X)), y_col="num", fit = obj_2$fit, method = "mutual_info", top_n = 15)
##     feature     weight abs_weight
## 1      thal 0.16758747 0.16758747
## 2   thalach 0.16458909 0.16458909
## 3        cp 0.16227125 0.16227125
## 4   oldpeak 0.15959587 0.15959587
## 5     slope 0.11678890 0.11678890
## 6        ca 0.08996638 0.08996638
## 7       age 0.08094144 0.08094144
## 8  trestbps 0.03139123 0.03139123
## 9      chol 0.02686837 0.02686837
## 10      sex 0.00000000 0.00000000
## 11      fbs 0.00000000 0.00000000
## 12  restecg 0.00000000 0.00000000
## 13    exang 0.00000000 0.00000000
print_feature_likelihoods(obj_2)
##     feature    class feature_type likelihood_family
## 1       age positive      numeric          gaussian
## 2       age negative      numeric          gaussian
## 3  trestbps positive      numeric         lognormal
## 4  trestbps negative      numeric         lognormal
## 5      chol positive      numeric         lognormal
## 6      chol negative      numeric         lognormal
## 7   thalach positive      numeric          gaussian
## 8   thalach negative      numeric          gaussian
## 9   oldpeak positive      numeric          gaussian
## 10  oldpeak negative      numeric          gaussian
## 11      sex positive  categorical       categorical
## 12      sex negative  categorical       categorical
## 13       cp positive  categorical       categorical
## 14       cp negative  categorical       categorical
## 15      fbs positive  categorical       categorical
## 16      fbs negative  categorical       categorical
## 17  restecg positive  categorical       categorical
## 18  restecg negative  categorical       categorical
## 19    exang positive  categorical       categorical
## 20    exang negative  categorical       categorical
## 21    slope positive  categorical       categorical
## 22    slope negative  categorical       categorical
## 23       ca positive  categorical       categorical
## 24       ca negative  categorical       categorical
## 25     thal positive  categorical       categorical
## 26     thal negative  categorical       categorical
##                                                                                     parameters
## 1                                                                      mean=56.7608, sd=7.7632
## 2                                                                      mean=52.6480, sd=9.0901
## 3                                                                 meanlog=4.8907, sdlog=0.1303
## 4                                                                 meanlog=4.8529, sdlog=0.1282
## 5                                                                 meanlog=5.4964, sdlog=0.1941
## 6                                                                 meanlog=5.4560, sdlog=0.1806
## 7                                                                    mean=138.5413, sd=22.4257
## 8                                                                    mean=158.2420, sd=18.1325
## 9                                                                       mean=1.5662, sd=1.2677
## 10                                                                      mean=0.5543, sd=0.7729
## 11                                                                          M=0.8208, F=0.1792
## 12                                                                          M=0.5574, F=0.4426
## 13 Asymptomatic=0.7407, Non-anginal Pain=0.1481, Atypical Angina=0.0556, Typical Angina=0.0556
## 14 Non-anginal Pain=0.3952, Asymptomatic=0.2823, Atypical Angina=0.2419, Typical Angina=0.0806
## 15                                                                   FALSE=0.8491, TRUE=0.1509
## 16                                                                   FALSE=0.8361, TRUE=0.1639
## 17   Probable Left Ventricular Hypertrophy=0.5888, Normal=0.3832, ST-T Wave Abnormality=0.0280
## 18   Normal=0.6179, Probable Left Ventricular Hypertrophy=0.3740, ST-T Wave Abnormality=0.0081
## 19                                                                   TRUE=0.5660, FALSE=0.4340
## 20                                                                   FALSE=0.8607, TRUE=0.1393
## 21                                           Flat=0.6449, Upsloping=0.2617, Downsloping=0.0935
## 22                                           Upsloping=0.6585, Flat=0.2683, Downsloping=0.0732
## 23                                                      0=0.3241, 1=0.3241, 2=0.2222, 3=0.1296
## 24                                                      0=0.7742, 1=0.1371, 2=0.0565, 3=0.0323
## 25                Reversible Defect=0.6019, Normal=0.2685, Fixed Defect=0.1111, Unknown=0.0185
## 26                Normal=0.7823, Reversible Defect=0.1694, Fixed Defect=0.0403, Unknown=0.0081
rf_train_2 <- ranger(formula = num ~ ., data=bind_rows(heartdisease_train%>%select(-X), heartdisease_val%>%select(-X)),
                         mtry=4, min.node.size=5, num.trees=500, probability = TRUE,
                         keep.inbag = TRUE)

Test Set Performance

scores_test <- bind_cols(scores_test,
                            p = predict(rf_train_2, heartdisease_test%>%select(-X))$predictions[,"Heart Disease"])

Triage Rule

If \(p \geq 0.65\) \(\rightarrow\) Heart Disease
Else If \(p > 0.45\) \(\rightarrow\) Review
Else
If \(proj > 0\) \(\rightarrow\) Review
Else If \(d\_dist \geq 0\) \(\rightarrow\) Review
Else \(\rightarrow\) Healthy

apply_triage <- function(df, p_hi = 0.65, p_review = 0.45, tau_d = 0, tau_p = 0) {
  df %>%
    mutate(
      action = case_when(
        p >= p_hi ~ "Disease",
        p > p_review ~ "Review",
        (p <= p_review) & (proj > tau_p) ~ "Review",
        (p <= p_review) & (d_dist >= tau_d) ~ "Review",
        TRUE ~ "Healthy"
      )
    )
}

triage_table <- function(df, truth_col, action_col = "action") {
  tab <- table(df[[action_col]], df[[truth_col]])
  as.data.frame.matrix(tab) %>%
    tibble::rownames_to_column(action_col)
}

Triage Results

test_triaged <- scores_test %>%
  apply_triage(p_hi = 0.65, p_review = 0.45, tau_d = 0, tau_p = 0)

triage_table(test_triaged, truth_col = "num", action_col = "action")
##    action Healthy Heart Disease
## 1 Disease       2            20
## 2 Healthy      22             1
## 3  Review      16            13
test_metrics <- function(df, truth_col, levels, p_col) {
  truth <- df[[truth_col]]
  stopifnot(all(levels(truth) %in% levels))

  benign_path <- df %>% filter(action == "Healthy")
  review_path <- df %>% filter(action == "Review")
  mal_path <- df %>% filter(action == "Disease")

  tibble(
    auto_benign_fn_rate = mean(benign_path[[truth_col]] == levels[2]),
    fn_capture_rate = 1 - mean(df[[truth_col]] == levels[2] & df$action == "Healthy"),
    overall_review_rate = mean(df$action == "Review"),
    benign_region_review_rate = mean(df[[p_col]] <= 0.45 & df$action == "Review") / mean(df[[p_col]] <= 0.45),
    malignant_region_fp_override_rate = mean(df[[p_col]] >= 0.65 & df[[truth_col]] == levels[1]) / mean(df[[p_col]] >= 0.65) # optional
  )
}

test_metrics(test_triaged, truth_col="num", p_col="p", levels=levels(heartdisease_test$num))
## # A tibble: 1 × 5
##   auto_benign_fn_rate fn_capture_rate overall_review_rate benign_region_review…¹
##                 <dbl>           <dbl>               <dbl>                  <dbl>
## 1              0.0435           0.986               0.392                  0.378
## # ℹ abbreviated name: ¹​benign_region_review_rate
## # ℹ 1 more variable: malignant_region_fp_override_rate <dbl>