CROSS VALIDATION

# Load packages
library(lubridate)

Attaching package: 'lubridate'
The following objects are masked from 'package:base':

    date, intersect, setdiff, union
library(tidyverse)
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr   1.1.4     ✔ readr   2.2.0
✔ forcats 1.0.1     ✔ stringr 1.6.0
✔ ggplot2 3.5.2     ✔ tibble  3.2.1
✔ purrr   1.2.1     ✔ tidyr   1.3.2
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidymodels)
── Attaching packages ────────────────────────────────────── tidymodels 1.4.1 ──
✔ broom        1.0.12     ✔ rsample      1.3.2 
✔ dials        1.4.2      ✔ tailor       0.1.0 
✔ infer        1.1.0      ✔ tune         2.0.1 
✔ modeldata    1.5.1      ✔ workflows    1.3.0 
✔ parsnip      1.4.1      ✔ workflowsets 1.1.1 
✔ recipes      1.3.1      ✔ yardstick    1.3.2 
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ scales::discard() masks purrr::discard()
✖ dplyr::filter()   masks stats::filter()
✖ recipes::fixed()  masks stringr::fixed()
✖ dplyr::lag()      masks stats::lag()
✖ yardstick::spec() masks readr::spec()
✖ recipes::step()   masks stats::step()
library(here)
here() starts at C:/Users/marco/Documents/marcoreina-MADA-project
library(vip)

Attaching package: 'vip'

The following object is masked from 'package:utils':

    vi
library(gt)

# Read data
amr_long_rf <- read_csv(here("data/processed-data/amr_long.csv"))
Rows: 688980 Columns: 10
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
chr  (8): Isolation.source, Source.type, BioSample, antigen_formula, serotyp...
dbl  (1): year
dttm (1): Create.date

ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
amr_long_rf <- amr_long_rf %>%
  mutate(AMR_class_strict = if_else(AMR_status == "COMPLETE", AMR_class, "NO AMR"))

amr_wide_rf <- amr_long_rf %>%
  select(Create.date, BioSample, AMR_class_strict, serotype, Source.type) %>%
  distinct(BioSample, AMR_class_strict, .keep_all = TRUE) %>%
  mutate(present = 1) %>%
  pivot_wider(
    id_cols = c(Create.date, BioSample, serotype, Source.type),
    names_from = AMR_class_strict,
    values_from = present,
    values_fill = 0
  ) 

amr_wide_rf$`Efflux Pump` = NULL
amr_wide_rf$`NO AMR` = NULL 

amr_wide_rf <- amr_wide_rf %>%
  mutate(
    sample_year = as.factor(year(Create.date)),
    sample_month = month(Create.date, label = TRUE),
    total_amr_classes = rowSums(across(c("Aminoglycoside", "Fosfomycin","Beta_lactam", "Gentamicin", "Trimethoprim", "Quinolone", "Phenicol", "Colistin", "Macrolide", "Bleomycin", "Tetracycline", "Sulfonamide", "Lincosamide"))),
    amr_prevalence = if_else(total_amr_classes == 0, 0, 1),
    mdr_prevalence = if_else(total_amr_classes <= 3, 0, 1)
  )

#Creating list with top serovars
top_20 <- c("I 4,[5],12:i:-", "Infantis", "Typhimurium", "Newport", "Agona", "Kentucky", "Enteritidis", "Anatum",     "Saintpaul", "Braenderup", "Muenchen", "Sandiego", "Javiana", "Montevideo", "Oranienburg", "Thompson", "Mississippi", "Bareilly", "Poona", "Rubislaw")

# Creating the new column to  "Other" for serovars that are not in the top 20
amr_wide_rf <- amr_wide_rf %>%
  mutate(top_20 = if_else(serotype %in% top_20, serotype, "Other"))

# Set seed
rngseed <- 1234
set.seed(rngseed)

amr_wide_rf <- amr_wide_rf %>%
  mutate(amr_prevalence = factor(amr_prevalence, levels = c(1, 0)),
          Tetracycline = factor(Tetracycline, levels = c(1, 0)),           
          Aminoglycoside = factor(Aminoglycoside, levels = c(1, 0)),  
          Sulfonamide = factor(Sulfonamide, levels = c(1, 0)),  
          Beta_lactam = factor(Beta_lactam, levels = c(1, 0)),  
          Phenicol = factor(Phenicol, levels = c(1, 0)),   
  )

amr_wide_rf <- amr_wide_rf %>%
  mutate()

set.seed(rngseed)

split <- initial_split(amr_wide_rf, prop = 0.75, strata = amr_prevalence)

train_data <- training(split)
test_data  <- testing(split)

rf_spec <- rand_forest(
  mode = "classification",
  trees = 500,
  mtry = floor(sqrt(ncol(train_data) - 1)),
  min_n = 5
) %>%
  set_engine("ranger", seed = rngseed, importance = "permutation")
set.seed(rngseed)

rf_fit_amr <- rf_spec %>%
  fit(amr_prevalence ~ Source.type + top_20 + sample_month, data = train_data)
Warning: ! 5 columns were requested but there were 3 predictors in the data.
ℹ 3 predictors will be used.
rf_fit_amr
parsnip model object

Ranger result

Call:
 ranger::ranger(x = maybe_data_frame(x), y = y, mtry = min_cols(~floor(sqrt(ncol(train_data) -      1)), x), num.trees = ~500, min.node.size = min_rows(~5, x),      seed = ~rngseed, importance = ~"permutation", num.threads = 1,      verbose = FALSE, probability = TRUE) 

Type:                             Probability estimation 
Number of trees:                  500 
Sample size:                      184328 
Number of independent variables:  3 
Mtry:                             3 
Target node size:                 5 
Variable importance mode:         permutation 
Splitrule:                        gini 
OOB prediction error (Brier s.):  0.09430651 
test_pred5 <- test_data %>%
  bind_cols(
    predict(rf_fit_amr, new_data = test_data),                 
    predict(rf_fit_amr, new_data = test_data, type = "prob")   
  )

conf_mat(test_pred5, truth = amr_prevalence, estimate = .pred_class)
          Truth
Prediction     1     0
         1  3391  1166
         0  6236 50650
sens(test_pred5, truth = amr_prevalence, estimate = .pred_class)
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 sens    binary         0.352
spec(test_pred5, truth = amr_prevalence, estimate = .pred_class)
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 spec    binary         0.977
accuracy(test_pred5, truth = amr_prevalence, estimate = .pred_class)
# A tibble: 1 × 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy binary         0.880
precision(test_pred5, truth = amr_prevalence, estimate = .pred_class)
# A tibble: 1 × 3
  .metric   .estimator .estimate
  <chr>     <chr>          <dbl>
1 precision binary         0.744
f_meas(test_pred5, truth = amr_prevalence, estimate = .pred_class) 
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 f_meas  binary         0.478
roc_curve(test_pred5, truth = amr_prevalence, .pred_1) %>%
  autoplot()

roc_auc(test_pred5, truth = amr_prevalence, .pred_1)
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 roc_auc binary         0.825
ggplot(test_pred5, aes(x = .pred_1, fill = amr_prevalence)) +
  geom_density(alpha = 0.5) +
  theme_bw()

# Extract and plot
rf_fit_amr %>%
  extract_fit_engine() %>%
  vip(num_features = 10) +
  theme_bw() +
  labs(title = "Variable Importance for AMR Prediction")

set.seed(rngseed)

# 1. Create cross-validation folds
folds <- vfold_cv(train_data, v = 10)

# 2. Build a workflow (combines formula + model spec)
rf_wf <- workflow() %>%
  add_formula(amr_prevalence ~ Source.type + top_20 + sample_month) %>%
  add_model(rf_spec)

# 3. Fit resamples using the workflow
rf_fit_cv <- fit_resamples(
  rf_wf, 
  resamples = folds,
  metrics = metric_set(roc_auc)
)
→ A | warning: ! 5 columns were requested but there were 3 predictors in the data.
               ℹ 3 predictors will be used.
There were issues with some computations   A: x1
There were issues with some computations   A: x2
There were issues with some computations   A: x3
There were issues with some computations   A: x4
There were issues with some computations   A: x5
There were issues with some computations   A: x6
There were issues with some computations   A: x7
There were issues with some computations   A: x8
There were issues with some computations   A: x9
There were issues with some computations   A: x10
There were issues with some computations   A: x10
# 4. View the results
collect_metrics(rf_fit_cv)
# A tibble: 1 × 6
  .metric .estimator  mean     n std_err .config        
  <chr>   <chr>      <dbl> <int>   <dbl> <chr>          
1 roc_auc binary     0.829    10 0.00145 pre0_mod0_post0

HONEST ASSESSMENT

# Predictions
test_pred_class <- predict(rf_fit_amr, new_data = test_data)
test_pred_prob  <- predict(rf_fit_amr, new_data = test_data, type = "prob")

# Combine
test_results <- test_data %>%
  select(amr_prevalence) %>%
  bind_cols(test_pred_class, test_pred_prob)

# ROC AUC
test_results %>% roc_auc(truth = amr_prevalence, .pred_1)
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 roc_auc binary         0.825

Publication ready tables

library(tibble)
library(gt)
library(tidyr)

validation_results <- tibble(
  Phase = c("10-Fold Cross-Validation (Mean)", "Honest Assessment (Test Set)"),
  `ROC AUC` = c(0.829, 0.825),
  `Sensitivity` = c(NA, 0.352),
  `SE (AUC)` = c(0.00145, NA)
)

validation_results %>%
  gt() %>%
  tab_header(title = "Model Stability and Performance") %>%
  fmt_number(columns = c(`ROC AUC`, `Sensitivity`), decimals = 3) %>%
  fmt_missing(columns = everything(), missing_text = "---") %>%
  opt_stylize(style = 1, color = "gray") %>%
  gtsave(here("results/tables/validation_results.html"))  
Warning: Since gt v0.6.0 `fmt_missing()` is deprecated and will soon be removed.
ℹ Use `sub_missing()` instead.
This warning is displayed once every 8 hours.
# Data from your Cross-Validation results
conf_matrix_data <- tibble(
  ` ` = c("Predicted: Resistant", "Predicted: Susceptible"),
  `Actual: Resistant` = c(3391, 6236),
  `Actual: Susceptible` = c(1166, 50650)
)

conf_matrix_data %>%
  gt() %>%
  tab_header(
    title = "Confusion Matrix",
    subtitle = "Evaluation on unseen test data (n = 61,443)"
  ) %>%
  tab_spanner(
    label = "Reference (Truth)",
    columns = 2:3
  ) %>%
  # Adding color to highlight the correct predictions (the diagonal)
  tab_style(
    style = cell_fill(color = "honeydew"),
    locations = cells_body(rows = 1, columns = 2) 
  ) %>%
  tab_style(
    style = cell_fill(color = "honeydew"),
    locations = cells_body(rows = 2, columns = 3)
  ) %>%
  cols_align(align = "center", columns = 2:3) %>%
  opt_stylize(style = 1, color = "gray") %>%
  gtsave(here("results/tables/conf_matrix.html"))