{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Simple scikit learn model\n", "\n", "Does money make people happier? Simple version without data splitting." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "# Load the data from GitHub\n", "LINK = \"https://raw.githubusercontent.com/kirenz/datasets/master/oecd_gdp.csv\"\n", "df = pd.read_csv(LINK)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data structure" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
CountryGDP per capitaLife satisfaction
0Russia9054.9146.0
1Turkey9437.3725.6
2Hungary12239.8944.9
3Poland12495.3345.8
4Slovak Republic15991.7366.1
5Estonia17288.0835.6
6Greece18064.2884.8
7Portugal19121.5925.1
8Slovenia20732.4825.7
9Spain25864.7216.5
10Korea27195.1975.8
11Italy29866.5816.0
12Japan32485.5455.9
13Israel35343.3367.4
14New Zealand37044.8917.3
15France37675.0066.5
16Belgium40106.6326.9
17Germany40996.5117.0
18Finland41973.9887.4
19Canada43331.9617.3
20Netherlands43603.1157.3
21Austria43724.0316.9
22United Kingdom43770.6886.8
23Sweden49866.2667.2
24Iceland50854.5837.5
25Australia50961.8657.3
26Ireland51350.7447.0
27Denmark52114.1657.5
28United States55805.2047.2
\n", "
" ], "text/plain": [ " Country GDP per capita Life satisfaction\n", "0 Russia 9054.914 6.0\n", "1 Turkey 9437.372 5.6\n", "2 Hungary 12239.894 4.9\n", "3 Poland 12495.334 5.8\n", "4 Slovak Republic 15991.736 6.1\n", "5 Estonia 17288.083 5.6\n", "6 Greece 18064.288 4.8\n", "7 Portugal 19121.592 5.1\n", "8 Slovenia 20732.482 5.7\n", "9 Spain 25864.721 6.5\n", "10 Korea 27195.197 5.8\n", "11 Italy 29866.581 6.0\n", "12 Japan 32485.545 5.9\n", "13 Israel 35343.336 7.4\n", "14 New Zealand 37044.891 7.3\n", "15 France 37675.006 6.5\n", "16 Belgium 40106.632 6.9\n", "17 Germany 40996.511 7.0\n", "18 Finland 41973.988 7.4\n", "19 Canada 43331.961 7.3\n", "20 Netherlands 43603.115 7.3\n", "21 Austria 43724.031 6.9\n", "22 United Kingdom 43770.688 6.8\n", "23 Sweden 49866.266 7.2\n", "24 Iceland 50854.583 7.5\n", "25 Australia 50961.865 7.3\n", "26 Ireland 51350.744 7.0\n", "27 Denmark 52114.165 7.5\n", "28 United States 55805.204 7.2" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 29 entries, 0 to 28\n", "Data columns (total 3 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 Country 29 non-null object \n", " 1 GDP per capita 29 non-null float64\n", " 2 Life satisfaction 29 non-null float64\n", "dtypes: float64(2), object(1)\n", "memory usage: 824.0+ bytes\n" ] } ], "source": [ "df.info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data corrections" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
countrygdp_per_capitalife_satisfaction
0Russia9054.9146.0
1Turkey9437.3725.6
2Hungary12239.8944.9
3Poland12495.3345.8
4Slovak Republic15991.7366.1
\n", "
" ], "text/plain": [ " country gdp_per_capita life_satisfaction\n", "0 Russia 9054.914 6.0\n", "1 Turkey 9437.372 5.6\n", "2 Hungary 12239.894 4.9\n", "3 Poland 12495.334 5.8\n", "4 Slovak Republic 15991.736 6.1" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Change column names (lower case and spaces to underscore)\n", "df.columns = df.columns.str.lower().str.replace(' ', '_')\n", "\n", "# show the first 5 rows\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Variable lists\n", "\n", "Prepare the data for later use" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# define outcome variable as y_label\n", "y_label = 'life_satisfaction'\n", "\n", "# select features\n", "X = df[[\"gdp_per_capita\"]]\n", "\n", "# create response\n", "y = df[y_label]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data exploration" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.Chart(...)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%matplotlib inline\n", "import altair as alt\n", "\n", "# Visualize the data\n", "alt.Chart(df).mark_circle(size=100).encode(\n", " x='gdp_per_capita:Q',\n", " y='life_satisfaction:Q',\n", " color=alt.Color('country', legend=None),\n", " tooltip=['country', 'gdp_per_capita', 'life_satisfaction']\n", "\n", ").interactive()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Linear regression model" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LinearRegression" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# Select a linear regression model\n", "reg = LinearRegression()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "LinearRegression()" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Fit the model\n", "reg.fit(X, y)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4.853052800266436" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Model intercept\n", "reg.intercept_" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([4.91154459e-05])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Model coefficient\n", "reg.coef_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prediction" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "# Prediction for our data\n", "y_pred = reg.predict(X)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([7.30882509])" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Make a prediction for a specific GDP value\n", "X_new = pd.DataFrame({'gdp_per_capita': [50000]})\n", "\n", "reg.predict(X_new)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluation" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import mean_squared_error\n", "from sklearn.metrics import mean_absolute_error" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.18075033705835142" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Mean squared error\n", "mean_squared_error(y, y_pred)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.4251474297915388" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Root mean squared error\n", "mean_squared_error(y, y_pred, squared=False)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.35530429427921734" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mean_absolute_error(y, y_pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## K-Nearest Neighbor Model" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "tags": [] }, "outputs": [], "source": [ "from sklearn.neighbors import KNeighborsRegressor" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "reg2 = KNeighborsRegressor(n_neighbors=2)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
KNeighborsRegressor(n_neighbors=2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "KNeighborsRegressor(n_neighbors=2)" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reg2.fit(X, y)" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "y_pred2 = reg2.predict(X)" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([7.35])" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reg2.predict(X_new) " ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.06181034482758619" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mean_squared_error(y, y_pred2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.20517241379310344" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mean_absolute_error(y, y_pred2)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.12 ('base')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "nav_menu": {}, "toc": { "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": 6, "toc_cell": false, "toc_section_display": "block", "toc_window_display": true }, "toc_position": { "height": "616px", "left": "0px", "right": "20px", "top": "106px", "width": "213px" }, "vscode": { "interpreter": { "hash": "463226f144cc21b006ce6927bfc93dd00694e52c8bc6857abb6e555b983749e9" } } }, "nbformat": 4, "nbformat_minor": 1 }