-
Notifications
You must be signed in to change notification settings - Fork 0
/
client.py
168 lines (124 loc) · 4.92 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# type: ignore
import asyncio
from redis import Redis
from uuid import UUID
from typing import Optional, Union, List
from ai.util import gen_uuid
from ai.util.timer import Timer
from ai.infer.common import QUEUES, encode, decode
SLEEP = 0.1
class InferencerUnresponsive(Exception):
'''Raised by InferenceClient.ping() after n seconds without a response.'''
class InferenceClient:
'''Client of an inference worker.
The __call__ method can be used as if it is the model, and
it will send a request to the worker. For example:
```
z = model(x, y)
```
is the same as
```
z = inference_client(x, y)
```
NOTE: should only be created by an Inferencer.
NOTE: InferencerClient sets up the Redis client JIT because it can't be
pickled. So don't use the InferencerClient until after it's been sent to the
data worker.
NOTE: if you have multiple requests, it's much faster to use the multi_infer
method instead of using __call__ multiple times in succession.
'''
def __init__(s, redis_cfg: Union[list, tuple]):
'''
redis_cfg : list/tuple of length 3
redis host (str), redis port (int), db number (int)
'''
assert len(redis_cfg) == 3
s._redis_cfg = redis_cfg
s._broker = None # redis client initialized JIT because of pickle
def __call__(s, *args, **kwargs):
'''Run inference request.'''
return s.infer(*args, **kwargs)
def ping(s, timeout: int = 5):
'''Raise InferencerUnresponsive exception if inferencer isnt running.
timeout : int
timeout in seconds
'''
req_id = gen_uuid()
s._add_to_queue(QUEUES.ping, req_id)
resp = s.wait_for_resp(req_id, timeout=timeout)
if resp != 'pong':
raise InferencerUnresponsive()
def debug(s) -> dict:
'''Fetch debug information from the worker.'''
req_id = gen_uuid()
s._add_to_queue(QUEUES.debug, req_id)
return s.wait_for_resp(req_id)
def update_params(s, params: dict):
'''Update the parameters of the model.
params : dict
model.state_dict()
'''
s._add_to_queue(QUEUES.update, params)
def infer(s, *args, **kwargs):
'''Run inference request.'''
req_id = s.infer_async(*args, **kwargs)
return s.wait_for_resp(req_id)
def multi_infer(s, reqs):
'''Bundle multiple inference requests.
reqs : list of tuples
each item is a tuple of (args, kwargs) where args is a list of
arguments and kwargs is a dict of keyword arguments
'''
req_ids = s.multi_infer_async(reqs)
return s.wait_for_resps(req_ids)
def infer_async(s, *args, **kwargs) -> UUID:
'''Send inference request and return request id.
Use InferenceClient.wait_for_resp(request_id) to get response.
'''
req_id = gen_uuid()
s._add_to_queue(QUEUES.infer, (req_id, args, kwargs))
return req_id
def multi_infer_async(s, reqs: List[tuple]) -> List[UUID]:
'''Same as infer_async but for multiple requests.
Use InferenceClient.wait_for_resps(request_ids) to get response.
reqs : list of tuples
each item is a tuple of (args, kwargs) where args is a list of
arguments and kwargs is a dict of keyword arguments
'''
req_ids = []
for args, kwargs in reqs:
req_id = gen_uuid()
s._add_to_queue(QUEUES.infer, (req_id, args, kwargs))
req_ids.append(req_id)
return req_ids
def wait_for_resp(s, req_id: UUID, timeout: Optional[int] = None):
'''Wait for a response to be available for the given request id.
req_id : uuid
the request id returned from infer_async
timeout : int or null
optional timeout in seconds (returns None on timeout)
'''
resp = asyncio.run(_async_wait_for_resp(s._broker, req_id, timeout))
return decode(resp)
def wait_for_resps(s, req_ids: List[UUID], timeout: Optional[int] = None):
'''Same as wait_for_resp but for multiple requests.'''
resps = asyncio.run(_async_wait_for_resps(s._broker, req_ids, timeout))
return [decode(resp) for resp in resps]
def _add_to_queue(s, queue, req):
if s._broker is None:
s._setup_broker()
s._broker.rpush(queue, encode(req))
def _setup_broker(s):
host, port, db = s._redis_cfg
s._broker = Redis(host=host, port=port, db=db)
async def _async_wait_for_resps(broker, keys, timeout=None):
return await asyncio.gather(*[
_async_wait_for_resp(broker, key, timeout) for key in keys])
async def _async_wait_for_resp(broker, key, timeout=None):
timer = Timer(timeout)
while 1:
resp = broker.get(key)
if resp is not None or timer():
break
await asyncio.sleep(SLEEP)
return resp