-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcontroller_onnx.py
39 lines (29 loc) · 1.25 KB
/
controller_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import argparse
import msgpack
import onnxruntime
import zmq
from game.common_controller import create_input, output_to_keys
from game.communication import parse_readings_message
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--car", default="red", choices=["red", "blue"])
parser.add_argument("--onnx", required=True)
args = parser.parse_args()
topic: bytes = args.car.encode() + b"_car"
context = zmq.Context.instance()
subscriber: zmq.Socket = context.socket(zmq.SUB)
subscriber.setsockopt(zmq.CONFLATE, 1)
subscriber.connect("tcp://localhost:6000")
subscriber.setsockopt(zmq.SUBSCRIBE, topic)
publisher: zmq.Socket = context.socket(zmq.PUB)
publisher.bind(f"tcp://localhost:{6001 if args.car == 'red' else 6002}")
onnx_session = onnxruntime.InferenceSession(args.onnx)
input_size = onnx_session.get_inputs()[0].shape[1]
while True:
readings = parse_readings_message(subscriber.recv()[len(topic) :])
onnx_input = create_input(readings, input_size)
onnx_output = onnx_session.run(["output_0"], {"input": onnx_input})[0]
keys = output_to_keys(onnx_output, logits=False)
publisher.send(msgpack.packb(keys))
if __name__ == "__main__":
main()