Simple scikit learn model
Contents
Simple scikit learn model#
Does money make people happier? Simple version without data splitting.
Data#
Import data#
import pandas as pd
# Load the data from GitHub
LINK = "https://raw.githubusercontent.com/kirenz/datasets/master/oecd_gdp.csv"
df = pd.read_csv(LINK)
Data structure#
df
Country | GDP per capita | Life satisfaction | |
---|---|---|---|
0 | Russia | 9054.914 | 6.0 |
1 | Turkey | 9437.372 | 5.6 |
2 | Hungary | 12239.894 | 4.9 |
3 | Poland | 12495.334 | 5.8 |
4 | Slovak Republic | 15991.736 | 6.1 |
5 | Estonia | 17288.083 | 5.6 |
6 | Greece | 18064.288 | 4.8 |
7 | Portugal | 19121.592 | 5.1 |
8 | Slovenia | 20732.482 | 5.7 |
9 | Spain | 25864.721 | 6.5 |
10 | Korea | 27195.197 | 5.8 |
11 | Italy | 29866.581 | 6.0 |
12 | Japan | 32485.545 | 5.9 |
13 | Israel | 35343.336 | 7.4 |
14 | New Zealand | 37044.891 | 7.3 |
15 | France | 37675.006 | 6.5 |
16 | Belgium | 40106.632 | 6.9 |
17 | Germany | 40996.511 | 7.0 |
18 | Finland | 41973.988 | 7.4 |
19 | Canada | 43331.961 | 7.3 |
20 | Netherlands | 43603.115 | 7.3 |
21 | Austria | 43724.031 | 6.9 |
22 | United Kingdom | 43770.688 | 6.8 |
23 | Sweden | 49866.266 | 7.2 |
24 | Iceland | 50854.583 | 7.5 |
25 | Australia | 50961.865 | 7.3 |
26 | Ireland | 51350.744 | 7.0 |
27 | Denmark | 52114.165 | 7.5 |
28 | United States | 55805.204 | 7.2 |
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 29 entries, 0 to 28
Data columns (total 3 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Country 29 non-null object
1 GDP per capita 29 non-null float64
2 Life satisfaction 29 non-null float64
dtypes: float64(2), object(1)
memory usage: 824.0+ bytes
Data corrections#
# Change column names (lower case and spaces to underscore)
df.columns = df.columns.str.lower().str.replace(' ', '_')
# show the first 5 rows
df.head()
country | gdp_per_capita | life_satisfaction | |
---|---|---|---|
0 | Russia | 9054.914 | 6.0 |
1 | Turkey | 9437.372 | 5.6 |
2 | Hungary | 12239.894 | 4.9 |
3 | Poland | 12495.334 | 5.8 |
4 | Slovak Republic | 15991.736 | 6.1 |
Variable lists#
Prepare the data for later use
# define outcome variable as y_label
y_label = 'life_satisfaction'
# select features
X = df[["gdp_per_capita"]]
# create response
y = df[y_label]
Data exploration#
%matplotlib inline
import altair as alt
# Visualize the data
alt.Chart(df).mark_circle(size=100).encode(
x='gdp_per_capita:Q',
y='life_satisfaction:Q',
color=alt.Color('country', legend=None),
tooltip=['country', 'gdp_per_capita', 'life_satisfaction']
).interactive()
Linear regression model#
from sklearn.linear_model import LinearRegression
# Select a linear regression model
reg = LinearRegression()
Training#
# Fit the model
reg.fit(X, y)
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
# Model intercept
reg.intercept_
4.853052800266436
# Model coefficient
reg.coef_
array([4.91154459e-05])
Prediction#
# Prediction for our data
y_pred = reg.predict(X)
# Make a prediction for a specific GDP value
X_new = pd.DataFrame({'gdp_per_capita': [50000]})
reg.predict(X_new)
array([7.30882509])
Evaluation#
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
# Mean squared error
mean_squared_error(y, y_pred)
0.18075033705835142
# Root mean squared error
mean_squared_error(y, y_pred, squared=False)
0.4251474297915388
mean_absolute_error(y, y_pred)
0.35530429427921734
K-Nearest Neighbor Model#
from sklearn.neighbors import KNeighborsRegressor
reg2 = KNeighborsRegressor(n_neighbors=2)
reg2.fit(X, y)
KNeighborsRegressor(n_neighbors=2)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KNeighborsRegressor(n_neighbors=2)
y_pred2 = reg2.predict(X)
reg2.predict(X_new)
array([7.35])
mean_squared_error(y, y_pred2)
0.06181034482758619
mean_absolute_error(y, y_pred2)
0.20517241379310344