Skip to content

Commit

Permalink
Merge pull request #43 from Iron-Stark/dectree
Browse files Browse the repository at this point in the history
Added more options like min_samples_split,min_samples_leaf,max_features and max_nodes in Decision Tree scikit implementation.
  • Loading branch information
zoq authored Apr 6, 2017
2 parents 4d9c8ec + 9c70c11 commit 09148bf
Showing 1 changed file with 30 additions and 3 deletions.
33 changes: 30 additions & 3 deletions methods/scikit/dtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,16 @@ def __init__(self, dataset, timeout=0, verbose=True):
self.criterion = 'gini'
self.max_depth = None
self.seed = 0

self.splitter = 'best'
self.max_depth = None
self.min_samples_split = 2
self.min_samples_leaf = 1
self.min_weight_fraction_leaf = 0.0
self.max_features = None
self.random_state = None
self.min_impurity_split = 1e-07
self.class_weight = None
self.presort = False
'''
Build the model for the Decision Tree Classifier.
Expand All @@ -62,7 +71,15 @@ def BuildModel(self, data, labels):
# Create and train the classifier.
dtc = DecisionTreeClassifier(criterion=self.criterion,
max_depth=self.max_depth,
random_state=self.seed)
random_state=self.seed,
splitter = self.splitter,
min_samples_split = self.min_samples_split,
min_weight_fraction_leaf=self.min_weight_fraction_leaf,
max_features = self.max_features,
max_leaf_nodes = self.max_leaf_nodes,
min_impurity_split = self.min_impurity_split,
class_weight = self.class_weight,
presort = self.presort)
dtc.fit(data, labels)
return dtc

Expand All @@ -85,9 +102,18 @@ def RunDTCScikit(q):
c = re.search("-c (\s+)", options)
d = re.search("-d (\s+)", options)
s = re.search("-s (\d+)", options)

mss = re.search("--min_samples_split (\d+)", options)
msl = re.search("--min_samples_leaf (\d+)", options)
mf = re.search("--max_features (\d+)", options)
mln = re.search("--max_leaf_nodes (\d+)", options)

self.criterion = 'gini' if not c else str(c.group(1))
self.max_depth = None if not d else int(d.group(1))
self.splitter = 'best' if not s else str(s.group(1))
self.min_samples_split = 2 if not mss else int(mss.group(1))
self.min_samples_leaf = 1 if not msl else int(msl.group(1))
self.max_features = None if not mf else int(mf.group(1))
self.max_leaf_nodes = None if not mln else int(mln.group(1))
self.seed = 0 if not s else int(s.group(1))

try:
Expand Down Expand Up @@ -148,3 +174,4 @@ def RunMetrics(self, options):
metrics['MSE'] = Metrics.SimpleMeanSquaredError(truelabels, predictedlabels)

return metrics

0 comments on commit 09148bf

Please sign in to comment.