-
Notifications
You must be signed in to change notification settings - Fork 2
/
tree.cpp
147 lines (138 loc) · 4.66 KB
/
tree.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
//
// Created by squall on 18-6-11.
//
#include <queue>
#include <iostream>
#include "tree.h"
int Tree::train(Dataset const &dataset) {
/*
* 传到这里以后树的gradient应该已经被设定好了。
*/
vector<int> sample_index;
for (int i = 0; i < dataset.get_label_data().size(); ++i) {
sample_index.push_back(i);
}
cout << "begin generate root node" << endl;
root = new Node(sample_index, NULL, this->objective, this->min_sample_num);
root->calc_node_score(dataset.get_gradients(), lambda);
if (dataset.get_task_num() > 1) {
root->calc_node_scores(dataset, lambda);
root->find_split_point_common(dataset, lambda, beta, this->regularization);
} else {
root->find_split_point(dataset, lambda);
}
root->set_is_leaf(false);
cout << "end generate root node" << endl;
queue<Node *> node_queue;
if (root->get_left_node() != NULL)
node_queue.push(root->get_left_node());
if (root->get_right_node() != NULL)
node_queue.push(root->get_right_node());
for (int i = 1; i < this->max_depth; ++i) {
int queue_size = node_queue.size();
for (int j = 0; j < queue_size; j++) {
// cout<<"max_depth "<<i<<" queue "<<j<<endl;
Node *tmp = node_queue.front();
if (tmp != NULL && tmp->get_sample_size() > min_sample_num) {
tmp->calc_node_score(dataset.get_gradients(), lambda);
// not null and the sample size is greater than min_sample_num
if (dataset.get_task_num() > 1) {
tmp->calc_node_scores(dataset, lambda);
tmp->find_split_point_common(dataset, lambda, beta, this->regularization);
} else {
tmp->find_split_point(dataset, lambda);
}
if (tmp->get_right_node() == NULL && tmp->get_left_node() == NULL) {
tmp->set_is_leaf(true);
tmp->calc_node_weight(dataset.get_gradients(), lambda);
if (this->learning_rate > 0) {
tmp->set_weight(tmp->get_weight() * (this->learning_rate));
}
} else {
node_queue.push(tmp->get_left_node());
node_queue.push(tmp->get_right_node());
}
} else {
if (tmp != NULL) {
tmp->calc_node_score(dataset.get_gradients(), lambda);
if (dataset.get_task_num() > 1) {
tmp->calc_node_scores(dataset, lambda);
}
tmp->calc_node_weight(dataset.get_gradients(), lambda);
if (this->learning_rate > 0) {
tmp->set_weight(tmp->get_weight() * (this->learning_rate));
}
tmp->set_is_leaf(true);
}
}
if (tmp != NULL) {
nodes.push_back(tmp);
}
node_queue.pop();
}
}
// calculate the rest of leaf node score.
int node_size = node_queue.size();
cout << "This is rest of node size " << node_size << endl;
for (int i = 0; i < node_size; ++i) {
Node *tmp = node_queue.front();
// cout << "this is leaf sample size: " << tmp->get_sample_size() << endl;
tmp->calc_node_score(dataset.get_gradients(), lambda);
// cout << dataset.get_task_num() << endl;
if (dataset.get_task_num() > 1) {
tmp->calc_node_scores(dataset, lambda);
}
tmp->calc_node_weight(dataset.get_gradients(), lambda);
tmp->set_is_leaf(true);
if (this->learning_rate > 0) {
// cout << "this is weight: " << tmp->get_weight() << endl;
tmp->set_weight(tmp->get_weight() * (this->learning_rate));
// cout << "leaf node weight : " << tmp->get_weight() << endl;
}
node_queue.pop();
}
#ifdef DEBUG
cout << "train tree over" << endl;
for(int i = 0; i < nodes.size(); i++) {
if (nodes[i] != NULL) {
if (nodes[i]->get_is_leaf()) {
cout << nodes[i]->get_weight() << endl;
}
}
}
#endif
return 0;
}
int Tree::predict(const Dataset &dataset, vector<float> &pred) {
if (dataset.get_data_size() == 0 || root == NULL) {
return TREE_PREDICT_ERROR;
}
cout << "go into the predict" << endl;
const Matrix &data = dataset.get_data();
for (int i = 0; i < dataset.get_data_size(); ++i) {
Node *current = root;
while (!current->get_is_leaf()) {
int index = current->get_feature_index();
float cut_point = current->get_cut_point();
#ifdef DEBUG
cout << "this is index of the feature:" << index << endl;
cout << "this is cut_point:" << cut_point << endl;
cout << "data is : " << data[index][i] << endl;
#endif
if (data[index][i] >= cut_point) {
current = current->get_right_node();
} else {
current = current->get_left_node();
}
#ifdef DEBUG
if (current == NULL) {
cout << "this is break" << endl;
break;
}
#endif
}
float score = current->get_weight();
pred.push_back(score);
}
return SUCCESS;
}