-
Notifications
You must be signed in to change notification settings - Fork 2
/
Booster.h
65 lines (60 loc) · 1.82 KB
/
Booster.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
//
// Created by squall on 18-6-12.
//
#ifndef MTREE_BOOSTER_H
#define MTREE_BOOSTER_H
#include "tree.h"
#include <ctime>
template<typename LOSS, typename UPDATER>
class Booster {
public:
Booster(int max_num_round,
int common_num_round,
int max_depth,
float lambda,
float beta,
int min_sample_leaf,
float learning_rate,
string regularization) {
this->max_num_round = max_num_round;
this->common_num_round = common_num_round;
this->max_depth = max_depth;
this->lambda = lambda;
this->beta = beta;
this->min_sample_leaf = min_sample_leaf;
this->learning_rate = learning_rate;
this->regularization = regularization;
}
int train(Dataset &dataset,
Dataset &eval_set,
string eval_metric,
int early_stopping_rounds,
bool verbose);
int predict(Dataset &dataset, vector<float> &score, const string &log_path);
// 对每个task单独预测
int single_predict(const Dataset &dataset, vector<float> &pred, const int &task_id, float &loss_score);
// calculate loss score
int calculate_loss_score(const vector<float> &label,
const vector<float> &pred,
const string &eval_metric,
const int &task_id,
float &loss_score);
private:
int max_num_round;
int common_num_round;
int max_depth;
float lambda;
float beta;
int min_sample_leaf;
float learning_rate;
string regularization;
vector<int> single_num_rounds;
string eval_metric;
int early_stopping_rounds;
bool verbose;
vector<Tree *> common_trees;
vector<vector<Tree *>> single_trees;
vector<pair<int, float> > commmon_best_iterations;
vector<pair<int, float> > single_best_iterations;
};
#endif //MTREE_BOOSTER_H