{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Random forest in scikit-learn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We illustrate the following regression method on a data set called \"Hitters\", which includes 20 variables and 322 observations of major league baseball players. The goal is to predict a baseball player’s salary on the basis of various features associated with performance in the previous year. We don't cover the topic of exploratory data analysis in this notebook. \n", "\n", "- Visit [this documentation](https://cran.r-project.org/web/packages/ISLR/ISLR.pdf) if you want to learn more about the data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "from collections import OrderedDict\n", "\n", "from sklearn.ensemble import RandomForestRegressor\n", "from sklearn.metrics import mean_squared_error\n", "from sklearn.inspection import permutation_importance" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Import" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "\n", "df = pd.read_csv(\"https://raw.githubusercontent.com/kirenz/datasets/master/Hitters.csv\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AtBatHitsHmRunRunsRBIWalksYearsCAtBatCHitsCHmRunCRunsCRBICWalksLeagueDivisionPutOutsAssistsErrorsSalaryNewLeague
02936613029141293661302914AE4463320NaNA
131581724383914344983569321414375NW6324310475.0N
2479130186672763162445763224266263AW8808214480.0A
3496141206578371156281575225828838354NE200113500.0N
43218710394230239610112484633NE80540491.5N
...............................................................
31749712776548375270380632379311138NE32593700.0N
3184921365765094125511151139897451875AE31338120875.0A
319475126361435261700433721793146AW371137385.0A
32057314498560788319885797470420332AE131413112960.0A
3216311709774431114908145730775357249AW408431000.0A
\n", "

322 rows × 20 columns

\n", "
" ], "text/plain": [ " AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun \\\n", "0 293 66 1 30 29 14 1 293 66 1 \n", "1 315 81 7 24 38 39 14 3449 835 69 \n", "2 479 130 18 66 72 76 3 1624 457 63 \n", "3 496 141 20 65 78 37 11 5628 1575 225 \n", "4 321 87 10 39 42 30 2 396 101 12 \n", ".. ... ... ... ... ... ... ... ... ... ... \n", "317 497 127 7 65 48 37 5 2703 806 32 \n", "318 492 136 5 76 50 94 12 5511 1511 39 \n", "319 475 126 3 61 43 52 6 1700 433 7 \n", "320 573 144 9 85 60 78 8 3198 857 97 \n", "321 631 170 9 77 44 31 11 4908 1457 30 \n", "\n", " CRuns CRBI CWalks League Division PutOuts Assists Errors Salary \\\n", "0 30 29 14 A E 446 33 20 NaN \n", "1 321 414 375 N W 632 43 10 475.0 \n", "2 224 266 263 A W 880 82 14 480.0 \n", "3 828 838 354 N E 200 11 3 500.0 \n", "4 48 46 33 N E 805 40 4 91.5 \n", ".. ... ... ... ... ... ... ... ... ... \n", "317 379 311 138 N E 325 9 3 700.0 \n", "318 897 451 875 A E 313 381 20 875.0 \n", "319 217 93 146 A W 37 113 7 385.0 \n", "320 470 420 332 A E 1314 131 12 960.0 \n", "321 775 357 249 A W 408 4 3 1000.0 \n", "\n", " NewLeague \n", "0 A \n", "1 N \n", "2 A \n", "3 N \n", "4 N \n", ".. ... \n", "317 N \n", "318 A \n", "319 A \n", "320 A \n", "321 A \n", "\n", "[322 rows x 20 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 322 entries, 0 to 321\n", "Data columns (total 20 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 AtBat 322 non-null int64 \n", " 1 Hits 322 non-null int64 \n", " 2 HmRun 322 non-null int64 \n", " 3 Runs 322 non-null int64 \n", " 4 RBI 322 non-null int64 \n", " 5 Walks 322 non-null int64 \n", " 6 Years 322 non-null int64 \n", " 7 CAtBat 322 non-null int64 \n", " 8 CHits 322 non-null int64 \n", " 9 CHmRun 322 non-null int64 \n", " 10 CRuns 322 non-null int64 \n", " 11 CRBI 322 non-null int64 \n", " 12 CWalks 322 non-null int64 \n", " 13 League 322 non-null object \n", " 14 Division 322 non-null object \n", " 15 PutOuts 322 non-null int64 \n", " 16 Assists 322 non-null int64 \n", " 17 Errors 322 non-null int64 \n", " 18 Salary 263 non-null float64\n", " 19 NewLeague 322 non-null object \n", "dtypes: float64(1), int64(16), object(3)\n", "memory usage: 50.4+ KB\n" ] } ], "source": [ "df.info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Missing values\n", "\n", "Note that the salary is missing for some of the players:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AtBat 0\n", "Hits 0\n", "HmRun 0\n", "Runs 0\n", "RBI 0\n", "Walks 0\n", "Years 0\n", "CAtBat 0\n", "CHits 0\n", "CHmRun 0\n", "CRuns 0\n", "CRBI 0\n", "CWalks 0\n", "League 0\n", "Division 0\n", "PutOuts 0\n", "Assists 0\n", "Errors 0\n", "Salary 59\n", "NewLeague 0\n", "dtype: int64\n" ] } ], "source": [ "print(df.isnull().sum())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We simply drop the missing cases: " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# drop missing cases\n", "df = df.dropna()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Create label and features\n", "\n", "Since we will use algorithms from scikit learn, we need to encode our categorical features as one-hot numeric features (dummy variables):" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "dummies = pd.get_dummies(df[['League', 'Division','NewLeague']])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Int64Index: 263 entries, 1 to 321\n", "Data columns (total 6 columns):\n", " # Column Non-Null Count Dtype\n", "--- ------ -------------- -----\n", " 0 League_A 263 non-null uint8\n", " 1 League_N 263 non-null uint8\n", " 2 Division_E 263 non-null uint8\n", " 3 Division_W 263 non-null uint8\n", " 4 NewLeague_A 263 non-null uint8\n", " 5 NewLeague_N 263 non-null uint8\n", "dtypes: uint8(6)\n", "memory usage: 3.6 KB\n" ] } ], "source": [ "dummies.info()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " League_A League_N Division_E Division_W NewLeague_A NewLeague_N\n", "1 0 1 0 1 0 1\n", "2 1 0 0 1 1 0\n", "3 0 1 1 0 0 1\n", "4 0 1 1 0 0 1\n", "5 1 0 0 1 1 0\n" ] } ], "source": [ "print(dummies.head())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we create our label y:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "y = df[['Salary']]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We drop the column with the outcome variable (Salary), and categorical columns for which we already created dummy variables:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "X_numerical = df.drop(['Salary', 'League', 'Division', 'NewLeague'], axis=1).astype('float64')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Make a list of all numerical features (we need them later):" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Index(['AtBat', 'Hits', 'HmRun', 'Runs', 'RBI', 'Walks', 'Years', 'CAtBat',\n", " 'CHits', 'CHmRun', 'CRuns', 'CRBI', 'CWalks', 'PutOuts', 'Assists',\n", " 'Errors'],\n", " dtype='object')" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list_numerical = X_numerical.columns\n", "list_numerical" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Int64Index: 263 entries, 1 to 321\n", "Data columns (total 19 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 AtBat 263 non-null float64\n", " 1 Hits 263 non-null float64\n", " 2 HmRun 263 non-null float64\n", " 3 Runs 263 non-null float64\n", " 4 RBI 263 non-null float64\n", " 5 Walks 263 non-null float64\n", " 6 Years 263 non-null float64\n", " 7 CAtBat 263 non-null float64\n", " 8 CHits 263 non-null float64\n", " 9 CHmRun 263 non-null float64\n", " 10 CRuns 263 non-null float64\n", " 11 CRBI 263 non-null float64\n", " 12 CWalks 263 non-null float64\n", " 13 PutOuts 263 non-null float64\n", " 14 Assists 263 non-null float64\n", " 15 Errors 263 non-null float64\n", " 16 League_N 263 non-null uint8 \n", " 17 Division_W 263 non-null uint8 \n", " 18 NewLeague_N 263 non-null uint8 \n", "dtypes: float64(16), uint8(3)\n", "memory usage: 35.7 KB\n" ] } ], "source": [ "# Create all features\n", "X = pd.concat([X_numerical, dummies[['League_N', 'Division_W', 'NewLeague_N']]], axis=1)\n", "X.info()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "# Create a list of feature names\n", "feature_names = X.columns" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Split data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Split the data set into train and test set with the first 70% of the data for training and the remaining 30% for testing." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=10)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AtBatHitsHmRunRunsRBIWalksYearsCAtBatCHitsCHmRunCRunsCRBICWalksPutOutsAssistsErrorsLeague_NDivision_WNewLeague_N
260496.0119.08.057.033.021.07.03358.0882.036.0365.0280.0165.0155.0371.029.0111
92317.078.07.035.035.032.01.0317.078.07.035.035.032.045.0122.026.0000
137343.0103.06.048.036.040.015.04338.01193.070.0581.0421.0325.0211.056.013.0000
90314.083.013.039.046.016.05.01457.0405.028.0156.0159.076.0533.040.04.0010
100495.0151.017.061.084.078.010.05624.01679.0275.0884.01015.0709.01045.088.013.0000
\n", "
" ], "text/plain": [ " AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun \\\n", "260 496.0 119.0 8.0 57.0 33.0 21.0 7.0 3358.0 882.0 36.0 \n", "92 317.0 78.0 7.0 35.0 35.0 32.0 1.0 317.0 78.0 7.0 \n", "137 343.0 103.0 6.0 48.0 36.0 40.0 15.0 4338.0 1193.0 70.0 \n", "90 314.0 83.0 13.0 39.0 46.0 16.0 5.0 1457.0 405.0 28.0 \n", "100 495.0 151.0 17.0 61.0 84.0 78.0 10.0 5624.0 1679.0 275.0 \n", "\n", " CRuns CRBI CWalks PutOuts Assists Errors League_N Division_W \\\n", "260 365.0 280.0 165.0 155.0 371.0 29.0 1 1 \n", "92 35.0 35.0 32.0 45.0 122.0 26.0 0 0 \n", "137 581.0 421.0 325.0 211.0 56.0 13.0 0 0 \n", "90 156.0 159.0 76.0 533.0 40.0 4.0 0 1 \n", "100 884.0 1015.0 709.0 1045.0 88.0 13.0 0 0 \n", "\n", " NewLeague_N \n", "260 1 \n", "92 0 \n", "137 0 \n", "90 0 \n", "100 0 " ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.head()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Salary
260875.0
9270.0
137430.0
90431.5
1002460.0
......
274200.0
196587.5
159200.0
17175.0
16275.0
\n", "

184 rows × 1 columns

\n", "
" ], "text/plain": [ " Salary\n", "260 875.0\n", "92 70.0\n", "137 430.0\n", "90 431.5\n", "100 2460.0\n", ".. ...\n", "274 200.0\n", "196 587.5\n", "159 200.0\n", "17 175.0\n", "162 75.0\n", "\n", "[184 rows x 1 columns]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data standardization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Some of our models perform best when all numerical features are centered around 0 and have variance in the same order (like Lasso, Ridge or GAMs).\n", "- To avoid [data leakage](https://en.wikipedia.org/wiki/Leakage_(machine_learning)), the standardization of numerical features should always be performed after data splitting and only from training data. \n", "- Furthermore, we obtain all necessary statistics for our features (mean and standard deviation) from training data and also use them on test data. Note that we don't standardize our dummy variables (which only have values of 0 or 1)." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import StandardScaler\n", "\n", "scaler = StandardScaler().fit(X_train[list_numerical]) \n", "\n", "X_train[list_numerical] = scaler.transform(X_train[list_numerical])\n", "X_test[list_numerical] = scaler.transform(X_test[list_numerical])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Make contiguous flattened arrays (for our scikit-learn model):" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "y_train = np.ravel(y_train)\n", "y_test = np.ravel(y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Define hyperparameters:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "params = {\n", " \"n_estimators\": 500,\n", " \"max_depth\": 4,\n", " \"min_samples_split\": 5,\n", " \"warm_start\":True,\n", " \"oob_score\":True,\n", " \"random_state\": 42,\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Build and fit model" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
RandomForestRegressor(max_depth=4, min_samples_split=5, n_estimators=500,\n",
       "                      oob_score=True, random_state=42, warm_start=True)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "RandomForestRegressor(max_depth=4, min_samples_split=5, n_estimators=500,\n", " oob_score=True, random_state=42, warm_start=True)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reg =RandomForestRegressor(**params)\n", "\n", "reg.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Make predictions" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "y_pred = reg.predict(X_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Evaluate model with RMSE" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "296.37036964432764" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mean_squared_error(y_test, y_pred, squared=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature importance\n", "\n", "- Next, we take a look at the tree based feature importance and the permutation importance.\n", "\n", "### Mean decrease in impurity (MDI)\n", "\n", "- Mean decrease in impurity (MDI) is a measure of feature importance for decision tree models. \n", "\n", "```{Note}\n", "Visit [this notebook](https://kirenz.github.io/feature-engineering/docs/mdi.html#) to learn more about MDI\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Feature importances are provided by the fitted attribute `feature_importances_` \n" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# obtain feature importance\n", "feature_importance = reg.feature_importances_\n", "\n", "# sort features according to importance\n", "sorted_idx = np.argsort(feature_importance)\n", "pos = np.arange(sorted_idx.shape[0])\n", "\n", "# plot feature importances\n", "plt.barh(pos, feature_importance[sorted_idx], align=\"center\")\n", "\n", "plt.yticks(pos, np.array(feature_names)[sorted_idx])\n", "plt.title(\"Feature Importance (MDI)\")\n", "plt.xlabel(\"Mean decrease in impurity\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Permutation feature importance\n", "\n", "The permutation feature importance is defined to be the decrease in a model score when a single feature value is randomly shuffled. \n", "\n", "```{Note}\n", "Visit [this notebook](https://kirenz.github.io/feature-engineering/docs/permutation-feature-importance.html) to learn more about permutation feature importance.\n", "```" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "result = permutation_importance(\n", " reg, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2\n", ")\n", "\n", "tree_importances = pd.Series(result.importances_mean, index=feature_names)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# sort features according to importance\n", "sorted_idx = np.argsort(tree_importances)\n", "pos = np.arange(sorted_idx.shape[0])\n", "\n", "# plot feature importances\n", "plt.barh(pos, tree_importances[sorted_idx], align=\"center\")\n", "\n", "plt.yticks(pos, np.array(feature_names)[sorted_idx])\n", "plt.title(\"Feature Importance (MDI)\")\n", "plt.xlabel(\"Mean decrease in impurity\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Same data plotted as boxplot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'Permutation Importance (test set)')" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.boxplot(\n", " result.importances[sorted_idx].T,\n", " vert=False,\n", " labels=np.array(feature_names)[sorted_idx],\n", ")\n", "\n", "plt.title(\"Permutation Importance (test set)\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We observe that the same features are detected as most important using both methods (e.g., `CAtBat`, `CRBI`, `CHits`, `Walks`, `Years`). Although the relative importances vary (especially for feature `Years`)." ] } ], "metadata": { "kernelspec": { "display_name": "ds", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "0de8387c967863cc622aba8b7ea5b466d4dfde089153d484429677aa77034389" } } }, "nbformat": 4, "nbformat_minor": 2 }