From b71bde3574d3284bae0547324e1b2234b8c1ed26 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Mon, 9 Oct 2017 16:59:00 -0400 Subject: [PATCH] Overhaul logistic regression scripts; fix a number of bugs. --- methods/matlab/LOGISTIC_REGRESSION.m | 4 +- methods/matlab/logistic_regression.py | 16 ++-- methods/milk/logistic_regression.py | 18 ++++ methods/mlpack/logistic_regression.py | 2 +- methods/scikit/logistic_regression.py | 2 +- methods/shogun/logistic_regression.py | 7 ++ methods/weka/logistic_regression.py | 22 +++-- methods/weka/src/LOGISTICREGRESSION.java | 102 ----------------------- methods/weka/src/LogisticRegression.java | 17 +++- util/timer.py | 4 +- 10 files changed, 71 insertions(+), 123 deletions(-) delete mode 100644 methods/weka/src/LOGISTICREGRESSION.java diff --git a/methods/matlab/LOGISTIC_REGRESSION.m b/methods/matlab/LOGISTIC_REGRESSION.m index 5a1e3ba..2a2b635 100644 --- a/methods/matlab/LOGISTIC_REGRESSION.m +++ b/methods/matlab/LOGISTIC_REGRESSION.m @@ -26,7 +26,7 @@ function logistic_regression(cmd) X = csvread(regressorsFile{:}); if isempty(responsesFile) - y = X(:,end); + y = X(:,end) + 1; % We have to increment because labels must be positive. X = X(:,1:end-1); else y = csvread(responsesFile{:}); @@ -47,7 +47,7 @@ function logistic_regression(cmd) disp(sprintf('[INFO ] total_time: %fs', toc(total_time))) if ~isempty(testFile) - csvwrite('predictions.csv', idx); + csvwrite('predictions.csv', idx - 1); % Subtract extra label bit. csvwrite('matlab_lr_probs.csv', predictions); end diff --git a/methods/matlab/logistic_regression.py b/methods/matlab/logistic_regression.py index aa2ec49..c2a67ae 100644 --- a/methods/matlab/logistic_regression.py +++ b/methods/matlab/logistic_regression.py @@ -79,7 +79,7 @@ def RunMetrics(self, options): # If the dataset contains two files then the second file is the test # file. In this case we add this to the command line. - if len(self.dataset) == 2: + if len(self.dataset) >= 2: inputCmd = "-i " + self.dataset[0] + " -t " + self.dataset[1] else: inputCmd = "-i " + self.dataset[0] @@ -111,11 +111,15 @@ def RunMetrics(self, options): truelabels = np.genfromtxt(self.dataset[2], delimiter = ',') metrics['Runtime'] = timer.total_time confusionMatrix = Metrics.ConfusionMatrix(truelabels, predictions) - metrics['ACC'] = Metrics.AverageAccuracy(confusionMatrix) - metrics['MCC'] = Metrics.MCCMultiClass(confusionMatrix) - metrics['Precision'] = Metrics.AvgPrecision(confusionMatrix) - metrics['Recall'] = Metrics.AvgRecall(confusionMatrix) - metrics['MSE'] = Metrics.SimpleMeanSquaredError(truelabels, predictions) + + metrics['Avg Accuracy'] = Metrics.AverageAccuracy(confusionMatrix) + metrics['MultiClass Precision'] = Metrics.AvgPrecision(confusionMatrix) + metrics['MultiClass Recall'] = Metrics.AvgRecall(confusionMatrix) + metrics['MultiClass FMeasure'] = Metrics.AvgFMeasure(confusionMatrix) + metrics['MultiClass Lift'] = Metrics.LiftMultiClass(confusionMatrix) + metrics['MultiClass MCC'] = Metrics.MCCMultiClass(confusionMatrix) + metrics['MultiClass Information'] = Metrics.AvgMPIArray(confusionMatrix, truelabels, predictions) + metrics['Simple MSE'] = Metrics.SimpleMeanSquaredError(truelabels, predictions) Log.Info(("total time: %fs" % (metrics['Runtime'])), self.verbose) diff --git a/methods/milk/logistic_regression.py b/methods/milk/logistic_regression.py index 7510ded..3d21976 100644 --- a/methods/milk/logistic_regression.py +++ b/methods/milk/logistic_regression.py @@ -80,6 +80,9 @@ def RunLogisticRegressionMilk(): self.model = self.BuildModel() with totalTimer: self.model = self.model.train(trainData, labels) + if len(self.dataset) > 1: + # We get back probabilities; cast these to classes. + self.predictions = np.greater(self.model.apply(testData), 0.5) except Exception as e: return -1 @@ -112,4 +115,19 @@ def RunMetrics(self, options): # Datastructure to store the results. metrics = {'Runtime' : results} + + if len(self.dataset) >= 3: + truelabels = LoadDataset(self.dataset[2]) + + confusionMatrix = Metrics.ConfusionMatrix(truelabels, self.predictions) + + metrics['Avg Accuracy'] = Metrics.AverageAccuracy(confusionMatrix) + metrics['MultiClass Precision'] = Metrics.AvgPrecision(confusionMatrix) + metrics['MultiClass Recall'] = Metrics.AvgRecall(confusionMatrix) + metrics['MultiClass FMeasure'] = Metrics.AvgFMeasure(confusionMatrix) + metrics['MultiClass Lift'] = Metrics.LiftMultiClass(confusionMatrix) + metrics['MultiClass MCC'] = Metrics.MCCMultiClass(confusionMatrix) + metrics['MultiClass Information'] = Metrics.AvgMPIArray(confusionMatrix, truelabels, self.predictions) + metrics['Simple MSE'] = Metrics.SimpleMeanSquaredError(truelabels, self.predictions) + return metrics diff --git a/methods/mlpack/logistic_regression.py b/methods/mlpack/logistic_regression.py index dbc57c1..7022903 100644 --- a/methods/mlpack/logistic_regression.py +++ b/methods/mlpack/logistic_regression.py @@ -98,7 +98,7 @@ def OptionsToStr(self, options): optionsStr = "-e " + str(options.pop("epsilon")) if "max_iterations" in options: optionsStr = optionsStr + " -n " + str(options.pop("max_iterations")) - if "optimizer" in options: + if "algorithm" in options: optionsStr = optionsStr + " -O " + str(options.pop("optimizer")) if "step_size" in options: optionsStr = optionsStr + " -s " + str(options.pop("step_size")) diff --git a/methods/scikit/logistic_regression.py b/methods/scikit/logistic_regression.py index 710cf6d..990f569 100644 --- a/methods/scikit/logistic_regression.py +++ b/methods/scikit/logistic_regression.py @@ -85,7 +85,7 @@ def RunLogisticRegressionScikit(): # Use the last row of the training set as the responses. X, y = SplitTrainData(self.dataset) if "algorithm" in options: - self.opts["algorithm"] = str(options.pop("algorithm")) + self.opts["solver"] = str(options.pop("algorithm")) if "epsilon" in options: self.opts["epsilon"] = float(options.pop("epsilon")) if "max_iterations" in options: diff --git a/methods/shogun/logistic_regression.py b/methods/shogun/logistic_regression.py index 5031d16..c6ea056 100644 --- a/methods/shogun/logistic_regression.py +++ b/methods/shogun/logistic_regression.py @@ -51,6 +51,7 @@ def __init__(self, dataset, timeout=0, verbose=True): self.predictions = None self.z = 1 self.model = None + self.max_iter = None ''' Build the model for the Logistic Regression. @@ -63,6 +64,8 @@ def BuildModel(self, data, responses): # Create and train the classifier. model = MulticlassLogisticRegression(self.z, RealFeatures(data.T), MulticlassLabels(responses)) + if self.max_iter is not None: + model.set_max_iter(self.max_iter); model.train() return model @@ -87,6 +90,10 @@ def RunLogisticRegressionShogun(): # Use the last row of the training set as the responses. X, y = SplitTrainData(self.dataset) + # Get the maximum number of iterations. + if "max_iterations" in options: + self.max_iter = int(options.pop("max_iterations")) + # Get the regularization value. if "lambda" in options: self.z = float(options.pop("lambda")) diff --git a/methods/weka/logistic_regression.py b/methods/weka/logistic_regression.py index 107e5b2..9953e01 100644 --- a/methods/weka/logistic_regression.py +++ b/methods/weka/logistic_regression.py @@ -69,6 +69,11 @@ def __del__(self): def RunMetrics(self, options): Log.Info("Perform Logistic Regression.", self.verbose) + maxIterStr = "" + if 'max_iterations' in options: + maxIterStr = " -m " + str(options['max_iterations']) + " " + options.pop('max_iterations') + if len(options) > 0: Log.Fatal("Unknown parameters: " + str(options)) raise Exception("unknown parameters") @@ -79,8 +84,8 @@ def RunMetrics(self, options): # Split the command using shell-like syntax. cmd = shlex.split("java -classpath " + self.path + "/weka.jar" + - ":methods/weka" + " LOGISTICREGRESSION -t " + self.dataset[0] + " -T " + - self.dataset[1]) + ":methods/weka" + " LogisticRegression -t " + self.dataset[0] + " -T " + + self.dataset[1] + maxIterStr) # Run command with the nessecary arguments and return its output as a byte # string. We have untrusted input so we disable all shell based features. @@ -105,11 +110,14 @@ def RunMetrics(self, options): truelabels = np.genfromtxt(self.dataset[2], delimiter = ',') metrics['Runtime'] = timer.total_time confusionMatrix = Metrics.ConfusionMatrix(truelabels, predictions) - metrics['ACC'] = Metrics.AverageAccuracy(confusionMatrix) - metrics['MCC'] = Metrics.MCCMultiClass(confusionMatrix) - metrics['Precision'] = Metrics.AvgPrecision(confusionMatrix) - metrics['Recall'] = Metrics.AvgRecall(confusionMatrix) - metrics['MSE'] = Metrics.SimpleMeanSquaredError(truelabels, predictions) + metrics['Avg Accuracy'] = Metrics.AverageAccuracy(confusionMatrix) + metrics['MultiClass Precision'] = Metrics.AvgPrecision(confusionMatrix) + metrics['MultiClass Recall'] = Metrics.AvgRecall(confusionMatrix) + metrics['MultiClass FMeasure'] = Metrics.AvgFMeasure(confusionMatrix) + metrics['MultiClass Lift'] = Metrics.LiftMultiClass(confusionMatrix) + metrics['MultiClass MCC'] = Metrics.MCCMultiClass(confusionMatrix) + metrics['MultiClass Information'] = Metrics.AvgMPIArray(confusionMatrix, truelabels, predictions) + metrics['Simple MSE'] = Metrics.SimpleMeanSquaredError(truelabels, predictions) Log.Info(("total time: %fs" % (metrics['Runtime'])), self.verbose) diff --git a/methods/weka/src/LOGISTICREGRESSION.java b/methods/weka/src/LOGISTICREGRESSION.java deleted file mode 100644 index 0b08949..0000000 --- a/methods/weka/src/LOGISTICREGRESSION.java +++ /dev/null @@ -1,102 +0,0 @@ -/** - * @file LOGISTICREGRESSION.java - * - * Logistic Regression with weka. - */ - -import weka.classifiers.Classifier; -import weka.classifiers.RandomizableClassifier; -import weka.classifiers.functions.Logistic; -import weka.core.Instances; -import weka.core.Utils; -import weka.core.converters.ConverterUtils.DataSource; -import weka.filters.Filter; -import weka.filters.unsupervised.attribute.NumericToNominal; -import java.io.File; -import java.io.FileWriter; -import java.io.BufferedWriter; -import weka.core.Attribute; -import java.util.List; -import java.util.ArrayList; -/** - * This class use the weka libary to implement Logistic Regression. - */ -public class LOGISTICREGRESSION { - - private static final String USAGE = String - .format("This program trains the Logistic Regression on the given\n" - + "labeled training set and then uses the trained classifier to classify\n" - + "the points in the given test set.\n\n" - + "Required options:\n" - + "-T [string] A file containing the test set.\n" - + "-t [string] A file containing the training set."); - - public static void main(String args[]) { - Timers timer = new Timers(); - try { - // Get the data set path. - String trainFile = Utils.getOption('t', args); - String testFile = Utils.getOption('T', args); - if (trainFile.length() == 0 || testFile.length() == 0) - throw new IllegalArgumentException(); - - // Load train and test dataset. - DataSource source = new DataSource(trainFile); - Instances trainData = source.getDataSet(); - - // Use the last row of the training data as the labels. - trainData.setClassIndex((trainData.numAttributes() - 1)); - DataSource testsource = new DataSource(testFile); - Instances testData = testsource.getDataSet(); - - // Add pseudo class to the test set if no class information is provided. - if (testData.numAttributes() < trainData.numAttributes()) { - List labelslist = new ArrayList(); - for (int i = 0; i < trainData.classAttribute().numValues(); i++) { - labelslist.add(trainData.classAttribute().value(i)); - } - - testData.insertAttributeAt(new Attribute("class", labelslist), - testData.numAttributes()); - } - - // Use the last row of the training data as the labels. - testData.setClassIndex((testData.numAttributes() - 1)); - - timer.StartTimer("total_time"); - // Create and train the classifier. - Classifier cModel = (Classifier)new Logistic(); - cModel.buildClassifier(trainData); - - // Run Decision Tree Classifier on the test dataset. - // Write predicted class values for each intance to - // benchmarks/weka_predicted.csv. - double prediction = 0; - try{ - File predictedlabels = new File("weka_predicted.csv"); - if(!predictedlabels.exists()) { - predictedlabels.createNewFile(); - } - FileWriter writer = new FileWriter(predictedlabels.getName(), false); - - for (int i = 0; i < testData.numInstances(); i++) { - prediction = cModel.classifyInstance(trainData.instance(i)); - String pred = Double.toString(prediction+1); - writer.write(pred); - writer.write("\n"); - } - - writer.close(); - } catch(Exception e) { - e.printStackTrace(); - } - timer.StopTimer("total_time"); - timer.PrintTimer("total_time"); - - } catch (IllegalArgumentException e) { - System.err.println(USAGE); - } catch (Exception e) { - e.printStackTrace(); - } - } -} diff --git a/methods/weka/src/LogisticRegression.java b/methods/weka/src/LogisticRegression.java index 2f1a996..ae48905 100644 --- a/methods/weka/src/LogisticRegression.java +++ b/methods/weka/src/LogisticRegression.java @@ -8,6 +8,7 @@ import java.io.IOException; import weka.core.*; import weka.core.converters.ConverterUtils.DataSource; +import weka.core.converters.CSVLoader; import weka.filters.Filter; import weka.filters.unsupervised.attribute.NumericToNominal; @@ -29,7 +30,8 @@ public class LogisticRegression { + " the last row of the input file.\n\n" + "Options:\n\n" + "-t [string] Optional file containing containing\n" - + " test dataset"); + + " test dataset\n" + + "-m [int] Maximum number of iterations\n"); public static HashMap createClassMap(Instances Data) { HashMap classMap = new HashMap(); @@ -69,6 +71,8 @@ public static void main(String args[]) { // Load input dataset. DataSource source = new DataSource(regressorsFile); + if (source.getLoader() instanceof CSVLoader) + ((CSVLoader) source.getLoader()).setNoHeaderRowPresent(true); Instances data = source.getDataSet(); // Transform numeric class to nominal class because the @@ -81,12 +85,19 @@ public static void main(String args[]) { nm.setInputFormat(data); data = Filter.useFilter(data, nm); + boolean hasMaxIters = false; + int maxIter = Integer.parseInt(Utils.getOption('m', args)); + if (maxIter != 0) + hasMaxIters = true; + // Did the user pass a test file? String testFile = Utils.getOption('t', args); Instances testData = null; if (testFile.length() != 0) { source = new DataSource(testFile); + if (source.getLoader() instanceof CSVLoader) + ((CSVLoader) source.getLoader()).setNoHeaderRowPresent(true); testData = source.getDataSet(); // Weka makes the assumption that the structure of the training and test @@ -122,6 +133,8 @@ public static void main(String args[]) { // Perform Logistic Regression. timer.StartTimer("total_time"); weka.classifiers.functions.Logistic model = new weka.classifiers.functions.Logistic(); + if (hasMaxIters) + model.setMaxIts(maxIter); model.buildClassifier(data); // Use the testdata to evaluate the modell. @@ -140,7 +153,7 @@ public static void main(String args[]) { } FileWriter writer = new FileWriter(probabs.getName(), false); - File predictions = new File("weka_lr_predictions.csv"); + File predictions = new File("weka_predicted.csv"); if(!predictions.exists()) { predictions.createNewFile(); } diff --git a/util/timer.py b/util/timer.py index fe39880..58d196c 100644 --- a/util/timer.py +++ b/util/timer.py @@ -64,10 +64,10 @@ def timeout(fun, timeout=9000): p.join() Log.Warn("Script timed out after " + str(timeout) + " seconds") - return -2 + return [-2] else: try: r = q.get(timeout=3) except Exception as e: - r = -1 + r = [-1] return r