{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Random forest in scikit-learn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We illustrate the following regression method on a data set called \"Hitters\", which includes 20 variables and 322 observations of major league baseball players. The goal is to predict a baseball player’s salary on the basis of various features associated with performance in the previous year. We don't cover the topic of exploratory data analysis in this notebook. \n", "\n", "- Visit [this documentation](https://cran.r-project.org/web/packages/ISLR/ISLR.pdf) if you want to learn more about the data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "from collections import OrderedDict\n", "\n", "from sklearn.ensemble import RandomForestRegressor\n", "from sklearn.metrics import mean_squared_error\n", "from sklearn.inspection import permutation_importance" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Import" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "\n", "df = pd.read_csv(\"https://raw.githubusercontent.com/kirenz/datasets/master/Hitters.csv\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | AtBat | \n", "Hits | \n", "HmRun | \n", "Runs | \n", "RBI | \n", "Walks | \n", "Years | \n", "CAtBat | \n", "CHits | \n", "CHmRun | \n", "CRuns | \n", "CRBI | \n", "CWalks | \n", "League | \n", "Division | \n", "PutOuts | \n", "Assists | \n", "Errors | \n", "Salary | \n", "NewLeague | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "293 | \n", "66 | \n", "1 | \n", "30 | \n", "29 | \n", "14 | \n", "1 | \n", "293 | \n", "66 | \n", "1 | \n", "30 | \n", "29 | \n", "14 | \n", "A | \n", "E | \n", "446 | \n", "33 | \n", "20 | \n", "NaN | \n", "A | \n", "
1 | \n", "315 | \n", "81 | \n", "7 | \n", "24 | \n", "38 | \n", "39 | \n", "14 | \n", "3449 | \n", "835 | \n", "69 | \n", "321 | \n", "414 | \n", "375 | \n", "N | \n", "W | \n", "632 | \n", "43 | \n", "10 | \n", "475.0 | \n", "N | \n", "
2 | \n", "479 | \n", "130 | \n", "18 | \n", "66 | \n", "72 | \n", "76 | \n", "3 | \n", "1624 | \n", "457 | \n", "63 | \n", "224 | \n", "266 | \n", "263 | \n", "A | \n", "W | \n", "880 | \n", "82 | \n", "14 | \n", "480.0 | \n", "A | \n", "
3 | \n", "496 | \n", "141 | \n", "20 | \n", "65 | \n", "78 | \n", "37 | \n", "11 | \n", "5628 | \n", "1575 | \n", "225 | \n", "828 | \n", "838 | \n", "354 | \n", "N | \n", "E | \n", "200 | \n", "11 | \n", "3 | \n", "500.0 | \n", "N | \n", "
4 | \n", "321 | \n", "87 | \n", "10 | \n", "39 | \n", "42 | \n", "30 | \n", "2 | \n", "396 | \n", "101 | \n", "12 | \n", "48 | \n", "46 | \n", "33 | \n", "N | \n", "E | \n", "805 | \n", "40 | \n", "4 | \n", "91.5 | \n", "N | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
317 | \n", "497 | \n", "127 | \n", "7 | \n", "65 | \n", "48 | \n", "37 | \n", "5 | \n", "2703 | \n", "806 | \n", "32 | \n", "379 | \n", "311 | \n", "138 | \n", "N | \n", "E | \n", "325 | \n", "9 | \n", "3 | \n", "700.0 | \n", "N | \n", "
318 | \n", "492 | \n", "136 | \n", "5 | \n", "76 | \n", "50 | \n", "94 | \n", "12 | \n", "5511 | \n", "1511 | \n", "39 | \n", "897 | \n", "451 | \n", "875 | \n", "A | \n", "E | \n", "313 | \n", "381 | \n", "20 | \n", "875.0 | \n", "A | \n", "
319 | \n", "475 | \n", "126 | \n", "3 | \n", "61 | \n", "43 | \n", "52 | \n", "6 | \n", "1700 | \n", "433 | \n", "7 | \n", "217 | \n", "93 | \n", "146 | \n", "A | \n", "W | \n", "37 | \n", "113 | \n", "7 | \n", "385.0 | \n", "A | \n", "
320 | \n", "573 | \n", "144 | \n", "9 | \n", "85 | \n", "60 | \n", "78 | \n", "8 | \n", "3198 | \n", "857 | \n", "97 | \n", "470 | \n", "420 | \n", "332 | \n", "A | \n", "E | \n", "1314 | \n", "131 | \n", "12 | \n", "960.0 | \n", "A | \n", "
321 | \n", "631 | \n", "170 | \n", "9 | \n", "77 | \n", "44 | \n", "31 | \n", "11 | \n", "4908 | \n", "1457 | \n", "30 | \n", "775 | \n", "357 | \n", "249 | \n", "A | \n", "W | \n", "408 | \n", "4 | \n", "3 | \n", "1000.0 | \n", "A | \n", "
322 rows × 20 columns
\n", "\n", " | AtBat | \n", "Hits | \n", "HmRun | \n", "Runs | \n", "RBI | \n", "Walks | \n", "Years | \n", "CAtBat | \n", "CHits | \n", "CHmRun | \n", "CRuns | \n", "CRBI | \n", "CWalks | \n", "PutOuts | \n", "Assists | \n", "Errors | \n", "League_N | \n", "Division_W | \n", "NewLeague_N | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
260 | \n", "496.0 | \n", "119.0 | \n", "8.0 | \n", "57.0 | \n", "33.0 | \n", "21.0 | \n", "7.0 | \n", "3358.0 | \n", "882.0 | \n", "36.0 | \n", "365.0 | \n", "280.0 | \n", "165.0 | \n", "155.0 | \n", "371.0 | \n", "29.0 | \n", "1 | \n", "1 | \n", "1 | \n", "
92 | \n", "317.0 | \n", "78.0 | \n", "7.0 | \n", "35.0 | \n", "35.0 | \n", "32.0 | \n", "1.0 | \n", "317.0 | \n", "78.0 | \n", "7.0 | \n", "35.0 | \n", "35.0 | \n", "32.0 | \n", "45.0 | \n", "122.0 | \n", "26.0 | \n", "0 | \n", "0 | \n", "0 | \n", "
137 | \n", "343.0 | \n", "103.0 | \n", "6.0 | \n", "48.0 | \n", "36.0 | \n", "40.0 | \n", "15.0 | \n", "4338.0 | \n", "1193.0 | \n", "70.0 | \n", "581.0 | \n", "421.0 | \n", "325.0 | \n", "211.0 | \n", "56.0 | \n", "13.0 | \n", "0 | \n", "0 | \n", "0 | \n", "
90 | \n", "314.0 | \n", "83.0 | \n", "13.0 | \n", "39.0 | \n", "46.0 | \n", "16.0 | \n", "5.0 | \n", "1457.0 | \n", "405.0 | \n", "28.0 | \n", "156.0 | \n", "159.0 | \n", "76.0 | \n", "533.0 | \n", "40.0 | \n", "4.0 | \n", "0 | \n", "1 | \n", "0 | \n", "
100 | \n", "495.0 | \n", "151.0 | \n", "17.0 | \n", "61.0 | \n", "84.0 | \n", "78.0 | \n", "10.0 | \n", "5624.0 | \n", "1679.0 | \n", "275.0 | \n", "884.0 | \n", "1015.0 | \n", "709.0 | \n", "1045.0 | \n", "88.0 | \n", "13.0 | \n", "0 | \n", "0 | \n", "0 | \n", "
\n", " | Salary | \n", "
---|---|
260 | \n", "875.0 | \n", "
92 | \n", "70.0 | \n", "
137 | \n", "430.0 | \n", "
90 | \n", "431.5 | \n", "
100 | \n", "2460.0 | \n", "
... | \n", "... | \n", "
274 | \n", "200.0 | \n", "
196 | \n", "587.5 | \n", "
159 | \n", "200.0 | \n", "
17 | \n", "175.0 | \n", "
162 | \n", "75.0 | \n", "
184 rows × 1 columns
\n", "RandomForestRegressor(max_depth=4, min_samples_split=5, n_estimators=500,\n", " oob_score=True, random_state=42, warm_start=True)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
RandomForestRegressor(max_depth=4, min_samples_split=5, n_estimators=500,\n", " oob_score=True, random_state=42, warm_start=True)