{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Sales prediction\n",
"\n",
"Version without data splitting."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import altair as alt\n",
"\n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.metrics import r2_score\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.metrics import mean_absolute_error"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('https://raw.githubusercontent.com/kirenz/datasets/master/advertising.csv')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data structure"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Market | \n",
" TV | \n",
" radio | \n",
" newspaper | \n",
" sales | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1 | \n",
" 230.1 | \n",
" 37.8 | \n",
" 69.2 | \n",
" 22.1 | \n",
"
\n",
" \n",
" 1 | \n",
" 2 | \n",
" 44.5 | \n",
" 39.3 | \n",
" 45.1 | \n",
" 10.4 | \n",
"
\n",
" \n",
" 2 | \n",
" 3 | \n",
" 17.2 | \n",
" 45.9 | \n",
" 69.3 | \n",
" 9.3 | \n",
"
\n",
" \n",
" 3 | \n",
" 4 | \n",
" 151.5 | \n",
" 41.3 | \n",
" 58.5 | \n",
" 18.5 | \n",
"
\n",
" \n",
" 4 | \n",
" 5 | \n",
" 180.8 | \n",
" 10.8 | \n",
" 58.4 | \n",
" 12.9 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 195 | \n",
" 196 | \n",
" 38.2 | \n",
" 3.7 | \n",
" 13.8 | \n",
" 7.6 | \n",
"
\n",
" \n",
" 196 | \n",
" 197 | \n",
" 94.2 | \n",
" 4.9 | \n",
" 8.1 | \n",
" 9.7 | \n",
"
\n",
" \n",
" 197 | \n",
" 198 | \n",
" 177.0 | \n",
" 9.3 | \n",
" 6.4 | \n",
" 12.8 | \n",
"
\n",
" \n",
" 198 | \n",
" 199 | \n",
" 283.6 | \n",
" 42.0 | \n",
" 66.2 | \n",
" 25.5 | \n",
"
\n",
" \n",
" 199 | \n",
" 200 | \n",
" 232.1 | \n",
" 8.6 | \n",
" 8.7 | \n",
" 13.4 | \n",
"
\n",
" \n",
"
\n",
"
200 rows × 5 columns
\n",
"
"
],
"text/plain": [
" Market TV radio newspaper sales\n",
"0 1 230.1 37.8 69.2 22.1\n",
"1 2 44.5 39.3 45.1 10.4\n",
"2 3 17.2 45.9 69.3 9.3\n",
"3 4 151.5 41.3 58.5 18.5\n",
"4 5 180.8 10.8 58.4 12.9\n",
".. ... ... ... ... ...\n",
"195 196 38.2 3.7 13.8 7.6\n",
"196 197 94.2 4.9 8.1 9.7\n",
"197 198 177.0 9.3 6.4 12.8\n",
"198 199 283.6 42.0 66.2 25.5\n",
"199 200 232.1 8.6 8.7 13.4\n",
"\n",
"[200 rows x 5 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"RangeIndex: 200 entries, 0 to 199\n",
"Data columns (total 5 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 Market 200 non-null int64 \n",
" 1 TV 200 non-null float64\n",
" 2 radio 200 non-null float64\n",
" 3 newspaper 200 non-null float64\n",
" 4 sales 200 non-null float64\n",
"dtypes: float64(4), int64(1)\n",
"memory usage: 7.9 KB\n"
]
}
],
"source": [
"df.info()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data corrections"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# variable Market is categorical\n",
"df['Market'] = df['Market'].astype('category')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Variable lists"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# define outcome variable as y_label\n",
"y_label = 'sales'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# select features\n",
"features = df.drop(columns=[y_label, 'Market']).columns.tolist()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# create feature data\n",
"X = df[features]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# create response\n",
"y = df[y_label]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Analysis"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Descriptive statistics"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" count | \n",
" mean | \n",
" std | \n",
" min | \n",
" 25% | \n",
" 50% | \n",
" 75% | \n",
" max | \n",
"
\n",
" \n",
" \n",
" \n",
" TV | \n",
" 200.0 | \n",
" 147.0425 | \n",
" 85.854236 | \n",
" 0.7 | \n",
" 74.375 | \n",
" 149.75 | \n",
" 218.825 | \n",
" 296.4 | \n",
"
\n",
" \n",
" radio | \n",
" 200.0 | \n",
" 23.2640 | \n",
" 14.846809 | \n",
" 0.0 | \n",
" 9.975 | \n",
" 22.90 | \n",
" 36.525 | \n",
" 49.6 | \n",
"
\n",
" \n",
" newspaper | \n",
" 200.0 | \n",
" 30.5540 | \n",
" 21.778621 | \n",
" 0.3 | \n",
" 12.750 | \n",
" 25.75 | \n",
" 45.100 | \n",
" 114.0 | \n",
"
\n",
" \n",
" sales | \n",
" 200.0 | \n",
" 14.0225 | \n",
" 5.217457 | \n",
" 1.6 | \n",
" 10.375 | \n",
" 12.90 | \n",
" 17.400 | \n",
" 27.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" count mean std min 25% 50% 75% max\n",
"TV 200.0 147.0425 85.854236 0.7 74.375 149.75 218.825 296.4\n",
"radio 200.0 23.2640 14.846809 0.0 9.975 22.90 36.525 49.6\n",
"newspaper 200.0 30.5540 21.778621 0.3 12.750 25.75 45.100 114.0\n",
"sales 200.0 14.0225 5.217457 1.6 10.375 12.90 17.400 27.0"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.describe().T"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Exploratory data analysis"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
""
],
"text/plain": [
"alt.RepeatChart(...)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"alt.Chart(df).mark_bar().encode(\n",
" alt.X(alt.repeat(\"column\"), type=\"quantitative\", bin=True),\n",
" y='count()',\n",
").properties(\n",
" width=150,\n",
" height=150\n",
").repeat(\n",
" column=['sales', 'TV', 'radio', 'newspaper']\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
""
],
"text/plain": [
"alt.RepeatChart(...)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"alt.Chart(df).mark_circle().encode(\n",
" alt.X(alt.repeat(\"column\"), type='quantitative'),\n",
" alt.Y(alt.repeat(\"row\"), type='quantitative')\n",
").properties(\n",
" width=150,\n",
" height=150\n",
").repeat(\n",
" row=['sales', 'TV', 'radio', 'newspaper'],\n",
" column=['sales', 'TV', 'radio', 'newspaper']\n",
").interactive()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Correlations"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"sales 1.000000\n",
"TV 0.782224\n",
"radio 0.576223\n",
"newspaper 0.228299\n",
"Name: sales, dtype: float64"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# inspect correlation between outcome and possible predictors\n",
"corr = df.corr()\n",
"corr[y_label].sort_values(ascending=False)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
" \n",
" \n",
" | \n",
" TV | \n",
" radio | \n",
" newspaper | \n",
" sales | \n",
"
\n",
" \n",
" \n",
" \n",
" TV | \n",
" 1.000000 | \n",
" 0.054809 | \n",
" 0.056648 | \n",
" 0.782224 | \n",
"
\n",
" \n",
" radio | \n",
" 0.054809 | \n",
" 1.000000 | \n",
" 0.354104 | \n",
" 0.576223 | \n",
"
\n",
" \n",
" newspaper | \n",
" 0.056648 | \n",
" 0.354104 | \n",
" 1.000000 | \n",
" 0.228299 | \n",
"
\n",
" \n",
" sales | \n",
" 0.782224 | \n",
" 0.576223 | \n",
" 0.228299 | \n",
" 1.000000 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
""
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# take a look at all correlations\n",
"corr.style.background_gradient(cmap='Blues')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Select model"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# select the linear regression model\n",
"reg = LinearRegression()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Fit model"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"LinearRegression()"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Fit the model to the data\n",
"reg.fit(X, y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Coefficients"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Name | \n",
" Coefficient | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Intercept | \n",
" 2.939 | \n",
"
\n",
" \n",
" 1 | \n",
" TV | \n",
" 0.046 | \n",
"
\n",
" \n",
" 2 | \n",
" radio | \n",
" 0.189 | \n",
"
\n",
" \n",
" 3 | \n",
" newspaper | \n",
" -0.001 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Name Coefficient\n",
"0 Intercept 2.939\n",
"1 TV 0.046\n",
"2 radio 0.189\n",
"3 newspaper -0.001"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# intercept\n",
"intercept = pd.DataFrame({\n",
" \"Name\": [\"Intercept\"],\n",
" \"Coefficient\":[reg.intercept_]}\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# make a slope table\n",
"slope = pd.DataFrame({\n",
" \"Name\": features,\n",
" \"Coefficient\": reg.coef_}\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# combine estimates of intercept and slopes\n",
"table = pd.concat([intercept, slope], ignore_index=True, sort=False)\n",
"\n",
"round(table, 3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# obtain predictions\n",
"y_pred = reg.predict(X)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.897"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# R squared\n",
"r2_score(y, y_pred).round(3)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2.784"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# MSE\n",
"mean_squared_error(y, y_pred).round(3)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.669"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# RMSE\n",
"mean_squared_error(y, y_pred, squared=False).round(3)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.252"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# MAE\n",
"mean_absolute_error(y, y_pred).round(3)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.12 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "463226f144cc21b006ce6927bfc93dd00694e52c8bc6857abb6e555b983749e9"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}