Chapter 19 Recipe and Models

We will want to use our recipe across several steps as we train and test our models. We will:

  1. Process the recipe using the training set: This involves any estimation or calculations based on the training set. For our recipe, the training set will be used to determine which predictors should be converted to dummy variables and which predictors will have zero-variance in the training set, and should be slated for removal.

  2. Apply the recipe to the training set: We create the final predictor set on the training set.

  3. Apply the recipe to the test set: We create the final predictor set on the test set. Nothing is recomputed and no information from the test set is used here; the dummy variable and zero-variance results from the training set are applied to the test set.

To simplify this process, we can use a model workflow, which pairs a model and recipe together. This is a straightforward approach because different recipes are often needed for different models, so when a model and recipe are bundled, it becomes easier to train and test workflows. We’ll use the workflow package from tidymodels to bundle our parsnip model (lr_mod etc.) with our recipe (flights_rec).

19.1 Fit models with workflows

To combine the data preparation with the model building, we use the package workflows. A workflow is an object that can bundle together your pre-processing, modeling, and post-processing requests.

19.1.1 Logistic regression

flights_wflow_lr_mod <- 
  workflow() %>% 
  add_model(lr_mod) %>% 
  add_recipe(flights_rec)

flights_wflow_lr_mod
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: logistic_reg()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
## 
## ● step_date()
## ● step_holiday()
## ● step_rm()
## ● step_dummy()
## ● step_zv()
## ● step_corr()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Logistic Regression Model Specification (classification)
## 
## Computational engine: glm

19.1.2 Decision tree

flights_wflow_dt_mod <- 
  workflow() %>% 
  add_model(dt_mod) %>% 
  add_recipe(flights_rec)

flights_wflow_dt_mod
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: decision_tree()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
## 
## ● step_date()
## ● step_holiday()
## ● step_rm()
## ● step_dummy()
## ● step_zv()
## ● step_corr()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Decision Tree Model Specification (classification)
## 
## Computational engine: C5.0

19.1.3 Random forest

flights_wflow_rf_mod <- 
  workflow() %>% 
  add_model(rf_mod) %>% 
  add_recipe(flights_rec)

flights_wflow_rf_mod
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: rand_forest()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
## 
## ● step_date()
## ● step_holiday()
## ● step_rm()
## ● step_dummy()
## ● step_zv()
## ● step_corr()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Random Forest Model Specification (classification)
## 
## Computational engine: ranger

19.1.4 XGBoost

flights_wflow_xgb_mod <- 
  workflow() %>% 
  add_model(xgb_mod) %>% 
  add_recipe(flights_rec)

flights_wflow_xgb_mod
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: boost_tree()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
## 
## ● step_date()
## ● step_holiday()
## ● step_rm()
## ● step_dummy()
## ● step_zv()
## ● step_corr()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Boosted Tree Model Specification (classification)
## 
## Computational engine: xgboost

19.2 Train models

Now, there is a single function that can be used to prepare the recipe and train the models from the resulting predictors.

19.2.1 Logistic regression

flights_fit_lr_mode <- 
  flights_wflow_lr_mod %>% 
  fit(data = train_data)

flights_fit_lr_mode
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: logistic_reg()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
## 
## ● step_date()
## ● step_holiday()
## ● step_rm()
## ● step_dummy()
## ● step_zv()
## ● step_corr()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## 
## Call:  stats::glm(formula = ..y ~ ., family = stats::binomial, data = data)
## 
## Coefficients:
##                  (Intercept)                      dep_time  
##                    3.272e+01                    -1.710e-03  
##                     air_time           date_USChristmasDay  
##                   -4.929e-02                    -4.008e-01  
##           date_USColumbusDay      date_USCPulaskisBirthday  
##                    1.296e+01                     9.247e-01  
## date_USDecorationMemorialDay            date_USElectionDay  
##                    6.422e-02                     1.121e+00  
##            date_USGoodFriday        date_USInaugurationDay  
##                    1.491e-01                     4.502e-01  
##       date_USIndependenceDay               date_USLaborDay  
##                    1.082e+00                    -1.792e+00  
##      date_USLincolnsBirthday            date_USMemorialDay  
##                   -7.683e-01                     1.125e+00  
##       date_USMLKingsBirthday            date_USNewYearsDay  
##                    4.219e-01                     5.933e-01  
##         date_USPresidentsDay        date_USThanksgivingDay  
##                    4.346e-01                    -5.125e-01  
##           date_USVeteransDay    date_USWashingtonsBirthday  
##                   -3.183e-01                    -5.277e-01  
##                   origin_JFK                    origin_LGA  
##                    3.618e-01                     2.904e-02  
##                     dest_ACK                      dest_ALB  
##                   -1.290e+01                    -2.497e+01  
##                     dest_ATL                      dest_AUS  
##                   -2.217e+01                    -1.729e+01  
##                     dest_AVL                      dest_BDL  
##                   -2.367e+01                    -2.594e+01  
##                     dest_BGR                      dest_BHM  
##                   -2.383e+01                    -2.097e+01  
##                     dest_BNA                      dest_BOS  
##                   -2.204e+01                    -2.519e+01  
##                     dest_BQN                      dest_BTV  
##                   -1.687e+01                    -2.482e+01  
##                     dest_BUF                      dest_BUR  
##                   -2.492e+01                    -1.091e+01  
##                     dest_BWI                      dest_BZN  
##                   -2.586e+01                    -3.139e+00  
##                     dest_CAE                      dest_CAK  
##                   -2.550e+01                    -2.509e+01  
##                     dest_CHO                      dest_CHS  
##                   -1.032e+01                    -2.302e+01  
##                     dest_CLE                      dest_CLT  
##                   -2.362e+01                    -2.384e+01  
##                     dest_CMH                      dest_CRW  
##                   -2.379e+01                    -2.341e+01  
## 
## ...
## and 106 more lines.

19.2.2 Decision tree

flights_fit_dt_mode <- 
  flights_wflow_dt_mod %>% 
  fit(data = train_data)

flights_fit_dt_mode
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: decision_tree()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
## 
## ● step_date()
## ● step_holiday()
## ● step_rm()
## ● step_dummy()
## ● step_zv()
## ● step_corr()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## 
## Call:
## C5.0.default(x = x, y = y, trials = 1, control = C50::C5.0Control(minCases =
##  2, sample = 0))
## 
## Classification Tree
## Number of samples: 7500 
## Number of predictors: 147 
## 
## Tree size: 18 
## 
## Non-standard options: attempt to group attributes

19.2.3 Random forest

flights_fit_rf_mode <- 
  flights_wflow_rf_mod %>% 
  fit(data = train_data)

flights_fit_rf_mode
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: rand_forest()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
## 
## ● step_date()
## ● step_holiday()
## ● step_rm()
## ● step_dummy()
## ● step_zv()
## ● step_corr()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Ranger result
## 
## Call:
##  ranger::ranger(x = maybe_data_frame(x), y = y, num.threads = 1,      verbose = FALSE, seed = sample.int(10^5, 1), probability = TRUE) 
## 
## Type:                             Probability estimation 
## Number of trees:                  500 
## Sample size:                      7500 
## Number of independent variables:  147 
## Mtry:                             12 
## Target node size:                 10 
## Variable importance mode:         none 
## Splitrule:                        gini 
## OOB prediction error (Brier s.):  0.1151285

19.2.4 XG Boost

flights_fit_xgb_mode <- 
  flights_wflow_xgb_mod %>% 
  fit(data = train_data)

flights_fit_xgb_mode
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: boost_tree()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
## 
## ● step_date()
## ● step_holiday()
## ● step_rm()
## ● step_dummy()
## ● step_zv()
## ● step_corr()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## ##### xgb.Booster
## raw: 47.3 Kb 
## call:
##   xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
##     colsample_bytree = 1, min_child_weight = 1, subsample = 1), 
##     data = x$data, nrounds = 15, watchlist = x$watchlist, verbose = 0, 
##     objective = "binary:logistic", nthread = 1)
## params (as set within xgb.train):
##   eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", min_child_weight = "1", subsample = "1", objective = "binary:logistic", nthread = "1", validate_parameters = "TRUE"
## xgb.attributes:
##   niter
## callbacks:
##   cb.evaluation.log()
## # of features: 147 
## niter: 15
## nfeatures : 147 
## evaluation_log:
##     iter training_error
##        1       0.137333
##        2       0.138533
## ---                    
##       14       0.130400
##       15       0.128267

19.3 Model recipe objects

The objects above have the finalized recipe and fitted model objects inside. You may want to extract the model or recipe objects from the workflow. To do this, you can use the helper functions pull_workflow_fit() and pull_workflow_prepped_recipe().

For example, here we pull the fitted model object then use the broom::tidy() function to get a tidy tibble of the Logisitc Regression model coefficients.

19.3.1 Logistic regression

flights_fit_lr_mode %>% 
  pull_workflow_fit() %>% 
  tidy()
## # A tibble: 148 x 5
##    term                         estimate   std.error statistic  p.value
##    <chr>                           <dbl>       <dbl>     <dbl>    <dbl>
##  1 (Intercept)                  32.7     702.           0.0466 9.63e- 1
##  2 dep_time                     -0.00171   0.0000832  -20.5    8.07e-94
##  3 air_time                     -0.0493    0.00337    -14.6    2.28e-48
##  4 date_USChristmasDay          -0.401     0.613       -0.654  5.13e- 1
##  5 date_USColumbusDay           13.0     274.           0.0472 9.62e- 1
##  6 date_USCPulaskisBirthday      0.925     0.765        1.21   2.27e- 1
##  7 date_USDecorationMemorialDay  0.0642    0.579        0.111  9.12e- 1
##  8 date_USElectionDay            1.12      1.07         1.05   2.95e- 1
##  9 date_USGoodFriday             0.149     0.713        0.209  8.34e- 1
## 10 date_USInaugurationDay        0.450     1.06         0.425  6.71e- 1
## # … with 138 more rows

19.3.2 Decision tree

flights_fit_xgb_mode %>% 
  pull_workflow_prepped_recipe()
## Data Recipe
## 
## Inputs:
## 
##       role #variables
##         ID          2
##    outcome          1
##  predictor          7
## 
## Training data contained 7500 data points and no missing data.
## 
## Operations:
## 
## Date features from date [trained]
## Holiday features from date [trained]
## Variables removed date [trained]
## Dummy variables from origin, dest, carrier, date_dow, date_month [trained]
## Zero variance filter removed dest_ANC, dest_EYW, dest_HDN, ... [trained]
## Correlation filter removed distance [trained]

19.3.3 Random forest

flights_fit_rf_mode %>% 
  pull_workflow_prepped_recipe()
## Data Recipe
## 
## Inputs:
## 
##       role #variables
##         ID          2
##    outcome          1
##  predictor          7
## 
## Training data contained 7500 data points and no missing data.
## 
## Operations:
## 
## Date features from date [trained]
## Holiday features from date [trained]
## Variables removed date [trained]
## Dummy variables from origin, dest, carrier, date_dow, date_month [trained]
## Zero variance filter removed dest_ANC, dest_EYW, dest_HDN, ... [trained]
## Correlation filter removed distance [trained]

19.3.4 XG Boost

flights_fit_xgb_mode %>% 
    pull_workflow_prepped_recipe()
## Data Recipe
## 
## Inputs:
## 
##       role #variables
##         ID          2
##    outcome          1
##  predictor          7
## 
## Training data contained 7500 data points and no missing data.
## 
## Operations:
## 
## Date features from date [trained]
## Holiday features from date [trained]
## Variables removed date [trained]
## Dummy variables from origin, dest, carrier, date_dow, date_month [trained]
## Zero variance filter removed dest_ANC, dest_EYW, dest_HDN, ... [trained]
## Correlation filter removed distance [trained]

19.4 Summary

Our goal was to predict whether a plane arrives more than 30 minutes late. We have just:

  1. Built the model (lr_mod etc.),

  2. Created a preprocessing recipe (flights_rec),

  3. Bundled the model and recipe (flights_wflow), and

  4. Trained our workflow using a single call to fit().

The next step is to use the trained workflow (flights_fit) to predict with the unseen test data, which we will do with a single call to predict().