-
Notifications
You must be signed in to change notification settings - Fork 5
/
hoeffding_tree.go
255 lines (217 loc) · 8.68 KB
/
hoeffding_tree.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
package mlpack
/*
#cgo CFLAGS: -I./capi -Wall
#cgo LDFLAGS: -L. -lmlpack_go_hoeffding_tree
#include <capi/hoeffding_tree.h>
#include <stdlib.h>
*/
import "C"
import "gonum.org/v1/gonum/mat"
type HoeffdingTreeOptionalParam struct {
BatchMode bool
Bins int
Confidence float64
InfoGain bool
InputModel *hoeffdingTreeModel
Labels *mat.Dense
MaxSamples int
MinSamples int
NumericSplitStrategy string
ObservationsBeforeBinning int
Passes int
Test *matrixWithInfo
TestLabels *mat.Dense
Training *matrixWithInfo
Verbose bool
}
func HoeffdingTreeOptions() *HoeffdingTreeOptionalParam {
return &HoeffdingTreeOptionalParam{
BatchMode: false,
Bins: 10,
Confidence: 0.95,
InfoGain: false,
InputModel: nil,
Labels: nil,
MaxSamples: 5000,
MinSamples: 100,
NumericSplitStrategy: "binary",
ObservationsBeforeBinning: 100,
Passes: 1,
Test: nil,
TestLabels: nil,
Training: nil,
Verbose: false,
}
}
/*
This program implements Hoeffding trees, a form of streaming decision tree
suited best for large (or streaming) datasets. This program supports both
categorical and numeric data. Given an input dataset, this program is able to
train the tree with numerous training options, and save the model to a file.
The program is also able to use a trained model or a model from file in order
to predict classes for a given test set.
The training file and associated labels are specified with the "Training" and
"Labels" parameters, respectively. Optionally, if "Labels" is not specified,
the labels are assumed to be the last dimension of the training dataset.
The training may be performed in batch mode (like a typical decision tree
algorithm) by specifying the "BatchMode" option, but this may not be the best
option for large datasets.
When a model is trained, it may be saved via the "OutputModel" output
parameter. A model may be loaded from file for further training or testing
with the "InputModel" parameter.
Test data may be specified with the "Test" parameter, and if performance
statistics are desired for that test set, labels may be specified with the
"TestLabels" parameter. Predictions for each test point may be saved with the
"Predictions" output parameter, and class probabilities for each prediction
may be saved with the "Probabilities" output parameter.
For example, to train a Hoeffding tree with confidence 0.99 with data dataset,
saving the trained tree to tree, the following command may be used:
// Initialize optional parameters for HoeffdingTree().
param := mlpack.HoeffdingTreeOptions()
param.Training = dataset
param.Confidence = 0.99
tree, _, _ := mlpack.HoeffdingTree(param)
Then, this tree may be used to make predictions on the test set test_set,
saving the predictions into predictions and the class probabilities into
class_probs with the following command:
// Initialize optional parameters for HoeffdingTree().
param := mlpack.HoeffdingTreeOptions()
param.InputModel = &tree
param.Test = test_set
_, predictions, class_probs := mlpack.HoeffdingTree(param)
Input parameters:
- BatchMode (bool): If true, samples will be considered in batch
instead of as a stream. This generally results in better trees but at
the cost of memory usage and runtime.
- Bins (int): If the 'domingos' split strategy is used, this specifies
the number of bins for each numeric split. Default value 10.
- Confidence (float64): Confidence before splitting (between 0 and 1).
Default value 0.95.
- InfoGain (bool): If set, information gain is used instead of Gini
impurity for calculating Hoeffding bounds.
- InputModel (hoeffdingTreeModel): Input trained Hoeffding tree model.
- Labels (mat.Dense): Labels for training dataset.
- MaxSamples (int): Maximum number of samples before splitting.
Default value 5000.
- MinSamples (int): Minimum number of samples before splitting.
Default value 100.
- NumericSplitStrategy (string): The splitting strategy to use for
numeric features: 'domingos' or 'binary'. Default value 'binary'.
- ObservationsBeforeBinning (int): If the 'domingos' split strategy is
used, this specifies the number of samples observed before binning is
performed. Default value 100.
- Passes (int): Number of passes to take over the dataset. Default
value 1.
- Test (matrixWithInfo): Testing dataset (may be categorical).
- TestLabels (mat.Dense): Labels of test data.
- Training (matrixWithInfo): Training dataset (may be categorical).
- Verbose (bool): Display informational messages and the full list of
parameters and timers at the end of execution.
Output parameters:
- outputModel (hoeffdingTreeModel): Output for trained Hoeffding tree
model.
- predictions (mat.Dense): Matrix to output label predictions for test
data into.
- probabilities (mat.Dense): In addition to predicting labels, provide
rediction probabilities in this matrix.
*/
func HoeffdingTree(param *HoeffdingTreeOptionalParam) (hoeffdingTreeModel, *mat.Dense, *mat.Dense) {
params := getParams("hoeffding_tree")
timers := getTimers()
disableBacktrace()
disableVerbose()
// Detect if the parameter was passed; set if so.
if param.BatchMode != false {
setParamBool(params, "batch_mode", param.BatchMode)
setPassed(params, "batch_mode")
}
// Detect if the parameter was passed; set if so.
if param.Bins != 10 {
setParamInt(params, "bins", param.Bins)
setPassed(params, "bins")
}
// Detect if the parameter was passed; set if so.
if param.Confidence != 0.95 {
setParamDouble(params, "confidence", param.Confidence)
setPassed(params, "confidence")
}
// Detect if the parameter was passed; set if so.
if param.InfoGain != false {
setParamBool(params, "info_gain", param.InfoGain)
setPassed(params, "info_gain")
}
// Detect if the parameter was passed; set if so.
if param.InputModel != nil {
setHoeffdingTreeModel(params, "input_model", param.InputModel)
setPassed(params, "input_model")
}
// Detect if the parameter was passed; set if so.
if param.Labels != nil {
gonumToArmaUrow(params, "labels", param.Labels)
setPassed(params, "labels")
}
// Detect if the parameter was passed; set if so.
if param.MaxSamples != 5000 {
setParamInt(params, "max_samples", param.MaxSamples)
setPassed(params, "max_samples")
}
// Detect if the parameter was passed; set if so.
if param.MinSamples != 100 {
setParamInt(params, "min_samples", param.MinSamples)
setPassed(params, "min_samples")
}
// Detect if the parameter was passed; set if so.
if param.NumericSplitStrategy != "binary" {
setParamString(params, "numeric_split_strategy", param.NumericSplitStrategy)
setPassed(params, "numeric_split_strategy")
}
// Detect if the parameter was passed; set if so.
if param.ObservationsBeforeBinning != 100 {
setParamInt(params, "observations_before_binning", param.ObservationsBeforeBinning)
setPassed(params, "observations_before_binning")
}
// Detect if the parameter was passed; set if so.
if param.Passes != 1 {
setParamInt(params, "passes", param.Passes)
setPassed(params, "passes")
}
// Detect if the parameter was passed; set if so.
if param.Test != nil {
gonumToArmaMatWithInfo(params, "test", param.Test)
setPassed(params, "test")
}
// Detect if the parameter was passed; set if so.
if param.TestLabels != nil {
gonumToArmaUrow(params, "test_labels", param.TestLabels)
setPassed(params, "test_labels")
}
// Detect if the parameter was passed; set if so.
if param.Training != nil {
gonumToArmaMatWithInfo(params, "training", param.Training)
setPassed(params, "training")
}
// Detect if the parameter was passed; set if so.
if param.Verbose != false {
setParamBool(params, "verbose", param.Verbose)
setPassed(params, "verbose")
enableVerbose()
}
// Mark all output options as passed.
setPassed(params, "output_model")
setPassed(params, "predictions")
setPassed(params, "probabilities")
// Call the mlpack program.
C.mlpackHoeffdingTree(params.mem, timers.mem)
// Initialize result variable and get output.
var outputModel hoeffdingTreeModel
outputModel.getHoeffdingTreeModel(params, "output_model")
var predictionsPtr mlpackArma
predictions := predictionsPtr.armaToGonumUrow(params, "predictions")
var probabilitiesPtr mlpackArma
probabilities := probabilitiesPtr.armaToGonumMat(params, "probabilities")
// Clean memory.
cleanParams(params)
cleanTimers(timers)
// Return output(s).
return outputModel, predictions, probabilities
}