Chapter 19 Recipe and Models
We will want to use our recipe across several steps as we train and test our models. We will:
Process the recipe using the training set: This involves any estimation or calculations based on the training set. For our recipe, the training set will be used to determine which predictors should be converted to dummy variables and which predictors will have zero-variance in the training set, and should be slated for removal.
Apply the recipe to the training set: We create the final predictor set on the training set.
Apply the recipe to the test set: We create the final predictor set on the test set. Nothing is recomputed and no information from the test set is used here; the dummy variable and zero-variance results from the training set are applied to the test set.
To simplify this process, we can use a model workflow, which pairs a model and recipe together. This is a straightforward approach because different recipes are often needed for different models, so when a model and recipe are bundled, it becomes easier to train and test workflows. We’ll use the workflow
package from tidymodels to bundle our parsnip model (lr_mod etc.) with our recipe (flights_rec).
19.1 Fit models with workflows
To combine the data preparation with the model building, we use the package workflows. A workflow is an object that can bundle together your pre-processing, modeling, and post-processing requests.
19.1.1 Logistic regression
<-
flights_wflow_lr_mod workflow() %>%
add_model(lr_mod) %>%
add_recipe(flights_rec)
flights_wflow_lr_mod
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: logistic_reg()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
##
## ● step_date()
## ● step_holiday()
## ● step_rm()
## ● step_dummy()
## ● step_zv()
## ● step_corr()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Logistic Regression Model Specification (classification)
##
## Computational engine: glm
19.1.2 Decision tree
<-
flights_wflow_dt_mod workflow() %>%
add_model(dt_mod) %>%
add_recipe(flights_rec)
flights_wflow_dt_mod
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: decision_tree()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
##
## ● step_date()
## ● step_holiday()
## ● step_rm()
## ● step_dummy()
## ● step_zv()
## ● step_corr()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Decision Tree Model Specification (classification)
##
## Computational engine: C5.0
19.1.3 Random forest
<-
flights_wflow_rf_mod workflow() %>%
add_model(rf_mod) %>%
add_recipe(flights_rec)
flights_wflow_rf_mod
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: rand_forest()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
##
## ● step_date()
## ● step_holiday()
## ● step_rm()
## ● step_dummy()
## ● step_zv()
## ● step_corr()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Random Forest Model Specification (classification)
##
## Computational engine: ranger
19.1.4 XGBoost
<-
flights_wflow_xgb_mod workflow() %>%
add_model(xgb_mod) %>%
add_recipe(flights_rec)
flights_wflow_xgb_mod
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: boost_tree()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
##
## ● step_date()
## ● step_holiday()
## ● step_rm()
## ● step_dummy()
## ● step_zv()
## ● step_corr()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Boosted Tree Model Specification (classification)
##
## Computational engine: xgboost
19.2 Train models
Now, there is a single function that can be used to prepare the recipe and train the models from the resulting predictors.
19.2.1 Logistic regression
<-
flights_fit_lr_mode %>%
flights_wflow_lr_mod fit(data = train_data)
flights_fit_lr_mode
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: logistic_reg()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
##
## ● step_date()
## ● step_holiday()
## ● step_rm()
## ● step_dummy()
## ● step_zv()
## ● step_corr()
##
## ── Model ───────────────────────────────────────────────────────────────────────
##
## Call: stats::glm(formula = ..y ~ ., family = stats::binomial, data = data)
##
## Coefficients:
## (Intercept) dep_time
## 3.272e+01 -1.710e-03
## air_time date_USChristmasDay
## -4.929e-02 -4.008e-01
## date_USColumbusDay date_USCPulaskisBirthday
## 1.296e+01 9.247e-01
## date_USDecorationMemorialDay date_USElectionDay
## 6.422e-02 1.121e+00
## date_USGoodFriday date_USInaugurationDay
## 1.491e-01 4.502e-01
## date_USIndependenceDay date_USLaborDay
## 1.082e+00 -1.792e+00
## date_USLincolnsBirthday date_USMemorialDay
## -7.683e-01 1.125e+00
## date_USMLKingsBirthday date_USNewYearsDay
## 4.219e-01 5.933e-01
## date_USPresidentsDay date_USThanksgivingDay
## 4.346e-01 -5.125e-01
## date_USVeteransDay date_USWashingtonsBirthday
## -3.183e-01 -5.277e-01
## origin_JFK origin_LGA
## 3.618e-01 2.904e-02
## dest_ACK dest_ALB
## -1.290e+01 -2.497e+01
## dest_ATL dest_AUS
## -2.217e+01 -1.729e+01
## dest_AVL dest_BDL
## -2.367e+01 -2.594e+01
## dest_BGR dest_BHM
## -2.383e+01 -2.097e+01
## dest_BNA dest_BOS
## -2.204e+01 -2.519e+01
## dest_BQN dest_BTV
## -1.687e+01 -2.482e+01
## dest_BUF dest_BUR
## -2.492e+01 -1.091e+01
## dest_BWI dest_BZN
## -2.586e+01 -3.139e+00
## dest_CAE dest_CAK
## -2.550e+01 -2.509e+01
## dest_CHO dest_CHS
## -1.032e+01 -2.302e+01
## dest_CLE dest_CLT
## -2.362e+01 -2.384e+01
## dest_CMH dest_CRW
## -2.379e+01 -2.341e+01
##
## ...
## and 106 more lines.
19.2.2 Decision tree
<-
flights_fit_dt_mode %>%
flights_wflow_dt_mod fit(data = train_data)
flights_fit_dt_mode
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: decision_tree()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
##
## ● step_date()
## ● step_holiday()
## ● step_rm()
## ● step_dummy()
## ● step_zv()
## ● step_corr()
##
## ── Model ───────────────────────────────────────────────────────────────────────
##
## Call:
## C5.0.default(x = x, y = y, trials = 1, control = C50::C5.0Control(minCases =
## 2, sample = 0))
##
## Classification Tree
## Number of samples: 7500
## Number of predictors: 147
##
## Tree size: 18
##
## Non-standard options: attempt to group attributes
19.2.3 Random forest
<-
flights_fit_rf_mode %>%
flights_wflow_rf_mod fit(data = train_data)
flights_fit_rf_mode
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: rand_forest()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
##
## ● step_date()
## ● step_holiday()
## ● step_rm()
## ● step_dummy()
## ● step_zv()
## ● step_corr()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Ranger result
##
## Call:
## ranger::ranger(x = maybe_data_frame(x), y = y, num.threads = 1, verbose = FALSE, seed = sample.int(10^5, 1), probability = TRUE)
##
## Type: Probability estimation
## Number of trees: 500
## Sample size: 7500
## Number of independent variables: 147
## Mtry: 12
## Target node size: 10
## Variable importance mode: none
## Splitrule: gini
## OOB prediction error (Brier s.): 0.1151285
19.2.4 XG Boost
<-
flights_fit_xgb_mode %>%
flights_wflow_xgb_mod fit(data = train_data)
flights_fit_xgb_mode
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: boost_tree()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
##
## ● step_date()
## ● step_holiday()
## ● step_rm()
## ● step_dummy()
## ● step_zv()
## ● step_corr()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## ##### xgb.Booster
## raw: 47.3 Kb
## call:
## xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0,
## colsample_bytree = 1, min_child_weight = 1, subsample = 1),
## data = x$data, nrounds = 15, watchlist = x$watchlist, verbose = 0,
## objective = "binary:logistic", nthread = 1)
## params (as set within xgb.train):
## eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", min_child_weight = "1", subsample = "1", objective = "binary:logistic", nthread = "1", validate_parameters = "TRUE"
## xgb.attributes:
## niter
## callbacks:
## cb.evaluation.log()
## # of features: 147
## niter: 15
## nfeatures : 147
## evaluation_log:
## iter training_error
## 1 0.137333
## 2 0.138533
## ---
## 14 0.130400
## 15 0.128267
19.3 Model recipe objects
The objects above have the finalized recipe and fitted model objects inside. You may want to extract the model or recipe objects from the workflow. To do this, you can use the helper functions pull_workflow_fit()
and pull_workflow_prepped_recipe()
.
For example, here we pull the fitted model object then use the broom::tidy()
function to get a tidy tibble of the Logisitc Regression model coefficients.
19.3.1 Logistic regression
%>%
flights_fit_lr_mode pull_workflow_fit() %>%
tidy()
## # A tibble: 148 x 5
## term estimate std.error statistic p.value
## <chr> <dbl> <dbl> <dbl> <dbl>
## 1 (Intercept) 32.7 702. 0.0466 9.63e- 1
## 2 dep_time -0.00171 0.0000832 -20.5 8.07e-94
## 3 air_time -0.0493 0.00337 -14.6 2.28e-48
## 4 date_USChristmasDay -0.401 0.613 -0.654 5.13e- 1
## 5 date_USColumbusDay 13.0 274. 0.0472 9.62e- 1
## 6 date_USCPulaskisBirthday 0.925 0.765 1.21 2.27e- 1
## 7 date_USDecorationMemorialDay 0.0642 0.579 0.111 9.12e- 1
## 8 date_USElectionDay 1.12 1.07 1.05 2.95e- 1
## 9 date_USGoodFriday 0.149 0.713 0.209 8.34e- 1
## 10 date_USInaugurationDay 0.450 1.06 0.425 6.71e- 1
## # … with 138 more rows
19.3.2 Decision tree
%>%
flights_fit_xgb_mode pull_workflow_prepped_recipe()
## Data Recipe
##
## Inputs:
##
## role #variables
## ID 2
## outcome 1
## predictor 7
##
## Training data contained 7500 data points and no missing data.
##
## Operations:
##
## Date features from date [trained]
## Holiday features from date [trained]
## Variables removed date [trained]
## Dummy variables from origin, dest, carrier, date_dow, date_month [trained]
## Zero variance filter removed dest_ANC, dest_EYW, dest_HDN, ... [trained]
## Correlation filter removed distance [trained]
19.3.3 Random forest
%>%
flights_fit_rf_mode pull_workflow_prepped_recipe()
## Data Recipe
##
## Inputs:
##
## role #variables
## ID 2
## outcome 1
## predictor 7
##
## Training data contained 7500 data points and no missing data.
##
## Operations:
##
## Date features from date [trained]
## Holiday features from date [trained]
## Variables removed date [trained]
## Dummy variables from origin, dest, carrier, date_dow, date_month [trained]
## Zero variance filter removed dest_ANC, dest_EYW, dest_HDN, ... [trained]
## Correlation filter removed distance [trained]
19.3.4 XG Boost
%>%
flights_fit_xgb_mode pull_workflow_prepped_recipe()
## Data Recipe
##
## Inputs:
##
## role #variables
## ID 2
## outcome 1
## predictor 7
##
## Training data contained 7500 data points and no missing data.
##
## Operations:
##
## Date features from date [trained]
## Holiday features from date [trained]
## Variables removed date [trained]
## Dummy variables from origin, dest, carrier, date_dow, date_month [trained]
## Zero variance filter removed dest_ANC, dest_EYW, dest_HDN, ... [trained]
## Correlation filter removed distance [trained]
19.4 Summary
Our goal was to predict whether a plane arrives more than 30 minutes late. We have just:
Built the model (
lr_mod
etc.),Created a preprocessing recipe (
flights_rec
),Bundled the model and recipe (
flights_wflow
), andTrained our workflow using a single call to
fit()
.
The next step is to use the trained workflow (flights_fit
) to predict with the unseen test data, which we will do with a single call to predict().