Random forest in scikit-learn
Contents
Random forest in scikit-learn#
We illustrate the following regression method on a data set called “Hitters”, which includes 20 variables and 322 observations of major league baseball players. The goal is to predict a baseball player’s salary on the basis of various features associated with performance in the previous year. We don’t cover the topic of exploratory data analysis in this notebook.
Visit this documentation if you want to learn more about the data
Setup#
import matplotlib.pyplot as plt
import numpy as np
from collections import OrderedDict
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.inspection import permutation_importance
Data#
See Hitters data preparation for details about the data preprocessing steps.
We simply import the preprocessed data by using this Python script which will yield:
X_train, X_test, y_train, y_test
df_train and df_test
feature_names
from hitters_data import *
Make contiguous flattened arrays (for our scikit-learn model):
y_train = np.ravel(y_train)
y_test = np.ravel(y_test)
Model#
Define hyperparameters:
params = {
"n_estimators": 500,
"max_depth": 4,
"min_samples_split": 5,
"warm_start":True,
"oob_score":True,
"random_state": 42,
}
Build and fit model
reg =RandomForestRegressor(**params)
reg.fit(X_train, y_train)
RandomForestRegressor(max_depth=4, min_samples_split=5, n_estimators=500,
oob_score=True, random_state=42, warm_start=True)
Make predictions
y_pred = reg.predict(X_test)
Evaluate model with RMSE
mean_squared_error(y_test, y_pred, squared=False)
296.37036964432764
Feature importance#
Next, we take a look at the tree based feature importance and the permutation importance.
Mean decrease in impurity (MDI)#
Mean decrease in impurity (MDI) is a measure of feature importance for decision tree models.
Note
Visit this notebook to learn more about MDI
Feature importances are provided by the fitted attribute
feature_importances_
# obtain feature importance
feature_importance = reg.feature_importances_
# sort features according to importance
sorted_idx = np.argsort(feature_importance)
pos = np.arange(sorted_idx.shape[0])
# plot feature importances
plt.barh(pos, feature_importance[sorted_idx], align="center")
plt.yticks(pos, np.array(feature_names)[sorted_idx])
plt.title("Feature Importance (MDI)")
plt.xlabel("Mean decrease in impurity");
Permutation feature importance#
The permutation feature importance is defined to be the decrease in a model score when a single feature value is randomly shuffled.
Note
Visit this notebook to learn more about permutation feature importance.
result = permutation_importance(
reg, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2
)
tree_importances = pd.Series(result.importances_mean, index=feature_names)
# sort features according to importance
sorted_idx = np.argsort(tree_importances)
pos = np.arange(sorted_idx.shape[0])
# plot feature importances
plt.barh(pos, tree_importances[sorted_idx], align="center")
plt.yticks(pos, np.array(feature_names)[sorted_idx])
plt.title("Feature Importance (MDI)")
plt.xlabel("Mean decrease in impurity");
Same data plotted as boxplot:
plt.boxplot(
result.importances[sorted_idx].T,
vert=False,
labels=np.array(feature_names)[sorted_idx],
)
plt.title("Permutation Importance (test set)")
Text(0.5, 1.0, 'Permutation Importance (test set)')
We observe that the same features are detected as most important using both methods (e.g.,
CAtBat
,CRBI
,CHits
,Walks
,Years
). Although the relative importances vary (especially for featureYears
).