Werken met Tidymodels, een suite voor machine learning

Tidymodels is het suite-pakket van Posit om met machinelearning te werken. Op de website staan vijf handleidingen die voor deze post wat bewerkt zijn en die de verschillende aspecten van het pakket laten zien. Het is een introductie, het vertelt iets over de belangrijkste onderdelen en er wordt een uitgebreide case-studie gepresenteerd.

modelleren
Author

Max Kuhn en Julia Silge, bewerkt door HarrieJonkman

Published

January 11, 2023

Introductie

Tidymodels is een nieuwe versie van Max Kuhns pakket CARET en kan voor verschillende machine learning taken worden gebruikt. Het is sterk geïnspireerd door Tidyverse. Ook Tidymodels is een suite van verschillende pakketten, van voorbereiding tot en met evaluatie en dat het mogelijk maakt om het uitvoeren van analyses steeds op een vergelijkbare manier uit te voeren. Over dit pakket is een zeer duidelijke website gemaakt met uitleg website. Tegelijkertijd is er het boek. Hieronder zie je een bewerkte versie van de vijf handleidingen die op de website zijn te vinden.

Handleiding 1: Overall

Met de eerste handleiding (see krijg je een idee hoe tidymodels werkt. Hierin zet je enkele belangrijke stappen: je opent de data, specificeert en traint het model, gebruikt daarbij verschillende engines (technieken) en je leert te begrijpen hoe het allemaal werkt. In deze handleiding staat het parsnip-pakket centraal.

Eerst moet je, zoals altijd, enkele pakketten openen.

# Het basispakket
library(tidymodels)
── Attaching packages ────────────────────────────────────── tidymodels 1.0.0 ──
✔ broom        1.0.1      ✔ recipes      1.0.1 
✔ dials        1.0.0      ✔ rsample      1.1.0 
✔ dplyr        1.0.10     ✔ tibble       3.1.8 
✔ ggplot2      3.4.0      ✔ tidyr        1.2.1 
✔ infer        1.0.3      ✔ tune         1.0.0 
✔ modeldata    1.0.1      ✔ workflows    1.1.0 
✔ parsnip      1.0.1      ✔ workflowsets 1.0.0 
✔ purrr        1.0.0      ✔ yardstick    1.1.0 
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ purrr::discard() masks scales::discard()
✖ dplyr::filter()  masks stats::filter()
✖ dplyr::lag()     masks stats::lag()
✖ recipes::step()  masks stats::step()
• Search for functions across packages at https://www.tidymodels.org/find/
## Enkele ondersteunende pakketten 
library(readr)       # voor importeren van data

Attaching package: 'readr'
The following object is masked from 'package:yardstick':

    spec
The following object is masked from 'package:scales':

    col_factor
library(broom.mixed) # om bayesiaanse modellen om te zetten naar tidy tibbles
library(dotwhisker)  # voor visualiseren van regressieresultaten

De dataset

In deze handleiding wordt met data van zeeëgels gewerkt. Hier vind je het artikel over voedingsregimes dat hieronder wordt uitgewerkt

Eerst maar eens die data inlezen.

urchins <-
# Data werden verzameld voor de handleiding
# zie https://www.flutterbys.com.au/stats/tut/tut7.5a.html
read_csv("https://tidymodels.org/start/models/urchins.csv") %>%
# Verander de namen om ze iets minder uitgebreid te laten zijn
setNames(c("food_regime", "initial_volume", "width")) %>%
# Factoren zijn handig bij modeleren, daarom een kolumn omgezet
mutate(food_regime = factor(food_regime, levels = c("Initial", "Low", "High")))
Rows: 72 Columns: 3
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
chr (1): TREAT
dbl (2): IV, SUTW

ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

Laten we de data vervolgens eens bekijken, met 72 rijen (zeeëgels) en drie variabelen.

glimpse(urchins)
Rows: 72
Columns: 3
$ food_regime    <fct> Initial, Initial, Initial, Initial, Initial, Initial, I…
$ initial_volume <dbl> 3.5, 5.0, 8.0, 10.0, 13.0, 13.0, 15.0, 15.0, 16.0, 17.0…
$ width          <dbl> 0.010, 0.020, 0.061, 0.051, 0.041, 0.061, 0.041, 0.071,…

Het is goed om de data dan ook eens te visualiseren.

 ggplot(urchins,
               aes(x = initial_volume,
                   y = width,
                   group = food_regime,
                   col = food_regime)) +
                geom_point() +
                geom_smooth(method = lm, se = FALSE) +
                scale_color_viridis_d(option = "plasma", end = .7)
`geom_smooth()` using formula = 'y ~ x'

Bouwen en fitten van een model

Een standaard twee-weg analyse van variantie (ANOVA) is zinvol voor deze dataset omdat deze zowel een continue als een categorische voorspeller bevat.

In dit geval draaien we een lineaire regressie.

lm_mod <-
        linear_reg() %>%
        set_engine("lm")
##  Train/fit/schatten van model 
lm_fit <-
                lm_mod %>%
                fit(width ~ initial_volume * food_regime, data = urchins)
## tidy uitprinten
tidy(lm_fit)
# A tibble: 6 × 5
  term                            estimate std.error statistic  p.value
  <chr>                              <dbl>     <dbl>     <dbl>    <dbl>
1 (Intercept)                     0.0331    0.00962      3.44  0.00100 
2 initial_volume                  0.00155   0.000398     3.91  0.000222
3 food_regimeLow                  0.0198    0.0130       1.52  0.133   
4 food_regimeHigh                 0.0214    0.0145       1.47  0.145   
5 initial_volume:food_regimeLow  -0.00126   0.000510    -2.47  0.0162  
6 initial_volume:food_regimeHigh  0.000525  0.000702     0.748 0.457   
## resultaten plotten
        tidy(lm_fit) %>%
                dwplot(dot_args = list(size = 2, color = "black"),
                       whisker_args = list(color = "black"),
                       vline = geom_vline(xintercept = 0, colour = "grey50", linetype = 2))

Een model gebruiken om te voorspellen

Het model hebben we gedefinieerd. Stel dat we vervolgens een voorspelling willen maken voor egels met een volume van 20ml. Zet deze punten erin.

 new_points <- expand.grid(initial_volume = 20,
                          food_regime = c("Initial", "Low", "High"))

Fit dan het model met deze nieuwe datapunten.

mean_pred <- predict(lm_fit, new_data = new_points)
mean_pred
# A tibble: 3 × 1
   .pred
   <dbl>
1 0.0642
2 0.0588
3 0.0961

Laat dan ook de betrouwbaarheidsintervallen hiervoor zien.

conf_int_pred <- predict(lm_fit,
                                 new_data = new_points,
                                 type = "conf_int")

Combineer de informatie.

plot_data <-
                new_points %>%
                bind_cols(mean_pred) %>%
                bind_cols(conf_int_pred)

En plot dan de resultaten.

 ggplot(plot_data, aes(x = food_regime)) +
                geom_point(aes(y = .pred)) +
                geom_errorbar(aes(ymin = .pred_lower,
                                  ymax = .pred_upper),
                              width = .2) +
                labs(y = "urchin size")

Model met een andere engine.

Laten we nu niet lineaire regressie op een standaardmanier uitvoeren. Stel dat we het nu Bayesiaans willen doen. In dat geval moet je eerst de prior-distributie vastzetten.

 prior_dist <- rstanarm::student_t(df = 1)
        set.seed(123)

Dan definiëren we het model opnieuw.

 bayes_mod <-
                linear_reg() %>%
                set_engine("stan",
                           prior_intercept = prior_dist,
                           prior = prior_dist)

Vervolgens trainen we het nieuwe model.

 bayes_fit <-
                bayes_mod %>%
                fit(width ~ initial_volume * food_regime, data = urchins)

Print de gegevens van het model vervolgens uit.

  print(bayes_fit, digits = 5)
parsnip model object

stan_glm
 family:       gaussian [identity]
 formula:      width ~ initial_volume * food_regime
 observations: 72
 predictors:   6
------
                               Median   MAD_SD  
(Intercept)                     0.03338  0.00947
initial_volume                  0.00155  0.00039
food_regimeLow                  0.01936  0.01348
food_regimeHigh                 0.02073  0.01395
initial_volume:food_regimeLow  -0.00125  0.00052
initial_volume:food_regimeHigh  0.00055  0.00069

Auxiliary parameter(s):
      Median  MAD_SD 
sigma 0.02143 0.00180

------
* For help interpreting the printed output see ?print.stanreg
* For info on the priors used see ?prior_summary.stanreg

Handleiding 2: Voorbereiding

In de tweede handleiding staan met name voorbereidende activiteiten centraal, activiteiten die je moet uitvoeren voordat je gaat modelleren. Hierbij gaat het bijvoorbeeld om het omzetten van variabelen zodat ze beter werken bij modelleren, variabelen naar andere schalen omzetten, hele groepen variabelen omzetten of om nadrukken te leggen op bepaalde aspecten van variabelen. Het gaat vooral om het pakket recipes.

Nu deze pakketten laden.

library(nycflights13)    # voor vluchtdata
library(skimr)           # voor samenvattingen van variabelen

De data

Het gaat hier om New York City vluchtdata.

# set seed om ervoor te zorgen dat herhalingen zelfde resultaten geven ----
set.seed(123)
## Laden van data ----
data(flights)
## Bekijken van data ----
skimr::skim(flights)
Data summary
Name flights
Number of rows 336776
Number of columns 19
_______________________
Column type frequency:
character 4
numeric 14
POSIXct 1
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
carrier 0 1.00 2 2 0 16 0
tailnum 2512 0.99 5 6 0 4043 0
origin 0 1.00 3 3 0 3 0
dest 0 1.00 3 3 0 105 0

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
year 0 1.00 2013.00 0.00 2013 2013 2013 2013 2013 ▁▁▇▁▁
month 0 1.00 6.55 3.41 1 4 7 10 12 ▇▆▆▆▇
day 0 1.00 15.71 8.77 1 8 16 23 31 ▇▇▇▇▆
dep_time 8255 0.98 1349.11 488.28 1 907 1401 1744 2400 ▁▇▆▇▃
sched_dep_time 0 1.00 1344.25 467.34 106 906 1359 1729 2359 ▁▇▇▇▃
dep_delay 8255 0.98 12.64 40.21 -43 -5 -2 11 1301 ▇▁▁▁▁
arr_time 8713 0.97 1502.05 533.26 1 1104 1535 1940 2400 ▁▃▇▇▇
sched_arr_time 0 1.00 1536.38 497.46 1 1124 1556 1945 2359 ▁▃▇▇▇
arr_delay 9430 0.97 6.90 44.63 -86 -17 -5 14 1272 ▇▁▁▁▁
flight 0 1.00 1971.92 1632.47 1 553 1496 3465 8500 ▇▃▃▁▁
air_time 9430 0.97 150.69 93.69 20 82 129 192 695 ▇▂▂▁▁
distance 0 1.00 1039.91 733.23 17 502 872 1389 4983 ▇▃▂▁▁
hour 0 1.00 13.18 4.66 1 9 13 17 23 ▁▇▇▇▅
minute 0 1.00 26.23 19.30 0 8 29 44 59 ▇▃▆▃▅

Variable type: POSIXct

skim_variable n_missing complete_rate min max median n_unique
time_hour 0 1 2013-01-01 05:00:00 2013-12-31 23:00:00 2013-07-03 10:00:00 6936

Laten we enkele veranderingen in de dataset aanbrengen.

flight_data <-
                flights %>%
                mutate(
# Converteer de 'arrival delay'-variabele in een factorvariabele
                  arr_delay = ifelse(arr_delay >= 30, "late", "on_time"),
                  arr_delay = factor(arr_delay),
# We zullen de datum en niet de tijd gebruiken
                  date = as.Date(time_hour)
                ) %>%
# Includeer ook de weersdata
inner_join(weather, by = c("origin", "time_hour")) %>%
# We gebruiken alleen specifieke kolommen
select(dep_time, flight, origin, dest, air_time, distance,
                       carrier, date, arr_delay, time_hour) %>%
# Missende data halen we eruit
na.omit() %>%
# Voor het draaien van modellen, is het beter om kwalitatieve data te hebbebn
# zet deze om in factoren (ipv karakter strings)
mutate_if(is.character, as.factor)

We zien dat 16% meer dan een half uur vertraging heeft.

flight_data %>% 
  count(arr_delay) %>% 
  mutate(prop = n/sum(n))
# A tibble: 2 × 3
  arr_delay      n  prop
  <fct>      <int> <dbl>
1 late       52540 0.161
2 on_time   273279 0.839

Laten we de veranderingen eens bekijken. We zien bv dat de variabele arr-delay een factor variabele geworden is. Dat is voor het trainen van een logistisch regressiemodel van belang. flight is een numerieke variabele en time-hour is een dttm variabele. Die gebruiken we niet in de training maar wel als eventuele controlevariabelen.

 glimpse(flight_data)
Rows: 325,819
Columns: 10
$ dep_time  <int> 517, 533, 542, 544, 554, 554, 555, 557, 557, 558, 558, 558, …
$ flight    <int> 1545, 1714, 1141, 725, 461, 1696, 507, 5708, 79, 301, 49, 71…
$ origin    <fct> EWR, LGA, JFK, JFK, LGA, EWR, EWR, LGA, JFK, LGA, JFK, JFK, …
$ dest      <fct> IAH, IAH, MIA, BQN, ATL, ORD, FLL, IAD, MCO, ORD, PBI, TPA, …
$ air_time  <dbl> 227, 227, 160, 183, 116, 150, 158, 53, 140, 138, 149, 158, 3…
$ distance  <dbl> 1400, 1416, 1089, 1576, 762, 719, 1065, 229, 944, 733, 1028,…
$ carrier   <fct> UA, UA, AA, B6, DL, UA, B6, EV, B6, AA, B6, B6, UA, UA, AA, …
$ date      <date> 2013-01-01, 2013-01-01, 2013-01-01, 2013-01-01, 2013-01-01,…
$ arr_delay <fct> on_time, on_time, late, on_time, on_time, on_time, on_time, …
$ time_hour <dttm> 2013-01-01 05:00:00, 2013-01-01 05:00:00, 2013-01-01 05:00:…

Er zijn 104 vluchtbestemmingen en 16 verschillende maatschappijen.

flight_data %>% 
  skimr::skim(dest, carrier) 
Data summary
Name Piped data
Number of rows 325819
Number of columns 10
_______________________
Column type frequency:
factor 2
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
dest 0 1 FALSE 104 ATL: 16771, ORD: 16507, LAX: 15942, BOS: 14948
carrier 0 1 FALSE 16 UA: 57489, B6: 53715, EV: 50868, DL: 47465

Data splitten

Vervolgens splitsenen we de dataset in training- en testdata. We splitten het en dan maken we twee datasets.

set.seed(555)
## Splitsen ----
data_split <- initial_split(flight_data, prop = 3/4)
## Training & Testing ----
train_data <- training(data_split)
test_data  <- testing(data_split)

Definieer het model en geef twee variabelen een nieuwe rol (ID). Deze kun je later gebruiken om te zien als iets niet helemaal goed gegaan is bij het voorspellen. Laat uiteindelijk zien hoe de dataset eruit ziet.

flights_rec <-
        recipe(arr_delay ~ ., data = train_data)
        
flights_rec <-
                recipe(arr_delay ~ ., data = train_data) %>%
                update_role(flight, time_hour, new_role = "ID")

summary(flights_rec)
# A tibble: 10 × 4
   variable  type    role      source  
   <chr>     <chr>   <chr>     <chr>   
 1 dep_time  numeric predictor original
 2 flight    numeric ID        original
 3 origin    nominal predictor original
 4 dest      nominal predictor original
 5 air_time  numeric predictor original
 6 distance  numeric predictor original
 7 carrier   nominal predictor original
 8 date      date    predictor original
 9 time_hour date    ID        original
10 arr_delay nominal outcome   original

We voegen nog enkele handelingen toe met recipe. Je kunt verschillende zaken tegelijk uitvoeren mbt verschillende variabelen.

flights_rec <-
        recipe(arr_delay ~ ., data = train_data) %>%
        update_role(flight, time_hour, new_role = "ID") %>%
        step_date(date, features = c("dow", "month")) %>%
        step_holiday(date, holidays = timeDate::listHolidays("US")) %>%
        step_rm(date) %>%
        step_dummy(all_nominal(), -all_outcomes()) %>%
        step_zv(all_predictors())

Model fitten

We specificeren het model als logistische regressie met de glm als engine en specificeren de workflow.

lr_mod <-
          logistic_reg() %>%
          set_engine("glm")
## Specificeren workflow ----
flights_wflow <-
          workflow() %>%
          add_model(lr_mod) %>%
          add_recipe(flights_rec)
flights_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: logistic_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps

• step_date()
• step_holiday()
• step_rm()
• step_dummy()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
Logistic Regression Model Specification (classification)

Computational engine: glm 

Vervolgens fitten we het model en kijken naar de resultaten.

 flights_fit <-
                flights_wflow %>%
                fit(data = train_data)
## Halen resultaten eruit
 flights_fit %>%
                pull_workflow_fit() %>%
                tidy()
Warning: `pull_workflow_fit()` was deprecated in workflows 0.2.3.
ℹ Please use `extract_fit_parsnip()` instead.
# A tibble: 158 × 5
   term                         estimate std.error statistic  p.value
   <chr>                           <dbl>     <dbl>     <dbl>    <dbl>
 1 (Intercept)                   5.25    2.72           1.93 5.40e- 2
 2 dep_time                     -0.00167 0.0000141   -118.   0       
 3 air_time                     -0.0438  0.000561     -78.0  0       
 4 distance                      0.00615 0.00150        4.10 4.09e- 5
 5 date_USChristmasDay           1.14    0.171          6.65 2.86e-11
 6 date_USColumbusDay            0.627   0.169          3.72 2.03e- 4
 7 date_USCPulaskisBirthday      0.702   0.133          5.29 1.25e- 7
 8 date_USDecorationMemorialDay  0.363   0.117          3.11 1.86e- 3
 9 date_USElectionDay            0.695   0.177          3.92 8.87e- 5
10 date_USGoodFriday             1.15    0.156          7.39 1.45e-13
# … with 148 more rows

Voorspellen

We gebruiken de getrainde werkflow om te voorspellen. Nu zetten we deze werkflow van de trainingsset in op de testdata.

predict(flights_fit, test_data)
# A tibble: 81,455 × 1
   .pred_class
   <fct>      
 1 on_time    
 2 on_time    
 3 on_time    
 4 on_time    
 5 on_time    
 6 on_time    
 7 on_time    
 8 on_time    
 9 on_time    
10 on_time    
# … with 81,445 more rows

Laat het voorspellen van de testdata zien en geef de waarschijnlijkheid terug.

flights_pred <-
                predict(flights_fit, test_data, type = "prob") %>%
                bind_cols(test_data %>% select(arr_delay, time_hour, flight))
        flights_pred
# A tibble: 81,455 × 5
   .pred_late .pred_on_time arr_delay time_hour           flight
        <dbl>         <dbl> <fct>     <dttm>               <int>
 1     0.0183         0.982 on_time   2013-01-01 06:00:00    461
 2     0.0426         0.957 on_time   2013-01-01 06:00:00   5708
 3     0.0413         0.959 on_time   2013-01-01 06:00:00     71
 4     0.0253         0.975 on_time   2013-01-01 06:00:00    194
 5     0.0306         0.969 on_time   2013-01-01 06:00:00   1743
 6     0.0236         0.976 on_time   2013-01-01 06:00:00   1077
 7     0.0119         0.988 on_time   2013-01-01 06:00:00    709
 8     0.137          0.863 on_time   2013-01-01 06:00:00    245
 9     0.0526         0.947 on_time   2013-01-01 06:00:00   4599
10     0.0246         0.975 on_time   2013-01-01 06:00:00   1019
# … with 81,445 more rows

Plot deze gegevens met name via yardstick-pakket

 flights_pred %>%
                roc_curve(truth = arr_delay, .pred_late) %>%
                autoplot()

Hoe groot is nu de AREA onder de curve? 76,1%, redelijk.

 flights_pred %>%
                roc_auc(truth = arr_delay, .pred_late)
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 roc_auc binary         0.761

Handleiding 3: Evaluatie

In de derde handleiding gaat het vooral om het evalueren van het model. We willen de performance van het model weten en dat doen we vooral met het resampling-pakket

Eerst gaan we de data binnenhalen waar we in deze handleiding mee zullen werken. Het gaat om data die iets zeggen over de kwaliteit van celbeeld segementatie. Ze zitten in dit pakket.

Bij het vaststellen van effecten van drugs (medicijn wel of niet) wordt er vaak gekeken naar de effecten op de cellen. Dat is op de beelden te zien. Dan wordt er naar de kleur of de afmeting gekeken of naar segmentatie zoals hier.

# tidymodels hebben we al actief gemaakt
# voor de cellen data ----
library(modeldata) 

## Laad de data ----
data(cells, package = "modeldata")

## Dit zijn de data
cells
# A tibble: 2,019 × 58
   case  class angle_c…¹ area_…² avg_i…³ avg_i…⁴ avg_i…⁵ avg_i…⁶ conve…⁷ conve…⁸
   <fct> <fct>     <dbl>   <int>   <dbl>   <dbl>   <dbl>   <dbl>   <dbl>   <dbl>
 1 Test  PS       143.       185    15.7    4.95    9.55    2.21    1.12   0.920
 2 Train PS       134.       819    31.9  207.     69.9   164.      1.26   0.797
 3 Train WS       107.       431    28.0  116.     63.9   107.      1.05   0.935
 4 Train PS        69.2      298    19.5  102.     28.2    31.0     1.20   0.866
 5 Test  PS         2.89     285    24.3  112.     20.5    40.6     1.11   0.957
 6 Test  WS        40.7      172   326.   654.    129.    347.      1.01   0.993
 7 Test  WS       174.       177   260.   596.    124.    273.      1.01   0.984
 8 Test  PS       180.       251    18.3    5.73   17.2     1.55    1.20   0.831
 9 Test  WS        18.9      495    16.1   89.5    13.7    51.4     1.19   0.822
10 Test  WS       153.       384    17.7   89.9    20.4    63.1     1.16   0.865
# … with 2,009 more rows, 48 more variables: diff_inten_density_ch_1 <dbl>,
#   diff_inten_density_ch_3 <dbl>, diff_inten_density_ch_4 <dbl>,
#   entropy_inten_ch_1 <dbl>, entropy_inten_ch_3 <dbl>,
#   entropy_inten_ch_4 <dbl>, eq_circ_diam_ch_1 <dbl>,
#   eq_ellipse_lwr_ch_1 <dbl>, eq_ellipse_oblate_vol_ch_1 <dbl>,
#   eq_ellipse_prolate_vol_ch_1 <dbl>, eq_sphere_area_ch_1 <dbl>,
#   eq_sphere_vol_ch_1 <dbl>, fiber_align_2_ch_3 <dbl>, …
## Uitkomst variable is 'class'
## PS = "poorly segmented, slecht gesegementeerd" WS = "weekly segmented, zwak gesegementeerd"
        cells %>%
                count(class) %>%
                mutate(prop = n/sum(n))
# A tibble: 2 × 3
  class     n  prop
  <fct> <int> <dbl>
1 PS     1300 0.644
2 WS      719 0.356

Data splitsen

De functie rsample::initial_split() neemt de oorspronkelijke gegevens en slaat de informatie op over hoe de delen moeten worden gemaakt. In de oorspronkelijke analyse maakten de auteurs hun eigen trainings-/testset en die informatie staat in de kolom “case”. Om te demonstreren hoe we een splitsing maken, verwijderen we deze kolom voordat we onze eigen splitsing maken.

set.seed(123)
cell_split <- rsample::initial_split(cells %>% select(-case),
                            strata = class)

Hier hebben we het strata-argument gebruikt, dat een gestratificeerde splitsing uitvoert. Dit zorgt ervoor dat, ondanks de onevenwichtigheid die we in onze klassenvariabele hebben opgemerkt, onze trainings- en testdatasets ongeveer dezelfde proporties slecht gesegmenteerde en goed gesegmenteerde cellen behouden als in de oorspronkelijke gegevens. Na de initiële splitsing leveren de functies training() en test() de eigenlijke datasets op.

cell_train <- training(cell_split)
cell_test  <- testing(cell_split)

nrow(cell_train)
[1] 1514
nrow(cell_train)/nrow(cells)
[1] 0.7498762
# trainingset proporties volgens class
cell_train %>% 
  count(class) %>% 
  mutate(prop = n/sum(n))
# A tibble: 2 × 3
  class     n  prop
  <fct> <int> <dbl>
1 PS      975 0.644
2 WS      539 0.356
# testset proporties volgens class
cell_test %>% 
  count(class) %>% 
  mutate(prop = n/sum(n))
# A tibble: 2 × 3
  class     n  prop
  <fct> <int> <dbl>
1 PS      325 0.644
2 WS      180 0.356

Het meeste modelleerwerk wordt op de trainingset uitgevoerd.

Modelleren

Een van de voordelen van een random forest model is dat het zeer onderhoudsarm is; het vereist zeer weinig voorbewerking van de gegevens en de standaardparameters geven doorgaans redelijke resultaten. Om die reden zullen we geen recept maken voor de celgegevens en gaan meteen aan de slag.

rf_mod <-
        rand_forest(trees = 1000) %>%
        set_engine("ranger") %>%
        set_mode("classification")

Dit nieuwe object rf_fit is het model dat we hebben getraind op de trainingsgegevensverzameling

set.seed(234)
rf_fit <-
        rf_mod %>%
        fit(class ~ ., data = cell_train)

Schatten van de performance

Prestaties kunnen worden gemeten aan de hand van de algemene classificatienauwkeurigheid en de Receiver Operating Characteristic (ROC) curve. Het yardstick-pakket heeft functies voor het berekenen van beide maten, genaamd roc_auc() en accuracy(). Gebruik hiervoor niet de trainingsset. Je moet de trainingsset opnieuw bewerken om betrouwbare schattingen te krijgen.

Daarom modelleren we het opnieuw maar nu met resample.

 set.seed(345)
        folds <- vfold_cv(cell_train, v = 10)
        folds
#  10-fold cross-validation 
# A tibble: 10 × 2
   splits             id    
   <list>             <chr> 
 1 <split [1362/152]> Fold01
 2 <split [1362/152]> Fold02
 3 <split [1362/152]> Fold03
 4 <split [1362/152]> Fold04
 5 <split [1363/151]> Fold05
 6 <split [1363/151]> Fold06
 7 <split [1363/151]> Fold07
 8 <split [1363/151]> Fold08
 9 <split [1363/151]> Fold09
10 <split [1363/151]> Fold10
        rf_wf <-
                workflow() %>%
                add_model(rf_mod) %>%
                add_formula(class ~ .)

De kolom .metrics bevat de prestatiestatistieken die uit de 10 beoordelingssets zijn gemaakt. Deze kunnen handmatig worden ontnomen, maar het tune-pakket bevat een aantal eenvoudige functies die deze gegevens kunnen extraheren:

  set.seed(456)
        rf_fit_rs <-
                rf_wf %>%
                fit_resamples(folds)
## Om de metrieken te krijgen ----
        collect_metrics(rf_fit_rs)
# A tibble: 2 × 6
  .metric  .estimator  mean     n std_err .config             
  <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
1 accuracy binary     0.832    10 0.00952 Preprocessor1_Model1
2 roc_auc  binary     0.904    10 0.00610 Preprocessor1_Model1

Conclusie

Denk aan de waarden die we nu hebben voor nauwkeurigheid en AUC. Deze prestatiecijfers zijn nu realistischer (d.w.z. lager) dan onze eerste poging om prestatiecijfers te berekenen in de handleiding hierboven.

rf_testing_pred <-                      # originele slechte idee
        predict(rf_fit, cell_test) %>%
        bind_cols(predict(rf_fit, cell_test, type = "prob")) %>%
        bind_cols(cell_test %>% select(class))
rf_testing_pred %>%                   # testset voorspellingen
        roc_auc(truth = class, .pred_PS)
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 roc_auc binary         0.891
rf_testing_pred %>%                   # test set voorspellingen
        accuracy(truth = class, .pred_class)
# A tibble: 1 × 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy binary         0.816

Handleiding 4: Hyperparameters

Sommige modelparameters kunnen tijdens de modeltraining niet rechtstreeks uit een dataset worden geleerd; dit soort parameters worden hyperparameters genoemd. Enkele voorbeelden van hyperparameters zijn het aantal voorspellers dat wordt bemonsterd bij splitsingen in een ’tree’model (wij noemen dit mtry in tidymodels) of de leersnelheid in een ’boosted tree’model (wij noemen dit learn_rate). In plaats van dit soort hyperparameters te leren tijdens de modeltraining, kunnen we de beste waarden voor deze waarden schatten door veel modellen te trainen op opnieuw gesampelde gegevenssets en te onderzoeken hoe goed al deze modellen presteren. Dit proces heet tuning

# tidymodels heb je al geopend met daarin het tune pakket met de andere pakketten
# andere pakketten
# modeldata, voor de cellen data, ook al geopend
# vip om het belang van variabelen te plotten
library(vip)         # 

Attaching package: 'vip'
The following object is masked from 'package:utils':

    vi

De data openen

# door experts gelabelled als 'well-segmented' (WS) of 'poorly segmented' (PS).
data(cells, package = "modeldata")
cells
# A tibble: 2,019 × 58
   case  class angle_c…¹ area_…² avg_i…³ avg_i…⁴ avg_i…⁵ avg_i…⁶ conve…⁷ conve…⁸
   <fct> <fct>     <dbl>   <int>   <dbl>   <dbl>   <dbl>   <dbl>   <dbl>   <dbl>
 1 Test  PS       143.       185    15.7    4.95    9.55    2.21    1.12   0.920
 2 Train PS       134.       819    31.9  207.     69.9   164.      1.26   0.797
 3 Train WS       107.       431    28.0  116.     63.9   107.      1.05   0.935
 4 Train PS        69.2      298    19.5  102.     28.2    31.0     1.20   0.866
 5 Test  PS         2.89     285    24.3  112.     20.5    40.6     1.11   0.957
 6 Test  WS        40.7      172   326.   654.    129.    347.      1.01   0.993
 7 Test  WS       174.       177   260.   596.    124.    273.      1.01   0.984
 8 Test  PS       180.       251    18.3    5.73   17.2     1.55    1.20   0.831
 9 Test  WS        18.9      495    16.1   89.5    13.7    51.4     1.19   0.822
10 Test  WS       153.       384    17.7   89.9    20.4    63.1     1.16   0.865
# … with 2,009 more rows, 48 more variables: diff_inten_density_ch_1 <dbl>,
#   diff_inten_density_ch_3 <dbl>, diff_inten_density_ch_4 <dbl>,
#   entropy_inten_ch_1 <dbl>, entropy_inten_ch_3 <dbl>,
#   entropy_inten_ch_4 <dbl>, eq_circ_diam_ch_1 <dbl>,
#   eq_ellipse_lwr_ch_1 <dbl>, eq_ellipse_oblate_vol_ch_1 <dbl>,
#   eq_ellipse_prolate_vol_ch_1 <dbl>, eq_sphere_area_ch_1 <dbl>,
#   eq_sphere_vol_ch_1 <dbl>, fiber_align_2_ch_3 <dbl>, …

Voorspellen van beeldsegmentatie maar nu beter

Random forest-modellen is een methode om bomen te schatten en die presteren doorgaans goed met standaard hyperparameters. De nauwkeurigheid van sommige andere soortgelijke modellen kan echter gevoelig zijn voor de waarden van de hyperparameters. In dit artikel zullen we een beslisboommodel (decision tree model) trainen.

set.seed(123)
cell_split <- initial_split(cells %>% select(-case),
                            strata = class)
cell_train <- training(cell_split)
cell_test  <- testing(cell_split)

Hyperparameters afstemmen

Laten we beginnen met het parsnip pakket, met een decision_tree() model met de rpart engine. Om de hyperparameters cost_complexity en tree_depth van de beslisboom te tunen, maken we een modelspecificatie die aangeeft welke hyperparameters we willen tunen.

tune_spec <-
        decision_tree(
                cost_complexity = tune(),
                tree_depth = tune()
        ) %>%
        set_engine("rpart") %>%
        set_mode("classification")
tune_spec
Decision Tree Model Specification (classification)

Main Arguments:
  cost_complexity = tune()
  tree_depth = tune()

Computational engine: rpart 
## dials::grid_regular() 
tree_grid <- grid_regular(cost_complexity(),
                          tree_depth(),
                          levels = 5)
tree_grid
# A tibble: 25 × 2
   cost_complexity tree_depth
             <dbl>      <int>
 1    0.0000000001          1
 2    0.0000000178          1
 3    0.00000316            1
 4    0.000562              1
 5    0.1                   1
 6    0.0000000001          4
 7    0.0000000178          4
 8    0.00000316            4
 9    0.000562              4
10    0.1                   4
# … with 15 more rows

Gewapend met ons raster gevuld met 25 kandidaat-beslisboommodellen, laten we cross-validatie maken voor tuning:

set.seed(234)
cell_folds <- vfold_cv(cell_train)

We zijn klaar voor het afstellen! Laten we tune_grid() gebruiken om modellen te passen bij alle verschillende waarden die we hebben gekozen voor elke afgestemde hyperparameter. Er zijn verschillende mogelijkheden om het object voor tuning te bouwen:

  • Stem een modelspecificatie af samen met een recept of model, of
  • Een workflow() afstemmen die een modelspecificatie en een recept of model preprocessor bundelt.
    Hier gebruiken we een workflow() met een eenvoudige formule; indien dit model een meer gecompliceerde gegevensvoorbewerking vereist, zouden we add_recipe() kunnen gebruiken in plaats van add_formula().
 set.seed(345)
        tree_wf <- workflow() %>%
                add_model(tune_spec) %>%
                add_formula(class ~ .)

Zodra we onze resultaten over het afstellen hebben, kunnen we ze zowel via visualisatie verkennen als het beste resultaat selecteren.

 tree_res <-
                tree_wf %>%
                tune_grid(
                        resamples = cell_folds,
                        grid = tree_grid
                )
        tree_res
# Tuning results
# 10-fold cross-validation 
# A tibble: 10 × 4
   splits             id     .metrics          .notes          
   <list>             <chr>  <list>            <list>          
 1 <split [1362/152]> Fold01 <tibble [50 × 6]> <tibble [0 × 3]>
 2 <split [1362/152]> Fold02 <tibble [50 × 6]> <tibble [0 × 3]>
 3 <split [1362/152]> Fold03 <tibble [50 × 6]> <tibble [0 × 3]>
 4 <split [1362/152]> Fold04 <tibble [50 × 6]> <tibble [0 × 3]>
 5 <split [1363/151]> Fold05 <tibble [50 × 6]> <tibble [0 × 3]>
 6 <split [1363/151]> Fold06 <tibble [50 × 6]> <tibble [0 × 3]>
 7 <split [1363/151]> Fold07 <tibble [50 × 6]> <tibble [0 × 3]>
 8 <split [1363/151]> Fold08 <tibble [50 × 6]> <tibble [0 × 3]>
 9 <split [1363/151]> Fold09 <tibble [50 × 6]> <tibble [0 × 3]>
10 <split [1363/151]> Fold10 <tibble [50 × 6]> <tibble [0 × 3]>

De functie collect_metrics() geeft ons een nette tabel met alle resultaten. We hadden 25 kandidaat-modellen en twee metrieken, accuracy en roc_auc, en we krijgen een rij voor elke .metriek en model.

tree_res %>% 
  collect_metrics()
# A tibble: 50 × 8
   cost_complexity tree_depth .metric  .estimator  mean     n std_err .config   
             <dbl>      <int> <chr>    <chr>      <dbl> <int>   <dbl> <chr>     
 1    0.0000000001          1 accuracy binary     0.732    10  0.0148 Preproces…
 2    0.0000000001          1 roc_auc  binary     0.777    10  0.0107 Preproces…
 3    0.0000000178          1 accuracy binary     0.732    10  0.0148 Preproces…
 4    0.0000000178          1 roc_auc  binary     0.777    10  0.0107 Preproces…
 5    0.00000316            1 accuracy binary     0.732    10  0.0148 Preproces…
 6    0.00000316            1 roc_auc  binary     0.777    10  0.0107 Preproces…
 7    0.000562              1 accuracy binary     0.732    10  0.0148 Preproces…
 8    0.000562              1 roc_auc  binary     0.777    10  0.0107 Preproces…
 9    0.1                   1 accuracy binary     0.732    10  0.0148 Preproces…
10    0.1                   1 roc_auc  binary     0.777    10  0.0107 Preproces…
# … with 40 more rows

Laten we er een grafiek van maken.

tree_res %>%
                collect_metrics() %>%
                mutate(tree_depth = factor(tree_depth)) %>%
                ggplot(aes(cost_complexity, mean, color = tree_depth)) +
                geom_line(size = 1.5, alpha = 0.6) +
                geom_point(size = 2) +
                facet_wrap(~ .metric, scales = "free", nrow = 2) +
                scale_x_log10(labels = scales::label_number()) +
                scale_color_viridis_d(option = "plasma", begin = .9, end = 0)
Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
ℹ Please use `linewidth` instead.

## Wat is de beste?
best_tree <- tree_res %>%
                select_best("roc_auc")
        best_tree
# A tibble: 1 × 3
  cost_complexity tree_depth .config              
            <dbl>      <int> <chr>                
1        0.000562         11 Preprocessor1_Model19

Afronden

Wij kunnen ons workflow-object tree_wf bijwerken (of “finaliseren”) met de waarden van select_best().

 final_wf <-
                tree_wf %>%
                finalize_workflow(best_tree)
        final_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Formula
Model: decision_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
class ~ .

── Model ───────────────────────────────────────────────────────────────────────
Decision Tree Model Specification (classification)

Main Arguments:
  cost_complexity = 0.000562341325190349
  tree_depth = 11

Computational engine: rpart 

Laatste fit

Tot slot passen we dit definitieve model toe op de opleidingsgegevens en gebruiken we onze testgegevens om de modelprestatie te schatten die we verwachten te zien met nieuwe gegevens. Wij kunnen de functie last_fit() gebruiken voor ons definitieve model; deze functie past het definitieve model toe op de volledige reeks opleidingsgegevens en evalueert het definitieve model op de testgegevens.

final_tree <-
            final_wf %>%
            fit(data = cell_train)

final_tree
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Formula
Model: decision_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
class ~ .

── Model ───────────────────────────────────────────────────────────────────────
n= 1514 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

   1) root 1514 539 PS (0.64398943 0.35601057)  
     2) total_inten_ch_2< 41732.5 642  33 PS (0.94859813 0.05140187)  
       4) shape_p_2_a_ch_1>=1.251801 631  27 PS (0.95721078 0.04278922)  
         8) avg_inten_ch_2< 125.8919 525  12 PS (0.97714286 0.02285714) *
         9) avg_inten_ch_2>=125.8919 106  15 PS (0.85849057 0.14150943)  
          18) var_inten_ch_4>=39.85951 82   6 PS (0.92682927 0.07317073) *
          19) var_inten_ch_4< 39.85951 24   9 PS (0.62500000 0.37500000)  
            38) inten_cooc_asm_ch_4>=0.2197672 12   0 PS (1.00000000 0.00000000) *
            39) inten_cooc_asm_ch_4< 0.2197672 12   3 WS (0.25000000 0.75000000) *
       5) shape_p_2_a_ch_1< 1.251801 11   5 WS (0.45454545 0.54545455) *
     3) total_inten_ch_2>=41732.5 872 366 WS (0.41972477 0.58027523)  
       6) fiber_width_ch_1< 11.37318 406 160 PS (0.60591133 0.39408867)  
        12) avg_inten_ch_1< 145.4883 293  85 PS (0.70989761 0.29010239)  
          24) fiber_width_ch_1< 7.878131 68   5 PS (0.92647059 0.07352941) *
          25) fiber_width_ch_1>=7.878131 225  80 PS (0.64444444 0.35555556)  
            50) total_inten_ch_1< 12969.5 74  15 PS (0.79729730 0.20270270)  
             100) inten_cooc_asm_ch_4< 0.06289989 34   2 PS (0.94117647 0.05882353) *
             101) inten_cooc_asm_ch_4>=0.06289989 40  13 PS (0.67500000 0.32500000)  
               202) neighbor_min_dist_ch_1>=32.71331 9   0 PS (1.00000000 0.00000000) *
               203) neighbor_min_dist_ch_1< 32.71331 31  13 PS (0.58064516 0.41935484)  
                 406) skew_inten_ch_4>=1.060929 16   3 PS (0.81250000 0.18750000) *
                 407) skew_inten_ch_4< 1.060929 15   5 WS (0.33333333 0.66666667) *
            51) total_inten_ch_1>=12969.5 151  65 PS (0.56953642 0.43046358)  
             102) kurt_inten_ch_1>=-0.3447192 110  37 PS (0.66363636 0.33636364)  
               204) diff_inten_density_ch_4>=112.6034 35   5 PS (0.85714286 0.14285714) *
               205) diff_inten_density_ch_4< 112.6034 75  32 PS (0.57333333 0.42666667)  
                 410) inten_cooc_contrast_ch_4< 3.122366 11   0 PS (1.00000000 0.00000000) *
                 411) inten_cooc_contrast_ch_4>=3.122366 64  32 PS (0.50000000 0.50000000)  
                   822) fiber_align_2_ch_4>=1.591445 11   1 PS (0.90909091 0.09090909) *
                   823) fiber_align_2_ch_4< 1.591445 53  22 WS (0.41509434 0.58490566)  
                    1646) neighbor_avg_dist_ch_1< 217.8143 21   7 PS (0.66666667 0.33333333)  
                      3292) eq_ellipse_lwr_ch_1>=1.942086 14   2 PS (0.85714286 0.14285714) *
                      3293) eq_ellipse_lwr_ch_1< 1.942086 7   2 WS (0.28571429 0.71428571) *
                    1647) neighbor_avg_dist_ch_1>=217.8143 32   8 WS (0.25000000 0.75000000) *
             103) kurt_inten_ch_1< -0.3447192 41  13 WS (0.31707317 0.68292683)  
               206) shape_bfr_ch_1>=0.635439 12   5 PS (0.58333333 0.41666667) *
               207) shape_bfr_ch_1< 0.635439 29   6 WS (0.20689655 0.79310345)  
                 414) shape_bfr_ch_1< 0.5196834 7   3 PS (0.57142857 0.42857143) *
                 415) shape_bfr_ch_1>=0.5196834 22   2 WS (0.09090909 0.90909091) *
        13) avg_inten_ch_1>=145.4883 113  38 WS (0.33628319 0.66371681)  
          26) total_inten_ch_3>=57919.5 33  10 PS (0.69696970 0.30303030)  
            52) spot_fiber_count_ch_3< 2.5 24   4 PS (0.83333333 0.16666667)  
             104) kurt_inten_ch_1>=-0.335807 17   0 PS (1.00000000 0.00000000) *
             105) kurt_inten_ch_1< -0.335807 7   3 WS (0.42857143 0.57142857) *
            53) spot_fiber_count_ch_3>=2.5 9   3 WS (0.33333333 0.66666667) *

...
and 40 more lines.
## variabele belang
library(vip)
final_tree %>%
                pull_workflow_fit() %>%
                vip(geom = "point")

Tot slot passen we dit definitieve model toe op de trainingsgegevens en gebruiken we onze testgegevens om de modelprestatie te schatten die we verwachten te zien met nieuwe gegevens.

Wij kunnen de functie last_fit() gebruiken voor ons definitieve model; deze functie past het definitieve model toe op de volledige reeks trainingsgegevens en evalueert het definitieve model op de testgegevens.

final_fit <-
          final_wf %>%
          last_fit(cell_split)
## verzamel de metrieken
final_fit %>%
          collect_metrics()
# A tibble: 2 × 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.758 Preprocessor1_Model1
2 roc_auc  binary         0.839 Preprocessor1_Model1

Toon nog even de ROC-curve.

 final_fit %>%
                collect_predictions() %>%
                roc_curve(class, .pred_PS) %>%
                autoplot()

De prestatiecijfers van de testset geven aan dat we tijdens onze tuneprocedure niet te veel hebben aangepast.

Het object final_fit bevat een definitieve, passende workflow die je kunt gebruiken voor voorspellingen op nieuwe gegevens of om de resultaten verder te begrijpen. Je kunt dit object uitpakken met een van de helpfuncties extract_.

final_tree <- extract_workflow(final_fit)
final_tree
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Formula
Model: decision_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
class ~ .

── Model ───────────────────────────────────────────────────────────────────────
n= 1514 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

   1) root 1514 539 PS (0.64398943 0.35601057)  
     2) total_inten_ch_2< 41732.5 642  33 PS (0.94859813 0.05140187)  
       4) shape_p_2_a_ch_1>=1.251801 631  27 PS (0.95721078 0.04278922)  
         8) avg_inten_ch_2< 125.8919 525  12 PS (0.97714286 0.02285714) *
         9) avg_inten_ch_2>=125.8919 106  15 PS (0.85849057 0.14150943)  
          18) var_inten_ch_4>=39.85951 82   6 PS (0.92682927 0.07317073) *
          19) var_inten_ch_4< 39.85951 24   9 PS (0.62500000 0.37500000)  
            38) inten_cooc_asm_ch_4>=0.2197672 12   0 PS (1.00000000 0.00000000) *
            39) inten_cooc_asm_ch_4< 0.2197672 12   3 WS (0.25000000 0.75000000) *
       5) shape_p_2_a_ch_1< 1.251801 11   5 WS (0.45454545 0.54545455) *
     3) total_inten_ch_2>=41732.5 872 366 WS (0.41972477 0.58027523)  
       6) fiber_width_ch_1< 11.37318 406 160 PS (0.60591133 0.39408867)  
        12) avg_inten_ch_1< 145.4883 293  85 PS (0.70989761 0.29010239)  
          24) fiber_width_ch_1< 7.878131 68   5 PS (0.92647059 0.07352941) *
          25) fiber_width_ch_1>=7.878131 225  80 PS (0.64444444 0.35555556)  
            50) total_inten_ch_1< 12969.5 74  15 PS (0.79729730 0.20270270)  
             100) inten_cooc_asm_ch_4< 0.06289989 34   2 PS (0.94117647 0.05882353) *
             101) inten_cooc_asm_ch_4>=0.06289989 40  13 PS (0.67500000 0.32500000)  
               202) neighbor_min_dist_ch_1>=32.71331 9   0 PS (1.00000000 0.00000000) *
               203) neighbor_min_dist_ch_1< 32.71331 31  13 PS (0.58064516 0.41935484)  
                 406) skew_inten_ch_4>=1.060929 16   3 PS (0.81250000 0.18750000) *
                 407) skew_inten_ch_4< 1.060929 15   5 WS (0.33333333 0.66666667) *
            51) total_inten_ch_1>=12969.5 151  65 PS (0.56953642 0.43046358)  
             102) kurt_inten_ch_1>=-0.3447192 110  37 PS (0.66363636 0.33636364)  
               204) diff_inten_density_ch_4>=112.6034 35   5 PS (0.85714286 0.14285714) *
               205) diff_inten_density_ch_4< 112.6034 75  32 PS (0.57333333 0.42666667)  
                 410) inten_cooc_contrast_ch_4< 3.122366 11   0 PS (1.00000000 0.00000000) *
                 411) inten_cooc_contrast_ch_4>=3.122366 64  32 PS (0.50000000 0.50000000)  
                   822) fiber_align_2_ch_4>=1.591445 11   1 PS (0.90909091 0.09090909) *
                   823) fiber_align_2_ch_4< 1.591445 53  22 WS (0.41509434 0.58490566)  
                    1646) neighbor_avg_dist_ch_1< 217.8143 21   7 PS (0.66666667 0.33333333)  
                      3292) eq_ellipse_lwr_ch_1>=1.942086 14   2 PS (0.85714286 0.14285714) *
                      3293) eq_ellipse_lwr_ch_1< 1.942086 7   2 WS (0.28571429 0.71428571) *
                    1647) neighbor_avg_dist_ch_1>=217.8143 32   8 WS (0.25000000 0.75000000) *
             103) kurt_inten_ch_1< -0.3447192 41  13 WS (0.31707317 0.68292683)  
               206) shape_bfr_ch_1>=0.635439 12   5 PS (0.58333333 0.41666667) *
               207) shape_bfr_ch_1< 0.635439 29   6 WS (0.20689655 0.79310345)  
                 414) shape_bfr_ch_1< 0.5196834 7   3 PS (0.57142857 0.42857143) *
                 415) shape_bfr_ch_1>=0.5196834 22   2 WS (0.09090909 0.90909091) *
        13) avg_inten_ch_1>=145.4883 113  38 WS (0.33628319 0.66371681)  
          26) total_inten_ch_3>=57919.5 33  10 PS (0.69696970 0.30303030)  
            52) spot_fiber_count_ch_3< 2.5 24   4 PS (0.83333333 0.16666667)  
             104) kurt_inten_ch_1>=-0.335807 17   0 PS (1.00000000 0.00000000) *
             105) kurt_inten_ch_1< -0.335807 7   3 WS (0.42857143 0.57142857) *
            53) spot_fiber_count_ch_3>=2.5 9   3 WS (0.33333333 0.66666667) *

...
and 40 more lines.

Misschien willen we ook begrijpen welke variabelen belangrijk zijn in dit uiteindelijke model. Wij kunnen het vip-pakket gebruiken om het belang van variabelen te schatten op basis van de structuur van het model.

library(vip)

final_tree %>% 
  extract_fit_parsnip() %>% 
  vip()

Dit zijn de geautomatiseerde beeldanalysemetingen die het belangrijkst zijn voor de voorspelling van de segmentatiekwaliteit.

We laten het aan de lezer over om te onderzoeken of zij een andere beslisboom-hyperparameter willen afstemmen. Daarvoor kun je de referentiedocumenten raadplegen,of de functie args() gebruiken om te zien welke parsnip-objectargumenten beschikbaar zijn:

args(decision_tree)
function (mode = "unknown", engine = "rpart", cost_complexity = NULL, 
    tree_depth = NULL, min_n = NULL) 
NULL

Handleiding 5: Case-studie

De vier handleiding hiervoor waren steeds gericht op één taak met betrekking tot modelleren. Onderweg hebben we ook de kernpakketten in het tidymodels ecosysteem geïntroduceerd en enkele van de belangrijkste functies die je nodig hebt om met modellen te gaan werken.

De vijfde en laatste handleiding is een case-studie waarin we de voorgaande kennis gebruiken als basis om een voorspellend model van begin tot eind te bouwen met gegevens over hotelovernachtingen case-studie.

##  tidymodels moet geïnstalleerd zijn evenals vip
## Verder:
library(readr)       
# voor importeren van data

Data

Eerst de data binnenhalen, iets aanpassen en bekijken.

##  Inlezen
hotels <-
                read_csv('https://tidymodels.org/start/case-study/hotels.csv') %>%
                mutate_if(is.character, as.factor)
Rows: 50000 Columns: 23
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
chr  (11): hotel, children, meal, country, market_segment, distribution_chan...
dbl  (11): lead_time, stays_in_weekend_nights, stays_in_week_nights, adults,...
date  (1): arrival_date

ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
dim(hotels)
[1] 50000    23

Allicht alle variabelen nog eens goed bekijken.

 glimpse(hotels)
Rows: 50,000
Columns: 23
$ hotel                          <fct> City_Hotel, City_Hotel, Resort_Hotel, R…
$ lead_time                      <dbl> 217, 2, 95, 143, 136, 67, 47, 56, 80, 6…
$ stays_in_weekend_nights        <dbl> 1, 0, 2, 2, 1, 2, 0, 0, 0, 2, 1, 0, 1, …
$ stays_in_week_nights           <dbl> 3, 1, 5, 6, 4, 2, 2, 3, 4, 2, 2, 1, 2, …
$ adults                         <dbl> 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 1, 2, …
$ children                       <fct> none, none, none, none, none, none, chi…
$ meal                           <fct> BB, BB, BB, HB, HB, SC, BB, BB, BB, BB,…
$ country                        <fct> DEU, PRT, GBR, ROU, PRT, GBR, ESP, ESP,…
$ market_segment                 <fct> Offline_TA/TO, Direct, Online_TA, Onlin…
$ distribution_channel           <fct> TA/TO, Direct, TA/TO, TA/TO, Direct, TA…
$ is_repeated_guest              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ previous_cancellations         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ previous_bookings_not_canceled <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ reserved_room_type             <fct> A, D, A, A, F, A, C, B, D, A, A, D, A, …
$ assigned_room_type             <fct> A, K, A, A, F, A, C, A, D, A, D, D, A, …
$ booking_changes                <dbl> 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ deposit_type                   <fct> No_Deposit, No_Deposit, No_Deposit, No_…
$ days_in_waiting_list           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ customer_type                  <fct> Transient-Party, Transient, Transient, …
$ average_daily_rate             <dbl> 80.75, 170.00, 8.00, 81.00, 157.60, 49.…
$ required_car_parking_spaces    <fct> none, none, none, none, none, none, non…
$ total_of_special_requests      <dbl> 1, 3, 2, 1, 4, 1, 1, 1, 1, 1, 0, 1, 0, …
$ arrival_date                   <date> 2016-09-01, 2017-08-25, 2016-11-19, 20…

De uitkomst variabele is children, een factorvariabele met twee niveaus (wel of geen kinderen. 8,1% van de gasten heeft kinderen bij zich tijdens de hotelovernachtingen.

hotels %>%
                count(children) %>%
                mutate(prop = n/sum(n))
# A tibble: 2 × 3
  children     n   prop
  <fct>    <int>  <dbl>
1 children  4038 0.0808
2 none     45962 0.919 

Splitsen van data

We reserveren 25% van de data voor de test-data. De variabele children is behoorlijk uit balans, dus we stratificeren de dataset op deze variabele als we deze opsplitsen.

set.seed(123)
        splits      <- initial_split(hotels, strata = children)
        hotel_other <- training(splits)
        hotel_test  <- testing(splits)

Zo ziet de trainingsset er nu uit qua children variabele.

hotel_other %>%
                count(children) %>%
                mutate(prop = n/sum(n))
# A tibble: 2 × 3
  children     n   prop
  <fct>    <int>  <dbl>
1 children  3027 0.0807
2 none     34473 0.919 

Zo ziet de testtest eruit op dezelfde variabele, vergelijkbaar:

hotel_test  %>%
                count(children) %>%
                mutate(prop = n/sum(n))
# A tibble: 2 × 3
  children     n   prop
  <fct>    <int>  <dbl>
1 children  1011 0.0809
2 none     11489 0.919 

Van de trainingsset maken we ook nog een aparte validitatie set. De

Opzet

Zo ziet het er dan uit.

We gebruiken de functie validation_split() om 20% van de hotel_other verblijven toe te wijzen aan de validatieset en 30.000 verblijven aan de trainingset. Dit betekent dat de prestatiecijfers van ons model worden berekend op een enkele set van 7.500 hotelovernachtingen. Dat is vrij groot, dus de hoeveelheid gegevens zou voldoende precisie moeten opleveren om een betrouwbare indicator te zijn voor hoe goed elk model de uitkomst voorspelt met een enkele iteratie van resampling.

 set.seed(234)
        val_set <- validation_split(hotel_other,
                                    strata = children,
                                    prop = 0.80)

Ook dit hebben we gestratificeerd op de uitkomstvariabele children.

Eerste model

Hier wordt, en ik gebruik toch maar even de Engelse woorden, een ‘penalized logistic regression’ model gebruikt via glmnet. De penalty=tune(),mixture = 1 haalt irrelevante predictoren weg.

  lr_mod <-
                logistic_reg(penalty = tune(), mixture = 1) %>%
                set_engine("glmnet")

Via het pakket recipe dat in tidymodels zit kun je enkele aanvullende voorbereidende handelingen verrichten. Zoals:

- step_date() creëert voorspellers voor het jaar, de maand en de dag van de week.
- step_holiday() genereert een reeks indicatorvariabelen voor specifieke feestdagen. Hoewel we niet weten waar deze twee hotels zich bevinden, weten we wel dat de landen van herkomst voor de meeste verblijven in Europa liggen.
- step_rm() verwijdert variabelen; hier gebruiken we het om de oorspronkelijke datumvariabele te verwijderen omdat we die niet langer in het model willen.

Bovendien moeten alle categorische voorspellers (bv. distribution-channel, hotel, …) worden omgezet naar dummy-variabelen en moeten alle numerieke voorspellers worden gecentreerd en geschaald.

- step_dummy() zet tekens of factoren (d.w.z. nominale variabelen) om in een of meer numerieke binaire modeltermen voor de niveaus van de oorspronkelijke gegevens.
- step_zv() verwijdert indicatorvariabelen die slechts één unieke waarde bevatten (bv. allemaal nullen). Dit is belangrijk omdat voor gestrafte modellen de voorspellers moeten worden gecentreerd en geschaald.

- step_normalize() centreert en schaalt numerieke variabelen.


Als we al deze stappen samenvoegen tot een recept voor ons gekozen model (’penalized logistic regression`), hebben we:

holidays <- c("AllSouls", "AshWednesday", "ChristmasEve", "Easter",
                      "ChristmasDay", "GoodFriday", "NewYearsDay", "PalmSunday")
 lr_recipe <-
                recipe(children ~ ., data = hotel_other) %>%
                step_date(arrival_date) %>%
                step_holiday(arrival_date, holidays = holidays) %>%
                step_rm(arrival_date) %>%
                step_dummy(all_nominal(), -all_outcomes()) %>%
                step_zv(all_predictors()) %>%
                step_normalize(all_predictors())

Laten we nu alles (‘model en ’recipe’) in een workflow plaatsen.

 lr_workflow <-
                workflow() %>%
                add_model(lr_mod) %>%
                add_recipe(lr_recipe)

Welke penalties moeten we gebruiken? Omdat we slechts een hyperparameter hoeven af te stellen, gebruiken we een grid met 30 verschillende waarden in een kolom.

lr_reg_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 30))
        lr_reg_grid %>% top_n(-5) # lowest penalty values
Selecting by penalty
# A tibble: 5 × 1
   penalty
     <dbl>
1 0.0001  
2 0.000127
3 0.000161
4 0.000204
5 0.000259
        lr_reg_grid %>% top_n(5)  # highest penalty values
Selecting by penalty
# A tibble: 5 × 1
  penalty
    <dbl>
1  0.0386
2  0.0489
3  0.0621
4  0.0788
5  0.1   
        ## 4.5 Train & Tune ----
        lr_res <-
                lr_workflow %>%
                tune_grid(val_set,
                          grid = lr_reg_grid,
                          control = control_grid(save_pred = TRUE),
                          metrics = metric_set(roc_auc))

Het is makkelijk om de validatieset metrieken te visualiseren door het gebied onder de ROC-curve uit te zetten tegen de reeks van waarden:

 lr_plot <-
                lr_res %>%
                collect_metrics() %>%
                ggplot(aes(x = penalty, y = mean)) +
                geom_point() +
                geom_line() +
                ylab("Gebied onder de ROC Curve") +
                scale_x_log10(labels = scales::label_number())

        lr_plot

De prestaties van ons model lijken overall het beste te doen bij de kleinere strafwaarden. Als we alleen uitgaan van de roc_auc-metriek zouden we meerdere opties voor de “beste” waarde van deze hyperparameter kunnen vinden:

top_models <-
  lr_res %>% 
  show_best("roc_auc", n = 15) %>% 
  arrange(penalty) 
top_models
# A tibble: 15 × 7
    penalty .metric .estimator  mean     n std_err .config              
      <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
 1 0.000127 roc_auc binary     0.872     1      NA Preprocessor1_Model02
 2 0.000161 roc_auc binary     0.872     1      NA Preprocessor1_Model03
 3 0.000204 roc_auc binary     0.873     1      NA Preprocessor1_Model04
 4 0.000259 roc_auc binary     0.873     1      NA Preprocessor1_Model05
 5 0.000329 roc_auc binary     0.874     1      NA Preprocessor1_Model06
 6 0.000418 roc_auc binary     0.874     1      NA Preprocessor1_Model07
 7 0.000530 roc_auc binary     0.875     1      NA Preprocessor1_Model08
 8 0.000672 roc_auc binary     0.875     1      NA Preprocessor1_Model09
 9 0.000853 roc_auc binary     0.876     1      NA Preprocessor1_Model10
10 0.00108  roc_auc binary     0.876     1      NA Preprocessor1_Model11
11 0.00137  roc_auc binary     0.876     1      NA Preprocessor1_Model12
12 0.00174  roc_auc binary     0.876     1      NA Preprocessor1_Model13
13 0.00221  roc_auc binary     0.876     1      NA Preprocessor1_Model14
14 0.00281  roc_auc binary     0.875     1      NA Preprocessor1_Model15
15 0.00356  roc_auc binary     0.873     1      NA Preprocessor1_Model16

Als we select_best() zouden gebruiken, zou dit kandidaat-model 11 opleveren met een penalty-waarde van 0,00137. Kandidaat-model 12 met een strafwaarde van 0,00174 heeft in feite dezelfde prestaties als het numeriek beste model, maar kan meer voorspellers elimineren. Laten we deze nemen.

lr_best <- 
  lr_res %>% 
  collect_metrics() %>% 
  arrange(penalty) %>% 
  slice(12)
lr_best
# A tibble: 1 × 7
  penalty .metric .estimator  mean     n std_err .config              
    <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
1 0.00137 roc_auc binary     0.876     1      NA Preprocessor1_Model12

Laten we deze visualiseren:

lr_auc <- 
  lr_res %>% 
  collect_predictions(parameters = lr_best) %>% 
  roc_curve(children, .pred_children) %>% 
  mutate(model = "Logistic Regression")

autoplot(lr_auc)

Het prestatieniveau van dit logistische regressiemodel is goed, maar niet baanbrekend. Misschien is de lineaire aard van de voorspellingsvergelijking te beperkend voor deze dataset. Als volgende stap zouden we een sterk niet-lineair model kunnen overwegen dat wordt gegenereerd met behulp van een ‘vertakte’-methode.

‘Vertakte’-methode

Een effectieve en onderhoudsarme modelleringstechniek is een random forest. Vertakte modellen vereisen zeer weinig voorbewerking en kunnen vele soorten voorspellers aan (continu, categorisch, enz.).

Bouw het model zo dat het de trainingstijd reduceert. Het tune-pakket kan parallelle verwerking voor u doen en staat gebruikers toe om meerdere processors of aparte machines te gebruiken om modellen te fitten. Zo detecteer je de processoren:

 cores <- parallel::detectCores()
        cores
[1] 4

Vervolgens het model bouwen.

 rf_mod <-
                rand_forest(mtry = tune(), min_n = tune(), trees = 1000) %>%
                # tune() is voor later
                set_engine("ranger", num.threads = cores) %>%
                set_mode("classification")

Opgelet: Geen processoren vaststellen behalve voor random forest

In tegenstelling tot de `penalized logistic regression’ modellen zoals hierboven gebruikt, vraagt het ‘random forest model’ geen dummies of genormaliseerde voorspellers.

rf_recipe <-
                recipe(children ~ ., data = hotel_other) %>%
                step_date(arrival_date) %>%
                step_holiday(arrival_date) %>%
                step_rm(arrival_date)

Creëer vervolgens de workflow.

 rf_workflow <-
                workflow() %>%
                add_model(rf_mod) %>%
                add_recipe(rf_recipe)

Train en stel het model af. Laat zien wat er moet worden afgesteld.

 rf_mod %>%
                parameters()
Warning: `parameters.model_spec()` was deprecated in tune 0.1.6.9003.
ℹ Please use `hardhat::extract_parameter_set_dials()` instead.
Collection of 2 parameters for tuning

 identifier  type    object
       mtry  mtry nparam[?]
      min_n min_n nparam[+]

Model parameters needing finalization:
   # Randomly Selected Predictors ('mtry')

See `?dials::finalize` or `?dials::update.parameters` for more information.

Laat zien wel ruimte je hebt.

set.seed(345)
        rf_res <-
                rf_workflow %>%
                tune_grid(val_set,
                          grid = 25,
                          control = control_grid(save_pred = TRUE),
                          metrics = metric_set(roc_auc))
i Creating pre-processing data to finalize unknown parameter: mtry

Laat zien wat de beste keuze is.

rf_res %>%
                show_best(metric = "roc_auc")
# A tibble: 5 × 8
   mtry min_n .metric .estimator  mean     n std_err .config              
  <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
1     8     7 roc_auc binary     0.926     1      NA Preprocessor1_Model13
2    12     7 roc_auc binary     0.926     1      NA Preprocessor1_Model01
3    13     4 roc_auc binary     0.925     1      NA Preprocessor1_Model05
4     9    12 roc_auc binary     0.924     1      NA Preprocessor1_Model19
5     6    18 roc_auc binary     0.924     1      NA Preprocessor1_Model24

Het bereik van de y-as geeft echter aan dat het model zeer robuust is voor de keuze van deze parameterwaarden — op één na zijn alle ROC AUC-waarden groter dan 0,90.

autoplot(rf_res)

Selecteer de beste.

 rf_best <-
                rf_res %>%
                select_best(metric = "roc_auc")
        rf_best
# A tibble: 1 × 3
   mtry min_n .config              
  <int> <int> <chr>                
1     8     7 Preprocessor1_Model13

Stel het model af op de beste voorspelling.

rf_auc <-
                rf_res %>%
                collect_predictions(parameters = rf_best) %>%
                roc_curve(children, .pred_children) %>%
                mutate(model = "Random Forest")

Plot vervolgens het beste model.

bind_rows(rf_auc, lr_auc) %>%
                ggplot(aes(x = 1 - specificity, y = sensitivity, col = model)) +
                geom_path(lwd = 1.5, alpha = 0.8) +
                geom_abline(lty = 3) +
                coord_equal() +
                scale_color_viridis_d(option = "plasma", end = .6)

De laatste fit.

Bouw het model opnieuw op en neem de beste hyperparameter waarde voor ons ‘random forest model’. Definieer ook een nieuw argument: importance = "impurity"

Het laatste model ziet er dan zo uit.

last_rf_mod <-
                rand_forest(mtry = 8, min_n = 7, trees = 1000) %>%
                set_engine("ranger", num.threads = cores, importance = "impurity") %>%
                set_mode("classification")
        ## Laatste werkflow
        last_rf_workflow <-
                rf_workflow %>%
                update_model(last_rf_mod)

De laatste fit dan nu.

 set.seed(345)
        last_rf_fit <-
                last_rf_workflow %>%
                last_fit(splits)

Evalueer het model.

last_rf_fit %>%
                collect_metrics()
# A tibble: 2 × 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.946 Preprocessor1_Model1
2 roc_auc  binary         0.923 Preprocessor1_Model1
        ## 6.5 review variable importance ----
        last_rf_fit %>%
                pluck(".workflow", 1) %>%
                pull_workflow_fit() %>%
                vip(num_features = 20)

Laatste roc, zelfde voor de validatie set. Goede voorspeller op de nieuwe data.

last_rf_fit %>%
                collect_predictions() %>%
                roc_curve(children, .pred_children) %>%
                autoplot()

Literatuur

  • https://www.tidymodels.org/, met name https://www.tidymodels.org/start/.
  • Kuhn, M. & Silge, J. (2022). Tidy Modeling with R. A Framework for Modeling in the Tidyverse. Boston: Sebastopol (CA): O’Reilly. zie: https://www.tmwr.org