Random forest in scikit-learn#

We illustrate the following regression method on a data set called “Hitters”, which includes 20 variables and 322 observations of major league baseball players. The goal is to predict a baseball player’s salary on the basis of various features associated with performance in the previous year. We don’t cover the topic of exploratory data analysis in this notebook.


import matplotlib.pyplot as plt
import numpy as np
from collections import OrderedDict

from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.inspection import permutation_importance


  • See Hitters data preparation for details about the data preprocessing steps.

  • We simply import the preprocessed data by using this Python script which will yield:

    • X_train, X_test, y_train, y_test

    • df_train and df_test

    • feature_names

from hitters_data import *
  • Make contiguous flattened arrays (for our scikit-learn model):

y_train = np.ravel(y_train)
y_test = np.ravel(y_test)


  • Define hyperparameters:

params = {
    "n_estimators": 500,
    "max_depth": 4,
    "min_samples_split": 5,
    "random_state": 42,
  • Build and fit model

reg =RandomForestRegressor(**params)

reg.fit(X_train, y_train)
RandomForestRegressor(max_depth=4, min_samples_split=5, n_estimators=500,
                      oob_score=True, random_state=42, warm_start=True)
  • Make predictions

y_pred = reg.predict(X_test)
  • Evaluate model with RMSE

mean_squared_error(y_test, y_pred, squared=False)

Feature importance#

  • Next, we take a look at the tree based feature importance and the permutation importance.

Mean decrease in impurity (MDI)#

  • Mean decrease in impurity (MDI) is a measure of feature importance for decision tree models.


Visit this notebook to learn more about MDI

  • Feature importances are provided by the fitted attribute feature_importances_

# obtain feature importance
feature_importance = reg.feature_importances_

# sort features according to importance
sorted_idx = np.argsort(feature_importance)
pos = np.arange(sorted_idx.shape[0])

# plot feature importances
plt.barh(pos, feature_importance[sorted_idx], align="center")

plt.yticks(pos, np.array(feature_names)[sorted_idx])
plt.title("Feature Importance (MDI)")
plt.xlabel("Mean decrease in impurity");

Permutation feature importance#

The permutation feature importance is defined to be the decrease in a model score when a single feature value is randomly shuffled.


Visit this notebook to learn more about permutation feature importance.

result = permutation_importance(
    reg, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2

tree_importances = pd.Series(result.importances_mean, index=feature_names)
# sort features according to importance
sorted_idx = np.argsort(tree_importances)
pos = np.arange(sorted_idx.shape[0])

# plot feature importances
plt.barh(pos, tree_importances[sorted_idx], align="center")

plt.yticks(pos, np.array(feature_names)[sorted_idx])
plt.title("Feature Importance (MDI)")
plt.xlabel("Mean decrease in impurity");
  • Same data plotted as boxplot:


plt.title("Permutation Importance (test set)")
Text(0.5, 1.0, 'Permutation Importance (test set)')
  • We observe that the same features are detected as most important using both methods (e.g., CAtBat, CRBI, CHits, Walks, Years). Although the relative importances vary (especially for feature Years).