Random forest in scikit-learn#

We illustrate the following regression method on a data set called “Hitters”, which includes 20 variables and 322 observations of major league baseball players. The goal is to predict a baseball player’s salary on the basis of various features associated with performance in the previous year. We don’t cover the topic of exploratory data analysis in this notebook.

Setup#

import matplotlib.pyplot as plt
import numpy as np
from collections import OrderedDict


from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.inspection import permutation_importance

Data#

  • See Hitters data preparation for details about the data preprocessing steps.

  • We simply import the preprocessed data by using this Python script which will yield:

    • X_train, X_test, y_train, y_test

    • df_train and df_test

    • feature_names

from hitters_data import *
  • Make contiguous flattened arrays (for our scikit-learn model):

y_train = np.ravel(y_train)
y_test = np.ravel(y_test)

Model#

  • Define hyperparameters:

params = {
    "n_estimators": 500,
    "max_depth": 4,
    "min_samples_split": 5,
    "warm_start":True,
    "oob_score":True,
    "random_state": 42,
}
  • Build and fit model

reg =RandomForestRegressor(**params)

reg.fit(X_train, y_train)
RandomForestRegressor(max_depth=4, min_samples_split=5, n_estimators=500,
                      oob_score=True, random_state=42, warm_start=True)
  • Make predictions

y_pred = reg.predict(X_test)
  • Evaluate model with RMSE

mean_squared_error(y_test, y_pred, squared=False)
296.37036964432764

Feature importance#

  • Next, we take a look at the tree based feature importance and the permutation importance.

Mean decrease in impurity (MDI)#

  • Mean decrease in impurity (MDI) is a measure of feature importance for decision tree models.

Note

Visit this notebook to learn more about MDI

  • Feature importances are provided by the fitted attribute feature_importances_

# obtain feature importance
feature_importance = reg.feature_importances_

# sort features according to importance
sorted_idx = np.argsort(feature_importance)
pos = np.arange(sorted_idx.shape[0])

# plot feature importances
plt.barh(pos, feature_importance[sorted_idx], align="center")

plt.yticks(pos, np.array(feature_names)[sorted_idx])
plt.title("Feature Importance (MDI)")
plt.xlabel("Mean decrease in impurity");
../_images/randomforest_20_0.png

Permutation feature importance#

The permutation feature importance is defined to be the decrease in a model score when a single feature value is randomly shuffled.

Note

Visit this notebook to learn more about permutation feature importance.

result = permutation_importance(
    reg, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2
)

tree_importances = pd.Series(result.importances_mean, index=feature_names)
# sort features according to importance
sorted_idx = np.argsort(tree_importances)
pos = np.arange(sorted_idx.shape[0])

# plot feature importances
plt.barh(pos, tree_importances[sorted_idx], align="center")

plt.yticks(pos, np.array(feature_names)[sorted_idx])
plt.title("Feature Importance (MDI)")
plt.xlabel("Mean decrease in impurity");
../_images/randomforest_23_0.png
  • Same data plotted as boxplot:

plt.boxplot(
    result.importances[sorted_idx].T,
    vert=False,
    labels=np.array(feature_names)[sorted_idx],
)

plt.title("Permutation Importance (test set)")
Text(0.5, 1.0, 'Permutation Importance (test set)')
../_images/randomforest_25_1.png
  • We observe that the same features are detected as most important using both methods (e.g., CAtBat, CRBI, CHits, Walks, Years). Although the relative importances vary (especially for feature Years).