diff --git a/.travis.yml b/.travis.yml
index ad1fe691..91cc8cb0 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,5 +1,5 @@
sudo: required
-dist: xenial
+dist: focal
language: cpp
before_install:
@@ -46,7 +46,7 @@ before_install:
- cd mlpack
- mkdir build
- cd build
- - cmake -DCMAKE_INSTALL_PREFIX=/usr -DUSE_OPENMP=OFF -DBUILD_CLI_EXECUTABLES=OFF -DBUILD_JULIA_BINDINGS=OFF -DBUILD_PYTHON_BINDINGS=OFF -DBUILD_MARKDOWN_BINDINGS=OFF -DBUILD_R_BINDINGS=OFF -DBUILD_TESTS=OFF ..
+ - cmake -DCMAKE_INSTALL_PREFIX=/usr -DUSE_OPENMP=ON -DBUILD_CLI_EXECUTABLES=ON -DBUILD_JULIA_BINDINGS=OFF -DBUILD_PYTHON_BINDINGS=OFF -DBUILD_MARKDOWN_BINDINGS=OFF -DBUILD_R_BINDINGS=OFF -DBUILD_TESTS=OFF ..
- make -j2
- sudo make install
- cd ../
@@ -58,32 +58,38 @@ install:
# Download datasets.
- pwd
- pip install tqdm
- - cd tools/
- - ./download_data_set.py
- - cd ../
- - ls data/
+ - cd tools/
+ - ./download_data_set.py
+ - cd ../
+ - ls data/
script:
# Finally, build all the examples.
- pwd
- - ls
- - ls data/
- - |
+ - ls
+ - ls data/
+ - |
for f in */Makefile; do
dir=`dirname $f`;
-
+
# TODO: this takes too long right now.
if [ "$dir" == "mnist_cnn" ];
then
continue;
fi
-
+
# TODO: the dataset cannot be loaded.
if [ "$dir" == "mnist_vae_cnn" ];
then
continue;
fi
-
+
+ # TODO: the dataset cannot be loaded.
+ if [ "$dir" == "spam" ];
+ then
+ continue;
+ fi
+
cd $dir;
target_name=`grep 'TARGET := ' Makefile | awk -F ' := ' '{ print $2 }'`;
echo "Make target $target_name in directory $dir.";
@@ -98,9 +104,9 @@ script:
fi
cd ../;
done
- # Print any failures.
- - ls
- - |
+ # Print any failures.
+ - ls
+ - |
if [ -f faillog ]; then
echo "Some examples failed!";
echo "";
@@ -108,7 +114,6 @@ script:
exit 1;
fi
-
notifications:
email:
- mlpack-git@lists.mlpack.org
diff --git a/README.md b/README.md
index dec68046..50c59e4c 100644
--- a/README.md
+++ b/README.md
@@ -61,6 +61,9 @@ description (just a little bit more than the title):
- `lstm_stock_prediction`: predict Google's historical stock price (daily high
_and_ low) using an LSTM-based recurrent neural network
+ - `spam`: predict whether a mobile phone text message in Indonesian is spam
+ or not using logistic regression
+
- `mnist_batch_norm`: use batch normalization in a simple feedforward neural
network to recognize the MNIST digits
diff --git a/spam/spam_classification.sh b/spam/spam_classification.sh
new file mode 100755
index 00000000..8a4d836b
--- /dev/null
+++ b/spam/spam_classification.sh
@@ -0,0 +1,325 @@
+#!/usr/bin/env bash
+# Spam Classification with mlpack on the command line
+
+## Introduction
+
+: <<'COMMENT'
+In this tutorial, the mlpack command line interface will
+be used to train a machine learning model to classify
+SMS spam. It will be assumed that mlpack has been
+successfully installed on your machine. The tutorial has
+been tested in a linux environment. It is written in
+bash - https://www.gnu.org/software/bash/
+
+Download the dataset using the `download_data_set.py` python script in the
+tools directory.
+
+Then run this example script using
+
+ bash spam_classification.sh
+
+If you are using a low power computer, you may want to run
+the script using
+
+ nice -20 bash spam_classification.sh
+
+so that it runs as a background task.
+COMMENT
+
+## Example
+
+: <<'COMMENT'
+As an example, we will train a machine learning model to classify
+spam SMS messages. We will use an example spam dataset in Indonesian
+provided by Yudi Wibisono.
+
+We will try to classify a message as spam or ham by the number of
+occurences of a word in a message. We first change the file line
+endings, merge lines 243 and 244 which should not be separated
+and then remove the header from the dataset. We then remove any
+blank lines and split our data into two files, labels and messages.
+Since the labels are at the end of the message, the message is
+reversed and then the labels are placed in one file. The messages
+are then placed in another file.
+COMMENT
+
+tr '\r' '\n' < ../data/dataset_sms_spam_bhs_indonesia_v1/dataset_sms_spam_v1.csv > dataset.txt
+sed '485{N;s/\n//;}' dataset.txt > dataset1.csv
+sed '1d' dataset1.csv > dataset2.csv
+sed '/^$/d' dataset2.csv > dataset.csv
+rev dataset.csv | cut -c1 | rev > labels.txt
+rev dataset.csv | cut -c2- | rev > messages.txt
+rm dataset.csv
+rm dataset1.csv
+rm dataset2.csv
+rm dataset.txt
+
+: <<'COMMENT'
+Machine learning works on numeric data, so we will use labels of
+1 for ham and 0 for spam. The dataset contains three labels, 0,
+normal sms (ham), 1, fraud (spam) and 2, promotion (spam). We will
+label all spam as 1, so promotions and fraud will be labelled as 1.
+COMMENT
+
+tr '2' '1' < labels.txt > labels.csv
+rm labels.txt
+
+: <<'COMMENT'
+The next step is to convert all text in the messages to lower case
+and for simplicity remove punctuation and any symbols that are not
+spaces, line endings or in the range a-z (one would need expand
+this range of symbols for production use)
+COMMENT
+
+tr '[:upper:]' '[:lower:]' < messages.txt > messagesLower.txt
+tr -Cd 'abcdefghijklmnopqrstuvwxyz \n' < messagesLower.txt > messagesLetters.txt
+rm messagesLower.txt
+
+: <<'COMMENT'
+We now obtain a sorted list of unique words used (this step may take
+a few minutes).
+COMMENT
+
+xargs -n1 < messagesLetters.txt > temp.txt
+sort temp.txt > temp2.txt
+uniq temp2.txt > words.txt
+rm temp.txt
+rm temp2.txt
+
+: <<'COMMENT'
+We then create a matrix, where for each message, the frequency of word
+occurrences is counted (more on this on Wikipedia,
+https://en.wikipedia.org/wiki/Tf–idf and
+https://en.wikipedia.org/wiki/Document-term_matrix ).
+COMMENT
+
+
+declare -a words=()
+declare -a letterstartind=()
+declare -a letterstart=()
+letter=" "
+i=0
+lettercount=0
+while IFS= read -r line; do
+ labels[$((i))]=$line
+ let "i++"
+done < labels.csv
+i=0
+while IFS= read -r line; do
+ words[$((i))]=$line
+ firstletter="$( echo $line | head -c 1 )"
+ if [ "$firstletter" != "$letter" ]
+ then
+ letterstartind[$((lettercount))]=$((i))
+ letterstart[$((lettercount))]=$firstletter
+ letter=$firstletter
+ let "lettercount++"
+ fi
+ let "i++"
+done < words.txt
+letterstartind[$((lettercount))]=$((i))
+echo "Created list of letters"
+
+touch wordfrequency.txt
+rm wordfrequency.txt
+touch wordfrequency.txt
+messagecount=0
+messagenum=0
+messages="$( wc -l messages.txt )"
+i=0
+while IFS= read -r line; do
+ let "messagenum++"
+ declare -a wordcount=()
+ declare -a wordarray=()
+ read -r -a wordarray <<< "$line"
+ let "messagecount++"
+ words=${#wordarray[@]}
+ for word in "${wordarray[@]}"; do
+ startletter="$( echo $word | head -c 1 )"
+ j=-1
+ while [ $((j)) -lt $((lettercount)) ]; do
+ let "j++"
+ if [ "$startletter" == "${letterstart[$((j))]}" ]
+ then
+ mystart=$((j))
+ fi
+ done
+ myend=$((mystart))+1
+ j=${letterstartind[$((mystart))]}
+ jend=${letterstartind[$((myend))]}
+ while [ $((j)) -le $((jend)) ]; do
+ wordcount[$((j))]=0
+ if [ "$word" == "${words[$((j))]}" ]
+ then
+ wordcount[$((j))]="$( echo $line | grep -o $word | wc -l )"
+ fi
+ let "j++"
+ done
+ done
+ for j in "${!wordcount[@]}"; do
+ wordcount[$((j))]=$(echo " scale=4; $((${wordcount[$((j))]})) / $((words))" | bc)
+ done
+ wordcount[$((words))+1]=$((words))
+ echo "${wordcount[*]}" >> wordfrequency.txt
+ echo "Processed message ""$messagenum"
+ let "i++"
+done < messagesLetters.txt
+
+# Create csv file
+tr ' ' ',' < wordfrequency.txt > data.csv
+
+: <<'COMMENT'
+Since Bash is an interpreted language, this simple implementation can
+take up to 30 minutes to complete.
+COMMENT
+
+: <<'COMMENT'
+Once the script has finished running, split the data into testing (30%)
+and training (70%) sets:
+COMMENT
+
+mlpack_preprocess_split \
+ --input_file data.csv \
+ --input_labels_file labels.csv \
+ --training_file train.data.csv \
+ --training_labels_file train.labels.csv \
+ --test_file test.data.csv \
+ --test_labels_file test.labels.csv \
+ --test_ratio 0.3 \
+ --verbose
+
+: <<'COMMENT'
+Now train a Logistic regression model
+(https://mlpack.org/doc/stable/cli_documentation.html#logistic_regression):
+COMMENT
+
+mlpack_logistic_regression --training_file train.data.csv \
+ --labels_file train.labels.csv \
+ --lambda 0.1 \
+ --output_model_file lr_model.bin
+
+: <<'COMMENT'
+Finally we test our model by producing predictions,
+COMMENT
+
+mlpack_logistic_regression --input_model_file lr_model.bin \
+ --test_file test.data.csv \
+ --output_file lr_predictions.csv
+
+: <<'COMMENT'
+and comparing the predictions with the exact results,
+COMMENT
+
+export incorrect=$(diff -U 0 lr_predictions.csv test.labels.csv | grep '^@@' | wc -l)
+export tests=$(wc -l < lr_predictions.csv)
+echo "scale=2; 100 * ( 1 - $((incorrect)) / $((tests)))" | bc
+
+: <<'COMMENT'
+This gives approximately 90% validation rate, similar to that
+obtained at
+https://towardsdatascience.com/spam-detection-with-logistic-regression-23e3709e522
+
+The dataset is composed of approximately 50% spam messages,
+so the validation rates are quite good without doing much parameter tuning.
+In typical cases, datasets are unbalanced with many more entries in
+some categories than in others. In these cases a good validation
+rate can be obtained by mispredicting the class with a few entries.
+Thus to better evaluate these models, one can compare the number of
+misclassifications of spam, and the number of misclassifications of ham.
+Of particular importance in applications is the number of false
+positive spam results as these are typically not transmitted. The next
+portion of the script creates a confusion matrix.
+COMMENT
+
+
+declare -a labels
+declare -a lr
+i=0
+while IFS= read -r line; do
+ labels[i]=$line
+ let "i++"
+done < test.labels.csv
+i=0
+while IFS= read -r line; do
+ lr[i]=$line
+ let "i++"
+done < lr_predictions.csv
+TruePositiveLR=0
+FalsePositiveLR=0
+TrueZerpLR=0
+FalseZeroLR=0
+Positive=0
+Zero=0
+for i in "${!labels[@]}"; do
+ if [ "${labels[$i]}" == "1" ]
+ then
+ let "Positive++"
+ if [ "${lr[$i]}" == "1" ]
+ then
+ let "TruePositiveLR++"
+ else
+ let "FalseZeroLR++"
+ fi
+ fi
+ if [ "${labels[$i]}" == "0" ]
+ then
+ let "Zero++"
+ if [ "${lr[$i]}" == "0" ]
+ then
+ let "TrueZeroLR++"
+ else
+ let "FalsePositiveLR++"
+ fi
+ fi
+
+done
+echo "Logistic Regression"
+echo "Total spam" $Positive
+echo "Total ham" $Zero
+echo "Confusion matrix"
+echo " Predicted class"
+echo " Ham | Spam "
+echo " ---------------"
+echo " Actual| Ham | " $TrueZeroLR "|" $FalseZeroLR
+echo " class | Spam | " $FalsePositiveLR " |" $TruePositiveLR
+echo ""
+
+: <<'COMMENT'
+You should get output similar to
+
+ Logistic Regression
+ Total spam 183
+ Total ham 159
+ Confusion matrix
+ Predicted class
+ -------------------
+ | Ham | Spam
+ | Actual | Ham | 128 | 26
+ | class | Spam | 31 | 157
+
+which indicates a reasonable level of classification.
+Other methods you can try in mlpack for this problem include:
+* Naive Bayes
+https://mlpack.org/doc/stable/cli_documentation.html#nbc
+* Random forest
+https://mlpack.org/doc/stable/cli_documentation.html#random_forest
+* Decision tree
+https://mlpack.org/doc/stable/cli_documentation.html#decision_tree
+* AdaBoost
+https://mlpack.org/doc/stable/cli_documentation.html#adaboost
+* Perceptron
+https://mlpack.org/doc/stable/cli_documentation.html#perceptron
+
+To improve the error rating, you can try other pre-processing methods
+on the initial data set. Neural networks can give up to 99.95%
+validation rates, see for example:
+
+https://thesai.org/Downloads/Volume11No1/Paper_67-The_Impact_of_Deep_Learning_Techniques.pdf
+https://www.kaggle.com/kredy10/simple-lstm-for-text-classification
+https://www.kaggle.com/xiu0714/sms-spam-detection-bert-acc-0-993
+
+However, using these techniques with mlpack is best covered in another tutorial.
+
+This tutorial is an adaptation of one that first appeared in the Fedora Magazine
+https://fedoramagazine.org/spam-classification-with-ml-pack/
+COMMENT
diff --git a/spam/tutorial.md b/spam/tutorial.md
new file mode 100644
index 00000000..941ebf4a
--- /dev/null
+++ b/spam/tutorial.md
@@ -0,0 +1,272 @@
+# Spam Classification with mlpack on the command line
+
+## Introduction
+
+In this tutorial, the mlpack command line interface will
+be used to train a machine learning model to classify
+SMS spam. It will be assumed that mlpack has been
+successfully installed on your machine. The tutorial has
+been tested in a linux environment.
+
+## Example
+
+As an example, we will train some machine learning models to classify spam SMS messages. We will use an example spam dataset in Indonesian provided by Yudi Wibisono.
+
+We will try to classify a message as spam or ham by the number of occurences of a word in a message. We first change the file line endings, remove line 243 which is missing a label and then remove the header from the dataset. Then, we split our data into two files, labels and messages. Since the labels are at the end of the message, the message is reversed and then the label removed and placed in one file. The message is then removed and placed in another file.
+
+```
+tr '\r' '\n' < dataset_sms_spam_v1.csv > dataset.txt
+sed '243d' dataset.txt > dataset1.csv
+sed '1d' dataset1.csv > dataset.csv
+rev dataset.csv | cut -c1 | rev > labels.txt
+rev dataset.csv | cut -c2- | rev > messages.txt
+rm dataset.csv
+rm dataset1.csv
+rm dataset.txt
+```
+
+Machine learning works on numeric data, so we will use labels to 1 for ham and 0 for spam. The dataset contains three labels, 0, normal sms (ham), 1, fraud (spam) and 2, promotion (spam). We will label all spam as 1, so promotions
+and fraud will be labelled as 1.
+
+
+```
+tr '2' '1' < labels.txt > labels.csv
+rm labels.txt
+```
+
+The next step is to convert all text in the messages to lower case and for simplicity remove punctuation and any symbols that are not spaces, line endings or in the range a-z (one would need expand this range of symbols for production use)
+
+```
+tr '[:upper:]' '[:lower:]' < messages.txt > messagesLower.txt
+tr -Cd 'abcdefghijklmnopqrstuvwxyz \n' < messagesLower.txt > messagesLetters.txt
+rm messagesLower.txt
+```
+
+We now obtain a sorted list of unique words used (this step may take a few minutes, so use nice to give it a low priority while you continue with other tasks on your computer).
+
+```
+nice -20 xargs -n1 < messagesLetters.txt > temp.txt
+sort temp.txt > temp2.txt
+uniq temp2.txt > words.txt
+rm temp.txt
+rm temp2.txt
+```
+
+We then create a matrix, where for each message, the frequency of word occurrences is counted (more on this on Wikipedia, [here](https://en.wikipedia.org/wiki/Tf–idf) and [here](https://en.wikipedia.org/wiki/Document-term_matrix)). This requires a few lines of code, so the full script, which should be saved as 'makematrix.sh' is below:
+
+```
+#!/bin/bash
+declare -a words=()
+declare -a letterstartind=()
+declare -a letterstart=()
+letter=" "
+i=0
+lettercount=0
+while IFS= read -r line; do
+ labels[$((i))]=$line
+ let "i++"
+done < labels.csv
+i=0
+while IFS= read -r line; do
+ words[$((i))]=$line
+ firstletter="$( echo $line | head -c 1 )"
+ if [ "$firstletter" != "$letter" ]
+ then
+ letterstartind[$((lettercount))]=$((i))
+ letterstart[$((lettercount))]=$firstletter
+ letter=$firstletter
+ let "lettercount++"
+ fi
+ let "i++"
+done < words.txt
+letterstartind[$((lettercount))]=$((i))
+echo "Created list of letters"
+
+touch wordfrequency.txt
+rm wordfrequency.txt
+touch wordfrequency.txt
+messagecount=0
+messagenum=0
+messages="$( wc -l messages.txt )"
+i=0
+while IFS= read -r line; do
+ let "messagenum++"
+ declare -a wordcount=()
+ declare -a wordarray=()
+ read -r -a wordarray <<< "$line"
+ let "messagecount++"
+ words=${#wordarray[@]}
+ for word in "${wordarray[@]}"; do
+ startletter="$( echo $word | head -c 1 )"
+ j=-1
+ while [ $((j)) -lt $((lettercount)) ]; do
+ let "j++"
+ if [ "$startletter" == "${letterstart[$((j))]}" ]
+ then
+ mystart=$((j))
+ fi
+ done
+ myend=$((mystart))+1
+ j=${letterstartind[$((mystart))]}
+ jend=${letterstartind[$((myend))]}
+ while [ $((j)) -le $((jend)) ]; do
+ wordcount[$((j))]=0
+ if [ "$word" == "${words[$((j))]}" ]
+ then
+ wordcount[$((j))]="$( echo $line | grep -o $word | wc -l )"
+ fi
+ let "j++"
+ done
+ done
+ for j in "${!wordcount[@]}"; do
+ wordcount[$((j))]=$(echo " scale=4; $((${wordcount[$((j))]})) / $((words))" | bc)
+ done
+ wordcount[$((words))+1]=$((words))
+ echo "${wordcount[*]}" >> wordfrequency.txt
+ echo "Processed message ""$messagenum"
+ let "i++"
+done < messagesLetters.txt
+# Create csv file
+tr ' ' ',' < wordfrequency.txt > data.csv
+```
+
+Since [Bash](https://www.gnu.org/software/bash/) is an interpreted language, this simple implementation can take up to 30 minutes to complete. If using the above Bash script on your primary workstation, run it as a task with low priority so that you can continue with other work while you wait:
+
+```
+nice -20 bash makematrix.sh
+```
+
+Once the script has finished running, split the data into testing (30%) and training (70%) sets:
+
+```
+mlpack_preprocess_split \
+ --input_file data.csv \
+ --input_labels_file labels.csv \
+ --training_file train.data.csv \
+ --training_labels_file train.labels.csv \
+ --test_file test.data.csv \
+ --test_labels_file test.labels.csv \
+ --test_ratio 0.3 \
+ --verbose
+```
+
+Now train a [Logistic regression model](https://mlpack.org/doc/mlpack-3.3.1/cli_documentation.html#logistic_regression):
+
+```
+mlpack_logistic_regression --training_file train.data.csv \
+ --labels_file train.labels.csv \
+ --lambda 0.1 \
+ --output_model_file lr_model.bin
+```
+
+Finally we test our model by producing predictions,
+
+```
+mlpack_logistic_regression --input_model_file lr_model.bin \
+ --test_file test.data.csv \
+ --output_file lr_predictions.csv
+```
+
+and comparing the predictions with the exact results,
+
+```
+export incorrect=$(diff -U 0 lr_predictions.csv test.labels.csv | grep '^@@' | wc -l)
+export tests=$(wc -l < lr_predictions.csv)
+echo "scale=2; 100 * ( 1 - $((incorrect)) / $((tests)))" | bc
+```
+
+This gives approximately 90% validation rate, similar to that obtained [here](https://towardsdatascience.com/spam-detection-with-logistic-regression-23e3709e522).
+
+The dataset is composed of approximately 50% spam messages, so the validation rates are quite good without doing much parameter tuning.
+In typical cases, datasets are unbalanced with many more entries in some categories than in others. In these cases a good validation
+rate can be obtained by mispredicting the class with a few entries.
+Thus to better evaluate these models, one can compare the number of misclassifications of spam, and the number of misclassifications of ham.
+Of particular importance in applications is the number of false positive spam results as these are typically not transmitted. The script below produces a confusion matrix which gives a better indication of misclassification.
+Save it as 'confusion.sh'
+
+```
+#!/bin/bash
+declare -a labels
+declare -a lr
+i=0
+while IFS= read -r line; do
+ labels[i]=$line
+ let "i++"
+done < test.labels.csv
+i=0
+while IFS= read -r line; do
+ lr[i]=$line
+ let "i++"
+done < lr_predictions.csv
+TruePositiveLR=0
+FalsePositiveLR=0
+TrueZerpLR=0
+FalseZeroLR=0
+Positive=0
+Zero=0
+for i in "${!labels[@]}"; do
+ if [ "${labels[$i]}" == "1" ]
+ then
+ let "Positive++"
+ if [ "${lr[$i]}" == "1" ]
+ then
+ let "TruePositiveLR++"
+ else
+ let "FalseZeroLR++"
+ fi
+ fi
+ if [ "${labels[$i]}" == "0" ]
+ then
+ let "Zero++"
+ if [ "${lr[$i]}" == "0" ]
+ then
+ let "TrueZeroLR++"
+ else
+ let "FalsePositiveLR++"
+ fi
+ fi
+
+done
+echo "Logistic Regression"
+echo "Total spam" $Positive
+echo "Total ham" $Zero
+echo "Confusion matrix"
+echo " Predicted class"
+echo " Ham | Spam "
+echo " ---------------"
+echo " Actual| Ham | " $TrueZeroLR "|" $FalseZeroLR
+echo " class | Spam | " $FalsePositiveLR " |" $TruePositiveLR
+echo ""
+
+```
+
+then run the script
+
+```
+bash confusion.sh
+```
+
+You should get output similar to
+
+> Logistic Regression
+> Total spam 183
+> Total ham 159
+> Confusion matrix
+> | | |Predicted |class
+> | --- |--- | --- |---
+> | | | **Ham** | **Spam**
+> | **Actual**| **Ham** | 128 | 26
+> | **class** | **Spam** | 31 | 157
+
+which indicates a reasonable level of classification.
+Other methods you can try in mlpack for this problem include:
+* [Naive Bayes](https://mlpack.org/doc/mlpack-3.3.1/cli_documentation.html#nbc)
+* [Random forest](https://mlpack.org/doc/mlpack-3.3.1/cli_documentation.html#random_forest)
+* [Decision tree](https://mlpack.org/doc/mlpack-3.3.1/cli_documentation.html#decision_tree)
+* [AdaBoost](https://mlpack.org/doc/mlpack-3.3.1/cli_documentation.html#adaboost)
+* [Perceptron](https://mlpack.org/doc/mlpack-3.3.1/cli_documentation.html#perceptron)
+
+To improve the error rating, you can try other pre-processing methods on the initial data set.
+Neural networks can give up to 99.95% validation rates, see for example [here](https://thesai.org/Downloads/Volume11No1/Paper_67-The_Impact_of_Deep_Learning_Techniques.pdf), [here](https://www.kaggle.com/kredy10/simple-lstm-for-text-classification) and [here](https://www.kaggle.com/xiu0714/sms-spam-detection-bert-acc-0-993). However, using these techniques with mlpack is best covered in another tutorial.
+
+This tutorial is an adaptation of one that first appeared in the [Fedora Magazine](https://fedoramagazine.org/spam-classification-with-ml-pack/).
diff --git a/tools/download_data_set.py b/tools/download_data_set.py
index 0de34d5b..b3432076 100755
--- a/tools/download_data_set.py
+++ b/tools/download_data_set.py
@@ -134,6 +134,15 @@ def iris_dataset():
tar.close()
clean()
+def spam_dataset():
+ print("Downloading spam dataset...")
+ spam = requests.get("https://www.mlpack.org/datasets/dataset_sms_spam_bhs_indonesia_v1.tar.gz")
+ progress_bar("dataset_sms_spam_bhs_indonesia_v1.tar.gz", spam)
+ tar = tarfile.open("dataset_sms_spam_bhs_indonesia_v1.tar.gz", "r:gz")
+ tar.extractall()
+ tar.close()
+ clean()
+
def salary_dataset():
print("Downloading salary dataset...")
salary = requests.get("http://datasets.mlpack.org/Salary_Data.csv")
@@ -154,6 +163,7 @@ def all_datasets():
iris_dataset()
salary_dataset()
body_fat_dataset()
+ spam_dataset()
cifar10_dataset()
@@ -177,6 +187,7 @@ def all_datasets():
stock : will download stock_exchange dataset
iris : will downlaod the iris dataset
bodyFat : will download the bodyFat dataset
+ spam : will download the spam dataset
salary: will download the salary dataset
cifar10: will download the cifar10 dataset
all : will download all datasets for all examples
@@ -205,6 +216,9 @@ def all_datasets():
elif args.dataset_name == "bodyFat":
create_dataset_dir()
body_fat_dataset()
+ elif args.dataset_name == "spam":
+ create_dataset_dir()
+ spam_dataset()
elif args.dataset_name == "salary":
create_dataset_dir()
salary_dataset()