{ "cells": [ { "cell_type": "code", "execution_count": 93, "metadata": { "tags": [ "remove_cell" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "\n", "pd.set_option('display.max_rows', 7)\n", "\n", "data = pd.read_csv('data/toy_regression.csv')\n", "\n", "def plot3Dscatter(data, xcol, ycol, mdl, actual):\n", " import matplotlib.pyplot as plt\n", " from mpl_toolkits.mplot3d import Axes3D\n", "\n", "\n", " # plot the plane of best fit\n", " # not as useful as one would like!\n", "\n", " fig = plt.figure(figsize=(12, 8))\n", " ax = fig.add_subplot(111, projection='3d')\n", "\n", " x_surf = np.arange(data[xcol].min() - 1, data[xcol].max() + 1)\n", " y_surf = np.arange(data[ycol].min() - 1, data[ycol].max() + 1)\n", " #x_surf = np.arange(50, 80, 5) # generate a mesh\n", " #y_surf = np.arange(50, 80, 5)\n", " x_surf, y_surf = np.meshgrid(x_surf, y_surf)\n", "\n", " zgrid = pd.core.frame.DataFrame(\n", " {xcol: x_surf.ravel(), ycol: y_surf.ravel()}\n", " )\n", "\n", " out = mdl.predict(zgrid)\n", "\n", " ax.plot_surface(x_surf, y_surf,\n", " out.reshape(x_surf.shape),\n", " rstride=1,\n", " cstride=1,\n", " color='None',\n", " alpha=0.4)\n", "\n", " ax.scatter(data[xcol], data[ycol], actual,\n", " c='blue',\n", " marker='o',\n", " alpha=0.7)\n", "\n", " ax.set_xlabel(xcol)\n", " ax.set_ylabel(ycol)\n", " ax.set_zlabel('y')\n", "\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Creating Modeling Pipelines\n", "---\n", "\n", "Building a statistical model requires developing the following steps:\n", "1. Create features to best reflect the meaning behind the data,\n", "2. Choose a model that captures relationships between features,\n", "3. Select a loss function and fit the model,\n", "4. Evaluate the model using the appropriate notion of error.\n", "\n", "Once these steps are completed, the model can be used for prediction or inference.\n", "\n", "Each of these steps may contain complicated transformations and logic, often involving thousands of choices among which parameters and features are most effective. Data pipelines handle this complexity, keeping track of these myriad choices in a organized fashion.\n", "\n", "## Modeling Pipelines in Scikit-Learn\n", "\n", "Models in Scikit-learn are defined as *Estimators*. The Estimator class is similar to the Transformer class for generating features. The Estimator class implements the following interface:\n", "\n", "* `Estimator.set_params` defines parameters needed for fitting the model.\n", "* `Estimator.fit(X, y)` fits the model on the features `X` and the target variable `y`.\n", "* `Estimator.predict(X)` returns predictions for observations in `X` from the fit model.\n", "* `Estimator.score(X, y)` returns an evaluation of the fit model on data `X` and target `y`.\n", " - A custom evaluation metric may be specified to generate the score.\n", " - For classifiers, the default score is the *accuracy*.\n", " - For regressors, the default score is the *coefficient of determination* ($R^2$).\n", " \n", "*Remark:* The Estimator class is similar to the transformer class. The main difference is the presence of the *target variable* `y`. In fact, a fit model behaves like a transformer, where `predict` corresponds to `transform`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Example:** Fitting a multivariate regression model is straightforward using Estimators. Below is a small dataset with two independent variables (`x1` and `x2`) and one target variable `y`." ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | x1 | \n", "x2 | \n", "y | \n", "
---|---|---|---|
0 | \n", "-9.623415 | \n", "-0.991037 | \n", "-19.420887 | \n", "
1 | \n", "-9.109035 | \n", "-1.145796 | \n", "-17.835602 | \n", "
2 | \n", "-8.583274 | \n", "0.885662 | \n", "-14.760810 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "
17 | \n", "4.246154 | \n", "2.712184 | \n", "15.964720 | \n", "
18 | \n", "4.759243 | \n", "4.643905 | \n", "17.381630 | \n", "
19 | \n", "5.121245 | \n", "3.419691 | \n", "21.253360 | \n", "
20 rows × 3 columns
\n", "\n", " | Id | \n", "MSSubClass | \n", "MSZoning | \n", "LotFrontage | \n", "LotArea | \n", "Street | \n", "Alley | \n", "LotShape | \n", "LandContour | \n", "Utilities | \n", "... | \n", "PoolArea | \n", "PoolQC | \n", "Fence | \n", "MiscFeature | \n", "MiscVal | \n", "MoSold | \n", "YrSold | \n", "SaleType | \n", "SaleCondition | \n", "SalePrice | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "1 | \n", "60 | \n", "RL | \n", "65.0 | \n", "8450 | \n", "Pave | \n", "NaN | \n", "Reg | \n", "Lvl | \n", "AllPub | \n", "... | \n", "0 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "0 | \n", "2 | \n", "2008 | \n", "WD | \n", "Normal | \n", "208500 | \n", "
1 | \n", "2 | \n", "20 | \n", "RL | \n", "80.0 | \n", "9600 | \n", "Pave | \n", "NaN | \n", "Reg | \n", "Lvl | \n", "AllPub | \n", "... | \n", "0 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "0 | \n", "5 | \n", "2007 | \n", "WD | \n", "Normal | \n", "181500 | \n", "
2 | \n", "3 | \n", "60 | \n", "RL | \n", "68.0 | \n", "11250 | \n", "Pave | \n", "NaN | \n", "IR1 | \n", "Lvl | \n", "AllPub | \n", "... | \n", "0 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "0 | \n", "9 | \n", "2008 | \n", "WD | \n", "Normal | \n", "223500 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
1457 | \n", "1458 | \n", "70 | \n", "RL | \n", "66.0 | \n", "9042 | \n", "Pave | \n", "NaN | \n", "Reg | \n", "Lvl | \n", "AllPub | \n", "... | \n", "0 | \n", "NaN | \n", "GdPrv | \n", "Shed | \n", "2500 | \n", "5 | \n", "2010 | \n", "WD | \n", "Normal | \n", "266500 | \n", "
1458 | \n", "1459 | \n", "20 | \n", "RL | \n", "68.0 | \n", "9717 | \n", "Pave | \n", "NaN | \n", "Reg | \n", "Lvl | \n", "AllPub | \n", "... | \n", "0 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "0 | \n", "4 | \n", "2010 | \n", "WD | \n", "Normal | \n", "142125 | \n", "
1459 | \n", "1460 | \n", "20 | \n", "RL | \n", "75.0 | \n", "9937 | \n", "Pave | \n", "NaN | \n", "Reg | \n", "Lvl | \n", "AllPub | \n", "... | \n", "0 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "0 | \n", "6 | \n", "2008 | \n", "WD | \n", "Normal | \n", "147500 | \n", "
1460 rows × 81 columns
\n", "