-
Notifications
You must be signed in to change notification settings - Fork 2
/
tree.h
47 lines (43 loc) · 1.13 KB
/
tree.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
//
// Created by squall on 18-6-11.
//
#ifndef MTREE_TREE_H
#define MTREE_TREE_H
#include "Node.h"
#include "LinearLoss.h"
class Tree {
public:
Tree(int max_depth, float lambda,
float beta,
int feature_size,
int min_sample_num,
float learning_rate,
string regularization,
Loss *loss,
Updater *objective) : max_depth(max_depth),
lambda(lambda),
beta(beta),
feature_size(feature_size),
min_sample_num(min_sample_num),
learning_rate(learning_rate),
regularization(regularization) {
root = NULL;
this->loss = loss;
this->objective = objective;
}
int train(Dataset const &dataset);
int predict(const Dataset &dataset, vector<float> &pred);
private:
Node *root;
float lambda;
float beta; // regularization coefficient
int max_depth;
Loss *loss;
int feature_size;
Updater *objective;
int min_sample_num;
float learning_rate;
string regularization;
vector<Node *> nodes;
};
#endif //MTREE_TREE_H