Skip to content

Commit

Permalink
Merge pull request mlpack#43 from Iron-Stark/dectree
Browse files Browse the repository at this point in the history
Added more options like min_samples_split,min_samples_leaf,max_features and max_nodes in Decision Tree scikit implementation.
  • Loading branch information
zoq authored and Iron-Stark committed Apr 6, 2017
2 parents 4d9c8ec + 9c70c11 commit 6471cd1
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
33 changes: 30 additions & 3 deletions methods/scikit/dtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,16 @@ def __init__(self, dataset, timeout=0, verbose=True):
self.criterion = 'gini'
self.max_depth = None
self.seed = 0

self.splitter = 'best'
self.max_depth = None
self.min_samples_split = 2
self.min_samples_leaf = 1
self.min_weight_fraction_leaf = 0.0
self.max_features = None
self.random_state = None
self.min_impurity_split = 1e-07
self.class_weight = None
self.presort = False
'''
Build the model for the Decision Tree Classifier.
Expand All @@ -62,7 +71,15 @@ def BuildModel(self, data, labels):
# Create and train the classifier.
dtc = DecisionTreeClassifier(criterion=self.criterion,
max_depth=self.max_depth,
random_state=self.seed)
random_state=self.seed,
splitter = self.splitter,
min_samples_split = self.min_samples_split,
min_weight_fraction_leaf=self.min_weight_fraction_leaf,
max_features = self.max_features,
max_leaf_nodes = self.max_leaf_nodes,
min_impurity_split = self.min_impurity_split,
class_weight = self.class_weight,
presort = self.presort)
dtc.fit(data, labels)
return dtc

Expand All @@ -85,9 +102,18 @@ def RunDTCScikit(q):
c = re.search("-c (\s+)", options)
d = re.search("-d (\s+)", options)
s = re.search("-s (\d+)", options)

mss = re.search("--min_samples_split (\d+)", options)
msl = re.search("--min_samples_leaf (\d+)", options)
mf = re.search("--max_features (\d+)", options)
mln = re.search("--max_leaf_nodes (\d+)", options)

self.criterion = 'gini' if not c else str(c.group(1))
self.max_depth = None if not d else int(d.group(1))
self.splitter = 'best' if not s else str(s.group(1))
self.min_samples_split = 2 if not mss else int(mss.group(1))
self.min_samples_leaf = 1 if not msl else int(msl.group(1))
self.max_features = None if not mf else int(mf.group(1))
self.max_leaf_nodes = None if not mln else int(mln.group(1))
self.seed = 0 if not s else int(s.group(1))

try:
Expand Down Expand Up @@ -148,3 +174,4 @@ def RunMetrics(self, options):
metrics['MSE'] = Metrics.SimpleMeanSquaredError(truelabels, predictedlabels)

return metrics

22 changes: 22 additions & 0 deletions return_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import yaml

class VERSION_INFO(object):

def ret_version_info():

stream = open("config.yaml", "r")
docs = yaml.load_all(stream)
count=1
ans = {}
for doc in docs:
if count==1:
ans = doc
else:
break
count+=1
versions = ans['settings']['version']
libraries = ans['settings']['libraries']
lib_ver = {}
for i in range(len(versions)):
lib_ver[libraries[i]] = versions[i]
return lib_ver

0 comments on commit 6471cd1

Please sign in to comment.