-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathutils.py
122 lines (94 loc) · 3.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import queue
import subprocess
import time
from multiprocessing import Queue
from threading import Thread
import numpy as np
import tensorflow as tf
def rewards_to_discounted_returns(rewards, discount_factor):
returns = np.zeros_like(rewards, dtype=np.float32)
returns[-1] = rewards[-1]
for i in range(len(rewards) - 2, -1, -1):
returns[i] = rewards[i] + discount_factor * returns[i + 1]
return returns
def get_git_rev():
try:
cmd = 'git rev-parse --short HEAD'
git_rev = subprocess.check_output(cmd.split(' '), stderr=subprocess.PIPE).decode().rstrip()
return git_rev
except subprocess.CalledProcessError:
return 'unkrev'
class MemoryProfiler:
STOP_CMD = 0
def __init__(self, pid, log_path):
self.pid = pid
self.log_path = log_path
self.cmd_queue = Queue()
self.t = None
def start(self):
self.t = Thread(target=self.profile)
self.t.start()
def stop(self):
self.cmd_queue.put(self.STOP_CMD)
self.t.join()
def profile(self):
import memory_profiler
f = open(self.log_path, 'w+')
while True:
# 5 samples, 1 second apart
memory_profiler.memory_usage(self.pid, stream=f, timeout=5, interval=1,
include_children=True)
f.flush()
try:
cmd = self.cmd_queue.get(timeout=0.1)
if cmd == self.STOP_CMD:
f.close()
break
except queue.Empty:
pass
class Timer:
"""
A simple timer class.
* Set the timer duration with the `duration_seconds` argument to the constructor.
* Start the timer by calling `reset()`.
* Check whether the timer is done by calling `done()`.
"""
def __init__(self, duration_seconds):
self.duration_seconds = duration_seconds
self.start_time = None
def reset(self):
self.start_time = time.time()
def done(self):
cur_time = time.time()
if cur_time - self.start_time > self.duration_seconds:
return True
else:
return False
class TensorFlowCounter:
"""
Counter implemented as a TensorFlow variable in the provided session's graph.
Useful if you want the value to feed into some other operation, e.g. learning rate calculation.
"""
def __init__(self, sess):
self.sess = sess
self.value = tf.Variable(0, trainable=False)
self.increment_by = tf.placeholder(tf.int32)
self.increment_op = self.value.assign_add(self.increment_by)
def __int__(self):
return int(self.sess.run(self.value))
def increment(self, n=1):
self.sess.run(self.increment_op, feed_dict={self.increment_by: n})
class RateMeasure:
def __init__(self):
self.prev_t = self.prev_value = None
def reset(self, val):
self.prev_value = val
self.prev_t = time.time()
def measure(self, val):
val_change = val - self.prev_value
cur_t = time.time()
interval = cur_t - self.prev_t
rate = val_change / interval
self.prev_t = cur_t
self.prev_value = val
return rate