diff --git a/methods/scikit/dtc.py b/methods/scikit/dtc.py index 46eb69f..ea5b2c1 100644 --- a/methods/scikit/dtc.py +++ b/methods/scikit/dtc.py @@ -50,7 +50,16 @@ def __init__(self, dataset, timeout=0, verbose=True): self.criterion = 'gini' self.max_depth = None self.seed = 0 - + self.splitter = 'best' + self.max_depth = None + self.min_samples_split = 2 + self.min_samples_leaf = 1 + self.min_weight_fraction_leaf = 0.0 + self.max_features = None + self.random_state = None + self.min_impurity_split = 1e-07 + self.class_weight = None + self.presort = False ''' Build the model for the Decision Tree Classifier. @@ -62,7 +71,15 @@ def BuildModel(self, data, labels): # Create and train the classifier. dtc = DecisionTreeClassifier(criterion=self.criterion, max_depth=self.max_depth, - random_state=self.seed) + random_state=self.seed, + splitter = self.splitter, + min_samples_split = self.min_samples_split, + min_weight_fraction_leaf=self.min_weight_fraction_leaf, + max_features = self.max_features, + max_leaf_nodes = self.max_leaf_nodes, + min_impurity_split = self.min_impurity_split, + class_weight = self.class_weight, + presort = self.presort) dtc.fit(data, labels) return dtc @@ -85,9 +102,18 @@ def RunDTCScikit(q): c = re.search("-c (\s+)", options) d = re.search("-d (\s+)", options) s = re.search("-s (\d+)", options) - + mss = re.search("--min_samples_split (\d+)", options) + msl = re.search("--min_samples_leaf (\d+)", options) + mf = re.search("--max_features (\d+)", options) + mln = re.search("--max_leaf_nodes (\d+)", options) + self.criterion = 'gini' if not c else str(c.group(1)) self.max_depth = None if not d else int(d.group(1)) + self.splitter = 'best' if not s else str(s.group(1)) + self.min_samples_split = 2 if not mss else int(mss.group(1)) + self.min_samples_leaf = 1 if not msl else int(msl.group(1)) + self.max_features = None if not mf else int(mf.group(1)) + self.max_leaf_nodes = None if not mln else int(mln.group(1)) self.seed = 0 if not s else int(s.group(1)) try: @@ -148,3 +174,4 @@ def RunMetrics(self, options): metrics['MSE'] = Metrics.SimpleMeanSquaredError(truelabels, predictedlabels) return metrics + diff --git a/return_version.py b/return_version.py new file mode 100644 index 0000000..defd322 --- /dev/null +++ b/return_version.py @@ -0,0 +1,22 @@ +import yaml + +class VERSION_INFO(object): + + def ret_version_info(): + + stream = open("config.yaml", "r") + docs = yaml.load_all(stream) + count=1 + ans = {} + for doc in docs: + if count==1: + ans = doc + else: + break + count+=1 + versions = ans['settings']['version'] + libraries = ans['settings']['libraries'] + lib_ver = {} + for i in range(len(versions)): + lib_ver[libraries[i]] = versions[i] + return lib_ver