MATH 427: Class Imbalance

Eric Friedlander

Computational Set-Up

library(tidyverse)
library(tidymodels)
library(knitr)
library(kableExtra)

tidymodels_prefer()

set.seed(427)

Exploring with App

  • App
    • Break into groups
    • Investigate how your performance metrics change between balanced data and unbalanced data
    • Additional Considerations:
      • Impact of boundaries/models?
      • Impact of sample size?
      • Impact of noise level?
    • Please write down observations so we can discuss

Dealing with Class-Imbalance

Class-Imbalance

  • Class-imbalance occurs where your the classes in your response greatly differ in terms of how common they are
  • Occurs frequently:
    • Medicine: survival/death
    • Admissions: enrollment/non-enrollment
    • Finance: repaid loan/defaulted
    • Tech: Clicked on ad/Didn’t click
    • Tech: Churn rate
    • Finance: Fraud

Data: haberman

Study conducted between 1958 and 1970 at the University of Chicago’s Billings Hospital on the survival of patients who had undergone surgery for breast cancer.

Goal: predict whether a patient survived after undergoing surgery for breast cancer.

haberman <- read_csv("../data/haberman.data",
                     col_names = c("Age", "OpYear", "AxNodes", "Survival"))
haberman |> head() |> kable()
Age OpYear AxNodes Survival
30 64 1 1
30 62 3 1
30 65 0 1
31 59 2 1
31 65 4 1
33 58 10 1

Quick Clean

haberman <- haberman |> 
  mutate(Survival = factor(if_else(Survival == 1, "Survived", "Died"),
                           levels = c("Died", "Survived")))
haberman |> head() |> kable()
Age OpYear AxNodes Survival
30 64 1 Survived
30 62 3 Survived
30 65 0 Survived
31 59 2 Survived
31 65 4 Survived
33 58 10 Survived

Split Data

set.seed(427)
hab_splits <- initial_split(haberman, prop = 0.75, strata = Survival)
hab_train <- training(hab_splits)
hab_test <- testing(hab_splits)

Visualizing Response

hab_train |> 
  ggplot(aes(y = Survival)) +
  geom_bar()

Fitting Model

lr_model <- logistic_reg() |> 
  set_engine("glm")

lr_fit <- lr_model |> 
  fit(Survival ~ . , data = hab_train)

Confusion Matrix

lr_fit |> augment(new_data = hab_test) |> 
  conf_mat(truth = Survival, estimate = .pred_class) |> autoplot("heatmap")

Performance Metrics

hab_metrics <- metric_set(accuracy, precision, recall)

lr_fit |> augment(new_data = hab_test) |> 
  roc_auc(truth = Survival, .pred_Died) |> kable()
.metric .estimator .estimate
roc_auc binary 0.7284879
lr_fit |> augment(new_data = hab_test) |> 
  hab_metrics(truth = Survival, estimate = .pred_class) |> kable()
.metric .estimator .estimate
accuracy binary 0.7692308
precision binary 0.8000000
recall binary 0.1904762

Recall is BAD!

  • Since there are so few deaths, model always predicts a low probability of death
  • Idea: just because you you have a HIGHER probability of death doesn’t mean have a HIGH probability of death

What do we do?

  • Depends on what your goal is…
  • Ask yourself: What is most important to my problem?
    • Accurate probabilities?
    • Overall accuracy?
    • Effective identification of a specific class (e.g. positives)?
    • Low false-positive rate?
  • Discussion: Let’s think of scenarios where each one of these is the most important.

Solutions to Class Imbalance

  • Sampling-based solutions (done during pre-processing)
    • Over-sample minority class
    • Under-sample majority class
    • Combination of both (e.g. SMOTE)
  • Weight class/objective function

Over-sampling minority class

  • Upsample: think bootstrapping for final sample is larger than original
  • Idea: upsample minority class until it is same size(ish) as majority class

Visualizing Data

hab_train |> ggplot(aes(x = OpYear, y = Age, color = Survival)) +
  geom_jitter()

Upsample Recipe

library(themis)
upsample_recipe <- recipe(Survival ~ ., data = hab_train) |> 
  step_upsample(Survival, over_ratio = 1)

hab_upsample <- upsample_recipe |> prep(hab_train) |> bake(new_data = NULL)

Upsampled Data

hab_upsample |>  ggplot(aes(x = Survival)) +
  geom_bar()

Visualizing Upsampled Data

hab_upsample |>  ggplot(aes(x = OpYear, y = Age, color = Survival)) +
  geom_jitter()

Visualizing Upsampled Data: No Jitter

hab_upsample |>  ggplot(aes(x = OpYear, y = Age, color = Survival)) +
  geom_point()

Performance consideration

  • Pro:
    • Preserves all information in the data set
  • Con:
    • Models will probably over-align to the noise in the minority class

Under-sampling majority class

  • Downsample: collect a random sample smaller than the original sample
  • Idea: down sample majority class until it is same size(ish) as minority class

Visualizing Data

hab_train |> ggplot(aes(x = OpYear, y = Age, color = Survival)) +
  geom_jitter()

Downsample Recipe

downsample_recipe <- recipe(Survival ~ ., data = hab_train) |> 
  step_downsample(Survival, under_ratio = 1)

hab_downsample <- downsample_recipe |> prep(hab_train) |> bake(new_data = NULL)

Downsample Data

hab_downsample |>  ggplot(aes(x = Survival)) +
  geom_bar()

Visualizing Downsampled Data

hab_downsample |>  ggplot(aes(x = OpYear, y = Age, color = Survival)) +
  geom_jitter()

Visualizing Downsampled Data: No Jitter

hab_downsample |>  ggplot(aes(x = OpYear, y = Age, color = Survival)) +
  geom_point()

Performance considerations

  • Pro:
    • Model doesn’t over-align to noise in minority class
  • Con:
    • Lose information from majority class

SMOTE

  • Basic idea:
    • Both upsample minority and downsample majority (Tidymodel implementation only upsamples)
  • Better Upsampling: Instead of just randomly replicating minority observations
    • Find (minority) nearest neighbors of each minority observation
    • Interpolate line between them
    • Upsample by randomly generating points in interpolated lines

Visualizing Data

hab_train |> ggplot(aes(x = OpYear, y = Age, color = Survival)) +
  geom_jitter()

SMOTE Recipe

smote_recipe <- recipe(Survival ~ ., data = hab_train) |> 
  step_normalize(all_numeric_predictors()) |> 
  step_smote(Survival, over_ratio = 1, neighbors = 5)

hab_smote <- smote_recipe |> prep(hab_train) |> bake(new_data = NULL)

SMOTE Data

hab_smote |>  ggplot(aes(x = Survival)) +
  geom_bar()

Visualizing SMOTE Data

hab_smote |>  ggplot(aes(x = OpYear, y = Age, color = Survival)) +
  geom_jitter()

Visualizing SMOTE Data: No Jitter

hab_smote |>  ggplot(aes(x = OpYear, y = Age, color = Survival)) +
  geom_point()

Performance considerations

  • Pro:
    • Model doesn’t over-align (as much) to noise in minority class
    • Don’t lose (as much) information from majority class
  • Con:
    • Creating new information out of nowhere

Fitting models

oversamp_fit <- workflow() |> add_recipe(upsample_recipe) |> 
  add_model(lr_model) |> fit(hab_train)
downsamp_fit <- workflow() |> add_recipe(downsample_recipe) |> 
  add_model(lr_model)  |> fit(hab_train)
smote_fit <- workflow() |> add_recipe(smote_recipe) |> 
  add_model(lr_model)  |> fit(hab_train)

Evaluate Performance

oversamp_fit |> augment(new_data = hab_test) |> 
  roc_auc(truth = Survival, .pred_Died) |> kable()
.metric .estimator .estimate
roc_auc binary 0.7343358
downsamp_fit |> augment(new_data = hab_test) |> 
  roc_auc(truth = Survival, .pred_Died) |> kable()
.metric .estimator .estimate
roc_auc binary 0.7243108
smote_fit |> augment(new_data = hab_test) |> 
  roc_auc(truth = Survival, .pred_Died) |> kable()
.metric .estimator .estimate
roc_auc binary 0.7251462

Evaluate Performance

oversamp_fit |> augment(new_data = hab_test) |> 
  hab_metrics(truth = Survival, estimate = .pred_class) |> kable()
.metric .estimator .estimate
accuracy binary 0.7692308
precision binary 0.5714286
recall binary 0.5714286
downsamp_fit |> augment(new_data = hab_test) |> 
  hab_metrics(truth = Survival, estimate = .pred_class) |> kable()
.metric .estimator .estimate
accuracy binary 0.7435897
precision binary 0.5217391
recall binary 0.5714286
smote_fit |> augment(new_data = hab_test) |> 
  hab_metrics(truth = Survival, estimate = .pred_class) |> kable()
.metric .estimator .estimate
accuracy binary 0.7435897
precision binary 0.5217391
recall binary 0.5714286