Logistic regression

This model was constructed to compare against the Random Forest approach

# Loading libraries
library(lubridate)

Attaching package: 'lubridate'
The following objects are masked from 'package:base':

    date, intersect, setdiff, union
library(tidyverse)
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr   1.1.4     ✔ readr   2.2.0
✔ forcats 1.0.1     ✔ stringr 1.6.0
✔ ggplot2 3.5.2     ✔ tibble  3.2.1
✔ purrr   1.2.1     ✔ tidyr   1.3.2
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidymodels)
── Attaching packages ────────────────────────────────────── tidymodels 1.4.1 ──
✔ broom        1.0.12     ✔ rsample      1.3.2 
✔ dials        1.4.2      ✔ tailor       0.1.0 
✔ infer        1.1.0      ✔ tune         2.0.1 
✔ modeldata    1.5.1      ✔ workflows    1.3.0 
✔ parsnip      1.4.1      ✔ workflowsets 1.1.1 
✔ recipes      1.3.1      ✔ yardstick    1.3.2 
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ scales::discard() masks purrr::discard()
✖ dplyr::filter()   masks stats::filter()
✖ recipes::fixed()  masks stringr::fixed()
✖ dplyr::lag()      masks stats::lag()
✖ yardstick::spec() masks readr::spec()
✖ recipes::step()   masks stats::step()
library(here)
here() starts at C:/Users/marco/Documents/marcoreina-MADA-project
library(gt)
# Read data (same as RF)
amr_long_rf <- read_csv(here("data/processed-data/amr_long.csv"))
Rows: 688980 Columns: 10
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
chr  (8): Isolation.source, Source.type, BioSample, antigen_formula, serotyp...
dbl  (1): year
dttm (1): Create.date

ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
amr_long_rf <- amr_long_rf %>%
  mutate(AMR_class_strict = if_else(AMR_status == "COMPLETE", AMR_class, "NO AMR"))

amr_wide_rf <- amr_long_rf %>%
  select(Create.date, BioSample, AMR_class_strict, serotype, Source.type) %>%
  distinct(BioSample, AMR_class_strict, .keep_all = TRUE) %>%
  mutate(present = 1) %>%
  pivot_wider(
    id_cols = c(Create.date, BioSample, serotype, Source.type),
    names_from = AMR_class_strict,
    values_from = present,
    values_fill = 0
  ) 

amr_wide_rf$`Efflux Pump` = NULL
amr_wide_rf$`NO AMR` = NULL 

amr_wide_rf <- amr_wide_rf %>%
  mutate(
    sample_year = as.factor(year(Create.date)),
    sample_month = month(Create.date, label = TRUE),
    total_amr_classes = rowSums(across(c("Aminoglycoside", "Fosfomycin","Beta_lactam", "Gentamicin", "Trimethoprim", "Quinolone", "Phenicol", "Colistin", "Macrolide", "Bleomycin", "Tetracycline", "Sulfonamide", "Lincosamide"))),
    amr_prevalence = if_else(total_amr_classes == 0, 0, 1),
    mdr_prevalence = if_else(total_amr_classes <= 3, 0, 1)
  )

#Creating list with top serovars
top_20 <- c("I 4,[5],12:i:-", "Infantis", "Typhimurium", "Newport", "Agona", "Kentucky", "Enteritidis", "Anatum",     "Saintpaul", "Braenderup", "Muenchen", "Sandiego", "Javiana", "Montevideo", "Oranienburg", "Thompson", "Mississippi", "Bareilly", "Poona", "Rubislaw")

# Creating the new column to  "Other" for serovars that are not in the top 20
amr_wide_rf <- amr_wide_rf %>%
  mutate(top_20 = if_else(serotype %in% top_20, serotype, "Other"))

# Set seed
rngseed <- 1234
set.seed(rngseed)

amr_wide_rf <- amr_wide_rf %>%
  mutate(amr_prevalence = factor(amr_prevalence, levels = c(1, 0)),
          Tetracycline = factor(Tetracycline, levels = c(1, 0)),           
          Aminoglycoside = factor(Aminoglycoside, levels = c(1, 0)),  
          Sulfonamide = factor(Sulfonamide, levels = c(1, 0)),  
          Beta_lactam = factor(Beta_lactam, levels = c(1, 0)),  
          Phenicol = factor(Phenicol, levels = c(1, 0)),   
  )

amr_wide_rf <- amr_wide_rf %>%
  mutate()

# Split into Training (75%) and Testing (25%)
# Stratifying by amr_prevalence ensures the same ratio of resistant isolates in both sets
set.seed(rngseed)
split <- initial_split(amr_wide_rf, prop = 0.75, strata = amr_prevalence)
train_data <- training(split)
test_data  <- testing(split)
# Define the logistic regression specification
lr_spec <- logistic_reg() %>%
  set_engine("glm") %>%
  set_mode("classification")

# Create a recipe (standardizing predictors)
# Note: For logistic regression, it's good practice to dummy-code factors
lr_recipe <- recipe(amr_prevalence ~ Source.type + top_20 + sample_month, data = train_data) %>%
  step_dummy(all_nominal_predictors())

# Combine into a workflow
lr_workflow <- workflow() %>%
  add_recipe(lr_recipe) %>%
  add_model(lr_spec)

# Fit the model
lr_fit <- lr_workflow %>% 
  fit(data = train_data)

# Calculate AUC for Training Data
train_pred_lr <- train_data %>%
  bind_cols(predict(lr_fit, new_data = train_data, type = "prob"))

train_auc <- train_pred_lr %>% 
  roc_auc(truth = amr_prevalence, .pred_1)

cat("Training AUC:", train_auc$.estimate)
Training AUC: 0.8268975
# Predict on test data
test_pred_lr <- test_data %>%
  bind_cols(
    predict(lr_fit, new_data = test_data),
    predict(lr_fit, new_data = test_data, type = "prob")
  )
# Calculate metrics for Logistic Regression
lr_metrics <- test_pred_lr %>%
  metrics(truth = amr_prevalence, estimate = .pred_class, .pred_1) %>%
  mutate(model = "Logistic Regression")

# Confusion Matrix
lr_conf_mat <- test_pred_lr %>%
  conf_mat(truth = amr_prevalence, estimate = .pred_class)

print(lr_conf_mat)
          Truth
Prediction     1     0
         1  3178  1109
         0  6449 50707
# Get all major metrics (Accuracy, Kappa, Sens, Spec, Precision, etc.)
lr_metrics <- summary(lr_conf_mat)
print(lr_metrics)
# A tibble: 13 × 3
   .metric              .estimator .estimate
   <chr>                <chr>          <dbl>
 1 accuracy             binary        0.877 
 2 kap                  binary        0.399 
 3 sens                 binary        0.330 
 4 spec                 binary        0.979 
 5 ppv                  binary        0.741 
 6 npv                  binary        0.887 
 7 mcc                  binary        0.440 
 8 j_index              binary        0.309 
 9 bal_accuracy         binary        0.654 
10 detection_prevalence binary        0.0698
11 precision            binary        0.741 
12 recall               binary        0.330 
13 f_meas               binary        0.457 
# Specific metrics for your table
lr_auc <- test_pred_lr %>% roc_auc(truth = amr_prevalence, .pred_1)
lr_sens <- test_pred_lr %>% sens(truth = amr_prevalence, estimate = .pred_class)
lr_spec <- test_pred_lr %>% spec(truth = amr_prevalence, estimate = .pred_class)

cat("Test AUC:", lr_auc$.estimate, "\nSensitivity:", lr_sens$.estimate, "\nSpecificity:", lr_spec$.estimate)
Test AUC: 0.8222631 
Sensitivity: 0.3301132 
Specificity: 0.9785973
ggplot(test_pred_lr, aes(x = .pred_1, fill = amr_prevalence)) +
  geom_density(alpha = 0.6) +
  scale_fill_manual(values = c("#D55E00", "#56B4E9"), 
                    labels = c("Resistant", "Susceptible"),
                    name = "Observed Status") +
  labs(title = "Logistic Regression: Predicted Probability Distribution",
       x = "Predicted Probability of Resistance",
       y = "Density") +
  theme_bw()