forked from facebookresearch/CompilerGym
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo_without_bazel.py
114 lines (91 loc) · 3.41 KB
/
demo_without_bazel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""This script demonstrates how the Python example service without needing
to use the bazel build system. Usage:
$ python example_compiler_gym_service/demo_without_bazel.py
It is equivalent in behavior to the demo.py script in this directory.
"""
import logging
from pathlib import Path
from typing import Iterable
import gym
from compiler_gym.datasets import Benchmark, Dataset
from compiler_gym.spaces import Reward
from compiler_gym.util.registration import register
from compiler_gym.util.runfiles_path import site_data_path
EXAMPLE_PY_SERVICE_BINARY: Path = Path(
"example_compiler_gym_service/service_py/example_service.py"
)
assert EXAMPLE_PY_SERVICE_BINARY.is_file(), "Service script not found"
class RuntimeReward(Reward):
"""An example reward that uses changes in the "runtime" observation value
to compute incremental reward.
"""
def __init__(self):
super().__init__(
id="runtime",
observation_spaces=["runtime"],
default_value=0,
default_negates_returns=True,
deterministic=False,
platform_dependent=True,
)
self.previous_runtime = None
def reset(self, benchmark: str, observation_view):
del benchmark # unused
self.previous_runtime = None
def update(self, action, observations, observation_view):
del action
del observation_view
if self.previous_runtime is None:
self.previous_runtime = observations[0]
reward = float(self.previous_runtime - observations[0])
self.previous_runtime = observations[0]
return reward
class ExampleDataset(Dataset):
def __init__(self, *args, **kwargs):
super().__init__(
name="benchmark://example-v0",
license="MIT",
description="An example dataset",
site_data_base=site_data_path("example_dataset"),
)
self._benchmarks = {
"benchmark://example-v0/foo": Benchmark.from_file_contents(
"benchmark://example-v0/foo", "Ir data".encode("utf-8")
),
"benchmark://example-v0/bar": Benchmark.from_file_contents(
"benchmark://example-v0/bar", "Ir data".encode("utf-8")
),
}
def benchmark_uris(self) -> Iterable[str]:
yield from self._benchmarks.keys()
def benchmark(self, uri: str) -> Benchmark:
if uri in self._benchmarks:
return self._benchmarks[uri]
else:
raise LookupError("Unknown program name")
# Register the environment for use with gym.make(...).
register(
id="example-v0",
entry_point="compiler_gym.envs:CompilerEnv",
kwargs={
"service": EXAMPLE_PY_SERVICE_BINARY,
"rewards": [RuntimeReward()],
"datasets": [ExampleDataset()],
},
)
def main():
# Use debug verbosity to print out extra logging information.
logging.basicConfig(level=logging.DEBUG)
# Create the environment using the regular gym.make(...) interface.
with gym.make("example-v0") as env:
env.reset()
for _ in range(20):
observation, reward, done, info = env.step(env.action_space.sample())
if done:
env.reset()
if __name__ == "__main__":
main()