Skip to content

Commit

Permalink
Merge pull request #40 from Iron-Stark/logtest
Browse files Browse the repository at this point in the history
Benchmark test file for Logistic Regression.
  • Loading branch information
zoq authored Apr 5, 2017
2 parents ab37f16 + 048b03c commit ca2094a
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 3 deletions.
54 changes: 54 additions & 0 deletions tests/benchmark_logistic_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
'''
@file benchmark_logistic_regression.py
Test for the Logistic Regression Classifier scripts.
'''

import unittest

import os, sys, inspect

# Import the util path, this method even works if the path contains
# symlinks to modules.
cmd_subfolder = os.path.realpath(os.path.abspath(os.path.join(
os.path.split(inspect.getfile(inspect.currentframe()))[0], '../util')))
if cmd_subfolder not in sys.path:
sys.path.insert(0, cmd_subfolder)

from loader import *
'''
Test the Scikit Logistic Regression Classifier script.
'''
class LR_SCIKIT_TEST(unittest.TestCase):

'''
Test initialization.
'''
def setUp(self):
self.dataset = ['datasets/iris_train.csv', 'datasets/iris_test.csv']
self.verbose = False
self.timeout = 9000

module = Loader.ImportModuleFromPath("methods/scikit/logistic_regression.py")
obj = getattr(module, "LogisticRegression")
self.instance = obj(self.dataset, verbose=self.verbose, timeout=self.timeout)

'''
Test the constructor.
'''
def test_Constructor(self):
self.assertEqual(self.instance.verbose, self.verbose)
self.assertEqual(self.instance.timeout, self.timeout)
self.assertEqual(self.instance.dataset, self.dataset)

'''
Test the 'RunMetrics' function.
'''
def test_RunMetrics(self):
result = self.instance.RunMetrics("")
self.assertTrue(result["Runtime"] > 0)


if __name__ == '__main__':
unittest.main()

7 changes: 4 additions & 3 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
'benchmark_allkfn',
'benchmark_allknn',
'benchmark_allkrann',
'benchmark_ann',
'benchmark_det',
'benchmark_emst',
'benchmark_fastmks',
Expand All @@ -21,17 +22,17 @@
'benchmark_linear_regression',
'benchmark_linear_ridge_regression',
'benchmark_local_coordinate_coding',
'benchmark_logistic_regression',
'benchmark_lasso',
'benchmark_lsh',
'benchmark_nbc',
'benchmark_nca',
'benchmark_nmf',
'benchmark_pca',
'benchmark_random_forest',
'benchmark_range_search',
'benchmark_sparse_coding',
'benchmark_svr',
'benchmark_ann',
'benchmark_random_forest'
'benchmark_svr'
]

def load_tests(loader, tests, pattern):
Expand Down

0 comments on commit ca2094a

Please sign in to comment.