-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #149 from Davidportlouis/simple_lr
Added Linear Regression for predicting salary.
- Loading branch information
Showing
3 changed files
with
935 additions
and
0 deletions.
There are no files selected for viewing
396 changes: 396 additions & 0 deletions
396
salary_prediction_with_linear_regression/salary-prediction-linear-regression-cpp.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,396 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "94323844", | ||
"metadata": {}, | ||
"source": [ | ||
"## Predicting Salary using Linear Regression\n", | ||
"\n", | ||
"### Objective\n", | ||
"* We have to predict the salary of an employee given how many years of experience they have.\n", | ||
"\n", | ||
"### Dataset\n", | ||
"* Salary_Data.csv has 2 columns — “Years of Experience” (feature) and “Salary” (target) for 30 employees in a company\n", | ||
"\n", | ||
"### Approach\n", | ||
"* So in this example, we will train a Linear Regression model to learn the correlation between the number of years of experience of each employee and their respective salary. \n", | ||
"* Once the model is trained, we will be able to do some sample predictions." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "189dc5ff-22c4-4502-89a8-75e5ce51f3e1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!wget -q https://mlpack.org/datasets/Salary_Data.csv" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "behavioral-cycling", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"// Import necessary library header.\n", | ||
"#include <mlpack/xeus-cling.hpp>\n", | ||
"\n", | ||
"#include <mlpack/core.hpp>\n", | ||
"#include <mlpack/core/data/split_data.hpp>\n", | ||
"#include <mlpack/methods/linear_regression/linear_regression.hpp>\n", | ||
"#include <cmath>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "db43325d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#define WITHOUT_NUMPY 1\n", | ||
"#include \"matplotlibcpp.h\"\n", | ||
"#include \"xwidgets/ximage.hpp\"\n", | ||
"\n", | ||
"namespace plt = matplotlibcpp;" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "9065ebb1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"using namespace mlpack;\n", | ||
"using namespace mlpack::regression;" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "victorian-donna", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"// Load the dataset into armadillo matrix.\n", | ||
"\n", | ||
"arma::mat inputs;\n", | ||
"data::Load(\"Salary_Data.csv\", inputs);" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "deluxe-present", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"// Drop the first row as they represent header.\n", | ||
"\n", | ||
"inputs.shed_col(0);" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "desirable-experience", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Years Of Experience Salary\n", | ||
" 1.1000e+00 3.9343e+04\n", | ||
" 1.3000e+00 4.6205e+04\n", | ||
" 1.5000e+00 3.7731e+04\n", | ||
" 2.0000e+00 4.3525e+04\n", | ||
" 2.2000e+00 3.9891e+04\n", | ||
" 2.9000e+00 5.6642e+04\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"// Display the first 5 rows of the input data.\n", | ||
"\n", | ||
"std::cout << std::setw(18) << \"Years Of Experience\" << std::setw(10) << \"Salary\" << std::endl;\n", | ||
"std::cout << inputs.submat(0, 0, inputs.n_rows-1, 5).t() << std::endl;" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"id": "associate-fifteen", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "912d932e54c14571a0ac726764dac35f", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"A Jupyter widget with unique id: 912d932e54c14571a0ac726764dac35f" | ||
] | ||
}, | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"// Plot the input data.\n", | ||
"\n", | ||
"std::vector<double> x = arma::conv_to<std::vector<double>>::from(inputs.row(0));\n", | ||
"std::vector<double> y = arma::conv_to<std::vector<double>>::from(inputs.row(1));\n", | ||
"\n", | ||
"plt::figure_size(800, 800);\n", | ||
"\n", | ||
"plt::scatter(x, y, 12, {{\"color\",\"coral\"}});\n", | ||
"plt::xlabel(\"Years of Experience\");\n", | ||
"plt::ylabel(\"Salary in $\");\n", | ||
"plt::title(\"Experience vs. Salary\");\n", | ||
"\n", | ||
"plt::save(\"./scatter.png\");\n", | ||
"auto img = xw::image_from_file(\"scatter.png\").finalize();\n", | ||
"img" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"id": "coordinate-canvas", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"// Split the data into features (X) and target (y) variables\n", | ||
"// targets are the last row.\n", | ||
"\n", | ||
"arma::Row<size_t> targets = arma::conv_to<arma::Row<size_t>>::from(inputs.row(inputs.n_rows - 1));" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"id": "blank-mexican", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"// Labels are dropped from the originally loaded data to be used as features.\n", | ||
"\n", | ||
"inputs.shed_row(inputs.n_rows - 1);" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8da116b5-83f2-4acd-8ac3-0d68adbd83ca", | ||
"metadata": {}, | ||
"source": [ | ||
"### Train Test Split\n", | ||
"The dataset has to be split into a training set and a test set.\n", | ||
"This can be done using the `data::Split()` api from mlpack.\n", | ||
"Here the dataset has 30 observations and the `testRatio` is taken as 40% of the total observations.\n", | ||
"This indicates the test set should have 40% * 30 = 12 observations and training test should have 18 observations respectively." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"id": "mechanical-laundry", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"// Split the dataset into train and test sets using mlpack.\n", | ||
"\n", | ||
"arma::mat Xtrain;\n", | ||
"arma::mat Xtest;\n", | ||
"arma::Row<size_t> Ytrain;\n", | ||
"arma::Row<size_t> Ytest;\n", | ||
"data::Split(inputs, targets, Xtrain, Xtest, Ytrain, Ytest, 0.4);" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"id": "friendly-petersburg", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"// Convert armadillo Rows into rowvec. (Required by mlpacks' LinearRegression API in this format).\n", | ||
"\n", | ||
"arma::rowvec yTrain = arma::conv_to<arma::rowvec>::from(Ytrain);\n", | ||
"arma::rowvec yTest = arma::conv_to<arma::rowvec>::from(Ytest);" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "99955e22", | ||
"metadata": {}, | ||
"source": [ | ||
"## Linear Model\n", | ||
"\n", | ||
"Regression analysis is the most widely used method of prediction. Linear regression is used when the dataset has a linear correlation and as the name suggests, \n", | ||
"simple linear regression has one independent variable (predictor) and one dependent variable(response).\n", | ||
"\n", | ||
"The simple linear regression equation is represented as $y = a+bx$ where $x$ is the explanatory variable, $y$ is the dependent variable, $b$ is coefficient and $a$ is the intercept\n", | ||
"\n", | ||
"To perform linear regression we'll be using `LinearRegression()` api from mlpack." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"id": "published-illustration", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"// Create and Train Linear Regression model.\n", | ||
"\n", | ||
"regression::LinearRegression lr(Xtrain, yTrain, 0.5);" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"id": "detailed-mystery", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"// Make predictions for test data points.\n", | ||
"\n", | ||
"arma::rowvec yPreds;\n", | ||
"lr.Predict(Xtest, yPreds);" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"id": "indian-ambassador", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"// Convert armadillo vectors and matrices to vector for plotting purpose.\n", | ||
"\n", | ||
"std::vector<double> XtestPlot = arma::conv_to<std::vector<double>>::from(Xtest);\n", | ||
"std::vector<double> yTestPlot = arma::conv_to<std::vector<double>>::from(yTest);\n", | ||
"std::vector<double> yPredsPlot = arma::conv_to<std::vector<double>>::from(yPreds);" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"id": "related-approach", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "88f7de7663bd431382ce760f7f8a08a0", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"A Jupyter widget with unique id: 88f7de7663bd431382ce760f7f8a08a0" | ||
] | ||
}, | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"// Visualize Predicted datapoints.\n", | ||
"plt::figure_size(800, 800);\n", | ||
"\n", | ||
"plt::scatter(XtestPlot, yTestPlot, 12, {{\"color\", \"coral\"}});\n", | ||
"plt::plot(XtestPlot,yPredsPlot);\n", | ||
"plt::xlabel(\"Years of Experience\");\n", | ||
"plt::ylabel(\"Salary in $\");\n", | ||
"plt::title(\"Predicted Experience vs. Salary\");\n", | ||
"\n", | ||
"plt::save(\"./scatter1.png\");\n", | ||
"auto img = xw::image_from_file(\"scatter1.png\").finalize();\n", | ||
"img" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0a10abbb-6b3a-423f-a573-1c650ac60b85", | ||
"metadata": {}, | ||
"source": [ | ||
"Test data is visualized with `XtestPlot` and `yPredsPlot`, the coral points indicates the data points and the blue line indicates the regression line or best fit line." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c24be191-959f-4244-8921-c1ee0ea98b3b", | ||
"metadata": {}, | ||
"source": [ | ||
"## Evaluation Metrics for Regression model\n", | ||
"\n", | ||
"In the Previous cell we have visualized our model performance by plotting the best fit line. Now we will use various evaluation metrics to understand how well our model has performed.\n", | ||
"\n", | ||
"* Mean Absolute Error (MAE) is the sum of absolute differences between actual and predicted values, without considering the direction.\n", | ||
"$$ MAE = \\frac{\\sum_{i=1}^n\\lvert y_{i} - \\hat{y_{i}}\\rvert} {n} $$\n", | ||
"* Mean Squared Error (MSE) is calculated as the mean or average of the squared differences between predicted and expected target values in a dataset, a lower value is better\n", | ||
"$$ MSE = \\frac {1}{n} \\sum_{i=1}^n (y_{i} - \\hat{y_{i}})^2 $$\n", | ||
"* Root Mean Squared Error (RMSE), Square root of MSE yields root mean square error (RMSE) it indicates the spread of the residual errors. It is always positive, and a lower value indicates better performance.\n", | ||
"$$ RMSE = \\sqrt{\\frac {1}{n} \\sum_{i=1}^n (y_{i} - \\hat{y_{i}})^2} $$" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 101, | ||
"id": "british-moment", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Mean Absolute Error: 5753.06\n", | ||
"Mean Squared Error: 3.9482e+07\n", | ||
"Root Mean Squared Error: 6283.47\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"// Model evaluation metrics.\n", | ||
"\n", | ||
"std::cout << \"Mean Absolute Error: \" << arma::mean(arma::abs(yPreds - yTest)) << std::endl;\n", | ||
"std::cout << \"Mean Squared Error: \" << arma::mean(arma::pow(yPreds - yTest,2)) << std::endl;\n", | ||
"std::cout << \"Root Mean Squared Error: \" << sqrt(arma::mean(arma::pow(yPreds - yTest,2))) << std::endl;" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "17cd38d7-214a-4f5a-8c4d-0517f834e804", | ||
"metadata": {}, | ||
"source": [ | ||
"From the above metrics we can notice that our model MAE is ~5K, which is relatively small compared to our average salary of $76003, from this we can conclude our model is resonably good fit." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "C++14", | ||
"language": "C++14", | ||
"name": "xcpp14" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": "text/x-c++src", | ||
"file_extension": ".cpp", | ||
"mimetype": "text/x-c++src", | ||
"name": "c++", | ||
"version": "14" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.