{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Sales prediction\n",
"\n",
"Version with data splitting."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import altair as alt\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.model_selection import cross_val_score\n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.metrics import r2_score\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.metrics import mean_absolute_error"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('https://raw.githubusercontent.com/kirenz/datasets/master/advertising.csv')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data structure"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Market | \n",
" TV | \n",
" radio | \n",
" newspaper | \n",
" sales | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1 | \n",
" 230.1 | \n",
" 37.8 | \n",
" 69.2 | \n",
" 22.1 | \n",
"
\n",
" \n",
" 1 | \n",
" 2 | \n",
" 44.5 | \n",
" 39.3 | \n",
" 45.1 | \n",
" 10.4 | \n",
"
\n",
" \n",
" 2 | \n",
" 3 | \n",
" 17.2 | \n",
" 45.9 | \n",
" 69.3 | \n",
" 9.3 | \n",
"
\n",
" \n",
" 3 | \n",
" 4 | \n",
" 151.5 | \n",
" 41.3 | \n",
" 58.5 | \n",
" 18.5 | \n",
"
\n",
" \n",
" 4 | \n",
" 5 | \n",
" 180.8 | \n",
" 10.8 | \n",
" 58.4 | \n",
" 12.9 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 195 | \n",
" 196 | \n",
" 38.2 | \n",
" 3.7 | \n",
" 13.8 | \n",
" 7.6 | \n",
"
\n",
" \n",
" 196 | \n",
" 197 | \n",
" 94.2 | \n",
" 4.9 | \n",
" 8.1 | \n",
" 9.7 | \n",
"
\n",
" \n",
" 197 | \n",
" 198 | \n",
" 177.0 | \n",
" 9.3 | \n",
" 6.4 | \n",
" 12.8 | \n",
"
\n",
" \n",
" 198 | \n",
" 199 | \n",
" 283.6 | \n",
" 42.0 | \n",
" 66.2 | \n",
" 25.5 | \n",
"
\n",
" \n",
" 199 | \n",
" 200 | \n",
" 232.1 | \n",
" 8.6 | \n",
" 8.7 | \n",
" 13.4 | \n",
"
\n",
" \n",
"
\n",
"
200 rows × 5 columns
\n",
"
"
],
"text/plain": [
" Market TV radio newspaper sales\n",
"0 1 230.1 37.8 69.2 22.1\n",
"1 2 44.5 39.3 45.1 10.4\n",
"2 3 17.2 45.9 69.3 9.3\n",
"3 4 151.5 41.3 58.5 18.5\n",
"4 5 180.8 10.8 58.4 12.9\n",
".. ... ... ... ... ...\n",
"195 196 38.2 3.7 13.8 7.6\n",
"196 197 94.2 4.9 8.1 9.7\n",
"197 198 177.0 9.3 6.4 12.8\n",
"198 199 283.6 42.0 66.2 25.5\n",
"199 200 232.1 8.6 8.7 13.4\n",
"\n",
"[200 rows x 5 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"RangeIndex: 200 entries, 0 to 199\n",
"Data columns (total 5 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 Market 200 non-null int64 \n",
" 1 TV 200 non-null float64\n",
" 2 radio 200 non-null float64\n",
" 3 newspaper 200 non-null float64\n",
" 4 sales 200 non-null float64\n",
"dtypes: float64(4), int64(1)\n",
"memory usage: 7.9 KB\n"
]
}
],
"source": [
"df.info()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data corrections"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# variable Market is categorical\n",
"df['Market'] = df['Market'].astype('category')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Variable lists"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# define outcome variable as y_label\n",
"y_label = 'sales'\n",
"\n",
"# select features\n",
"features = df.drop(columns=[y_label, 'Market']).columns\n",
"\n",
"# create feature data\n",
"X = df[features]\n",
"\n",
"# create response\n",
"y = df[y_label]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data splitting"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, y, \n",
" test_size=0.2,\n",
" random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# data exploration set\n",
"df_train = pd.DataFrame(X_train.copy())\n",
"\n",
"df_train = df_train.join(pd.DataFrame(y_train))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Analysis"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" count | \n",
" mean | \n",
" std | \n",
" min | \n",
" 25% | \n",
" 50% | \n",
" 75% | \n",
" max | \n",
"
\n",
" \n",
" \n",
" \n",
" TV | \n",
" 160.0 | \n",
" 150.019375 | \n",
" 84.418857 | \n",
" 0.7 | \n",
" 77.750 | \n",
" 150.65 | \n",
" 218.825 | \n",
" 296.4 | \n",
"
\n",
" \n",
" radio | \n",
" 160.0 | \n",
" 22.875625 | \n",
" 14.805216 | \n",
" 0.0 | \n",
" 9.825 | \n",
" 21.20 | \n",
" 36.425 | \n",
" 49.6 | \n",
"
\n",
" \n",
" newspaper | \n",
" 160.0 | \n",
" 29.945625 | \n",
" 20.336449 | \n",
" 0.3 | \n",
" 12.875 | \n",
" 25.60 | \n",
" 44.500 | \n",
" 100.9 | \n",
"
\n",
" \n",
" sales | \n",
" 160.0 | \n",
" 14.100000 | \n",
" 5.108754 | \n",
" 1.6 | \n",
" 10.475 | \n",
" 13.20 | \n",
" 17.325 | \n",
" 27.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" count mean std min 25% 50% 75% max\n",
"TV 160.0 150.019375 84.418857 0.7 77.750 150.65 218.825 296.4\n",
"radio 160.0 22.875625 14.805216 0.0 9.825 21.20 36.425 49.6\n",
"newspaper 160.0 29.945625 20.336449 0.3 12.875 25.60 44.500 100.9\n",
"sales 160.0 14.100000 5.108754 1.6 10.475 13.20 17.325 27.0"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_train.describe().T"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
""
],
"text/plain": [
"alt.RepeatChart(...)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"alt.Chart(df_train).mark_bar().encode(\n",
" alt.X(alt.repeat(\"column\"), type=\"quantitative\", bin=True),\n",
" y='count()',\n",
").properties(\n",
" width=150,\n",
" height=150\n",
").repeat(\n",
" column=['sales', 'TV', 'radio', 'newspaper']\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
""
],
"text/plain": [
"alt.RepeatChart(...)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"alt.Chart(df_train).mark_circle().encode(\n",
" alt.X(alt.repeat(\"column\"), type='quantitative'),\n",
" alt.Y(alt.repeat(\"row\"), type='quantitative')\n",
").properties(\n",
" width=150,\n",
" height=150\n",
").repeat(\n",
" row=['sales', 'TV', 'radio', 'newspaper'],\n",
" column=['sales', 'TV', 'radio', 'newspaper']\n",
").interactive()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"sales 1.000000\n",
"TV 0.768874\n",
"radio 0.592373\n",
"newspaper 0.237874\n",
"Name: sales, dtype: float64"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# inspect correlation between outcome and possible predictors\n",
"corr = df_train.corr()\n",
"corr[y_label].sort_values(ascending=False)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
" \n",
" \n",
" | \n",
" TV | \n",
" radio | \n",
" newspaper | \n",
" sales | \n",
"
\n",
" \n",
" \n",
" \n",
" TV | \n",
" 1.000000 | \n",
" 0.053872 | \n",
" 0.019084 | \n",
" 0.768874 | \n",
"
\n",
" \n",
" radio | \n",
" 0.053872 | \n",
" 1.000000 | \n",
" 0.388074 | \n",
" 0.592373 | \n",
"
\n",
" \n",
" newspaper | \n",
" 0.019084 | \n",
" 0.388074 | \n",
" 1.000000 | \n",
" 0.237874 | \n",
"
\n",
" \n",
" sales | \n",
" 0.768874 | \n",
" 0.592373 | \n",
" 0.237874 | \n",
" 1.000000 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
""
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# take a look at all correlations\n",
"corr.style.background_gradient(cmap='Blues')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Select model"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# select the linear regression model\n",
"reg = LinearRegression()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training and validation"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# cross-validation with 5 folds\n",
"scores = cross_val_score(reg, \n",
" X_train, \n",
" y_train, \n",
" cv=5, \n",
" scoring='neg_mean_squared_error') *-1"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
" \n",
" \n",
" | \n",
" lr | \n",
"
\n",
" \n",
" \n",
" \n",
" 1 | \n",
" 4.192954 | \n",
"
\n",
" \n",
" 2 | \n",
" 1.500644 | \n",
"
\n",
" \n",
" 3 | \n",
" 2.109080 | \n",
"
\n",
" \n",
" 4 | \n",
" 2.541355 | \n",
"
\n",
" \n",
" 5 | \n",
" 4.372931 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
""
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# store cross-validation scores\n",
"df_scores = pd.DataFrame({\"lr\": scores})\n",
"df_scores"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# reset index to match the number of folds\n",
"df_scores.index += 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# print dataframe\n",
"df_scores.style.background_gradient(cmap='Blues')"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
""
],
"text/plain": [
"alt.Chart(...)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"alt.Chart(df_scores.reset_index()).mark_line(\n",
" point=alt.OverlayMarkDef()\n",
").encode(\n",
" x=alt.X(\"index\", bin=False, title=\"Fold\", axis=alt.Axis(tickCount=5)),\n",
" y=alt.Y(\"lr\", aggregate=\"mean\", title=\"Mean squared error (MSE)\")\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" count | \n",
" mean | \n",
" std | \n",
" min | \n",
" 25% | \n",
" 50% | \n",
" 75% | \n",
" max | \n",
"
\n",
" \n",
" \n",
" \n",
" lr | \n",
" 5.0 | \n",
" 2.943393 | \n",
" 1.279083 | \n",
" 1.500644 | \n",
" 2.10908 | \n",
" 2.541355 | \n",
" 4.192954 | \n",
" 4.372931 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" count mean std min 25% 50% 75% max\n",
"lr 5.0 2.943393 1.279083 1.500644 2.10908 2.541355 4.192954 4.372931"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_scores.describe().T"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Fit model"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"LinearRegression()"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Fit the model to the complete training data\n",
"reg.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Coefficients"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Name | \n",
" Coefficient | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Intercept | \n",
" 2.979 | \n",
"
\n",
" \n",
" 1 | \n",
" TV | \n",
" 0.045 | \n",
"
\n",
" \n",
" 2 | \n",
" radio | \n",
" 0.189 | \n",
"
\n",
" \n",
" 3 | \n",
" newspaper | \n",
" 0.003 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Name Coefficient\n",
"0 Intercept 2.979\n",
"1 TV 0.045\n",
"2 radio 0.189\n",
"3 newspaper 0.003"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# intercept\n",
"intercept = pd.DataFrame({\n",
" \"Name\": [\"Intercept\"],\n",
" \"Coefficient\":[reg.intercept_]}\n",
" )\n",
"\n",
"# make a slope table\n",
"slope = pd.DataFrame({\n",
" \"Name\": features,\n",
" \"Coefficient\": reg.coef_}\n",
")\n",
"\n",
"# combine estimates of intercept and slopes\n",
"table = pd.concat([intercept, slope], ignore_index=True, sort=False)\n",
"\n",
"round(table, 3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Evaluation on test set"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
"# obtain predictions\n",
"y_pred = reg.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.899"
]
},
"execution_count": 94,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# R squared\n",
"r2_score(y_test, y_pred).round(3)"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3.174"
]
},
"execution_count": 95,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# MSE\n",
"mean_squared_error(y_test, y_pred).round(3)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.782"
]
},
"execution_count": 96,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# RMSE\n",
"mean_squared_error(y_test, y_pred, squared=False).round(3)"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.461"
]
},
"execution_count": 97,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# MAE\n",
"mean_absolute_error(y_test, y_pred).round(3)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.12 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "463226f144cc21b006ce6927bfc93dd00694e52c8bc6857abb6e555b983749e9"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}