diff --git a/salary_prediction_with_linear_regression/salary-prediction-linear-regression-cpp.ipynb b/salary_prediction_with_linear_regression/salary-prediction-linear-regression-cpp.ipynb new file mode 100644 index 00000000..fd8f62f4 --- /dev/null +++ b/salary_prediction_with_linear_regression/salary-prediction-linear-regression-cpp.ipynb @@ -0,0 +1,396 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "94323844", + "metadata": {}, + "source": [ + "## Predicting Salary using Linear Regression\n", + "\n", + "### Objective\n", + "* We have to predict the salary of an employee given how many years of experience they have.\n", + "\n", + "### Dataset\n", + "* Salary_Data.csv has 2 columns — “Years of Experience” (feature) and “Salary” (target) for 30 employees in a company\n", + "\n", + "### Approach\n", + "* So in this example, we will train a Linear Regression model to learn the correlation between the number of years of experience of each employee and their respective salary. \n", + "* Once the model is trained, we will be able to do some sample predictions." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "189dc5ff-22c4-4502-89a8-75e5ce51f3e1", + "metadata": {}, + "outputs": [], + "source": [ + "!wget -q https://mlpack.org/datasets/Salary_Data.csv" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "behavioral-cycling", + "metadata": {}, + "outputs": [], + "source": [ + "// Import necessary library header.\n", + "#include \n", + "\n", + "#include \n", + "#include \n", + "#include \n", + "#include " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "db43325d", + "metadata": {}, + "outputs": [], + "source": [ + "#define WITHOUT_NUMPY 1\n", + "#include \"matplotlibcpp.h\"\n", + "#include \"xwidgets/ximage.hpp\"\n", + "\n", + "namespace plt = matplotlibcpp;" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9065ebb1", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace mlpack;\n", + "using namespace mlpack::regression;" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "victorian-donna", + "metadata": {}, + "outputs": [], + "source": [ + "// Load the dataset into armadillo matrix.\n", + "\n", + "arma::mat inputs;\n", + "data::Load(\"Salary_Data.csv\", inputs);" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "deluxe-present", + "metadata": {}, + "outputs": [], + "source": [ + "// Drop the first row as they represent header.\n", + "\n", + "inputs.shed_col(0);" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "desirable-experience", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Years Of Experience Salary\n", + " 1.1000e+00 3.9343e+04\n", + " 1.3000e+00 4.6205e+04\n", + " 1.5000e+00 3.7731e+04\n", + " 2.0000e+00 4.3525e+04\n", + " 2.2000e+00 3.9891e+04\n", + " 2.9000e+00 5.6642e+04\n", + "\n" + ] + } + ], + "source": [ + "// Display the first 5 rows of the input data.\n", + "\n", + "std::cout << std::setw(18) << \"Years Of Experience\" << std::setw(10) << \"Salary\" << std::endl;\n", + "std::cout << inputs.submat(0, 0, inputs.n_rows-1, 5).t() << std::endl;" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "associate-fifteen", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "912d932e54c14571a0ac726764dac35f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "A Jupyter widget with unique id: 912d932e54c14571a0ac726764dac35f" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "// Plot the input data.\n", + "\n", + "std::vector x = arma::conv_to>::from(inputs.row(0));\n", + "std::vector y = arma::conv_to>::from(inputs.row(1));\n", + "\n", + "plt::figure_size(800, 800);\n", + "\n", + "plt::scatter(x, y, 12, {{\"color\",\"coral\"}});\n", + "plt::xlabel(\"Years of Experience\");\n", + "plt::ylabel(\"Salary in $\");\n", + "plt::title(\"Experience vs. Salary\");\n", + "\n", + "plt::save(\"./scatter.png\");\n", + "auto img = xw::image_from_file(\"scatter.png\").finalize();\n", + "img" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "coordinate-canvas", + "metadata": {}, + "outputs": [], + "source": [ + "// Split the data into features (X) and target (y) variables\n", + "// targets are the last row.\n", + "\n", + "arma::Row targets = arma::conv_to>::from(inputs.row(inputs.n_rows - 1));" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "blank-mexican", + "metadata": {}, + "outputs": [], + "source": [ + "// Labels are dropped from the originally loaded data to be used as features.\n", + "\n", + "inputs.shed_row(inputs.n_rows - 1);" + ] + }, + { + "cell_type": "markdown", + "id": "8da116b5-83f2-4acd-8ac3-0d68adbd83ca", + "metadata": {}, + "source": [ + "### Train Test Split\n", + "The dataset has to be split into a training set and a test set.\n", + "This can be done using the `data::Split()` api from mlpack.\n", + "Here the dataset has 30 observations and the `testRatio` is taken as 40% of the total observations.\n", + "This indicates the test set should have 40% * 30 = 12 observations and training test should have 18 observations respectively." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "mechanical-laundry", + "metadata": {}, + "outputs": [], + "source": [ + "// Split the dataset into train and test sets using mlpack.\n", + "\n", + "arma::mat Xtrain;\n", + "arma::mat Xtest;\n", + "arma::Row Ytrain;\n", + "arma::Row Ytest;\n", + "data::Split(inputs, targets, Xtrain, Xtest, Ytrain, Ytest, 0.4);" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "friendly-petersburg", + "metadata": {}, + "outputs": [], + "source": [ + "// Convert armadillo Rows into rowvec. (Required by mlpacks' LinearRegression API in this format).\n", + "\n", + "arma::rowvec yTrain = arma::conv_to::from(Ytrain);\n", + "arma::rowvec yTest = arma::conv_to::from(Ytest);" + ] + }, + { + "cell_type": "markdown", + "id": "99955e22", + "metadata": {}, + "source": [ + "## Linear Model\n", + "\n", + "Regression analysis is the most widely used method of prediction. Linear regression is used when the dataset has a linear correlation and as the name suggests, \n", + "simple linear regression has one independent variable (predictor) and one dependent variable(response).\n", + "\n", + "The simple linear regression equation is represented as $y = a+bx$ where $x$ is the explanatory variable, $y$ is the dependent variable, $b$ is coefficient and $a$ is the intercept\n", + "\n", + "To perform linear regression we'll be using `LinearRegression()` api from mlpack." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "published-illustration", + "metadata": {}, + "outputs": [], + "source": [ + "// Create and Train Linear Regression model.\n", + "\n", + "regression::LinearRegression lr(Xtrain, yTrain, 0.5);" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "detailed-mystery", + "metadata": {}, + "outputs": [], + "source": [ + "// Make predictions for test data points.\n", + "\n", + "arma::rowvec yPreds;\n", + "lr.Predict(Xtest, yPreds);" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "indian-ambassador", + "metadata": {}, + "outputs": [], + "source": [ + "// Convert armadillo vectors and matrices to vector for plotting purpose.\n", + "\n", + "std::vector XtestPlot = arma::conv_to>::from(Xtest);\n", + "std::vector yTestPlot = arma::conv_to>::from(yTest);\n", + "std::vector yPredsPlot = arma::conv_to>::from(yPreds);" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "related-approach", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "88f7de7663bd431382ce760f7f8a08a0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "A Jupyter widget with unique id: 88f7de7663bd431382ce760f7f8a08a0" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "// Visualize Predicted datapoints.\n", + "plt::figure_size(800, 800);\n", + "\n", + "plt::scatter(XtestPlot, yTestPlot, 12, {{\"color\", \"coral\"}});\n", + "plt::plot(XtestPlot,yPredsPlot);\n", + "plt::xlabel(\"Years of Experience\");\n", + "plt::ylabel(\"Salary in $\");\n", + "plt::title(\"Predicted Experience vs. Salary\");\n", + "\n", + "plt::save(\"./scatter1.png\");\n", + "auto img = xw::image_from_file(\"scatter1.png\").finalize();\n", + "img" + ] + }, + { + "cell_type": "markdown", + "id": "0a10abbb-6b3a-423f-a573-1c650ac60b85", + "metadata": {}, + "source": [ + "Test data is visualized with `XtestPlot` and `yPredsPlot`, the coral points indicates the data points and the blue line indicates the regression line or best fit line." + ] + }, + { + "cell_type": "markdown", + "id": "c24be191-959f-4244-8921-c1ee0ea98b3b", + "metadata": {}, + "source": [ + "## Evaluation Metrics for Regression model\n", + "\n", + "In the Previous cell we have visualized our model performance by plotting the best fit line. Now we will use various evaluation metrics to understand how well our model has performed.\n", + "\n", + "* Mean Absolute Error (MAE) is the sum of absolute differences between actual and predicted values, without considering the direction.\n", + "$$ MAE = \\frac{\\sum_{i=1}^n\\lvert y_{i} - \\hat{y_{i}}\\rvert} {n} $$\n", + "* Mean Squared Error (MSE) is calculated as the mean or average of the squared differences between predicted and expected target values in a dataset, a lower value is better\n", + "$$ MSE = \\frac {1}{n} \\sum_{i=1}^n (y_{i} - \\hat{y_{i}})^2 $$\n", + "* Root Mean Squared Error (RMSE), Square root of MSE yields root mean square error (RMSE) it indicates the spread of the residual errors. It is always positive, and a lower value indicates better performance.\n", + "$$ RMSE = \\sqrt{\\frac {1}{n} \\sum_{i=1}^n (y_{i} - \\hat{y_{i}})^2} $$" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "id": "british-moment", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean Absolute Error: 5753.06\n", + "Mean Squared Error: 3.9482e+07\n", + "Root Mean Squared Error: 6283.47\n" + ] + } + ], + "source": [ + "// Model evaluation metrics.\n", + "\n", + "std::cout << \"Mean Absolute Error: \" << arma::mean(arma::abs(yPreds - yTest)) << std::endl;\n", + "std::cout << \"Mean Squared Error: \" << arma::mean(arma::pow(yPreds - yTest,2)) << std::endl;\n", + "std::cout << \"Root Mean Squared Error: \" << sqrt(arma::mean(arma::pow(yPreds - yTest,2))) << std::endl;" + ] + }, + { + "cell_type": "markdown", + "id": "17cd38d7-214a-4f5a-8c4d-0517f834e804", + "metadata": {}, + "source": [ + "From the above metrics we can notice that our model MAE is ~5K, which is relatively small compared to our average salary of $76003, from this we can conclude our model is resonably good fit." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "C++14", + "language": "C++14", + "name": "xcpp14" + }, + "language_info": { + "codemirror_mode": "text/x-c++src", + "file_extension": ".cpp", + "mimetype": "text/x-c++src", + "name": "c++", + "version": "14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/salary_prediction_with_linear_regression/salary-prediction-linear-regression-py.ipynb b/salary_prediction_with_linear_regression/salary-prediction-linear-regression-py.ipynb new file mode 100644 index 00000000..20a66613 --- /dev/null +++ b/salary_prediction_with_linear_regression/salary-prediction-linear-regression-py.ipynb @@ -0,0 +1,528 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "technical-identification", + "metadata": {}, + "source": [ + "## Predicting Salary using Linear Regression\n", + "\n", + "### Objective\n", + "* We have to predict the salary of an employee given how many years of experience they have.\n", + "\n", + "### Dataset\n", + "* Salary_Data.csv has 2 columns — “Years of Experience” and “Salary” for 30 employees in a company\n", + "\n", + "### Approach\n", + "* So in this example, we will train a Linear Regression model to learn the correlation between the number of years of experience of each employee and their respective salary. \n", + "* Once the model is trained, we will be able to do some sample predictions." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "449a2f52", + "metadata": {}, + "outputs": [], + "source": [ + "# Import Libraries.\n", + "\n", + "import mlpack\n", + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "786e154b", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "# Uncomment below line to enable dark background style sheet.\n", + "# plt.style.use('dark_background')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9c7de4da", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the salary dataset.\n", + "data = pd.read_csv(\"https://mlpack.org/datasets/Salary_Data.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1d59786b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
YearsExperienceSalary
01.139343.0
11.346205.0
21.537731.0
32.043525.0
42.239891.0
\n", + "
" + ], + "text/plain": [ + " YearsExperience Salary\n", + "0 1.1 39343.0\n", + "1 1.3 46205.0\n", + "2 1.5 37731.0\n", + "3 2.0 43525.0\n", + "4 2.2 39891.0" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Display the first 5 samples from dataframe.\n", + "data.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "5a3a26af", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
YearsExperienceSalary
count30.00000030.000000
mean5.31333376003.000000
std2.83788827414.429785
min1.10000037731.000000
25%3.20000056720.750000
50%4.70000065237.000000
75%7.700000100544.750000
max10.500000122391.000000
\n", + "
" + ], + "text/plain": [ + " YearsExperience Salary\n", + "count 30.000000 30.000000\n", + "mean 5.313333 76003.000000\n", + "std 2.837888 27414.429785\n", + "min 1.100000 37731.000000\n", + "25% 3.200000 56720.750000\n", + "50% 4.700000 65237.000000\n", + "75% 7.700000 100544.750000\n", + "max 10.500000 122391.000000" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Generates basic statistical summary of the dataframe.\n", + "data.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "8d8410cd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 30 entries, 0 to 29\n", + "Data columns (total 2 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 YearsExperience 30 non-null float64\n", + " 1 Salary 30 non-null float64\n", + "dtypes: float64(2)\n", + "memory usage: 608.0 bytes\n" + ] + } + ], + "source": [ + "# Generates a concise summary of the dataframe.\n", + "data.info()" + ] + }, + { + "cell_type": "markdown", + "id": "78f2eea6", + "metadata": {}, + "source": [ + "### Exploratory Data Analysis" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "ef71b4dc", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZgAAAEWCAYAAABbgYH9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3de5hddX3v8fdnMnESCZCQBIRMQqiJWuAgLSOiOXIo+AhWBVqlxlNKqrFUD1Z7eUzAcyzFy1OCtqj1iEZAwkUgTWxDPYICsdVSLiYawk1KLJdMiBBCggkmQy7f88f6bbJms+eSmVl77b3n83qe/ey1f2v91vqtTdjf+V2XIgIzM7OR1lZ2AczMrDU5wJiZWSEcYMzMrBAOMGZmVggHGDMzK4QDjJmZFcIBxmwIJL1N0qNll6ORSPpXSR8uuxzWOBxgrKlIekLSDknbc6+v1rscEfHjiHh9va9bNEnHSPqBpC2StkpaLel3yy6XNaf2sgtgNgTviYg7yrq4pPaI2F3W9Qv2L8AVwLvT5zcBKvKCLf59jmquwVjLkHSFpGW5z4sk3anMKZK6JX1K0nOpJvSHuWM7JH1R0lOSnpH0dUnj075K3oWSfgl8q5KWy3+EpOWSNkl6XNLHc/v+RtJSSddK2ibpIUlduf3TJX0n5d2cr5FJ+pCkR1KN4vuSjuzj3m+T9LGqtPsl/X66/8slPSvpBUlrJR1b4xxTgKOAb0bES+l1V0T8e9o/SdJ3Uzm3pO3OPsrzWkkr0/08J+kGSRNz+59I3+da4EVJn5S0vOoc/yDpS7XOb83BAcZayV8Bx0n6Y0lvA+YD82LfekivAaYA04B5wGJJlWauRcDrgOOBWemYv86d+zXAIcCRwPn5i0pqI/vL//6U7zTgzyWdnjvsTOAmYCJwC/DVlHcM8F3gSWBmyn9T2nc28Cng94GpwI+BG/u4928DH8iV6ehU1v8HvAM4Od3fROD9wOYa59gMrAOul3S2pMOq9rcB30rnnQHsqNxHDQL+FjgC+E1gOvA3Vcd8AHhXKtP1wBmVICSpPZXzuj7Ob80gIvzyq2lewBPAdmBr7vUnuf0nAs+T/WB/IJd+CrAbOCCXthT4NNmP4YvAa3P73gI8nsv7EjCu6nzdafvNwFNV5bwI+Fba/hvgjty+o4EduetsAtpr3OutwPzc5zbg18CRNY49MN3Dkenz54Gr0/apwH8CJwFtA3y/nWRB4xfAXuBHwOw+jj0e2JL7/K/Ah/s49mzgZ1X/HT9U437/JG2/G3i47H9vfg3v5RqMNaOzI2Ji7vXNyo6IuA/4L7KgsbQq35aIeDH3+Umyv7CnAq8GVqeO7a3AbSm9YlNE7OyjPEcCR1TypvyfAvI1gF/mtn8NjEt/pU8HnozafRBHAl/OnfP5dF/Tqg+MiG1ktZW5KWkucEPat5IsaPxf4BlJiyUdVOtGIqI7Ij4WEa9N138RuBZA0qslfUPSk5J+RRZ8JqZaWC+SDpV0k6QN6djryWqPeeurPi8Bzk3b5+LaS9NzgLGWIukCoAN4GlhQtXuSpANyn2ek454ja+45Jhe0Do6ICblj+1t2fD1ZbScf9A6MiMGMvloPzEjBpta+P6067/iI+I8+znUj8AFJbwHGAz98ufARX4mIE4BjyJrKPjlQwSJiPVlQqvTX/BXweuDNEXEQWbMb1B4E8Ldk39lx6dhzaxxX/Z3+M1kT57FkNZgbBiqjNTYHGGsZkl4HfI7sx+yPgAWSjq867BJJr0p9NO8G/jEi9gLfBC6XdGg617SqPpT+3Af8KnVaj5c0RtKxkt40yLwbgUslHSBpnKQ5ad/XgYskHZPKdLCkc/o51/fIah2fAW5O94WkN0l6s6SxZDWSncCe6sypE/8SSbMktaVO/w8B96RDDiQLxFslHQJc3E9ZDiQ1ZUqaxuAC2k5gGVl/0n0R8dRAeayxOcBYM/oX9Z4H80+pBnA9sCgi7o+Ix8iaqa6T1JHy/RLYQlZruQH4SET8PO1bSNbBfU9q0rmD7K/1AUXEHuA9ZH0Sj5PViK4EDt6PvLOAp4Buss5tIuKfyAYf3JTK9CDwzn7O1QN8B3g72Y90xUFkAXQLWbPgZuCLNU7xEtlAgzuAyvV6gD9O+79EVjN6jizo3NbPrV0C/DbwAlnT3Xf6OTZvCfDfcPNYS1CEHzhmrU/SKcD1EVFzWK01BkkzgJ8Dr4mIX5VdHhse12DMrCGk4d5/Cdzk4NIaPJPfzEqXBl88Q9aEd0bJxbER4iYyMzMrhJvIzMysEG4iS6ZMmRIzZ84suxhmZk1l9erVz0XE1Fr7HGCSmTNnsmrVqrKLYWbWVCQ92dc+N5GZmVkhHGDMzKwQDjBmZlYIBxgzMyuEA4yZmRXCAcbMrMVt3t7D/eu3snl7T12v62HKZmYtbMWaDSxcvpaxbW3s2ruXy957HGcev++ZdZu399C9ZQedk8YzeUJHP2fafw4wZmYtavP2HhYuX8vOXXvZyV4AFixfy5xZU5g8oWPA4DNcbiIzM2tR3Vt2MLat98/82LY2urfs6BV8tvXsZueuvSxYvnZEm9EcYMzMWlTnpPHs2ru3V9quvXvpnDS+3+AzUhxgzMxa1OQJHVz23uMYN7aNAzvaGTe2jcveexyTJ3T0G3xGSmEBRtLVkp6V9GAu7QuSfi5pbXrM7cTcvoskrZP0aP5Z6JJOkPRA2vcVSUrpHZJuTun3SpqZyzNP0mPpNa+oezQza3RnHj+NuxaeyvUffjN3LTz15T6W/oLPSCnseTCSTga2A9dGxLEp7R3AyojYLWkRQEQslHQ0cCNwInAE2TPBXxcReyTdB3yC7Bng3wO+EhG3SvpfwHER8RFJc4Hfi4j3SzoEWAV0AQGsBk6IiC39lberqyu82KWZjTbDHUUmaXVEdNXaV1gNJiJ+BDxflfaDiNidPt4DVJ6PfhbZY1J7IuJxYB1woqTDgYMi4u7IIuG1wNm5PEvS9jLgtFS7OR24PSKeT0HldvyEPDOzmiZP6OCN0yeO+BBlKLcP5kPArWl7GrA+t687pU1L29XpvfKkoPUCMLmfc72CpPMlrZK0atOmTcO6GTMz662UACPpfwO7gRsqSTUOi37Sh5qnd2LE4ojoioiuqVNrPi/HzMyGqO4BJnW6vxv4w9jXAdQNTM8d1gk8ndI7a6T3yiOpHTiYrEmur3OZmVkd1TXASDoDWAicGRG/zu26BZibRoYdBcwG7ouIjcA2SSel/pXzgBW5PJURYu8jGzwQwPeBd0iaJGkS8I6UZmZmdVTYUjGSbgROAaZI6gYuBi4COoDb02jjeyLiIxHxkKSlwMNkTWcXRMSedKqPAtcA48n6bCr9NlcB10laR1ZzmQsQEc9L+izwk3TcZyKi12ADM7NGUOQ6YI2gsGHKzcbDlM2snopeB6xeShmmbGZmtdVjHbBG4ABjZlZn9VgHrBE4wJiZ1Vk91gFrBA4wZmZ1Vo91wBqBHzhmZlaCM4+fxpxZU1p6FJkDjJm1nGYZ/jt5QkdDl2+4HGDMrKW0yvDfVuA+GDNrGaNl+G+zcIAxs5bR7MN/N2/v4f71W1smILqJzMxaRjMP/23Fpj3XYMysZTTr8N9WbdpzDcbMWkozDv+tNO3tZF/tq9K01wzl74sDjJm1nGYb/tvMTXv9cROZmVnJmrVpbyCuwZiZNYBmbNobiAOMmVmDaLamvYG4iczMrIZWm5NSBtdgzMyqtOKclDK4BmNmltOqc1LK4ABjZpbT7MvNNBIHGDOznFadk1IGBxgzs5xWnZNSBnfym5lVacU5KWVwgDEzq6HV5qSUwU1kZmZD5Lky/XMNxsxsCDxXZmCuwZiZ7SfPlRkcBxgzs/3kuTKD4wBjZrafPFdmcBxgzMz2k+fKDI47+c3MhsBzZQbmAGNmNkSeK9M/N5GZmVkhHGDMzKwQDjBm1lQ8e755uA/GzJqGZ883F9dgzKwpePZ88ykswEi6WtKzkh7MpR0i6XZJj6X3Sbl9F0laJ+lRSafn0k+Q9EDa9xVJSukdkm5O6fdKmpnLMy9d4zFJ84q6RzOrH8+ebz5F1mCuAc6oSrsQuDMiZgN3ps9IOhqYCxyT8nxN0piU5wrgfGB2elXOOR/YEhGzgMuBRelchwAXA28GTgQuzgcyM2tOnj3ffAoLMBHxI+D5quSzgCVpewlwdi79pojoiYjHgXXAiZIOBw6KiLsjIoBrq/JUzrUMOC3Vbk4Hbo+I5yNiC3A7rwx0ZtZkPHu++dS7k/+wiNgIEBEbJR2a0qcB9+SO605pu9J2dXolz/p0rt2SXgAm59Nr5DGzJubZ882lUUaRqUZa9JM+1Dy9LyqdT9b8xowZMwYupZmVzrPnm0e9R5E9k5q9SO/PpvRuYHruuE7g6ZTeWSO9Vx5J7cDBZE1yfZ3rFSJicUR0RUTX1KlTh3FbZub5KVat3gHmFqAyqmsesCKXPjeNDDuKrDP/vtSctk3SSal/5byqPJVzvQ9Ymfppvg+8Q9Kk1Ln/jpRmZgVZsWYDcxat5Nwr72XOopXcsmZD2UWyBlBYE5mkG4FTgCmSuslGdl0KLJU0H3gKOAcgIh6StBR4GNgNXBARe9KpPko2Im08cGt6AVwFXCdpHVnNZW461/OSPgv8JB33mYioHmxgZiMkPz9lJ9korwXL1zJn1hQ3ZY1yhQWYiPhAH7tO6+P4zwOfr5G+Cji2RvpOUoCqse9q4OpBF9bMhqwyP6USXGDf/BQHmNHNM/nNbFg8P8X64gBjZsPi+SnWl0YZpmxmTczzU6wWBxgzGxGen2LV3ERmZmaFcIAxM7NCOMCYNSnPnLdG5z4YsybkJztaM3ANxqzJ+MmO1iwcYMyajJ/saM3CAcasyXjmvDULBxizJlPvmfMeTGBD5U5+syZUr5nzHkxgw+EAY9akip4572X4bbjcRGZmNXkwgQ2XA4yZ1eTBBDZcDjBmVpOX4bfhch+MmfXJy/DbcDjAmFm/vAy/DZWbyMzMrBAOMGZmVggHGDMzK4QDjJmZFcIBxszMCuEAY2ZmhXCAMTOzQjjAmJlZIRxgzMysEIMKMJLGFF0QMzNrLYOtwayT9AVJRxdaGjMzaxmDDTDHAf8JXCnpHknnSzqowHKZmVmTG1SAiYhtEfHNiHgrsAC4GNgoaYmkWYWW0KzJ+Zn2NloNajXl1AfzLuCDwEzg74AbgLcB3wNeV1D5zJqan2lvo9lgl+t/DPgh8IWI+I9c+jJJJ498scyan59pb6PdgE1kqfZyTUTMrwouAETExwspmVmT8zPtbbQbMMBExB7gd+pQFrOWUsQz7d2fY81ksE1k/yHpq8DNwIuVxIj4aSGlMmsBlWfaL6jqgxlq85j7c6zZKCIGPkj6YY3kiIhTR75I5ejq6opVq1aVXQxrQZu39wz7mfabt/cwZ9FKdu7aVyMaN7aNuxae6v4cK5Wk1RHRVWvfoGowETGiTWSS/gL4MBDAA2Sj015NVkOaCTwB/EFEbEnHXwTMB/YAH4+I76f0E4BrgPFko9k+EREhqQO4FjgB2Ay8PyKeGMl7MBuskXimfaU/pzJYAPb15zjAWKMa9Fpkkt4laYGkv668hnJBSdOAjwNdEXEsMAaYC1wI3BkRs4E702fS6gFzgWOAM4Cv5ZauuQI4H5idXmek9PnAloiYBVwOLBpKWc0aRRH9OWZFG+xaZF8H3g/8GSDgHODIYVy3HRgvqZ2s5vI0cBawJO1fApydts8CboqInoh4HFgHnCjpcOCgiLg7sna+a6vyVM61DDhNkoZRXrNSVfpzxo1t48COdsaNbRtWf45ZPQy2k/+tEXGcpLURcYmkvwO+M5QLRsQGSV8EngJ2AD+IiB9IOiwiNqZjNko6NGWZBtyTO0V3StuVtqvTK3nWp3PtlvQCMBl4Ll8WSeeT1YCYMWPGUG7HrG7OPH4ac2ZNGXZ/jlm9DLaJrDJw/9eSjiD7cT9qKBeUNImshnEUcARwgKRz+8tSIy36Se8vT++EiMUR0RURXVOnTu2/4GYNYPKEDt44faKDizWFwQaY70qaCHwB+ClZJ/xNQ7zm24HHI2JTROwiqwm9FXgmNXuR3p9Nx3cD03P5O8ma1LrTdnV6rzypGe5g4PkhltfMzIZgsItdfjYitkbEcrK+lzdExKeHeM2ngJMkvTr1i5wGPALcAsxLx8wDVqTtW4C5kjokHUXWmX9fak7bJumkdJ7zqvJUzvU+YGUMZjy2mZmNmH77YCT9fj/7iIj97oeJiHslLSOrCe0GfgYsBiYASyXNJwtC56TjH5K0FHg4HX9BWl0A4KPsG6Z8a3oBXAVcJ2kdWc1l7v6W08zMhqffiZaSvtVP3oiID418kcrhiZZmZvtvyBMtI+KDxRTJzMxa3WCHKSPpXWSTHcdV0iLiM0UUyszMml9ZEy3NzKzFDXaY8lsj4jyy5VcuAd5C76HDZmZmvQx1ouVuhjjR0szMRofB9sFUJlpeBqxOaVcWUyQzM2sFA82DeROwPiI+mz5PIFte/+dkqxSbmZnVNFAT2TeAlwAknQxcmtJeIJscaWZmVtNATWRjIqKyhtf7gcVpuZjlktYUWzQzM2tmA9VgxqTFIiFbM2xlbt+g59CYmdnoM1CQuBH4N0nPkY0k+zGApFlkzWRmZmY1DbRUzOcl3QkcTvZgsMrCZW1kky7NzMxqGrCZKyLuqZH2n8UUx8zMWsVgJ1qamZntFwcYMzMrhAOMmZkVwgHGzMwK4QBjZmaFcIAxM7NCOMCY7YfN23u4f/1WNm/vKbsoZg3Py72YDdKKNRtYuHwtY9va2LV3L5e99zjOPH5a2cUya1iuwZgNwubtPSxcvpadu/ayrWc3O3ftZcHyta7JmPXDAcZsELq37GBsW+//Xca2tdG9ZUcfOczMAcZsEDonjWfX3r290nbt3UvnpPEllcis8TnAmA3C5AkdXPbe4xg3to0DO9oZN7aNy957HJMndADu/DerxZ38ZoN05vHTmDNrCt1bdtA5afzLwcWd/2a1OcCY7YfJEzpeDizQu/N/J1kT2oLla5kza0qv48xGIzeRmQ2DO//N+uYAYyNuNPVHuPPfrG9uIrMRNdr6Iyqd/wuq7tnNY2YOMDaCiu6P2Ly95xUd7I2gr85/s9HOAcZGTKU/ohJcYF9/xHB/dBu9ZlTd+W9m7oOxEVRUf4SXaTFrTg4wNmIGmow4VB6pZdac3ERmI6qI/giP1DJrTq7B2IibPKGDN06fOGJ9EkXVjEbTcGqzMrgGY01hpGtGjT5owKwVOMBY0xipkVpe3sWsPkppIpM0UdIyST+X9Iikt0g6RNLtkh5L75Nyx18kaZ2kRyWdnks/QdIDad9XJCmld0i6OaXfK2lm/e+ytTVz85IHDZjVR1l9MF8GbouINwBvBB4BLgTujIjZwJ3pM5KOBuYCxwBnAF+TNCad5wrgfGB2ep2R0ucDWyJiFnA5sKgeNzVarFizgTmLVnLulfcyZ9FKblmzoewi7RcPGjCrj7oHGEkHAScDVwFExEsRsRU4C1iSDlsCnJ22zwJuioieiHgcWAecKOlw4KCIuDsiAri2Kk/lXMuA0yq1GxueVpiTUtSgATPrrYw+mN8ANgHfkvRGYDXwCeCwiNgIEBEbJR2ajp8G3JPL353SdqXt6vRKnvXpXLslvQBMBp7LF0TS+WQ1IGbMmDFS99fSipytX09e3sWseGU0kbUDvw1cERG/BbxIag7rQ62aR/ST3l+e3gkRiyOiKyK6pk6d2n+pDWit5qWRHk5tZr2VEWC6ge6IuDd9XkYWcJ5JzV6k92dzx0/P5e8Enk7pnTXSe+WR1A4cDDw/4ncyCrl5ycwGq+5NZBHxS0nrJb0+Ih4FTgMeTq95wKXpfUXKcgvwbUl/DxxB1pl/X0TskbRN0knAvcB5wD/k8swD7gbeB6xM/TQ2AopuXmrUVZPNbP+UNQ/mz4AbJL0K+C/gg2S1qaWS5gNPAecARMRDkpaSBaDdwAURsSed56PANcB44Nb0gmwAwXWS1pHVXObW46ZGk6JWD96fCZD5QAQ4KJk1GPkP+0xXV1esWrWq7GI0tKJrFpu39zBn0Up27trXxzNubBt3LTz1FdfLB6Idu3YjiXHtYzwr36zOJK2OiK5a+zyT3walHkurDHaEWq2Z+BDs2rMb8Kx8s0bhxS5tQPWa+zLYEWq1ZuLneVa+WWNwgLEB1WtplcGOUOucNJ6du/f0cZbmHTZt1mrcRGYDqufcl8GOUKvVd3hAxxj27A0PmzZrEA4wNqBKzWJBVR9MXz/iwx0MMNAIte4tOxg/tp1tPbtfTjvgVWO45D3H8DtvONTBxaxBOMDYoAy2ZlGPwQC1alR7IhxczBqM+2Bs0AZaWqVegwG8moBZc3ANxkZMPRfC9GKVZo3PAcZGTL0XwixqNQEzGxluIrMR46YrM8tzDcZGlJuuzKzCAcZGnJuuzAzcRGZmZgVxgDEzs0I4wJiZWSEcYMzMrBAOMGZmVggHGDMzK4QDjJmZFcIBxszMCuEAY2ZmhXCAMTOzQjjAmJlZIRxgCrR5ew/3r9864g/cMjNrBl7ssiD1eHSwmVkjcw2mAPV6dLCZWSNzgClA5dHBeZVHB5uZjRYOMAWo96ODzcwakQNMAfzoYDMzd/IXpuhHB2/e3uPHEptZQ3OAKVBRjw72CDUzawZuImsyHqFmZs3CAabJeISamTULB5gm4xFqZtYsHGCajEeomVmzcCd/Eyp6hJqZ2UhwgGlSRY1QMzMbKaU1kUkaI+lnkr6bPh8i6XZJj6X3SbljL5K0TtKjkk7PpZ8g6YG07yuSlNI7JN2c0u+VNLPe92dmNtqV2QfzCeCR3OcLgTsjYjZwZ/qMpKOBucAxwBnA1ySNSXmuAM4HZqfXGSl9PrAlImYBlwOLir2VxuFHBJhZoyglwEjqBN4FXJlLPgtYkraXAGfn0m+KiJ6IeBxYB5wo6XDgoIi4OyICuLYqT+Vcy4DTKrWbIjTKj/qKNRuYs2gl5155L3MWreSWNRtKLY+ZjW5l9cF8CVgAHJhLOywiNgJExEZJh6b0acA9ueO6U9qutF2dXsmzPp1rt6QXgMnAcyN8Hw0zqz4/AXMn2TDmBcvXMmfWFPfVmFkp6l6DkfRu4NmIWD3YLDXSop/0/vJUl+V8Saskrdq0adMgi7NPI82q9wRMM2s0ZTSRzQHOlPQEcBNwqqTrgWdSsxfp/dl0fDcwPZe/E3g6pXfWSO+VR1I7cDDwfHVBImJxRHRFRNfUqVP3+0Ya6UfdEzDNrNHUPcBExEUR0RkRM8k671dGxLnALcC8dNg8YEXavgWYm0aGHUXWmX9fak7bJumk1L9yXlWeyrnel67xihrMcA3nR32k+208AdPMGk0jzYO5FFgqaT7wFHAOQEQ8JGkp8DCwG7ggIvakPB8FrgHGA7emF8BVwHWS1pHVXOYWUeDKj/qCqj6YgX7Ui+q38QRMM2skKuAP+6bU1dUVq1atGlLe/Xk2y+btPcxZtJKdu/bVfMaNbeOuhac6IJhZ05G0OiK6au1rpBpM09qfWfWVfpvKSC/Y12/jAGNmrcSLXdaZO+PNbLRwgKkzd8ab2WjhJrISuDPezEYDB5iSeDVkM2t1biIzM7NCOMCYmVkhHGDMzKwQDjBmZlYIBxgzMyuEl4pJJG0Cniy7HPtpCgU846bJjPbvYLTfP/g7gHK/gyMjouZy9A4wTUzSqr7WABotRvt3MNrvH/wdQON+B24iMzOzQjjAmJlZIRxgmtvisgvQAEb7dzDa7x/8HUCDfgfugzEzs0K4BmNmZoVwgDEzs0I4wDQhSdMl/VDSI5IekvSJsstUBkljJP1M0nfLLksZJE2UtEzSz9O/hbeUXaZ6kvQX6d//g5JulDSu7DIVTdLVkp6V9GAu7RBJt0t6LL1PKrOMeQ4wzWk38FcR8ZvAScAFko4uuUxl+ATwSNmFKNGXgdsi4g3AGxlF34WkacDHga6IOBYYA8wtt1R1cQ1wRlXahcCdETEbuDN9bggOME0oIjZGxE/T9jayH5Zp5ZaqviR1Au8Criy7LGWQdBBwMnAVQES8FBFbyy1V3bUD4yW1A68Gni65PIWLiB8Bz1clnwUsSdtLgLPrWqh+OMA0OUkzgd8C7i23JHX3JWABsLfsgpTkN4BNwLdSM+GVkg4ou1D1EhEbgC8CTwEbgRci4gfllqo0h0XERsj++AQOLbk8L3OAaWKSJgDLgT+PiF+VXZ56kfRu4NmIWF12WUrUDvw2cEVE/BbwIg3UNFK01M9wFnAUcARwgKRzyy2VVXOAaVKSxpIFlxsi4jtll6fO5gBnSnoCuAk4VdL15Rap7rqB7oio1FyXkQWc0eLtwOMRsSkidgHfAd5acpnK8oykwwHS+7Mll+dlDjBNSJLI2t4fiYi/L7s89RYRF0VEZ0TMJOvYXRkRo+qv14j4JbBe0utT0mnAwyUWqd6eAk6S9Or0/8NpjKJBDlVuAeal7XnAihLL0kt72QWwIZkD/BHwgKQ1Ke1TEfG9Estk9fdnwA2SXgX8F/DBkstTNxFxr6RlwE/JRlX+jAZdLmUkSboROAWYIqkbuBi4FFgqaT5Z4D2nvBL25qVizMysEG4iMzOzQjjAmJlZIRxgzMysEA4wZmZWCAcYMzMrhAOMtTRl/l3SO3NpfyDptpLLtFTSWkkfr9r3OUkbJK3JvQ4suDzfL/oaNjp5mLK1PEnHAv9ItmbbGGANcEZE/GIY52yPiN1DzNsJ/FtEvLbGvs8Bz0XEl4Zatv0oh8h+A0brem5WMNdgrOVFxIPAvwALySamXRsRv5A0T9J9qZbwNUltAJIWS1qVnjXy15XzSOqW9GlJdwG/l55H8rCk+2stVSNpvKQlkh6Q9FNJJ6ddPwCOSNcd1PImkhZIWpy2j0/nHJ9qPEvS84Eek/ShXJ4L0/2trdyHpFnp+SlfJ5ukeHi6r4lp/yu+E0ntkrZKujTd692SDk3Hv0bSinSN+yW9ua/z7Nd/NGsNEeGXXy3/Ag4AHgUeADqAY4F/BtrT/sXA/0zbh6T3duDHwNHpczfwl7lzbgRelbYn1rjmQuCbafsY4EngVcAsYIWpBcsAAAJ+SURBVE0f5fwcsIGslrUGuCOltwF3kS3w+DPgpNzxPwXGka2i2w0cBvwu8DVAKe9tZGt1zSJbgfpNuWt2AxP7+k7S9xDAO1P63wMXpu3lwMdy39dB/X23fo2ul5eKsVEhIl6UdDOwPSJ6JL0deBOwKmspYjywPh3+gbTsRjvZSr1Hs2+dr5tzp30IuF7SCrIf1Gr/HfhCuv5Dkp4m+4F/aYDifiGqmsgiYq+kPyYLOl+NiHtyu/85InYCOyX9KN3X24F3kgUjgAnA68gWQvxFRPykxnX7+052RMStaXs18La0fQrpQV+RNRn+aoDv1kYRBxgbTfay7/kxAq6OiE/nD5A0m+xJmSdGxNbU9JV/FO+Lue3Tgf9BVqv4P5KOjYg9+dONcPlnA9vJgl5edUdqpGt/LiKuyu+QNIve99BrN7W/k3Z6B8U99P7tqL5+zfPY6ON2URut7gD+QNIUAEmTJc0ga+LZRvaX+OFkQeQVJI0BOiNiJfBJYCrZUxXzfgT8YTr+N4HDgXVDKWzqI7mcbKHTaZLyTy08W1JHupe3AauA7wPzlR5CJqmzcq/96Os76c8PgY+k48coe9LmUM5jLcg1GBuVIuIBSZcAd6QO6F1kP5SryJrDHiRbofiuPk7RDnw7De9tAxZF9vjqvH8AviHpgXT+8yLipdRs1J9PpuawivcAnwe+HBHrJH0wlfvf0/6fALcC04GLI+IZ4HuS3gDck663jaw/pU/9fCf9PYr4Y8A3Jf0p2arGfxoR9/VxnqcGunFrLR6mbNbE6jms2Wx/uYnMzMwK4RqMmZkVwjUYMzMrhAOMmZkVwgHGzMwK4QBjZmaFcIAxM7NC/H9uTSjoTqrWJwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# Scatter plot of Experience vs Salary.\n", + "data.plot(x=\"YearsExperience\", y=\"Salary\",\n", + " kind=\"scatter\", title=\"Experience vs Salary\")\n", + "plt.xlabel(\"Years of Experience\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "739be4d1-2a46-49f1-9bc5-b07d598cbe28", + "metadata": {}, + "source": [ + "### Train Test Split\n", + "The dataset has to be split into a training set and a test set.\n", + "This can be done using the `preprocess_split()` api from mlpack.\n", + "Here the dataset has 30 observations and the `testRatio` is taken as 40% of the total observations.\n", + "This indicates the test set should have 40% * 30 = 12 observations and training test should have 18 observations respectively." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2cd31a2a", + "metadata": {}, + "outputs": [], + "source": [ + "# Split data into features (X) and targets (y).\n", + "\n", + "targets = data.Salary\n", + "features = data.drop(\"Salary\", axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9e82b675", + "metadata": {}, + "outputs": [], + "source": [ + "# Split the dataset using mlpack's preprocess_split method.\n", + "splitData = mlpack.preprocess_split(input=features, input_labels=targets, test_ratio=0.4, seed=101)" + ] + }, + { + "cell_type": "markdown", + "id": "91e0b6b8", + "metadata": {}, + "source": [ + "### Training the linear model\n", + "\n", + "Regression analysis is the most widely used method of prediction. Linear regression is used when the dataset has a linear correlation and as the name suggests, simple linear regression has one independent variable (predictor) and one dependent variable(response).\n", + "\n", + "The simple linear regression equation is represented as y = a+bx where x is the explanatory variable, y is the dependent variable, b is coefficient and a is the intercept\n", + "\n", + "To perform linear regression we'll be using `LinearRegression()` api from mlpack." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "5a642645", + "metadata": {}, + "outputs": [], + "source": [ + "# Create and train Linear Regression model.\n", + "output = mlpack.linear_regression(training=splitData[\"training\"],\n", + " training_responses=splitData[\"training_labels\"], \n", + " lambda_=0.5, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8b2e2bb4", + "metadata": {}, + "outputs": [], + "source": [ + "model = output[\"output_model\"]" + ] + }, + { + "cell_type": "markdown", + "id": "bf6ce883", + "metadata": {}, + "source": [ + "### Making Predictions on Test set" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "e41657ad", + "metadata": {}, + "outputs": [], + "source": [ + "# Predict the values of the test data.\n", + "predictions = mlpack.linear_regression(input_model=model, test=splitData[\"test\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "d3734f1a", + "metadata": {}, + "outputs": [], + "source": [ + "yPreds = predictions[\"output_predictions\"].reshape(-1, 1)" + ] + }, + { + "cell_type": "markdown", + "id": "53843549", + "metadata": {}, + "source": [ + "### Model Evaluation\n", + "Test data is visualized with `splitData[\"test\"]` and `yPreds`, the coral points indicates the data points and the blue line indicates the regression line or best fit line." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "531b842d", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZgAAAEWCAYAAABbgYH9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deXxU1fnH8c+XECAiiAtaSEBQEQVciUgVkYoaBBf0py22Lq22ttal1hYFqYqigsVaW1sX6r5UsBTBqggqKi4IgqiIiqBECSCLCCqChPD8/rgncWaYhACZmUzyvF+veTH33O25d8I8c86591yZGc4551xNa5DpAJxzztVNnmCcc86lhCcY55xzKeEJxjnnXEp4gnHOOZcSnmCcc86lhCcYVytJOkrSvEzHUZtIeknSL9O4v+GSLkvTvoZKeiS8byvpG0k527CdqyTdU/MRxu1jhqTOqdxHXeEJxsWRVCxpXfgPXv76R7rjMLNXzKxjuvebapI6S5os6UtJqyXNktQ303ElktQSOAe4O0z3krQp/D18LWmepF+kYt9m9pmZ7WhmZVuIsZekkoR1bzKzVCfhW4DrU7yPOsETjEvmpPAfvPx1cTp3LqlhOveXZv8DngP2AHYHLgW+SuUOt/F8/hx4xszWxZQtMbMdgebAlcC/JHWqof1lkyeBH0lqlelAajtPMK7aJN0paWzM9M2SXlCkl6SS0ESxMtSEfhazbGNJt0j6TNIySXdJygvzyte9UtLnwP2Jv04ltZb0X0krJC2UdGnMvKGSHpf0UPh1PVdSYcz8NpLGhXW/iK2RSTpP0gehRjFJ0p6VHPuzki5OKHtH0mnh+P8qabmkNZLeldQlyTZ2A9oD/zKzDeH1mpm9GubvLOmpEOeX4X1BJfHsLWlKOJ6Vkh6V1CJmfnE4n+8CayUNlPTfhG3cLum2ZNsHTgBeTjbDIuOBL4FOktpJMknnS/oMmBK2313S66Gm9o6kXjH7bi/p5fB5PQfsFjOvfHsNw/Quku6XtCScl/GSmgITgdYxNe3WimlqC+ueHP4eVitqYtw/4Rz9MXxeaySNkdSk/LMK53+1pFWSXpHUIBz/emAWcHwl586VMzN/+aviBRQDx1YybwfgI6Jft0cBK4GCMK8XsBG4FWgMHA2sBTqG+bcR/fLbBWhG9Et+eMK6N4d180JZSZjfgOg/9DVAI2Av4BOgKMwfCqwH+gI5wHDgjTAvB3gH+CvQFGgC9Ajz+gMLgP2BhsCfgNcrOfZzgNdipjsBq0O8RSG+FoDC9lol2YaA+cBTYd97JMzfFfi/cJ6bAf8BxsfMfwn4ZXi/D3Bc2H9LYCpwW8Ln+DbQJpzPVuHzaBHmNwSWA10rOd4VwGEx04mfx6lAKdARaAcY8FA4x3lAPvBF+EwahFi/AFqGbUzj+7+VnsDXwCNhXvn2Gobpp4ExwM5ALnB0YkwxcQ6N2c6+4ZiPC+tdET7vRjHnaAbQmujv8gPgN2HecOCusF4u0d+7Yvbzd+DWTP9/re2vjAfgr9r1Cv/pvglfnuWvX8XM7wasAj4Fzowp70WUJJrGlD0OXE30xboW2Dtm3g+BhTHrbgCaJGyv/AvtcOCzhDgHA/eH90OB52PmdQLWxexnRfmXVcI2JgLnx0w3AL4F9kyybLNwDHuG6RuB+8L7Y4gSb3egwRbObwHwD+BjYBNRYuhQybIHA1/GTL9ESDBJlu0PzE74HM9Lcry/Cu9PBN6vIs5SYL+Ez2NT+HtYRZS8BoR57YgSwl4xy18JPJywzUnAuUDbJH8r/yZJgiFKjJuAnZPEWPE3ElM2NGY7VwOPJ3y+i4FeMeforJj5fwbuCu+vByYA+1Ryfio+f39V/vImMpdMfzNrEfP6V/kMM5tBVHsQUQKJ9aWZrY2Z/pTo12FLol/ls0KTw2rg2VBeboVFTQ/J7EnUFLI6Zv2riPoxyn0e8/5boEloYmkDfGpmGyvZ7t9itrkqHFd+4oJm9jXRL+kBoWgA8GiYN4UoafwTWCZplKTmyQ7EzErM7GIz2zvsfy3RL38k7SDpbkmfSvqKKPm0UJKrqSTtLmm0pMVh2UeIaWYKFiVMPwicFd6fBTycLMbgS6KkGmtJ+HvYxcwONrPRVexvT+CMhM+sB1HCaE3yv5Vk2gCrzOzLKmKtTOvY7ZrZphBj7Oeb+HezY3g/kqi2M1nSJ5IGJWy7GVGydVXwBOO2iqSLiJo1lhA1OcTaObSNl2sbllsJrAM6xyStnSzqMC5X1bDei4hqO7FJr5mZVefqq0VAWyXveF4E/Dphu3lm9nol23oMOFPSD4magV6sCN7s72bWFehM1DQzcEuBmdkioqRU3l/zB6Imp8PNrDlR0xFESS/RcKJzdmBY9qwkyyWe0/HAgaF/6ERCgqzEu+E4tkbs/hYR1WBiz21TMxsBLCX530oyi4BdYvuXKtlfMkuIEh0AkkSUsBZv8UDMvjazP5jZXsBJwOWSescssj9R06urgicYV22S9gVuIPoyOxu4QtLBCYtdJ6mRpKOIvsT+E345/gv4q6Tdw7byJRVVc9czgK9Cp3WepBxJXSQdVs11lwIjJDWV1ETSkWHeXcBghXsaJO0k6YwqtvUM0RfW9cCYcFxIOkzS4ZJyiWok64HNLrENnfjXSdpHUoPQ6X8e8EZYpBlRIl4taRfg2ipiaUZoypSUT/US2npgLFFz1Awz+2wLx3r0lrZZhUeAkyQVhc+riaILNwrM7FNgJt//rfQg+hJPFvNSoqa9O8L5y5VUnniXAbtK2qmSGB4H+knqHT6bPwDfAZX9gKgg6cTwOYnoKr+y8EJSY6Ar0dWArgqeYFwy/1P8fTBPhBrAI8DNZvaOmc0naqZ6OPyHg6i54UuiX46PEnWYfhjmXUnU5PBGaNJ5nujX+hZZdD/ESUR9EguJakT3AJV9sSRbdx/gM6AE+EmY9wTRhQWjQ0zvEV09Vdm2vgPGAccSfUmXa06UQL8kapL5guheiUQbiPoXnif60nqP6Avv52H+bUQ1o5VESefZKg7tOuBQYA1R0924KpaN9SBwAFU3j0HUbNdX4Uq/rRVqZ6cQ/Y2sIKqJDOT775yfEvWtrSJKpA9VsbmzifqEPiS6MOGysI8PiWqVn4RmuNYJMcwj+jF0O9E5PYnoEvwN1TiEDkSf0zdEFyTcYWYvhXknAy+Z2ZJqbKdek5k/cMxtP0WXoD5iZkkvq3W1g6S2RF/UPzCzKu+/kXQTsNzMKruUuV6SNJ3o4pD3Mh1LbVfXb4hyzgXhPo7LgdFbSi4AZnZV6qPKPmZ2eKZjyBaeYJyrB0KH+jKiJrw+GQ7H1RPeROaccy4lvJPfOedcSngTWbDbbrtZu3btMh2Gc85llVmzZq00s5bJ5nmCCdq1a8fMmTMzHYZzzmUVSZWNwuBNZM4551LDE4xzzrmU8ATjnHMuJbwPpgqlpaWUlJSwfn1lg/y6dGrSpAkFBQXk5uZmOhTnXDV4gqlCSUkJzZo1o127dkRj3rlMMTO++OILSkpKaN++fabDcc5VgzeRVWH9+vXsuuuunlxqAUnsuuuuXpt0Lot4gtkCTy61h38WzmUXTzDOOVdPffPdRm6dPI/lX6emZcATTC234447blZ211138dBDVT0+o+b16tWLtm3bEjt2Xf/+/ZPGV5Wf//znjB07druXcc5tnwdeW0iXayfx9ykLmFW8LU+k3jLv5M9Cv/nNb1K6fTPDzGjQIP73R4sWLXjttdfo0aMHq1evZunSpSmNwzlX81at3cChw75/GOfZ3ffkhANapWRfXoPJQkOHDuWWW6IHJvbq1Ysrr7ySbt26se+++/LKK68AUFZWxsCBAznssMM48MADufvuuwH45ptv6N27N4ceeigHHHAAEyZMAKC4uJj999+f3/72txx66KEsWrRos/0OGDCA0aNHAzBu3DhOO+20inlmxsCBA+nSpQsHHHAAY8aMqSi/+OKL6dSpE/369WP58uUV68yaNYujjz6arl27UlRU5AnLuRS7dfK8uOQybfAxDOvfJWX78xpMNV33v7m8v2SLz2jaKp1aN+fakzpv93Y2btzIjBkzeOaZZ7juuut4/vnnuffee9lpp5148803+e677zjyyCM5/vjjadOmDU888QTNmzdn5cqVdO/enZNPPhmAefPmcf/993PHHXck3U/v3r351a9+RVlZGaNHj2bUqFEMGzYMiBLO22+/zTvvvMPKlSs57LDD6NmzJ9OmTWPevHnMmTOHZcuW0alTJ8477zxKS0u55JJLmDBhAi1btmTMmDEMGTKE++67b7vPh3Mu3pLV6zhixJSK6cuO7cBlx+6b8v16gqkDymsSXbt2pbi4GIDJkyfz7rvvVvRlrFmzhvnz51NQUMBVV13F1KlTadCgAYsXL2bZsmUA7LnnnnTv3r3S/eTk5NCjRw/GjBnDunXriB19+tVXX+XMM88kJyeHPfbYg6OPPpo333yTqVOnVpS3bt2aY445BoiS2Xvvvcdxxx0HRDWuVq1SU013rj676ok5/Hv6ZxXTs68+jp2bNkrLvj3BVFNN1DRSpXHjxkCUADZu3AhETVO33347RUVFccs+8MADrFixglmzZpGbm0u7du0q7i1p2rTpFvc1YMAATj31VIYOHRpXXtWD65JdXmxmdO7cmWnTpm1xn865rbdg+dcce+vUiulhp3Tm7B+2i1tm/OzFjJw0jyWr19G6RR4DizrS/5D8GovB+2DqqKKiIu68805KS0sB+Oijj1i7di1r1qxh9913Jzc3lxdffJFPP610pO2kjjrqKAYPHsyZZ54ZV96zZ0/GjBlDWVkZK1asYOrUqXTr1o2ePXsyevRoysrKWLp0KS+++CIAHTt2ZMWKFRUJprS0lLlz59bAkTtXv5kZv3poZkVyyWkg5l5XlDS5DB43h8Wr12HA4tXrGDxuDuNnL66xWLwGU8t9++23FBQUVExffvnl1Vrvl7/8JcXFxRx66KGYGS1btmT8+PH87Gc/46STTqKwsJCDDz6Y/fbbb6vikcQf//jHzcpPPfVUpk2bxkEHHYQk/vznP/ODH/yAU089lSlTpnDAAQew7777cvTRRwPQqFEjxo4dy6WXXsqaNWvYuHEjl112GZ07196aonO13TuLVnPKP1+rmL79zEM46aDWSZcdOWke60rL4srWlZYxctK8GqvFqKqmjfqksLDQEh849sEHH7D//vtnKCKXjH8mzm1u0ybj1Dtf551FqwHYo3ljXrniGBo1rLyRqv2gp0n27S9g4Yh+1d63pFlmVphsntdgnHMui706fyVn3Tu9YvrB87px9L5Jn2Acp3WLPBavXpe0vKakrA9G0n2Slkt6L6ZspKQPJb0r6QlJLWLmDZa0QNI8SUUx5V0lzQnz/q7QYyypsaQxoXy6pHYx65wraX54nZuqY3TOuUwpLdvEkSOmVCSXAwt24uOb+lYruQAMLOpIXm5OXFlebg4DizrWWIyp7OR/AOiTUPYc0MXMDgQ+AgYDSOoEDAA6h3XukFR+5HcCFwAdwqt8m+cDX5rZPsBfgZvDtnYBrgUOB7oB10raeVsPwpsQaw//LJyLPP3uUjoMmVhRAxn32yN48uIe5DSo/oCw/Q/JZ/hpB5DfIg8B+S3yGH7aATV6FVnKmsjMbGpsrSKUTY6ZfAM4Pbw/BRhtZt8BCyUtALpJKgaam9k0AEkPAf2BiWGdoWH9scA/Qu2mCHjOzFaFdZ4jSkqPbe0xNGnShC+++MKH7K8Fyp8H06RJk0yH4lzGfLthIwddN5nSsujHVu/9dueecwu3+fup/yH5NZpQEmWyD+Y8YEx4n0+UcMqVhLLS8D6xvHydRQBmtlHSGmDX2PIk62yVgoICSkpKWLFixbas7mpY+RMtnauPHp3+KUOeqOhxYPLve7LvHs0yGNGWZSTBSBoCbAQeLS9KsphVUb6t6yTGcQFR8xtt27bdbH5ubq4/PdE5l1HLvlrP4Te9UDH9k8I23Hz6gRmMqPrSnmBCp/uJQG/7vlG9BGgTs1gBsCSUFyQpj12nRFJDYCdgVSjvlbDOS8liMbNRwCiILlPe1mNyzrlUOOe+GUz96PsWlFev/BEFO++QwYi2Tlrv5JfUB7gSONnMvo2Z9SQwIFwZ1p6oM3+GmS0FvpbUPfSvnANMiFmn/Aqx04EpIWFNAo6XtHPo3D8+lDnnXFaY/skXtBv0dEVyObhNC4pH9Muq5AIprMFIeoyoJrGbpBKiK7sGA42B50Kn1Btm9hszmyvpceB9oqazi8ys/BbTC4muSMsj6tyfGMrvBR4OFwSsIroKDTNbJWkY8GZY7vryDn/nnKvt2g16Om766hP35/wee2Uomu3jd/IHye7kd865dBk/ezGXjXl7s/K83Jwav3y4Jvmd/M45V0uZGe0HP1Pp/JoeHyydPME451yG/PPFBYycNG+Lyy1JMqRLNvDh+p1zLs3KNhntBj0dl1xmX30c+ZWMA1aT44OlkycY55xLo4H/eYe9r/q+SeyIvXeleEQ/dm7aKC3jg6WTN5E551warNtQxv7XPBtX9uGwPjSJSSjl/SypfMpkOnmCcc65FDvtjtd467PVFdM/O7wtN556QNJlUz0+WDp5gnHOuRRZ+c13FN7wfFzZJzf1pcFWjHqczTzBOOdcCnS5dhLffLexYvqqvvtxQc+9MxhR+nmCcc65GrRw5Vp+dMtLcWXFW/EI4kTjZy/O2j4ZTzDOOVdDEod5+cdPD+HEA1tv8/bGz17M4HFzWFcajZy1ePU6Bo+bA5AVScYvU3bOue0069NVmyWX4hH9tiu5QHQ1WXlyKVd+Z3828BqMc85th8TE8vivf0i39rvUyLYru4M/W+7s9wTjnHPbYOKcpVz46FtxZdvT15JM6xZ5LE6STLLlzn5PMM45t5USay0v/OFo9m65Y43vZ2BRx7g+GMiuO/s9wTjnXDXd++pChj31fsV044YNmHfDCSnbX7bf2e8JxjnntmDTJmOvq+KH1H9zyLG0bNY45fvO5jv7PcE451wVrp3wHg9O+7Ri+qCCnZhwcY8MRpQ9PME451wS320so+Of4genfP/6InZo5F+b1eVnyjnnEpx973Remb+yYvq0Q/K59ScHZzCi7OQJxjnngtXfbuDg65+LK1tw4wk0zPF70reFJxjnnAO63/QCn3+1vmL698fuy++O7ZDBiLKfJxjnXL22aNW3HPXnF+PKFg7vi1Q/htRPJU8wzrl6K/GGyVvOOIjTuxZkKJq6xxOMc67emVOyhpP+8WpcWU0P8+I8wTjn6pnEWssj5x9Ojw67ZSiaus0TjHOuXpjy4TLOe2BmXJnXWlLLE4xzrs5LrLVM/N1R7N+qeYaiqT88wTjn6qxHp3/KkCfeiyvzWkv6eIJxztU5Zkb7wfGDU74xuDc/2KlJhiKqnzzBOOfqlBETP+Sulz+umN6rZVOm/KFX5gKqxzzBOOfqhNKyTXQYMjGubM7Q42nWJDdDETlPMM65rPebh2fx7NzPK6aLOu/B3WcXZjAiB55gnHNZ7NE3PmXI+PhO/Pk3nkCuD05ZK3iCcc5lpcRLjyF6Xv3T7y7N2idA1jWe5p1zWeWTFd8kTS4A60rLGDlpXpojcpVJWYKRdJ+k5ZLeiynbRdJzkuaHf3eOmTdY0gJJ8yQVxZR3lTQnzPu7whCnkhpLGhPKp0tqF7POuWEf8yWdm6pjdM6lV7tBT3PMX16ucpklq9elKRq3JamswTwA9EkoGwS8YGYdgBfCNJI6AQOAzmGdOyTlhHXuBC4AOoRX+TbPB740s32AvwI3h23tAlwLHA50A66NTWTOuewz/ZMvNqu15LfIS7ps60rKXfqlLMGY2VRgVULxKcCD4f2DQP+Y8tFm9p2ZLQQWAN0ktQKam9k0MzPgoYR1yrc1FugdajdFwHNmtsrMvgSeY/NE55zLEu0GPc1PRr1RMT2sfxeKR/RjYFFH8nJz4pbNy81hYFHHdIfoKpHuTv49zGwpgJktlbR7KM8H3ohZriSUlYb3ieXl6ywK29ooaQ2wa2x5knXiSLqAqHZE27Ztt/2onHM1btxbJVz++DtxZbHDvJR35I+cNI8lq9fRukUeA4s6egd/LVJbriJL9ug4q6J8W9eJLzQbBYwCKCwsTLqMcy79EpvDHvjFYfTquPtmy/U/JN8TSi2W7gSzTFKrUHtpBSwP5SVAm5jlCoAlobwgSXnsOiWSGgI7ETXJlQC9EtZ5qWYPwzmXCn+ZPI/bpyyIK/PBKbNXui9TfhIov6rrXGBCTPmAcGVYe6LO/BmhOe1rSd1D/8o5CeuUb+t0YErop5kEHC9p59C5f3woc87VUmZGu0FPxyWXZy87ypNLlktZDUbSY0Q1id0klRBd2TUCeFzS+cBnwBkAZjZX0uPA+8BG4CIzKwubupDoirQ8YGJ4AdwLPCxpAVHNZUDY1ipJw4A3w3LXm1nixQbOuVrigodmMvn9ZXFlnljqBkU/+l1hYaHNnDlzyws652pEssEpZ1zVm92b+5D62UTSLDNLOvBbbenkd87VI0eOmMLihBsivdZS93iCcc6lzVfrSzlw6OS4sg+u70Neo5xK1nDZzBOMcy4tEi897ty6OU9felSGonHp4AnGOZdSJV9+S4+bX4wr++SmvjRokOyWNVeXeIJxzqVMYq3l9K4F3HLGQRmKxqWbJxjnXI17e9Fq+v/ztbgy78SvfzzBOOdqVGKt5aq++3FBz70zFI3LJE8wzrkaMXHOUi589K24Mq+11G+eYJxz2y2x1nLXWV3p0+UHGYrG1RaeYJxz2+zOlz7m5mc/jCvzWosr5wnGObdNEmstEy46koPatMhQNK428gTjnNsqlz/+NuPeWhxX5rUWl0yVCUZSAzPbFDP9M6AZ8JCZfZvq4JxztUfZJmPvq56JK3tt0DHkt8jLUESutttSDeZpSZeb2QeShgA9gU+A0cDJKY/OOVcr9LltKh9+/nVcmdda3JZUmmAkHU304K+WknYHzgauAr4A/iWpJ1BsZp+lJVLnXNp9u2Ejna6Jf17fnKHH06xJboYictlkSzWYBkBzoClQBqwkeub9+jDfBxNyro5K7MRvu8sOTL3iRxmKxmWjShOMmb0s6RHgZmBHYKiZTZW0K7DCzKamK0jnXPos+2o9h9/0QlzZghtPoGFOup+w7rJdlTUYM7tG0r+BjWZW/rDsBsAFKY/MOZd2ibWWvgf8gDt+1jVD0bhst8XLlM3sw4TpFcCKlEXknEu795d8Rd+/vxJXtnB4XyRvBXfbzu+Dca6eS6y1XNq7A5cft2+GonF1iScY5+qpFz9czi8eeDOuzC89djXJE4xz9VBireW2nxxM/0PyMxSNq6u2mGAkHQkMBfYMywswM9srtaE552rag68Xc+2Tc+PKvNbiUqU6NZh7gd8Ds4juhXHOZaHEWsvjv/4h3drvkqFoXH1QnQSzxswmpjwS51xKXDPhPR6a9mlcmddaXDpUJ8G8KGkkMA74rrzQzN6qfBXnXKaZGe0Hxw9O+dIfe9Fut6YZisjVN9VJMIeHfwtjygw4pubDcc7VhB/fNY0ZxaviyrzW4tKtOjda+uBDzmWJ9aVl7Hf1s3Flb19zHC12aJShiFx9VtVoymeZ2SOSLk8238xuTV1Yzrmttd/VE1lfWvH4JnbKy+Wda4/PYESuvquqBlPeUNssHYE457bNF998R9cbno8r++iGE2jU0AendJlV1WjKd4d/r0tfOM65rZF46fFRHXbj4fMPr2Rp59LL7+R3LgstXbOOHw6fElfmg1O62sYTjHNZJrHWsmPjhtzQv4snF1frbLGRVlJOOgJxzlXtg6VfbZZcAL75biODx81h/OzFGYjKucpVpwazQNJY4H4zez/VATnnNpcsscRaV1rGyEnzfMBKV6tU5zKTA4GPgHskvSHpAknNt2enkn4vaa6k9yQ9JqmJpF0kPSdpfvh355jlB0taIGmepKKY8q6S5oR5f1doI5DUWNKYUD5dUrvtide5THll/orNkktlDWFLVq9LfUDObYUtJhgz+9rM/mVmRwBXANcCSyU9KGmfrd2hpHzgUqDQzLoAOcAAYBDwgpl1AF4I00jqFOZ3BvoAd8Q0291J9PjmDuHVJ5SfD3xpZvsAfwVu3to4ncu0doOe5ux7Z1RMP3VJD4pH9KN1i7yky1dW7lymVKsPRtLJkp4A/gb8BdgL+B/wTJUrV64hkCepIbADsAQ4BXgwzH8Q6B/enwKMNrPvzGwhsADoJqkV0NzMppmZAQ8lrFO+rbFA7/LajXO13X9mLtqs1lI8oh9d8ncCYGBRR/Jy47tG83JzGFjUMW0xOlcd1emDmQ+8CIw0s9djysdK6rm1OzSzxZJuAT4D1gGTzWyypD3MbGlYZqmk3cMq+cAbMZsoCWWl4X1iefk6i8K2NkpaA+wKrIyNRdIFRDUg2rZtu7WH4lyNSjY45atX/oiCnXeIKyvvZxk5aR5LVq+jdYs8BhZ19P4XV+tUmWBCU9QDZnZ9svlmdunW7jD0rZwCtAdWA/+RdFZVqyTbdRXlVa0TX2A2ChgFUFhYuNl859Llr899xN9emF8xnd8ij9cGVT6ebP9D8j2huFqvygRjZmWSfgQkTTDb6FhgoZmtAJA0DjgCWCapVai9tAKWh+VLgDYx6xcQNamVhPeJ5bHrlIRmuJ2A+KFlnasFNpZtYp8h8Y9beuea49lph9wMReRczanOVWSvS/qHpKMkHVr+2o59fgZ0l7RD6BfpDXwAPAmcG5Y5F5gQ3j8JDAhXhrUn6syfEZrTvpbUPWznnIR1yrd1OjAl9NM4V2v8fszbccmlV8eWFI/o58nF1RnV6YM5IvwbW4vZ5ufBmNn0cF/NW8BGYDZRM9WOwOOSzidKQmeE5edKehx4Pyx/kZmVP7r5QuABIA+YGF4QPeb5YUkLiGouA7YlVudS4dsNG+l0zaS4snk39KFxQ7+n2dUt8h/2kcLCQps5c2amw3B13Em3v8qcxWsqpn9+RDuGntw5gxE5t30kzTKzwmTzqjUWmaR+RPehNCkvq6zj3zm3ueVfr6fbjS/ElX1yU18aNPCr513dtcUEI+kuomN70ycAABT9SURBVHtVfgTcQ9SnMaPKlZxzFfYdMpENZd8/COzakzrxiyPbZzAi59KjWn0wZnagpHfN7DpJfwHGpTow57LdguXfcOytL8eVFY/ol6FonEu/6iSY8gGOvpXUGviC6B4W51wlEu/Ev+usrvTp8oMMReNcZlQnwTwlqQUwkujKLyNqKnPOJZixcBU/vntaXJnXWlx9tcUEY2bDwtv/SnoKaGJma6pax7n6KLHW8t8Lj6DrnjtXsrRzdV+lCUbSaVXMw8y8H8Y54H/vLOGSx2bHlXmtxbmqazAnVTHP8I5+5zartbz0x160261phqJxrnapNMGY2S/SGYhz2eTulz9m+MQPK6abN2nIu0OLqljDufrHb7R0bits2mTsdVX8kPqz/nQsu+7YOEMROVd7+Y2WzlXT4HFzeGzGZxXTh7Xbmf/85ogq1nCufvMbLZ3bgvWlZex39bNxZR8O60OTXB+c0rmq+I2WzlXhJ3dPY/rC7x8l9JPCNtx8+oEZjMi57LGtN1r+K6VROZdhq9Zu4NBhz8WVfXxTX3J8cErnqs1vtHQuwaHDnmPV2g0V01f06chve+2TwYicy05V3Wh5GLDIzD4P0+cA/wd8KmmomfkjiF2d8ukXazl65EtxZX7DpHPbrqoazN3AsQCSegIjgEuAg4meQHl6yqNzLk0Sb5j824CDOeXg/AxF41zdUFWCyYmppfwEGGVm/yVqKns79aE5l3qvL1jJT++ZHle2pVrL+NmLGTlpHktWr6N1izwGFnWk/yGejJxLVGWCkdTQzDYCvYELqrmec1khsdZy388LOWa/PapcZ/zsxQweN4d1pWUALF69jsHj5gB4knEuQVWJ4jHgZUkriS5VfgVA0j6Ad/K7rDV2Vgl//M87cWXV7WsZOWleRXIpt660jJGT5nmCcS5BVWOR3SjpBaAVMNnMLMxqQNQX41zWSay1PHVJD7rk71Tt9ZesXrdV5c7VZ1U2dZnZG0nKPkpdOM6lxq2T5/H3KQviyrblCrHWLfJYnCSZtG6Rt82xOVdXeV+Kq9OSDU752qBjyN/GhDCwqGNcHwxAXm4OA4s6blecztVFnmBcnXXRo2/x9JylFdONchrw0Y0nbNc2y/tZ/Coy57bME4yrc5INTvnu0ONp3iS3Rrbf/5B8TyjOVYMnGFen9P7LS3y8Ym3F9MFtWjD+oiMzGJFz9ZcnGFcnfLl2A4ckDE45/8YTyM1pkKGInHOeYFzWS7z0+MeFBfz59IMyFI1zrpwnGJe1kg1OuXB4XyQfUt+52sATjMtKibWWQSfsx2+O3jtD0TjnkvEE47LKrE9X8X93Tosr8yH1naudPMG4rJFYa/nnTw+l34GtMhSNc25LPMG4Wu+pd5dw8b9nx5XVVK3Fh953LnU8wbhaLbHW8t8Lf0jXPXepkW370PvOpVZGbhKQ1ELSWEkfSvpA0g8l7SLpOUnzw787xyw/WNICSfMkFcWUd5U0J8z7u8LlQ5IaSxoTyqdLapf+o3Tb486XPt4suRSP6FdjyQWqHnrfObf9MlWD+RvwrJmdLqkRsANwFfCCmY2QNAgYBFwpqRMwAOgMtAael7SvmZUBdxI9CO0N4BmgDzAROB/40sz2kTQAuJnoqZyuljMz2g+OH5zypT/2ot1uTWt8Xz70vnOplfYajKTmQE/gXgAz22Bmq4FTgAfDYg8C/cP7U4DRZvadmS0EFgDdJLUCmpvZtPCsmocS1inf1ligd3ntxtVeV4x9Z7PkUjyiX0qSC1Q+xL4Pve9czchEDWYvYAVwv6SDgFnA74A9zGwpgJktlbR7WD6fqIZSriSUlYb3ieXl6ywK29ooaQ2wK7AyNhBJFxAeBd22bduaOj63lUrLNtFhyMS4stlXH8fOTRuldL8+9L5zqZWJBNMQOBS4xMymS/obUXNYZZLVPKyK8qrWiS8wGwWMAigsLNxsvku9/v98jbcXra6Y3qtlU6b8oVd69u1D7zuXUplIMCVAiZlND9NjiRLMMkmtQu2lFbA8Zvk2MesXAEtCeUGS8th1SiQ1BHYCVqXiYNy2+Wp9KQcOnRxX9uGwPjTJzUlrHD70vnOpk/Y+GDP7HFgkqbwdojfwPvAkcG4oOxeYEN4/CQwIV4a1BzoAM0Jz2teSuof+lXMS1inf1unAlNBP42qBDkOeiUsu/Q5oRfGIfmlPLs651MrUVWSXAI+GK8g+AX5BlOwel3Q+8BlwBoCZzZX0OFES2ghcFK4gA7gQeADII7p6rLwh/17gYUkLiGouA9JxUK5qi1ev48gRU+LKanJwSr9p0rnaRf7DPlJYWGgzZ87MdBh1VuI9LZf27sDlx+1bY9tPvGkSog774acd4EnGuRSSNMvMCpPN8zv5XUq9t3gNJ97+alxZKganrOqmSU8wzmWGJxiXMom1llvOOIjTuxZUsvS2KW8WW+w3TTpX63iCcTXu+feX8cuH4psbU1FrSdYslshvmnQuczzBuBqVWGv59y8P54h9dkvJvpI1i8XymyadyyxPMK5GPPh6Mdc+OTeuLNUPAquq+SvfryJzLuM8wbjtllhree73PemwR7OU77d1i7ykfS/5LfJ4bdAxKd+/c65qGRmu39UNQ5+cm3RI/XQkF4jGEstLuDnTm8Wcqz28BuO2WtkmY++r4kc9fnPIsbRs1jitcfhYYs7Vbp5g3FY5657pvLrg+0GpWzZrzJtDjs1YPD6WmHO1lycYVy3fbthIp2smxZV9cH0f8hr5+GHOueQ8wbgt6jrsOb5Yu6Fi+uh9W/Lged0yGJFzLht4gnGVWv7Verrd9EJc2cc39SWngT8c1Dm3ZZ5gXFKJV4ed36M9V5/YKUPROOeykScYF2fe519TdNvUuLJU3zDpnKubPMG4Com1lmH9u3B29z0zFI1zLtt5gnG8On8lZ907Pa7May3Oue3lCaaeS6y13P+Lw/hRx90zFI1zri7xBFNPPT5zEVeMfTeuzGstzrma5AmmHkqstTx1SQ+65O+UoWicc3WVJ5h6ZOSkD/nnix/HlXmtxTmXKp5g6oFNm4y9EganfH3QMf60R+dcSnmCqeN+/fBMJs1dVjGdl5vDB8P6ZDAi51x94QmmjlpfWsZ+Vz8bVzZn6PE0a5KboYicc/WNJ5g6qNfIFyn+4tuK6a577sx/LzwigxE55+ojTzB1yKq1Gzh02HNxZQtuPIGGOf7gUudc+nmCqSMSLz0+s1sbhp92YIaicc45TzBZb+HKtfzolpfiy4b3RfIh9Z1zmeUJJkuNn72Yy8a8HVc2pO/+/KrnXhmKyDnn4nmCyUJ3v/wxwyd+GFeWl5tDy2aNMxSRc85tzhNMlknsaym3rrSMkZPm0f+Q/DRH5JxzyfnlRVliZvGqSpNLuSWr16UpGuec2zKvwWSBxMSye7PGLP/6u82W86FfnHO1iddgarFn3/s8Lrnsu8eOFI/ox1V99ycvNydu2bzcHAYWdUx3iM45VymvwdRCZkb7wfGDU7455NiKTvzyfpaRk+axZPU6WrfIY2BRR+9/cc7VKp5gapmHphVzzYS5FdPHd9qDUecUbrZc/0PyPaE452q1jCUYSTnATGCxmZ0oaRdgDNAOKAZ+bGZfhmUHA+cDZcClZjYplHcFHgDygGeA35mZSWoMPAR0Bb4AfmJmxWk7uG2wsWwT+wyZGFc297oimjb23wDOueyUyT6Y3wEfxEwPAl4wsw7AC2EaSZ2AAUBnoA9wR0hOAHcCFwAdwqt8HPrzgS/NbB/gr8DNqTqI8bMXc+SIKbQf9DRHjpjC+NmLt3obwyd+EJdcftmjPcUj+nlycc5ltYx8g0kqAPoBNwKXh+JTgF7h/YPAS8CVoXy0mX0HLJS0AOgmqRhobmbTwjYfAvoDE8M6Q8O2xgL/kCQzs5o8jvGzFzN43BzWlZYBsHj1OgaPmwNQrearbzdspNM1k+LK5t94Ark+OKVzrg7I1DfZbcAVwKaYsj3MbClA+Hf3UJ4PLIpZriSU5Yf3ieVx65jZRmANsGtiEJIukDRT0swVK1Zs9UGMnDSvIrmUK7/hcUv+MnleXHK55sROFI/o58nFOVdnpL0GI+lEYLmZzZLUqzqrJCmzKsqrWie+wGwUMAqgsLBwq2s3ld3YWNUNj99u2MhB102mtOz73fnglM65uigTP5ePBE4OTVyjgWMkPQIsk9QKIPy7PCxfArSJWb8AWBLKC5KUx60jqSGwE7Cqpg+kshsbKyt/dPqndLpmUlxyyW+Rx4S3lyRd3jnnslnaE4yZDTazAjNrR9R5P8XMzgKeBM4Ni50LTAjvnwQGSGosqT1RZ/6M0Iz2taTuin7+n5OwTvm2Tg/7qNH+F4CBRR2rdcPj6m830G7Q0wx54j0AcmJqK+X9NttycYBzztVmtanBfwRwnKT5wHFhGjObCzwOvA88C1xkZuUdHxcC9wALgI+JOvgB7gV2DRcEXE64Iq2m9T8kn+GnHUB+izxEVBsZftoBcR38/5gyn4Ov//4pk3s0a0xZQq6rbr+Nc85lE6Xgh31WKiwstJkzZ9bY9j5fs57uw1+omP5tr725os9+tB/09OadQUSdRgtH9Kux/TvnXDpImmVmm98Njt/JnxJDn5zLA68XV0zP/NOx7LZjNMxL6xZ5LE5yEYAPVOmcq2tqUxNZ1lu4ci3tBj1dkVz+1G9/ikf0q0guUP1+G+ecy3Zeg6kBZsYlj83mqXeXVpTNGXo8zZrkbrasD1TpnKsvPMHUgHPum8Er81cC8JczDuL/uhZUubwPVOmcqw88wdSAAYe1ZZMZ9557GE0Smr+cc66+8gRTA/od2Ip+B7bKdBjOOVereCe/c865lPAE45xzLiU8wTjnnEsJTzDOOedSwhOMc865lPAE45xzLiU8wTjnnEsJTzDOOedSwofrDyStAD7NdBzbYDdgZaaDqAX8PET8PET8PETScR72NLOWyWZ4gslykmZW9iyG+sTPQ8TPQ8TPQyTT58GbyJxzzqWEJxjnnHMp4Qkm+43KdAC1hJ+HiJ+HiJ+HSEbPg/fBOOecSwmvwTjnnEsJTzDOOedSwhNMFpLURtKLkj6QNFfS7zIdUyZJypE0W9JTmY4lUyS1kDRW0ofh7+KHmY4pEyT9PvyfeE/SY5KaZDqmdJF0n6Tlkt6LKdtF0nOS5od/d05nTJ5gstNG4A9mtj/QHbhIUqcMx5RJvwM+yHQQGfY34Fkz2w84iHp4PiTlA5cChWbWBcgBBmQ2qrR6AOiTUDYIeMHMOgAvhOm08QSThcxsqZm9Fd5/TfRlkp/ZqDJDUgHQD7gn07FkiqTmQE/gXgAz22BmqzMbVcY0BPIkNQR2AJZkOJ60MbOpwKqE4lOAB8P7B4H+6YzJE0yWk9QOOASYntlIMuY24ApgU6YDyaC9gBXA/aGp8B5JTTMdVLqZ2WLgFuAzYCmwxswmZzaqjNvDzJZC9MMU2D2dO/cEk8Uk7Qj8F7jMzL7KdDzpJulEYLmZzcp0LBnWEDgUuNPMDgHWkuamkNog9C+cArQHWgNNJZ2V2ajqN08wWUpSLlFyedTMxmU6ngw5EjhZUjEwGjhG0iOZDSkjSoASMyuvxY4lSjj1zbHAQjNbYWalwDjgiAzHlGnLJLUCCP8uT+fOPcFkIUkiam//wMxuzXQ8mWJmg82swMzaEXXmTjGzeveL1cw+BxZJ6hiKegPvZzCkTPkM6C5ph/B/pDf18GKHBE8C54b35wIT0rnzhuncmasxRwJnA3MkvR3KrjKzZzIYk8usS4BHJTUCPgF+keF40s7MpksaC7xFdKXlbOrRkDGSHgN6AbtJKgGuBUYAj0s6nygBn5HWmHyoGOecc6ngTWTOOedSwhOMc865lPAE45xzLiU8wTjnnEsJTzDOOedSwhOMq9MUeVXSCTFlP5b0bIZjelzSu5IuTZh3g6TFkt6OeTVLcTyTUr0PVz/5ZcquzpPUBfgP0ZhtOcDbQB8z+3g7ttnQzDZu47oFwMtmtneSeTcAK83stm2NbSviENF3QH0ex82lkNdgXJ1nZu8B/wOuJLr57CEz+1jSuZJmhFrCHZIaAEgaJWlmeK7INeXbkVQi6WpJrwGnhmePvC/pnWRD1EjKk/SgpDmS3pLUM8yaDLQO+63WUCaSrpA0Krw/OGwzL9R4HgzPB5ov6byYdQaF43u3/Dgk7ROelXIX0Q2JrcJxtQjzNzsnkhpKWi1pRDjWaZJ2D8v/QNKEsI93JB1e2Xa26kNzdYOZ+ctfdf4FNAXmAXOAxkAXYDzQMMwfBfw0vN8l/NsQeAXoFKZLgMtjtrkUaBTet0iyzyuBf4X3nYFPgUbAPsDblcR5A7CYqJb1NvB8KG8AvEY0mONsoHvM8m8BTYhGyi0B9gD6AncACus+SzQu1z5EI08fFrPPEqBFZecknAcDTgjltwKDwvv/AhfHnK/mVZ1bf9Wvlw8V4+oFM1sraQzwjZl9J+lY4DBgZtRSRB6wKCx+ZhhaoyHRqLyd+H5srzExm50LPCJpAtEXaqIewMiw/7mSlhB9wW/YQrgjLaGJzMw2Sfo5UdL5h5m9ETN7vJmtB9ZLmhqO61jgBKJkBLAjsC/RYIcfm9mbSfZb1TlZZ2YTw/tZwFHhfS/CQ70sajL8agvn1tUjnmBcfbKJ758bI+A+M7s6dgFJHYiekNnNzFaHpq/Yx+6ujXlfBBxNVKv4k6QuZlYWu7kajr8D8A1R0ouV2JFqYd83mNm9sTMk7UP8McTNJvk5aUh8Uiwj/rsjcf9Jt+PqH28XdfXV88CPJe0GIGlXSW2Jmni+Jvol3oooiWxGUg5QYGZTgIFAS6InKMaaCvwsLL8/0ApYsC3Bhj6SvxINdJovKfbJhP0lNQ7HchQwE5gEnK/w4DFJBeXHWoXKzklVXgR+E5bPUfR0zW3ZjquDvAbj6iUzmyPpOuD50AFdSvRFOZOoOew9olGJX6tkEw2Bf4fLexsAN1v0+OpYtwN3S5oTtn+OmW0IzUZVGRiaw8qdBNwI/M3MFkj6RYj71TD/TWAi0Aa41syWAc9I2g94I+zva6L+lEpVcU6qeuzwxcC/JP2aaATjX5vZjEq289mWDtzVLX6ZsnNZLJ2XNTu3tbyJzDnnXEp4DcY551xKeA3GOedcSniCcc45lxKeYJxzzqWEJxjnnHMp4QnGOedcSvw/KM+fji8RHCgAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# Plot the linear model.\n", + "plt.scatter(splitData[\"test\"], splitData[\"test_labels\"])\n", + "plt.xlabel(\"Years of Experience\")\n", + "plt.ylabel(\"Salary in $\")\n", + "plt.title(\"Experience vs Salary (Predictions)\")\n", + "plt.plot(splitData[\"test\"], yPreds)\n", + "plt.legend([\"Linear Model\"])\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "twenty-qualification", + "metadata": {}, + "source": [ + "## Evaluation Metrics for Regression model\n", + "\n", + "In the Previous cell we have visualized our model performance by plotting the best fit line. Now we will use various evaluation metrics to understand how well our model has performed.\n", + "\n", + "* Mean Absolute Error (MAE) is the sum of absolute differences between actual and predicted values, without considering the direction.\n", + "$$ MAE = \\frac{\\sum_{i=1}^n\\lvert y_{i} - \\hat{y_{i}}\\rvert} {n} $$\n", + "* Mean Squared Error (MSE) is calculated as the mean or average of the squared differences between predicted and expected target values in a dataset, a lower value is better\n", + "$$ MSE = \\frac {1}{n} \\sum_{i=1}^n (y_{i} - \\hat{y_{i}})^2 $$\n", + "* Root Mean Squared Error (RMSE), Square root of MSE yields root mean square error (RMSE) it indicates the spread of the residual errors. It is always positive, and a lower value indicates better performance.\n", + "$$ RMSE = \\sqrt{\\frac {1}{n} \\sum_{i=1}^n (y_{i} - \\hat{y_{i}})^2} $$" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "c26ee546", + "metadata": {}, + "outputs": [], + "source": [ + "# Utility functions for evaulation metrics.\n", + "\n", + "def mae(y_true, y_preds):\n", + " return np.mean(np.abs(y_preds - y_true))\n", + "\n", + "def mse(y_true, y_preds):\n", + " return np.mean(np.power(y_preds - y_true, 2))" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "8ad80db1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---- Evaluation Metrics ----\n", + "Mean Absoulte Error: 4136.06\n", + "Mean Squared Error: 24922668.74\n", + "Root Mean Squared Error: 4992.26\n" + ] + } + ], + "source": [ + "print(\"---- Evaluation Metrics ----\")\n", + "print(f\"Mean Absoulte Error: {mae(splitData['test_labels'], yPreds):.2f}\")\n", + "print(f\"Mean Squared Error: {mse(splitData['test_labels'], yPreds):.2f}\")\n", + "print(f\"Root Mean Squared Error: {np.sqrt(mse(splitData['test_labels'], yPreds)):.2f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "9f0899be-5069-432a-a3d3-b5c7f33c8417", + "metadata": {}, + "source": [ + "From the above metrics, we can notice that our model MAE is ~4K, which is relatively small compared to our average salary of $76003, from this we can conclude our model is a reasonably good fit." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tools/download_data_set.py b/tools/download_data_set.py index 51a22ec7..01fcb57b 100755 --- a/tools/download_data_set.py +++ b/tools/download_data_set.py @@ -133,13 +133,20 @@ def iris_dataset(): tar.extractall() tar.close() clean() + +def salary_dataset(): + print("Downloading salary dataset...") + salary = requests.get("http://mlpack.org/datasets/Salary_Data.csv") + progress_bar("Salary_Data.csv", salary) def all_datasets(): mnist_dataset() electricity_consumption_dataset() stock_exchange_dataset() iris_dataset() + salary_dataset() body_fat_dataset() + if __name__ == '__main__': @@ -161,6 +168,7 @@ def all_datasets(): stock : will download stock_exchange dataset iris : will downlaod the iris dataset bodyFat : will download the bodyFat dataset + salary: will download the salary dataset all : will download all datasets for all examples ''')) @@ -187,6 +195,9 @@ def all_datasets(): elif args.dataset_name == "bodyFat": create_dataset_dir() body_fat_dataset() + elif args.dataset_name == "salary": + create_dataset_dir() + salary_dataset() elif args.dataset_name == "all": create_dataset_dir() all_datasets()