{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Sales prediction\n", "\n", "Version with data splitting." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import altair as alt\n", "\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.model_selection import cross_val_score\n", "from sklearn.linear_model import LinearRegression\n", "from sklearn.metrics import r2_score\n", "from sklearn.metrics import mean_squared_error\n", "from sklearn.metrics import mean_absolute_error" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv('https://raw.githubusercontent.com/kirenz/datasets/master/advertising.csv')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data structure" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
MarketTVradionewspapersales
01230.137.869.222.1
1244.539.345.110.4
2317.245.969.39.3
34151.541.358.518.5
45180.810.858.412.9
..................
19519638.23.713.87.6
19619794.24.98.19.7
197198177.09.36.412.8
198199283.642.066.225.5
199200232.18.68.713.4
\n", "

200 rows × 5 columns

\n", "
" ], "text/plain": [ " Market TV radio newspaper sales\n", "0 1 230.1 37.8 69.2 22.1\n", "1 2 44.5 39.3 45.1 10.4\n", "2 3 17.2 45.9 69.3 9.3\n", "3 4 151.5 41.3 58.5 18.5\n", "4 5 180.8 10.8 58.4 12.9\n", ".. ... ... ... ... ...\n", "195 196 38.2 3.7 13.8 7.6\n", "196 197 94.2 4.9 8.1 9.7\n", "197 198 177.0 9.3 6.4 12.8\n", "198 199 283.6 42.0 66.2 25.5\n", "199 200 232.1 8.6 8.7 13.4\n", "\n", "[200 rows x 5 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 200 entries, 0 to 199\n", "Data columns (total 5 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 Market 200 non-null int64 \n", " 1 TV 200 non-null float64\n", " 2 radio 200 non-null float64\n", " 3 newspaper 200 non-null float64\n", " 4 sales 200 non-null float64\n", "dtypes: float64(4), int64(1)\n", "memory usage: 7.9 KB\n" ] } ], "source": [ "df.info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data corrections" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# variable Market is categorical\n", "df['Market'] = df['Market'].astype('category')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Variable lists" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# define outcome variable as y_label\n", "y_label = 'sales'\n", "\n", "# select features\n", "features = df.drop(columns=[y_label, 'Market']).columns\n", "\n", "# create feature data\n", "X = df[features]\n", "\n", "# create response\n", "y = df[y_label]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data splitting" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X, y, \n", " test_size=0.2,\n", " random_state=42)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# data exploration set\n", "df_train = pd.DataFrame(X_train.copy())\n", "\n", "df_train = df_train.join(pd.DataFrame(y_train))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Analysis" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
countmeanstdmin25%50%75%max
TV160.0150.01937584.4188570.777.750150.65218.825296.4
radio160.022.87562514.8052160.09.82521.2036.42549.6
newspaper160.029.94562520.3364490.312.87525.6044.500100.9
sales160.014.1000005.1087541.610.47513.2017.32527.0
\n", "
" ], "text/plain": [ " count mean std min 25% 50% 75% max\n", "TV 160.0 150.019375 84.418857 0.7 77.750 150.65 218.825 296.4\n", "radio 160.0 22.875625 14.805216 0.0 9.825 21.20 36.425 49.6\n", "newspaper 160.0 29.945625 20.336449 0.3 12.875 25.60 44.500 100.9\n", "sales 160.0 14.100000 5.108754 1.6 10.475 13.20 17.325 27.0" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.describe().T" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.RepeatChart(...)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "alt.Chart(df_train).mark_bar().encode(\n", " alt.X(alt.repeat(\"column\"), type=\"quantitative\", bin=True),\n", " y='count()',\n", ").properties(\n", " width=150,\n", " height=150\n", ").repeat(\n", " column=['sales', 'TV', 'radio', 'newspaper']\n", ")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.RepeatChart(...)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "alt.Chart(df_train).mark_circle().encode(\n", " alt.X(alt.repeat(\"column\"), type='quantitative'),\n", " alt.Y(alt.repeat(\"row\"), type='quantitative')\n", ").properties(\n", " width=150,\n", " height=150\n", ").repeat(\n", " row=['sales', 'TV', 'radio', 'newspaper'],\n", " column=['sales', 'TV', 'radio', 'newspaper']\n", ").interactive()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "sales 1.000000\n", "TV 0.768874\n", "radio 0.592373\n", "newspaper 0.237874\n", "Name: sales, dtype: float64" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# inspect correlation between outcome and possible predictors\n", "corr = df_train.corr()\n", "corr[y_label].sort_values(ascending=False)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 TVradionewspapersales
TV1.0000000.0538720.0190840.768874
radio0.0538721.0000000.3880740.592373
newspaper0.0190840.3880741.0000000.237874
sales0.7688740.5923730.2378741.000000
\n" ], "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# take a look at all correlations\n", "corr.style.background_gradient(cmap='Blues')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Select model" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# select the linear regression model\n", "reg = LinearRegression()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training and validation" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# cross-validation with 5 folds\n", "scores = cross_val_score(reg, \n", " X_train, \n", " y_train, \n", " cv=5, \n", " scoring='neg_mean_squared_error') *-1" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 lr
14.192954
21.500644
32.109080
42.541355
54.372931
\n" ], "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# store cross-validation scores\n", "df_scores = pd.DataFrame({\"lr\": scores})\n", "df_scores" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# reset index to match the number of folds\n", "df_scores.index += 1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# print dataframe\n", "df_scores.style.background_gradient(cmap='Blues')" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.Chart(...)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "alt.Chart(df_scores.reset_index()).mark_line(\n", " point=alt.OverlayMarkDef()\n", ").encode(\n", " x=alt.X(\"index\", bin=False, title=\"Fold\", axis=alt.Axis(tickCount=5)),\n", " y=alt.Y(\"lr\", aggregate=\"mean\", title=\"Mean squared error (MSE)\")\n", ")" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
countmeanstdmin25%50%75%max
lr5.02.9433931.2790831.5006442.109082.5413554.1929544.372931
\n", "
" ], "text/plain": [ " count mean std min 25% 50% 75% max\n", "lr 5.0 2.943393 1.279083 1.500644 2.10908 2.541355 4.192954 4.372931" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_scores.describe().T" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fit model" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "LinearRegression()" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Fit the model to the complete training data\n", "reg.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Coefficients" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
NameCoefficient
0Intercept2.979
1TV0.045
2radio0.189
3newspaper0.003
\n", "
" ], "text/plain": [ " Name Coefficient\n", "0 Intercept 2.979\n", "1 TV 0.045\n", "2 radio 0.189\n", "3 newspaper 0.003" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# intercept\n", "intercept = pd.DataFrame({\n", " \"Name\": [\"Intercept\"],\n", " \"Coefficient\":[reg.intercept_]}\n", " )\n", "\n", "# make a slope table\n", "slope = pd.DataFrame({\n", " \"Name\": features,\n", " \"Coefficient\": reg.coef_}\n", ")\n", "\n", "# combine estimates of intercept and slopes\n", "table = pd.concat([intercept, slope], ignore_index=True, sort=False)\n", "\n", "round(table, 3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluation on test set" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [], "source": [ "# obtain predictions\n", "y_pred = reg.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.899" ] }, "execution_count": 94, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# R squared\n", "r2_score(y_test, y_pred).round(3)" ] }, { "cell_type": "code", "execution_count": 95, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3.174" ] }, "execution_count": 95, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# MSE\n", "mean_squared_error(y_test, y_pred).round(3)" ] }, { "cell_type": "code", "execution_count": 96, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.782" ] }, "execution_count": 96, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# RMSE\n", "mean_squared_error(y_test, y_pred, squared=False).round(3)" ] }, { "cell_type": "code", "execution_count": 97, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.461" ] }, "execution_count": 97, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# MAE\n", "mean_absolute_error(y_test, y_pred).round(3)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.12 ('base')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "463226f144cc21b006ce6927bfc93dd00694e52c8bc6857abb6e555b983749e9" } } }, "nbformat": 4, "nbformat_minor": 2 }