-
Notifications
You must be signed in to change notification settings - Fork 19
/
env_wrapper.py
107 lines (93 loc) · 3.9 KB
/
env_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
This file is derived from https://github.com/shelhamer/ourl/envs.py
Originally written by Evan Shelhamer and modified by Deepak Pathak
"""
from __future__ import print_function
import numpy as np
from collections import deque
from PIL import Image
from gym.spaces.box import Box
import gym
import numpy as np
class BufferedObsEnv(gym.ObservationWrapper):
"""Buffer observations and stack e.g. for frame skipping.
n is the length of the buffer, and number of observations stacked.
skip is the number of steps between buffered observations (min=1).
n.b. first obs is the oldest, last obs is the newest.
the buffer is zeroed out on reset.
*must* call reset() for init!
"""
def __init__(self, env=None, n=4, skip=4, shape=(84, 84),
channel_last=False):
super(BufferedObsEnv, self).__init__(env)
self.obs_shape = shape
# most recent raw observations (for max pooling across time steps)
self.obs_buffer = deque(maxlen=2)
self.n = n
self.skip = skip
self.buffer = deque(maxlen=self.n)
self.counter = 0 # init and reset should agree on this
shape = shape + (n,) if channel_last else (n,) + shape
self.observation_space = Box(0.0, 255.0, shape)
self.ch_axis = -1 if channel_last else 0
self.scale = 1.0 / 255
self.observation_space.high[...] = 1.0
def _step(self, action):
obs, reward, done, info = self.env.step(action)
return self._observation(obs), reward, done, info
def _observation(self, obs):
obs = self._convert(obs)
self.counter += 1
if self.counter % self.skip == 0:
self.buffer.append(obs)
obsNew = np.stack(self.buffer, axis=self.ch_axis)
return obsNew.astype(np.float32) * self.scale
def _reset(self):
"""Clear buffer and re-fill by duplicating the first observation."""
self.obs_buffer.clear()
obs = self._convert(self.env.reset())
self.buffer.clear()
self.counter = 0
for _ in range(self.n - 1):
self.buffer.append(np.zeros_like(obs))
self.buffer.append(obs)
obsNew = np.stack(self.buffer, axis=self.ch_axis)
return obsNew.astype(np.float32) * self.scale
def _convert(self, obs):
self.obs_buffer.append(obs)
max_frame = np.max(np.stack(self.obs_buffer), axis=0)
intensity_frame = self._rgb2y(max_frame).astype(np.uint8)
small_frame = np.array(Image.fromarray(intensity_frame).resize(
self.obs_shape, resample=Image.BILINEAR), dtype=np.uint8)
return small_frame
def _rgb2y(self, im):
"""Converts an RGB image to a Y image (as in YUV).
These coefficients are taken from the torch/image library.
Beware: these are more critical than you might think, as the
monochromatic contrast can be surprisingly low.
"""
if len(im.shape) < 3:
return im
return np.sum(im * [0.299, 0.587, 0.114], axis=2)
class NoNegativeRewardEnv(gym.RewardWrapper):
"""Clip reward in negative direction."""
def __init__(self, env=None, neg_clip=0.0):
super(NoNegativeRewardEnv, self).__init__(env)
self.neg_clip = neg_clip
def _reward(self, reward):
new_reward = self.neg_clip if reward < self.neg_clip else reward
return new_reward
def create_doom(record=False, outdir=None):
from ppaquette_gym_doom import wrappers
env = gym.make('ppaquette/DoomMyWayHome-v0')
modewrapper = wrappers.SetPlayingMode('algo')
obwrapper = wrappers.SetResolution('160x120')
acwrapper = wrappers.ToDiscrete('minimal')
env = modewrapper(obwrapper(acwrapper(env)))
if record:
env = gym.wrappers.Monitor(env, outdir, force=True)
fshape = (42, 42)
env.seed(None)
#env = env_wrapper.NoNegativeRewardEnv(env)
env = BufferedObsEnv(env, skip=1, shape=fshape)
return env