-
Notifications
You must be signed in to change notification settings - Fork 130
/
Copy pathRandomSearch.cc
190 lines (158 loc) · 6.05 KB
/
RandomSearch.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
// A random search for LLVM codesize using the C++ API.
//
// While not intended for the majority of users, it is entirely straightforward
// to skip the Python frontend and interact with the C++ API directly. This file
// demonstrates a simple parallelized random search implemented for the LLVM
// compiler service.
//
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the LICENSE file
// in the root directory of this source tree.
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <stdlib.h>
#include <time.h>
#include <boost/filesystem.hpp>
#include <iostream>
#include <limits>
#include <magic_enum.hpp>
#include <thread>
#include <vector>
#include "compiler_gym/envs/llvm/service/LlvmSession.h"
#include "compiler_gym/envs/llvm/service/ObservationSpaces.h"
#include "compiler_gym/service/runtime/CompilerGymService.h"
#include "compiler_gym/util/GrpcStatusMacros.h"
DEFINE_string(benchmark, "benchmark://cbench-v1/crc32", "The benchmark to use.");
DEFINE_int32(step_count, 100, "The number of steps to run for each random search");
DEFINE_int32(nproc, std::max(1u, std::thread::hardware_concurrency()),
"The number of parallel search threads to use");
namespace fs = boost::filesystem;
namespace compiler_gym {
using grpc::Status;
using llvm_service::LlvmAction;
using llvm_service::LlvmObservationSpace;
using LlvmService = runtime::CompilerGymService<llvm_service::LlvmSession>;
// A wrapper around an LLVM service. Here, we call the RPC enpoints directly
// on the service, we do not use RPC. This means that we do not get the
// reliability benefits of running the compiler service in a separate process,
// but we also do not pay the performance overhead.
template <LlvmObservationSpace observationSpace>
class Environment {
public:
Environment(const fs::path& workingDir, const std::string& benchmark)
: service_(workingDir), benchmark_(benchmark) {}
// Reset the environment and compute the initial observation.
[[nodiscard]] Status reset(Event* observation) {
if (inEpisode_) {
RETURN_IF_ERROR(close());
}
StartSessionRequest startRequest;
startRequest.mutable_benchmark()->set_uri(benchmark_);
StartSessionReply startReply;
RETURN_IF_ERROR(service_.StartSession(nullptr, &startRequest, &startReply));
sessionId_ = startReply.session_id();
StepRequest stepRequest;
StepReply stepReply;
stepRequest.set_session_id(sessionId_);
stepRequest.add_observation_space(static_cast<int>(observationSpace));
RETURN_IF_ERROR(service_.Step(nullptr, &stepRequest, &stepReply));
CHECK(stepReply.observation_size() == 1);
*observation = stepReply.observation(0);
inEpisode_ = true;
return Status::OK;
}
// End the current session.
[[nodiscard]] Status close() {
EndSessionRequest endRequest;
EndSessionReply endReply;
endRequest.set_session_id(sessionId_);
inEpisode_ = false;
return service_.EndSession(nullptr, &endRequest, &endReply);
}
// Apply the given action and compute an observation.
[[nodiscard]] Status step(LlvmAction action, Event* observation) {
StepRequest request;
StepReply reply;
request.set_session_id(sessionId_);
request.add_action()->set_int64_value(static_cast<int>(action));
request.add_observation_space(static_cast<int>(observationSpace));
RETURN_IF_ERROR(service_.Step(nullptr, &request, &reply));
CHECK(reply.observation_size() == 1);
*observation = reply.observation(0);
return Status::OK;
}
private:
LlvmService service_;
const std::string benchmark_;
bool inEpisode_;
int64_t sessionId_;
};
Status runSearch(const fs::path& workingDir, std::vector<int>* bestActions, int64_t* bestCost) {
Environment<LlvmObservationSpace::IR_INSTRUCTION_COUNT> environment(workingDir, FLAGS_benchmark);
// Reset the environment.
Event observation;
RETURN_IF_ERROR(environment.reset(&observation));
*bestCost = observation.int64_value();
// Run a bunch of actions randomly.
srand(time(NULL));
std::vector<int> actions;
for (int i = 0; i < FLAGS_step_count; ++i) {
int action = rand() % magic_enum::enum_count<LlvmAction>();
actions.push_back(action);
Event obs;
RETURN_IF_ERROR(environment.step(static_cast<LlvmAction>(action), &obs));
int64_t cost = obs.int64_value();
if (cost < *bestCost) {
*bestCost = cost;
*bestActions = actions;
}
VLOG(3) << "Step " << action << " " << cost << " " << *bestCost;
}
RETURN_IF_ERROR(environment.close());
return Status::OK;
}
void runThread(std::vector<int>* bestActions, int64_t* bestCost) {
const fs::path workingDir = fs::unique_path();
fs::create_directories(workingDir);
const auto status = runSearch(workingDir, bestActions, bestCost);
if (!status.ok()) {
LOG(ERROR) << "ERROR " << status.error_code() << ": " << status.error_message();
}
fs::remove_all(workingDir);
}
// Run `numThreads` random searches concurrently.
Status runRandomSearches(unsigned numThreads) {
std::cout << "Starting " << numThreads << " random search threads for benchmark "
<< FLAGS_benchmark << std::endl;
std::vector<std::thread> threads;
std::vector<std::vector<int>> actions(numThreads);
std::vector<int64_t> costs(numThreads, INT_MAX);
for (unsigned i = 0; i < numThreads; ++i) {
threads.push_back(std::thread(runThread, &actions[i], &costs[i]));
}
for (auto& thread : threads) {
thread.join();
}
int64_t bestCost = costs[0];
unsigned bestThread = 0;
for (unsigned i = 0; i < costs.size(); ++i) {
if (costs[i] < bestCost) {
bestCost = costs[i];
bestThread = i;
}
}
std::cout << "Lowest cost achieved: " << bestCost << std::endl;
std::cout << "Actions: ";
for (auto action : actions[bestThread]) {
std::cout << action << " ";
}
std::cout << std::endl;
return Status::OK;
}
} // namespace compiler_gym
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, /*remove_flags=*/false);
google::InitGoogleLogging(argv[0]);
CHECK(compiler_gym::runRandomSearches(FLAGS_nproc).ok());
}