Sales prediction#
Version without data splitting.
Setup#
import numpy as np
import pandas as pd
import altair as alt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
Data#
Import data#
df = pd.read_csv('https://raw.githubusercontent.com/kirenz/datasets/master/advertising.csv')
Data structure#
df
Market | TV | radio | newspaper | sales | |
---|---|---|---|---|---|
0 | 1 | 230.1 | 37.8 | 69.2 | 22.1 |
1 | 2 | 44.5 | 39.3 | 45.1 | 10.4 |
2 | 3 | 17.2 | 45.9 | 69.3 | 9.3 |
3 | 4 | 151.5 | 41.3 | 58.5 | 18.5 |
4 | 5 | 180.8 | 10.8 | 58.4 | 12.9 |
... | ... | ... | ... | ... | ... |
195 | 196 | 38.2 | 3.7 | 13.8 | 7.6 |
196 | 197 | 94.2 | 4.9 | 8.1 | 9.7 |
197 | 198 | 177.0 | 9.3 | 6.4 | 12.8 |
198 | 199 | 283.6 | 42.0 | 66.2 | 25.5 |
199 | 200 | 232.1 | 8.6 | 8.7 | 13.4 |
200 rows × 5 columns
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 200 entries, 0 to 199
Data columns (total 5 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Market 200 non-null int64
1 TV 200 non-null float64
2 radio 200 non-null float64
3 newspaper 200 non-null float64
4 sales 200 non-null float64
dtypes: float64(4), int64(1)
memory usage: 7.9 KB
Data corrections#
# variable Market is categorical
df['Market'] = df['Market'].astype('category')
Variable lists#
# define outcome variable as y_label
y_label = 'sales'
# select features
features = df.drop(columns=[y_label, 'Market']).columns.tolist()
# create feature data
X = df[features]
# create response
y = df[y_label]
Analysis#
Descriptive statistics#
df.describe().T
count | mean | std | min | 25% | 50% | 75% | max | |
---|---|---|---|---|---|---|---|---|
TV | 200.0 | 147.0425 | 85.854236 | 0.7 | 74.375 | 149.75 | 218.825 | 296.4 |
radio | 200.0 | 23.2640 | 14.846809 | 0.0 | 9.975 | 22.90 | 36.525 | 49.6 |
newspaper | 200.0 | 30.5540 | 21.778621 | 0.3 | 12.750 | 25.75 | 45.100 | 114.0 |
sales | 200.0 | 14.0225 | 5.217457 | 1.6 | 10.375 | 12.90 | 17.400 | 27.0 |
Exploratory data analysis#
alt.Chart(df).mark_bar().encode(
alt.X(alt.repeat("column"), type="quantitative", bin=True),
y='count()',
).properties(
width=150,
height=150
).repeat(
column=['sales', 'TV', 'radio', 'newspaper']
)
alt.Chart(df).mark_circle().encode(
alt.X(alt.repeat("column"), type='quantitative'),
alt.Y(alt.repeat("row"), type='quantitative')
).properties(
width=150,
height=150
).repeat(
row=['sales', 'TV', 'radio', 'newspaper'],
column=['sales', 'TV', 'radio', 'newspaper']
).interactive()
Correlations#
# inspect correlation between outcome and possible predictors
corr = df.corr()
corr[y_label].sort_values(ascending=False)
sales 1.000000
TV 0.782224
radio 0.576223
newspaper 0.228299
Name: sales, dtype: float64
# take a look at all correlations
corr.style.background_gradient(cmap='Blues')
TV | radio | newspaper | sales | |
---|---|---|---|---|
TV | 1.000000 | 0.054809 | 0.056648 | 0.782224 |
radio | 0.054809 | 1.000000 | 0.354104 | 0.576223 |
newspaper | 0.056648 | 0.354104 | 1.000000 | 0.228299 |
sales | 0.782224 | 0.576223 | 0.228299 | 1.000000 |
Model#
Select model#
# select the linear regression model
reg = LinearRegression()
Fit model#
# Fit the model to the data
reg.fit(X, y)
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
Coefficients
# intercept
intercept = pd.DataFrame({
"Name": ["Intercept"],
"Coefficient":[reg.intercept_]}
)
Name | Coefficient | |
---|---|---|
0 | Intercept | 2.939 |
1 | TV | 0.046 |
2 | radio | 0.189 |
3 | newspaper | -0.001 |
# make a slope table
slope = pd.DataFrame({
"Name": features,
"Coefficient": reg.coef_}
)
# combine estimates of intercept and slopes
table = pd.concat([intercept, slope], ignore_index=True, sort=False)
round(table, 3)
Evaluation#
# obtain predictions
y_pred = reg.predict(X)
# R squared
r2_score(y, y_pred).round(3)
0.897
# MSE
mean_squared_error(y, y_pred).round(3)
2.784
# RMSE
mean_squared_error(y, y_pred, squared=False).round(3)
1.669
# MAE
mean_absolute_error(y, y_pred).round(3)
1.252