-
Notifications
You must be signed in to change notification settings - Fork 2
/
Random.h
118 lines (109 loc) · 2.83 KB
/
Random.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
//
// Created by zebang.zhzb on 2018/7/14.
//
#ifndef MTREE_RANDOM_H
#define MTREE_RANDOM_H
/*
* This Random implementation is copy from github: https://github.com/Microsoft/LightGBM/blob/master/include/LightGBM/utils/random.h
*/
#include <cstdint>
#include <random>
#include <vector>
#include <set>
namespace common {
using namespace std;
/*!
* \brief A wrapper for random generator
*/
class Random {
public:
/*!
* \brief Constructor, with random seed
*/
Random() {
std::random_device rd;
auto genrator = std::mt19937(rd());
std::uniform_int_distribution<int> distribution(0, x);
x = distribution(genrator);
}
/*!
* \brief Constructor, with specific seed
*/
Random(int seed) {
x = seed;
}
/*!
* \brief Generate random integer, int16 range. [0, 65536]
* \param lower_bound lower bound
* \param upper_bound upper bound
* \return The random integer between [lower_bound, upper_bound)
*/
inline int NextShort(int lower_bound, int upper_bound) {
return (RandInt16()) % (upper_bound - lower_bound) + lower_bound;
}
/*!
* \brief Generate random integer, int32 range
* \param lower_bound lower bound
* \param upper_bound upper bound
* \return The random integer between [lower_bound, upper_bound)
*/
inline int NextInt(int lower_bound, int upper_bound) {
return (RandInt32()) % (upper_bound - lower_bound) + lower_bound;
}
/*!
* \brief Generate random float data
* \return The random float between [0.0, 1.0)
*/
inline float NextFloat() {
// get random float in [0,1)
return static_cast<float>(RandInt16()) / (32768.0f);
}
/*!
* \brief Sample K data from {0,1,...,N-1}
* \param N
* \param K
* \return K Ordered sampled data from {0,1,...,N-1}
*/
inline std::vector<int> Sample(int N, int K) {
std::vector<int> ret;
ret.reserve(K);
if (K > N || K <= 0) {
return ret;
} else if (K == N) {
for (int i = 0; i < N; ++i) {
ret.push_back(i);
}
} else if (K > 1 && K > (N / std::log2(K))) {
for (int i = 0; i < N; ++i) {
double prob = (K - ret.size()) / static_cast<double>(N - i);
if (NextFloat() < prob) {
ret.push_back(i);
}
}
} else {
std::set<int> sample_set;
while (static_cast<int>(sample_set.size()) < K) {
int next = RandInt32() % N;
if (sample_set.count(next) == 0) {
sample_set.insert(next);
}
}
for (auto iter = sample_set.begin(); iter != sample_set.end(); ++iter) {
ret.push_back(*iter);
}
}
return ret;
}
private:
inline int RandInt16() {
x = (214013 * x + 2531011);
return static_cast<int>((x >> 16) & 0x7FFF);
}
inline int RandInt32() {
x = (214013 * x + 2531011);
return static_cast<int>(x & 0x7FFFFFFF);
}
unsigned int x = 123456789;
};
}
#endif //MTREE_RANDOM_H