Fitting a line#

Setup#

import pandas as pd

from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

Data#

Import data#

df = pd.read_csv('https://raw.githubusercontent.com/kirenz/datasets/master/possum.csv')

Data structure#

df
site pop sex age head_l skull_w total_l tail_l
0 1 Vic m 8.0 94.1 60.4 89.0 36.0
1 1 Vic f 6.0 92.5 57.6 91.5 36.5
2 1 Vic f 6.0 94.0 60.0 95.5 39.0
3 1 Vic f 6.0 93.2 57.1 92.0 38.0
4 1 Vic f 2.0 91.5 56.3 85.5 36.0
... ... ... ... ... ... ... ... ...
99 7 other m 1.0 89.5 56.0 81.5 36.5
100 7 other m 1.0 88.6 54.7 82.5 39.0
101 7 other f 6.0 92.4 55.0 89.0 38.0
102 7 other m 4.0 91.5 55.2 82.5 36.5
103 7 other f 3.0 93.6 59.9 89.0 40.0

104 rows × 8 columns

df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 104 entries, 0 to 103
Data columns (total 8 columns):
 #   Column   Non-Null Count  Dtype  
---  ------   --------------  -----  
 0   site     104 non-null    int64  
 1   pop      104 non-null    object 
 2   sex      104 non-null    object 
 3   age      102 non-null    float64
 4   head_l   104 non-null    float64
 5   skull_w  104 non-null    float64
 6   total_l  104 non-null    float64
 7   tail_l   104 non-null    float64
dtypes: float64(5), int64(1), object(2)
memory usage: 6.6+ KB

Variable lists#

Prepara data for scikit-learn model:

y_label = "head_l"

X = df[["total_l"]]
y = df[y_label]

Model#

Select model#

# Choose the linear regression model
reg = LinearRegression()

Fit model#

# Fit the model to the data
reg.fit(X, y)
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Coefficients#

# Intercept
reg.intercept_
42.70979314896378
# Slope
reg.coef_
array([0.57290128])

Make predictions#

# Make predictions on the data
y_pred = reg.predict(X)

Evaluation#

Mean squared error#

mean_squared_error(y, y_pred)
6.6061634260446445

Root mean squared error#

mean_squared_error(y, y_pred, squared=False)
2.570245790978879