Fitting a line#
Setup#
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
Data#
Import data#
df = pd.read_csv('https://raw.githubusercontent.com/kirenz/datasets/master/possum.csv')
Data structure#
df
site | pop | sex | age | head_l | skull_w | total_l | tail_l | |
---|---|---|---|---|---|---|---|---|
0 | 1 | Vic | m | 8.0 | 94.1 | 60.4 | 89.0 | 36.0 |
1 | 1 | Vic | f | 6.0 | 92.5 | 57.6 | 91.5 | 36.5 |
2 | 1 | Vic | f | 6.0 | 94.0 | 60.0 | 95.5 | 39.0 |
3 | 1 | Vic | f | 6.0 | 93.2 | 57.1 | 92.0 | 38.0 |
4 | 1 | Vic | f | 2.0 | 91.5 | 56.3 | 85.5 | 36.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... |
99 | 7 | other | m | 1.0 | 89.5 | 56.0 | 81.5 | 36.5 |
100 | 7 | other | m | 1.0 | 88.6 | 54.7 | 82.5 | 39.0 |
101 | 7 | other | f | 6.0 | 92.4 | 55.0 | 89.0 | 38.0 |
102 | 7 | other | m | 4.0 | 91.5 | 55.2 | 82.5 | 36.5 |
103 | 7 | other | f | 3.0 | 93.6 | 59.9 | 89.0 | 40.0 |
104 rows × 8 columns
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 104 entries, 0 to 103
Data columns (total 8 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 site 104 non-null int64
1 pop 104 non-null object
2 sex 104 non-null object
3 age 102 non-null float64
4 head_l 104 non-null float64
5 skull_w 104 non-null float64
6 total_l 104 non-null float64
7 tail_l 104 non-null float64
dtypes: float64(5), int64(1), object(2)
memory usage: 6.6+ KB
Variable lists#
Prepara data for scikit-learn model:
y_label = "head_l"
X = df[["total_l"]]
y = df[y_label]
Model#
Select model#
# Choose the linear regression model
reg = LinearRegression()
Fit model#
# Fit the model to the data
reg.fit(X, y)
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
Coefficients#
# Intercept
reg.intercept_
42.70979314896378
# Slope
reg.coef_
array([0.57290128])
Make predictions#
# Make predictions on the data
y_pred = reg.predict(X)
Evaluation#
Mean squared error#
mean_squared_error(y, y_pred)
6.6061634260446445
Root mean squared error#
mean_squared_error(y, y_pred, squared=False)
2.570245790978879