-
Notifications
You must be signed in to change notification settings - Fork 6
/
jax_wrappers.py
283 lines (218 loc) · 10.2 KB
/
jax_wrappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
"""This is based on Gymnax's wrappers.py, but modified to work with our multi-agent environments."""
from functools import partial
from typing import Tuple
import chex
import jax
import jax.numpy as jnp
from flax import struct
from crazy_rl.multi_agent.jax.base_parallel_env import State
class Wrapper:
"""Base class for wrappers."""
def __init__(self, env):
self._env = env
# provide proxy access to regular attributes of wrapped object
def __getattr__(self, name):
return getattr(self._env, name)
class VecEnv(Wrapper):
"""Vectorized environment wrapper."""
def __init__(self, env):
super().__init__(env)
self.reset = jax.vmap(self._env.reset, in_axes=(0,))
self.step = jax.vmap(self._env.step)
self.state = jax.vmap(self._env.state, in_axes=(0,))
@struct.dataclass
class LogEnvState:
env_state: State
episode_returns: float
episode_lengths: int
returned_episode_returns: float
returned_episode_lengths: int
timestep: int
total_timestep: int
class LogWrapper(Wrapper):
"""Log the episode returns and lengths."""
def __init__(self, env, reward_dim=1):
self.reward_dim = reward_dim
super().__init__(env)
@partial(jax.jit, static_argnums=(0,))
def reset(self, key: chex.PRNGKey, total_timestep: int = 0) -> Tuple[chex.Array, dict, LogEnvState]:
obs, info, env_state = self._env.reset(key)
state = LogEnvState(env_state, jnp.zeros(self.reward_dim), 0, jnp.zeros(self.reward_dim), 0, 0, total_timestep)
return obs, info, state
@partial(jax.jit, static_argnums=(0,))
def step(
self,
state: LogEnvState,
action: chex.Array,
key: chex.PRNGKey,
) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array, dict, LogEnvState]:
obs, rewards, terminateds, truncateds, info, env_state = self._env.step(state.env_state, action, key)
done = jnp.logical_or(jnp.any(terminateds), jnp.any(truncateds))
new_episode_return = state.episode_returns + rewards.sum(axis=0) # rewards are summed over agents "team reward"
new_episode_length = state.episode_lengths + 1
state = LogEnvState(
env_state=env_state,
episode_returns=new_episode_return * (1 - done),
episode_lengths=new_episode_length * (1 - done),
returned_episode_returns=state.returned_episode_returns * (1 - done) + new_episode_return * done,
returned_episode_lengths=state.returned_episode_lengths * (1 - done) + new_episode_length * done,
timestep=state.timestep + 1,
total_timestep=state.total_timestep + 1,
)
info["returned_episode_returns"] = state.returned_episode_returns
info["returned_episode_lengths"] = state.returned_episode_lengths
info["timestep"] = state.timestep
info["total_timestep"] = state.total_timestep
info["returned_episode"] = done
return obs, rewards, terminateds, truncateds, info, state
def state(self, state: LogEnvState) -> chex.Array:
return self._env.state(state.env_state)
class AddIDToObs(Wrapper):
"""Add agent id to observation as one hot encoding."""
def __init__(self, env, num_agents):
super().__init__(env)
self.num_agents = num_agents
def _add_id(self, obs: jnp.ndarray) -> jnp.ndarray:
# one hot encoding of agent id
def _one_hot(id: int):
return jnp.eye(self.num_agents)[id]
return jnp.array([jnp.concatenate([o, _one_hot(id)]) for id, o in enumerate(obs)])
def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, dict, State]:
obs, info, state = self._env.reset(key)
return self._add_id(obs), info, state
def step(
self, state: State, action: jnp.ndarray, key: chex.PRNGKey
) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array, dict, State]:
obs, rewards, term, trunc, info, state = self._env.step(state, action, key)
return self._add_id(obs), rewards, term, trunc, info, state
def state(self, state: State) -> chex.Array:
return self._env.state(state)
class AutoReset(Wrapper):
"""Automatically reset the environment when done.
Based on Brax's wrapper; https://github.com/google/brax/blob/main/brax/envs/wrappers/training.py#L96C1-L123C65"""
def __init__(self, env):
super().__init__(env)
def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, dict, State]:
return self._env.reset(key)
def step(
self, state: State, action: jnp.ndarray, key: chex.PRNGKey
) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array, dict, State]:
obs, rewards, term, trunc, info, state = self._env.step(state, action, key)
done = jnp.logical_or(jnp.any(term), jnp.any(trunc))
def where_done(ifval, elseval):
nonlocal done
if done.shape:
done = jnp.reshape(done, [ifval.shape[0]] + [1] * (len(elseval.shape) - 1)) # type: ignore
return jnp.where(done, ifval, elseval)
if isinstance(self._env, LogWrapper):
new_obs, new_info, new_state = self._env.reset(key, state.total_timestep)
else:
new_obs, new_info, new_state = self._env.reset(key)
obs = where_done(new_obs, obs)
state = jax.tree_util.tree_map(where_done, new_state, state)
# TODO does not work with VecEnv... info["final_obs"] = where_done(new_obs, obs)
return obs, rewards, term, trunc, info, state
def state(self, state: State) -> chex.Array:
return self._env.state(state)
@struct.dataclass
class NormalizeVecRewEnvState:
mean: jnp.ndarray
var: jnp.ndarray
count: float
return_val: float
env_state: State
class NormalizeVecReward(Wrapper):
"""Normalize the reward over a vectorized environment.
Taken and adapted from https://github.com/luchris429/purejaxrl/blob/main/purejaxrl/wrappers.py
"""
def __init__(self, env, gamma):
super().__init__(env)
self.gamma = gamma
def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, dict, NormalizeVecRewEnvState]:
obs, info, state = self._env.reset(key)
batch_count = obs.shape[0]
state = NormalizeVecRewEnvState(
mean=0.0,
var=1.0,
count=1e-4,
return_val=jnp.zeros((batch_count,)),
env_state=state,
)
return obs, info, state
def step(
self, state: chex.Array, action: chex.Array, key: chex.Array
) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array, dict, NormalizeVecRewEnvState]:
obs, reward, term, truncated, info, env_state = self._env.step(state.env_state, action, key)
done = jnp.logical_or(jnp.any(term, axis=1), jnp.any(truncated, axis=1))
return_val = state.return_val * self.gamma * (1 - done) + reward.sum(axis=1) # team reward
batch_mean = jnp.mean(return_val, axis=0)
batch_var = jnp.var(return_val, axis=0)
batch_count = obs.shape[0]
delta = batch_mean - state.mean
tot_count = state.count + batch_count
new_mean = state.mean + delta * batch_count / tot_count
m_a = state.var * state.count
m_b = batch_var * batch_count
M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
state = NormalizeVecRewEnvState(
mean=new_mean,
var=new_var,
count=new_count,
return_val=return_val,
env_state=env_state,
)
return obs, reward / jnp.sqrt(state.var + 1e-8), term, truncated, info, state
def state(self, state: NormalizeVecRewEnvState) -> chex.Array:
return self._env.state(state.env_state)
class NormalizeObservation(Wrapper):
"""Rescale the observation between low and high."""
def __init__(self, env, low=-1, high=1):
super().__init__(env)
self.max_obs = self._env.observation_space(0).high
self.min_obs = self._env.observation_space(0).low
self.low = low
self.high = high
def reset(self, key):
obs, info, state = self._env.reset(key)
obs = self.low + (obs - self.min_obs) * (self.high - self.low) / (self.max_obs - self.min_obs) # min-max normalization
return obs, info, state
def step(self, state, action, key):
obs, reward, term, truncated, info, state = self._env.step(state, action, key)
obs = self.low + (obs - self.min_obs) * (self.high - self.low) / (self.max_obs - self.min_obs) # min-max normalization
return obs, reward, term, truncated, info, state
def state(self, state: State) -> chex.Array:
global_obs = self._env.state(state)
global_obs = self.low + (global_obs - self.min_obs) * (self.high - self.low) / (
self.max_obs - self.min_obs
) # min-max normalization
return global_obs
class ClipActions(Wrapper):
"""Clip actions to the action space."""
def __init__(self, env):
super().__init__(env)
def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, dict, State]:
return self._env.reset(key)
def step(
self, state: State, action: jnp.ndarray, key: chex.PRNGKey
) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array, dict, State]:
action = jnp.clip(action, self._env.action_space(0).low, self._env.action_space(0).high)
return self._env.step(state, action, key)
def state(self, state: State) -> chex.Array:
return self._env.state(state)
class LinearizeReward(Wrapper):
"""Convert MO reward to a single reward."""
def __init__(self, env, weights: jnp.ndarray):
self.weights = weights
super().__init__(env)
def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, dict, State]:
return self._env.reset(key)
def step(
self, state: State, action: jnp.ndarray, key: chex.PRNGKey
) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array, dict, State]:
obs, rewards, term, truncated, info, state = self._env.step(state, action, key)
rewards = jnp.dot(rewards, self.weights)
return obs, rewards, term, truncated, info, state
def state(self, state: State) -> chex.Array:
return self._env.state(state)